This commit is contained in:
Sofia 2025-06-29 01:18:17 +03:00
parent 5b23d7d4d5
commit 814b816450
4 changed files with 167 additions and 43 deletions

View File

@ -7,7 +7,7 @@
export .env export .env
cargo run --example libtest && \ cargo run --example libtest && \
# clang++ main.cpp hello.o -o main && \ # clang hello.o -o main && \
ld -dynamic-linker /lib64/ld-linux-x86-64.so.2 \ ld -dynamic-linker /lib64/ld-linux-x86-64.so.2 \
-o main /usr/lib/crt1.o hello.o -lc && \ -o main /usr/lib/crt1.o hello.o -lc && \
./main ; echo "Return value: ""$?" ./main ; echo "Return value: ""$?"

View File

@ -1,4 +1,7 @@
use reid_lib::{Context, IntPredicate, types::BasicType}; use reid_lib::{
Context, IntPredicate,
types::{BasicType, IntegerType, IntegerValue},
};
pub fn main() { pub fn main() {
// Notes from inkwell: // Notes from inkwell:
@ -21,19 +24,34 @@ pub fn main() {
let entry = function.block("entry"); let entry = function.block("entry");
let v1 = int_32.from_signed(100); let call = entry.call(&secondary, vec![], "call").unwrap();
let v2 = entry.call(&secondary, vec![], "call").unwrap(); let add = entry.add(&int_32.from_signed(100), &call, "add").unwrap();
let lhs_cmp = entry.add(&v1, &v2, "add").unwrap();
let rhs_cmp = int_32.from_signed(200); let rhs_cmp = int_32.from_signed(200);
let cond_res = entry let cond_res = entry
.integer_compare(&lhs_cmp, &rhs_cmp, &IntPredicate::SLT, "cmp") .integer_compare(&add, &rhs_cmp, &IntPredicate::SLT, "cmp")
.unwrap(); .unwrap();
let (lhs, rhs) = entry.conditional_br(&cond_res, "lhs", "rhs").unwrap(); let (lhs, rhs) = entry.conditional_br(&cond_res, "lhs", "rhs").unwrap();
lhs.ret(&int_32.from_signed(123)).unwrap(); let left = lhs.add(&call, &int_32.from_signed(20), "add").unwrap();
rhs.ret(&int_32.from_signed(456)).unwrap(); let right = rhs.add(&call, &int_32.from_signed(30), "add").unwrap();
let final_block = function.block("final");
let phi = final_block
.phi::<IntegerValue>(&int_32, "phi")
.unwrap()
.add_incoming(&left, &lhs)
.add_incoming(&right, &rhs)
.build();
lhs.br(&final_block).unwrap();
rhs.br(&final_block).unwrap();
let val = final_block
.add(&phi, &int_32.from_signed(11), "add")
.unwrap();
final_block.ret(&val).unwrap();
match module.print_to_string() { match module.print_to_string() {
Ok(v) => println!("{}", v), Ok(v) => println!("{}", v),

View File

@ -1,4 +1,6 @@
use std::ffi::CString; use std::ffi::CString;
use std::marker::PhantomData;
use std::net::Incoming;
use std::ptr::null_mut; use std::ptr::null_mut;
use llvm_sys::analysis::LLVMVerifyModule; use llvm_sys::analysis::LLVMVerifyModule;
@ -101,11 +103,11 @@ impl<'ctx> Module<'ctx> {
} }
} }
pub fn add_function<ReturnValue: BasicValue>( pub fn add_function<ReturnValue: BasicValue<'ctx>>(
&self, &'ctx self,
fn_type: FunctionType<'ctx, ReturnValue::BaseType>, fn_type: FunctionType<'ctx, ReturnValue::BaseType>,
name: &str, name: &str,
) -> Function<'_, ReturnValue> { ) -> Function<'ctx, ReturnValue> {
unsafe { unsafe {
let name_cstring = into_cstring(name); let name_cstring = into_cstring(name);
let function_ref = let function_ref =
@ -140,7 +142,7 @@ impl<'ctx> Module<'ctx> {
triple, triple,
c"generic".as_ptr(), c"generic".as_ptr(),
c"".as_ptr(), c"".as_ptr(),
llvm_sys::target_machine::LLVMCodeGenOptLevel::LLVMCodeGenLevelNone, llvm_sys::target_machine::LLVMCodeGenOptLevel::LLVMCodeGenLevelDefault,
llvm_sys::target_machine::LLVMRelocMode::LLVMRelocDefault, llvm_sys::target_machine::LLVMRelocMode::LLVMRelocDefault,
llvm_sys::target_machine::LLVMCodeModel::LLVMCodeModelDefault, llvm_sys::target_machine::LLVMCodeModel::LLVMCodeModelDefault,
); );
@ -190,20 +192,20 @@ impl<'a> Drop for Module<'a> {
} }
} }
pub struct Function<'ctx, ReturnValue: BasicValue> { pub struct Function<'ctx, ReturnValue: BasicValue<'ctx>> {
module: &'ctx Module<'ctx>, module: &'ctx Module<'ctx>,
name: CString, name: CString,
fn_type: FunctionType<'ctx, ReturnValue::BaseType>, fn_type: FunctionType<'ctx, ReturnValue::BaseType>,
fn_ref: LLVMValueRef, fn_ref: LLVMValueRef,
} }
impl<'ctx, ReturnValue: BasicValue> Function<'ctx, ReturnValue> { impl<'ctx, ReturnValue: BasicValue<'ctx>> Function<'ctx, ReturnValue> {
pub fn block<T: Into<String>>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnValue> { pub fn block<T: Into<String>>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnValue> {
BasicBlock::in_function(&self, name.into()) BasicBlock::in_function(&self, name.into())
} }
} }
pub struct BasicBlock<'ctx, ReturnValue: BasicValue> { pub struct BasicBlock<'ctx, ReturnValue: BasicValue<'ctx>> {
function: &'ctx Function<'ctx, ReturnValue>, function: &'ctx Function<'ctx, ReturnValue>,
builder_ref: LLVMBuilderRef, builder_ref: LLVMBuilderRef,
name: CString, name: CString,
@ -211,9 +213,9 @@ pub struct BasicBlock<'ctx, ReturnValue: BasicValue> {
inserted: bool, inserted: bool,
} }
impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> { impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
fn in_function( fn in_function(
function: &'ctx Function<ReturnValue>, function: &'ctx Function<'ctx, ReturnValue>,
name: String, name: String,
) -> BasicBlock<'ctx, ReturnValue> { ) -> BasicBlock<'ctx, ReturnValue> {
unsafe { unsafe {
@ -233,7 +235,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> {
} }
#[must_use] #[must_use]
pub fn integer_compare<T: BasicValue>( pub fn integer_compare<T: BasicValue<'ctx>>(
&self, &self,
lhs: &'ctx T, lhs: &'ctx T,
rhs: &'ctx T, rhs: &'ctx T,
@ -285,7 +287,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> {
} }
#[must_use] #[must_use]
pub fn add<T: BasicValue>(&self, lhs: &T, rhs: &T, name: &str) -> Result<T, ()> { pub fn add<T: BasicValue<'ctx>>(&self, lhs: &T, rhs: &T, name: &str) -> Result<T, ()> {
if lhs.llvm_type() != rhs.llvm_type() { if lhs.llvm_type() != rhs.llvm_type() {
return Err(()); // TODO error return Err(()); // TODO error
} }
@ -302,7 +304,24 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> {
} }
#[must_use] #[must_use]
pub fn br(self, into: BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> { pub fn phi<PhiValue: BasicValue<'ctx>>(
&self,
phi_type: &'ctx PhiValue::BaseType,
name: &str,
) -> Result<PhiBuilder<'ctx, ReturnValue, PhiValue>, ()> {
unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
let phi_node = LLVMBuildPhi(
self.builder_ref,
phi_type.llvm_type(),
into_cstring(name).as_ptr(),
);
Ok(PhiBuilder::new(phi_node))
}
}
#[must_use]
pub fn br(self, into: &BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> {
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildBr(self.builder_ref, into.blockref); LLVMBuildBr(self.builder_ref, into.blockref);
@ -312,7 +331,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> {
} }
#[must_use] #[must_use]
pub fn conditional_br<T: BasicValue>( pub fn conditional_br<T: BasicValue<'ctx>>(
self, self,
condition: &T, condition: &T,
lhs_name: &str, lhs_name: &str,
@ -336,7 +355,13 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> {
#[must_use] #[must_use]
pub fn ret(self, return_value: &ReturnValue) -> Result<(), ()> { pub fn ret(self, return_value: &ReturnValue) -> Result<(), ()> {
if self.function.fn_type.return_type().llvm_type() != return_value.llvm_type() { if self
.function
.fn_type
.return_type(self.function.module.context)
.llvm_type()
!= return_value.llvm_type()
{
return Err(()); return Err(());
} }
unsafe { unsafe {
@ -355,7 +380,7 @@ impl<'ctx, ReturnValue: BasicValue> BasicBlock<'ctx, ReturnValue> {
} }
} }
impl<'ctx, ReturnValue: BasicValue> Drop for BasicBlock<'ctx, ReturnValue> { impl<'ctx, ReturnValue: BasicValue<'ctx>> Drop for BasicBlock<'ctx, ReturnValue> {
fn drop(&mut self) { fn drop(&mut self) {
if !self.inserted { if !self.inserted {
unsafe { unsafe {
@ -364,3 +389,37 @@ impl<'ctx, ReturnValue: BasicValue> Drop for BasicBlock<'ctx, ReturnValue> {
} }
} }
} }
pub struct PhiBuilder<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>> {
phi_node: LLVMValueRef,
phantom: PhantomData<&'ctx (PhiValue, ReturnValue)>,
}
impl<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>>
PhiBuilder<'ctx, ReturnValue, PhiValue>
{
fn new(phi_node: LLVMValueRef) -> PhiBuilder<'ctx, ReturnValue, PhiValue> {
PhiBuilder {
phi_node,
phantom: PhantomData,
}
}
pub fn add_incoming(&self, value: &PhiValue, block: &BasicBlock<'ctx, ReturnValue>) -> &Self {
let mut values = vec![value.llvm_value()];
let mut blocks = vec![block.blockref];
unsafe {
LLVMAddIncoming(
self.phi_node,
values.as_mut_ptr(),
blocks.as_mut_ptr(),
values.len() as u32,
);
self
}
}
pub fn build(&self) -> PhiValue {
unsafe { PhiValue::from_llvm(self.phi_node) }
}
}

View File

@ -1,4 +1,4 @@
use std::{any::Any, marker::PhantomData}; use std::{any::Any, marker::PhantomData, ptr::null_mut};
use llvm_sys::{ use llvm_sys::{
LLVMTypeKind, LLVMTypeKind,
@ -8,10 +8,13 @@ use llvm_sys::{
use crate::Context; use crate::Context;
pub trait BasicType { pub trait BasicType<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef; fn llvm_type(&self) -> LLVMTypeRef;
fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
where
Self: Sized;
fn function_type<'a>(&'a self, params: &'a [&'a dyn BasicType]) -> FunctionType<'a, Self> fn function_type(&'ctx self, params: &'ctx [&'ctx dyn BasicType]) -> FunctionType<'ctx, Self>
where where
Self: Sized, Self: Sized,
{ {
@ -20,14 +23,14 @@ pub trait BasicType {
let param_ptr = typerefs.as_mut_ptr(); let param_ptr = typerefs.as_mut_ptr();
let param_len = typerefs.len(); let param_len = typerefs.len();
FunctionType { FunctionType {
return_type: self, phantom: PhantomData,
param_types: typerefs, param_types: typerefs,
type_ref: LLVMFunctionType(self.llvm_type(), param_ptr, param_len as u32, 0), type_ref: LLVMFunctionType(self.llvm_type(), param_ptr, param_len as u32, 0),
} }
} }
} }
fn array_type(&self, length: u32) -> ArrayType<Self> fn array_type(&'ctx self, length: u32) -> ArrayType<'ctx, Self>
where where
Self: Sized, Self: Sized,
{ {
@ -39,13 +42,13 @@ pub trait BasicType {
} }
} }
impl PartialEq for &dyn BasicType { impl<'ctx> PartialEq for &dyn BasicType<'ctx> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.llvm_type() == other.llvm_type() self.llvm_type() == other.llvm_type()
} }
} }
impl PartialEq<LLVMTypeRef> for &dyn BasicType { impl<'ctx> PartialEq<LLVMTypeRef> for &dyn BasicType<'ctx> {
fn eq(&self, other: &LLVMTypeRef) -> bool { fn eq(&self, other: &LLVMTypeRef) -> bool {
self.llvm_type() == *other self.llvm_type() == *other
} }
@ -56,10 +59,20 @@ pub struct IntegerType<'ctx> {
type_ref: LLVMTypeRef, type_ref: LLVMTypeRef,
} }
impl<'ctx> BasicType for IntegerType<'ctx> { impl<'ctx> BasicType<'ctx> for IntegerType<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef { fn llvm_type(&self) -> LLVMTypeRef {
self.type_ref self.type_ref
} }
fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
where
Self: Sized,
{
IntegerType {
context,
type_ref: llvm_type,
}
}
} }
impl<'ctx> IntegerType<'ctx> { impl<'ctx> IntegerType<'ctx> {
@ -100,38 +113,72 @@ impl<'ctx> IntegerType<'ctx> {
} }
} }
pub struct FunctionType<'ctx, ReturnType: BasicType> { pub struct FunctionType<'ctx, ReturnType: BasicType<'ctx>> {
pub(crate) return_type: &'ctx ReturnType, phantom: PhantomData<&'ctx ReturnType>,
pub(crate) param_types: Vec<LLVMTypeRef>, pub(crate) param_types: Vec<LLVMTypeRef>,
type_ref: LLVMTypeRef, type_ref: LLVMTypeRef,
} }
impl<'ctx, ReturnType: BasicType> BasicType for FunctionType<'ctx, ReturnType> { impl<'ctx, ReturnType: BasicType<'ctx>> BasicType<'ctx> for FunctionType<'ctx, ReturnType> {
fn llvm_type(&self) -> LLVMTypeRef { fn llvm_type(&self) -> LLVMTypeRef {
self.type_ref self.type_ref
} }
}
impl<'ctx, ReturnType: BasicType> FunctionType<'ctx, ReturnType> { fn from_llvm(_context: &'ctx Context, fn_type: LLVMTypeRef) -> Self
pub fn return_type(&self) -> &ReturnType { where
self.return_type Self: Sized,
{
unsafe {
let param_count = LLVMCountParamTypes(fn_type);
let param_types_ptr: *mut LLVMTypeRef = null_mut();
LLVMGetParamTypes(fn_type, param_types_ptr);
let param_types: Vec<LLVMTypeRef> =
std::slice::from_raw_parts(param_types_ptr, param_count as usize)
.iter()
.map(|t| *t)
.collect();
FunctionType {
phantom: PhantomData,
param_types,
type_ref: fn_type,
}
}
} }
} }
pub struct ArrayType<'ctx, T: BasicType> { impl<'ctx, ReturnType: BasicType<'ctx>> FunctionType<'ctx, ReturnType> {
pub fn return_type(&self, context: &'ctx Context) -> ReturnType {
unsafe {
let return_type = LLVMGetReturnType(self.type_ref);
ReturnType::from_llvm(context, return_type)
}
}
}
pub struct ArrayType<'ctx, T: BasicType<'ctx>> {
element_type: &'ctx T, element_type: &'ctx T,
length: u32, length: u32,
type_ref: LLVMTypeRef, type_ref: LLVMTypeRef,
} }
impl<'ctx, T: BasicType> BasicType for ArrayType<'ctx, T> { impl<'ctx, T: BasicType<'ctx>> BasicType<'ctx> for ArrayType<'ctx, T> {
fn llvm_type(&self) -> LLVMTypeRef { fn llvm_type(&self) -> LLVMTypeRef {
self.type_ref self.type_ref
} }
fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
where
Self: Sized,
{
unsafe {
let length = LLVMGetArrayLength(llvm_type);
todo!()
}
}
} }
pub trait BasicValue { pub trait BasicValue<'ctx> {
type BaseType: BasicType; type BaseType: BasicType<'ctx>;
unsafe fn from_llvm(value: LLVMValueRef) -> Self unsafe fn from_llvm(value: LLVMValueRef) -> Self
where where
Self: Sized; Self: Sized;
@ -144,7 +191,7 @@ pub struct IntegerValue<'ctx> {
pub(crate) value_ref: LLVMValueRef, pub(crate) value_ref: LLVMValueRef,
} }
impl<'ctx> BasicValue for IntegerValue<'ctx> { impl<'ctx> BasicValue<'ctx> for IntegerValue<'ctx> {
type BaseType = IntegerType<'ctx>; type BaseType = IntegerType<'ctx>;
unsafe fn from_llvm(value: LLVMValueRef) -> Self { unsafe fn from_llvm(value: LLVMValueRef) -> Self {