diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index 7718fe2..18f33e3 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -8,8 +8,8 @@ use VagueType::*; use super::{ pass::{Pass, PassState, ScopeFunction, ScopeVariable}, - typerefs::{ScopeTypeRefs, TypeRef, TypeRefs}, - types::{pick_return, ReturnType}, + typerefs::TypeRefs, + types::ReturnType, }; #[derive(thiserror::Error, Debug, Clone)] @@ -129,10 +129,6 @@ impl Block { ); let res_t = if res_t.known().is_err() { - // state.ok::<_, Infallible>( - // Err(ErrorKind::TypeNotInferrable(res_t)), - // variable_reference.2 + expression.1, - // ); // Unable to infer variable type even from expression! Default it let res_t = state.or_else(res_t.or_default(), Vague(Unknown), variable_reference.2); @@ -146,7 +142,7 @@ impl Block { res_t }; - // Update typing to be more accurate + // Update typing variable_reference.0 = res_t; // Variable might already be defined, note error diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index 1fa6b4b..1bc6cd4 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -1,11 +1,15 @@ -use std::iter; +//! Type Inference is a pass where all of the potentially vague types are went +//! through, stored in an intermediary storage [`TypeRefs`], and then the types +//! in MIR are changed to [`TypeKind::TypeRef`]s with the correct ID. This MIR +//! must then be passed through TypeCheck with the same [`TypeRefs`] in order to +//! place the correct types from the IDs and check that there are no issues. -use reid_lib::Function; +use std::iter; use super::{ pass::{Pass, PassState, ScopeVariable}, typecheck::ErrorKind, - typerefs::{self, ScopeTypeRefs, TypeRef, TypeRefs}, + typerefs::{ScopeTypeRefs, TypeRef, TypeRefs}, types::{pick_return, ReturnType}, Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module, ReturnKind, StmtKind, @@ -13,8 +17,9 @@ use super::{ VagueType::*, }; -/// Struct used to implement a type-checking pass that can be performed on the -/// MIR. +/// Struct used to implement Type Inference, where an intermediary +/// TypeRefs-struct is used as a helper to go through the modules and change +/// types while inferring. pub struct TypeInference<'t> { pub refs: &'t TypeRefs, } @@ -24,14 +29,14 @@ impl<'t> Pass for TypeInference<'t> { fn module(&mut self, module: &mut Module, mut state: PassState) { for function in &mut module.functions { - let res = function.infer_hints(&self.refs, &mut state); + let res = function.infer_types(&self.refs, &mut state); state.ok(res, function.block_meta()); } } } impl FunctionDefinition { - fn infer_hints( + fn infer_types( &mut self, type_refs: &TypeRefs, state: &mut PassState, @@ -51,29 +56,29 @@ impl FunctionDefinition { .or(Err(ErrorKind::VariableAlreadyDefined(param.0.clone()))); state.ok(res, self.signature()); } - let scope_hints = ScopeTypeRefs::from(type_refs); - let return_type = self.return_type.clone(); - let return_type_hint = scope_hints.from_type(&return_type).unwrap(); - let mut ret = match &mut self.kind { + match &mut self.kind { FunctionDefinitionKind::Local(block, _) => { state.scope.return_type_hint = Some(self.return_type); - let block_res = block.infer_hints(state, &scope_hints); - state.ok(block_res.map(|(_, ty)| ty), self.block_meta()) - } - FunctionDefinitionKind::Extern => Some(scope_hints.from_type(&Vague(Unknown)).unwrap()), - }; + let scope_hints = ScopeTypeRefs::from(type_refs); - if let Some(ret) = &mut ret { - state.ok(ret.narrow(&return_type_hint), self.signature()); - } + // Infer block return type + let ret_res = block.infer_types(state, &scope_hints); + + // Narrow block type to declared function type + if let Some(mut ret_ty) = state.ok(ret_res.map(|(_, ty)| ty), self.block_meta()) { + ret_ty.narrow(&scope_hints.from_type(&self.return_type).unwrap()); + } + } + FunctionDefinitionKind::Extern => {} + }; Ok(()) } } impl Block { - fn infer_hints<'s>( + fn infer_types<'s>( &mut self, state: &mut PassState, outer_hints: &'s ScopeTypeRefs, @@ -84,82 +89,110 @@ impl Block { for statement in &mut self.statements { match &mut statement.0 { StmtKind::Let(var, mutable, expr) => { + // Get the TypeRef for this variable declaration let mut var_ref = state.ok(inner_hints.new_var(var.1.clone(), *mutable, var.0), var.2); + + // If ok, update the MIR type to this TypeRef if let Some(var_ref) = &var_ref { var.0 = var_ref.as_type(); } - let inferred = expr.infer_hints(&mut state, &inner_hints); + + // Infer hints for the expression itself + let inferred = expr.infer_types(&mut state, &inner_hints); let mut expr_ty_ref = state.ok(inferred, expr.1); + + // Try to narrow the variable type declaration with the + // expression if let (Some(var_ref), Some(expr_ty_ref)) = (var_ref.as_mut(), expr_ty_ref.as_mut()) { - state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1); + var_ref.narrow(&expr_ty_ref); } } StmtKind::Set(var, expr) => { + // Get the TypeRef for this variable declaration let var_ref = inner_hints.find_hint(&var.1); + + // If ok, update the MIR type to this TypeRef if let Some((_, var_ref)) = &var_ref { var.0 = var_ref.as_type() } - let inferred = expr.infer_hints(&mut state, &inner_hints); + + // Infer hints for the expression itself + let inferred = expr.infer_types(&mut state, &inner_hints); let expr_ty_ref = state.ok(inferred, expr.1); + + // Try to narrow the variable type declaration with the + // expression if let (Some((_, mut var_ref)), Some(expr_ty_ref)) = (var_ref, expr_ty_ref) { - state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1); + var_ref.narrow(&expr_ty_ref); } } StmtKind::Import(_) => todo!(), StmtKind::Expression(expr) => { - let expr_res = expr.infer_hints(&mut state, &inner_hints); + let expr_res = expr.infer_types(&mut state, &inner_hints); state.ok(expr_res, expr.1); } }; } + // If there is a return expression, infer it's type if let Some(ret_expr) = &mut self.return_expression { - let ret_res = ret_expr.1.infer_hints(&mut state, &inner_hints); + let ret_res = ret_expr.1.infer_types(&mut state, &inner_hints); state.ok(ret_res, ret_expr.1 .1); } + // Fetch the declared return type let (kind, ty) = self.return_type().ok().unwrap_or((ReturnKind::Soft, Void)); let mut ret_type_ref = outer_hints.from_type(&ty).unwrap(); + // Narow return type to declared type if hard return if kind == ReturnKind::Hard { if let Some(hint) = state.scope.return_type_hint { - state.ok( - ret_type_ref.narrow(&mut outer_hints.from_type(&hint).unwrap()), - self.meta, - ); + ret_type_ref.narrow(&mut outer_hints.from_type(&hint).unwrap()); } } + Ok((kind, ret_type_ref)) } } impl Expression { - fn infer_hints<'s>( + fn infer_types<'s>( &mut self, state: &mut PassState, type_refs: &'s ScopeTypeRefs<'s>, ) -> Result, ErrorKind> { match &mut self.0 { ExprKind::Variable(var) => { - let hint = type_refs + // Find variable type + let type_ref = type_refs .find_hint(&var.1) .map(|(_, hint)| hint) .ok_or(ErrorKind::VariableNotDefined(var.1.clone())); - if let Ok(hint) = &hint { + + // Update MIR type to TypeRef if found + if let Ok(hint) = &type_ref { var.0 = hint.as_type() } - hint + + type_ref } ExprKind::Literal(literal) => Ok(type_refs.from_type(&literal.as_type()).unwrap()), ExprKind::BinOp(op, lhs, rhs) => { - let mut lhs_ref = lhs.infer_hints(state, type_refs)?; - let mut rhs_ref = rhs.infer_hints(state, type_refs)?; - type_refs.binop(op, &mut lhs_ref, &mut rhs_ref) + // Infer LHS and RHS, and return binop type + let mut lhs_ref = lhs.infer_types(state, type_refs)?; + let mut rhs_ref = rhs.infer_types(state, type_refs)?; + type_refs + .binop(op, &mut lhs_ref, &mut rhs_ref) + .ok_or(ErrorKind::TypesIncompatible( + lhs_ref.as_type(), + rhs_ref.as_type(), + )) } ExprKind::FunctionCall(function_call) => { + // Get function definition and types let fn_call = state .scope .function_returns @@ -167,48 +200,52 @@ impl Expression { .ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()))? .clone(); + // Infer param expression types and narrow them to the + // expected function parameters (or Unknown types if too + // many were provided) let true_params_iter = fn_call.params.iter().chain(iter::repeat(&Vague(Unknown))); for (param_expr, param_t) in function_call.parameters.iter_mut().zip(true_params_iter) { - let expr_res = param_expr.infer_hints(state, type_refs); + let expr_res = param_expr.infer_types(state, type_refs); if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) { - state.ok( - param_ref.narrow(&mut type_refs.from_type(param_t).unwrap()), - param_expr.1, - ); + param_ref.narrow(&mut type_refs.from_type(param_t).unwrap()); } } + // Provide function return type Ok(type_refs.from_type(&fn_call.ret).unwrap()) } ExprKind::If(IfExpression(cond, lhs, rhs)) => { - let cond_res = cond.infer_hints(state, type_refs); + // Infer condition type + let cond_res = cond.infer_types(state, type_refs); let cond_hints = state.ok(cond_res, cond.1); + // Try to narrow condition type to boolean if let Some(mut cond_hints) = cond_hints { - state.ok( - cond_hints.narrow(&mut type_refs.from_type(&Bool).unwrap()), - cond.1, - ); + cond_hints.narrow(&mut type_refs.from_type(&Bool).unwrap()); } - let lhs_res = lhs.infer_hints(state, type_refs); + // Infer LHS return type + let lhs_res = lhs.infer_types(state, type_refs); let lhs_hints = state.ok(lhs_res, cond.1); if let Some(rhs) = rhs { - let rhs_res = rhs.infer_hints(state, type_refs); + // Infer RHS return type + let rhs_res = rhs.infer_types(state, type_refs); let rhs_hints = state.ok(rhs_res, cond.1); + // Narrow LHS to the same type as RHS and return it's return type if let (Some(mut lhs_hints), Some(mut rhs_hints)) = (lhs_hints, rhs_hints) { - state.ok(lhs_hints.1.narrow(&mut rhs_hints.1), self.1); + lhs_hints.1.narrow(&mut rhs_hints.1); Ok(pick_return(lhs_hints, rhs_hints).1) } else { // Failed to retrieve types from either Ok(type_refs.from_type(&Vague(Unknown)).unwrap()) } } else { + // Return LHS return type if let Some((_, type_ref)) = lhs_hints { Ok(type_ref) } else { @@ -217,7 +254,7 @@ impl Expression { } } ExprKind::Block(block) => { - let block_ref = block.infer_hints(state, type_refs)?; + let block_ref = block.infer_types(state, type_refs)?; match block_ref.0 { ReturnKind::Hard => Ok(type_refs.from_type(&Void).unwrap()), ReturnKind::Soft => Ok(block_ref.1), diff --git a/reid/src/mir/typerefs.rs b/reid/src/mir/typerefs.rs index 726ee12..81fb400 100644 --- a/reid/src/mir/typerefs.rs +++ b/reid/src/mir/typerefs.rs @@ -17,7 +17,7 @@ impl<'scope> TypeRef<'scope> { unsafe { *self.1.types.hints.borrow().get_unchecked(*self.0.borrow()) } } - pub fn narrow(&mut self, other: &TypeRef) -> Result, ErrorKind> { + pub fn narrow(&mut self, other: &TypeRef) -> Option> { self.1.combine_vars(self, other) } @@ -119,11 +119,11 @@ impl<'outer> ScopeTypeRefs<'outer> { if self.variables.borrow().contains_key(&name) { return Err(ErrorKind::VariableAlreadyDefined(name)); } - let idx = self.types.new(initial_ty); + let type_ref = self.from_type(&initial_ty).unwrap(); self.variables .borrow_mut() - .insert(name, (mutable, idx.clone())); - Ok(TypeRef(idx, self)) + .insert(name, (mutable, type_ref.0.clone())); + Ok(type_ref) } pub fn from_type(&'outer self, ty: &TypeKind) -> Option> { @@ -144,24 +144,16 @@ impl<'outer> ScopeTypeRefs<'outer> { Some(TypeRef(idx, self)) } - fn narrow_to_type( - &'outer self, - hint: &TypeRef, - ty: &TypeKind, - ) -> Result, ErrorKind> { + fn narrow_to_type(&'outer self, hint: &TypeRef, ty: &TypeKind) -> Option> { unsafe { let mut hints = self.types.hints.borrow_mut(); let existing = hints.get_unchecked_mut(*hint.0.borrow()); - *existing = existing.collapse_into(&ty)?; - Ok(TypeRef(hint.0.clone(), self)) + *existing = existing.collapse_into(&ty).ok()?; + Some(TypeRef(hint.0.clone(), self)) } } - fn combine_vars( - &'outer self, - hint1: &TypeRef, - hint2: &TypeRef, - ) -> Result, ErrorKind> { + fn combine_vars(&'outer self, hint1: &TypeRef, hint2: &TypeRef) -> Option> { unsafe { let ty = self .types @@ -175,7 +167,7 @@ impl<'outer> ScopeTypeRefs<'outer> { *idx.borrow_mut() = *hint1.0.borrow(); } } - Ok(TypeRef(hint1.0.clone(), self)) + Some(TypeRef(hint1.0.clone(), self)) } } @@ -200,9 +192,9 @@ impl<'outer> ScopeTypeRefs<'outer> { op: &BinaryOperator, lhs: &mut TypeRef<'outer>, rhs: &mut TypeRef<'outer>, - ) -> Result, ErrorKind> { + ) -> Option> { let ty = lhs.narrow(rhs)?; - Ok(match op { + Some(match op { BinaryOperator::Add => ty, BinaryOperator::Minus => ty, BinaryOperator::Mult => ty,