From 58117d86e469b14682b05e5432d6f014854bb341 Mon Sep 17 00:00:00 2001 From: sofia Date: Sun, 6 Jul 2025 19:47:05 +0300 Subject: [PATCH] Make a more Rusty LLIR for the lib that is compiled to LLVM IR --- reid-llvm-lib/examples/test.rs | 60 +++++ reid-llvm-lib/src/lib.rs | 1 + reid-llvm-lib/src/test/builder.rs | 334 +++++++++++++++++++++++++++ reid-llvm-lib/src/test/compile.rs | 364 ++++++++++++++++++++++++++++++ reid-llvm-lib/src/test/mod.rs | 230 +++++++++++++++++++ reid-llvm-lib/src/test/util.rs | 18 ++ 6 files changed, 1007 insertions(+) create mode 100644 reid-llvm-lib/examples/test.rs create mode 100644 reid-llvm-lib/src/test/builder.rs create mode 100644 reid-llvm-lib/src/test/compile.rs create mode 100644 reid-llvm-lib/src/test/mod.rs create mode 100644 reid-llvm-lib/src/test/util.rs diff --git a/reid-llvm-lib/examples/test.rs b/reid-llvm-lib/examples/test.rs new file mode 100644 index 0000000..2490faf --- /dev/null +++ b/reid-llvm-lib/examples/test.rs @@ -0,0 +1,60 @@ +use reid_lib::test::{ConstValue, Context, InstructionKind, IntPredicate, TerminatorKind, Type}; + +fn main() { + use ConstValue::*; + use InstructionKind::*; + + let context = Context::new(); + + let mut module = context.module("test"); + + let mut main = module.function("main", Type::I32, Vec::new()); + let mut m_entry = main.block("entry"); + + let mut fibonacci = module.function("fibonacci", Type::I32, vec![Type::I32]); + + let arg = m_entry.build(Constant(I32(5))).unwrap(); + let fibonacci_call = m_entry + .build(FunctionCall(fibonacci.value(), vec![arg])) + .unwrap(); + m_entry + .terminate(TerminatorKind::Ret(fibonacci_call)) + .unwrap(); + + let mut f_entry = fibonacci.block("entry"); + + let num_3 = f_entry.build(Constant(I32(3))).unwrap(); + let param_n = f_entry.build(Param(0)).unwrap(); + let cond = f_entry + .build(ICmp(IntPredicate::LessThan, param_n, num_3)) + .unwrap(); + + let mut then_b = fibonacci.block("then"); + let mut else_b = fibonacci.block("else"); + + f_entry + .terminate(TerminatorKind::CondBr(cond, then_b.value(), else_b.value())) + .unwrap(); + + let ret_const = then_b.build(Constant(I32(1))).unwrap(); + then_b.terminate(TerminatorKind::Ret(ret_const)).unwrap(); + + let const_1 = else_b.build(Constant(I32(1))).unwrap(); + let const_2 = else_b.build(Constant(I32(2))).unwrap(); + let param_1 = else_b.build(Sub(param_n, const_1)).unwrap(); + let param_2 = else_b.build(Sub(param_n, const_2)).unwrap(); + let call_1 = else_b + .build(FunctionCall(fibonacci.value(), vec![param_1])) + .unwrap(); + let call_2 = else_b + .build(FunctionCall(fibonacci.value(), vec![param_2])) + .unwrap(); + + let add = else_b.build(Add(call_1, call_2)).unwrap(); + + else_b.terminate(TerminatorKind::Ret(add)).unwrap(); + + dbg!(&context); + + context.compile(); +} diff --git a/reid-llvm-lib/src/lib.rs b/reid-llvm-lib/src/lib.rs index ec9b390..25d307f 100644 --- a/reid-llvm-lib/src/lib.rs +++ b/reid-llvm-lib/src/lib.rs @@ -16,6 +16,7 @@ use llvm_sys::{LLVMBuilder, LLVMContext, LLVMIntPredicate, core::*, prelude::*}; use types::{BasicType, BasicValue, FunctionType, IntegerType, Value}; use util::{ErrorMessageHolder, from_cstring, into_cstring}; +pub mod test; pub mod types; mod util; diff --git a/reid-llvm-lib/src/test/builder.rs b/reid-llvm-lib/src/test/builder.rs new file mode 100644 index 0000000..f28cb63 --- /dev/null +++ b/reid-llvm-lib/src/test/builder.rs @@ -0,0 +1,334 @@ +use std::{cell::RefCell, marker::PhantomData, rc::Rc}; + +use crate::test::{ConstValue, InstructionKind, TerminatorKind, Type}; + +use super::{BlockData, FunctionData, InstructionData, ModuleData, util::match_types}; + +#[derive(Debug, Clone, Hash, Copy, PartialEq, Eq)] +pub struct ModuleValue(usize); + +#[derive(Debug, Clone, Hash, Copy, PartialEq, Eq)] +pub struct FunctionValue(ModuleValue, usize); + +#[derive(Debug, Clone, Hash, Copy, PartialEq, Eq)] +pub struct BlockValue(FunctionValue, usize); + +#[derive(Debug, Clone, Hash, Copy, PartialEq, Eq)] +pub struct InstructionValue(BlockValue, usize); + +#[derive(Debug, Clone)] +pub struct ModuleHolder { + pub(crate) value: ModuleValue, + pub(crate) data: ModuleData, + pub(crate) functions: Vec, +} + +#[derive(Debug, Clone)] +pub struct FunctionHolder { + pub(crate) value: FunctionValue, + pub(crate) data: FunctionData, + pub(crate) blocks: Vec, +} + +#[derive(Debug, Clone)] +pub struct BlockHolder { + pub(crate) value: BlockValue, + pub(crate) data: BlockData, + pub(crate) instructions: Vec, +} + +#[derive(Debug, Clone)] +pub struct InstructionHolder { + pub(crate) value: InstructionValue, + pub(crate) data: InstructionData, +} + +#[derive(Debug, Clone)] +pub struct Builder { + modules: Rc>>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + modules: Rc::new(RefCell::new(Vec::new())), + } + } + + pub(crate) fn add_module(&self, data: ModuleData) -> ModuleValue { + let value = ModuleValue(self.modules.borrow().len()); + self.modules.borrow_mut().push(ModuleHolder { + value, + data, + functions: Vec::new(), + }); + value + } + + pub(crate) unsafe fn add_function( + &self, + mod_val: &ModuleValue, + data: FunctionData, + ) -> FunctionValue { + unsafe { + let mut modules = self.modules.borrow_mut(); + let module = modules.get_unchecked_mut(mod_val.0); + let value = FunctionValue(module.value, module.functions.len()); + module.functions.push(FunctionHolder { + value, + data, + blocks: Vec::new(), + }); + value + } + } + + pub(crate) unsafe fn add_block(&self, fun_val: &FunctionValue, data: BlockData) -> BlockValue { + unsafe { + let mut modules = self.modules.borrow_mut(); + let module = modules.get_unchecked_mut(fun_val.0.0); + let function = module.functions.get_unchecked_mut(fun_val.1); + let value = BlockValue(function.value, function.blocks.len()); + function.blocks.push(BlockHolder { + value, + data, + instructions: Vec::new(), + }); + value + } + } + + pub(crate) unsafe fn add_instruction( + &self, + block_val: &BlockValue, + data: InstructionData, + ) -> Result { + unsafe { + let mut modules = self.modules.borrow_mut(); + let module = modules.get_unchecked_mut(block_val.0.0.0); + let function = module.functions.get_unchecked_mut(block_val.0.1); + let block = function.blocks.get_unchecked_mut(block_val.1); + let value = InstructionValue(block.value, block.instructions.len()); + block.instructions.push(InstructionHolder { value, data }); + + // Drop modules so that it is no longer mutable borrowed + // (check_instruction requires an immutable borrow). + drop(modules); + + self.check_instruction(&value)?; + Ok(value) + } + } + + pub(crate) unsafe fn terminate( + &self, + block: &BlockValue, + value: TerminatorKind, + ) -> Result<(), ()> { + unsafe { + let mut modules = self.modules.borrow_mut(); + let module = modules.get_unchecked_mut(block.0.0.0); + let function = module.functions.get_unchecked_mut(block.0.1); + let block = function.blocks.get_unchecked_mut(block.1); + if let Some(_) = &block.data.terminator { + Err(()) + } else { + block.data.terminator = Some(value); + Ok(()) + } + } + } + + pub(crate) unsafe fn module_data(&self, value: &ModuleValue) -> ModuleData { + unsafe { self.modules.borrow().get_unchecked(value.0).data.clone() } + } + + pub(crate) unsafe fn function_data(&self, value: &FunctionValue) -> FunctionData { + unsafe { + self.modules + .borrow() + .get_unchecked(value.0.0) + .functions + .get_unchecked(value.1) + .data + .clone() + } + } + + pub(crate) unsafe fn block_data(&self, value: &BlockValue) -> BlockData { + unsafe { + self.modules + .borrow() + .get_unchecked(value.0.0.0) + .functions + .get_unchecked(value.0.1) + .blocks + .get_unchecked(value.1) + .data + .clone() + } + } + + pub(crate) unsafe fn instr_data(&self, value: &InstructionValue) -> InstructionData { + unsafe { + self.modules + .borrow() + .get_unchecked(value.0.0.0.0) + .functions + .get_unchecked(value.0.0.1) + .blocks + .get_unchecked(value.0.1) + .instructions + .get_unchecked(value.1) + .data + .clone() + } + } + + pub(crate) fn get_modules(&self) -> Rc>> { + self.modules.clone() + } + + // pub(crate) fn get_functions(&self, module: ModuleValue) -> Vec<(FunctionValue, FunctionData)> { + // unsafe { + // self.modules + // .borrow() + // .get_unchecked(module.0) + // .2 + // .iter() + // .map(|h| (h.0, h.1.clone())) + // .collect() + // } + // } + + // pub(crate) fn get_blocks(&self, function: FunctionValue) -> Vec<(BlockValue, BlockData)> { + // unsafe { + // self.modules + // .borrow() + // .get_unchecked(function.0.0) + // .2 + // .get_unchecked(function.1) + // .2 + // .iter() + // .map(|h| (h.0, h.1.clone())) + // .collect() + // } + // } + + // pub(crate) fn get_instructions( + // &self, + // block: BlockValue, + // ) -> ( + // Vec<(InstructionValue, InstructionData)>, + // Option, + // ) { + // unsafe { + // let modules = self.modules.borrow(); + // let block = modules + // .get_unchecked(block.0.0.0) + // .2 + // .get_unchecked(block.0.1) + // .2 + // .get_unchecked(block.1); + // ( + // block.2.iter().map(|h| (h.0, h.1.clone())).collect(), + // block.1.terminator.clone(), + // ) + // } + // } + + pub fn check_instruction(&self, instruction: &InstructionValue) -> Result<(), ()> { + use super::InstructionKind::*; + unsafe { + match self.instr_data(&instruction).kind { + Param(_) => Ok(()), + Constant(_) => Ok(()), + Add(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), + Sub(lhs, rhs) => match_types(&lhs, &rhs, &self).map(|_| ()), + ICmp(_, lhs, rhs) => { + let t = match_types(&lhs, &rhs, self)?; + if t.comparable() { + Ok(()) + } else { + Err(()) // TODO error: Types not comparable + } + } + FunctionCall(fun, params) => { + let param_types = self.function_data(&fun).params; + if param_types.len() != params.len() { + return Err(()); // TODO error: invalid amount of params + } + for (a, b) in param_types.iter().zip(params) { + if *a != b.get_type(&self)? { + return Err(()); // TODO error: params do not match + } + } + Ok(()) + } + } + } + } +} + +impl InstructionValue { + pub fn get_type(&self, builder: &Builder) -> Result { + use InstructionKind::*; + use Type::*; + unsafe { + match &builder.instr_data(self).kind { + Param(nth) => builder + .function_data(&self.0.0) + .params + .get(*nth) + .copied() + .ok_or(()), + Constant(c) => Ok(c.get_type()), + Add(lhs, rhs) => match_types(lhs, rhs, &builder), + Sub(lhs, rhs) => match_types(lhs, rhs, &builder), + ICmp(pred, lhs, rhs) => Ok(Type::Bool), + FunctionCall(function_value, _) => Ok(builder.function_data(function_value).ret), + } + } + } +} + +impl ConstValue { + pub fn get_type(&self) -> Type { + use Type::*; + match self { + ConstValue::I32(_) => I32, + ConstValue::U32(_) => U32, + } + } +} + +impl Type { + pub fn comparable(&self) -> bool { + match self { + Type::I32 => true, + Type::U32 => true, + Type::Bool => true, + Type::Void => false, + } + } + + pub fn signed(&self) -> bool { + match self { + Type::I32 => true, + Type::U32 => false, + Type::Bool => false, + Type::Void => false, + } + } +} + +impl TerminatorKind { + pub fn get_type(&self, builder: &Builder) -> Result { + use TerminatorKind::*; + match self { + Ret(instr_val) => instr_val.get_type(builder), + Branch(_) => Ok(Type::Void), + CondBr(_, _, _) => Ok(Type::Void), + } + } +} diff --git a/reid-llvm-lib/src/test/compile.rs b/reid-llvm-lib/src/test/compile.rs new file mode 100644 index 0000000..28f77ed --- /dev/null +++ b/reid-llvm-lib/src/test/compile.rs @@ -0,0 +1,364 @@ +use std::{collections::HashMap, ffi::CString, hash::Hash, process::Termination, ptr::null_mut}; + +use llvm_sys::{ + LLVMIntPredicate, + analysis::LLVMVerifyModule, + core::*, + prelude::*, + target::{ + LLVM_InitializeAllAsmParsers, LLVM_InitializeAllAsmPrinters, LLVM_InitializeAllTargetInfos, + LLVM_InitializeAllTargetMCs, LLVM_InitializeAllTargets, LLVMSetModuleDataLayout, + }, + target_machine::{ + LLVMCodeGenFileType, LLVMCreateTargetDataLayout, LLVMCreateTargetMachine, + LLVMGetDefaultTargetTriple, LLVMGetTargetFromTriple, LLVMTargetMachineEmitToFile, + }, +}; + +use crate::util::{ErrorMessageHolder, from_cstring, into_cstring}; + +use super::{ + ConstValue, Context, Function, IntPredicate, Module, TerminatorKind, Type, + builder::{ + BlockHolder, BlockValue, Builder, FunctionHolder, FunctionValue, InstructionHolder, + InstructionValue, ModuleHolder, + }, +}; + +pub struct LLVMContext { + context_ref: LLVMContextRef, + builder_ref: LLVMBuilderRef, +} + +impl Context { + pub fn compile(&self) { + unsafe { + let context_ref = LLVMContextCreate(); + + let context = LLVMContext { + context_ref, + builder_ref: LLVMCreateBuilderInContext(context_ref), + }; + + for holder in self.builder.get_modules().borrow().iter() { + holder.compile(&context, &self.builder); + } + + LLVMDisposeBuilder(context.builder_ref); + LLVMContextDispose(context.context_ref); + } + } +} + +pub struct LLVMModule<'a> { + builder: &'a Builder, + context_ref: LLVMContextRef, + builder_ref: LLVMBuilderRef, + module_ref: LLVMModuleRef, + functions: HashMap, + blocks: HashMap, + values: HashMap, +} + +#[derive(Clone, Copy)] +pub struct LLVMFunction { + type_ref: LLVMTypeRef, + value_ref: LLVMValueRef, +} + +pub struct LLVMValue { + ty: Type, + value_ref: LLVMValueRef, +} + +impl ModuleHolder { + fn compile(&self, context: &LLVMContext, builder: &Builder) { + unsafe { + let module_ref = LLVMModuleCreateWithNameInContext( + into_cstring(&self.data.name).as_ptr(), + context.context_ref, + ); + + // Compile the contents + + let mut functions = HashMap::new(); + + for function in &self.functions { + functions.insert( + function.value, + function.compile_signature(context, module_ref), + ); + } + + let mut module = LLVMModule { + builder, + context_ref: context.context_ref, + builder_ref: context.builder_ref, + module_ref, + functions, + blocks: HashMap::new(), + values: HashMap::new(), + }; + + for function in &self.functions { + function.compile(&mut module); + } + + LLVM_InitializeAllTargets(); + LLVM_InitializeAllTargetInfos(); + LLVM_InitializeAllTargetMCs(); + LLVM_InitializeAllAsmParsers(); + LLVM_InitializeAllAsmPrinters(); + + let triple = LLVMGetDefaultTargetTriple(); + + let mut target: _ = null_mut(); + let mut err = ErrorMessageHolder::null(); + LLVMGetTargetFromTriple(triple, &mut target, err.borrow_mut()); + println!("{:?}, {:?}", from_cstring(triple), target); + err.into_result().unwrap(); + + let target_machine = LLVMCreateTargetMachine( + target, + triple, + c"generic".as_ptr(), + c"".as_ptr(), + llvm_sys::target_machine::LLVMCodeGenOptLevel::LLVMCodeGenLevelNone, + llvm_sys::target_machine::LLVMRelocMode::LLVMRelocDefault, + llvm_sys::target_machine::LLVMCodeModel::LLVMCodeModelDefault, + ); + + let data_layout = LLVMCreateTargetDataLayout(target_machine); + LLVMSetTarget(module_ref, triple); + LLVMSetModuleDataLayout(module_ref, data_layout); + + let mut err = ErrorMessageHolder::null(); + LLVMVerifyModule( + module_ref, + llvm_sys::analysis::LLVMVerifierFailureAction::LLVMPrintMessageAction, + err.borrow_mut(), + ); + err.into_result().unwrap(); + + let mut err = ErrorMessageHolder::null(); + LLVMTargetMachineEmitToFile( + target_machine, + module_ref, + CString::new("hello.asm").unwrap().into_raw(), + LLVMCodeGenFileType::LLVMAssemblyFile, + err.borrow_mut(), + ); + err.into_result().unwrap(); + + let mut err = ErrorMessageHolder::null(); + LLVMTargetMachineEmitToFile( + target_machine, + module_ref, + CString::new("hello.o").unwrap().into_raw(), + LLVMCodeGenFileType::LLVMObjectFile, + err.borrow_mut(), + ); + err.into_result().unwrap(); + + let module_str = from_cstring(LLVMPrintModuleToString(module_ref)); + println!("{}", module_str.unwrap()); + } + } +} + +impl FunctionHolder { + unsafe fn compile_signature( + &self, + context: &LLVMContext, + module_ref: LLVMModuleRef, + ) -> LLVMFunction { + unsafe { + let ret_type = self.data.ret.as_llvm(context.context_ref); + let mut param_types: Vec = self + .data + .params + .iter() + .map(|t| t.as_llvm(context.context_ref)) + .collect(); + let param_ptr = param_types.as_mut_ptr(); + let param_len = param_types.len(); + + let fn_type = LLVMFunctionType(ret_type, param_ptr, param_len as u32, 0); + + let function_ref = + LLVMAddFunction(module_ref, into_cstring(&self.data.name).as_ptr(), fn_type); + + LLVMFunction { + type_ref: fn_type, + value_ref: function_ref, + } + } + } + + unsafe fn compile(&self, module: &mut LLVMModule) { + unsafe { + let own_function = *module.functions.get(&self.value).unwrap(); + + for block in &self.blocks { + let block_ref = LLVMCreateBasicBlockInContext( + module.context_ref, + into_cstring(&self.data.name).as_ptr(), + ); + LLVMAppendExistingBasicBlock(own_function.value_ref, block_ref); + module.blocks.insert(block.value, block_ref); + } + + for block in &self.blocks { + block.compile(module, &own_function); + } + } + } +} + +impl BlockHolder { + unsafe fn compile(&self, module: &mut LLVMModule, function: &LLVMFunction) { + unsafe { + let block_ref = *module.blocks.get(&self.value).unwrap(); + LLVMPositionBuilderAtEnd(module.builder_ref, block_ref); + + for instruction in &self.instructions { + let key = instruction.value; + let ret = instruction.compile(module, function, block_ref); + module.values.insert(key, ret); + } + + self.data + .terminator + .clone() + .expect(&format!( + "Block {} does not have a terminator!", + self.data.name + )) + .compile(module, function, block_ref); + } + } +} + +impl InstructionHolder { + unsafe fn compile( + &self, + module: &LLVMModule, + function: &LLVMFunction, + block: LLVMBasicBlockRef, + ) -> LLVMValue { + let ty = self.value.get_type(module.builder).unwrap(); + let val = unsafe { + use super::InstructionKind::*; + match &self.data.kind { + Param(nth) => LLVMGetParam(function.value_ref, *nth as u32), + Constant(val) => val.as_llvm(module.context_ref), + Add(lhs, rhs) => { + let lhs_val = module.values.get(&lhs).unwrap().value_ref; + let rhs_val = module.values.get(&rhs).unwrap().value_ref; + LLVMBuildAdd(module.builder_ref, lhs_val, rhs_val, c"add".as_ptr()) + } + Sub(lhs, rhs) => { + let lhs_val = module.values.get(&lhs).unwrap().value_ref; + let rhs_val = module.values.get(&rhs).unwrap().value_ref; + LLVMBuildSub(module.builder_ref, lhs_val, rhs_val, c"sub".as_ptr()) + } + ICmp(pred, lhs, rhs) => { + let lhs_val = module.values.get(&lhs).unwrap().value_ref; + let rhs_val = module.values.get(&rhs).unwrap().value_ref; + LLVMBuildICmp( + module.builder_ref, + pred.as_llvm(ty.signed()), + lhs_val, + rhs_val, + c"icmp".as_ptr(), + ) + } + FunctionCall(function_value, instruction_values) => { + let fun = module.functions.get(&function_value).unwrap(); + let mut param_list: Vec = instruction_values + .iter() + .map(|i| module.values.get(i).unwrap().value_ref) + .collect(); + + LLVMBuildCall2( + module.builder_ref, + fun.type_ref, + fun.value_ref, + param_list.as_mut_ptr(), + param_list.len() as u32, + c"call".as_ptr(), + ) + } + } + }; + LLVMValue { ty, value_ref: val } + } +} + +impl TerminatorKind { + fn compile( + &self, + module: &LLVMModule, + function: &LLVMFunction, + block: LLVMBasicBlockRef, + ) -> LLVMValue { + let ty = self.get_type(module.builder).unwrap(); + let val = unsafe { + match self { + TerminatorKind::Ret(val) => { + let value = module.values.get(val).unwrap(); + LLVMBuildRet(module.builder_ref, value.value_ref) + } + TerminatorKind::Branch(block_value) => { + let dest = *module.blocks.get(block_value).unwrap(); + LLVMBuildBr(module.builder_ref, dest) + } + TerminatorKind::CondBr(cond, then_b, else_b) => { + let cond_val = module.values.get(cond).unwrap().value_ref; + let then_bb = *module.blocks.get(then_b).unwrap(); + let else_bb = *module.blocks.get(else_b).unwrap(); + LLVMBuildCondBr(module.builder_ref, cond_val, then_bb, else_bb) + } + } + }; + LLVMValue { ty, value_ref: val } + } +} + +impl IntPredicate { + fn as_llvm(&self, signed: bool) -> LLVMIntPredicate { + use IntPredicate::*; + use LLVMIntPredicate::*; + match (self, signed) { + (LessThan, true) => LLVMIntSLT, + (GreaterThan, true) => LLVMIntSGT, + (LessThan, false) => LLVMIntULT, + (GreaterThan, false) => LLVMIntUGT, + } + } +} + +impl ConstValue { + fn as_llvm(&self, context: LLVMContextRef) -> LLVMValueRef { + unsafe { + let t = self.get_type().as_llvm(context); + match *self { + ConstValue::I32(val) => LLVMConstInt(t, val as u64, 1), + ConstValue::U32(val) => LLVMConstInt(t, val as u64, 1), + } + } + } +} + +impl Type { + fn as_llvm(&self, context: LLVMContextRef) -> LLVMTypeRef { + unsafe { + match self { + Type::I32 => LLVMInt32TypeInContext(context), + Type::U32 => LLVMInt32TypeInContext(context), + Type::Bool => LLVMInt1TypeInContext(context), + Type::Void => LLVMVoidType(), + } + } + } +} diff --git a/reid-llvm-lib/src/test/mod.rs b/reid-llvm-lib/src/test/mod.rs new file mode 100644 index 0000000..429d1ca --- /dev/null +++ b/reid-llvm-lib/src/test/mod.rs @@ -0,0 +1,230 @@ +use std::marker::PhantomData; + +use builder::{BlockValue, Builder, FunctionValue, InstructionValue, ModuleValue}; + +mod builder; +mod compile; +mod util; + +// pub struct InstructionValue(BlockValue, usize); + +#[derive(Debug)] +pub struct Context { + builder: Builder, +} + +impl Context { + pub fn new() -> Context { + Context { + builder: Builder::new(), + } + } + + pub fn module<'ctx>(&'ctx self, name: &str) -> Module<'ctx> { + let value = self.builder.add_module(ModuleData { + name: name.to_owned(), + }); + Module { + phantom: PhantomData, + builder: self.builder.clone(), + value, + } + } +} + +#[derive(Debug, Clone, Hash)] +pub struct ModuleData { + name: String, +} + +pub struct Module<'ctx> { + phantom: PhantomData<&'ctx ()>, + builder: Builder, + value: ModuleValue, +} + +impl<'ctx> Module<'ctx> { + pub fn function(&mut self, name: &str, ret: Type, params: Vec) -> Function<'ctx> { + unsafe { + Function { + phantom: PhantomData, + builder: self.builder.clone(), + value: self.builder.add_function( + &self.value, + FunctionData { + name: name.to_owned(), + ret, + params, + }, + ), + } + } + } + + pub fn value(&self) -> ModuleValue { + self.value + } +} + +#[derive(Debug, Clone, Hash)] +pub struct FunctionData { + name: String, + ret: Type, + params: Vec, +} + +pub struct Function<'ctx> { + phantom: PhantomData<&'ctx ()>, + builder: Builder, + value: FunctionValue, +} + +impl<'ctx> Function<'ctx> { + pub fn block(&mut self, name: &str) -> Block<'ctx> { + unsafe { + Block { + phantom: PhantomData, + builder: self.builder.clone(), + value: self.builder.add_block( + &self.value, + BlockData { + name: name.to_owned(), + terminator: None, + }, + ), + } + } + } + + pub fn value(&self) -> FunctionValue { + self.value + } +} + +#[derive(Debug, Clone, Hash)] +pub struct BlockData { + name: String, + terminator: Option, +} + +pub struct Block<'builder> { + phantom: PhantomData<&'builder ()>, + builder: Builder, + value: BlockValue, +} + +impl<'builder> Block<'builder> { + pub fn build(&mut self, instruction: InstructionKind) -> Result { + unsafe { + self.builder + .add_instruction(&self.value, InstructionData { kind: instruction }) + } + } + + pub fn terminate(&mut self, instruction: TerminatorKind) -> Result<(), ()> { + unsafe { self.builder.terminate(&self.value, instruction) } + } + + pub fn value(&self) -> BlockValue { + self.value + } +} + +#[derive(Debug, Clone, Hash)] +pub struct InstructionData { + kind: InstructionKind, +} + +#[derive(Debug, Clone, Copy, Hash)] +pub enum IntPredicate { + LessThan, + GreaterThan, +} + +#[derive(Debug, Clone, Hash)] +pub enum InstructionKind { + Param(usize), + Constant(ConstValue), + Add(InstructionValue, InstructionValue), + Sub(InstructionValue, InstructionValue), + + /// Integer Comparison + ICmp(IntPredicate, InstructionValue, InstructionValue), + + FunctionCall(FunctionValue, Vec), +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub enum Type { + I32, + U32, + Bool, + Void, +} + +#[derive(Debug, Clone, Hash)] +pub enum ConstValue { + I32(i32), + U32(u32), +} + +#[derive(Debug, Clone, Hash)] +pub enum TerminatorKind { + Ret(InstructionValue), + Branch(BlockValue), + CondBr(InstructionValue, BlockValue, BlockValue), +} + +fn test() { + use ConstValue::*; + use InstructionKind::*; + + let context = Context::new(); + + let mut module = context.module("test"); + + let mut main = module.function("main", Type::I32, Vec::new()); + let mut m_entry = main.block("entry"); + + let mut fibonacci = module.function("fibonacci", Type::I32, vec![Type::I32]); + + let arg = m_entry.build(Constant(I32(5))).unwrap(); + m_entry + .build(FunctionCall(fibonacci.value, vec![arg])) + .unwrap(); + + let mut f_entry = fibonacci.block("entry"); + + let num_3 = f_entry.build(Constant(I32(3))).unwrap(); + let param_n = f_entry.build(Param(0)).unwrap(); + let cond = f_entry + .build(ICmp(IntPredicate::LessThan, param_n, num_3)) + .unwrap(); + + let mut then_b = fibonacci.block("then"); + let mut else_b = fibonacci.block("else"); + + f_entry + .terminate(TerminatorKind::CondBr(cond, then_b.value, else_b.value)) + .unwrap(); + + let ret_const = then_b.build(Constant(I32(1))).unwrap(); + then_b.terminate(TerminatorKind::Ret(ret_const)).unwrap(); + + let const_1 = else_b.build(Constant(I32(1))).unwrap(); + let const_2 = else_b.build(Constant(I32(2))).unwrap(); + let param_1 = else_b.build(Sub(param_n, const_1)).unwrap(); + let param_2 = else_b.build(Sub(param_n, const_2)).unwrap(); + let call_1 = else_b + .build(FunctionCall(fibonacci.value, vec![param_1])) + .unwrap(); + let call_2 = else_b + .build(FunctionCall(fibonacci.value, vec![param_2])) + .unwrap(); + + let add = else_b.build(Add(call_1, call_2)).unwrap(); + + else_b.terminate(TerminatorKind::Ret(add)).unwrap(); + + dbg!(context); +} diff --git a/reid-llvm-lib/src/test/util.rs b/reid-llvm-lib/src/test/util.rs new file mode 100644 index 0000000..22670d3 --- /dev/null +++ b/reid-llvm-lib/src/test/util.rs @@ -0,0 +1,18 @@ +use super::{ + Type, + builder::{Builder, InstructionValue}, +}; + +pub fn match_types( + lhs: &InstructionValue, + rhs: &InstructionValue, + builder: &Builder, +) -> Result { + let lhs_type = lhs.get_type(&builder); + let rhs_type = rhs.get_type(&builder); + if let (Ok(lhs_t), Ok(rhs_t)) = (lhs_type, rhs_type) { + if lhs_t == rhs_t { Ok(lhs_t) } else { Err(()) } + } else { + Err(()) + } +}