From 5b23d7d4d5a96836485e9c30bc28558fa2177468 Mon Sep 17 00:00:00 2001 From: sofia Date: Sat, 28 Jun 2025 21:58:35 +0300 Subject: [PATCH] rework --- reid-llvm-lib/examples/libtest.rs | 16 ++-- reid-llvm-lib/src/lib.rs | 154 ++++++++++++++++-------------- reid-llvm-lib/src/types.rs | 112 ++++++++++++++++------ 3 files changed, 174 insertions(+), 108 deletions(-) diff --git a/reid-llvm-lib/examples/libtest.rs b/reid-llvm-lib/examples/libtest.rs index f6ce7e5..f9bd5bf 100644 --- a/reid-llvm-lib/examples/libtest.rs +++ b/reid-llvm-lib/examples/libtest.rs @@ -1,4 +1,4 @@ -use reid_lib::{Comparison, Context, types::BasicType}; +use reid_lib::{Context, IntPredicate, types::BasicType}; pub fn main() { // Notes from inkwell: @@ -11,29 +11,29 @@ pub fn main() { let module = context.module("testmodule"); - let int_32 = context.integer_type::<32, true>(); + let int_32 = context.type_i32(); let secondary = module.add_function(int_32.function_type(&[]), "secondary"); let s_entry = secondary.block("entry"); - s_entry.ret(&int_32.from_const(54)).unwrap(); + s_entry.ret(&int_32.from_signed(54)).unwrap(); let function = module.add_function(int_32.function_type(&[]), "main"); let entry = function.block("entry"); - let v1 = int_32.from_const(100); + let v1 = int_32.from_signed(100); let v2 = entry.call(&secondary, vec![], "call").unwrap(); let lhs_cmp = entry.add(&v1, &v2, "add").unwrap(); - let rhs_cmp = int_32.from_const(200); + let rhs_cmp = int_32.from_signed(200); let cond_res = entry - .integer_compare(&lhs_cmp, &rhs_cmp, &Comparison::LessThan, "cmp") + .integer_compare(&lhs_cmp, &rhs_cmp, &IntPredicate::SLT, "cmp") .unwrap(); let (lhs, rhs) = entry.conditional_br(&cond_res, "lhs", "rhs").unwrap(); - lhs.ret(&int_32.from_const(123)).unwrap(); - rhs.ret(&int_32.from_const(456)).unwrap(); + lhs.ret(&int_32.from_signed(123)).unwrap(); + rhs.ret(&int_32.from_signed(456)).unwrap(); match module.print_to_string() { Ok(v) => println!("{}", v), diff --git a/reid-llvm-lib/src/lib.rs b/reid-llvm-lib/src/lib.rs index 4a8040a..553f0e7 100644 --- a/reid-llvm-lib/src/lib.rs +++ b/reid-llvm-lib/src/lib.rs @@ -10,17 +10,25 @@ use llvm_sys::target_machine::{ LLVMCodeGenFileType, LLVMCreateTargetDataLayout, LLVMCreateTargetMachine, LLVMGetDefaultTargetTriple, LLVMGetTargetFromTriple, LLVMTargetMachineEmitToFile, }; -use llvm_sys::{LLVMBuilder, LLVMContext, core::*, prelude::*}; -use types::{BasicType, FunctionType, IntegerType}; +use llvm_sys::{LLVMBuilder, LLVMContext, LLVMIntPredicate, core::*, prelude::*}; +use types::{BasicType, BasicValue, FunctionType, IntegerType, Value}; use util::{ErrorMessageHolder, from_cstring, into_cstring}; -pub use types::OpaqueValue; - pub mod types; mod util; -pub enum Comparison { - LessThan, +pub enum IntPredicate { + ULT, + SLT, +} + +impl IntPredicate { + pub fn as_llvm(&self) -> LLVMIntPredicate { + match *self { + Self::ULT => LLVMIntPredicate::LLVMIntULT, + Self::SLT => LLVMIntPredicate::LLVMIntSLT, + } + } } pub struct Context { @@ -42,10 +50,20 @@ impl Context { } } - pub fn integer_type<'a, const WIDTH: u32, const SIGN: bool>( - &'a self, - ) -> IntegerType<'a, WIDTH, SIGN> { - IntegerType::in_context(&self) + pub fn type_i1<'a>(&'a self) -> IntegerType<'a> { + IntegerType::in_context(&self, 1) + } + + pub fn type_i8<'a>(&'a self) -> IntegerType<'a> { + IntegerType::in_context(&self, 8) + } + + pub fn type_i16<'a>(&'a self) -> IntegerType<'a> { + IntegerType::in_context(&self, 16) + } + + pub fn type_i32<'a>(&'a self) -> IntegerType<'a> { + IntegerType::in_context(&self, 32) } pub fn module>(&self, name: T) -> Module { @@ -83,11 +101,11 @@ impl<'ctx> Module<'ctx> { } } - pub fn add_function>( + pub fn add_function( &self, - fn_type: FunctionType<'ctx, ReturnType>, - name: T, - ) -> Function<'_, ReturnType> { + fn_type: FunctionType<'ctx, ReturnValue::BaseType>, + name: &str, + ) -> Function<'_, ReturnValue> { unsafe { let name_cstring = into_cstring(name); let function_ref = @@ -172,32 +190,32 @@ impl<'a> Drop for Module<'a> { } } -pub struct Function<'ctx, ReturnType: BasicType> { +pub struct Function<'ctx, ReturnValue: BasicValue> { module: &'ctx Module<'ctx>, name: CString, - fn_type: FunctionType<'ctx, ReturnType>, + fn_type: FunctionType<'ctx, ReturnValue::BaseType>, fn_ref: LLVMValueRef, } -impl<'ctx, ReturnType: BasicType> Function<'ctx, ReturnType> { - pub fn block>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnType> { +impl<'ctx, ReturnValue: BasicValue> Function<'ctx, ReturnValue> { + pub fn block>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnValue> { BasicBlock::in_function(&self, name.into()) } } -pub struct BasicBlock<'ctx, ReturnType: BasicType> { - function: &'ctx Function<'ctx, ReturnType>, +pub struct BasicBlock<'ctx, ReturnValue: BasicValue> { + function: &'ctx Function<'ctx, ReturnValue>, builder_ref: LLVMBuilderRef, name: CString, blockref: LLVMBasicBlockRef, inserted: bool, } -impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { +impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { fn in_function( - function: &'ctx Function, + function: &'ctx Function, name: String, - ) -> BasicBlock<'ctx, ReturnType> { + ) -> BasicBlock<'ctx, ReturnValue> { unsafe { let block_name = into_cstring(name); let block_ref = LLVMCreateBasicBlockInContext( @@ -215,49 +233,44 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { } #[must_use] - pub fn integer_compare>( + pub fn integer_compare( &self, - lhs: &'ctx OpaqueValue<'ctx>, - rhs: &'ctx OpaqueValue<'ctx>, - comparison: &Comparison, - name: T, - ) -> Result, ()> { - if lhs.basic_type != rhs.basic_type { - return Err(()); // TODO invalid amount of parameters - } + lhs: &'ctx T, + rhs: &'ctx T, + comparison: &IntPredicate, + name: &str, + ) -> Result { unsafe { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); - let value = match comparison { - Comparison::LessThan => LLVMBuildICmp( - self.builder_ref, - llvm_sys::LLVMIntPredicate::LLVMIntSLT, - lhs.value_ref, - rhs.value_ref, - into_cstring(name.into()).as_ptr(), - ), - }; + let value = LLVMBuildICmp( + self.builder_ref, + comparison.as_llvm(), + lhs.llvm_value(), + rhs.llvm_value(), + into_cstring(name).as_ptr(), + ); - Ok(OpaqueValue::new(lhs.basic_type, value)) + Ok(T::from_llvm(value)) } } #[must_use] - pub fn call>( + pub fn call( &self, - callee: &'ctx Function<'ctx, ReturnType>, - params: Vec<&'ctx OpaqueValue<'ctx>>, - name: T, - ) -> Result, ()> { + callee: &'ctx Function<'ctx, ReturnValue>, + params: Vec>, + name: &str, + ) -> Result { if params.len() != callee.fn_type.param_types.len() { return Err(()); // TODO invalid amount of parameters } for (t1, t2) in callee.fn_type.param_types.iter().zip(¶ms) { - if t1 != &t2.basic_type.llvm_type() { + if t1 != &t2.llvm_type() { return Err(()); // TODO wrong types in parameters } } unsafe { - let mut param_list: Vec = params.iter().map(|p| p.value_ref).collect(); + let mut param_list: Vec = params.iter().map(|p| p.llvm_value()).collect(); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); let ret_val = LLVMBuildCall2( self.builder_ref, @@ -265,36 +278,31 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { callee.fn_ref, param_list.as_mut_ptr(), param_list.len() as u32, - into_cstring(name.into()).as_ptr(), + into_cstring(name).as_ptr(), ); - Ok(OpaqueValue::new(callee.fn_type.return_type, ret_val)) + Ok(ReturnValue::from_llvm(ret_val)) } } #[must_use] - pub fn add>( - &self, - lhs: &OpaqueValue<'ctx>, - rhs: &OpaqueValue<'ctx>, - name: T, - ) -> Result, ()> { - if lhs.basic_type != rhs.basic_type { + pub fn add(&self, lhs: &T, rhs: &T, name: &str) -> Result { + if lhs.llvm_type() != rhs.llvm_type() { return Err(()); // TODO error } unsafe { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); let add_value_ref = LLVMBuildAdd( self.builder_ref, - lhs.value_ref, - rhs.value_ref, - into_cstring(name.into()).as_ptr(), + lhs.llvm_value(), + rhs.llvm_value(), + into_cstring(name).as_ptr(), ); - Ok(OpaqueValue::new(lhs.basic_type, add_value_ref)) + Ok(T::from_llvm(add_value_ref)) } } #[must_use] - pub fn br(self, into: BasicBlock<'ctx, ReturnType>) -> Result<(), ()> { + pub fn br(self, into: BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> { unsafe { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMBuildBr(self.builder_ref, into.blockref); @@ -304,12 +312,12 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { } #[must_use] - pub fn conditional_br, U: Into>( + pub fn conditional_br( self, - condition: &OpaqueValue<'ctx>, - lhs_name: T, - rhs_name: U, - ) -> Result<(BasicBlock<'ctx, ReturnType>, BasicBlock<'ctx, ReturnType>), ()> { + condition: &T, + lhs_name: &str, + rhs_name: &str, + ) -> Result<(BasicBlock<'ctx, ReturnValue>, BasicBlock<'ctx, ReturnValue>), ()> { unsafe { let lhs = BasicBlock::in_function(&self.function, lhs_name.into()); let rhs = BasicBlock::in_function(&self.function, rhs_name.into()); @@ -317,7 +325,7 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMBuildCondBr( self.builder_ref, - condition.value_ref, + condition.llvm_value(), lhs.blockref, rhs.blockref, ); @@ -327,13 +335,13 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { } #[must_use] - pub fn ret(self, return_value: &OpaqueValue<'ctx>) -> Result<(), ()> { - if self.function.fn_type.return_type().llvm_type() != return_value.basic_type.llvm_type() { + pub fn ret(self, return_value: &ReturnValue) -> Result<(), ()> { + if self.function.fn_type.return_type().llvm_type() != return_value.llvm_type() { return Err(()); } unsafe { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); - LLVMBuildRet(self.builder_ref, return_value.value_ref); + LLVMBuildRet(self.builder_ref, return_value.llvm_value()); self.terminate(); Ok(()) } @@ -347,7 +355,7 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { } } -impl<'ctx, ReturnType: BasicType> Drop for BasicBlock<'ctx, ReturnType> { +impl<'ctx, ReturnValue: BasicValue> Drop for BasicBlock<'ctx, ReturnValue> { fn drop(&mut self) { if !self.inserted { unsafe { diff --git a/reid-llvm-lib/src/types.rs b/reid-llvm-lib/src/types.rs index 64b5547..251137c 100644 --- a/reid-llvm-lib/src/types.rs +++ b/reid-llvm-lib/src/types.rs @@ -1,4 +1,7 @@ +use std::{any::Any, marker::PhantomData}; + use llvm_sys::{ + LLVMTypeKind, core::*, prelude::{LLVMTypeRef, LLVMValueRef}, }; @@ -48,46 +51,51 @@ impl PartialEq for &dyn BasicType { } } -pub struct IntegerType<'ctx, const WIDTH: u32, const SIGNED: bool> { +pub struct IntegerType<'ctx> { context: &'ctx Context, type_ref: LLVMTypeRef, } -impl<'ctx, const WIDTH: u32, const SIGNED: bool> BasicType for IntegerType<'ctx, WIDTH, SIGNED> { +impl<'ctx> BasicType for IntegerType<'ctx> { fn llvm_type(&self) -> LLVMTypeRef { self.type_ref } } -impl<'ctx, const WIDTH: u32, const SIGNED: bool> IntegerType<'ctx, WIDTH, SIGNED> { - pub(crate) fn in_context(context: &Context) -> IntegerType { +impl<'ctx> IntegerType<'ctx> { + pub(crate) fn in_context(context: &Context, width: u32) -> IntegerType { let type_ref = unsafe { - match WIDTH { + match width { 128 => LLVMInt128TypeInContext(context.context_ref), 64 => LLVMInt64TypeInContext(context.context_ref), 32 => LLVMInt32TypeInContext(context.context_ref), 16 => LLVMInt16TypeInContext(context.context_ref), 8 => LLVMInt8TypeInContext(context.context_ref), 1 => LLVMInt1TypeInContext(context.context_ref), - _ => LLVMIntTypeInContext(context.context_ref, WIDTH), + _ => LLVMIntTypeInContext(context.context_ref, width), } }; IntegerType { context, type_ref } } - pub fn from_const(&self, value: u64) -> OpaqueValue { - unsafe { - OpaqueValue { - basic_type: self, - value_ref: LLVMConstInt(self.type_ref, value, Self::sign_to_i32()), - } - } + pub fn from_signed(&self, value: i64) -> IntegerValue<'_> { + self.from_const(value as u64, true) } - const fn sign_to_i32() -> i32 { - match SIGNED { - true => 1, - false => 0, + pub fn from_unsigned(&self, value: i64) -> IntegerValue<'_> { + self.from_const(value as u64, false) + } + + fn from_const(&self, value: u64, sign: bool) -> IntegerValue<'_> { + unsafe { + IntegerValue::from_llvm(LLVMConstInt( + self.type_ref, + value, + match sign { + true => 1, + false => 0, + }, + )) } } } @@ -122,19 +130,69 @@ impl<'ctx, T: BasicType> BasicType for ArrayType<'ctx, T> { } } -pub struct OpaqueValue<'ctx> { - pub(crate) basic_type: &'ctx dyn BasicType, +pub trait BasicValue { + type BaseType: BasicType; + unsafe fn from_llvm(value: LLVMValueRef) -> Self + where + Self: Sized; + fn llvm_value(&self) -> LLVMValueRef; + fn llvm_type(&self) -> LLVMTypeRef; +} + +pub struct IntegerValue<'ctx> { + phantom: PhantomData<&'ctx ()>, pub(crate) value_ref: LLVMValueRef, } -impl<'ctx> OpaqueValue<'ctx> { - pub(crate) fn new( - basic_type: &'ctx dyn BasicType, - value_ref: LLVMValueRef, - ) -> OpaqueValue<'ctx> { - OpaqueValue { - basic_type, - value_ref, +impl<'ctx> BasicValue for IntegerValue<'ctx> { + type BaseType = IntegerType<'ctx>; + + unsafe fn from_llvm(value: LLVMValueRef) -> Self { + IntegerValue { + phantom: PhantomData, + value_ref: value, + } + } + + fn llvm_value(&self) -> LLVMValueRef { + self.value_ref + } + + fn llvm_type(&self) -> LLVMTypeRef { + unsafe { LLVMTypeOf(self.value_ref) } + } +} + +pub enum Value<'ctx> { + Integer(IntegerValue<'ctx>), +} + +impl<'ctx> Value<'ctx> { + unsafe fn from_llvm(value: LLVMValueRef) -> Self + where + Self: Sized, + { + unsafe { + use LLVMTypeKind::*; + + let llvm_type = LLVMTypeOf(value); + let type_kind = LLVMGetTypeKind(llvm_type); + match type_kind { + LLVMIntegerTypeKind => Value::Integer(IntegerValue::from_llvm(value)), + _ => panic!("asd"), + } + } + } + + pub fn llvm_value(&self) -> LLVMValueRef { + match self { + Self::Integer(i) => i.llvm_value(), + } + } + + pub fn llvm_type(&self) -> LLVMTypeRef { + match self { + Self::Integer(i) => i.llvm_type(), } } }