diff --git a/Cargo.lock b/Cargo.lock index 0de9c85..bcedbb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,6 +100,7 @@ name = "reid" version = "0.1.0" dependencies = [ "llvm-sys", + "reid-lib", "thiserror", ] diff --git a/libtest.sh b/libtest.sh index eb7d0e9..bf971aa 100755 --- a/libtest.sh +++ b/libtest.sh @@ -6,7 +6,7 @@ # Do note this file is extremely simply for my own personal convenience export .env -cargo run --example libtest && \ +cargo run --example testcodegen && \ # clang hello.o -o main && \ ld -dynamic-linker /lib64/ld-linux-x86-64.so.2 \ -o main /usr/lib/crt1.o hello.o -lc && \ diff --git a/reid-llvm-lib/examples/libtest.rs b/reid-llvm-lib/examples/libtest.rs index 43a3ecd..7fa3f81 100644 --- a/reid-llvm-lib/examples/libtest.rs +++ b/reid-llvm-lib/examples/libtest.rs @@ -1,6 +1,6 @@ use reid_lib::{ Context, IntPredicate, - types::{BasicType, IntegerType, IntegerValue, Value}, + types::{BasicType, IntegerValue, Value}, }; pub fn main() { @@ -16,15 +16,19 @@ pub fn main() { let int_32 = context.type_i32(); - let fibonacci = module.add_function(int_32.function_type(vec![&int_32]), "fibonacci"); - let f_main = fibonacci.block("main"); + let fibonacci = module.add_function(int_32.function_type(vec![int_32.into()]), "fibonacci"); + let mut f_main = fibonacci.block("main"); - let param = fibonacci.get_param::(0).unwrap(); - let cmp = f_main + let param = fibonacci + .get_param::(0, int_32.into()) + .unwrap(); + let mut cmp = f_main .integer_compare(¶m, &int_32.from_unsigned(3), &IntPredicate::ULT, "cmp") .unwrap(); - let (done, recurse) = f_main.conditional_br(&cmp, "done", "recurse").unwrap(); + let mut done = fibonacci.block("done"); + let mut recurse = fibonacci.block("recurse"); + f_main.conditional_br(&cmp, &done, &recurse).unwrap(); done.ret(&int_32.from_unsigned(1)).unwrap(); @@ -34,7 +38,7 @@ pub fn main() { let minus_two = recurse .sub(¶m, &int_32.from_unsigned(2), "minus_two") .unwrap(); - let one = recurse + let one: IntegerValue = recurse .call(&fibonacci, vec![Value::Integer(minus_one)], "call_one") .unwrap(); let two = recurse @@ -47,8 +51,8 @@ pub fn main() { let main_f = module.add_function(int_32.function_type(Vec::new()), "main"); - let main_b = main_f.block("main"); - let call = main_b + let mut main_b = main_f.block("main"); + let call: IntegerValue = main_b .call( &fibonacci, vec![Value::Integer(int_32.from_unsigned(8))], diff --git a/reid-llvm-lib/src/lib.rs b/reid-llvm-lib/src/lib.rs index 92a9688..ec9b390 100644 --- a/reid-llvm-lib/src/lib.rs +++ b/reid-llvm-lib/src/lib.rs @@ -20,15 +20,20 @@ pub mod types; mod util; pub enum IntPredicate { - ULT, SLT, + SGT, + + ULT, + UGT, } impl IntPredicate { pub fn as_llvm(&self) -> LLVMIntPredicate { match *self { - Self::ULT => LLVMIntPredicate::LLVMIntULT, Self::SLT => LLVMIntPredicate::LLVMIntSLT, + Self::SGT => LLVMIntPredicate::LLVMIntSGT, + Self::ULT => LLVMIntPredicate::LLVMIntULT, + Self::UGT => LLVMIntPredicate::LLVMIntUGT, } } } @@ -68,8 +73,8 @@ impl Context { IntegerType::in_context(&self, 32) } - pub fn module>(&self, name: T) -> Module { - Module::with_name(self, name.into()) + pub fn module(&self, name: &str) -> Module { + Module::with_name(self, name) } } @@ -90,7 +95,7 @@ pub struct Module<'ctx> { } impl<'ctx> Module<'ctx> { - fn with_name(context: &'ctx Context, name: String) -> Module<'ctx> { + fn with_name(context: &'ctx Context, name: &str) -> Module<'ctx> { unsafe { let cstring_name = into_cstring(name); let module_ref = @@ -103,11 +108,7 @@ impl<'ctx> Module<'ctx> { } } - pub fn add_function>( - &'ctx self, - fn_type: FunctionType<'ctx, ReturnValue::BaseType>, - name: &str, - ) -> Function<'ctx, ReturnValue> { + pub fn add_function(&'ctx self, fn_type: FunctionType<'ctx>, name: &str) -> Function<'ctx> { unsafe { let name_cstring = into_cstring(name); let function_ref = @@ -193,21 +194,26 @@ impl<'a> Drop for Module<'a> { } } -pub struct Function<'ctx, ReturnValue: BasicValue<'ctx>> { +#[derive(Clone)] +pub struct Function<'ctx> { module: &'ctx Module<'ctx>, name: CString, - fn_type: FunctionType<'ctx, ReturnValue::BaseType>, + fn_type: FunctionType<'ctx>, fn_ref: LLVMValueRef, } -impl<'ctx, ReturnValue: BasicValue<'ctx>> Function<'ctx, ReturnValue> { - pub fn block>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnValue> { +impl<'ctx> Function<'ctx> { + pub fn block>(&'ctx self, name: T) -> BasicBlock<'ctx> { BasicBlock::in_function(&self, name.into()) } - pub fn get_param>(&'ctx self, nth: usize) -> Result { - if let Some(param_type) = self.fn_type.param_types.iter().nth(nth) { - if self.fn_type.return_type(self.module.context).llvm_type() != *param_type { + pub fn get_param>( + &'ctx self, + nth: usize, + param_type: T::BaseType, + ) -> Result { + if let Some(actual_type) = self.fn_type.param_types.iter().nth(nth) { + if param_type.llvm_type() != *actual_type { return Err(String::from("Wrong type")); } } else { @@ -217,29 +223,27 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> Function<'ctx, ReturnValue> { } } -pub struct BasicBlock<'ctx, ReturnValue: BasicValue<'ctx>> { - function: &'ctx Function<'ctx, ReturnValue>, +pub struct BasicBlock<'ctx> { + function: &'ctx Function<'ctx>, builder_ref: LLVMBuilderRef, - name: CString, + name: String, blockref: LLVMBasicBlockRef, inserted: bool, } -impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { - fn in_function( - function: &'ctx Function<'ctx, ReturnValue>, - name: String, - ) -> BasicBlock<'ctx, ReturnValue> { +impl<'ctx> BasicBlock<'ctx> { + fn in_function(function: &'ctx Function<'ctx>, name: String) -> BasicBlock<'ctx> { unsafe { - let block_name = into_cstring(name); + let block_name = into_cstring(name.clone()); let block_ref = LLVMCreateBasicBlockInContext( function.module.context.context_ref, block_name.as_ptr(), ); + LLVMAppendExistingBasicBlock(function.fn_ref, block_ref); BasicBlock { function: function, builder_ref: function.module.context.builder_ref, - name: block_name, + name, blockref: block_ref, inserted: false, } @@ -269,12 +273,12 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { } #[must_use] - pub fn call( + pub fn call>( &self, - callee: &'ctx Function<'ctx, ReturnValue>, + callee: &Function<'ctx>, params: Vec>, name: &str, - ) -> Result { + ) -> Result { if params.len() != callee.fn_type.param_types.len() { return Err(()); // TODO invalid amount of parameters } @@ -283,6 +287,9 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { return Err(()); // TODO wrong types in parameters } } + if !T::BaseType::is_type(callee.fn_type.return_type) { + return Err(()); // TODO wrong return type + } unsafe { let mut param_list: Vec = params.iter().map(|p| p.llvm_value()).collect(); LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); @@ -294,7 +301,7 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { param_list.len() as u32, into_cstring(name).as_ptr(), ); - Ok(ReturnValue::from_llvm(ret_val)) + Ok(T::from_llvm(ret_val)) } } @@ -317,6 +324,8 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { #[must_use] pub fn sub>(&self, lhs: &T, rhs: &T, name: &str) -> Result { + dbg!(lhs, rhs); + dbg!(lhs.llvm_type(), rhs.llvm_type()); if lhs.llvm_type() != rhs.llvm_type() { return Err(()); // TODO error } @@ -335,9 +344,9 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { #[must_use] pub fn phi>( &self, - phi_type: &'ctx PhiValue::BaseType, + phi_type: &PhiValue::BaseType, name: &str, - ) -> Result, ()> { + ) -> Result, ()> { unsafe { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); let phi_node = LLVMBuildPhi( @@ -350,26 +359,24 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { } #[must_use] - pub fn br(self, into: &BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> { + pub fn br(&mut self, into: &BasicBlock<'ctx>) -> Result<(), ()> { + self.try_insert()?; unsafe { LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMBuildBr(self.builder_ref, into.blockref); - self.terminate(); Ok(()) } } #[must_use] pub fn conditional_br>( - self, + &mut self, condition: &T, - lhs_name: &str, - rhs_name: &str, - ) -> Result<(BasicBlock<'ctx, ReturnValue>, BasicBlock<'ctx, ReturnValue>), ()> { + lhs: &BasicBlock<'ctx>, + rhs: &BasicBlock<'ctx>, + ) -> Result<(), ()> { + self.try_insert()?; unsafe { - let lhs = BasicBlock::in_function(&self.function, lhs_name.into()); - let rhs = BasicBlock::in_function(&self.function, rhs_name.into()); - LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); LLVMBuildCondBr( self.builder_ref, @@ -377,39 +384,34 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> { lhs.blockref, rhs.blockref, ); - self.terminate(); - Ok((lhs, rhs)) - } - } - - #[must_use] - pub fn ret(self, return_value: &ReturnValue) -> Result<(), ()> { - if self - .function - .fn_type - .return_type(self.function.module.context) - .llvm_type() - != return_value.llvm_type() - { - return Err(()); - } - unsafe { - LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); - LLVMBuildRet(self.builder_ref, return_value.llvm_value()); - self.terminate(); Ok(()) } } - unsafe fn terminate(mut self) { - unsafe { - LLVMAppendExistingBasicBlock(self.function.fn_ref, self.blockref); - self.inserted = true; + #[must_use] + pub fn ret>(&mut self, return_value: &T) -> Result<(), ()> { + if self.function.fn_type.return_type != return_value.llvm_type() { + return Err(()); } + self.try_insert()?; + + unsafe { + LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref); + LLVMBuildRet(self.builder_ref, return_value.llvm_value()); + Ok(()) + } + } + + fn try_insert(&mut self) -> Result<(), ()> { + if self.inserted { + return Err(()); + } + self.inserted = true; + Ok(()) } } -impl<'ctx, ReturnValue: BasicValue<'ctx>> Drop for BasicBlock<'ctx, ReturnValue> { +impl<'ctx> Drop for BasicBlock<'ctx> { fn drop(&mut self) { if !self.inserted { unsafe { @@ -419,22 +421,20 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> Drop for BasicBlock<'ctx, ReturnValue> } } -pub struct PhiBuilder<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>> { +pub struct PhiBuilder<'ctx, PhiValue: BasicValue<'ctx>> { phi_node: LLVMValueRef, - phantom: PhantomData<&'ctx (PhiValue, ReturnValue)>, + phantom: PhantomData<&'ctx PhiValue>, } -impl<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>> - PhiBuilder<'ctx, ReturnValue, PhiValue> -{ - fn new(phi_node: LLVMValueRef) -> PhiBuilder<'ctx, ReturnValue, PhiValue> { +impl<'ctx, PhiValue: BasicValue<'ctx>> PhiBuilder<'ctx, PhiValue> { + fn new(phi_node: LLVMValueRef) -> PhiBuilder<'ctx, PhiValue> { PhiBuilder { phi_node, phantom: PhantomData, } } - pub fn add_incoming(&self, value: &PhiValue, block: &BasicBlock<'ctx, ReturnValue>) -> &Self { + pub fn add_incoming(&self, value: &PhiValue, block: &BasicBlock<'ctx>) -> &Self { let mut values = vec![value.llvm_value()]; let mut blocks = vec![block.blockref]; unsafe { diff --git a/reid-llvm-lib/src/types.rs b/reid-llvm-lib/src/types.rs index 9f585c5..48647fd 100644 --- a/reid-llvm-lib/src/types.rs +++ b/reid-llvm-lib/src/types.rs @@ -6,36 +6,40 @@ use llvm_sys::{ prelude::{LLVMTypeRef, LLVMValueRef}, }; -use crate::Context; +use crate::{BasicBlock, Context, PhiBuilder}; pub trait BasicType<'ctx> { fn llvm_type(&self) -> LLVMTypeRef; - fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self + + fn is_type(llvm_type: LLVMTypeRef) -> bool where Self: Sized; - fn function_type(&'ctx self, params: Vec<&'ctx dyn BasicType>) -> FunctionType<'ctx, Self> + unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self where - Self: Sized, - { + Self: Sized; + + fn function_type(&self, params: Vec) -> FunctionType<'ctx> { unsafe { let mut typerefs: Vec = params.iter().map(|b| b.llvm_type()).collect(); let param_ptr = typerefs.as_mut_ptr(); let param_len = typerefs.len(); FunctionType { phantom: PhantomData, + return_type: self.llvm_type(), param_types: typerefs, type_ref: LLVMFunctionType(self.llvm_type(), param_ptr, param_len as u32, 0), } } } - fn array_type(&'ctx self, length: u32) -> ArrayType<'ctx, Self> + fn array_type(&'ctx self, length: u32) -> ArrayType<'ctx> where Self: Sized, { ArrayType { - element_type: self, + phantom: PhantomData, + element_type: self.llvm_type(), length, type_ref: unsafe { LLVMArrayType(self.llvm_type(), length) }, } @@ -54,6 +58,7 @@ impl<'ctx> PartialEq for &dyn BasicType<'ctx> { } } +#[derive(Clone, Copy)] pub struct IntegerType<'ctx> { context: &'ctx Context, type_ref: LLVMTypeRef, @@ -64,7 +69,7 @@ impl<'ctx> BasicType<'ctx> for IntegerType<'ctx> { self.type_ref } - fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self + unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self where Self: Sized, { @@ -73,6 +78,10 @@ impl<'ctx> BasicType<'ctx> for IntegerType<'ctx> { type_ref: llvm_type, } } + + fn is_type(llvm_type: LLVMTypeRef) -> bool { + unsafe { LLVMGetTypeKind(llvm_type) == LLVMTypeKind::LLVMIntegerTypeKind } + } } impl<'ctx> IntegerType<'ctx> { @@ -91,15 +100,15 @@ impl<'ctx> IntegerType<'ctx> { IntegerType { context, type_ref } } - pub fn from_signed(&self, value: i64) -> IntegerValue<'_> { + pub fn from_signed(&self, value: i64) -> IntegerValue<'ctx> { self.from_const(value as u64, true) } - pub fn from_unsigned(&self, value: i64) -> IntegerValue<'_> { + pub fn from_unsigned(&self, value: i64) -> IntegerValue<'ctx> { self.from_const(value as u64, false) } - fn from_const(&self, value: u64, sign: bool) -> IntegerValue<'_> { + fn from_const(&self, value: u64, sign: bool) -> IntegerValue<'ctx> { unsafe { IntegerValue::from_llvm(LLVMConstInt( self.type_ref, @@ -113,18 +122,20 @@ impl<'ctx> IntegerType<'ctx> { } } -pub struct FunctionType<'ctx, ReturnType: BasicType<'ctx>> { - phantom: PhantomData<&'ctx ReturnType>, +#[derive(Clone)] +pub struct FunctionType<'ctx> { + phantom: PhantomData<&'ctx ()>, + pub(crate) return_type: LLVMTypeRef, pub(crate) param_types: Vec, type_ref: LLVMTypeRef, } -impl<'ctx, ReturnType: BasicType<'ctx>> BasicType<'ctx> for FunctionType<'ctx, ReturnType> { +impl<'ctx> BasicType<'ctx> for FunctionType<'ctx> { fn llvm_type(&self) -> LLVMTypeRef { self.type_ref } - fn from_llvm(_context: &'ctx Context, fn_type: LLVMTypeRef) -> Self + unsafe fn from_llvm(_context: &'ctx Context, fn_type: LLVMTypeRef) -> Self where Self: Sized, { @@ -139,34 +150,32 @@ impl<'ctx, ReturnType: BasicType<'ctx>> BasicType<'ctx> for FunctionType<'ctx, R .collect(); FunctionType { phantom: PhantomData, + return_type: LLVMGetReturnType(fn_type), param_types, type_ref: fn_type, } } } -} -impl<'ctx, ReturnType: BasicType<'ctx>> FunctionType<'ctx, ReturnType> { - pub fn return_type(&self, context: &'ctx Context) -> ReturnType { - unsafe { - let return_type = LLVMGetReturnType(self.type_ref); - ReturnType::from_llvm(context, return_type) - } + fn is_type(llvm_type: LLVMTypeRef) -> bool { + unsafe { LLVMGetTypeKind(llvm_type) == LLVMTypeKind::LLVMFunctionTypeKind } } } -pub struct ArrayType<'ctx, T: BasicType<'ctx>> { - element_type: &'ctx T, +#[derive(Clone, Copy)] +pub struct ArrayType<'ctx> { + phantom: PhantomData<&'ctx ()>, + element_type: LLVMTypeRef, length: u32, type_ref: LLVMTypeRef, } -impl<'ctx, T: BasicType<'ctx>> BasicType<'ctx> for ArrayType<'ctx, T> { +impl<'ctx> BasicType<'ctx> for ArrayType<'ctx> { fn llvm_type(&self) -> LLVMTypeRef { self.type_ref } - fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self + unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self where Self: Sized, { @@ -175,9 +184,81 @@ impl<'ctx, T: BasicType<'ctx>> BasicType<'ctx> for ArrayType<'ctx, T> { todo!() } } + + fn is_type(llvm_type: LLVMTypeRef) -> bool { + unsafe { LLVMGetTypeKind(llvm_type) == LLVMTypeKind::LLVMArrayTypeKind } + } } -pub trait BasicValue<'ctx> { +#[derive(Clone)] +pub enum TypeEnum<'ctx> { + Integer(IntegerType<'ctx>), + Array(ArrayType<'ctx>), + Function(FunctionType<'ctx>), +} + +impl<'ctx> From> for TypeEnum<'ctx> { + fn from(int: IntegerType<'ctx>) -> Self { + TypeEnum::Integer(int) + } +} + +impl<'ctx> From> for TypeEnum<'ctx> { + fn from(arr: ArrayType<'ctx>) -> Self { + TypeEnum::Array(arr) + } +} + +impl<'ctx> From> for TypeEnum<'ctx> { + fn from(func: FunctionType<'ctx>) -> Self { + TypeEnum::Function(func) + } +} + +impl<'ctx> TypeEnum<'ctx> { + fn inner_basic(&'ctx self) -> &'ctx dyn BasicType<'ctx> { + match self { + TypeEnum::Integer(integer_type) => integer_type, + TypeEnum::Array(array_type) => array_type, + TypeEnum::Function(function_type) => function_type, + } + } +} + +impl<'ctx> BasicType<'ctx> for TypeEnum<'ctx> { + fn llvm_type(&self) -> LLVMTypeRef { + self.inner_basic().llvm_type() + } + + fn is_type(llvm_type: LLVMTypeRef) -> bool + where + Self: Sized, + { + true + } + + unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self + where + Self: Sized, + { + unsafe { + match LLVMGetTypeKind(llvm_type) { + LLVMTypeKind::LLVMIntegerTypeKind => { + TypeEnum::Integer(IntegerType::from_llvm(context, llvm_type)) + } + LLVMTypeKind::LLVMArrayTypeKind => { + TypeEnum::Array(ArrayType::from_llvm(context, llvm_type)) + } + LLVMTypeKind::LLVMFunctionTypeKind => { + TypeEnum::Function(FunctionType::from_llvm(context, llvm_type)) + } + _ => todo!(), + } + } + } +} + +pub trait BasicValue<'ctx>: std::fmt::Debug { type BaseType: BasicType<'ctx>; unsafe fn from_llvm(value: LLVMValueRef) -> Self where @@ -186,6 +267,7 @@ pub trait BasicValue<'ctx> { fn llvm_type(&self) -> LLVMTypeRef; } +#[derive(Clone, Debug)] pub struct IntegerValue<'ctx> { phantom: PhantomData<&'ctx ()>, pub(crate) value_ref: LLVMValueRef, @@ -210,11 +292,14 @@ impl<'ctx> BasicValue<'ctx> for IntegerValue<'ctx> { } } +#[derive(Clone, Debug)] pub enum Value<'ctx> { Integer(IntegerValue<'ctx>), } -impl<'ctx> Value<'ctx> { +impl<'ctx> BasicValue<'ctx> for Value<'ctx> { + type BaseType = TypeEnum<'ctx>; + unsafe fn from_llvm(value: LLVMValueRef) -> Self where Self: Sized, @@ -231,15 +316,21 @@ impl<'ctx> Value<'ctx> { } } - pub fn llvm_value(&self) -> LLVMValueRef { + fn llvm_value(&self) -> LLVMValueRef { match self { Self::Integer(i) => i.llvm_value(), } } - pub fn llvm_type(&self) -> LLVMTypeRef { + fn llvm_type(&self) -> LLVMTypeRef { match self { Self::Integer(i) => i.llvm_type(), } } } + +impl<'ctx> From> for Value<'ctx> { + fn from(value: IntegerValue<'ctx>) -> Self { + Value::Integer(value) + } +} diff --git a/reid/Cargo.toml b/reid/Cargo.toml index 4d8f5f3..36a2c7c 100644 --- a/reid/Cargo.toml +++ b/reid/Cargo.toml @@ -9,4 +9,5 @@ edition = "2021" ## LLVM Bindings llvm-sys = "160" ## Make it easier to generate errors -thiserror = "1.0.44" \ No newline at end of file +thiserror = "1.0.44" +reid-lib = { path = "../reid-llvm-lib" } \ No newline at end of file diff --git a/reid/examples/testcodegen.rs b/reid/examples/testcodegen.rs new file mode 100644 index 0000000..9d567ef --- /dev/null +++ b/reid/examples/testcodegen.rs @@ -0,0 +1,172 @@ +use reid::mir::*; +use reid_lib::Context; + +fn main() { + let fibonacci_name = "fibonacci".to_owned(); + let fibonacci_n = "N".to_owned(); + + let fibonacci = FunctionDefinition { + name: fibonacci_name.clone(), + parameters: vec![(fibonacci_n.clone(), TypeKind::I32)], + kind: FunctionDefinitionKind::Local( + Block { + statements: vec![Statement( + StatementKind::If(IfExpression( + // If N < 3 + Box::new(Expression( + ExpressionKind::BinOp( + BinaryOperator::Logic(LogicOperator::GreaterThan), + Box::new(Expression( + ExpressionKind::Variable(VariableReference( + TypeKind::I32, + "N".to_string(), + Default::default(), + )), + Default::default(), + )), + Box::new(Expression( + ExpressionKind::Literal(Literal::I32(2)), + Default::default(), + )), + ), + Default::default(), + )), + // Then + Block { + statements: vec![], + return_expression: Some(( + ReturnKind::HardReturn, + // return fibonacci(n-1) + fibonacci(n-2) + Box::new(Expression( + ExpressionKind::BinOp( + BinaryOperator::Add, + // fibonacci(n-1) + Box::new(Expression( + ExpressionKind::FunctionCall(FunctionCall { + name: fibonacci_name.clone(), + return_type: TypeKind::I32, + parameters: vec![Expression( + ExpressionKind::BinOp( + BinaryOperator::Minus, + Box::new(Expression( + ExpressionKind::Variable( + VariableReference( + TypeKind::I32, + fibonacci_n.clone(), + Default::default(), + ), + ), + Default::default(), + )), + Box::new(Expression( + ExpressionKind::Literal(Literal::I32( + 1, + )), + Default::default(), + )), + ), + Default::default(), + )], + }), + Default::default(), + )), + // fibonacci(n-2) + Box::new(Expression( + ExpressionKind::FunctionCall(FunctionCall { + name: fibonacci_name.clone(), + return_type: TypeKind::I32, + parameters: vec![Expression( + ExpressionKind::BinOp( + BinaryOperator::Minus, + Box::new(Expression( + ExpressionKind::Variable( + VariableReference( + TypeKind::I32, + fibonacci_n.clone(), + Default::default(), + ), + ), + Default::default(), + )), + Box::new(Expression( + ExpressionKind::Literal(Literal::I32( + 2, + )), + Default::default(), + )), + ), + Default::default(), + )], + }), + Default::default(), + )), + ), + Default::default(), + )), + )), + range: Default::default(), + }, + // No else-block + None, + )), + Default::default(), + )], + // return 1 + return_expression: Some(( + ReturnKind::SoftReturn, + Box::new(Expression( + ExpressionKind::Literal(Literal::I32(1)), + Default::default(), + )), + )), + range: Default::default(), + }, + Default::default(), + ), + }; + + let main = FunctionDefinition { + name: "main".to_owned(), + parameters: vec![], + kind: FunctionDefinitionKind::Local( + Block { + statements: vec![], + return_expression: Some(( + ReturnKind::SoftReturn, + Box::new(Expression( + ExpressionKind::FunctionCall(FunctionCall { + name: fibonacci_name.clone(), + return_type: TypeKind::I32, + parameters: vec![Expression( + ExpressionKind::Literal(Literal::I32(5)), + Default::default(), + )], + }), + Default::default(), + )), + )), + range: Default::default(), + }, + Default::default(), + ), + }; + + println!("test1"); + + let module = Module { + name: "test module".to_owned(), + imports: vec![], + functions: vec![fibonacci, main], + }; + + println!("test2"); + let context = Context::new(); + let codegen_module = module.codegen(&context); + + println!("test3"); + + match codegen_module.module.print_to_string() { + Ok(v) => println!("{}", v), + Err(e) => println!("Err: {:?}", e), + } +} diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs new file mode 100644 index 0000000..41abe06 --- /dev/null +++ b/reid/src/codegen.rs @@ -0,0 +1,281 @@ +use std::{collections::HashMap, mem, ops::Deref}; + +use crate::mir::{self, types::ReturnType, TypeKind, VariableReference}; +use reid_lib::{ + types::{BasicType, BasicValue, IntegerValue, TypeEnum, Value}, + BasicBlock, Context, Function, IntPredicate, Module, +}; + +pub struct ModuleCodegen<'ctx> { + context: &'ctx Context, + pub module: Module<'ctx>, +} + +impl mir::Module { + pub fn codegen<'ctx>(&self, context: &'ctx Context) -> ModuleCodegen<'ctx> { + let module = context.module(&self.name); + + let mut functions = HashMap::new(); + + for function in &self.functions { + let ret_type = function.return_type().unwrap().get_type(&context); + let fn_type = ret_type.function_type( + function + .parameters + .iter() + .map(|(_, p)| p.get_type(&context)) + .collect(), + ); + + let func = match &function.kind { + mir::FunctionDefinitionKind::Local(_, _) => { + module.add_function(fn_type, &function.name) + } + mir::FunctionDefinitionKind::Extern(_) => todo!(), + }; + functions.insert(function.name.clone(), func); + } + + for mir_function in &self.functions { + let function = functions.get(&mir_function.name).unwrap(); + + let mut stack_values = HashMap::new(); + for (i, (p_name, p_type)) in mir_function.parameters.iter().enumerate() { + stack_values.insert( + p_name.clone(), + function.get_param(i, p_type.get_type(&context)).unwrap(), + ); + } + + let mut scope = Scope { + context, + module: &module, + function, + block: function.block("entry"), + functions: functions.clone(), + stack_values, + }; + match &mir_function.kind { + mir::FunctionDefinitionKind::Local(block, _) => { + if let Some(ret) = block.codegen(&mut scope) { + scope.block.ret(&ret).unwrap(); + } + } + mir::FunctionDefinitionKind::Extern(_) => {} + } + } + + ModuleCodegen { context, module } + } +} + +pub struct Scope<'ctx> { + context: &'ctx Context, + module: &'ctx Module<'ctx>, + function: &'ctx Function<'ctx>, + block: BasicBlock<'ctx>, + functions: HashMap>, + stack_values: HashMap>, +} + +impl<'ctx> Scope<'ctx> { + pub fn with_block(&self, block: BasicBlock<'ctx>) -> Scope<'ctx> { + Scope { + block, + context: self.context, + function: self.function, + module: self.module, + functions: self.functions.clone(), + stack_values: self.stack_values.clone(), + } + } + + /// Takes the block out from this scope, swaps the given block in it's place + /// and returns the old block. + pub fn swap_block(&mut self, block: BasicBlock<'ctx>) -> BasicBlock<'ctx> { + let mut old_block = block; + mem::swap(&mut self.block, &mut old_block); + old_block + } +} + +impl mir::Statement { + pub fn codegen<'ctx>(&self, scope: &mut Scope<'ctx>) -> Option> { + match &self.0 { + mir::StatementKind::Let(VariableReference(_, name, _), expression) => { + let value = expression.codegen(scope).unwrap(); + scope.stack_values.insert(name.clone(), value); + None + } + mir::StatementKind::If(if_expression) => if_expression.codegen(scope), + mir::StatementKind::Import(_) => todo!(), + mir::StatementKind::Expression(expression) => { + let value = expression.codegen(scope).unwrap(); + Some(value) + } + } + } +} + +impl mir::IfExpression { + pub fn codegen<'ctx>(&self, scope: &mut Scope<'ctx>) -> Option> { + let condition = self.0.codegen(scope).unwrap(); + + // Create blocks + let then_bb = scope.function.block("then"); + let after_bb = scope.function.block("after"); + let mut before_bb = scope.swap_block(after_bb); + + let mut then_scope = scope.with_block(then_bb); + let then_res = self.1.codegen(&mut then_scope); + then_scope.block.br(&scope.block).ok(); + + let else_bb = scope.function.block("else"); + let mut else_scope = scope.with_block(else_bb); + + let else_opt = if let Some(else_block) = &self.2 { + before_bb + .conditional_br(&condition, &then_scope.block, &else_scope.block) + .unwrap(); + + let opt = else_block.codegen(&mut else_scope); + + if let Some(ret) = opt { + else_scope.block.br(&scope.block).ok(); + Some((else_scope.block, ret)) + } else { + None + } + } else { + else_scope.block.br(&scope.block).unwrap(); + before_bb + .conditional_br(&condition, &then_scope.block, &scope.block) + .unwrap(); + None + }; + + if then_res.is_none() && else_opt.is_none() { + None + } else if let Ok(ret_type) = self.1.return_type() { + let phi = scope + .block + .phi(&ret_type.get_type(scope.context), "phi") + .unwrap(); + if let Some(then_ret) = then_res { + phi.add_incoming(&then_ret, &then_scope.block); + } + if let Some((else_bb, else_ret)) = else_opt { + phi.add_incoming(&else_ret, &else_bb); + } + + Some(phi.build()) + } else { + None + } + } +} + +impl mir::Expression { + pub fn codegen<'ctx>(&self, scope: &mut Scope<'ctx>) -> Option> { + match &self.0 { + mir::ExpressionKind::Variable(varref) => { + let v = scope + .stack_values + .get(&varref.1) + .expect("Variable reference not found?!"); + Some(v.clone()) + } + mir::ExpressionKind::Literal(lit) => Some(lit.codegen(scope.context)), + mir::ExpressionKind::BinOp(binop, lhs_exp, rhs_exp) => { + let lhs = lhs_exp.codegen(scope).expect("lhs has no return value"); + let rhs = rhs_exp.codegen(scope).expect("rhs has no return value"); + Some(match binop { + mir::BinaryOperator::Add => scope.block.add(&lhs, &rhs, "add").unwrap(), + mir::BinaryOperator::Minus => scope.block.sub(&lhs, &rhs, "sub").unwrap(), + mir::BinaryOperator::Mult => todo!(), + mir::BinaryOperator::And => todo!(), + mir::BinaryOperator::Logic(l) => { + let ret_type = lhs_exp.return_type().expect("No ret type in lhs?"); + scope + .block + .integer_compare(&lhs, &rhs, &l.int_predicate(ret_type.signed()), "cmp") + .unwrap() + } + }) + } + mir::ExpressionKind::FunctionCall(call) => { + let params = call + .parameters + .iter() + .map(|e| e.codegen(scope).unwrap()) + .collect(); + let callee = scope + .functions + .get(&call.name) + .expect("function not found!"); + Some(scope.block.call(callee, params, "call").unwrap()) + } + mir::ExpressionKind::If(if_expression) => if_expression.codegen(scope), + mir::ExpressionKind::Block(block) => { + let mut inner_scope = scope.with_block(scope.function.block("inner")); + if let Some(ret) = block.codegen(&mut inner_scope) { + inner_scope.block.br(&scope.block); + Some(ret) + } else { + None + } + } + } + } +} + +impl mir::LogicOperator { + fn int_predicate(&self, signed: bool) -> IntPredicate { + match (self, signed) { + (mir::LogicOperator::LessThan, true) => IntPredicate::SLT, + (mir::LogicOperator::GreaterThan, true) => IntPredicate::SGT, + (mir::LogicOperator::LessThan, false) => IntPredicate::ULT, + (mir::LogicOperator::GreaterThan, false) => IntPredicate::UGT, + } + } +} + +impl mir::Block { + pub fn codegen<'ctx>(&self, mut scope: &mut Scope<'ctx>) -> Option> { + for stmt in &self.statements { + stmt.codegen(&mut scope); + } + + if let Some((kind, expr)) = &self.return_expression { + let ret = expr.codegen(&mut scope).unwrap(); + match kind { + mir::ReturnKind::HardReturn => { + scope.block.ret(&ret).unwrap(); + None + } + mir::ReturnKind::SoftReturn => Some(ret), + } + } else { + None + } + } +} + +impl mir::Literal { + pub fn codegen<'ctx>(&self, context: &'ctx Context) -> Value<'ctx> { + let val: IntegerValue<'ctx> = match *self { + mir::Literal::I32(val) => context.type_i32().from_signed(val as i64), + mir::Literal::I16(val) => context.type_i16().from_signed(val as i64), + }; + Value::Integer(val) + } +} + +impl TypeKind { + fn get_type<'ctx>(&self, context: &'ctx Context) -> TypeEnum<'ctx> { + match &self { + TypeKind::I32 => TypeEnum::Integer(context.type_i32()), + TypeKind::I16 => TypeEnum::Integer(context.type_i16()), + } + } +} diff --git a/reid/src/lib.rs b/reid/src/lib.rs index e43c8be..de55536 100644 --- a/reid/src/lib.rs +++ b/reid/src/lib.rs @@ -1,11 +1,13 @@ -use codegen::{form_context, from_statements}; +use old_codegen::{form_context, from_statements}; -use crate::{ast::TopLevelStatement, lexer::Token, token_stream::TokenStream}; +use crate::{lexer::Token, parser::TopLevelStatement, token_stream::TokenStream}; -mod ast; -mod codegen; mod lexer; +pub mod mir; +mod old_codegen; +mod parser; // mod llvm_ir; +pub mod codegen; mod token_stream; // TODO: diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs new file mode 100644 index 0000000..45531e8 --- /dev/null +++ b/reid/src/mir/mod.rs @@ -0,0 +1,120 @@ +/// In this module are defined structs that are used for performing passes on +/// Reid. It contains a simplified version of Reid which must already be +/// type-checked beforehand. +use std::collections::HashMap; + +use types::*; + +use crate::token_stream::TokenRange; + +pub mod types; + +#[derive(Clone, Copy)] +pub enum TypeKind { + I32, + I16, +} + +impl TypeKind { + pub fn signed(&self) -> bool { + match self { + _ => true, + } + } +} + +#[derive(Clone, Copy)] +pub enum Literal { + I32(i32), + I16(i16), +} + +impl Literal { + fn as_type(self: &Literal) -> TypeKind { + match self { + Literal::I32(_) => TypeKind::I32, + Literal::I16(_) => TypeKind::I16, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum BinaryOperator { + Add, + Minus, + Mult, + And, + Logic(LogicOperator), +} + +#[derive(Debug, Clone, Copy)] +pub enum LogicOperator { + LessThan, + GreaterThan, +} + +#[derive(Debug, Clone, Copy)] +pub enum ReturnKind { + HardReturn, + SoftReturn, +} + +pub struct VariableReference(pub TypeKind, pub String, pub TokenRange); + +pub struct Import(pub String, pub TokenRange); + +pub enum ExpressionKind { + Variable(VariableReference), + Literal(Literal), + BinOp(BinaryOperator, Box, Box), + FunctionCall(FunctionCall), + If(IfExpression), + Block(Block), +} + +pub struct Expression(pub ExpressionKind, pub TokenRange); + +/// Condition, Then, Else +pub struct IfExpression(pub Box, pub Block, pub Option); + +pub struct FunctionCall { + pub name: String, + pub return_type: TypeKind, + pub parameters: Vec, +} + +pub struct FunctionDefinition { + pub name: String, + pub parameters: Vec<(String, TypeKind)>, + pub kind: FunctionDefinitionKind, +} + +pub enum FunctionDefinitionKind { + /// Actual definition block and surrounding signature range + Local(Block, TokenRange), + /// Return Type + Extern(TypeKind), +} + +pub struct Block { + /// List of non-returning statements + pub statements: Vec, + pub return_expression: Option<(ReturnKind, Box)>, + pub range: TokenRange, +} + +pub struct Statement(pub StatementKind, pub TokenRange); + +pub enum StatementKind { + /// Variable name+type, evaluation + Let(VariableReference, Expression), + If(IfExpression), + Import(Import), + Expression(Expression), +} + +pub struct Module { + pub name: String, + pub imports: Vec, + pub functions: Vec, +} diff --git a/reid/src/mir/types.rs b/reid/src/mir/types.rs new file mode 100644 index 0000000..c62f299 --- /dev/null +++ b/reid/src/mir/types.rs @@ -0,0 +1,75 @@ +use super::*; + +#[derive(Debug, Clone)] +pub enum ReturnTypeOther { + Import(TokenRange), + Let(TokenRange), + EmptyBlock(TokenRange), + NoBlockReturn(TokenRange), +} + +pub trait ReturnType { + fn return_type(&self) -> Result; +} + +impl ReturnType for Block { + fn return_type(&self) -> Result { + self.return_expression + .as_ref() + .ok_or(ReturnTypeOther::NoBlockReturn(self.range.clone())) + .and_then(|(_, stmt)| stmt.return_type()) + } +} + +impl ReturnType for Statement { + fn return_type(&self) -> Result { + use StatementKind::*; + match &self.0 { + Expression(e) => e.return_type(), + If(e) => e.return_type(), + Import(_) => Err(ReturnTypeOther::Import(self.1)), + Let(_, _) => Err(ReturnTypeOther::Let(self.1)), + } + } +} + +impl ReturnType for Expression { + fn return_type(&self) -> Result { + use ExpressionKind::*; + match &self.0 { + Literal(lit) => Ok(lit.as_type()), + Variable(var) => var.return_type(), + BinOp(_, expr, _) => expr.return_type(), + Block(block) => block.return_type(), + FunctionCall(fcall) => fcall.return_type(), + If(expr) => expr.return_type(), + } + } +} + +impl ReturnType for IfExpression { + fn return_type(&self) -> Result { + self.1.return_type() + } +} + +impl ReturnType for VariableReference { + fn return_type(&self) -> Result { + Ok(self.0) + } +} + +impl ReturnType for FunctionCall { + fn return_type(&self) -> Result { + Ok(self.return_type) + } +} + +impl ReturnType for FunctionDefinition { + fn return_type(&self) -> Result { + match &self.kind { + FunctionDefinitionKind::Local(block, _) => block.return_type(), + FunctionDefinitionKind::Extern(type_kind) => Ok(*type_kind), + } + } +} diff --git a/reid/src/codegen/llvm.rs b/reid/src/old_codegen/llvm.rs similarity index 97% rename from reid/src/codegen/llvm.rs rename to reid/src/old_codegen/llvm.rs index 51f0ea7..1b74f7b 100644 --- a/reid/src/codegen/llvm.rs +++ b/reid/src/old_codegen/llvm.rs @@ -9,10 +9,10 @@ use llvm_sys::transforms::pass_manager_builder::{ LLVMPassManagerBuilderSetOptLevel, }; use llvm_sys::{ - LLVMBasicBlock, LLVMBuilder, LLVMContext, LLVMModule, LLVMType, LLVMValue, core::*, prelude::*, + core::*, prelude::*, LLVMBasicBlock, LLVMBuilder, LLVMContext, LLVMModule, LLVMType, LLVMValue, }; -use crate::ast; +use crate::parser; fn into_cstring>(value: T) -> CString { let string = value.into(); @@ -47,8 +47,8 @@ impl IRType { pub struct IRValue(pub IRType, *mut LLVMValue); impl IRValue { - pub fn from_literal(literal: &ast::Literal, module: &IRModule) -> Self { - use ast::Literal; + pub fn from_literal(literal: &parser::Literal, module: &IRModule) -> Self { + use parser::Literal; match literal { Literal::I32(v) => { let ir_type = IRType::I32; diff --git a/reid/src/codegen/mod.rs b/reid/src/old_codegen/mod.rs similarity index 96% rename from reid/src/codegen/mod.rs rename to reid/src/old_codegen/mod.rs index e13c66f..d709c41 100644 --- a/reid/src/codegen/mod.rs +++ b/reid/src/old_codegen/mod.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use llvm::{Error, IRBlock, IRContext, IRFunction, IRModule, IRValue}; use crate::{ - ast::{ + parser::{ Block, BlockLevelStatement, Expression, ExpressionKind, FunctionDefinition, IfExpression, LetStatement, ReturnType, }, @@ -97,7 +97,7 @@ impl Expression { Binop(op, lhs, rhs) => { let lhs = lhs.codegen(scope); let rhs = rhs.codegen(scope); - use crate::ast::BinaryOperator::*; + use crate::parser::BinaryOperator::*; match op { Add => scope.block.add(lhs, rhs).unwrap(), Mult => scope.block.mult(lhs, rhs).unwrap(), @@ -121,7 +121,7 @@ impl Expression { _ => then.block.move_into(&mut scope.block), } - IRValue::from_literal(&crate::ast::Literal::I32(1), scope.block.function.module) + IRValue::from_literal(&crate::parser::Literal::I32(1), scope.block.function.module) } BlockExpr(_) => panic!("block expr not supported"), FunctionCall(_) => panic!("function call expr not supported"), diff --git a/reid/src/ast.rs b/reid/src/parser.rs similarity index 100% rename from reid/src/ast.rs rename to reid/src/parser.rs diff --git a/reid/src/token_stream.rs b/reid/src/token_stream.rs index 12daa54..223ab9d 100644 --- a/reid/src/token_stream.rs +++ b/reid/src/token_stream.rs @@ -1,6 +1,6 @@ use crate::{ - ast::Parse, lexer::{FullToken, Position, Token}, + parser::Parse, }; pub struct TokenStream<'a, 'b> { @@ -156,7 +156,7 @@ impl Drop for TokenStream<'_, '_> { } } -#[derive(Clone)] +#[derive(Clone, Copy)] pub struct TokenRange { pub start: usize, pub end: usize, @@ -168,6 +168,15 @@ impl std::fmt::Debug for TokenRange { } } +impl Default for TokenRange { + fn default() -> Self { + Self { + start: Default::default(), + end: Default::default(), + } + } +} + #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Expected {} at Ln {}, Col {}, got {:?}", .0, (.2).1, (.2).0, .1)]