diff --git a/examples/associated_functions.reid b/examples/associated_functions.reid index 7cfbc19..9421ad8 100644 --- a/examples/associated_functions.reid +++ b/examples/associated_functions.reid @@ -9,6 +9,6 @@ impl Otus { } fn main() -> u32 { - let otus = Otus {}; + let otus = Otus { field: 17 }; return Otus::test(&otus); } diff --git a/reid/src/mir/implement.rs b/reid/src/mir/implement.rs index 0dfa202..5cc62f5 100644 --- a/reid/src/mir/implement.rs +++ b/reid/src/mir/implement.rs @@ -443,7 +443,7 @@ impl Expression { }, Err(_) => Ok((ReturnKind::Soft, type_kind.clone())), }, - AssociatedFunctionCall(type_kind, function_call) => todo!(), + AssociatedFunctionCall(_, fcall) => fcall.return_type(), } } @@ -462,7 +462,7 @@ impl Expression { ExprKind::FunctionCall(_) => None, ExprKind::If(_) => None, ExprKind::CastTo(expression, _) => expression.backing_var(), - ExprKind::AssociatedFunctionCall(type_kind, function_call) => None, + ExprKind::AssociatedFunctionCall(..) => None, } } diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index 2ed78d8..45e53a3 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -123,7 +123,8 @@ pub type BinopMap = Storage; pub struct Scope { pub module_id: Option, pub binops: BinopMap, - pub function_returns: Storage, + pub associated_functions: Storage, + pub functions: Storage, pub variables: Storage, pub types: Storage, /// Hard Return type of this scope, if inside a function @@ -135,7 +136,8 @@ impl Scope { pub fn inner(&self) -> Scope { Scope { module_id: self.module_id, - function_returns: self.function_returns.clone(), + associated_functions: self.associated_functions.clone(), + functions: self.functions.clone(), variables: self.variables.clone(), binops: self.binops.clone(), types: self.types.clone(), @@ -181,6 +183,9 @@ pub struct ScopeVariable { pub mutable: bool, } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct AssociatedFunctionKey(pub TypeKind, pub String); + #[derive(Clone, Debug, Eq)] pub struct BinopKey { pub params: (TypeKind, TypeKind), @@ -389,7 +394,7 @@ impl Module { for function in &self.functions { scope - .function_returns + .functions .set( function.name.clone(), ScopeFunction { @@ -400,6 +405,19 @@ impl Module { .ok(); } + for (ty, function) in &self.associated_functions { + scope + .associated_functions + .set( + AssociatedFunctionKey(ty.clone(), function.name.clone()), + ScopeFunction { + ret: function.return_type.clone(), + params: function.parameters.iter().cloned().map(|v| v.1).collect(), + }, + ) + .ok(); + } + pass.module(self, PassState::from(state, scope, Some(self.module_id)))?; for function in &mut self.functions { diff --git a/reid/src/mir/typecheck/mod.rs b/reid/src/mir/typecheck/mod.rs index 59fd6a7..93b1fd7 100644 --- a/reid/src/mir/typecheck/mod.rs +++ b/reid/src/mir/typecheck/mod.rs @@ -24,12 +24,16 @@ pub enum ErrorKind { TypesIncompatible(TypeKind, TypeKind), #[error("Variable not defined: {0}")] VariableNotDefined(String), - #[error("Function not defined: {0}")] + #[error("Function {0} not defined")] FunctionNotDefined(String), + #[error("Function {0} not defined for type {1}")] + AssocFunctionNotDefined(String, TypeKind), #[error("Expected a return type of {0}, got {1} instead")] ReturnTypeMismatch(TypeKind, TypeKind), #[error("Function {0} already defined {1}")] FunctionAlreadyDefined(String, ErrorTypedefKind), + #[error("Function {0}::{1} already defined {2}")] + AssocFunctionAlreadyDefined(TypeKind, String, ErrorTypedefKind), #[error("Variable already defined: {0}")] VariableAlreadyDefined(String), #[error("Variable {0} is not declared as mutable")] diff --git a/reid/src/mir/typecheck/typecheck.rs b/reid/src/mir/typecheck/typecheck.rs index 4162949..1f38329 100644 --- a/reid/src/mir/typecheck/typecheck.rs +++ b/reid/src/mir/typecheck/typecheck.rs @@ -443,7 +443,7 @@ impl Expression { ExprKind::FunctionCall(function_call) => { let true_function = state .scope - .function_returns + .functions .get(&function_call.name) .cloned() .ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone())); @@ -724,7 +724,56 @@ impl Expression { let expr = expression.typecheck(state, typerefs, HintKind::Default)?; expr.resolve_ref(typerefs).cast_into(type_kind) } - ExprKind::AssociatedFunctionCall(type_kind, function_call) => todo!(), + ExprKind::AssociatedFunctionCall(type_kind, function_call) => { + let true_function = state + .scope + .associated_functions + .get(&pass::AssociatedFunctionKey( + type_kind.clone(), + function_call.name.clone(), + )) + .cloned() + .ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone())); + + if let Some(f) = state.ok(true_function, self.1) { + let param_len_given = function_call.parameters.len(); + let param_len_expected = f.params.len(); + + // Check that there are the same number of parameters given + // as expected + if param_len_given != param_len_expected { + state.ok::<_, Infallible>( + Err(ErrorKind::InvalidAmountParameters( + function_call.name.clone(), + param_len_given, + param_len_expected, + )), + self.1, + ); + } + + let true_params_iter = f + .params + .into_iter() + .chain(iter::repeat(TypeKind::Vague(Vague::Unknown))); + + for (param, true_param_t) in function_call.parameters.iter_mut().zip(true_params_iter) { + // Typecheck every param separately + let param_res = param.typecheck(state, &typerefs, HintKind::Coerce(true_param_t.clone())); + let param_t = state.or_else(param_res, TypeKind::Vague(Vague::Unknown), param.1); + state.ok(param_t.narrow_into(&true_param_t), param.1); + } + + // Make sure function return type is the same as the claimed + // return type + let ret_t = f.ret.narrow_into(&function_call.return_type.resolve_ref(typerefs))?; + // Update typing to be more accurate + function_call.return_type = ret_t.clone(); + Ok(ret_t.resolve_ref(typerefs)) + } else { + Ok(function_call.return_type.clone().resolve_ref(typerefs)) + } + } } } } diff --git a/reid/src/mir/typecheck/typeinference.rs b/reid/src/mir/typecheck/typeinference.rs index 3522b5a..926161b 100644 --- a/reid/src/mir/typecheck/typeinference.rs +++ b/reid/src/mir/typecheck/typeinference.rs @@ -12,8 +12,8 @@ use std::{ use crate::{ mir::{ - BinopDefinition, Block, CustomTypeKey, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, - IfExpression, Module, ReturnKind, StmtKind, TypeKind, WhileStatement, + pass::AssociatedFunctionKey, BinopDefinition, Block, CustomTypeKey, ExprKind, Expression, FunctionDefinition, + FunctionDefinitionKind, IfExpression, Module, ReturnKind, StmtKind, TypeKind, WhileStatement, }, util::try_all, }; @@ -60,6 +60,29 @@ impl<'t> Pass for TypeInference<'t> { } } + let mut seen_assoc_functions = HashMap::new(); + for (ty, function) in &mut module.associated_functions { + if let Some(kind) = seen_assoc_functions.get(&(ty.clone(), function.name.clone())) { + state.note_errors( + &vec![ErrorKind::AssocFunctionAlreadyDefined( + ty.clone(), + function.name.clone(), + *kind, + )], + function.signature(), + ); + } else { + seen_assoc_functions.insert( + (ty.clone(), function.name.clone()), + match function.kind { + FunctionDefinitionKind::Local(..) => ErrorTypedefKind::Local, + FunctionDefinitionKind::Extern(..) => ErrorTypedefKind::Extern, + FunctionDefinitionKind::Intrinsic(..) => ErrorTypedefKind::Intrinsic, + }, + ); + } + } + let mut seen_binops = HashSet::new(); for binop in &module.binop_defs { let binop_key = BinopKey { @@ -365,7 +388,7 @@ impl Expression { // Get function definition and types let fn_call = state .scope - .function_returns + .functions .get(&function_call.name) .ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()))? .clone(); @@ -566,7 +589,33 @@ impl Expression { expression.infer_types(state, type_refs)?; Ok(type_refs.from_type(type_kind).unwrap()) } - ExprKind::AssociatedFunctionCall(type_kind, function_call) => todo!(), + ExprKind::AssociatedFunctionCall(type_kind, function_call) => { + // Get function definition and types + let fn_call = state + .scope + .associated_functions + .get(&AssociatedFunctionKey(type_kind.clone(), function_call.name.clone())) + .ok_or(ErrorKind::AssocFunctionNotDefined( + function_call.name.clone(), + type_kind.clone(), + ))? + .clone(); + + // Infer param expression types and narrow them to the + // expected function parameters (or Unknown types if too + // many were provided) + let true_params_iter = fn_call.params.iter().chain(iter::repeat(&Vague(Unknown))); + + for (param_expr, param_t) in function_call.parameters.iter_mut().zip(true_params_iter) { + let expr_res = param_expr.infer_types(state, type_refs); + if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) { + param_ref.narrow(&mut type_refs.from_type(param_t).unwrap()); + } + } + + // Provide function return type + Ok(type_refs.from_type(&fn_call.ret).unwrap()) + } } } }