This commit is contained in:
Sofia 2025-06-28 21:58:35 +03:00
parent 740aee1382
commit 5b23d7d4d5
3 changed files with 174 additions and 108 deletions

View File

@ -1,4 +1,4 @@
use reid_lib::{Comparison, Context, types::BasicType}; use reid_lib::{Context, IntPredicate, types::BasicType};
pub fn main() { pub fn main() {
// Notes from inkwell: // Notes from inkwell:
@ -11,29 +11,29 @@ pub fn main() {
let module = context.module("testmodule"); let module = context.module("testmodule");
let int_32 = context.integer_type::<32, true>(); let int_32 = context.type_i32();
let secondary = module.add_function(int_32.function_type(&[]), "secondary"); let secondary = module.add_function(int_32.function_type(&[]), "secondary");
let s_entry = secondary.block("entry"); let s_entry = secondary.block("entry");
s_entry.ret(&int_32.from_const(54)).unwrap(); s_entry.ret(&int_32.from_signed(54)).unwrap();
let function = module.add_function(int_32.function_type(&[]), "main"); let function = module.add_function(int_32.function_type(&[]), "main");
let entry = function.block("entry"); let entry = function.block("entry");
let v1 = int_32.from_const(100); let v1 = int_32.from_signed(100);
let v2 = entry.call(&secondary, vec![], "call").unwrap(); let v2 = entry.call(&secondary, vec![], "call").unwrap();
let lhs_cmp = entry.add(&v1, &v2, "add").unwrap(); let lhs_cmp = entry.add(&v1, &v2, "add").unwrap();
let rhs_cmp = int_32.from_const(200); let rhs_cmp = int_32.from_signed(200);
let cond_res = entry let cond_res = entry
.integer_compare(&lhs_cmp, &rhs_cmp, &Comparison::LessThan, "cmp") .integer_compare(&lhs_cmp, &rhs_cmp, &IntPredicate::SLT, "cmp")
.unwrap(); .unwrap();
let (lhs, rhs) = entry.conditional_br(&cond_res, "lhs", "rhs").unwrap(); let (lhs, rhs) = entry.conditional_br(&cond_res, "lhs", "rhs").unwrap();
lhs.ret(&int_32.from_const(123)).unwrap(); lhs.ret(&int_32.from_signed(123)).unwrap();
rhs.ret(&int_32.from_const(456)).unwrap(); rhs.ret(&int_32.from_signed(456)).unwrap();
match module.print_to_string() { match module.print_to_string() {
Ok(v) => println!("{}", v), Ok(v) => println!("{}", v),

View File

@ -10,17 +10,25 @@ use llvm_sys::target_machine::{
LLVMCodeGenFileType, LLVMCreateTargetDataLayout, LLVMCreateTargetMachine, LLVMCodeGenFileType, LLVMCreateTargetDataLayout, LLVMCreateTargetMachine,
LLVMGetDefaultTargetTriple, LLVMGetTargetFromTriple, LLVMTargetMachineEmitToFile, LLVMGetDefaultTargetTriple, LLVMGetTargetFromTriple, LLVMTargetMachineEmitToFile,
}; };
use llvm_sys::{LLVMBuilder, LLVMContext, core::*, prelude::*}; use llvm_sys::{LLVMBuilder, LLVMContext, LLVMIntPredicate, core::*, prelude::*};
use types::{BasicType, FunctionType, IntegerType}; use types::{BasicType, BasicValue, FunctionType, IntegerType, Value};
use util::{ErrorMessageHolder, from_cstring, into_cstring}; use util::{ErrorMessageHolder, from_cstring, into_cstring};
pub use types::OpaqueValue;
pub mod types; pub mod types;
mod util; mod util;
pub enum Comparison { pub enum IntPredicate {
LessThan, ULT,
SLT,
}
impl IntPredicate {
pub fn as_llvm(&self) -> LLVMIntPredicate {
match *self {
Self::ULT => LLVMIntPredicate::LLVMIntULT,
Self::SLT => LLVMIntPredicate::LLVMIntSLT,
}
}
} }
pub struct Context { pub struct Context {
@ -42,10 +50,20 @@ impl Context {
} }
} }
pub fn integer_type<'a, const WIDTH: u32, const SIGN: bool>( pub fn type_i1<'a>(&'a self) -> IntegerType<'a> {
&'a self, IntegerType::in_context(&self, 1)
) -> IntegerType<'a, WIDTH, SIGN> { }
IntegerType::in_context(&self)
pub fn type_i8<'a>(&'a self) -> IntegerType<'a> {
IntegerType::in_context(&self, 8)
}
pub fn type_i16<'a>(&'a self) -> IntegerType<'a> {
IntegerType::in_context(&self, 16)
}
pub fn type_i32<'a>(&'a self) -> IntegerType<'a> {
IntegerType::in_context(&self, 32)
} }
pub fn module<T: Into<String>>(&self, name: T) -> Module { pub fn module<T: Into<String>>(&self, name: T) -> Module {
@ -83,11 +101,11 @@ impl<'ctx> Module<'ctx> {
} }
} }
pub fn add_function<ReturnType: BasicType, T: Into<String>>( pub fn add_function<ReturnValue: BasicValue>(
&self, &self,
fn_type: FunctionType<'ctx, ReturnType>, fn_type: FunctionType<'ctx, ReturnValue::BaseType>,
name: T, name: &str,
) -> Function<'_, ReturnType> { ) -> Function<'_, ReturnValue> {
unsafe { unsafe {
let name_cstring = into_cstring(name); let name_cstring = into_cstring(name);
let function_ref = let function_ref =
@ -172,32 +190,32 @@ impl<'a> Drop for Module<'a> {
} }
} }
pub struct Function<'ctx, ReturnType: BasicType> { pub struct Function<'ctx, ReturnValue: BasicValue> {
module: &'ctx Module<'ctx>, module: &'ctx Module<'ctx>,
name: CString, name: CString,
fn_type: FunctionType<'ctx, ReturnType>, fn_type: FunctionType<'ctx, ReturnValue::BaseType>,
fn_ref: LLVMValueRef, fn_ref: LLVMValueRef,
} }
impl<'ctx, ReturnType: BasicType> Function<'ctx, ReturnType> { impl<'ctx, ReturnValue: BasicValue> Function<'ctx, ReturnValue> {
pub fn block<T: Into<String>>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnType> { pub fn block<T: Into<String>>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnValue> {
BasicBlock::in_function(&self, name.into()) BasicBlock::in_function(&self, name.into())
} }
} }
pub struct BasicBlock<'ctx, ReturnType: BasicType> { pub struct BasicBlock<'ctx, ReturnValue: BasicValue> {
function: &'ctx Function<'ctx, ReturnType>, function: &'ctx Function<'ctx, ReturnValue>,
builder_ref: LLVMBuilderRef, builder_ref: LLVMBuilderRef,
name: CString, name: CString,
blockref: LLVMBasicBlockRef, blockref: LLVMBasicBlockRef,
inserted: bool, inserted: bool,
} }
impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> { impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> {
fn in_function( fn in_function(
function: &'ctx Function<ReturnType>, function: &'ctx Function<ReturnValue>,
name: String, name: String,
) -> BasicBlock<'ctx, ReturnType> { ) -> BasicBlock<'ctx, ReturnValue> {
unsafe { unsafe {
let block_name = into_cstring(name); let block_name = into_cstring(name);
let block_ref = LLVMCreateBasicBlockInContext( let block_ref = LLVMCreateBasicBlockInContext(
@ -215,49 +233,44 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> {
} }
#[must_use] #[must_use]
pub fn integer_compare<T: Into<String>>( pub fn integer_compare<T: BasicValue>(
&self, &self,
lhs: &'ctx OpaqueValue<'ctx>, lhs: &'ctx T,
rhs: &'ctx OpaqueValue<'ctx>, rhs: &'ctx T,
comparison: &Comparison, comparison: &IntPredicate,
name: T, name: &str,
) -> Result<OpaqueValue<'ctx>, ()> { ) -> Result<T, ()> {
if lhs.basic_type != rhs.basic_type {
return Err(()); // TODO invalid amount of parameters
}
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
let value = match comparison { let value = LLVMBuildICmp(
Comparison::LessThan => LLVMBuildICmp(
self.builder_ref, self.builder_ref,
llvm_sys::LLVMIntPredicate::LLVMIntSLT, comparison.as_llvm(),
lhs.value_ref, lhs.llvm_value(),
rhs.value_ref, rhs.llvm_value(),
into_cstring(name.into()).as_ptr(), into_cstring(name).as_ptr(),
), );
};
Ok(OpaqueValue::new(lhs.basic_type, value)) Ok(T::from_llvm(value))
} }
} }
#[must_use] #[must_use]
pub fn call<T: Into<String>>( pub fn call(
&self, &self,
callee: &'ctx Function<'ctx, ReturnType>, callee: &'ctx Function<'ctx, ReturnValue>,
params: Vec<&'ctx OpaqueValue<'ctx>>, params: Vec<Value<'ctx>>,
name: T, name: &str,
) -> Result<OpaqueValue<'ctx>, ()> { ) -> Result<ReturnValue, ()> {
if params.len() != callee.fn_type.param_types.len() { if params.len() != callee.fn_type.param_types.len() {
return Err(()); // TODO invalid amount of parameters return Err(()); // TODO invalid amount of parameters
} }
for (t1, t2) in callee.fn_type.param_types.iter().zip(&params) { for (t1, t2) in callee.fn_type.param_types.iter().zip(&params) {
if t1 != &t2.basic_type.llvm_type() { if t1 != &t2.llvm_type() {
return Err(()); // TODO wrong types in parameters return Err(()); // TODO wrong types in parameters
} }
} }
unsafe { unsafe {
let mut param_list: Vec<LLVMValueRef> = params.iter().map(|p| p.value_ref).collect(); let mut param_list: Vec<LLVMValueRef> = params.iter().map(|p| p.llvm_value()).collect();
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
let ret_val = LLVMBuildCall2( let ret_val = LLVMBuildCall2(
self.builder_ref, self.builder_ref,
@ -265,36 +278,31 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> {
callee.fn_ref, callee.fn_ref,
param_list.as_mut_ptr(), param_list.as_mut_ptr(),
param_list.len() as u32, param_list.len() as u32,
into_cstring(name.into()).as_ptr(), into_cstring(name).as_ptr(),
); );
Ok(OpaqueValue::new(callee.fn_type.return_type, ret_val)) Ok(ReturnValue::from_llvm(ret_val))
} }
} }
#[must_use] #[must_use]
pub fn add<T: Into<String>>( pub fn add<T: BasicValue>(&self, lhs: &T, rhs: &T, name: &str) -> Result<T, ()> {
&self, if lhs.llvm_type() != rhs.llvm_type() {
lhs: &OpaqueValue<'ctx>,
rhs: &OpaqueValue<'ctx>,
name: T,
) -> Result<OpaqueValue<'ctx>, ()> {
if lhs.basic_type != rhs.basic_type {
return Err(()); // TODO error return Err(()); // TODO error
} }
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
let add_value_ref = LLVMBuildAdd( let add_value_ref = LLVMBuildAdd(
self.builder_ref, self.builder_ref,
lhs.value_ref, lhs.llvm_value(),
rhs.value_ref, rhs.llvm_value(),
into_cstring(name.into()).as_ptr(), into_cstring(name).as_ptr(),
); );
Ok(OpaqueValue::new(lhs.basic_type, add_value_ref)) Ok(T::from_llvm(add_value_ref))
} }
} }
#[must_use] #[must_use]
pub fn br(self, into: BasicBlock<'ctx, ReturnType>) -> Result<(), ()> { pub fn br(self, into: BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> {
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildBr(self.builder_ref, into.blockref); LLVMBuildBr(self.builder_ref, into.blockref);
@ -304,12 +312,12 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> {
} }
#[must_use] #[must_use]
pub fn conditional_br<T: Into<String>, U: Into<String>>( pub fn conditional_br<T: BasicValue>(
self, self,
condition: &OpaqueValue<'ctx>, condition: &T,
lhs_name: T, lhs_name: &str,
rhs_name: U, rhs_name: &str,
) -> Result<(BasicBlock<'ctx, ReturnType>, BasicBlock<'ctx, ReturnType>), ()> { ) -> Result<(BasicBlock<'ctx, ReturnValue>, BasicBlock<'ctx, ReturnValue>), ()> {
unsafe { unsafe {
let lhs = BasicBlock::in_function(&self.function, lhs_name.into()); let lhs = BasicBlock::in_function(&self.function, lhs_name.into());
let rhs = BasicBlock::in_function(&self.function, rhs_name.into()); let rhs = BasicBlock::in_function(&self.function, rhs_name.into());
@ -317,7 +325,7 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildCondBr( LLVMBuildCondBr(
self.builder_ref, self.builder_ref,
condition.value_ref, condition.llvm_value(),
lhs.blockref, lhs.blockref,
rhs.blockref, rhs.blockref,
); );
@ -327,13 +335,13 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> {
} }
#[must_use] #[must_use]
pub fn ret(self, return_value: &OpaqueValue<'ctx>) -> Result<(), ()> { pub fn ret(self, return_value: &ReturnValue) -> Result<(), ()> {
if self.function.fn_type.return_type().llvm_type() != return_value.basic_type.llvm_type() { if self.function.fn_type.return_type().llvm_type() != return_value.llvm_type() {
return Err(()); return Err(());
} }
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildRet(self.builder_ref, return_value.value_ref); LLVMBuildRet(self.builder_ref, return_value.llvm_value());
self.terminate(); self.terminate();
Ok(()) Ok(())
} }
@ -347,7 +355,7 @@ impl<'ctx, ReturnType: BasicType> BasicBlock<'ctx, ReturnType> {
} }
} }
impl<'ctx, ReturnType: BasicType> Drop for BasicBlock<'ctx, ReturnType> { impl<'ctx, ReturnValue: BasicValue> Drop for BasicBlock<'ctx, ReturnValue> {
fn drop(&mut self) { fn drop(&mut self) {
if !self.inserted { if !self.inserted {
unsafe { unsafe {

View File

@ -1,4 +1,7 @@
use std::{any::Any, marker::PhantomData};
use llvm_sys::{ use llvm_sys::{
LLVMTypeKind,
core::*, core::*,
prelude::{LLVMTypeRef, LLVMValueRef}, prelude::{LLVMTypeRef, LLVMValueRef},
}; };
@ -48,46 +51,51 @@ impl PartialEq<LLVMTypeRef> for &dyn BasicType {
} }
} }
pub struct IntegerType<'ctx, const WIDTH: u32, const SIGNED: bool> { pub struct IntegerType<'ctx> {
context: &'ctx Context, context: &'ctx Context,
type_ref: LLVMTypeRef, type_ref: LLVMTypeRef,
} }
impl<'ctx, const WIDTH: u32, const SIGNED: bool> BasicType for IntegerType<'ctx, WIDTH, SIGNED> { impl<'ctx> BasicType for IntegerType<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef { fn llvm_type(&self) -> LLVMTypeRef {
self.type_ref self.type_ref
} }
} }
impl<'ctx, const WIDTH: u32, const SIGNED: bool> IntegerType<'ctx, WIDTH, SIGNED> { impl<'ctx> IntegerType<'ctx> {
pub(crate) fn in_context(context: &Context) -> IntegerType<WIDTH, SIGNED> { pub(crate) fn in_context(context: &Context, width: u32) -> IntegerType {
let type_ref = unsafe { let type_ref = unsafe {
match WIDTH { match width {
128 => LLVMInt128TypeInContext(context.context_ref), 128 => LLVMInt128TypeInContext(context.context_ref),
64 => LLVMInt64TypeInContext(context.context_ref), 64 => LLVMInt64TypeInContext(context.context_ref),
32 => LLVMInt32TypeInContext(context.context_ref), 32 => LLVMInt32TypeInContext(context.context_ref),
16 => LLVMInt16TypeInContext(context.context_ref), 16 => LLVMInt16TypeInContext(context.context_ref),
8 => LLVMInt8TypeInContext(context.context_ref), 8 => LLVMInt8TypeInContext(context.context_ref),
1 => LLVMInt1TypeInContext(context.context_ref), 1 => LLVMInt1TypeInContext(context.context_ref),
_ => LLVMIntTypeInContext(context.context_ref, WIDTH), _ => LLVMIntTypeInContext(context.context_ref, width),
} }
}; };
IntegerType { context, type_ref } IntegerType { context, type_ref }
} }
pub fn from_const(&self, value: u64) -> OpaqueValue { pub fn from_signed(&self, value: i64) -> IntegerValue<'_> {
unsafe { self.from_const(value as u64, true)
OpaqueValue {
basic_type: self,
value_ref: LLVMConstInt(self.type_ref, value, Self::sign_to_i32()),
}
}
} }
const fn sign_to_i32() -> i32 { pub fn from_unsigned(&self, value: i64) -> IntegerValue<'_> {
match SIGNED { self.from_const(value as u64, false)
}
fn from_const(&self, value: u64, sign: bool) -> IntegerValue<'_> {
unsafe {
IntegerValue::from_llvm(LLVMConstInt(
self.type_ref,
value,
match sign {
true => 1, true => 1,
false => 0, false => 0,
},
))
} }
} }
} }
@ -122,19 +130,69 @@ impl<'ctx, T: BasicType> BasicType for ArrayType<'ctx, T> {
} }
} }
pub struct OpaqueValue<'ctx> { pub trait BasicValue {
pub(crate) basic_type: &'ctx dyn BasicType, type BaseType: BasicType;
unsafe fn from_llvm(value: LLVMValueRef) -> Self
where
Self: Sized;
fn llvm_value(&self) -> LLVMValueRef;
fn llvm_type(&self) -> LLVMTypeRef;
}
pub struct IntegerValue<'ctx> {
phantom: PhantomData<&'ctx ()>,
pub(crate) value_ref: LLVMValueRef, pub(crate) value_ref: LLVMValueRef,
} }
impl<'ctx> OpaqueValue<'ctx> { impl<'ctx> BasicValue for IntegerValue<'ctx> {
pub(crate) fn new( type BaseType = IntegerType<'ctx>;
basic_type: &'ctx dyn BasicType,
value_ref: LLVMValueRef, unsafe fn from_llvm(value: LLVMValueRef) -> Self {
) -> OpaqueValue<'ctx> { IntegerValue {
OpaqueValue { phantom: PhantomData,
basic_type, value_ref: value,
value_ref, }
}
fn llvm_value(&self) -> LLVMValueRef {
self.value_ref
}
fn llvm_type(&self) -> LLVMTypeRef {
unsafe { LLVMTypeOf(self.value_ref) }
}
}
pub enum Value<'ctx> {
Integer(IntegerValue<'ctx>),
}
impl<'ctx> Value<'ctx> {
unsafe fn from_llvm(value: LLVMValueRef) -> Self
where
Self: Sized,
{
unsafe {
use LLVMTypeKind::*;
let llvm_type = LLVMTypeOf(value);
let type_kind = LLVMGetTypeKind(llvm_type);
match type_kind {
LLVMIntegerTypeKind => Value::Integer(IntegerValue::from_llvm(value)),
_ => panic!("asd"),
}
}
}
pub fn llvm_value(&self) -> LLVMValueRef {
match self {
Self::Integer(i) => i.llvm_value(),
}
}
pub fn llvm_type(&self) -> LLVMTypeRef {
match self {
Self::Integer(i) => i.llvm_type(),
} }
} }
} }