Properly implement structs in lib

This commit is contained in:
Sofia 2025-07-16 15:46:55 +03:00
parent 97fc468d78
commit 31185d921e
5 changed files with 195 additions and 71 deletions

View File

@ -1,11 +1,11 @@
//! This module contains simply [`Builder`] and it's related utility Values.
//! Builder is the actual struct being modified when building the LLIR.
use std::{cell::RefCell, rc::Rc};
use std::{cell::RefCell, collections::HashMap, rc::Rc};
use crate::{
BlockData, ConstValue, FunctionData, Instr, InstructionData, ModuleData, TerminatorKind, Type,
TypeData, util::match_types,
BlockData, ConstValue, CustomTypeKind, FunctionData, Instr, InstructionData, ModuleData,
NamedStruct, TerminatorKind, Type, TypeData, util::match_types,
};
#[derive(Clone, Hash, Copy, PartialEq, Eq)]
@ -213,6 +213,18 @@ impl Builder {
}
}
pub(crate) unsafe fn type_data(&self, value: &TypeValue) -> TypeData {
unsafe {
self.modules
.borrow()
.get_unchecked(value.0.0)
.types
.get_unchecked(value.1)
.data
.clone()
}
}
pub(crate) fn find_module<'ctx>(&'ctx self, value: ModuleValue) -> ModuleHolder {
unsafe { self.modules.borrow().get_unchecked(value.0).clone() }
}
@ -222,16 +234,15 @@ impl Builder {
}
pub fn check_instruction(&self, instruction: &InstructionValue) -> Result<(), ()> {
use super::Instr::*;
unsafe {
match self.instr_data(&instruction).kind {
Param(_) => Ok(()),
Constant(_) => Ok(()),
Add(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
Sub(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
Mult(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
And(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
ICmp(_, lhs, rhs) => {
Instr::Param(_) => Ok(()),
Instr::Constant(_) => Ok(()),
Instr::Add(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
Instr::Sub(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
Instr::Mult(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
Instr::And(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()),
Instr::ICmp(_, lhs, rhs) => {
let t = match_types(&lhs, &rhs, self)?;
if t.comparable() {
Ok(())
@ -239,7 +250,7 @@ impl Builder {
Err(()) // TODO error: Types not comparable
}
}
FunctionCall(fun, params) => {
Instr::FunctionCall(fun, params) => {
let param_types = self.function_data(&fun).params;
if param_types.len() != params.len() {
return Err(()); // TODO error: invalid amount of params
@ -251,7 +262,7 @@ impl Builder {
}
Ok(())
}
Phi(vals) => {
Instr::Phi(vals) => {
let mut iter = vals.iter();
// TODO error: Phi must contain at least one item
@ -265,40 +276,53 @@ impl Builder {
}
Ok(())
}
Alloca(_, _) => Ok(()),
Load(ptr, load_ty) => {
if let Ok(ptr_ty) = ptr.get_type(&self) {
if let Type::Ptr(ptr_ty_inner) = ptr_ty {
if *ptr_ty_inner == load_ty {
Ok(())
} else {
Err(())
}
} else {
Err(())
}
} else {
Err(())
}
}
Store(ptr, _) => {
if let Ok(ty) = ptr.get_type(&self) {
if let Type::Ptr(_) = ty {
Instr::Alloca(_, _) => Ok(()),
Instr::Load(ptr, load_ty) => {
let ptr_ty = ptr.get_type(&self)?;
if let Type::Ptr(ptr_ty_inner) = ptr_ty {
if *ptr_ty_inner == load_ty {
Ok(())
} else {
Err(())
Err(()) // TODO error: inner type mismatch
}
} else {
Err(())
Err(()) // TODO error: not a pointer
}
}
ArrayAlloca(_, _) => Ok(()),
GetElemPtr(arr, _) => {
let arr_ty = arr.get_type(&self)?;
if let Type::Ptr(_) = arr_ty {
Instr::Store(ptr, _) => {
let ty = ptr.get_type(&self)?;
if let Type::Ptr(_) = ty {
Ok(())
} else {
Err(())
Err(()) // TODO error: not a pointer
}
}
Instr::ArrayAlloca(_, _) => Ok(()),
Instr::GetElemPtr(ptr_val, _) => {
let ptr_ty = ptr_val.get_type(&self)?;
if let Type::Ptr(_) = ptr_ty {
Ok(())
} else {
Err(()) // TODO error: not a pointer
}
}
Instr::GetStructElemPtr(ptr_val, idx) => {
let ptr_ty = ptr_val.get_type(&self)?;
if let Type::Ptr(ty) = ptr_ty {
if let Type::CustomType(val) = *ty {
match self.type_data(&val).kind {
CustomTypeKind::NamedStruct(NamedStruct(_, fields)) => {
if fields.len() <= idx as usize {
return Err(()); // TODO error: no such field
}
}
}
Ok(())
} else {
Err(()) // TODO error: not a struct
}
} else {
Err(()) // TODO error: not a pointer
}
}
}
@ -364,6 +388,17 @@ impl InstructionValue {
Store(_, value) => value.get_type(builder),
ArrayAlloca(ty, _) => Ok(Type::Ptr(Box::new(ty.clone()))),
GetElemPtr(ptr, _) => ptr.get_type(builder),
GetStructElemPtr(instr, idx) => {
let instr_val = instr.get_type(builder)?;
let Type::CustomType(ty_value) = instr_val else {
panic!("GetStructElemPtr on non-struct! ({:?})", &instr_val)
};
match builder.type_data(&ty_value).kind {
CustomTypeKind::NamedStruct(NamedStruct(_, fields)) => {
Ok(fields.get_unchecked(*idx as usize).clone())
}
}
}
}
}
}
@ -383,7 +418,7 @@ impl ConstValue {
ConstValue::U32(_) => U32,
ConstValue::U64(_) => U64,
ConstValue::U128(_) => U128,
ConstValue::String(_) => Ptr(Box::new(I8)),
ConstValue::StringPtr(_) => Ptr(Box::new(I8)),
ConstValue::Bool(_) => Bool,
}
}
@ -405,6 +440,7 @@ impl Type {
Type::Bool => true,
Type::Void => false,
Type::Ptr(_) => false,
Type::CustomType(_) => false,
}
}
@ -423,6 +459,7 @@ impl Type {
Type::Bool => false,
Type::Void => false,
Type::Ptr(_) => false,
Type::CustomType(_) => false,
}
}
}

View File

@ -20,7 +20,11 @@ use llvm_sys::{
},
};
use crate::util::{ErrorMessageHolder, MemoryBufferHolder, from_cstring, into_cstring};
use crate::{
CustomTypeKind, NamedStruct,
builder::{TypeHolder, TypeValue},
util::{ErrorMessageHolder, MemoryBufferHolder, from_cstring, into_cstring},
};
use super::{
CmpPredicate, ConstValue, Context, TerminatorKind, Type,
@ -175,6 +179,7 @@ pub struct LLVMModule<'a> {
functions: HashMap<FunctionValue, LLVMFunction>,
blocks: HashMap<BlockValue, LLVMBasicBlockRef>,
values: HashMap<InstructionValue, LLVMValue>,
types: HashMap<TypeValue, LLVMTypeRef>,
}
#[derive(Clone, Copy)]
@ -202,12 +207,16 @@ impl ModuleHolder {
// Compile the contents
let mut functions = HashMap::new();
let mut types = HashMap::new();
for ty in &self.types {
types.insert(ty.value, ty.compile_type(context, &types));
}
let mut functions = HashMap::new();
for function in &self.functions {
functions.insert(
function.value,
function.compile_signature(context, module_ref),
function.compile_signature(context, module_ref, &types),
);
}
@ -217,6 +226,7 @@ impl ModuleHolder {
builder_ref: context.builder_ref,
module_ref,
functions,
types,
blocks: HashMap::new(),
values: HashMap::new(),
};
@ -230,19 +240,46 @@ impl ModuleHolder {
}
}
impl TypeHolder {
unsafe fn compile_type(
&self,
context: &LLVMContext,
types: &HashMap<TypeValue, LLVMTypeRef>,
) -> LLVMTypeRef {
unsafe {
match &self.data.kind {
CustomTypeKind::NamedStruct(named_struct) => {
let mut elem_types = Vec::new();
for ty in &named_struct.1 {
elem_types.push(ty.as_llvm(context.context_ref, types));
}
let struct_ty = LLVMStructTypeInContext(
context.context_ref,
elem_types.as_mut_ptr(),
elem_types.len() as u32,
0,
);
struct_ty
}
}
}
}
}
impl FunctionHolder {
unsafe fn compile_signature(
&self,
context: &LLVMContext,
module_ref: LLVMModuleRef,
types: &HashMap<TypeValue, LLVMTypeRef>,
) -> LLVMFunction {
unsafe {
let ret_type = self.data.ret.as_llvm(context.context_ref);
let ret_type = self.data.ret.as_llvm(context.context_ref, types);
let mut param_types: Vec<LLVMTypeRef> = self
.data
.params
.iter()
.map(|t| t.as_llvm(context.context_ref))
.map(|t| t.as_llvm(context.context_ref, types))
.collect();
let param_ptr = param_types.as_mut_ptr();
let param_len = param_types.len();
@ -346,7 +383,7 @@ impl InstructionHolder {
use super::Instr::*;
match &self.data.kind {
Param(nth) => LLVMGetParam(function.value_ref, *nth as u32),
Constant(val) => val.as_llvm(module.context_ref, module.builder_ref),
Constant(val) => val.as_llvm(module),
Add(lhs, rhs) => {
let lhs_val = module.values.get(&lhs).unwrap().value_ref;
let rhs_val = module.values.get(&rhs).unwrap().value_ref;
@ -412,7 +449,7 @@ impl InstructionHolder {
}
let phi = LLVMBuildPhi(
module.builder_ref,
_ty.as_llvm(module.context_ref),
_ty.as_llvm(module.context_ref, &module.types),
c"phi".as_ptr(),
);
LLVMAddIncoming(
@ -425,12 +462,12 @@ impl InstructionHolder {
}
Alloca(name, ty) => LLVMBuildAlloca(
module.builder_ref,
ty.as_llvm(module.context_ref),
ty.as_llvm(module.context_ref, &module.types),
into_cstring(name).as_ptr(),
),
Load(ptr, ty) => LLVMBuildLoad2(
module.builder_ref,
ty.as_llvm(module.context_ref),
ty.as_llvm(module.context_ref, &module.types),
module.values.get(&ptr).unwrap().value_ref,
c"load".as_ptr(),
),
@ -440,11 +477,10 @@ impl InstructionHolder {
module.values.get(&ptr).unwrap().value_ref,
),
ArrayAlloca(ty, len) => {
let array_len = ConstValue::U16(*len as u16)
.as_llvm(module.context_ref, module.builder_ref);
let array_len = ConstValue::U16(*len as u16).as_llvm(module);
LLVMBuildArrayAlloca(
module.builder_ref,
ty.as_llvm(module.context_ref),
ty.as_llvm(module.context_ref, &module.types),
array_len,
c"array_alloca".as_ptr(),
)
@ -453,20 +489,51 @@ impl InstructionHolder {
let t = arr.get_type(module.builder).unwrap();
let Type::Ptr(elem_t) = t else { panic!() };
let mut indices: Vec<_> = indices
let mut llvm_indices: Vec<_> = indices
.iter()
.map(|idx| {
ConstValue::U32(*idx).as_llvm(module.context_ref, module.builder_ref)
})
.map(|idx| ConstValue::U32(*idx).as_llvm(module))
.collect();
LLVMBuildGEP2(
module.builder_ref,
elem_t.as_llvm(module.context_ref),
elem_t.as_llvm(module.context_ref, &module.types),
module.values.get(arr).unwrap().value_ref,
indices.as_mut_ptr(),
indices.len() as u32,
c"array_gep".as_ptr(),
llvm_indices.as_mut_ptr(),
llvm_indices.len() as u32,
into_cstring(format!(
"array_gep_{:?}",
indices
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join("_")
))
.as_ptr(),
)
}
GetStructElemPtr(struct_val, idx) => {
let t = struct_val.get_type(module.builder).unwrap();
let Type::Ptr(inner_t) = t else { panic!() };
let Type::CustomType(struct_ty) = *inner_t else {
panic!();
};
let struct_ty_data = module.builder.type_data(&struct_ty);
let (name, elem_ty) = match struct_ty_data.kind {
CustomTypeKind::NamedStruct(NamedStruct(name, fields)) => (
name,
fields
.get_unchecked(*idx as usize)
.as_llvm(module.context_ref, &module.types),
),
};
LLVMBuildStructGEP2(
module.builder_ref,
elem_ty,
module.values.get(struct_val).unwrap().value_ref,
*idx,
into_cstring(format!("struct_gep_{}_{}", name, idx)).as_ptr(),
)
}
}
@ -532,9 +599,9 @@ impl CmpPredicate {
}
impl ConstValue {
fn as_llvm(&self, context: LLVMContextRef, builder: LLVMBuilderRef) -> LLVMValueRef {
fn as_llvm(&self, module: &LLVMModule) -> LLVMValueRef {
unsafe {
let t = self.get_type().as_llvm(context);
let t = self.get_type().as_llvm(module.context_ref, &module.types);
match self {
ConstValue::Bool(val) => LLVMConstInt(t, *val as u64, 1),
ConstValue::I8(val) => LLVMConstInt(t, *val as u64, 1),
@ -547,8 +614,8 @@ impl ConstValue {
ConstValue::U32(val) => LLVMConstInt(t, *val as u64, 1),
ConstValue::U64(val) => LLVMConstInt(t, *val as u64, 1),
ConstValue::U128(val) => LLVMConstInt(t, *val as u64, 1),
ConstValue::String(val) => LLVMBuildGlobalStringPtr(
builder,
ConstValue::StringPtr(val) => LLVMBuildGlobalStringPtr(
module.builder_ref,
into_cstring(val).as_ptr(),
c"string".as_ptr(),
),
@ -558,7 +625,11 @@ impl ConstValue {
}
impl Type {
fn as_llvm(&self, context: LLVMContextRef) -> LLVMTypeRef {
fn as_llvm(
&self,
context: LLVMContextRef,
typemap: &HashMap<TypeValue, LLVMTypeRef>,
) -> LLVMTypeRef {
use Type::*;
unsafe {
match self {
@ -569,7 +640,8 @@ impl Type {
I128 | U128 => LLVMInt128TypeInContext(context),
Bool => LLVMInt1TypeInContext(context),
Void => LLVMVoidTypeInContext(context),
Ptr(ty) => LLVMPointerType(ty.as_llvm(context), 0),
Ptr(ty) => LLVMPointerType(ty.as_llvm(context, typemap), 0),
CustomType(struct_ty) => *typemap.get(struct_ty).unwrap(),
}
}
}

View File

@ -5,7 +5,7 @@ use std::{
marker::PhantomData,
};
use crate::{CmpPredicate, Instr, InstructionData, TerminatorKind, builder::*};
use crate::{CmpPredicate, Instr, InstructionData, NamedStruct, TerminatorKind, builder::*};
impl Debug for Builder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -99,6 +99,12 @@ impl Debug for InstructionValue {
}
}
impl Debug for TypeValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Ty[{:0>2}-{:0>2}]", &self.0.0, self.1)
}
}
impl Debug for Instr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@ -126,6 +132,9 @@ impl Debug for Instr {
.collect::<Vec<_>>()
.join(", "),
),
Instr::GetStructElemPtr(instruction_value, index) => {
fmt_index(f, instruction_value, &index.to_string())
}
}
}
}

View File

@ -4,7 +4,7 @@
use std::{fmt::Debug, marker::PhantomData};
use builder::{BlockValue, Builder, FunctionValue, InstructionValue, ModuleValue};
use builder::{BlockValue, Builder, FunctionValue, InstructionValue, ModuleValue, TypeValue};
use debug::PrintableModule;
pub mod builder;
@ -205,6 +205,7 @@ pub enum CmpPredicate {
pub enum Instr {
Param(usize),
Constant(ConstValue),
Add(InstructionValue, InstructionValue),
Sub(InstructionValue, InstructionValue),
Mult(InstructionValue, InstructionValue),
@ -216,6 +217,7 @@ pub enum Instr {
Store(InstructionValue, InstructionValue),
ArrayAlloca(Type, u32),
GetElemPtr(InstructionValue, Vec<u32>),
GetStructElemPtr(InstructionValue, u32),
/// Integer Comparison
ICmp(CmpPredicate, InstructionValue, InstructionValue),
@ -237,6 +239,7 @@ pub enum Type {
U128,
Bool,
Void,
CustomType(TypeValue),
Ptr(Box<Type>),
}
@ -253,7 +256,7 @@ pub enum ConstValue {
U64(u64),
U128(u128),
Bool(bool),
String(String),
StringPtr(String),
}
#[derive(Clone, Hash)]
@ -272,5 +275,8 @@ pub struct TypeData {
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub enum CustomTypeKind {
Struct(Vec<Type>),
NamedStruct(NamedStruct),
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct NamedStruct(String, Vec<Type>);

View File

@ -494,7 +494,7 @@ impl mir::Literal {
mir::Literal::U64(val) => ConstValue::U64(val),
mir::Literal::U128(val) => ConstValue::U128(val),
mir::Literal::Bool(val) => ConstValue::Bool(val),
mir::Literal::String(val) => ConstValue::String(val.clone()),
mir::Literal::String(val) => ConstValue::StringPtr(val.clone()),
mir::Literal::Vague(_) => panic!("Got vague literal!"),
})
}