diff --git a/examples/custom_binop.reid b/examples/custom_binop.reid index 3a80bac..40f8a2b 100644 --- a/examples/custom_binop.reid +++ b/examples/custom_binop.reid @@ -5,8 +5,8 @@ impl binop (lhs: u16) + (rhs: u32) -> u32 { } fn main() -> u32 { - let value = 6; - let other = 15; + let value = 6 as u16; + let other = 15 as u32; - return value * other + 7 * -value; + return value + other; } diff --git a/reid/src/ast/process.rs b/reid/src/ast/process.rs index 8387531..55134ea 100644 --- a/reid/src/ast/process.rs +++ b/reid/src/ast/process.rs @@ -110,7 +110,7 @@ impl ast::Module { lhs: (lhs.0.clone(), lhs.1 .0.into_mir(module_id)), op: op.mir(), rhs: (rhs.0.clone(), rhs.1 .0.into_mir(module_id)), - return_ty: return_ty.0.into_mir(module_id), + return_type: return_ty.0.into_mir(module_id), fn_kind: mir::FunctionDefinitionKind::Local( block.into_mir(module_id), block.2.as_meta(module_id), diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index fef1ac2..9eb3ef9 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -78,11 +78,11 @@ pub struct Scope<'ctx, 'scope> { tokens: &'ctx Vec, module: &'ctx Module<'ctx>, pub(super) module_id: SourceModuleId, - function: &'ctx StackFunction<'ctx>, + function: &'ctx Function<'ctx>, pub(super) block: Block<'ctx>, pub(super) types: &'scope HashMap, pub(super) type_values: &'scope HashMap, - functions: &'scope HashMap>, + functions: &'scope HashMap>, stack_values: HashMap, debug: Option>, allocator: Rc>, @@ -131,10 +131,6 @@ pub struct Debug<'ctx> { types: &'ctx HashMap, } -pub struct StackFunction<'ctx> { - ir: Function<'ctx>, -} - #[derive(Debug, Clone, PartialEq, Eq)] pub struct StackValue(StackValueKind, TypeKind); @@ -291,21 +287,6 @@ impl mir::Module { insert_debug!(&TypeKind::CustomType(type_key.clone())); } - // let mut binops = HashMap::new(); - // for binop in &self.binop_defs { - // binops.insert( - // ScopeBinopKey { - // operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), - // commutative: mir::pass::CommutativeKind::True, - // }, - // StackBinopDefinition { - // parameters: (binop.lhs.clone(), binop.rhs.clone()), - // return_ty: binop.return_ty.clone(), - // ir: todo!(), - // }, - // ); - // } - let mut functions = HashMap::new(); for function in &self.functions { @@ -348,12 +329,91 @@ impl mir::Module { ), }; - functions.insert(function.name.clone(), StackFunction { ir: func }); + functions.insert(function.name.clone(), func); + } + + let mut binops = HashMap::new(); + for binop in &self.binop_defs { + let binop_fn_name = format!( + "binop.{}.{:?}.{}.{}", + binop.lhs.1, binop.op, binop.rhs.1, binop.return_type + ); + let ir_function = module.function( + &binop_fn_name, + binop.return_type.get_type(&type_values), + vec![ + binop.lhs.1.get_type(&type_values), + binop.rhs.1.get_type(&type_values), + ], + FunctionFlags::default(), + ); + let mut entry = ir_function.block("entry"); + + let allocator = Allocator::from( + &binop.fn_kind, + &vec![binop.lhs.clone(), binop.rhs.clone()], + &mut AllocatorScope { + block: &mut entry, + module_id: self.module_id, + type_values: &type_values, + }, + ); + + let mut scope = Scope { + context, + modules: &modules, + tokens, + module: &module, + module_id: self.module_id, + function: &ir_function, + block: entry, + functions: &functions, + types: &types, + type_values: &type_values, + stack_values: HashMap::new(), + debug: Some(Debug { + info: &debug, + scope: compile_unit, + types: &debug_types, + }), + allocator: Rc::new(RefCell::new(allocator)), + }; + + binop + .fn_kind + .codegen( + binop_fn_name.clone(), + false, + &mut scope, + &vec![binop.lhs.clone(), binop.rhs.clone()], + &binop.return_type, + &ir_function, + match &binop.fn_kind { + FunctionDefinitionKind::Local(_, meta) => { + meta.into_debug(tokens, compile_unit) + } + FunctionDefinitionKind::Extern(_) => None, + FunctionDefinitionKind::Intrinsic(_) => None, + }, + ) + .unwrap(); + + binops.insert( + ScopeBinopKey { + operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), + commutative: mir::pass::CommutativeKind::True, + }, + StackBinopDefinition { + parameters: (binop.lhs.clone(), binop.rhs.clone()), + return_ty: binop.return_type.clone(), + ir: ir_function, + }, + ); } for mir_function in &self.functions { let function = functions.get(&mir_function.name).unwrap(); - let mut entry = function.ir.block("entry"); + let mut entry = function.block("entry"); let allocator = Allocator::from( &mir_function.kind, @@ -393,7 +453,7 @@ impl mir::Module { &mut scope, &mir_function.parameters, &mir_function.return_type, - &function.ir, + &function, match &mir_function.kind { FunctionDefinitionKind::Local(..) => { mir_function.signature().into_debug(tokens, compile_unit) @@ -670,9 +730,9 @@ impl mir::Statement { mir::StmtKind::While(WhileStatement { condition, block, .. }) => { - let condition_block = scope.function.ir.block("while.cond"); - let condition_true_block = scope.function.ir.block("while.body"); - let condition_failed_block = scope.function.ir.block("while.end"); + let condition_block = scope.function.block("while.cond"); + let condition_true_block = scope.function.block("while.body"); + let condition_failed_block = scope.function.block("while.end"); scope .block @@ -881,7 +941,7 @@ impl mir::Expression { .block .build_named( call.name.clone(), - Instr::FunctionCall(callee.ir.value(), param_instrs), + Instr::FunctionCall(callee.value(), param_instrs), ) .unwrap(); @@ -929,7 +989,7 @@ impl mir::Expression { } mir::ExprKind::If(if_expression) => if_expression.codegen(scope, state)?, mir::ExprKind::Block(block) => { - let inner = scope.function.ir.block("inner"); + let inner = scope.function.block("inner"); scope.block.terminate(Term::Br(inner.value())).unwrap(); let mut inner_scope = scope.with_block(inner); @@ -938,7 +998,7 @@ impl mir::Expression { } else { None }; - let outer = scope.function.ir.block("outer"); + let outer = scope.function.block("outer"); inner_scope.block.terminate(Term::Br(outer.value())).ok(); scope.swap_block(outer); ret @@ -1341,9 +1401,9 @@ impl mir::IfExpression { let condition = self.0.codegen(scope, state)?.unwrap(); // Create blocks - let mut then_b = scope.function.ir.block("then"); - let mut else_b = scope.function.ir.block("else"); - let after_b = scope.function.ir.block("after"); + let mut then_b = scope.function.block("then"); + let mut else_b = scope.function.block("else"); + let after_b = scope.function.block("after"); if let Some(debug) = &scope.debug { let before_location = self.0 .1.into_debug(scope.tokens, debug.scope).unwrap(); diff --git a/reid/src/mir/fmt.rs b/reid/src/mir/fmt.rs index 4505afd..fc3dcc7 100644 --- a/reid/src/mir/fmt.rs +++ b/reid/src/mir/fmt.rs @@ -64,7 +64,7 @@ impl Display for BinopDefinition { write!( f, "impl binop ({}: {:#}) {} ({}: {:#}) -> {:#} ", - self.lhs.0, self.lhs.1, self.op, self.rhs.0, self.rhs.1, self.return_ty + self.lhs.0, self.lhs.1, self.op, self.rhs.0, self.rhs.1, self.return_type )?; Display::fmt(&self.fn_kind, f) } diff --git a/reid/src/mir/implement.rs b/reid/src/mir/implement.rs index ccffdfc..f8951eb 100644 --- a/reid/src/mir/implement.rs +++ b/reid/src/mir/implement.rs @@ -1,4 +1,4 @@ -use super::{typecheck::ErrorKind, typerefs::TypeRefs, VagueType as Vague, *}; +use super::{pass::ScopeBinopDef, typecheck::ErrorKind, typerefs::TypeRefs, VagueType as Vague, *}; #[derive(Debug, Clone)] pub enum ReturnTypeOther { @@ -95,7 +95,7 @@ impl TypeKind { /// Return the type that is the result of a binary operator between two /// values of this type - pub fn binop_type(&self, op: &BinaryOperator) -> TypeKind { + pub fn simple_binop_type(&self, op: &BinaryOperator) -> TypeKind { // TODO make some type of mechanism that allows to binop two values of // differing types.. // TODO Return None for arrays later @@ -110,6 +110,20 @@ impl TypeKind { } } + pub fn binop_type<'o>( + lhs: &TypeKind, + rhs: &TypeKind, + binop: &ScopeBinopDef, + ) -> Option<(TypeKind, TypeKind, TypeKind)> { + let lhs_ty = lhs.collapse_into(&binop.operators.0); + let rhs_ty = rhs.collapse_into(&binop.operators.1); + if let (Ok(lhs_ty), Ok(rhs_ty)) = (lhs_ty, rhs_ty) { + Some((lhs_ty, rhs_ty, binop.return_ty.clone())) + } else { + None + } + } + /// Reverse of binop_type, where the given hint is the known required output /// type of the binop, and the output is the hint for the lhs/rhs type. pub fn binop_hint(&self, op: &BinaryOperator) -> Option { diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index dfee8a3..2e3055a 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -370,7 +370,7 @@ pub struct BinopDefinition { pub lhs: (String, TypeKind), pub op: BinaryOperator, pub rhs: (String, TypeKind), - pub return_ty: TypeKind, + pub return_type: TypeKind, pub fn_kind: FunctionDefinitionKind, pub meta: Metadata, } diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index d136df3..64b29e6 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -111,6 +111,10 @@ impl Storage { pub fn get(&self, key: &Key) -> Option<&T> { self.0.get(key) } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } } #[derive(Clone, Default, Debug)] @@ -362,7 +366,7 @@ impl Module { ScopeBinopDef { operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), commutative: true, - return_ty: binop.return_ty.clone(), + return_ty: binop.return_type.clone(), }, ); } diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index f15b555..a44801f 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -203,9 +203,9 @@ impl BinopDefinition { state.ok(res, self.signature()); } - let return_type = self.return_ty.clone().assert_known(typerefs, state)?; + let return_type = self.return_type.clone().assert_known(typerefs, state)?; - state.scope.return_type_hint = Some(self.return_ty.clone()); + state.scope.return_type_hint = Some(self.return_type.clone()); let inferred = self.fn_kind .typecheck(&typerefs, &mut state.inner(), Some(return_type.clone())); @@ -535,7 +535,7 @@ impl Expression { } } - Ok(both_t.binop_type(op)) + Ok(both_t.simple_binop_type(op)) } ExprKind::FunctionCall(function_call) => { let true_function = state diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index ae7132b..c42104e 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -124,11 +124,11 @@ impl BinopDefinition { self.signature(), ); - let ret_ty = self - .fn_kind - .infer_types(state, &scope_hints, Some(self.return_ty.clone()))?; + let ret_ty = + self.fn_kind + .infer_types(state, &scope_hints, Some(self.return_type.clone()))?; if let Some(mut ret_ty) = ret_ty { - ret_ty.narrow(&scope_hints.from_type(&self.return_ty).unwrap()); + ret_ty.narrow(&scope_hints.from_type(&self.return_type).unwrap()); } Ok(()) @@ -312,7 +312,7 @@ impl Expression { let mut lhs_ref = lhs.infer_types(state, type_refs)?; let mut rhs_ref = rhs.infer_types(state, type_refs)?; type_refs - .binop(op, &mut lhs_ref, &mut rhs_ref) + .binop(op, &mut lhs_ref, &mut rhs_ref, &state.scope.binops) .ok_or(ErrorKind::TypesIncompatible( lhs_ref.resolve_deep().unwrap(), rhs_ref.resolve_deep().unwrap(), diff --git a/reid/src/mir/typerefs.rs b/reid/src/mir/typerefs.rs index 2be4ad7..e000e84 100644 --- a/reid/src/mir/typerefs.rs +++ b/reid/src/mir/typerefs.rs @@ -6,7 +6,11 @@ use std::{ use crate::mir::VagueType; -use super::{typecheck::ErrorKind, BinaryOperator, TypeKind}; +use super::{ + pass::{ScopeBinopDef, ScopeBinopKey, Storage}, + typecheck::ErrorKind, + BinaryOperator, TypeKind, +}; #[derive(Clone)] pub struct TypeRef<'scope>( @@ -227,8 +231,31 @@ impl<'outer> ScopeTypeRefs<'outer> { op: &BinaryOperator, lhs: &mut TypeRef<'outer>, rhs: &mut TypeRef<'outer>, + binops: &Storage, ) -> Option> { + for (_, binop) in binops.iter() { + if let Some(ret) = try_binop(lhs, rhs, binop) { + return Some(ret); + } + if binop.commutative { + if let Some(ret) = try_binop(rhs, lhs, binop) { + return Some(ret); + } + } + } let ty = lhs.narrow(rhs)?; - self.from_type(&ty.as_type().binop_type(op)) + self.from_type(&ty.as_type().simple_binop_type(op)) } } + +fn try_binop<'o>( + lhs: &mut TypeRef<'o>, + rhs: &mut TypeRef<'o>, + binop: &ScopeBinopDef, +) -> Option> { + let (lhs_ty, rhs_ty, ret_ty) = + TypeKind::binop_type(&lhs.resolve_deep()?, &rhs.resolve_deep()?, binop)?; + lhs.narrow(&lhs.1.from_type(&lhs_ty).unwrap()).unwrap(); + rhs.narrow(&rhs.1.from_type(&rhs_ty).unwrap()).unwrap(); + lhs.1.from_type(&ret_ty) +}