diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index 5a2539d..b970a55 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -99,21 +99,42 @@ pub struct StackFunction<'ctx> { #[derive(Debug, Clone, PartialEq, Eq)] pub struct StackValue(StackValueKind, TypeKind); +impl StackValue { + fn instr(&self) -> InstructionValue { + self.0.instr() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StackValueKind { Immutable(InstructionValue), Mutable(InstructionValue), - Any(InstructionValue), + Literal(InstructionValue), } impl StackValueKind { - unsafe fn get_inner(&self) -> InstructionValue { + fn instr(&self) -> InstructionValue { match &self { StackValueKind::Immutable(val) => *val, StackValueKind::Mutable(val) => *val, - StackValueKind::Any(val) => *val, + StackValueKind::Literal(val) => *val, } } + + fn derive(&self, instr: InstructionValue) -> StackValueKind { + match &self { + StackValueKind::Immutable(_) => StackValueKind::Immutable(instr), + StackValueKind::Mutable(_) => StackValueKind::Mutable(instr), + StackValueKind::Literal(_) => StackValueKind::Literal(instr), + } + } + + fn map(&self, lambda: F) -> StackValueKind + where + F: FnOnce(InstructionValue) -> InstructionValue, + { + self.derive(lambda(self.instr())) + } } impl<'ctx, 'a> Scope<'ctx, 'a> { @@ -369,7 +390,7 @@ impl mir::Module { mir::FunctionDefinitionKind::Local(block, _) => { let state = State::default(); if let Some(ret) = block.codegen(&mut scope, &state) { - scope.block.terminate(Term::Ret(ret)).unwrap(); + scope.block.terminate(Term::Ret(ret.instr())).unwrap(); } else { if !scope.block.delete_if_unused().unwrap() { // Add a void return just in case if the block @@ -397,13 +418,13 @@ impl mir::Block { &self, mut scope: &mut Scope<'ctx, 'a>, state: &State, - ) -> Option { + ) -> Option { for stmt in &self.statements { stmt.codegen(&mut scope, state).map(|s| { if let Some(debug) = &scope.debug { let location = stmt.1.into_debug(scope.tokens).unwrap(); let loc_val = debug.info.location(&debug.scope, location); - s.with_location(&mut scope.block, loc_val); + s.instr().with_location(&mut scope.block, loc_val); } }); } @@ -412,7 +433,7 @@ impl mir::Block { match kind { mir::ReturnKind::Hard => { let ret = expr.codegen(&mut scope, &state)?; - scope.block.terminate(Term::Ret(ret)).unwrap(); + scope.block.terminate(Term::Ret(ret.instr())).unwrap(); None } mir::ReturnKind::Soft => expr.codegen(&mut scope, state), @@ -424,11 +445,7 @@ impl mir::Block { } impl mir::Statement { - fn codegen<'ctx, 'a>( - &self, - scope: &mut Scope<'ctx, 'a>, - state: &State, - ) -> Option { + fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>, state: &State) -> Option { let location = self.1.into_debug(scope.tokens).unwrap(); let location = scope .debug @@ -450,7 +467,7 @@ impl mir::Statement { let store = scope .block - .build(Instr::Store(alloca, value)) + .build(Instr::Store(alloca, value.instr())) .unwrap() .maybe_location(&mut scope.block, location); @@ -483,12 +500,12 @@ impl mir::Statement { InstructionDebugRecordData { variable: var, location, - kind: DebugRecordKind::Declare(value), + kind: DebugRecordKind::Declare(value.instr()), scope: debug.scope, }, ); } - StackValueKind::Any(_) => {} + StackValueKind::Literal(_) => {} } } None @@ -500,13 +517,23 @@ impl mir::Statement { let rhs_value = rhs.codegen(scope, state)?; - Some( - scope - .block - .build(Instr::Store(lhs_value, rhs_value)) - .unwrap() - .maybe_location(&mut scope.block, location), - ) + match lhs_value.0 { + StackValueKind::Immutable(_) => { + panic!("Tried to assign to immutable!") + } + StackValueKind::Mutable(instr) => { + scope + .block + .build(Instr::Store(instr, rhs_value.instr())) + .unwrap() + .maybe_location(&mut scope.block, location); + } + StackValueKind::Literal(_) => { + panic!("Tried to assign to a literal!") + } + }; + + None } mir::StmtKind::Import(_) => todo!(), mir::StmtKind::Expression(expression) => expression.codegen(scope, state), @@ -515,11 +542,7 @@ impl mir::Statement { } impl mir::Expression { - fn codegen<'ctx, 'a>( - &self, - scope: &mut Scope<'ctx, 'a>, - state: &State, - ) -> Option { + fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>, state: &State) -> Option { let location = if let Some(debug) = &scope.debug { Some( debug @@ -530,7 +553,7 @@ impl mir::Expression { None }; - match &self.0 { + let value = match &self.0 { mir::ExprKind::Variable(varref) => { varref.0.known().expect("variable type unknown"); let v = scope @@ -538,18 +561,23 @@ impl mir::Expression { .get(&varref.1) .expect("Variable reference not found?!"); dbg!(varref); - Some(match v.0 { - StackValueKind::Immutable(val) | StackValueKind::Mutable(val) => scope - .block - .build(Instr::Load( - val, - v.1.get_type(scope.type_values, scope.types), - )) - .unwrap(), - _ => panic!("Found an unknown-mutable variable!"), - }) + Some(StackValue( + v.0.map(|val| { + scope + .block + .build(Instr::Load( + val, + v.1.get_type(scope.type_values, scope.types), + )) + .unwrap() + }), + varref.0.clone(), + )) } - mir::ExprKind::Literal(lit) => Some(lit.as_const(&mut scope.block)), + mir::ExprKind::Literal(lit) => Some(StackValue( + StackValueKind::Literal(lit.as_const(&mut scope.block)), + lit.as_type(), + )), mir::ExprKind::BinOp(binop, lhs_exp, rhs_exp) => { lhs_exp .return_type() @@ -566,41 +594,58 @@ impl mir::Expression { let lhs = lhs_exp .codegen(scope, state) - .expect("lhs has no return value"); + .expect("lhs has no return value") + .instr(); let rhs = rhs_exp .codegen(scope, state) - .expect("rhs has no return value"); - Some(match binop { - mir::BinaryOperator::Add => scope.block.build(Instr::Add(lhs, rhs)).unwrap(), - mir::BinaryOperator::Minus => scope.block.build(Instr::Sub(lhs, rhs)).unwrap(), - mir::BinaryOperator::Mult => scope.block.build(Instr::Mult(lhs, rhs)).unwrap(), - mir::BinaryOperator::And => scope.block.build(Instr::And(lhs, rhs)).unwrap(), - mir::BinaryOperator::Cmp(l) => scope - .block - .build(Instr::ICmp(l.int_predicate(), lhs, rhs)) - .unwrap(), - }) + .expect("rhs has no return value") + .instr(); + Some(StackValue( + StackValueKind::Immutable(match binop { + mir::BinaryOperator::Add => { + scope.block.build(Instr::Add(lhs, rhs)).unwrap() + } + mir::BinaryOperator::Minus => { + scope.block.build(Instr::Sub(lhs, rhs)).unwrap() + } + mir::BinaryOperator::Mult => { + scope.block.build(Instr::Mult(lhs, rhs)).unwrap() + } + mir::BinaryOperator::And => { + scope.block.build(Instr::And(lhs, rhs)).unwrap() + } + mir::BinaryOperator::Cmp(l) => scope + .block + .build(Instr::ICmp(l.int_predicate(), lhs, rhs)) + .unwrap(), + }), + TypeKind::U32, + )) } mir::ExprKind::FunctionCall(call) => { - call.return_type + let ret_type = call + .return_type .known() .expect("function return type unknown"); let params = call .parameters .iter() - .map(|e| e.codegen(scope, state).unwrap()) + .map(|e| e.codegen(scope, state).unwrap().instr()) .collect(); let callee = scope .functions .get(&call.name) .expect("function not found!"); - Some( - scope - .block - .build(Instr::FunctionCall(callee.ir.value(), params)) - .unwrap(), - ) + Some(StackValue( + StackValueKind::Immutable( + scope + .block + .build(Instr::FunctionCall(callee.ir.value(), params)) + .unwrap(), + ), + ret_type, + )) } mir::ExprKind::If(if_expression) => if_expression.codegen(scope, state), mir::ExprKind::Block(block) => { @@ -616,16 +661,17 @@ impl mir::Expression { } } mir::ExprKind::Indexed(expression, val_t, idx_expr) => { - let array = expression + let StackValue(kind, ty) = expression .codegen(scope, &state.load(true)) .expect("array returned none!"); let idx = idx_expr .codegen(scope, &state.load(true)) - .expect("index returned none!"); + .expect("index returned none!") + .instr(); let mut ptr = scope .block - .build(Instr::GetElemPtr(array, vec![idx])) + .build(Instr::GetElemPtr(kind.instr(), vec![idx])) .unwrap() .maybe_location(&mut scope.block, location); @@ -640,16 +686,25 @@ impl mir::Expression { .maybe_location(&mut scope.block, location); } - Some(ptr) + let TypeKind::Array(elem_ty, _) = ty else { + panic!(); + }; + + Some(StackValue(kind.derive(ptr), *elem_ty)) } mir::ExprKind::Array(expressions) => { - let instr_list = expressions + let stack_value_list = expressions .iter() .map(|e| e.codegen(scope, state).unwrap()) .collect::>(); - let instr_t = expressions + let instr_list = stack_value_list .iter() - .map(|e| e.return_type().unwrap().1) + .map(|s| s.instr()) + .collect::>(); + + let instr_t = stack_value_list + .iter() + .map(|s| s.1.clone()) .next() .unwrap_or(TypeKind::Void); @@ -679,28 +734,27 @@ impl mir::Expression { .maybe_location(&mut scope.block, location); } - Some(array) + Some(StackValue( + StackValueKind::Literal(array), + TypeKind::Array(Box::new(instr_t), instr_list.len() as u64), + )) } mir::ExprKind::Accessed(expression, type_kind, field) => { let struct_val = expression.codegen(scope, &mut state.load(true)).unwrap(); - let struct_ty = expression - .return_type() - .map(|r| r.1.known()) - .unwrap() - .unwrap(); - - let TypeKind::CustomType(name) = struct_ty.deref_borrow() else { + let TypeKind::CustomType(name) = struct_val.1.deref_borrow() else { panic!("tried accessing non-custom-type"); }; - let TypeDefinitionKind::Struct(struct_ty) = scope.get_typedef(&name).unwrap(); + let TypeDefinitionKind::Struct(struct_ty) = + scope.get_typedef(&name).unwrap().clone(); let idx = struct_ty.find_index(field).unwrap(); let mut value = scope .block - .build(Instr::GetStructElemPtr(struct_val, idx as u32)) - .unwrap() - .maybe_location(&mut scope.block, location); + .build(Instr::GetStructElemPtr(struct_val.instr(), idx as u32)) + .unwrap(); + + // value.maybe_location(&mut scope.block, location); if state.should_load { value = scope @@ -712,7 +766,10 @@ impl mir::Expression { .unwrap(); } - Some(value) + Some(StackValue( + struct_val.0.derive(value), + struct_ty.get_field_ty(&field).unwrap().clone(), + )) } mir::ExprKind::Struct(name, items) => { let struct_ptr = scope @@ -733,25 +790,27 @@ impl mir::Expression { if let Some(val) = exp.codegen(scope, state) { scope .block - .build(Instr::Store(elem_ptr, val)) + .build(Instr::Store(elem_ptr, val.instr())) .unwrap() .maybe_location(&mut scope.block, location); } } - Some(struct_ptr) + Some(StackValue( + StackValueKind::Literal(struct_ptr), + TypeKind::CustomType(name.clone()), + )) } + }; + if let Some(value) = &value { + value.instr().maybe_location(&mut scope.block, location); } - .map(|i| i.maybe_location(&mut scope.block, location)) + value } } impl mir::IfExpression { - fn codegen<'ctx, 'a>( - &self, - scope: &mut Scope<'ctx, 'a>, - state: &State, - ) -> Option { + fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>, state: &State) -> Option { let condition = self.0.codegen(scope, state).unwrap(); // Create blocks @@ -792,7 +851,7 @@ impl mir::IfExpression { scope .block - .terminate(Term::CondBr(condition, then_bb, else_bb)) + .terminate(Term::CondBr(condition.instr(), then_bb, else_bb)) .unwrap(); let opt = else_block.codegen(&mut else_scope, state); @@ -807,7 +866,7 @@ impl mir::IfExpression { else_b.terminate(Term::Br(after_bb)).unwrap(); scope .block - .terminate(Term::CondBr(condition, then_bb, after_bb)) + .terminate(Term::CondBr(condition.instr(), then_bb, after_bb)) .unwrap(); None }; @@ -819,8 +878,29 @@ impl mir::IfExpression { None } else { let mut incoming = Vec::from(then_res.as_slice()); - incoming.extend(else_res); - Some(scope.block.build(Instr::Phi(incoming)).unwrap()) + incoming.extend(else_res.clone()); + let instr = scope + .block + .build(Instr::Phi(incoming.iter().map(|i| i.instr()).collect())) + .unwrap(); + + use StackValueKind::*; + let value = match (then_res, else_res) { + (None, None) => StackValue(StackValueKind::Immutable(instr), TypeKind::Void), + (Some(val), None) | (None, Some(val)) => StackValue(val.0.derive(instr), val.1), + (Some(lhs_val), Some(rhs_val)) => match (lhs_val.0, rhs_val.0) { + (Immutable(_), Immutable(_)) + | (Immutable(_), Mutable(_)) + | (Mutable(_), Immutable(_)) + | (Immutable(_), Literal(_)) + | (Literal(_), Immutable(_)) + | (Mutable(_), Literal(_)) + | (Literal(_), Mutable(_)) + | (Literal(_), Literal(_)) => StackValue(Immutable(instr), lhs_val.1), + (Mutable(_), Mutable(_)) => StackValue(Mutable(instr), lhs_val.1), + }, + }; + Some(value) } } }