From 814b816450890ccace21e5113749271c255a44dc Mon Sep 17 00:00:00 2001 From: sofia Date: Sun, 29 Jun 2025 01:18:17 +0300 Subject: [PATCH] Add phi --- libtest.sh | 2 +- reid-llvm-lib/examples/libtest.rs | 32 ++++++++--- reid-llvm-lib/src/lib.rs | 89 +++++++++++++++++++++++++------ reid-llvm-lib/src/types.rs | 87 +++++++++++++++++++++++------- 4 files changed, 167 insertions(+), 43 deletions(-) diff --git a/libtest.sh b/libtest.sh index 9150b17..eb7d0e9 100755 --- a/libtest.sh +++ b/libtest.sh @@ -7,7 +7,7 @@ export .env cargo run --example libtest && \ -# clang++ main.cpp hello.o -o main && \ +# clang hello.o -o main && \ ld -dynamic-linker /lib64/ld-linux-x86-64.so.2 \ -o main /usr/lib/crt1.o hello.o -lc && \ ./main ; echo "Return value: ""$?" diff --git a/reid-llvm-lib/examples/libtest.rs b/reid-llvm-lib/examples/libtest.rs index f9bd5bf..7cfc9eb 100644 --- a/reid-llvm-lib/examples/libtest.rs +++ b/reid-llvm-lib/examples/libtest.rs @@ -1,4 +1,7 @@ -use reid_lib::{Context, IntPredicate, types::BasicType}; +use reid_lib::{ + Context, IntPredicate, + types::{BasicType, IntegerType, IntegerValue}, +}; pub fn main() { // Notes from inkwell: @@ -21,19 +24,34 @@ pub fn main() { let entry = function.block("entry"); - let v1 = int_32.from_signed(100); - let v2 = entry.call(&secondary, vec![], "call").unwrap(); - let lhs_cmp = entry.add(&v1, &v2, "add").unwrap(); + let call = entry.call(&secondary, vec![], "call").unwrap(); + let add = entry.add(&int_32.from_signed(100), &call, "add").unwrap(); let rhs_cmp = int_32.from_signed(200); let cond_res = entry - .integer_compare(&lhs_cmp, &rhs_cmp, &IntPredicate::SLT, "cmp") + .integer_compare(&add, &rhs_cmp, &IntPredicate::SLT, "cmp") .unwrap(); let (lhs, rhs) = entry.conditional_br(&cond_res, "lhs", "rhs").unwrap(); - lhs.ret(&int_32.from_signed(123)).unwrap(); - rhs.ret(&int_32.from_signed(456)).unwrap(); + let left = lhs.add(&call, &int_32.from_signed(20), "add").unwrap(); + let right = rhs.add(&call, &int_32.from_signed(30), "add").unwrap(); + + let final_block = function.block("final"); + let phi = final_block + .phi::(&int_32, "phi") + .unwrap() + .add_incoming(&left, &lhs) + .add_incoming(&right, &rhs) + .build(); + + lhs.br(&final_block).unwrap(); + rhs.br(&final_block).unwrap(); + + let val = final_block + .add(&phi, &int_32.from_signed(11), "add") + .unwrap(); + final_block.ret(&val).unwrap(); match module.print_to_string() { Ok(v) => println!("{}", v), diff --git a/reid-llvm-lib/src/lib.rs b/reid-llvm-lib/src/lib.rs index 553f0e7..4d307b1 100644 --- a/reid-llvm-lib/src/lib.rs +++ b/reid-llvm-lib/src/lib.rs @@ -1,4 +1,6 @@ use std::ffi::CString; +use std::marker::PhantomData; +use std::net::Incoming; use std::ptr::null_mut; use llvm_sys::analysis::LLVMVerifyModule; @@ -101,11 +103,11 @@ impl<'ctx> Module<'ctx> { } } - pub fn add_function( - &self, + pub fn add_function>( + &'ctx self, fn_type: FunctionType<'ctx, ReturnValue::BaseType>, name: &str, - ) -> Function<'_, ReturnValue> { + ) -> Function<'ctx, ReturnValue> { unsafe { let name_cstring = into_cstring(name); let function_ref = @@ -140,7 +142,7 @@ impl<'ctx> Module<'ctx> { triple, c"generic".as_ptr(), c"".as_ptr(), - llvm_sys::target_machine::LLVMCodeGenOptLevel::LLVMCodeGenLevelNone, + llvm_sys::target_machine::LLVMCodeGenOptLevel::LLVMCodeGenLevelDefault, llvm_sys::target_machine::LLVMRelocMode::LLVMRelocDefault, llvm_sys::target_machine::LLVMCodeModel::LLVMCodeModelDefault, ); @@ -190,20 +192,20 @@ impl<'a> Drop for Module<'a> { } } -pub struct Function<'ctx, ReturnValue: BasicValue> { +pub struct Function<'ctx, ReturnValue: BasicValue<'ctx>> { module: &'ctx Module<'ctx>, name: CString, fn_type: FunctionType<'ctx, ReturnValue::BaseType>, fn_ref: LLVMValueRef, } -impl<'ctx, ReturnValue: BasicValue> Function<'ctx, ReturnValue> { +impl<'ctx, ReturnValue: BasicValue<'ctx>> Function<'ctx, ReturnValue> { pub fn block>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnValue> { BasicBlock::in_function(&self, name.into()) } } -pub struct BasicBlock<'ctx, ReturnValue: BasicValue> { +pub struct BasicBlock<'ctx, ReturnValue: BasicValue<'ctx>> { function: &'ctx Function<'ctx, ReturnValue>, builder_ref: LLVMBuilderRef, name: CString, @@ -211,9 +213,9 @@ pub struct BasicBlock<'ctx, ReturnValue: BasicValue> { inserted: bool, } -impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { +impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { fn in_function( - function: &'ctx Function, + function: &'ctx Function<'ctx, ReturnValue>, name: String, ) -> BasicBlock<'ctx, ReturnValue> { unsafe { @@ -233,7 +235,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { } #[must_use] - pub fn integer_compare( + pub fn integer_compare>( &self, lhs: &'ctx T, rhs: &'ctx T, @@ -285,7 +287,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { } #[must_use] - pub fn add(&self, lhs: &T, rhs: &T, name: &str) -> Result { + pub fn add>(&self, lhs: &T, rhs: &T, name: &str) -> Result { if lhs.llvm_type() != rhs.llvm_type() { return Err(()); // TODO error } @@ -302,7 +304,24 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { } #[must_use] - pub fn br(self, into: BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> { + pub fn phi>( + &self, + phi_type: &'ctx PhiValue::BaseType, + name: &str, + ) -> Result, ()> { + unsafe { + LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); + let phi_node = LLVMBuildPhi( + self.builder_ref, + phi_type.llvm_type(), + into_cstring(name).as_ptr(), + ); + Ok(PhiBuilder::new(phi_node)) + } + } + + #[must_use] + pub fn br(self, into: &BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> { unsafe { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMBuildBr(self.builder_ref, into.blockref); @@ -312,7 +331,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { } #[must_use] - pub fn conditional_br( + pub fn conditional_br>( self, condition: &T, lhs_name: &str, @@ -336,7 +355,13 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { #[must_use] pub fn ret(self, return_value: &ReturnValue) -> Result<(), ()> { - if self.function.fn_type.return_type().llvm_type() != return_value.llvm_type() { + if self + .function + .fn_type + .return_type(self.function.module.context) + .llvm_type() + != return_value.llvm_type() + { return Err(()); } unsafe { @@ -355,7 +380,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { } } -impl<'ctx, ReturnValue: BasicValue> Drop for BasicBlock<'ctx, ReturnValue> { +impl<'ctx, ReturnValue: BasicValue<'ctx>> Drop for BasicBlock<'ctx, ReturnValue> { fn drop(&mut self) { if !self.inserted { unsafe { @@ -364,3 +389,37 @@ impl<'ctx, ReturnValue: BasicValue> Drop for BasicBlock<'ctx, ReturnValue> { } } } + +pub struct PhiBuilder<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>> { + phi_node: LLVMValueRef, + phantom: PhantomData<&'ctx (PhiValue, ReturnValue)>, +} + +impl<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>> + PhiBuilder<'ctx, ReturnValue, PhiValue> +{ + fn new(phi_node: LLVMValueRef) -> PhiBuilder<'ctx, ReturnValue, PhiValue> { + PhiBuilder { + phi_node, + phantom: PhantomData, + } + } + + pub fn add_incoming(&self, value: &PhiValue, block: &BasicBlock<'ctx, ReturnValue>) -> &Self { + let mut values = vec![value.llvm_value()]; + let mut blocks = vec![block.blockref]; + unsafe { + LLVMAddIncoming( + self.phi_node, + values.as_mut_ptr(), + blocks.as_mut_ptr(), + values.len() as u32, + ); + self + } + } + + pub fn build(&self) -> PhiValue { + unsafe { PhiValue::from_llvm(self.phi_node) } + } +} diff --git a/reid-llvm-lib/src/types.rs b/reid-llvm-lib/src/types.rs index 251137c..acb0073 100644 --- a/reid-llvm-lib/src/types.rs +++ b/reid-llvm-lib/src/types.rs @@ -1,4 +1,4 @@ -use std::{any::Any, marker::PhantomData}; +use std::{any::Any, marker::PhantomData, ptr::null_mut}; use llvm_sys::{ LLVMTypeKind, @@ -8,10 +8,13 @@ use llvm_sys::{ use crate::Context; -pub trait BasicType { +pub trait BasicType<'ctx> { fn llvm_type(&self) -> LLVMTypeRef; + fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self + where + Self: Sized; - fn function_type<'a>(&'a self, params: &'a [&'a dyn BasicType]) -> FunctionType<'a, Self> + fn function_type(&'ctx self, params: &'ctx [&'ctx dyn BasicType]) -> FunctionType<'ctx, Self> where Self: Sized, { @@ -20,14 +23,14 @@ pub trait BasicType { let param_ptr = typerefs.as_mut_ptr(); let param_len = typerefs.len(); FunctionType { - return_type: self, + phantom: PhantomData, param_types: typerefs, type_ref: LLVMFunctionType(self.llvm_type(), param_ptr, param_len as u32, 0), } } } - fn array_type(&self, length: u32) -> ArrayType + fn array_type(&'ctx self, length: u32) -> ArrayType<'ctx, Self> where Self: Sized, { @@ -39,13 +42,13 @@ pub trait BasicType { } } -impl PartialEq for &dyn BasicType { +impl<'ctx> PartialEq for &dyn BasicType<'ctx> { fn eq(&self, other: &Self) -> bool { self.llvm_type() == other.llvm_type() } } -impl PartialEq for &dyn BasicType { +impl<'ctx> PartialEq for &dyn BasicType<'ctx> { fn eq(&self, other: &LLVMTypeRef) -> bool { self.llvm_type() == *other } @@ -56,10 +59,20 @@ pub struct IntegerType<'ctx> { type_ref: LLVMTypeRef, } -impl<'ctx> BasicType for IntegerType<'ctx> { +impl<'ctx> BasicType<'ctx> for IntegerType<'ctx> { fn llvm_type(&self) -> LLVMTypeRef { self.type_ref } + + fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self + where + Self: Sized, + { + IntegerType { + context, + type_ref: llvm_type, + } + } } impl<'ctx> IntegerType<'ctx> { @@ -100,38 +113,72 @@ impl<'ctx> IntegerType<'ctx> { } } -pub struct FunctionType<'ctx, ReturnType: BasicType> { - pub(crate) return_type: &'ctx ReturnType, +pub struct FunctionType<'ctx, ReturnType: BasicType<'ctx>> { + phantom: PhantomData<&'ctx ReturnType>, pub(crate) param_types: Vec, type_ref: LLVMTypeRef, } -impl<'ctx, ReturnType: BasicType> BasicType for FunctionType<'ctx, ReturnType> { +impl<'ctx, ReturnType: BasicType<'ctx>> BasicType<'ctx> for FunctionType<'ctx, ReturnType> { fn llvm_type(&self) -> LLVMTypeRef { self.type_ref } -} -impl<'ctx, ReturnType: BasicType> FunctionType<'ctx, ReturnType> { - pub fn return_type(&self) -> &ReturnType { - self.return_type + fn from_llvm(_context: &'ctx Context, fn_type: LLVMTypeRef) -> Self + where + Self: Sized, + { + unsafe { + let param_count = LLVMCountParamTypes(fn_type); + let param_types_ptr: *mut LLVMTypeRef = null_mut(); + LLVMGetParamTypes(fn_type, param_types_ptr); + let param_types: Vec = + std::slice::from_raw_parts(param_types_ptr, param_count as usize) + .iter() + .map(|t| *t) + .collect(); + FunctionType { + phantom: PhantomData, + param_types, + type_ref: fn_type, + } + } } } -pub struct ArrayType<'ctx, T: BasicType> { +impl<'ctx, ReturnType: BasicType<'ctx>> FunctionType<'ctx, ReturnType> { + pub fn return_type(&self, context: &'ctx Context) -> ReturnType { + unsafe { + let return_type = LLVMGetReturnType(self.type_ref); + ReturnType::from_llvm(context, return_type) + } + } +} + +pub struct ArrayType<'ctx, T: BasicType<'ctx>> { element_type: &'ctx T, length: u32, type_ref: LLVMTypeRef, } -impl<'ctx, T: BasicType> BasicType for ArrayType<'ctx, T> { +impl<'ctx, T: BasicType<'ctx>> BasicType<'ctx> for ArrayType<'ctx, T> { fn llvm_type(&self) -> LLVMTypeRef { self.type_ref } + + fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self + where + Self: Sized, + { + unsafe { + let length = LLVMGetArrayLength(llvm_type); + todo!() + } + } } -pub trait BasicValue { - type BaseType: BasicType; +pub trait BasicValue<'ctx> { + type BaseType: BasicType<'ctx>; unsafe fn from_llvm(value: LLVMValueRef) -> Self where Self: Sized; @@ -144,7 +191,7 @@ pub struct IntegerValue<'ctx> { pub(crate) value_ref: LLVMValueRef, } -impl<'ctx> BasicValue for IntegerValue<'ctx> { +impl<'ctx> BasicValue<'ctx> for IntegerValue<'ctx> { type BaseType = IntegerType<'ctx>; unsafe fn from_llvm(value: LLVMValueRef) -> Self {