diff --git a/examples/a.reid b/examples/a.reid new file mode 100644 index 0000000..8b53d10 --- /dev/null +++ b/examples/a.reid @@ -0,0 +1,6 @@ +pub fn abs(f: f32) -> f32 { + if f < 0.0 { + return f * (0.0 - 1.0); + } + return f; +} diff --git a/reid/src/ast/process.rs b/reid/src/ast/process.rs index 071b461..2d8c114 100644 --- a/reid/src/ast/process.rs +++ b/reid/src/ast/process.rs @@ -3,8 +3,8 @@ use std::path::PathBuf; use crate::{ ast::{self}, mir::{ - self, CustomTypeKey, ModuleMap, NamedVariableRef, ReturnKind, SourceModuleId, StmtKind, - StructField, StructType, WhileStatement, + self, CustomTypeKey, ModuleMap, NamedVariableRef, ReturnKind, SourceModuleId, StmtKind, StructField, + StructType, WhileStatement, }, }; @@ -162,9 +162,7 @@ impl ast::Block { *range, ), ast::BlockLevelStatement::Import { _i } => todo!(), - ast::BlockLevelStatement::Expression(e) => { - (StmtKind::Expression(e.process(module_id)), e.1) - } + ast::BlockLevelStatement::Expression(e) => (StmtKind::Expression(e.process(module_id)), e.1), ast::BlockLevelStatement::Return(_, e) => { if let Some(e) = e { (StmtKind::Expression(e.process(module_id)), e.1) @@ -197,11 +195,10 @@ impl ast::Block { counter_range.as_meta(module_id), )), Box::new(mir::Expression( - mir::ExprKind::Literal(mir::Literal::Vague( - mir::VagueLiteral::Number(1), - )), + mir::ExprKind::Literal(mir::Literal::Vague(mir::VagueLiteral::Number(1))), counter_range.as_meta(module_id), )), + mir::TypeKind::Vague(mir::VagueType::Unknown), ), counter_range.as_meta(module_id), ), @@ -220,6 +217,7 @@ impl ast::Block { counter_range.as_meta(module_id), )), Box::new(end.process(module_id)), + mir::TypeKind::Vague(mir::VagueType::Unknown), ), counter_range.as_meta(module_id), ), @@ -292,22 +290,15 @@ impl ast::Expression { binary_operator.mir(), Box::new(lhs.process(module_id)), Box::new(rhs.process(module_id)), + mir::TypeKind::Vague(mir::VagueType::Unknown), ), - ast::ExpressionKind::FunctionCall(fn_call_expr) => { - mir::ExprKind::FunctionCall(mir::FunctionCall { - name: fn_call_expr.0.clone(), - return_type: mir::TypeKind::Vague(mir::VagueType::Unknown), - parameters: fn_call_expr - .1 - .iter() - .map(|e| e.process(module_id)) - .collect(), - meta: fn_call_expr.2.as_meta(module_id), - }) - } - ast::ExpressionKind::BlockExpr(block) => { - mir::ExprKind::Block(block.into_mir(module_id)) - } + ast::ExpressionKind::FunctionCall(fn_call_expr) => mir::ExprKind::FunctionCall(mir::FunctionCall { + name: fn_call_expr.0.clone(), + return_type: mir::TypeKind::Vague(mir::VagueType::Unknown), + parameters: fn_call_expr.1.iter().map(|e| e.process(module_id)).collect(), + meta: fn_call_expr.2.as_meta(module_id), + }), + ast::ExpressionKind::BlockExpr(block) => mir::ExprKind::Block(block.into_mir(module_id)), ast::ExpressionKind::IfExpr(if_expression) => { let cond = if_expression.0.process(module_id); let then_block = if_expression.1.process(module_id); @@ -364,6 +355,7 @@ impl ast::Expression { expr.1.as_meta(module_id), )), Box::new(expr.process(module_id)), + mir::TypeKind::Vague(mir::VagueType::Unknown), ), ast::UnaryOperator::Minus => mir::ExprKind::BinOp( mir::BinaryOperator::Minus, @@ -372,6 +364,7 @@ impl ast::Expression { expr.1.as_meta(module_id), )), Box::new(expr.process(module_id)), + mir::TypeKind::Vague(mir::VagueType::Unknown), ), }, ast::ExpressionKind::CastTo(expression, ty) => mir::ExprKind::CastTo( @@ -457,15 +450,11 @@ impl ast::TypeKind { ast::TypeKind::Array(type_kind, length) => { mir::TypeKind::Array(Box::new(type_kind.clone().into_mir(source_mod)), *length) } - ast::TypeKind::Custom(name) => { - mir::TypeKind::CustomType(CustomTypeKey(name.clone(), source_mod)) - } + ast::TypeKind::Custom(name) => mir::TypeKind::CustomType(CustomTypeKey(name.clone(), source_mod)), ast::TypeKind::Borrow(type_kind, mutable) => { mir::TypeKind::Borrow(Box::new(type_kind.clone().into_mir(source_mod)), *mutable) } - ast::TypeKind::Ptr(type_kind) => { - mir::TypeKind::UserPtr(Box::new(type_kind.clone().into_mir(source_mod))) - } + ast::TypeKind::Ptr(type_kind) => mir::TypeKind::UserPtr(Box::new(type_kind.clone().into_mir(source_mod))), ast::TypeKind::F16 => mir::TypeKind::F16, ast::TypeKind::F32B => mir::TypeKind::F32B, ast::TypeKind::F32 => mir::TypeKind::F32, diff --git a/reid/src/codegen/allocator.rs b/reid/src/codegen/allocator.rs index 839d7b3..441064c 100644 --- a/reid/src/codegen/allocator.rs +++ b/reid/src/codegen/allocator.rs @@ -6,8 +6,7 @@ use reid_lib::{ }; use crate::mir::{ - self, CustomTypeKey, FunctionCall, FunctionDefinitionKind, IfExpression, SourceModuleId, - TypeKind, WhileStatement, + self, CustomTypeKey, FunctionCall, FunctionDefinitionKind, IfExpression, SourceModuleId, TypeKind, WhileStatement, }; #[derive(Debug)] @@ -74,9 +73,7 @@ impl mir::FunctionDefinitionKind { mir::FunctionDefinitionKind::Intrinsic(_) => {} } - Allocator { - allocations: allocated, - } + Allocator { allocations: allocated } } } @@ -126,9 +123,7 @@ impl mir::Statement { crate::mir::StmtKind::Expression(expression) => { allocated.extend(expression.allocate(scope)); } - crate::mir::StmtKind::While(WhileStatement { - condition, block, .. - }) => { + crate::mir::StmtKind::While(WhileStatement { condition, block, .. }) => { allocated.extend(condition.allocate(scope)); allocated.extend(block.allocate(scope)); } @@ -162,7 +157,7 @@ impl mir::Expression { } } crate::mir::ExprKind::Literal(_) => {} - crate::mir::ExprKind::BinOp(_, lhs, rhs) => { + crate::mir::ExprKind::BinOp(_, lhs, rhs, _) => { allocated.extend(lhs.allocate(scope)); allocated.extend(rhs.allocate(scope)); } diff --git a/reid/src/codegen/mod.rs b/reid/src/codegen/mod.rs index 1326944..b94366a 100644 --- a/reid/src/codegen/mod.rs +++ b/reid/src/codegen/mod.rs @@ -5,20 +5,18 @@ use intrinsics::*; use reid_lib::{ compile::CompiledModule, debug_information::{ - DebugFileData, DebugLocalVariable, DebugLocation, DebugMetadata, DebugRecordKind, - DebugSubprogramData, DebugSubprogramOptionals, DebugSubprogramType, DebugTypeData, - DwarfFlags, InstructionDebugRecordData, + DebugFileData, DebugLocalVariable, DebugLocation, DebugMetadata, DebugRecordKind, DebugSubprogramData, + DebugSubprogramOptionals, DebugSubprogramType, DebugTypeData, DwarfFlags, InstructionDebugRecordData, }, - CmpPredicate, ConstValue, Context, CustomTypeKind, Function, FunctionFlags, Instr, Module, - NamedStruct, TerminatorKind as Term, Type, + CmpPredicate, ConstValue, Context, CustomTypeKind, Function, FunctionFlags, Instr, Module, NamedStruct, + TerminatorKind as Term, Type, }; use scope::*; use crate::{ mir::{ - self, implement::TypeCategory, pass::ScopeBinopKey, CustomTypeKey, FunctionDefinitionKind, - NamedVariableRef, SourceModuleId, StructField, StructType, TypeDefinitionKind, TypeKind, - WhileStatement, + self, implement::TypeCategory, pass::ScopeBinopKey, CustomTypeKey, FunctionDefinitionKind, NamedVariableRef, + SourceModuleId, StructField, StructType, TypeDefinitionKind, TypeKind, WhileStatement, }, util::try_all, }; @@ -83,9 +81,7 @@ struct State { impl State { /// Sets should load, returning a new state fn load(self, should: bool) -> State { - State { - should_load: should, - } + State { should_load: should } } } @@ -235,10 +231,7 @@ impl mir::Module { let ir_function = module.function( &binop_fn_name, binop.return_type.get_type(&type_values), - vec![ - binop.lhs.1.get_type(&type_values), - binop.rhs.1.get_type(&type_values), - ], + vec![binop.lhs.1.get_type(&type_values), binop.rhs.1.get_type(&type_values)], FunctionFlags { inline: true, ..Default::default() @@ -287,9 +280,7 @@ impl mir::Module { &binop.return_type, &ir_function, match &binop.fn_kind { - FunctionDefinitionKind::Local(_, meta) => { - meta.into_debug(tokens, compile_unit) - } + FunctionDefinitionKind::Local(_, meta) => meta.into_debug(tokens, compile_unit), FunctionDefinitionKind::Extern(_) => None, FunctionDefinitionKind::Intrinsic(_) => None, }, @@ -352,9 +343,7 @@ impl mir::Module { &mir_function.return_type, &function, match &mir_function.kind { - FunctionDefinitionKind::Local(..) => { - mir_function.signature().into_debug(tokens, compile_unit) - } + FunctionDefinitionKind::Local(..) => mir_function.signature().into_debug(tokens, compile_unit), FunctionDefinitionKind::Extern(_) => None, FunctionDefinitionKind::Intrinsic(_) => None, }, @@ -386,13 +375,10 @@ impl FunctionDefinitionKind { let fn_param_ty = &return_type.get_debug_type(&debug, scope); - let debug_ty = - debug - .info - .debug_type(DebugTypeData::Subprogram(DebugSubprogramType { - parameters: vec![*fn_param_ty], - flags: DwarfFlags, - })); + let debug_ty = debug.info.debug_type(DebugTypeData::Subprogram(DebugSubprogramType { + parameters: vec![*fn_param_ty], + flags: DwarfFlags, + })); let subprogram = debug.info.subprogram(DebugSubprogramData { name: name.clone(), @@ -477,9 +463,7 @@ impl FunctionDefinitionKind { } if let Some(debug) = &scope.debug { - if let Some(location) = - &block.return_meta().into_debug(scope.tokens, debug.scope) - { + if let Some(location) = &block.return_meta().into_debug(scope.tokens, debug.scope) { let location = debug.info.location(&debug.scope, *location); scope.block.set_terminator_location(location).unwrap(); } @@ -536,11 +520,7 @@ impl mir::Block { } impl mir::Statement { - fn codegen<'ctx, 'a>( - &self, - scope: &mut Scope<'ctx, 'a>, - state: &State, - ) -> Result, ErrorKind> { + fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>, state: &State) -> Result, ErrorKind> { let location = scope.debug.clone().map(|d| { let location = self.1.into_debug(scope.tokens, d.scope).unwrap(); d.info.location(&d.scope, location) @@ -557,10 +537,7 @@ impl mir::Statement { let store = scope .block - .build_named( - format!("{}.store", name), - Instr::Store(alloca, value.instr()), - ) + .build_named(format!("{}.store", name), Instr::Store(alloca, value.instr())) .unwrap() .maybe_location(&mut scope.block, location); @@ -632,17 +609,12 @@ impl mir::Statement { } mir::StmtKind::Import(_) => todo!(), mir::StmtKind::Expression(expression) => expression.codegen(scope, state), - mir::StmtKind::While(WhileStatement { - condition, block, .. - }) => { + mir::StmtKind::While(WhileStatement { condition, block, .. }) => { let condition_block = scope.function.block("while.cond"); let condition_true_block = scope.function.block("while.body"); let condition_failed_block = scope.function.block("while.end"); - scope - .block - .terminate(Term::Br(condition_block.value())) - .unwrap(); + scope.block.terminate(Term::Br(condition_block.value())).unwrap(); let mut condition_scope = scope.with_block(condition_block); let condition_res = condition.codegen(&mut condition_scope, state)?.unwrap(); let true_instr = condition_scope @@ -651,11 +623,7 @@ impl mir::Statement { .unwrap(); let check = condition_scope .block - .build(Instr::ICmp( - CmpPredicate::EQ, - condition_res.instr(), - true_instr, - )) + .build(Instr::ICmp(CmpPredicate::EQ, condition_res.instr(), true_instr)) .unwrap(); condition_scope @@ -685,16 +653,13 @@ impl mir::Statement { } impl mir::Expression { - fn codegen<'ctx, 'a>( - &self, - scope: &mut Scope<'ctx, 'a>, - state: &State, - ) -> Result, ErrorKind> { + fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>, state: &State) -> Result, ErrorKind> { let location = if let Some(debug) = &scope.debug { - Some(debug.info.location( - &debug.scope, - self.1.into_debug(scope.tokens, debug.scope).unwrap(), - )) + Some( + debug + .info + .location(&debug.scope, self.1.into_debug(scope.tokens, debug.scope).unwrap()), + ) } else { None }; @@ -715,10 +680,7 @@ impl mir::Expression { .block .build_named( format!("{}", varref.1), - Instr::Load( - v.0.instr(), - inner.get_type(scope.type_values), - ), + Instr::Load(v.0.instr(), inner.get_type(scope.type_values)), ) .unwrap(), ), @@ -736,13 +698,9 @@ impl mir::Expression { StackValueKind::Literal(lit.as_const(&mut scope.block)), lit.as_type(), )), - mir::ExprKind::BinOp(binop, lhs_exp, rhs_exp) => { - let lhs_val = lhs_exp - .codegen(scope, state)? - .expect("lhs has no return value"); - let rhs_val = rhs_exp - .codegen(scope, state)? - .expect("rhs has no return value"); + mir::ExprKind::BinOp(binop, lhs_exp, rhs_exp, return_ty) => { + let lhs_val = lhs_exp.codegen(scope, state)?.expect("lhs has no return value"); + let rhs_val = rhs_exp.codegen(scope, state)?.expect("rhs has no return value"); let lhs = lhs_val.instr(); let rhs = rhs_val.instr(); @@ -755,15 +713,8 @@ impl mir::Expression { let a = operation.codegen(&lhs_val, &rhs_val, scope)?; Some(a) } else { - let lhs_type = lhs_exp - .return_type(&Default::default(), scope.module_id) - .unwrap() - .1; - let instr = match ( - binop, - lhs_type.signed(), - lhs_type.category() == TypeCategory::Real, - ) { + let lhs_type = lhs_exp.return_type(&Default::default(), scope.module_id).unwrap().1; + let instr = match (binop, lhs_type.signed(), lhs_type.category() == TypeCategory::Real) { (mir::BinaryOperator::Add, _, false) => Instr::Add(lhs, rhs), (mir::BinaryOperator::Add, _, true) => Instr::FAdd(lhs, rhs), (mir::BinaryOperator::Minus, _, false) => Instr::Sub(lhs, rhs), @@ -771,12 +722,8 @@ impl mir::Expression { (mir::BinaryOperator::Mult, _, false) => Instr::Mul(lhs, rhs), (mir::BinaryOperator::Mult, _, true) => Instr::FMul(lhs, rhs), (mir::BinaryOperator::And, _, _) => Instr::And(lhs, rhs), - (mir::BinaryOperator::Cmp(i), _, false) => { - Instr::ICmp(i.predicate(), lhs, rhs) - } - (mir::BinaryOperator::Cmp(i), _, true) => { - Instr::FCmp(i.predicate(), lhs, rhs) - } + (mir::BinaryOperator::Cmp(i), _, false) => Instr::ICmp(i.predicate(), lhs, rhs), + (mir::BinaryOperator::Cmp(i), _, true) => Instr::FCmp(i.predicate(), lhs, rhs), (mir::BinaryOperator::Div, false, false) => Instr::UDiv(lhs, rhs), (mir::BinaryOperator::Div, true, false) => Instr::SDiv(lhs, rhs), (mir::BinaryOperator::Div, _, true) => Instr::FDiv(lhs, rhs), @@ -828,15 +775,12 @@ impl mir::Expression { .unwrap() .maybe_location(&mut scope.block, location), ), - lhs_type, + return_ty.clone(), )) } } mir::ExprKind::FunctionCall(call) => { - let ret_type_kind = call - .return_type - .known() - .expect("function return type unknown"); + let ret_type_kind = call.return_type.known().expect("function return type unknown"); let ret_type = ret_type_kind.get_type(scope.type_values); @@ -852,17 +796,11 @@ impl mir::Expression { .collect::>(); let param_instrs = params.iter().map(|e| e.instr()).collect(); - let callee = scope - .functions - .get(&call.name) - .expect("function not found!"); + let callee = scope.functions.get(&call.name).expect("function not found!"); let val = scope .block - .build_named( - call.name.clone(), - Instr::FunctionCall(callee.value(), param_instrs), - ) + .build_named(call.name.clone(), Instr::FunctionCall(callee.value(), param_instrs)) .unwrap(); if let Some(debug) = &scope.debug { @@ -939,10 +877,7 @@ impl mir::Expression { let (ptr, contained_ty) = if let TypeKind::UserPtr(further_inner) = *inner.clone() { let loaded = scope .block - .build_named( - "load", - Instr::Load(kind.instr(), inner.get_type(scope.type_values)), - ) + .build_named("load", Instr::Load(kind.instr(), inner.get_type(scope.type_values))) .unwrap(); ( scope @@ -960,10 +895,7 @@ impl mir::Expression { ( scope .block - .build_named( - format!("array.gep"), - Instr::GetElemPtr(kind.instr(), vec![idx]), - ) + .build_named(format!("array.gep"), Instr::GetElemPtr(kind.instr(), vec![idx])) .unwrap() .maybe_location(&mut scope.block, location), val_t.clone(), @@ -980,10 +912,7 @@ impl mir::Expression { ( scope .block - .build_named( - format!("array.gep"), - Instr::GetElemPtr(kind.instr(), vec![first, idx]), - ) + .build_named(format!("array.gep"), Instr::GetElemPtr(kind.instr(), vec![first, idx])) .unwrap() .maybe_location(&mut scope.block, location), val_t.clone(), @@ -995,10 +924,7 @@ impl mir::Expression { kind.derive( scope .block - .build_named( - "array.load", - Instr::Load(ptr, contained_ty.get_type(scope.type_values)), - ) + .build_named("array.load", Instr::Load(ptr, contained_ty.get_type(scope.type_values))) .unwrap() .maybe_location(&mut scope.block, location), ), @@ -1012,21 +938,14 @@ impl mir::Expression { } } mir::ExprKind::Array(expressions) => { - let stack_value_list: Vec<_> = try_all( - expressions - .iter() - .map(|e| e.codegen(scope, state)) - .collect::>(), - ) - .map_err(|e| e.first().cloned().unwrap())? - .into_iter() - .map(|v| v.unwrap()) - .collect(); + let stack_value_list: Vec<_> = + try_all(expressions.iter().map(|e| e.codegen(scope, state)).collect::>()) + .map_err(|e| e.first().cloned().unwrap())? + .into_iter() + .map(|v| v.unwrap()) + .collect(); - let instr_list = stack_value_list - .iter() - .map(|s| s.instr()) - .collect::>(); + let instr_list = stack_value_list.iter().map(|s| s.instr()).collect::>(); let elem_ty_kind = stack_value_list .iter() @@ -1053,10 +972,7 @@ impl mir::Expression { let index_expr = scope .block - .build_named( - index.to_string(), - Instr::Constant(ConstValue::U32(index as u32)), - ) + .build_named(index.to_string(), Instr::Constant(ConstValue::U32(index as u32))) .unwrap(); let first = scope .block @@ -1094,8 +1010,7 @@ impl mir::Expression { let TypeKind::CustomType(key) = *inner.clone() else { panic!("tried accessing non-custom-type"); }; - let TypeDefinitionKind::Struct(struct_ty) = - scope.get_typedef(&key).unwrap().kind.clone(); + let TypeDefinitionKind::Struct(struct_ty) = scope.get_typedef(&key).unwrap().kind.clone(); let idx = struct_ty.find_index(field).unwrap(); let gep_n = format!("{}.{}.gep", key.0, field); @@ -1103,10 +1018,7 @@ impl mir::Expression { let value = scope .block - .build_named( - gep_n, - Instr::GetStructElemPtr(struct_val.instr(), idx as u32), - ) + .build_named(gep_n, Instr::GetStructElemPtr(struct_val.instr(), idx as u32)) .unwrap(); // value.maybe_location(&mut scope.block, location); @@ -1116,10 +1028,7 @@ impl mir::Expression { struct_val.0.derive( scope .block - .build_named( - load_n, - Instr::Load(value, type_kind.get_type(scope.type_values)), - ) + .build_named(load_n, Instr::Load(value, type_kind.get_type(scope.type_values))) .unwrap(), ), struct_ty.get_field_ty(&field).unwrap().clone(), @@ -1127,9 +1036,7 @@ impl mir::Expression { } else { Some(StackValue( struct_val.0.derive(value), - TypeKind::CodegenPtr(Box::new( - struct_ty.get_field_ty(&field).unwrap().clone(), - )), + TypeKind::CodegenPtr(Box::new(struct_ty.get_field_ty(&field).unwrap().clone())), )) } } @@ -1222,10 +1129,7 @@ impl mir::Expression { .block .build_named( format!("{}.deref.inner", varref.1), - Instr::Load( - var_ptr_instr, - inner.get_type(scope.type_values), - ), + Instr::Load(var_ptr_instr, inner.get_type(scope.type_values)), ) .unwrap(), ), @@ -1253,17 +1157,14 @@ impl mir::Expression { Some(val) } else { match (&val.1, type_kind) { - (TypeKind::CodegenPtr(inner), TypeKind::UserPtr(_)) => match *inner.clone() - { + (TypeKind::CodegenPtr(inner), TypeKind::UserPtr(_)) => match *inner.clone() { TypeKind::UserPtr(_) => Some(StackValue( val.0.derive( scope .block .build(Instr::BitCast( val.instr(), - Type::Ptr(Box::new( - type_kind.get_type(scope.type_values), - )), + Type::Ptr(Box::new(type_kind.get_type(scope.type_values))), )) .unwrap(), ), @@ -1278,10 +1179,7 @@ impl mir::Expression { val.0.derive( scope .block - .build(Instr::BitCast( - val.instr(), - type_kind.get_type(scope.type_values), - )) + .build(Instr::BitCast(val.instr(), type_kind.get_type(scope.type_values))) .unwrap(), ), type_kind.clone(), @@ -1290,10 +1188,7 @@ impl mir::Expression { let cast_instr = val .1 .get_type(scope.type_values) - .cast_instruction( - val.instr(), - &type_kind.get_type(scope.type_values), - ) + .cast_instruction(val.instr(), &type_kind.get_type(scope.type_values)) .unwrap(); Some(StackValue( @@ -1313,11 +1208,7 @@ impl mir::Expression { } impl mir::IfExpression { - fn codegen<'ctx, 'a>( - &self, - scope: &mut Scope<'ctx, 'a>, - state: &State, - ) -> Result, ErrorKind> { + fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>, state: &State) -> Result, ErrorKind> { let condition = self.0.codegen(scope, state)?.unwrap(); // Create blocks @@ -1389,10 +1280,7 @@ impl mir::IfExpression { incoming.extend(else_res.clone()); let instr = scope .block - .build_named( - "phi", - Instr::Phi(incoming.iter().map(|i| i.instr()).collect()), - ) + .build_named("phi", Instr::Phi(incoming.iter().map(|i| i.instr()).collect())) .unwrap(); use StackValueKind::*; diff --git a/reid/src/lib.rs b/reid/src/lib.rs index 6b445bb..a573898 100644 --- a/reid/src/lib.rs +++ b/reid/src/lib.rs @@ -182,7 +182,7 @@ pub fn perform_all_passes<'map>( #[cfg(debug_assertions)] println!("{}", &refs); #[cfg(debug_assertions)] - println!("{}", &context); + println!("{:#}", &context); #[cfg(debug_assertions)] dbg!(&state); diff --git a/reid/src/mir/fmt.rs b/reid/src/mir/fmt.rs index ce8fe9a..a761a51 100644 --- a/reid/src/mir/fmt.rs +++ b/reid/src/mir/fmt.rs @@ -152,24 +152,14 @@ impl Display for StmtKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { StmtKind::Let(var, mutable, block) => { - write!( - f, - "let{} {} = {}", - if *mutable { " mut" } else { "" }, - var, - block - ) + write!(f, "let{} {} = {}", if *mutable { " mut" } else { "" }, var, block) } StmtKind::Set(var, expr) => write!(f, "{} = {}", var, expr), StmtKind::Import(n) => write!(f, "import {}", n), StmtKind::Expression(exp) => Display::fmt(exp, f), StmtKind::While(while_statement) => { - write!( - f, - "while {} {}", - while_statement.condition, while_statement.block, - ) + write!(f, "while {} {}", while_statement.condition, while_statement.block,) } } } @@ -189,7 +179,12 @@ impl Display for ExprKind { match self { ExprKind::Variable(var) => Display::fmt(var, f), ExprKind::Literal(lit) => Display::fmt(lit, f), - ExprKind::BinOp(op, lhs, rhs) => write!(f, "{} {} {}", lhs, op, rhs), + ExprKind::BinOp(op, lhs, rhs, ty) => { + write!(f, "{} {} {} (= ", lhs, op, rhs)?; + Debug::fmt(ty, f)?; + f.write_char(')')?; + Ok(()) + } ExprKind::FunctionCall(fc) => Display::fmt(fc, f), ExprKind::If(if_exp) => Display::fmt(&if_exp, f), ExprKind::Block(block) => Display::fmt(block, f), diff --git a/reid/src/mir/implement.rs b/reid/src/mir/implement.rs index 0168d22..d685dd9 100644 --- a/reid/src/mir/implement.rs +++ b/reid/src/mir/implement.rs @@ -359,10 +359,7 @@ impl Statement { expr.return_type(refs, mod_id)?, Err(ReturnTypeOther::Let(var.2 + expr.1)), ), - Set(lhs, rhs) => if_hard( - rhs.return_type(refs, mod_id)?, - Err(ReturnTypeOther::Set(lhs.1 + rhs.1)), - ), + Set(lhs, rhs) => if_hard(rhs.return_type(refs, mod_id)?, Err(ReturnTypeOther::Set(lhs.1 + rhs.1))), Import(_) => todo!(), Expression(expression) => expression.return_type(refs, mod_id), While(_) => Err(ReturnTypeOther::Loop), @@ -390,11 +387,14 @@ impl Expression { match &self.0 { Literal(lit) => Ok((ReturnKind::Soft, lit.as_type())), Variable(var) => var.return_type(), - BinOp(_, then_e, else_e) => { + BinOp(_, then_e, else_e, return_ty) => { let then_r = then_e.return_type(refs, mod_id)?; let else_r = else_e.return_type(refs, mod_id)?; - Ok(pick_return(then_r, else_r)) + Ok(match (then_r.0, else_r.0) { + (ReturnKind::Hard, ReturnKind::Hard) => (ReturnKind::Hard, return_ty.clone()), + _ => (ReturnKind::Soft, return_ty.clone()), + }) } Block(block) => block.return_type(refs, mod_id), FunctionCall(fcall) => fcall.return_type(), @@ -457,7 +457,7 @@ impl Expression { ExprKind::Array(_) => None, ExprKind::Struct(_, _) => None, ExprKind::Literal(_) => None, - ExprKind::BinOp(_, _, _) => None, + ExprKind::BinOp(_, _, _, _) => None, ExprKind::FunctionCall(_) => None, ExprKind::If(_) => None, ExprKind::CastTo(expression, _) => expression.backing_var(), @@ -472,7 +472,7 @@ impl Expression { ExprKind::Array(_) => None, ExprKind::Struct(..) => None, ExprKind::Literal(literal) => literal.num_value(), - ExprKind::BinOp(op, lhs, rhs) => match op { + ExprKind::BinOp(op, lhs, rhs, _) => match op { BinaryOperator::Add => maybe(lhs.num_value()?, rhs.num_value()?, |a, b| a + b), BinaryOperator::Minus => maybe(lhs.num_value()?, rhs.num_value()?, |a, b| a - b), BinaryOperator::Mult => maybe(lhs.num_value()?, rhs.num_value()?, |a, b| a * b), @@ -604,9 +604,7 @@ pub enum EqualsIssue { impl FunctionDefinition { pub fn equals_as_imported(&self, other: &FunctionDefinition) -> Result<(), EqualsIssue> { match &self.kind { - FunctionDefinitionKind::Local(_, metadata) => { - Err(EqualsIssue::ExistsLocally(*metadata)) - } + FunctionDefinitionKind::Local(_, metadata) => Err(EqualsIssue::ExistsLocally(*metadata)), FunctionDefinitionKind::Extern(imported) => { if *imported { Err(EqualsIssue::ConflictWithImport(self.name.clone())) @@ -618,10 +616,7 @@ impl FunctionDefinition { { Ok(()) } else { - Err(EqualsIssue::AlreadyExtern( - self.name.clone(), - self.signature(), - )) + Err(EqualsIssue::AlreadyExtern(self.name.clone(), self.signature())) } } } diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index d43772a..2871072 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -256,7 +256,7 @@ pub enum ExprKind { Array(Vec), Struct(String, Vec<(String, Expression)>), Literal(Literal), - BinOp(BinaryOperator, Box, Box), + BinOp(BinaryOperator, Box, Box, TypeKind), FunctionCall(FunctionCall), If(IfExpression), Block(Block), @@ -270,11 +270,7 @@ pub struct Expression(pub ExprKind, pub Metadata); /// Condition, Then, Else #[derive(Debug)] -pub struct IfExpression( - pub Box, - pub Box, - pub Box>, -); +pub struct IfExpression(pub Box, pub Box, pub Box>); #[derive(Debug)] pub struct FunctionCall { diff --git a/reid/src/mir/typecheck/mod.rs b/reid/src/mir/typecheck/mod.rs index ea46dc4..1df3323 100644 --- a/reid/src/mir/typecheck/mod.rs +++ b/reid/src/mir/typecheck/mod.rs @@ -85,50 +85,41 @@ impl TypeKind { } match (self, other) { - (TypeKind::Vague(Vague::Integer), other) | (other, TypeKind::Vague(Vague::Integer)) => { - match other { - TypeKind::Vague(Vague::Unknown) => Ok(TypeKind::Vague(Vague::Integer)), - TypeKind::Vague(Vague::Integer) => Ok(TypeKind::Vague(Vague::Integer)), - TypeKind::I8 - | TypeKind::I16 - | TypeKind::I32 - | TypeKind::I64 - | TypeKind::I128 - | TypeKind::U8 - | TypeKind::U16 - | TypeKind::U32 - | TypeKind::U64 - | TypeKind::U128 => Ok(other.clone()), - _ => Err(ErrorKind::TypesIncompatible(self.clone(), other.clone())), - } - } - (TypeKind::Vague(Vague::Decimal), other) | (other, TypeKind::Vague(Vague::Decimal)) => { - match other { - TypeKind::Vague(Vague::Unknown) => Ok(TypeKind::Vague(Vague::Decimal)), - TypeKind::Vague(Vague::Decimal) => Ok(TypeKind::Vague(Vague::Decimal)), - TypeKind::F16 - | TypeKind::F32B - | TypeKind::F32 - | TypeKind::F64 - | TypeKind::F80 - | TypeKind::F128 - | TypeKind::F128PPC => Ok(other.clone()), - _ => Err(ErrorKind::TypesIncompatible(self.clone(), other.clone())), - } - } - (TypeKind::Vague(Vague::Unknown), other) | (other, TypeKind::Vague(Vague::Unknown)) => { - Ok(other.clone()) - } + (TypeKind::Vague(Vague::Integer), other) | (other, TypeKind::Vague(Vague::Integer)) => match other { + TypeKind::Vague(Vague::Unknown) => Ok(TypeKind::Vague(Vague::Integer)), + TypeKind::Vague(Vague::Integer) => Ok(TypeKind::Vague(Vague::Integer)), + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::U8 + | TypeKind::U16 + | TypeKind::U32 + | TypeKind::U64 + | TypeKind::U128 => Ok(other.clone()), + _ => Err(ErrorKind::TypesIncompatible(self.clone(), other.clone())), + }, + (TypeKind::Vague(Vague::Decimal), other) | (other, TypeKind::Vague(Vague::Decimal)) => match other { + TypeKind::Vague(Vague::Unknown) => Ok(TypeKind::Vague(Vague::Decimal)), + TypeKind::Vague(Vague::Decimal) => Ok(TypeKind::Vague(Vague::Decimal)), + TypeKind::F16 + | TypeKind::F32B + | TypeKind::F32 + | TypeKind::F64 + | TypeKind::F80 + | TypeKind::F128 + | TypeKind::F128PPC => Ok(other.clone()), + _ => Err(ErrorKind::TypesIncompatible(self.clone(), other.clone())), + }, + (TypeKind::Vague(Vague::Unknown), other) | (other, TypeKind::Vague(Vague::Unknown)) => Ok(other.clone()), (TypeKind::Borrow(val1, mut1), TypeKind::Borrow(val2, mut2)) => { // Extracted to give priority for other collapse-error let collapsed = val1.narrow_into(val2)?; if mut1 == mut2 { Ok(TypeKind::Borrow(Box::new(collapsed), *mut1 && *mut2)) } else { - Err(ErrorKind::TypesDifferMutability( - self.clone(), - other.clone(), - )) + Err(ErrorKind::TypesDifferMutability(self.clone(), other.clone())) } } (TypeKind::UserPtr(val1), TypeKind::UserPtr(val2)) => { @@ -146,36 +137,30 @@ impl TypeKind { (TypeKind::Vague(Vague::Unknown), other) | (other, TypeKind::Vague(Vague::Unknown)) => { TypeKind::Vague(VagueType::Unknown) } - (TypeKind::Vague(Vague::Integer), other) | (other, TypeKind::Vague(Vague::Integer)) => { - match other { - TypeKind::I8 - | TypeKind::I16 - | TypeKind::I32 - | TypeKind::I64 - | TypeKind::I128 - | TypeKind::U8 - | TypeKind::U16 - | TypeKind::U32 - | TypeKind::U64 - | TypeKind::U128 => TypeKind::Vague(VagueType::Integer), - _ => TypeKind::Vague(VagueType::Unknown), - } - } - (TypeKind::Vague(Vague::Decimal), other) | (other, TypeKind::Vague(Vague::Decimal)) => { - match other { - TypeKind::F16 - | TypeKind::F32B - | TypeKind::F32 - | TypeKind::F64 - | TypeKind::F80 - | TypeKind::F128 - | TypeKind::F128PPC => TypeKind::Vague(VagueType::Decimal), - _ => TypeKind::Vague(VagueType::Unknown), - } - } - (TypeKind::UserPtr(val1), TypeKind::UserPtr(val2)) => { - TypeKind::UserPtr(Box::new(val1.widen_into(val2))) - } + (TypeKind::Vague(Vague::Integer), other) | (other, TypeKind::Vague(Vague::Integer)) => match other { + TypeKind::I8 + | TypeKind::I16 + | TypeKind::I32 + | TypeKind::I64 + | TypeKind::I128 + | TypeKind::U8 + | TypeKind::U16 + | TypeKind::U32 + | TypeKind::U64 + | TypeKind::U128 => TypeKind::Vague(VagueType::Integer), + _ => TypeKind::Vague(VagueType::Unknown), + }, + (TypeKind::Vague(Vague::Decimal), other) | (other, TypeKind::Vague(Vague::Decimal)) => match other { + TypeKind::F16 + | TypeKind::F32B + | TypeKind::F32 + | TypeKind::F64 + | TypeKind::F80 + | TypeKind::F128 + | TypeKind::F128PPC => TypeKind::Vague(VagueType::Decimal), + _ => TypeKind::Vague(VagueType::Unknown), + }, + (TypeKind::UserPtr(val1), TypeKind::UserPtr(val2)) => TypeKind::UserPtr(Box::new(val1.widen_into(val2))), (TypeKind::CodegenPtr(val1), TypeKind::CodegenPtr(val2)) => { TypeKind::CodegenPtr(Box::new(val1.widen_into(val2))) } @@ -245,16 +230,10 @@ impl TypeKind { Vague::TypeRef(_) => panic!("Hinted default!"), VagueType::Decimal => TypeKind::F32, }, - TypeKind::Array(type_kind, len) => { - TypeKind::Array(Box::new(type_kind.or_default()?), *len) - } - TypeKind::Borrow(type_kind, mutable) => { - TypeKind::Borrow(Box::new(type_kind.or_default()?), *mutable) - } + TypeKind::Array(type_kind, len) => TypeKind::Array(Box::new(type_kind.or_default()?), *len), + TypeKind::Borrow(type_kind, mutable) => TypeKind::Borrow(Box::new(type_kind.or_default()?), *mutable), TypeKind::UserPtr(type_kind) => TypeKind::UserPtr(Box::new(type_kind.or_default()?)), - TypeKind::CodegenPtr(type_kind) => { - TypeKind::CodegenPtr(Box::new(type_kind.or_default()?)) - } + TypeKind::CodegenPtr(type_kind) => TypeKind::CodegenPtr(Box::new(type_kind.or_default()?)), _ => self.clone(), }) } @@ -270,37 +249,29 @@ impl TypeKind { let resolved = self.resolve_weak(refs); match resolved { TypeKind::Array(t, len) => TypeKind::Array(Box::new(t.resolve_ref(refs)), len), - TypeKind::Borrow(inner, mutable) => { - TypeKind::Borrow(Box::new(inner.resolve_ref(refs)), mutable) - } + TypeKind::Borrow(inner, mutable) => TypeKind::Borrow(Box::new(inner.resolve_ref(refs)), mutable), _ => resolved, } } - pub(super) fn assert_known( - &self, - refs: &TypeRefs, - state: &TypecheckPassState, - ) -> Result { + pub(super) fn assert_known(&self, refs: &TypeRefs, state: &TypecheckPassState) -> Result { self.is_known(refs, state).map(|_| self.clone()) } - pub(super) fn is_known( - &self, - refs: &TypeRefs, - state: &TypecheckPassState, - ) -> Result<(), ErrorKind> { + pub(super) fn is_known(&self, refs: &TypeRefs, state: &TypecheckPassState) -> Result<(), ErrorKind> { match &self { TypeKind::Array(type_kind, _) => type_kind.as_ref().is_known(refs, state), - TypeKind::CustomType(custom_type_key) => state - .scope - .types - .get(custom_type_key) - .map(|_| ()) - .ok_or(ErrorKind::NoSuchType( - custom_type_key.0.clone(), - state.module_id.unwrap(), - )), + TypeKind::CustomType(custom_type_key) => { + state + .scope + .types + .get(custom_type_key) + .map(|_| ()) + .ok_or(ErrorKind::NoSuchType( + custom_type_key.0.clone(), + state.module_id.unwrap(), + )) + } TypeKind::Borrow(type_kind, _) => type_kind.is_known(refs, state), TypeKind::UserPtr(type_kind) => type_kind.is_known(refs, state), TypeKind::CodegenPtr(type_kind) => type_kind.is_known(refs, state), diff --git a/reid/src/mir/typecheck/typecheck.rs b/reid/src/mir/typecheck/typecheck.rs index a9efa5b..9106acf 100644 --- a/reid/src/mir/typecheck/typecheck.rs +++ b/reid/src/mir/typecheck/typecheck.rs @@ -410,28 +410,34 @@ impl Expression { *literal = literal.clone().try_coerce(hint_t.cloned())?; Ok(literal.as_type()) } - ExprKind::BinOp(op, lhs, rhs) => { + ExprKind::BinOp(op, lhs, rhs, ret_ty) => { // First find unfiltered parameters to binop let lhs_res = lhs.typecheck(state, &typerefs, None); let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1); let rhs_res = rhs.typecheck(state, &typerefs, None); let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1); + let expected_return_ty = ret_ty.resolve_ref(typerefs); - if let Some(binop) = typerefs - .binop_types - .find(&pass::ScopeBinopKey { - params: (lhs_type.clone(), rhs_type.clone()), - operator: *op, - }) + let binops = typerefs.binop_types.filter(&pass::ScopeBinopKey { + params: (lhs_type.clone(), rhs_type.clone()), + operator: *op, + }); + if let Some(binop) = binops + .iter() + .filter(|f| f.1.return_ty == expected_return_ty) .map(|v| (v.1.clone())) + .next() { lhs.typecheck(state, &typerefs, Some(&binop.hands.0))?; rhs.typecheck(state, &typerefs, Some(&binop.hands.1))?; - Ok(binop.narrow(&lhs_type, &rhs_type).unwrap().2) + *ret_ty = binop.narrow(&lhs_type, &rhs_type).unwrap().2; + Ok(ret_ty.clone()) } else { + dbg!(&binops); dbg!(&op, &lhs, &rhs); dbg!(&lhs_type); dbg!(&rhs_type); + dbg!(&expected_return_ty); panic!() } } diff --git a/reid/src/mir/typecheck/typeinference.rs b/reid/src/mir/typecheck/typeinference.rs index efc61a0..9ed66bc 100644 --- a/reid/src/mir/typecheck/typeinference.rs +++ b/reid/src/mir/typecheck/typeinference.rs @@ -291,7 +291,7 @@ impl Expression { type_ref } ExprKind::Literal(literal) => Ok(type_refs.from_type(&literal.as_type()).unwrap()), - ExprKind::BinOp(op, lhs, rhs) => { + ExprKind::BinOp(op, lhs, rhs, return_ty) => { // Infer LHS and RHS, and return binop type let mut lhs_ref = lhs.infer_types(state, type_refs)?; let mut rhs_ref = rhs.infer_types(state, type_refs)?; @@ -306,11 +306,17 @@ impl Expression { widened_lhs = widened_lhs.widen_into(&binop.hands.0); widened_rhs = widened_rhs.widen_into(&binop.hands.1); } + let binop_res = type_refs.from_binop(*op, &lhs_ref, &rhs_ref); lhs_ref.narrow(&type_refs.from_type(&widened_lhs).unwrap()); rhs_ref.narrow(&type_refs.from_type(&widened_rhs).unwrap()); - Ok(type_refs.from_binop(*op, &lhs_ref, &rhs_ref)) + *return_ty = binop_res.as_type(); + Ok(binop_res) } else { - panic!(); + Err(ErrorKind::InvalidBinop( + *op, + lhs_ref.resolve_deep().unwrap(), + rhs_ref.resolve_deep().unwrap(), + )) } } ExprKind::FunctionCall(function_call) => { diff --git a/reid/src/mir/typecheck/typerefs.rs b/reid/src/mir/typecheck/typerefs.rs index 9cc3f8c..c870399 100644 --- a/reid/src/mir/typecheck/typerefs.rs +++ b/reid/src/mir/typecheck/typerefs.rs @@ -71,7 +71,10 @@ impl TypeRefKind { .binop_types .iter() .filter(|b| b.1.operator == *op) - .map(|b| b.1.narrow(&lhs, &rhs).map(|b| b.2)) + .map(|b| { + b.1.narrow(&lhs.resolve_ref(types), &rhs.resolve_ref(types)) + .map(|b| b.2) + }) .filter_map(|s| s); if let Some(mut ty) = binops.next() { while let Some(other) = binops.next() { @@ -105,11 +108,11 @@ impl std::fmt::Display for TypeRefs { let idx = *typeref.borrow(); writeln!( f, - "{:<3} = {:<3} = {:?} = {}", + "{:<3} = {:<3} = {:?} = {:?}", i, unsafe { *self.recurse_type_ref(idx).borrow() }, + self.retrieve_typeref(idx), self.retrieve_wide_type(idx), - TypeKind::Vague(VagueType::TypeRef(idx)).resolve_ref(self) )?; } Ok(()) @@ -176,9 +179,13 @@ impl TypeRefs { return refs.get_unchecked(idx).clone(); } - pub fn retrieve_wide_type(&self, idx: usize) -> Option { + pub fn retrieve_typeref(&self, idx: usize) -> Option { let inner_idx = unsafe { *self.recurse_type_ref(idx).borrow() }; - self.hints.borrow().get(inner_idx).cloned().map(|t| t.widen(self)) + self.hints.borrow().get(inner_idx).cloned() + } + + pub fn retrieve_wide_type(&self, idx: usize) -> Option { + self.retrieve_typeref(idx).map(|t| t.widen(self)) } } @@ -248,22 +255,45 @@ impl<'outer> ScopeTypeRefs<'outer> { unsafe { let mut hints = self.types.hints.borrow_mut(); let existing = hints.get_unchecked_mut(*hint.0.borrow()); + match existing { TypeRefKind::Direct(type_kind) => { *type_kind = type_kind.narrow_into(&ty).ok()?; } TypeRefKind::BinOp(op, lhs, rhs) => { + let op = op.clone(); + let lhs = lhs.clone(); + let rhs = rhs.clone(); + drop(hints); + + let lhs_resolved = lhs.resolve_ref(self.types); + let rhs_resolved = rhs.resolve_ref(self.types); + let binops = self .types .binop_types .iter() - .filter(|b| b.1.operator == *op && b.1.return_ty == *ty); + .filter(|b| b.1.operator == op && b.1.return_ty == *ty) + .collect::>(); for binop in binops { - if let (Ok(lhs_narrow), Ok(rhs_narrow)) = - (lhs.narrow_into(&binop.1.hands.0), rhs.narrow_into(&binop.1.hands.1)) - { - *lhs = lhs_narrow; - *rhs = rhs_narrow + if let (Ok(lhs_narrow), Ok(rhs_narrow)) = ( + lhs_resolved.narrow_into(&binop.1.hands.0), + rhs_resolved.narrow_into(&binop.1.hands.1), + ) { + match &lhs { + TypeKind::Vague(VagueType::TypeRef(idx)) => { + let mut lhs_ref = TypeRef(Rc::new(RefCell::new(*idx)), self); + let narrowed = self.narrow_to_type(&mut lhs_ref, &lhs_narrow).unwrap_or(lhs_ref); + } + _ => {} + }; + match &rhs { + TypeKind::Vague(VagueType::TypeRef(idx)) => { + let mut rhs_ref = TypeRef(Rc::new(RefCell::new(*idx)), self); + let narrowed = self.narrow_to_type(&mut rhs_ref, &rhs_narrow).unwrap_or(rhs_ref); + } + _ => {} + } } } }