diff --git a/examples/custom_binop.reid b/examples/custom_binop.reid index f337312..3a80bac 100644 --- a/examples/custom_binop.reid +++ b/examples/custom_binop.reid @@ -1,7 +1,7 @@ // Arithmetic, function calls and imports! -impl binop (lhs: u32) + (rhs: u32) -> u32 { - +impl binop (lhs: u16) + (rhs: u32) -> u32 { + return (lhs as u32) + rhs; } fn main() -> u32 { diff --git a/reid/src/ast/mod.rs b/reid/src/ast/mod.rs index c72a15c..d135995 100644 --- a/reid/src/ast/mod.rs +++ b/reid/src/ast/mod.rs @@ -217,6 +217,7 @@ pub struct BinopDefinition { pub rhs: (String, Type), pub return_ty: Type, pub block: Block, + pub signature_range: TokenRange, } #[derive(Debug)] diff --git a/reid/src/ast/parse.rs b/reid/src/ast/parse.rs index bfa0268..a44e2a2 100644 --- a/reid/src/ast/parse.rs +++ b/reid/src/ast/parse.rs @@ -554,7 +554,11 @@ impl Parse for Block { statements.push(statement); } stream.expect(Token::BraceClose)?; - Ok(Block(statements, return_stmt, stream.get_range().unwrap())) + Ok(Block( + statements, + return_stmt, + stream.get_range_prev().unwrap(), + )) } } @@ -814,6 +818,8 @@ impl Parse for BinopDefinition { let rhs_type = stream.parse()?; stream.expect(Token::ParenClose)?; + let signature_range = stream.get_range().unwrap(); + stream.expect(Token::Arrow)?; Ok(BinopDefinition { @@ -822,6 +828,7 @@ impl Parse for BinopDefinition { rhs: (rhs_name, rhs_type), return_ty: stream.parse()?, block: stream.parse()?, + signature_range, }) } } diff --git a/reid/src/ast/process.rs b/reid/src/ast/process.rs index 610dcc4..bad6d67 100644 --- a/reid/src/ast/process.rs +++ b/reid/src/ast/process.rs @@ -104,6 +104,7 @@ impl ast::Module { rhs, return_ty, block, + signature_range, }) => { binops.push(mir::BinopDefinition { lhs: (lhs.0.clone(), lhs.1 .0.into_mir(module_id)), @@ -111,6 +112,7 @@ impl ast::Module { rhs: (rhs.0.clone(), rhs.1 .0.into_mir(module_id)), return_ty: return_ty.0.into_mir(module_id), block: block.into_mir(module_id), + meta: signature_range.as_meta(module_id), }); } } diff --git a/reid/src/mir/fmt.rs b/reid/src/mir/fmt.rs index 68c0899..9a31e51 100644 --- a/reid/src/mir/fmt.rs +++ b/reid/src/mir/fmt.rs @@ -40,6 +40,9 @@ impl Display for Module { for import in &self.imports { writeln!(inner_f, "{}", import)?; } + for binop in &self.binop_defs { + writeln!(inner_f, "{}", binop)?; + } for typedef in &self.typedefs { writeln!(inner_f, "{}", typedef)?; } @@ -56,6 +59,17 @@ impl Display for Import { } } +impl Display for BinopDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "impl binop ({}: {:#}) {} ({}: {:#}) -> {:#} ", + self.lhs.0, self.lhs.1, self.op, self.rhs.0, self.rhs.1, self.return_ty + )?; + Display::fmt(&self.block, f) + } +} + impl Display for TypeDefinition { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "type {} = ", self.name)?; diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index 4cb3630..db1ce2b 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -372,6 +372,17 @@ pub struct BinopDefinition { pub rhs: (String, TypeKind), pub return_ty: TypeKind, pub block: Block, + pub meta: Metadata, +} + +impl BinopDefinition { + pub fn block_meta(&self) -> Metadata { + self.block.meta + } + + pub fn signature(&self) -> Metadata { + self.meta + } } #[derive(Debug)] diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index 23e586b..842a494 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -132,6 +132,11 @@ impl<'t> Pass for TypeCheck<'t> { check_typedefs_for_recursion(&defmap, typedef, HashSet::new(), &mut state); } + for binop in &mut module.binop_defs { + let res = binop.typecheck(&self.refs, &mut state.inner()); + state.ok(res, binop.block_meta()); + } + for function in &mut module.functions { let res = function.typecheck(&self.refs, &mut state.inner()); state.ok(res, function.block_meta()); @@ -170,6 +175,48 @@ fn check_typedefs_for_recursion<'a, 'b>( } } +impl BinopDefinition { + fn typecheck( + &mut self, + typerefs: &TypeRefs, + state: &mut TypecheckPassState, + ) -> Result { + for param in vec![&self.lhs, &self.rhs] { + let param_t = state.or_else( + param.1.assert_known(typerefs, state), + TypeKind::Vague(Vague::Unknown), + self.signature(), + ); + let res = state + .scope + .variables + .set( + param.0.clone(), + ScopeVariable { + ty: param_t.clone(), + mutable: param_t.is_mutable(), + }, + ) + .or(Err(ErrorKind::VariableAlreadyDefined(param.0.clone()))); + state.ok(res, self.signature()); + } + + let return_type = self.return_ty.clone().assert_known(typerefs, state)?; + + state.scope.return_type_hint = Some(self.return_ty.clone()); + let inferred = self + .block + .typecheck(&mut state.inner(), &typerefs, Some(&return_type)); + + match inferred { + Ok(t) => return_type + .collapse_into(&t.1) + .or(Err(ErrorKind::ReturnTypeMismatch(return_type, t.1))), + Err(e) => Ok(state.or_else(Err(e), return_type, self.block_meta())), + } + } +} + impl FunctionDefinition { fn typecheck( &mut self, diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index 5a805ba..99f637c 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -12,8 +12,8 @@ use super::{ pass::{Pass, PassResult, PassState}, typecheck::{ErrorKind, ErrorTypedefKind}, typerefs::{ScopeTypeRefs, TypeRef, TypeRefs}, - Block, CustomTypeKey, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, - IfExpression, Module, ReturnKind, StmtKind, + BinopDefinition, Block, CustomTypeKey, ExprKind, Expression, FunctionDefinition, + FunctionDefinitionKind, IfExpression, Module, ReturnKind, StmtKind, TypeKind::*, VagueType::*, WhileStatement, @@ -55,6 +55,11 @@ impl<'t> Pass for TypeInference<'t> { } } + for binop in &mut module.binop_defs { + let res = binop.infer_types(&self.refs, &mut state.inner()); + state.ok(res, binop.block_meta()); + } + for function in &mut module.functions { let res = function.infer_types(&self.refs, &mut state.inner()); state.ok(res, function.block_meta()); @@ -63,6 +68,56 @@ impl<'t> Pass for TypeInference<'t> { } } +impl BinopDefinition { + fn infer_types( + &mut self, + type_refs: &TypeRefs, + state: &mut TypeInferencePassState, + ) -> Result<(), ErrorKind> { + let scope_hints = ScopeTypeRefs::from(type_refs); + + let lhs_ty = state.or_else( + self.lhs.1.assert_unvague(), + Vague(Unknown), + self.signature(), + ); + state.ok( + scope_hints + .new_var(self.lhs.0.clone(), false, &lhs_ty) + .or(Err(ErrorKind::VariableAlreadyDefined(self.lhs.0.clone()))), + self.signature(), + ); + + let rhs_ty = state.or_else( + self.rhs.1.assert_unvague(), + Vague(Unknown), + self.signature(), + ); + + state.ok( + scope_hints + .new_var(self.rhs.0.clone(), false, &rhs_ty) + .or(Err(ErrorKind::VariableAlreadyDefined(self.rhs.0.clone()))), + self.signature(), + ); + + state.scope.return_type_hint = Some(self.return_ty.clone()); + let ret_res = self.block.infer_types(state, &scope_hints); + let (_, mut ret_ty) = state.or_else( + ret_res, + ( + ReturnKind::Soft, + scope_hints.from_type(&Vague(Unknown)).unwrap(), + ), + self.block_meta(), + ); + + ret_ty.narrow(&scope_hints.from_type(&self.return_ty).unwrap()); + + Ok(()) + } +} + impl FunctionDefinition { fn infer_types( &mut self,