From 31185d921efef4d9b135964ba3cf6e5f1c7b14bb Mon Sep 17 00:00:00 2001 From: sofia Date: Wed, 16 Jul 2025 15:46:55 +0300 Subject: [PATCH] Properly implement structs in lib --- reid-llvm-lib/src/builder.rs | 117 ++++++++++++++++++++++----------- reid-llvm-lib/src/compile.rs | 124 +++++++++++++++++++++++++++-------- reid-llvm-lib/src/debug.rs | 11 +++- reid-llvm-lib/src/lib.rs | 12 +++- reid/src/codegen.rs | 2 +- 5 files changed, 195 insertions(+), 71 deletions(-) diff --git a/reid-llvm-lib/src/builder.rs b/reid-llvm-lib/src/builder.rs index 7ec0f5b..d03e914 100644 --- a/reid-llvm-lib/src/builder.rs +++ b/reid-llvm-lib/src/builder.rs @@ -1,11 +1,11 @@ //! This module contains simply [`Builder`] and it's related utility Values. //! Builder is the actual struct being modified when building the LLIR. -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; use crate::{ - BlockData, ConstValue, FunctionData, Instr, InstructionData, ModuleData, TerminatorKind, Type, - TypeData, util::match_types, + BlockData, ConstValue, CustomTypeKind, FunctionData, Instr, InstructionData, ModuleData, + NamedStruct, TerminatorKind, Type, TypeData, util::match_types, }; #[derive(Clone, Hash, Copy, PartialEq, Eq)] @@ -213,6 +213,18 @@ impl Builder { } } + pub(crate) unsafe fn type_data(&self, value: &TypeValue) -> TypeData { + unsafe { + self.modules + .borrow() + .get_unchecked(value.0.0) + .types + .get_unchecked(value.1) + .data + .clone() + } + } + pub(crate) fn find_module<'ctx>(&'ctx self, value: ModuleValue) -> ModuleHolder { unsafe { self.modules.borrow().get_unchecked(value.0).clone() } } @@ -222,16 +234,15 @@ impl Builder { } pub fn check_instruction(&self, instruction: &InstructionValue) -> Result<(), ()> { - use super::Instr::*; unsafe { match self.instr_data(&instruction).kind { - Param(_) => Ok(()), - Constant(_) => Ok(()), - Add(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), - Sub(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), - Mult(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), - And(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), - ICmp(_, lhs, rhs) => { + Instr::Param(_) => Ok(()), + Instr::Constant(_) => Ok(()), + Instr::Add(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), + Instr::Sub(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), + Instr::Mult(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), + Instr::And(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), + Instr::ICmp(_, lhs, rhs) => { let t = match_types(&lhs, &rhs, self)?; if t.comparable() { Ok(()) @@ -239,7 +250,7 @@ impl Builder { Err(()) // TODO error: Types not comparable } } - FunctionCall(fun, params) => { + Instr::FunctionCall(fun, params) => { let param_types = self.function_data(&fun).params; if param_types.len() != params.len() { return Err(()); // TODO error: invalid amount of params @@ -251,7 +262,7 @@ impl Builder { } Ok(()) } - Phi(vals) => { + Instr::Phi(vals) => { let mut iter = vals.iter(); // TODO error: Phi must contain at least one item @@ -265,40 +276,53 @@ impl Builder { } Ok(()) } - Alloca(_, _) => Ok(()), - Load(ptr, load_ty) => { - if let Ok(ptr_ty) = ptr.get_type(&self) { - if let Type::Ptr(ptr_ty_inner) = ptr_ty { - if *ptr_ty_inner == load_ty { - Ok(()) - } else { - Err(()) - } - } else { - Err(()) - } - } else { - Err(()) - } - } - Store(ptr, _) => { - if let Ok(ty) = ptr.get_type(&self) { - if let Type::Ptr(_) = ty { + Instr::Alloca(_, _) => Ok(()), + Instr::Load(ptr, load_ty) => { + let ptr_ty = ptr.get_type(&self)?; + if let Type::Ptr(ptr_ty_inner) = ptr_ty { + if *ptr_ty_inner == load_ty { Ok(()) } else { - Err(()) + Err(()) // TODO error: inner type mismatch } } else { - Err(()) + Err(()) // TODO error: not a pointer } } - ArrayAlloca(_, _) => Ok(()), - GetElemPtr(arr, _) => { - let arr_ty = arr.get_type(&self)?; - if let Type::Ptr(_) = arr_ty { + Instr::Store(ptr, _) => { + let ty = ptr.get_type(&self)?; + if let Type::Ptr(_) = ty { Ok(()) } else { - Err(()) + Err(()) // TODO error: not a pointer + } + } + Instr::ArrayAlloca(_, _) => Ok(()), + Instr::GetElemPtr(ptr_val, _) => { + let ptr_ty = ptr_val.get_type(&self)?; + if let Type::Ptr(_) = ptr_ty { + Ok(()) + } else { + Err(()) // TODO error: not a pointer + } + } + Instr::GetStructElemPtr(ptr_val, idx) => { + let ptr_ty = ptr_val.get_type(&self)?; + if let Type::Ptr(ty) = ptr_ty { + if let Type::CustomType(val) = *ty { + match self.type_data(&val).kind { + CustomTypeKind::NamedStruct(NamedStruct(_, fields)) => { + if fields.len() <= idx as usize { + return Err(()); // TODO error: no such field + } + } + } + Ok(()) + } else { + Err(()) // TODO error: not a struct + } + } else { + Err(()) // TODO error: not a pointer } } } @@ -364,6 +388,17 @@ impl InstructionValue { Store(_, value) => value.get_type(builder), ArrayAlloca(ty, _) => Ok(Type::Ptr(Box::new(ty.clone()))), GetElemPtr(ptr, _) => ptr.get_type(builder), + GetStructElemPtr(instr, idx) => { + let instr_val = instr.get_type(builder)?; + let Type::CustomType(ty_value) = instr_val else { + panic!("GetStructElemPtr on non-struct! ({:?})", &instr_val) + }; + match builder.type_data(&ty_value).kind { + CustomTypeKind::NamedStruct(NamedStruct(_, fields)) => { + Ok(fields.get_unchecked(*idx as usize).clone()) + } + } + } } } } @@ -383,7 +418,7 @@ impl ConstValue { ConstValue::U32(_) => U32, ConstValue::U64(_) => U64, ConstValue::U128(_) => U128, - ConstValue::String(_) => Ptr(Box::new(I8)), + ConstValue::StringPtr(_) => Ptr(Box::new(I8)), ConstValue::Bool(_) => Bool, } } @@ -405,6 +440,7 @@ impl Type { Type::Bool => true, Type::Void => false, Type::Ptr(_) => false, + Type::CustomType(_) => false, } } @@ -423,6 +459,7 @@ impl Type { Type::Bool => false, Type::Void => false, Type::Ptr(_) => false, + Type::CustomType(_) => false, } } } diff --git a/reid-llvm-lib/src/compile.rs b/reid-llvm-lib/src/compile.rs index 418734d..5660959 100644 --- a/reid-llvm-lib/src/compile.rs +++ b/reid-llvm-lib/src/compile.rs @@ -20,7 +20,11 @@ use llvm_sys::{ }, }; -use crate::util::{ErrorMessageHolder, MemoryBufferHolder, from_cstring, into_cstring}; +use crate::{ + CustomTypeKind, NamedStruct, + builder::{TypeHolder, TypeValue}, + util::{ErrorMessageHolder, MemoryBufferHolder, from_cstring, into_cstring}, +}; use super::{ CmpPredicate, ConstValue, Context, TerminatorKind, Type, @@ -175,6 +179,7 @@ pub struct LLVMModule<'a> { functions: HashMap, blocks: HashMap, values: HashMap, + types: HashMap, } #[derive(Clone, Copy)] @@ -202,12 +207,16 @@ impl ModuleHolder { // Compile the contents - let mut functions = HashMap::new(); + let mut types = HashMap::new(); + for ty in &self.types { + types.insert(ty.value, ty.compile_type(context, &types)); + } + let mut functions = HashMap::new(); for function in &self.functions { functions.insert( function.value, - function.compile_signature(context, module_ref), + function.compile_signature(context, module_ref, &types), ); } @@ -217,6 +226,7 @@ impl ModuleHolder { builder_ref: context.builder_ref, module_ref, functions, + types, blocks: HashMap::new(), values: HashMap::new(), }; @@ -230,19 +240,46 @@ impl ModuleHolder { } } +impl TypeHolder { + unsafe fn compile_type( + &self, + context: &LLVMContext, + types: &HashMap, + ) -> LLVMTypeRef { + unsafe { + match &self.data.kind { + CustomTypeKind::NamedStruct(named_struct) => { + let mut elem_types = Vec::new(); + for ty in &named_struct.1 { + elem_types.push(ty.as_llvm(context.context_ref, types)); + } + let struct_ty = LLVMStructTypeInContext( + context.context_ref, + elem_types.as_mut_ptr(), + elem_types.len() as u32, + 0, + ); + struct_ty + } + } + } + } +} + impl FunctionHolder { unsafe fn compile_signature( &self, context: &LLVMContext, module_ref: LLVMModuleRef, + types: &HashMap, ) -> LLVMFunction { unsafe { - let ret_type = self.data.ret.as_llvm(context.context_ref); + let ret_type = self.data.ret.as_llvm(context.context_ref, types); let mut param_types: Vec = self .data .params .iter() - .map(|t| t.as_llvm(context.context_ref)) + .map(|t| t.as_llvm(context.context_ref, types)) .collect(); let param_ptr = param_types.as_mut_ptr(); let param_len = param_types.len(); @@ -346,7 +383,7 @@ impl InstructionHolder { use super::Instr::*; match &self.data.kind { Param(nth) => LLVMGetParam(function.value_ref, *nth as u32), - Constant(val) => val.as_llvm(module.context_ref, module.builder_ref), + Constant(val) => val.as_llvm(module), Add(lhs, rhs) => { let lhs_val = module.values.get(&lhs).unwrap().value_ref; let rhs_val = module.values.get(&rhs).unwrap().value_ref; @@ -412,7 +449,7 @@ impl InstructionHolder { } let phi = LLVMBuildPhi( module.builder_ref, - _ty.as_llvm(module.context_ref), + _ty.as_llvm(module.context_ref, &module.types), c"phi".as_ptr(), ); LLVMAddIncoming( @@ -425,12 +462,12 @@ impl InstructionHolder { } Alloca(name, ty) => LLVMBuildAlloca( module.builder_ref, - ty.as_llvm(module.context_ref), + ty.as_llvm(module.context_ref, &module.types), into_cstring(name).as_ptr(), ), Load(ptr, ty) => LLVMBuildLoad2( module.builder_ref, - ty.as_llvm(module.context_ref), + ty.as_llvm(module.context_ref, &module.types), module.values.get(&ptr).unwrap().value_ref, c"load".as_ptr(), ), @@ -440,11 +477,10 @@ impl InstructionHolder { module.values.get(&ptr).unwrap().value_ref, ), ArrayAlloca(ty, len) => { - let array_len = ConstValue::U16(*len as u16) - .as_llvm(module.context_ref, module.builder_ref); + let array_len = ConstValue::U16(*len as u16).as_llvm(module); LLVMBuildArrayAlloca( module.builder_ref, - ty.as_llvm(module.context_ref), + ty.as_llvm(module.context_ref, &module.types), array_len, c"array_alloca".as_ptr(), ) @@ -453,20 +489,51 @@ impl InstructionHolder { let t = arr.get_type(module.builder).unwrap(); let Type::Ptr(elem_t) = t else { panic!() }; - let mut indices: Vec<_> = indices + let mut llvm_indices: Vec<_> = indices .iter() - .map(|idx| { - ConstValue::U32(*idx).as_llvm(module.context_ref, module.builder_ref) - }) + .map(|idx| ConstValue::U32(*idx).as_llvm(module)) .collect(); LLVMBuildGEP2( module.builder_ref, - elem_t.as_llvm(module.context_ref), + elem_t.as_llvm(module.context_ref, &module.types), module.values.get(arr).unwrap().value_ref, - indices.as_mut_ptr(), - indices.len() as u32, - c"array_gep".as_ptr(), + llvm_indices.as_mut_ptr(), + llvm_indices.len() as u32, + into_cstring(format!( + "array_gep_{:?}", + indices + .iter() + .map(|v| v.to_string()) + .collect::>() + .join("_") + )) + .as_ptr(), + ) + } + GetStructElemPtr(struct_val, idx) => { + let t = struct_val.get_type(module.builder).unwrap(); + let Type::Ptr(inner_t) = t else { panic!() }; + + let Type::CustomType(struct_ty) = *inner_t else { + panic!(); + }; + let struct_ty_data = module.builder.type_data(&struct_ty); + let (name, elem_ty) = match struct_ty_data.kind { + CustomTypeKind::NamedStruct(NamedStruct(name, fields)) => ( + name, + fields + .get_unchecked(*idx as usize) + .as_llvm(module.context_ref, &module.types), + ), + }; + + LLVMBuildStructGEP2( + module.builder_ref, + elem_ty, + module.values.get(struct_val).unwrap().value_ref, + *idx, + into_cstring(format!("struct_gep_{}_{}", name, idx)).as_ptr(), ) } } @@ -532,9 +599,9 @@ impl CmpPredicate { } impl ConstValue { - fn as_llvm(&self, context: LLVMContextRef, builder: LLVMBuilderRef) -> LLVMValueRef { + fn as_llvm(&self, module: &LLVMModule) -> LLVMValueRef { unsafe { - let t = self.get_type().as_llvm(context); + let t = self.get_type().as_llvm(module.context_ref, &module.types); match self { ConstValue::Bool(val) => LLVMConstInt(t, *val as u64, 1), ConstValue::I8(val) => LLVMConstInt(t, *val as u64, 1), @@ -547,8 +614,8 @@ impl ConstValue { ConstValue::U32(val) => LLVMConstInt(t, *val as u64, 1), ConstValue::U64(val) => LLVMConstInt(t, *val as u64, 1), ConstValue::U128(val) => LLVMConstInt(t, *val as u64, 1), - ConstValue::String(val) => LLVMBuildGlobalStringPtr( - builder, + ConstValue::StringPtr(val) => LLVMBuildGlobalStringPtr( + module.builder_ref, into_cstring(val).as_ptr(), c"string".as_ptr(), ), @@ -558,7 +625,11 @@ impl ConstValue { } impl Type { - fn as_llvm(&self, context: LLVMContextRef) -> LLVMTypeRef { + fn as_llvm( + &self, + context: LLVMContextRef, + typemap: &HashMap, + ) -> LLVMTypeRef { use Type::*; unsafe { match self { @@ -569,7 +640,8 @@ impl Type { I128 | U128 => LLVMInt128TypeInContext(context), Bool => LLVMInt1TypeInContext(context), Void => LLVMVoidTypeInContext(context), - Ptr(ty) => LLVMPointerType(ty.as_llvm(context), 0), + Ptr(ty) => LLVMPointerType(ty.as_llvm(context, typemap), 0), + CustomType(struct_ty) => *typemap.get(struct_ty).unwrap(), } } } diff --git a/reid-llvm-lib/src/debug.rs b/reid-llvm-lib/src/debug.rs index 7842092..6755ae0 100644 --- a/reid-llvm-lib/src/debug.rs +++ b/reid-llvm-lib/src/debug.rs @@ -5,7 +5,7 @@ use std::{ marker::PhantomData, }; -use crate::{CmpPredicate, Instr, InstructionData, TerminatorKind, builder::*}; +use crate::{CmpPredicate, Instr, InstructionData, NamedStruct, TerminatorKind, builder::*}; impl Debug for Builder { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -99,6 +99,12 @@ impl Debug for InstructionValue { } } +impl Debug for TypeValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Ty[{:0>2}-{:0>2}]", &self.0.0, self.1) + } +} + impl Debug for Instr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -126,6 +132,9 @@ impl Debug for Instr { .collect::>() .join(", "), ), + Instr::GetStructElemPtr(instruction_value, index) => { + fmt_index(f, instruction_value, &index.to_string()) + } } } } diff --git a/reid-llvm-lib/src/lib.rs b/reid-llvm-lib/src/lib.rs index f57a23f..7d78550 100644 --- a/reid-llvm-lib/src/lib.rs +++ b/reid-llvm-lib/src/lib.rs @@ -4,7 +4,7 @@ use std::{fmt::Debug, marker::PhantomData}; -use builder::{BlockValue, Builder, FunctionValue, InstructionValue, ModuleValue}; +use builder::{BlockValue, Builder, FunctionValue, InstructionValue, ModuleValue, TypeValue}; use debug::PrintableModule; pub mod builder; @@ -205,6 +205,7 @@ pub enum CmpPredicate { pub enum Instr { Param(usize), Constant(ConstValue), + Add(InstructionValue, InstructionValue), Sub(InstructionValue, InstructionValue), Mult(InstructionValue, InstructionValue), @@ -216,6 +217,7 @@ pub enum Instr { Store(InstructionValue, InstructionValue), ArrayAlloca(Type, u32), GetElemPtr(InstructionValue, Vec), + GetStructElemPtr(InstructionValue, u32), /// Integer Comparison ICmp(CmpPredicate, InstructionValue, InstructionValue), @@ -237,6 +239,7 @@ pub enum Type { U128, Bool, Void, + CustomType(TypeValue), Ptr(Box), } @@ -253,7 +256,7 @@ pub enum ConstValue { U64(u64), U128(u128), Bool(bool), - String(String), + StringPtr(String), } #[derive(Clone, Hash)] @@ -272,5 +275,8 @@ pub struct TypeData { #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub enum CustomTypeKind { - Struct(Vec), + NamedStruct(NamedStruct), } + +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct NamedStruct(String, Vec); diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index a778c2a..edaa07e 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -494,7 +494,7 @@ impl mir::Literal { mir::Literal::U64(val) => ConstValue::U64(val), mir::Literal::U128(val) => ConstValue::U128(val), mir::Literal::Bool(val) => ConstValue::Bool(val), - mir::Literal::String(val) => ConstValue::String(val.clone()), + mir::Literal::String(val) => ConstValue::StringPtr(val.clone()), mir::Literal::Vague(_) => panic!("Got vague literal!"), }) }