diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index 7024878..194f176 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -67,6 +67,8 @@ pub enum VagueType { Unknown, #[error("Number")] Number, + #[error("Hinted({0})")] + Hinted(usize), } impl TypeKind { @@ -158,6 +160,14 @@ impl Literal { } } +impl VagueLiteral { + pub fn as_type(self: &VagueLiteral) -> VagueType { + match self { + VagueLiteral::Number(_) => VagueType::Number, + } + } +} + #[derive(Debug, Clone, Copy)] pub enum BinaryOperator { Add, diff --git a/reid/src/mir/scopehints.rs b/reid/src/mir/scopehints.rs index a8e1773..12eea26 100644 --- a/reid/src/mir/scopehints.rs +++ b/reid/src/mir/scopehints.rs @@ -1,46 +1,77 @@ -use std::{cell::RefCell, collections::HashMap}; +use std::{cell::RefCell, collections::HashMap, fmt::Error, rc::Rc}; use super::{ typecheck::{Collapsable, ErrorKind}, - TypeKind, + BinaryOperator, Literal, TypeKind, VagueLiteral, VagueType, }; -#[derive(Debug, Clone)] -pub struct ScopeHint<'scope>(usize, &'scope ScopeHints<'scope>); +#[derive(Clone)] +pub struct ScopeHint<'scope>(TypeIdRef, &'scope ScopeHints<'scope>); impl<'scope> ScopeHint<'scope> { - pub fn resolve(&self) -> TypeRef { - let mut scope = self.1; - while !scope.type_hints.borrow().contains_key(&self.0) { - scope = scope.outer.as_ref().unwrap(); - } - let ty = scope.type_hints.borrow().get(&self.0).unwrap().clone(); - match ty.known() { - Ok(narrow) => TypeRef::Literal(narrow), - Err(_) => TypeRef::Hint(self.clone()), + pub unsafe fn raw_type(&self) -> TypeKind { + if let Some(ty) = self.1.types.hints.borrow().get(*self.0.borrow()) { + *ty + } else { + panic!("TODO") } } - pub fn narrow(&self, ty_ref: &TypeRef) -> Result { + pub fn narrow(&mut self, ty_ref: &TypeRef) -> Result, ErrorKind> { match ty_ref { TypeRef::Hint(other) => self.1.combine_vars(self, other), TypeRef::Literal(ty) => self.1.narrow_to_type(self, ty), } } + + pub fn as_type(&self) -> TypeKind { + TypeKind::Vague(super::VagueType::Hinted(*self.0.borrow())) + } } +impl<'scope> std::fmt::Debug for ScopeHint<'scope> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("Hint") + .field(&self.0) + .field(unsafe { &self.raw_type() }) + .finish() + } +} + +type TypeIdRef = Rc>; + #[derive(Debug, Default)] +pub struct TypeHints { + /// Simple list of types that variables can refrence + hints: RefCell>, + types: RefCell>, +} + +impl TypeHints { + pub fn new(&self, ty: TypeKind) -> TypeIdRef { + let idx = self.hints.borrow().len(); + let typecell = Rc::new(RefCell::new(idx)); + self.types.borrow_mut().push(typecell.clone()); + self.hints.borrow_mut().push(ty); + typecell + } +} + +#[derive(Debug)] pub struct ScopeHints<'outer> { + types: &'outer TypeHints, outer: Option<&'outer ScopeHints<'outer>>, /// Mapping of what types variables point to - variables: RefCell>, - /// Simple list of types that variables can refrence - type_hints: RefCell>, + variables: RefCell>, } impl<'outer> ScopeHints<'outer> { - fn get_idx(&self) -> usize { - self.type_hints.borrow().len() + self.outer.as_ref().map(|o| o.get_idx()).unwrap_or(0) + pub fn from(types: &'outer TypeHints) -> ScopeHints<'outer> { + ScopeHints { + types, + outer: Default::default(), + variables: Default::default(), + } } pub fn new_var( @@ -52,21 +83,29 @@ impl<'outer> ScopeHints<'outer> { if self.variables.borrow().contains_key(&name) { return Err(ErrorKind::VariableAlreadyDefined(name)); } - let idx = self.get_idx(); - self.variables.borrow_mut().insert(name, (mutable, idx)); - self.type_hints.borrow_mut().insert(idx, initial_ty); + let idx = self.types.new(initial_ty); + self.variables + .borrow_mut() + .insert(name, (mutable, idx.clone())); Ok(ScopeHint(idx, self)) } + fn new_vague(&'outer self, vague: &VagueType) -> ScopeHint<'outer> { + let idx = self.types.new(TypeKind::Vague(*vague)); + ScopeHint(idx, self) + } + fn narrow_to_type( &'outer self, hint: &ScopeHint, ty: &TypeKind, ) -> Result, ErrorKind> { - let mut hints = self.type_hints.borrow_mut(); - let existing = hints.get_mut(&hint.0).unwrap(); - *existing = existing.collapse_into(&ty)?; - Ok(ScopeHint(hint.0, self)) + unsafe { + let mut hints = self.types.hints.borrow_mut(); + let existing = hints.get_unchecked_mut(*hint.0.borrow()); + *existing = existing.collapse_into(&ty)?; + Ok(ScopeHint(hint.0.clone(), self)) + } } fn combine_vars( @@ -74,21 +113,28 @@ impl<'outer> ScopeHints<'outer> { hint1: &ScopeHint, hint2: &ScopeHint, ) -> Result, ErrorKind> { - let ty = self.type_hints.borrow().get(&hint2.0).unwrap().clone(); - self.narrow_to_type(&hint1, &ty)?; - for (_, (_, idx)) in self.variables.borrow_mut().iter_mut() { - if *idx == hint2.0 { - *idx = hint1.0; + unsafe { + let ty = self + .types + .hints + .borrow() + .get_unchecked(*hint2.0.borrow()) + .clone(); + self.narrow_to_type(&hint1, &ty)?; + for idx in self.types.types.borrow_mut().iter_mut() { + if *idx == hint2.0 { + *idx.borrow_mut() = *hint1.0.borrow(); + } } + Ok(ScopeHint(hint1.0.clone(), self)) } - Ok(ScopeHint(hint1.0, self)) } pub fn inner(&'outer self) -> ScopeHints<'outer> { ScopeHints { + types: self.types, outer: Some(self), variables: Default::default(), - type_hints: Default::default(), } } @@ -96,20 +142,35 @@ impl<'outer> ScopeHints<'outer> { self.variables .borrow() .get(name) - .map(|(mutable, idx)| (*mutable, ScopeHint(*idx, self))) + .map(|(mutable, idx)| (*mutable, ScopeHint(idx.clone(), self))) + .or(self.outer.map(|o| o.find_hint(name)).flatten()) + } + + pub fn binop( + &'outer self, + op: &BinaryOperator, + lhs: &mut TypeRef<'outer>, + rhs: &mut TypeRef<'outer>, + ) -> Result, ErrorKind> { + let ty = lhs.narrow(rhs)?; + Ok(match op { + BinaryOperator::Add => ty, + BinaryOperator::Minus => ty, + BinaryOperator::Mult => ty, + BinaryOperator::And => TypeRef::Literal(TypeKind::Bool), + BinaryOperator::Cmp(_) => TypeRef::Literal(TypeKind::Bool), + }) } } +#[derive(Debug)] pub enum TypeRef<'scope> { Hint(ScopeHint<'scope>), Literal(TypeKind), } impl<'scope> TypeRef<'scope> { - pub fn narrow( - &'scope self, - other: &'scope TypeRef<'scope>, - ) -> Result, ErrorKind> { + pub fn narrow(&mut self, other: &mut TypeRef<'scope>) -> Result, ErrorKind> { match (self, other) { (TypeRef::Hint(hint), unk) | (unk, TypeRef::Hint(hint)) => { Ok(TypeRef::Hint(hint.narrow(unk)?)) @@ -119,4 +180,27 @@ impl<'scope> TypeRef<'scope> { } } } + + pub fn from_type(hints: &'scope ScopeHints<'scope>, ty: TypeKind) -> TypeRef<'scope> { + match &ty.known() { + Ok(ty) => TypeRef::Literal(*ty), + Err(vague) => match &vague { + super::VagueType::Hinted(idx) => TypeRef::Hint(ScopeHint( + unsafe { hints.types.types.borrow().get_unchecked(*idx).clone() }, + hints, + )), + _ => TypeRef::Hint(hints.new_vague(vague)), + }, + } + } + + pub fn from_literal( + hints: &'scope ScopeHints<'scope>, + lit: Literal, + ) -> Result, ErrorKind> { + Ok(match lit { + Literal::Vague(vague) => TypeRef::Hint(hints.new_vague(&vague.as_type())), + _ => TypeRef::Literal(lit.as_type()), + }) + } } diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index 3d4dfc6..6cdb424 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -1,6 +1,9 @@ //! This module contains code relevant to doing a type checking pass on the MIR. //! During typechecking relevant types are also coerced if possible. -use std::{cell::RefCell, collections::HashMap, convert::Infallible, iter, marker::PhantomData}; +use std::{ + cell::RefCell, collections::HashMap, convert::Infallible, iter, marker::PhantomData, + thread::scope, +}; use crate::{mir::*, util::try_all}; use TypeKind::*; @@ -8,8 +11,8 @@ use VagueType::*; use super::{ pass::{Pass, PassState, ScopeFunction, ScopeVariable, Storage}, - scopehints::ScopeHints, - types::ReturnType, + scopehints::{ScopeHints, TypeHints, TypeRef}, + types::{pick_return, ReturnType}, }; #[derive(thiserror::Error, Debug, Clone)] @@ -77,7 +80,16 @@ impl FunctionDefinition { let inferred = match &mut self.kind { FunctionDefinitionKind::Local(block, _) => { state.scope.return_type_hint = Some(self.return_type); - block.typecheck(state, &ScopeHints::default(), Some(return_type)) + + let types = TypeHints::default(); + let hints = ScopeHints::from(&types); + if let Ok(_) = block.infer_hints(state, &hints) { + dbg!(&block, &hints); + // block.typecheck(state, &hints, Some(return_type)) + Ok(Vague(Unknown)) + } else { + Ok(Vague(Unknown)) + } } FunctionDefinitionKind::Extern => Ok(Vague(Unknown)), }; @@ -92,6 +104,69 @@ impl FunctionDefinition { } impl Block { + fn infer_hints<'s>( + &mut self, + state: &mut PassState, + outer_hints: &'s ScopeHints, + ) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> { + let mut state = state.inner(); + let inner_hints = outer_hints.inner(); + + for statement in &mut self.statements { + match &mut statement.0 { + StmtKind::Let(var, mutable, expr) => { + let mut var_ref = + state.ok(inner_hints.new_var(var.1.clone(), *mutable, var.0), var.2); + if let Some(var_ref) = &var_ref { + var.0 = var_ref.as_type(); + } + let inferred = expr.infer_hints(&mut state, &inner_hints); + let mut expr_ty_ref = state.ok(inferred, expr.1); + if let (Some(var_ref), Some(expr_ty_ref)) = + (var_ref.as_mut(), expr_ty_ref.as_mut()) + { + state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1); + } + } + StmtKind::Set(var, expr) => { + let var_ref = inner_hints.find_hint(&var.1); + dbg!(&var_ref); + if let Some((_, var_ref)) = &var_ref { + var.0 = var_ref.as_type() + } + let inferred = expr.infer_hints(&mut state, &inner_hints); + let expr_ty_ref = state.ok(inferred, expr.1); + if let (Some((_, mut var_ref)), Some(expr_ty_ref)) = (var_ref, expr_ty_ref) { + state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1); + } + } + StmtKind::Import(_) => todo!(), + StmtKind::Expression(expr) => { + let expr_res = expr.infer_hints(&mut state, &inner_hints); + state.ok(expr_res, expr.1); + } + }; + } + + if let Some(ret_expr) = &mut self.return_expression { + let ret_res = ret_expr.1.infer_hints(&mut state, &inner_hints); + state.ok(ret_res, ret_expr.1 .1); + } + + let (kind, ty) = self.return_type().ok().unwrap_or((ReturnKind::Soft, Void)); + let mut ret_type_ref = TypeRef::from_type(&outer_hints, ty); + + if kind == ReturnKind::Hard { + if let Some(hint) = state.scope.return_type_hint { + state.ok( + ret_type_ref.narrow(&mut TypeRef::from_type(outer_hints, hint)), + self.meta, + ); + } + } + Ok((kind, ret_type_ref)) + } + fn typecheck( &mut self, state: &mut PassState, @@ -230,6 +305,92 @@ impl Block { } impl Expression { + fn infer_hints<'s>( + &mut self, + state: &mut PassState, + hints: &'s ScopeHints<'s>, + ) -> Result, ErrorKind> { + match &mut self.0 { + ExprKind::Variable(var) => { + let hint = hints + .find_hint(&var.1) + .map(|(_, hint)| hint) + .ok_or(ErrorKind::VariableNotDefined(var.1.clone())); + if let Ok(hint) = &hint { + var.0 = hint.as_type() + } + Ok(TypeRef::Hint(hint?)) + } + ExprKind::Literal(literal) => TypeRef::from_literal(hints, *literal), + ExprKind::BinOp(op, lhs, rhs) => { + let mut lhs_ref = lhs.infer_hints(state, hints)?; + let mut rhs_ref = rhs.infer_hints(state, hints)?; + hints.binop(op, &mut lhs_ref, &mut rhs_ref) + } + ExprKind::FunctionCall(function_call) => { + let fn_call = state + .scope + .function_returns + .get(&function_call.name) + .ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()))? + .clone(); + + let true_params_iter = fn_call.params.iter().chain(iter::repeat(&Vague(Unknown))); + + for (param_expr, param_t) in + function_call.parameters.iter_mut().zip(true_params_iter) + { + let expr_res = param_expr.infer_hints(state, hints); + if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) { + state.ok( + param_ref.narrow(&mut TypeRef::from_type(hints, *param_t)), + param_expr.1, + ); + } + } + + Ok(TypeRef::from_type(hints, fn_call.ret)) + } + ExprKind::If(IfExpression(cond, lhs, rhs)) => { + let cond_res = cond.infer_hints(state, hints); + let cond_hints = state.ok(cond_res, cond.1); + + if let Some(mut cond_hints) = cond_hints { + state.ok(cond_hints.narrow(&mut TypeRef::Literal(Bool)), cond.1); + } + + let lhs_res = lhs.infer_hints(state, hints); + let lhs_hints = state.ok(lhs_res, cond.1); + + if let Some(rhs) = rhs { + let rhs_res = rhs.infer_hints(state, hints); + let rhs_hints = state.ok(rhs_res, cond.1); + + if let (Some(mut lhs_hints), Some(mut rhs_hints)) = (lhs_hints, rhs_hints) { + state.ok(lhs_hints.1.narrow(&mut rhs_hints.1), self.1); + Ok(pick_return(lhs_hints, rhs_hints).1) + } else { + // Failed to retrieve types from either + Ok(TypeRef::from_type(hints, Vague(Unknown))) + } + } else { + if let Some((_, type_ref)) = lhs_hints { + Ok(type_ref) + } else { + Ok(TypeRef::from_type(hints, Vague(Unknown))) + } + } + } + ExprKind::Block(block) => { + let block_ref = block.infer_hints(state, hints)?; + match block_ref.0 { + ReturnKind::Hard => Ok(TypeRef::from_type(hints, Void)), + ReturnKind::Soft => Ok(block_ref.1), + } + } + } + } + fn typecheck( &mut self, state: &mut PassState, @@ -415,6 +576,7 @@ impl TypeKind { Vague(vague_type) => match vague_type { Unknown => Err(ErrorKind::TypeIsVague(*vague_type)), Number => Ok(TypeKind::I32), + Hinted(_) => panic!("Hinted default!"), }, _ => Ok(*self), } diff --git a/reid/src/mir/types.rs b/reid/src/mir/types.rs index 369ec7f..448908b 100644 --- a/reid/src/mir/types.rs +++ b/reid/src/mir/types.rs @@ -1,11 +1,14 @@ +use crate::mir::typecheck::Collapsable; + use super::*; #[derive(Debug, Clone)] pub enum ReturnTypeOther { - Import(TokenRange), - Let(TokenRange), - EmptyBlock(TokenRange), - NoBlockReturn(TokenRange), + Import(Metadata), + Let(Metadata), + Set(Metadata), + EmptyBlock(Metadata), + NoBlockReturn(Metadata), } pub trait ReturnType { @@ -14,9 +17,25 @@ pub trait ReturnType { impl ReturnType for Block { fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { + let mut early_return = None; + + for statement in &self.statements { + let ret = statement.return_type(); + if let Ok((ReturnKind::Hard, _)) = ret { + early_return = early_return.or(ret.ok()); + } + } + + // TODO should actually probably prune all instructions after this one + // as to not cause problems in codegen later (when unable to delete the + // block) + if let Some((ReturnKind::Hard, ret_ty)) = early_return { + return Ok((ReturnKind::Hard, ret_ty)); + } + self.return_expression .as_ref() - .ok_or(ReturnTypeOther::NoBlockReturn(self.meta.range)) + .ok_or(ReturnTypeOther::NoBlockReturn(self.meta)) .and_then(|(kind, stmt)| Ok((*kind, stmt.return_type()?.1))) } } @@ -25,10 +44,16 @@ impl ReturnType for Statement { fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { use StmtKind::*; match &self.0 { - Expression(e) => e.return_type(), - Set(_, _) => todo!(), - Import(_) => Err(ReturnTypeOther::Import(self.1.range)), - Let(_, _, _) => Err(ReturnTypeOther::Let(self.1.range)), + Let(var, _, expr) => if_hard( + expr.return_type()?, + Err(ReturnTypeOther::Let(var.2 + expr.1)), + ), + Set(var, expr) => if_hard( + expr.return_type()?, + Err(ReturnTypeOther::Set(var.2 + expr.1)), + ), + Import(_) => todo!(), + Expression(expression) => expression.return_type(), } } } @@ -43,12 +68,7 @@ impl ReturnType for Expression { let then_r = then_e.return_type()?; let else_r = else_e.return_type()?; - let kind = if then_r.0 == ReturnKind::Hard && else_r.0 == ReturnKind::Hard { - ReturnKind::Hard - } else { - ReturnKind::Soft - }; - Ok((kind, then_r.1)) + Ok(pick_return(then_r, else_r)) } Block(block) => block.return_type(), FunctionCall(fcall) => fcall.return_type(), @@ -87,3 +107,24 @@ impl ReturnType for FunctionCall { Ok((ReturnKind::Soft, self.return_type.clone())) } } + +fn if_hard( + return_type: (ReturnKind, TypeKind), + default: Result<(ReturnKind, TypeKind), TErr>, +) -> Result<(ReturnKind, TypeKind), TErr> { + if let (ReturnKind::Hard, _) = return_type { + Ok(return_type) + } else { + default + } +} + +pub fn pick_return(lhs: (ReturnKind, T), rhs: (ReturnKind, T)) -> (ReturnKind, T) { + use ReturnKind::*; + match (lhs.0, rhs.0) { + (Hard, Hard) => (Hard, lhs.1), + (Hard, Soft) => (Soft, rhs.1), + (Soft, Hard) => (Soft, lhs.1), + (_, _) => (Soft, lhs.1), + } +}