diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index 9f99114..2af6349 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -175,6 +175,22 @@ pub enum StackValueKind { Mutable(InstructionValue), } +impl StackValueKind { + unsafe fn get_instr(&self) -> &InstructionValue { + match self { + StackValueKind::Immutable(val) => val, + StackValueKind::Mutable(val) => val, + } + } + + fn with_instr(&self, instr: InstructionValue) -> StackValueKind { + match self { + StackValueKind::Immutable(_) => StackValueKind::Immutable(instr), + StackValueKind::Mutable(_) => StackValueKind::Mutable(instr), + } + } +} + impl<'ctx, 'a> Scope<'ctx, 'a> { fn with_block(&self, block: Block<'ctx>) -> Scope<'ctx, 'a> { Scope { @@ -242,34 +258,12 @@ impl mir::Statement { None } mir::StmtKind::Set(var, val) => { - if let Some((StackValue(kind, mut ty), indices)) = var.get_stack_value(scope) { + if let Some(StackValue(kind, _)) = var.get_stack_value(scope, false) { match kind { StackValueKind::Immutable(_) => { panic!("Tried to mutate an immutable variable") } - StackValueKind::Mutable(mut ptr) => { - for (i, idx_kind) in indices.iter().enumerate() { - let Type::Ptr(inner) = ty else { panic!() }; - ty = *inner; - - match idx_kind { - IndexKind::Array(idx) => { - ptr = scope - .block - .build(Instr::GetElemPtr(ptr, vec![*idx])) - .unwrap(); - } - IndexKind::Struct(idx) => { - ptr = scope - .block - .build(Instr::GetStructElemPtr(ptr, *idx)) - .unwrap(); - } - } - if i < (indices.len() - 1) { - ptr = scope.block.build(Instr::Load(ptr, ty.clone())).unwrap() - } - } + StackValueKind::Mutable(ptr) => { let expression = val.codegen(scope).unwrap(); Some(scope.block.build(Instr::Store(ptr, expression)).unwrap()) } @@ -524,34 +518,67 @@ pub enum IndexKind { } impl IndexedVariableReference { - fn get_stack_value(&self, scope: &mut Scope) -> Option<(StackValue, Vec)> { + fn get_stack_value(&self, scope: &mut Scope, load_after_gep: bool) -> Option { match &self.kind { - mir::IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => scope - .stack_values - .get(name) - .cloned() - .map(|v| (v, Vec::new())), + mir::IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => { + scope.stack_values.get(name).cloned().map(|v| v) + } mir::IndexedVariableReferenceKind::ArrayIndex(inner, idx) => { - let (inner_val, mut indices) = inner.get_stack_value(scope)?; + let inner_stack_val = inner.get_stack_value(scope, true)?; - match &inner_val.1 { - Type::Ptr(_) => { - indices.push(IndexKind::Array(*idx as u32)); - Some((inner_val, indices)) + let mut gep_instr = scope + .block + .build(Instr::GetElemPtr( + unsafe { *inner_stack_val.0.get_instr() }, + vec![*idx as u32], + )) + .unwrap(); + + match &inner_stack_val.1 { + Type::Ptr(inner_ty) => { + if load_after_gep { + gep_instr = scope + .block + .build(Instr::Load(gep_instr, *inner_ty.clone())) + .unwrap() + } + Some(StackValue( + inner_stack_val.0.with_instr(gep_instr), + *inner_ty.clone(), + )) } _ => panic!("Tried to codegen indexing a non-indexable value!"), } } mir::IndexedVariableReferenceKind::StructIndex(inner, field) => { - let (inner_val, mut indices) = inner.get_stack_value(scope)?; + let inner_stack_val = inner.get_stack_value(scope, true)?; - let (idx, elem_ty) = if let Type::Ptr(inner_ty) = inner_val.1 { + let (instr_value, inner_ty) = if let Type::Ptr(inner_ty) = inner_stack_val.1 { if let Type::CustomType(ty_val) = *inner_ty { match scope.types.get(&ty_val).unwrap() { - TypeDefinitionKind::Struct(struct_type) => Some(( - struct_type.find_index(field)?, - struct_type.get_field_ty(field)?, - )), + TypeDefinitionKind::Struct(struct_type) => { + let idx = struct_type.find_index(field)?; + let field_ty = struct_type + .get_field_ty(field)? + .get_type(scope.type_values, scope.types); + + let mut gep_instr = scope + .block + .build(Instr::GetStructElemPtr( + unsafe { *inner_stack_val.0.get_instr() }, + idx, + )) + .unwrap(); + + if load_after_gep { + gep_instr = scope + .block + .build(Instr::Load(gep_instr, field_ty.clone())) + .unwrap() + } + + Some((gep_instr, field_ty)) + } } } else { None @@ -560,13 +587,9 @@ impl IndexedVariableReference { None }?; - indices.push(IndexKind::Struct(idx as u32)); - Some(( - StackValue( - inner_val.0, - Type::Ptr(Box::new(elem_ty.get_type(scope.type_values, scope.types))), - ), - indices, + Some(StackValue( + inner_stack_val.0.with_instr(instr_value), + Type::Ptr(Box::new(inner_ty)), )) } }