diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index add82c0..544d528 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -56,9 +56,6 @@ pub struct Scope<'ctx, 'a> { type_values: &'a HashMap, functions: &'a HashMap>, stack_values: HashMap, - // True if the current expression should attemt to load it's pointer value, - // or keep it as a pointer (mainly used for Set-statement). - should_load: bool, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -97,7 +94,6 @@ impl<'ctx, 'a> Scope<'ctx, 'a> { types: self.types, type_values: self.type_values, stack_values: self.stack_values.clone(), - should_load: self.should_load, } } @@ -112,12 +108,19 @@ impl<'ctx, 'a> Scope<'ctx, 'a> { fn get_typedef(&self, name: &String) -> Option<&TypeDefinitionKind> { self.type_values.get(name).and_then(|v| self.types.get(v)) } +} - /// Sets should load, returning the old value - fn set_should_load(&mut self, should: bool) -> bool { - let old = self.should_load; - self.should_load = should; - old +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Default, Clone, Copy)] +struct State { + should_load: bool, +} + +impl State { + /// Sets should load, returning a new state + fn load(self, should: bool) -> State { + State { + should_load: should, + } } } @@ -207,11 +210,11 @@ impl mir::Module { types: &types, type_values: &type_values, stack_values, - should_load: true, }; match &mir_function.kind { mir::FunctionDefinitionKind::Local(block, _) => { - if let Some(ret) = block.codegen(&mut scope) { + let mut state = State::default(); + if let Some(ret) = block.codegen(&mut scope, &mut state) { scope.block.terminate(Term::Ret(ret)).unwrap(); } else { if !scope.block.delete_if_unused().unwrap() { @@ -230,20 +233,23 @@ impl mir::Module { } impl mir::Block { - fn codegen<'ctx, 'a>(&self, mut scope: &mut Scope<'ctx, 'a>) -> Option { + fn codegen<'ctx, 'a>( + &self, + mut scope: &mut Scope<'ctx, 'a>, + state: &State, + ) -> Option { for stmt in &self.statements { - stmt.codegen(&mut scope); + stmt.codegen(&mut scope, state); } if let Some((kind, expr)) = &self.return_expression { match kind { mir::ReturnKind::Hard => { - scope.should_load = true; - let ret = expr.codegen(&mut scope)?; + let ret = expr.codegen(&mut scope, &mut state.load(true))?; scope.block.terminate(Term::Ret(ret)).unwrap(); None } - mir::ReturnKind::Soft => expr.codegen(&mut scope), + mir::ReturnKind::Soft => expr.codegen(&mut scope, state), } } else { None @@ -252,10 +258,14 @@ impl mir::Block { } impl mir::Statement { - fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option { + fn codegen<'ctx, 'a>( + &self, + scope: &mut Scope<'ctx, 'a>, + state: &State, + ) -> Option { match &self.0 { mir::StmtKind::Let(NamedVariableRef(ty, name, _), mutable, expression) => { - let value = expression.codegen(scope).unwrap(); + let value = expression.codegen(scope, state).unwrap(); scope.stack_values.insert( name.clone(), StackValue( @@ -291,13 +301,11 @@ impl mir::Statement { None } mir::StmtKind::Set(lhs, rhs) => { - let old = scope.set_should_load(false); let lhs_value = lhs - .codegen(scope) + .codegen(scope, &mut state.load(false)) .expect("non-returning LHS snuck into codegen!"); - scope.should_load = old; - let rhs_value = rhs.codegen(scope)?; + let rhs_value = rhs.codegen(scope, state)?; Some( scope @@ -307,13 +315,17 @@ impl mir::Statement { ) } mir::StmtKind::Import(_) => todo!(), - mir::StmtKind::Expression(expression) => expression.codegen(scope), + mir::StmtKind::Expression(expression) => expression.codegen(scope, state), } } } impl mir::Expression { - fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option { + fn codegen<'ctx, 'a>( + &self, + scope: &mut Scope<'ctx, 'a>, + state: &State, + ) -> Option { match &self.0 { mir::ExprKind::Variable(varref) => { varref.0.known().expect("variable type unknown"); @@ -324,7 +336,7 @@ impl mir::Expression { Some(match v.0 { StackValueKind::Immutable(val) => val.clone(), StackValueKind::Mutable(val) => { - if scope.should_load { + if state.should_load { match v.1 { // TODO probably wrong ..? Type::Ptr(_) => val, @@ -351,8 +363,12 @@ impl mir::Expression { .known() .expect("rhs ret type is unknown"); - let lhs = lhs_exp.codegen(scope).expect("lhs has no return value"); - let rhs = rhs_exp.codegen(scope).expect("rhs has no return value"); + let lhs = lhs_exp + .codegen(scope, state) + .expect("lhs has no return value"); + let rhs = rhs_exp + .codegen(scope, state) + .expect("rhs has no return value"); Some(match binop { mir::BinaryOperator::Add => scope.block.build(Instr::Add(lhs, rhs)).unwrap(), mir::BinaryOperator::Minus => scope.block.build(Instr::Sub(lhs, rhs)).unwrap(), @@ -372,7 +388,7 @@ impl mir::Expression { let params = call .parameters .iter() - .map(|e| e.codegen(scope).unwrap()) + .map(|e| e.codegen(scope, state).unwrap()) .collect(); let callee = scope .functions @@ -385,10 +401,10 @@ impl mir::Expression { .unwrap(), ) } - mir::ExprKind::If(if_expression) => if_expression.codegen(scope), + mir::ExprKind::If(if_expression) => if_expression.codegen(scope, state), mir::ExprKind::Block(block) => { let mut inner_scope = scope.with_block(scope.function.block("inner")); - if let Some(ret) = block.codegen(&mut inner_scope) { + if let Some(ret) = block.codegen(&mut inner_scope, state) { inner_scope .block .terminate(Term::Br(scope.block.value())) @@ -399,14 +415,14 @@ impl mir::Expression { } } mir::ExprKind::Indexed(expression, val_t, idx_expr) => { - let array = expression.codegen(scope)?; - let idx = idx_expr.codegen(scope)?; + let array = expression.codegen(scope, state)?; + let idx = idx_expr.codegen(scope, state)?; let mut ptr = scope .block .build(Instr::GetElemPtr(array, vec![idx])) .unwrap(); - if scope.should_load { + if state.should_load { ptr = scope .block .build(Instr::Load( @@ -421,7 +437,7 @@ impl mir::Expression { mir::ExprKind::Array(expressions) => { let instr_list = expressions .iter() - .map(|e| e.codegen(scope).unwrap()) + .map(|e| e.codegen(scope, state).unwrap()) .collect::>(); let instr_t = expressions .iter() @@ -452,9 +468,7 @@ impl mir::Expression { Some(array) } mir::ExprKind::Accessed(expression, type_kind, field) => { - let old = scope.set_should_load(true); - let struct_val = expression.codegen(scope)?; - scope.should_load = old; + let struct_val = expression.codegen(scope, &mut state.load(true))?; let struct_ty = expression.return_type().ok()?.1.known().ok()?; let TypeKind::CustomType(name) = struct_ty else { @@ -468,7 +482,7 @@ impl mir::Expression { .build(Instr::GetStructElemPtr(struct_val, idx as u32)) .unwrap(); - if scope.should_load { + if state.should_load { value = scope .block .build(Instr::Load( @@ -494,7 +508,7 @@ impl mir::Expression { .block .build(Instr::GetStructElemPtr(struct_ptr, i as u32)) .unwrap(); - if let Some(val) = exp.codegen(scope) { + if let Some(val) = exp.codegen(scope, state) { scope.block.build(Instr::Store(elem_ptr, val)).unwrap(); } } @@ -506,8 +520,12 @@ impl mir::Expression { } impl mir::IfExpression { - fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option { - let condition = self.0.codegen(scope).unwrap(); + fn codegen<'ctx, 'a>( + &self, + scope: &mut Scope<'ctx, 'a>, + state: &State, + ) -> Option { + let condition = self.0.codegen(scope, state).unwrap(); // Create blocks let then_b = scope.function.block("then"); @@ -521,7 +539,7 @@ impl mir::IfExpression { // Generate then-block content let mut then_scope = scope.with_block(then_b); - let then_res = self.1.codegen(&mut then_scope); + let then_res = self.1.codegen(&mut then_scope, state); then_scope.block.terminate(Term::Br(after_bb)).ok(); let else_res = if let Some(else_block) = &self.2 { @@ -531,7 +549,7 @@ impl mir::IfExpression { .terminate(Term::CondBr(condition, then_bb, else_bb)) .unwrap(); - let opt = else_block.codegen(&mut else_scope); + let opt = else_block.codegen(&mut else_scope, state); if let Some(ret) = opt { else_scope.block.terminate(Term::Br(after_bb)).ok(); @@ -560,86 +578,6 @@ impl mir::IfExpression { } } } - -// impl IndexedVariableReference { -// fn get_stack_value(&self, scope: &mut Scope_after_gep: bool) -> Option { -// match &self.kind { -// mir::IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => { -// scope.stack_values.get(name).cloned().map(|v| v) -// } -// mir::IndexedVariableReferenceKind::ArrayIndex(inner, idx) => { -// let inner_stack_val = inner.get_stack_value(scope, true)?; - -// let mut gep_instr = scope -// .block -// .build(Instr::GetElemPtr( -// unsafe { *inner_stack_val.0.get_instr() }, -// vec![*idx as u32], -// )) -// .unwrap(); - -// match &inner_stack_val.1 { -// Type::Ptr(inner_ty) => { -// if load_after_gep { -// gep_instr = scope -// .block -// .build(Instr::Load(gep_instr, *inner_ty.clone())) -// .unwrap() -// } -// Some(StackValue( -// inner_stack_val.0.with_instr(gep_instr), -// *inner_ty.clone(), -// )) -// } -// _ => panic!("Tried to codegen indexing a non-indexable value!"), -// } -// } -// mir::IndexedVariableReferenceKind::StructIndex(inner, field) => { -// let inner_stack_val = inner.get_stack_value(scope, true)?; - -// let (instr_value, inner_ty) = if let Type::Ptr(inner_ty) = inner_stack_val.1 { -// if let Type::CustomType(ty_val) = *inner_ty { -// match scope.types.get(&ty_val).unwrap() { -// TypeDefinitionKind::Struct(struct_type) => { -// let idx = struct_type.find_index(field)?; -// let field_ty = struct_type -// .get_field_ty(field)? -// .get_type(scope.type_values, scope.types); - -// let mut gep_instr = scope -// .block -// .build(Instr::GetStructElemPtr( -// unsafe { *inner_stack_val.0.get_instr() }, -// idx, -// )) -// .unwrap(); - -// if load_after_gep { -// gep_instr = scope -// .block -// .build(Instr::Load(gep_instr, field_ty.clone())) -// .unwrap() -// } - -// Some((gep_instr, field_ty)) -// } -// } -// } else { -// None -// } -// } else { -// None -// }?; - -// Some(StackValue( -// inner_stack_val.0.with_instr(instr_value), -// Type::Ptr(Box::new(inner_ty)), -// )) -// } -// } -// } -// } - impl mir::CmpOperator { fn int_predicate(&self) -> CmpPredicate { match self {