Refactor a bunch of stuff, produce compiling MIR

This commit is contained in:
Sofia 2025-07-04 21:30:40 +03:00
parent 05c585d47c
commit 8a32e66ba8
15 changed files with 884 additions and 128 deletions

1
Cargo.lock generated
View File

@ -100,6 +100,7 @@ name = "reid"
version = "0.1.0"
dependencies = [
"llvm-sys",
"reid-lib",
"thiserror",
]

View File

@ -6,7 +6,7 @@
# Do note this file is extremely simply for my own personal convenience
export .env
cargo run --example libtest && \
cargo run --example testcodegen && \
# clang hello.o -o main && \
ld -dynamic-linker /lib64/ld-linux-x86-64.so.2 \
-o main /usr/lib/crt1.o hello.o -lc && \

View File

@ -1,6 +1,6 @@
use reid_lib::{
Context, IntPredicate,
types::{BasicType, IntegerType, IntegerValue, Value},
types::{BasicType, IntegerValue, Value},
};
pub fn main() {
@ -16,15 +16,19 @@ pub fn main() {
let int_32 = context.type_i32();
let fibonacci = module.add_function(int_32.function_type(vec![&int_32]), "fibonacci");
let f_main = fibonacci.block("main");
let fibonacci = module.add_function(int_32.function_type(vec![int_32.into()]), "fibonacci");
let mut f_main = fibonacci.block("main");
let param = fibonacci.get_param::<IntegerValue>(0).unwrap();
let cmp = f_main
let param = fibonacci
.get_param::<IntegerValue>(0, int_32.into())
.unwrap();
let mut cmp = f_main
.integer_compare(&param, &int_32.from_unsigned(3), &IntPredicate::ULT, "cmp")
.unwrap();
let (done, recurse) = f_main.conditional_br(&cmp, "done", "recurse").unwrap();
let mut done = fibonacci.block("done");
let mut recurse = fibonacci.block("recurse");
f_main.conditional_br(&cmp, &done, &recurse).unwrap();
done.ret(&int_32.from_unsigned(1)).unwrap();
@ -34,7 +38,7 @@ pub fn main() {
let minus_two = recurse
.sub(&param, &int_32.from_unsigned(2), "minus_two")
.unwrap();
let one = recurse
let one: IntegerValue = recurse
.call(&fibonacci, vec![Value::Integer(minus_one)], "call_one")
.unwrap();
let two = recurse
@ -47,8 +51,8 @@ pub fn main() {
let main_f = module.add_function(int_32.function_type(Vec::new()), "main");
let main_b = main_f.block("main");
let call = main_b
let mut main_b = main_f.block("main");
let call: IntegerValue = main_b
.call(
&fibonacci,
vec![Value::Integer(int_32.from_unsigned(8))],

View File

@ -20,15 +20,20 @@ pub mod types;
mod util;
pub enum IntPredicate {
ULT,
SLT,
SGT,
ULT,
UGT,
}
impl IntPredicate {
pub fn as_llvm(&self) -> LLVMIntPredicate {
match *self {
Self::ULT => LLVMIntPredicate::LLVMIntULT,
Self::SLT => LLVMIntPredicate::LLVMIntSLT,
Self::SGT => LLVMIntPredicate::LLVMIntSGT,
Self::ULT => LLVMIntPredicate::LLVMIntULT,
Self::UGT => LLVMIntPredicate::LLVMIntUGT,
}
}
}
@ -68,8 +73,8 @@ impl Context {
IntegerType::in_context(&self, 32)
}
pub fn module<T: Into<String>>(&self, name: T) -> Module {
Module::with_name(self, name.into())
pub fn module(&self, name: &str) -> Module {
Module::with_name(self, name)
}
}
@ -90,7 +95,7 @@ pub struct Module<'ctx> {
}
impl<'ctx> Module<'ctx> {
fn with_name(context: &'ctx Context, name: String) -> Module<'ctx> {
fn with_name(context: &'ctx Context, name: &str) -> Module<'ctx> {
unsafe {
let cstring_name = into_cstring(name);
let module_ref =
@ -103,11 +108,7 @@ impl<'ctx> Module<'ctx> {
}
}
pub fn add_function<ReturnValue: BasicValue<'ctx>>(
&'ctx self,
fn_type: FunctionType<'ctx, ReturnValue::BaseType>,
name: &str,
) -> Function<'ctx, ReturnValue> {
pub fn add_function(&'ctx self, fn_type: FunctionType<'ctx>, name: &str) -> Function<'ctx> {
unsafe {
let name_cstring = into_cstring(name);
let function_ref =
@ -193,21 +194,26 @@ impl<'a> Drop for Module<'a> {
}
}
pub struct Function<'ctx, ReturnValue: BasicValue<'ctx>> {
#[derive(Clone)]
pub struct Function<'ctx> {
module: &'ctx Module<'ctx>,
name: CString,
fn_type: FunctionType<'ctx, ReturnValue::BaseType>,
fn_type: FunctionType<'ctx>,
fn_ref: LLVMValueRef,
}
impl<'ctx, ReturnValue: BasicValue<'ctx>> Function<'ctx, ReturnValue> {
pub fn block<T: Into<String>>(&'ctx self, name: T) -> BasicBlock<'ctx, ReturnValue> {
impl<'ctx> Function<'ctx> {
pub fn block<T: Into<String>>(&'ctx self, name: T) -> BasicBlock<'ctx> {
BasicBlock::in_function(&self, name.into())
}
pub fn get_param<T: BasicValue<'ctx>>(&'ctx self, nth: usize) -> Result<T, String> {
if let Some(param_type) = self.fn_type.param_types.iter().nth(nth) {
if self.fn_type.return_type(self.module.context).llvm_type() != *param_type {
pub fn get_param<T: BasicValue<'ctx>>(
&'ctx self,
nth: usize,
param_type: T::BaseType,
) -> Result<T, String> {
if let Some(actual_type) = self.fn_type.param_types.iter().nth(nth) {
if param_type.llvm_type() != *actual_type {
return Err(String::from("Wrong type"));
}
} else {
@ -217,29 +223,27 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> Function<'ctx, ReturnValue> {
}
}
pub struct BasicBlock<'ctx, ReturnValue: BasicValue<'ctx>> {
function: &'ctx Function<'ctx, ReturnValue>,
pub struct BasicBlock<'ctx> {
function: &'ctx Function<'ctx>,
builder_ref: LLVMBuilderRef,
name: CString,
name: String,
blockref: LLVMBasicBlockRef,
inserted: bool,
}
impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
fn in_function(
function: &'ctx Function<'ctx, ReturnValue>,
name: String,
) -> BasicBlock<'ctx, ReturnValue> {
impl<'ctx> BasicBlock<'ctx> {
fn in_function(function: &'ctx Function<'ctx>, name: String) -> BasicBlock<'ctx> {
unsafe {
let block_name = into_cstring(name);
let block_name = into_cstring(name.clone());
let block_ref = LLVMCreateBasicBlockInContext(
function.module.context.context_ref,
block_name.as_ptr(),
);
LLVMAppendExistingBasicBlock(function.fn_ref, block_ref);
BasicBlock {
function: function,
builder_ref: function.module.context.builder_ref,
name: block_name,
name,
blockref: block_ref,
inserted: false,
}
@ -269,12 +273,12 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
}
#[must_use]
pub fn call(
pub fn call<T: BasicValue<'ctx>>(
&self,
callee: &'ctx Function<'ctx, ReturnValue>,
callee: &Function<'ctx>,
params: Vec<Value<'ctx>>,
name: &str,
) -> Result<ReturnValue, ()> {
) -> Result<T, ()> {
if params.len() != callee.fn_type.param_types.len() {
return Err(()); // TODO invalid amount of parameters
}
@ -283,6 +287,9 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
return Err(()); // TODO wrong types in parameters
}
}
if !T::BaseType::is_type(callee.fn_type.return_type) {
return Err(()); // TODO wrong return type
}
unsafe {
let mut param_list: Vec<LLVMValueRef> = params.iter().map(|p| p.llvm_value()).collect();
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
@ -294,7 +301,7 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
param_list.len() as u32,
into_cstring(name).as_ptr(),
);
Ok(ReturnValue::from_llvm(ret_val))
Ok(T::from_llvm(ret_val))
}
}
@ -317,6 +324,8 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
#[must_use]
pub fn sub<T: BasicValue<'ctx>>(&self, lhs: &T, rhs: &T, name: &str) -> Result<T, ()> {
dbg!(lhs, rhs);
dbg!(lhs.llvm_type(), rhs.llvm_type());
if lhs.llvm_type() != rhs.llvm_type() {
return Err(()); // TODO error
}
@ -335,9 +344,9 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
#[must_use]
pub fn phi<PhiValue: BasicValue<'ctx>>(
&self,
phi_type: &'ctx PhiValue::BaseType,
phi_type: &PhiValue::BaseType,
name: &str,
) -> Result<PhiBuilder<'ctx, ReturnValue, PhiValue>, ()> {
) -> Result<PhiBuilder<'ctx, PhiValue>, ()> {
unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
let phi_node = LLVMBuildPhi(
@ -350,26 +359,24 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
}
#[must_use]
pub fn br(self, into: &BasicBlock<'ctx, ReturnValue>) -> Result<(), ()> {
pub fn br(&mut self, into: &BasicBlock<'ctx>) -> Result<(), ()> {
self.try_insert()?;
unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildBr(self.builder_ref, into.blockref);
self.terminate();
Ok(())
}
}
#[must_use]
pub fn conditional_br<T: BasicValue<'ctx>>(
self,
&mut self,
condition: &T,
lhs_name: &str,
rhs_name: &str,
) -> Result<(BasicBlock<'ctx, ReturnValue>, BasicBlock<'ctx, ReturnValue>), ()> {
lhs: &BasicBlock<'ctx>,
rhs: &BasicBlock<'ctx>,
) -> Result<(), ()> {
self.try_insert()?;
unsafe {
let lhs = BasicBlock::in_function(&self.function, lhs_name.into());
let rhs = BasicBlock::in_function(&self.function, rhs_name.into());
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildCondBr(
self.builder_ref,
@ -377,39 +384,34 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> BasicBlock<'ctx, ReturnValue> {
lhs.blockref,
rhs.blockref,
);
self.terminate();
Ok((lhs, rhs))
}
}
#[must_use]
pub fn ret(self, return_value: &ReturnValue) -> Result<(), ()> {
if self
.function
.fn_type
.return_type(self.function.module.context)
.llvm_type()
!= return_value.llvm_type()
{
return Err(());
}
unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildRet(self.builder_ref, return_value.llvm_value());
self.terminate();
Ok(())
}
}
unsafe fn terminate(mut self) {
unsafe {
LLVMAppendExistingBasicBlock(self.function.fn_ref, self.blockref);
self.inserted = true;
#[must_use]
pub fn ret<T: BasicValue<'ctx>>(&mut self, return_value: &T) -> Result<(), ()> {
if self.function.fn_type.return_type != return_value.llvm_type() {
return Err(());
}
self.try_insert()?;
unsafe {
LLVMPositionBuilderAtEnd(self.builder_ref, self.blockref);
LLVMBuildRet(self.builder_ref, return_value.llvm_value());
Ok(())
}
}
impl<'ctx, ReturnValue: BasicValue<'ctx>> Drop for BasicBlock<'ctx, ReturnValue> {
fn try_insert(&mut self) -> Result<(), ()> {
if self.inserted {
return Err(());
}
self.inserted = true;
Ok(())
}
}
impl<'ctx> Drop for BasicBlock<'ctx> {
fn drop(&mut self) {
if !self.inserted {
unsafe {
@ -419,22 +421,20 @@ impl<'ctx, ReturnValue: BasicValue<'ctx>> Drop for BasicBlock<'ctx, ReturnValue>
}
}
pub struct PhiBuilder<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>> {
pub struct PhiBuilder<'ctx, PhiValue: BasicValue<'ctx>> {
phi_node: LLVMValueRef,
phantom: PhantomData<&'ctx (PhiValue, ReturnValue)>,
phantom: PhantomData<&'ctx PhiValue>,
}
impl<'ctx, ReturnValue: BasicValue<'ctx>, PhiValue: BasicValue<'ctx>>
PhiBuilder<'ctx, ReturnValue, PhiValue>
{
fn new(phi_node: LLVMValueRef) -> PhiBuilder<'ctx, ReturnValue, PhiValue> {
impl<'ctx, PhiValue: BasicValue<'ctx>> PhiBuilder<'ctx, PhiValue> {
fn new(phi_node: LLVMValueRef) -> PhiBuilder<'ctx, PhiValue> {
PhiBuilder {
phi_node,
phantom: PhantomData,
}
}
pub fn add_incoming(&self, value: &PhiValue, block: &BasicBlock<'ctx, ReturnValue>) -> &Self {
pub fn add_incoming(&self, value: &PhiValue, block: &BasicBlock<'ctx>) -> &Self {
let mut values = vec![value.llvm_value()];
let mut blocks = vec![block.blockref];
unsafe {

View File

@ -6,36 +6,40 @@ use llvm_sys::{
prelude::{LLVMTypeRef, LLVMValueRef},
};
use crate::Context;
use crate::{BasicBlock, Context, PhiBuilder};
pub trait BasicType<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef;
fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
fn is_type(llvm_type: LLVMTypeRef) -> bool
where
Self: Sized;
fn function_type(&'ctx self, params: Vec<&'ctx dyn BasicType>) -> FunctionType<'ctx, Self>
unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
where
Self: Sized,
{
Self: Sized;
fn function_type(&self, params: Vec<TypeEnum>) -> FunctionType<'ctx> {
unsafe {
let mut typerefs: Vec<LLVMTypeRef> = params.iter().map(|b| b.llvm_type()).collect();
let param_ptr = typerefs.as_mut_ptr();
let param_len = typerefs.len();
FunctionType {
phantom: PhantomData,
return_type: self.llvm_type(),
param_types: typerefs,
type_ref: LLVMFunctionType(self.llvm_type(), param_ptr, param_len as u32, 0),
}
}
}
fn array_type(&'ctx self, length: u32) -> ArrayType<'ctx, Self>
fn array_type(&'ctx self, length: u32) -> ArrayType<'ctx>
where
Self: Sized,
{
ArrayType {
element_type: self,
phantom: PhantomData,
element_type: self.llvm_type(),
length,
type_ref: unsafe { LLVMArrayType(self.llvm_type(), length) },
}
@ -54,6 +58,7 @@ impl<'ctx> PartialEq<LLVMTypeRef> for &dyn BasicType<'ctx> {
}
}
#[derive(Clone, Copy)]
pub struct IntegerType<'ctx> {
context: &'ctx Context,
type_ref: LLVMTypeRef,
@ -64,7 +69,7 @@ impl<'ctx> BasicType<'ctx> for IntegerType<'ctx> {
self.type_ref
}
fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
where
Self: Sized,
{
@ -73,6 +78,10 @@ impl<'ctx> BasicType<'ctx> for IntegerType<'ctx> {
type_ref: llvm_type,
}
}
fn is_type(llvm_type: LLVMTypeRef) -> bool {
unsafe { LLVMGetTypeKind(llvm_type) == LLVMTypeKind::LLVMIntegerTypeKind }
}
}
impl<'ctx> IntegerType<'ctx> {
@ -91,15 +100,15 @@ impl<'ctx> IntegerType<'ctx> {
IntegerType { context, type_ref }
}
pub fn from_signed(&self, value: i64) -> IntegerValue<'_> {
pub fn from_signed(&self, value: i64) -> IntegerValue<'ctx> {
self.from_const(value as u64, true)
}
pub fn from_unsigned(&self, value: i64) -> IntegerValue<'_> {
pub fn from_unsigned(&self, value: i64) -> IntegerValue<'ctx> {
self.from_const(value as u64, false)
}
fn from_const(&self, value: u64, sign: bool) -> IntegerValue<'_> {
fn from_const(&self, value: u64, sign: bool) -> IntegerValue<'ctx> {
unsafe {
IntegerValue::from_llvm(LLVMConstInt(
self.type_ref,
@ -113,18 +122,20 @@ impl<'ctx> IntegerType<'ctx> {
}
}
pub struct FunctionType<'ctx, ReturnType: BasicType<'ctx>> {
phantom: PhantomData<&'ctx ReturnType>,
#[derive(Clone)]
pub struct FunctionType<'ctx> {
phantom: PhantomData<&'ctx ()>,
pub(crate) return_type: LLVMTypeRef,
pub(crate) param_types: Vec<LLVMTypeRef>,
type_ref: LLVMTypeRef,
}
impl<'ctx, ReturnType: BasicType<'ctx>> BasicType<'ctx> for FunctionType<'ctx, ReturnType> {
impl<'ctx> BasicType<'ctx> for FunctionType<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef {
self.type_ref
}
fn from_llvm(_context: &'ctx Context, fn_type: LLVMTypeRef) -> Self
unsafe fn from_llvm(_context: &'ctx Context, fn_type: LLVMTypeRef) -> Self
where
Self: Sized,
{
@ -139,34 +150,32 @@ impl<'ctx, ReturnType: BasicType<'ctx>> BasicType<'ctx> for FunctionType<'ctx, R
.collect();
FunctionType {
phantom: PhantomData,
return_type: LLVMGetReturnType(fn_type),
param_types,
type_ref: fn_type,
}
}
}
}
impl<'ctx, ReturnType: BasicType<'ctx>> FunctionType<'ctx, ReturnType> {
pub fn return_type(&self, context: &'ctx Context) -> ReturnType {
unsafe {
let return_type = LLVMGetReturnType(self.type_ref);
ReturnType::from_llvm(context, return_type)
}
fn is_type(llvm_type: LLVMTypeRef) -> bool {
unsafe { LLVMGetTypeKind(llvm_type) == LLVMTypeKind::LLVMFunctionTypeKind }
}
}
pub struct ArrayType<'ctx, T: BasicType<'ctx>> {
element_type: &'ctx T,
#[derive(Clone, Copy)]
pub struct ArrayType<'ctx> {
phantom: PhantomData<&'ctx ()>,
element_type: LLVMTypeRef,
length: u32,
type_ref: LLVMTypeRef,
}
impl<'ctx, T: BasicType<'ctx>> BasicType<'ctx> for ArrayType<'ctx, T> {
impl<'ctx> BasicType<'ctx> for ArrayType<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef {
self.type_ref
}
fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
where
Self: Sized,
{
@ -175,9 +184,81 @@ impl<'ctx, T: BasicType<'ctx>> BasicType<'ctx> for ArrayType<'ctx, T> {
todo!()
}
}
fn is_type(llvm_type: LLVMTypeRef) -> bool {
unsafe { LLVMGetTypeKind(llvm_type) == LLVMTypeKind::LLVMArrayTypeKind }
}
}
pub trait BasicValue<'ctx> {
#[derive(Clone)]
pub enum TypeEnum<'ctx> {
Integer(IntegerType<'ctx>),
Array(ArrayType<'ctx>),
Function(FunctionType<'ctx>),
}
impl<'ctx> From<IntegerType<'ctx>> for TypeEnum<'ctx> {
fn from(int: IntegerType<'ctx>) -> Self {
TypeEnum::Integer(int)
}
}
impl<'ctx> From<ArrayType<'ctx>> for TypeEnum<'ctx> {
fn from(arr: ArrayType<'ctx>) -> Self {
TypeEnum::Array(arr)
}
}
impl<'ctx> From<FunctionType<'ctx>> for TypeEnum<'ctx> {
fn from(func: FunctionType<'ctx>) -> Self {
TypeEnum::Function(func)
}
}
impl<'ctx> TypeEnum<'ctx> {
fn inner_basic(&'ctx self) -> &'ctx dyn BasicType<'ctx> {
match self {
TypeEnum::Integer(integer_type) => integer_type,
TypeEnum::Array(array_type) => array_type,
TypeEnum::Function(function_type) => function_type,
}
}
}
impl<'ctx> BasicType<'ctx> for TypeEnum<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef {
self.inner_basic().llvm_type()
}
fn is_type(llvm_type: LLVMTypeRef) -> bool
where
Self: Sized,
{
true
}
unsafe fn from_llvm(context: &'ctx Context, llvm_type: LLVMTypeRef) -> Self
where
Self: Sized,
{
unsafe {
match LLVMGetTypeKind(llvm_type) {
LLVMTypeKind::LLVMIntegerTypeKind => {
TypeEnum::Integer(IntegerType::from_llvm(context, llvm_type))
}
LLVMTypeKind::LLVMArrayTypeKind => {
TypeEnum::Array(ArrayType::from_llvm(context, llvm_type))
}
LLVMTypeKind::LLVMFunctionTypeKind => {
TypeEnum::Function(FunctionType::from_llvm(context, llvm_type))
}
_ => todo!(),
}
}
}
}
pub trait BasicValue<'ctx>: std::fmt::Debug {
type BaseType: BasicType<'ctx>;
unsafe fn from_llvm(value: LLVMValueRef) -> Self
where
@ -186,6 +267,7 @@ pub trait BasicValue<'ctx> {
fn llvm_type(&self) -> LLVMTypeRef;
}
#[derive(Clone, Debug)]
pub struct IntegerValue<'ctx> {
phantom: PhantomData<&'ctx ()>,
pub(crate) value_ref: LLVMValueRef,
@ -210,11 +292,14 @@ impl<'ctx> BasicValue<'ctx> for IntegerValue<'ctx> {
}
}
#[derive(Clone, Debug)]
pub enum Value<'ctx> {
Integer(IntegerValue<'ctx>),
}
impl<'ctx> Value<'ctx> {
impl<'ctx> BasicValue<'ctx> for Value<'ctx> {
type BaseType = TypeEnum<'ctx>;
unsafe fn from_llvm(value: LLVMValueRef) -> Self
where
Self: Sized,
@ -231,15 +316,21 @@ impl<'ctx> Value<'ctx> {
}
}
pub fn llvm_value(&self) -> LLVMValueRef {
fn llvm_value(&self) -> LLVMValueRef {
match self {
Self::Integer(i) => i.llvm_value(),
}
}
pub fn llvm_type(&self) -> LLVMTypeRef {
fn llvm_type(&self) -> LLVMTypeRef {
match self {
Self::Integer(i) => i.llvm_type(),
}
}
}
impl<'ctx> From<IntegerValue<'ctx>> for Value<'ctx> {
fn from(value: IntegerValue<'ctx>) -> Self {
Value::Integer(value)
}
}

View File

@ -10,3 +10,4 @@ edition = "2021"
llvm-sys = "160"
## Make it easier to generate errors
thiserror = "1.0.44"
reid-lib = { path = "../reid-llvm-lib" }

View File

@ -0,0 +1,172 @@
use reid::mir::*;
use reid_lib::Context;
fn main() {
let fibonacci_name = "fibonacci".to_owned();
let fibonacci_n = "N".to_owned();
let fibonacci = FunctionDefinition {
name: fibonacci_name.clone(),
parameters: vec![(fibonacci_n.clone(), TypeKind::I32)],
kind: FunctionDefinitionKind::Local(
Block {
statements: vec![Statement(
StatementKind::If(IfExpression(
// If N < 3
Box::new(Expression(
ExpressionKind::BinOp(
BinaryOperator::Logic(LogicOperator::GreaterThan),
Box::new(Expression(
ExpressionKind::Variable(VariableReference(
TypeKind::I32,
"N".to_string(),
Default::default(),
)),
Default::default(),
)),
Box::new(Expression(
ExpressionKind::Literal(Literal::I32(2)),
Default::default(),
)),
),
Default::default(),
)),
// Then
Block {
statements: vec![],
return_expression: Some((
ReturnKind::HardReturn,
// return fibonacci(n-1) + fibonacci(n-2)
Box::new(Expression(
ExpressionKind::BinOp(
BinaryOperator::Add,
// fibonacci(n-1)
Box::new(Expression(
ExpressionKind::FunctionCall(FunctionCall {
name: fibonacci_name.clone(),
return_type: TypeKind::I32,
parameters: vec![Expression(
ExpressionKind::BinOp(
BinaryOperator::Minus,
Box::new(Expression(
ExpressionKind::Variable(
VariableReference(
TypeKind::I32,
fibonacci_n.clone(),
Default::default(),
),
),
Default::default(),
)),
Box::new(Expression(
ExpressionKind::Literal(Literal::I32(
1,
)),
Default::default(),
)),
),
Default::default(),
)],
}),
Default::default(),
)),
// fibonacci(n-2)
Box::new(Expression(
ExpressionKind::FunctionCall(FunctionCall {
name: fibonacci_name.clone(),
return_type: TypeKind::I32,
parameters: vec![Expression(
ExpressionKind::BinOp(
BinaryOperator::Minus,
Box::new(Expression(
ExpressionKind::Variable(
VariableReference(
TypeKind::I32,
fibonacci_n.clone(),
Default::default(),
),
),
Default::default(),
)),
Box::new(Expression(
ExpressionKind::Literal(Literal::I32(
2,
)),
Default::default(),
)),
),
Default::default(),
)],
}),
Default::default(),
)),
),
Default::default(),
)),
)),
range: Default::default(),
},
// No else-block
None,
)),
Default::default(),
)],
// return 1
return_expression: Some((
ReturnKind::SoftReturn,
Box::new(Expression(
ExpressionKind::Literal(Literal::I32(1)),
Default::default(),
)),
)),
range: Default::default(),
},
Default::default(),
),
};
let main = FunctionDefinition {
name: "main".to_owned(),
parameters: vec![],
kind: FunctionDefinitionKind::Local(
Block {
statements: vec![],
return_expression: Some((
ReturnKind::SoftReturn,
Box::new(Expression(
ExpressionKind::FunctionCall(FunctionCall {
name: fibonacci_name.clone(),
return_type: TypeKind::I32,
parameters: vec![Expression(
ExpressionKind::Literal(Literal::I32(5)),
Default::default(),
)],
}),
Default::default(),
)),
)),
range: Default::default(),
},
Default::default(),
),
};
println!("test1");
let module = Module {
name: "test module".to_owned(),
imports: vec![],
functions: vec![fibonacci, main],
};
println!("test2");
let context = Context::new();
let codegen_module = module.codegen(&context);
println!("test3");
match codegen_module.module.print_to_string() {
Ok(v) => println!("{}", v),
Err(e) => println!("Err: {:?}", e),
}
}

281
reid/src/codegen.rs Normal file
View File

@ -0,0 +1,281 @@
use std::{collections::HashMap, mem, ops::Deref};
use crate::mir::{self, types::ReturnType, TypeKind, VariableReference};
use reid_lib::{
types::{BasicType, BasicValue, IntegerValue, TypeEnum, Value},
BasicBlock, Context, Function, IntPredicate, Module,
};
pub struct ModuleCodegen<'ctx> {
context: &'ctx Context,
pub module: Module<'ctx>,
}
impl mir::Module {
pub fn codegen<'ctx>(&self, context: &'ctx Context) -> ModuleCodegen<'ctx> {
let module = context.module(&self.name);
let mut functions = HashMap::new();
for function in &self.functions {
let ret_type = function.return_type().unwrap().get_type(&context);
let fn_type = ret_type.function_type(
function
.parameters
.iter()
.map(|(_, p)| p.get_type(&context))
.collect(),
);
let func = match &function.kind {
mir::FunctionDefinitionKind::Local(_, _) => {
module.add_function(fn_type, &function.name)
}
mir::FunctionDefinitionKind::Extern(_) => todo!(),
};
functions.insert(function.name.clone(), func);
}
for mir_function in &self.functions {
let function = functions.get(&mir_function.name).unwrap();
let mut stack_values = HashMap::new();
for (i, (p_name, p_type)) in mir_function.parameters.iter().enumerate() {
stack_values.insert(
p_name.clone(),
function.get_param(i, p_type.get_type(&context)).unwrap(),
);
}
let mut scope = Scope {
context,
module: &module,
function,
block: function.block("entry"),
functions: functions.clone(),
stack_values,
};
match &mir_function.kind {
mir::FunctionDefinitionKind::Local(block, _) => {
if let Some(ret) = block.codegen(&mut scope) {
scope.block.ret(&ret).unwrap();
}
}
mir::FunctionDefinitionKind::Extern(_) => {}
}
}
ModuleCodegen { context, module }
}
}
pub struct Scope<'ctx> {
context: &'ctx Context,
module: &'ctx Module<'ctx>,
function: &'ctx Function<'ctx>,
block: BasicBlock<'ctx>,
functions: HashMap<String, Function<'ctx>>,
stack_values: HashMap<String, Value<'ctx>>,
}
impl<'ctx> Scope<'ctx> {
pub fn with_block(&self, block: BasicBlock<'ctx>) -> Scope<'ctx> {
Scope {
block,
context: self.context,
function: self.function,
module: self.module,
functions: self.functions.clone(),
stack_values: self.stack_values.clone(),
}
}
/// Takes the block out from this scope, swaps the given block in it's place
/// and returns the old block.
pub fn swap_block(&mut self, block: BasicBlock<'ctx>) -> BasicBlock<'ctx> {
let mut old_block = block;
mem::swap(&mut self.block, &mut old_block);
old_block
}
}
impl mir::Statement {
pub fn codegen<'ctx>(&self, scope: &mut Scope<'ctx>) -> Option<Value<'ctx>> {
match &self.0 {
mir::StatementKind::Let(VariableReference(_, name, _), expression) => {
let value = expression.codegen(scope).unwrap();
scope.stack_values.insert(name.clone(), value);
None
}
mir::StatementKind::If(if_expression) => if_expression.codegen(scope),
mir::StatementKind::Import(_) => todo!(),
mir::StatementKind::Expression(expression) => {
let value = expression.codegen(scope).unwrap();
Some(value)
}
}
}
}
impl mir::IfExpression {
pub fn codegen<'ctx>(&self, scope: &mut Scope<'ctx>) -> Option<Value<'ctx>> {
let condition = self.0.codegen(scope).unwrap();
// Create blocks
let then_bb = scope.function.block("then");
let after_bb = scope.function.block("after");
let mut before_bb = scope.swap_block(after_bb);
let mut then_scope = scope.with_block(then_bb);
let then_res = self.1.codegen(&mut then_scope);
then_scope.block.br(&scope.block).ok();
let else_bb = scope.function.block("else");
let mut else_scope = scope.with_block(else_bb);
let else_opt = if let Some(else_block) = &self.2 {
before_bb
.conditional_br(&condition, &then_scope.block, &else_scope.block)
.unwrap();
let opt = else_block.codegen(&mut else_scope);
if let Some(ret) = opt {
else_scope.block.br(&scope.block).ok();
Some((else_scope.block, ret))
} else {
None
}
} else {
else_scope.block.br(&scope.block).unwrap();
before_bb
.conditional_br(&condition, &then_scope.block, &scope.block)
.unwrap();
None
};
if then_res.is_none() && else_opt.is_none() {
None
} else if let Ok(ret_type) = self.1.return_type() {
let phi = scope
.block
.phi(&ret_type.get_type(scope.context), "phi")
.unwrap();
if let Some(then_ret) = then_res {
phi.add_incoming(&then_ret, &then_scope.block);
}
if let Some((else_bb, else_ret)) = else_opt {
phi.add_incoming(&else_ret, &else_bb);
}
Some(phi.build())
} else {
None
}
}
}
impl mir::Expression {
pub fn codegen<'ctx>(&self, scope: &mut Scope<'ctx>) -> Option<Value<'ctx>> {
match &self.0 {
mir::ExpressionKind::Variable(varref) => {
let v = scope
.stack_values
.get(&varref.1)
.expect("Variable reference not found?!");
Some(v.clone())
}
mir::ExpressionKind::Literal(lit) => Some(lit.codegen(scope.context)),
mir::ExpressionKind::BinOp(binop, lhs_exp, rhs_exp) => {
let lhs = lhs_exp.codegen(scope).expect("lhs has no return value");
let rhs = rhs_exp.codegen(scope).expect("rhs has no return value");
Some(match binop {
mir::BinaryOperator::Add => scope.block.add(&lhs, &rhs, "add").unwrap(),
mir::BinaryOperator::Minus => scope.block.sub(&lhs, &rhs, "sub").unwrap(),
mir::BinaryOperator::Mult => todo!(),
mir::BinaryOperator::And => todo!(),
mir::BinaryOperator::Logic(l) => {
let ret_type = lhs_exp.return_type().expect("No ret type in lhs?");
scope
.block
.integer_compare(&lhs, &rhs, &l.int_predicate(ret_type.signed()), "cmp")
.unwrap()
}
})
}
mir::ExpressionKind::FunctionCall(call) => {
let params = call
.parameters
.iter()
.map(|e| e.codegen(scope).unwrap())
.collect();
let callee = scope
.functions
.get(&call.name)
.expect("function not found!");
Some(scope.block.call(callee, params, "call").unwrap())
}
mir::ExpressionKind::If(if_expression) => if_expression.codegen(scope),
mir::ExpressionKind::Block(block) => {
let mut inner_scope = scope.with_block(scope.function.block("inner"));
if let Some(ret) = block.codegen(&mut inner_scope) {
inner_scope.block.br(&scope.block);
Some(ret)
} else {
None
}
}
}
}
}
impl mir::LogicOperator {
fn int_predicate(&self, signed: bool) -> IntPredicate {
match (self, signed) {
(mir::LogicOperator::LessThan, true) => IntPredicate::SLT,
(mir::LogicOperator::GreaterThan, true) => IntPredicate::SGT,
(mir::LogicOperator::LessThan, false) => IntPredicate::ULT,
(mir::LogicOperator::GreaterThan, false) => IntPredicate::UGT,
}
}
}
impl mir::Block {
pub fn codegen<'ctx>(&self, mut scope: &mut Scope<'ctx>) -> Option<Value<'ctx>> {
for stmt in &self.statements {
stmt.codegen(&mut scope);
}
if let Some((kind, expr)) = &self.return_expression {
let ret = expr.codegen(&mut scope).unwrap();
match kind {
mir::ReturnKind::HardReturn => {
scope.block.ret(&ret).unwrap();
None
}
mir::ReturnKind::SoftReturn => Some(ret),
}
} else {
None
}
}
}
impl mir::Literal {
pub fn codegen<'ctx>(&self, context: &'ctx Context) -> Value<'ctx> {
let val: IntegerValue<'ctx> = match *self {
mir::Literal::I32(val) => context.type_i32().from_signed(val as i64),
mir::Literal::I16(val) => context.type_i16().from_signed(val as i64),
};
Value::Integer(val)
}
}
impl TypeKind {
fn get_type<'ctx>(&self, context: &'ctx Context) -> TypeEnum<'ctx> {
match &self {
TypeKind::I32 => TypeEnum::Integer(context.type_i32()),
TypeKind::I16 => TypeEnum::Integer(context.type_i16()),
}
}
}

View File

@ -1,11 +1,13 @@
use codegen::{form_context, from_statements};
use old_codegen::{form_context, from_statements};
use crate::{ast::TopLevelStatement, lexer::Token, token_stream::TokenStream};
use crate::{lexer::Token, parser::TopLevelStatement, token_stream::TokenStream};
mod ast;
mod codegen;
mod lexer;
pub mod mir;
mod old_codegen;
mod parser;
// mod llvm_ir;
pub mod codegen;
mod token_stream;
// TODO:

120
reid/src/mir/mod.rs Normal file
View File

@ -0,0 +1,120 @@
/// In this module are defined structs that are used for performing passes on
/// Reid. It contains a simplified version of Reid which must already be
/// type-checked beforehand.
use std::collections::HashMap;
use types::*;
use crate::token_stream::TokenRange;
pub mod types;
#[derive(Clone, Copy)]
pub enum TypeKind {
I32,
I16,
}
impl TypeKind {
pub fn signed(&self) -> bool {
match self {
_ => true,
}
}
}
#[derive(Clone, Copy)]
pub enum Literal {
I32(i32),
I16(i16),
}
impl Literal {
fn as_type(self: &Literal) -> TypeKind {
match self {
Literal::I32(_) => TypeKind::I32,
Literal::I16(_) => TypeKind::I16,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum BinaryOperator {
Add,
Minus,
Mult,
And,
Logic(LogicOperator),
}
#[derive(Debug, Clone, Copy)]
pub enum LogicOperator {
LessThan,
GreaterThan,
}
#[derive(Debug, Clone, Copy)]
pub enum ReturnKind {
HardReturn,
SoftReturn,
}
pub struct VariableReference(pub TypeKind, pub String, pub TokenRange);
pub struct Import(pub String, pub TokenRange);
pub enum ExpressionKind {
Variable(VariableReference),
Literal(Literal),
BinOp(BinaryOperator, Box<Expression>, Box<Expression>),
FunctionCall(FunctionCall),
If(IfExpression),
Block(Block),
}
pub struct Expression(pub ExpressionKind, pub TokenRange);
/// Condition, Then, Else
pub struct IfExpression(pub Box<Expression>, pub Block, pub Option<Block>);
pub struct FunctionCall {
pub name: String,
pub return_type: TypeKind,
pub parameters: Vec<Expression>,
}
pub struct FunctionDefinition {
pub name: String,
pub parameters: Vec<(String, TypeKind)>,
pub kind: FunctionDefinitionKind,
}
pub enum FunctionDefinitionKind {
/// Actual definition block and surrounding signature range
Local(Block, TokenRange),
/// Return Type
Extern(TypeKind),
}
pub struct Block {
/// List of non-returning statements
pub statements: Vec<Statement>,
pub return_expression: Option<(ReturnKind, Box<Expression>)>,
pub range: TokenRange,
}
pub struct Statement(pub StatementKind, pub TokenRange);
pub enum StatementKind {
/// Variable name+type, evaluation
Let(VariableReference, Expression),
If(IfExpression),
Import(Import),
Expression(Expression),
}
pub struct Module {
pub name: String,
pub imports: Vec<Import>,
pub functions: Vec<FunctionDefinition>,
}

75
reid/src/mir/types.rs Normal file
View File

@ -0,0 +1,75 @@
use super::*;
#[derive(Debug, Clone)]
pub enum ReturnTypeOther {
Import(TokenRange),
Let(TokenRange),
EmptyBlock(TokenRange),
NoBlockReturn(TokenRange),
}
pub trait ReturnType {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther>;
}
impl ReturnType for Block {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther> {
self.return_expression
.as_ref()
.ok_or(ReturnTypeOther::NoBlockReturn(self.range.clone()))
.and_then(|(_, stmt)| stmt.return_type())
}
}
impl ReturnType for Statement {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther> {
use StatementKind::*;
match &self.0 {
Expression(e) => e.return_type(),
If(e) => e.return_type(),
Import(_) => Err(ReturnTypeOther::Import(self.1)),
Let(_, _) => Err(ReturnTypeOther::Let(self.1)),
}
}
}
impl ReturnType for Expression {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther> {
use ExpressionKind::*;
match &self.0 {
Literal(lit) => Ok(lit.as_type()),
Variable(var) => var.return_type(),
BinOp(_, expr, _) => expr.return_type(),
Block(block) => block.return_type(),
FunctionCall(fcall) => fcall.return_type(),
If(expr) => expr.return_type(),
}
}
}
impl ReturnType for IfExpression {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther> {
self.1.return_type()
}
}
impl ReturnType for VariableReference {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther> {
Ok(self.0)
}
}
impl ReturnType for FunctionCall {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther> {
Ok(self.return_type)
}
}
impl ReturnType for FunctionDefinition {
fn return_type(&self) -> Result<TypeKind, ReturnTypeOther> {
match &self.kind {
FunctionDefinitionKind::Local(block, _) => block.return_type(),
FunctionDefinitionKind::Extern(type_kind) => Ok(*type_kind),
}
}
}

View File

@ -9,10 +9,10 @@ use llvm_sys::transforms::pass_manager_builder::{
LLVMPassManagerBuilderSetOptLevel,
};
use llvm_sys::{
LLVMBasicBlock, LLVMBuilder, LLVMContext, LLVMModule, LLVMType, LLVMValue, core::*, prelude::*,
core::*, prelude::*, LLVMBasicBlock, LLVMBuilder, LLVMContext, LLVMModule, LLVMType, LLVMValue,
};
use crate::ast;
use crate::parser;
fn into_cstring<T: Into<String>>(value: T) -> CString {
let string = value.into();
@ -47,8 +47,8 @@ impl IRType {
pub struct IRValue(pub IRType, *mut LLVMValue);
impl IRValue {
pub fn from_literal(literal: &ast::Literal, module: &IRModule) -> Self {
use ast::Literal;
pub fn from_literal(literal: &parser::Literal, module: &IRModule) -> Self {
use parser::Literal;
match literal {
Literal::I32(v) => {
let ir_type = IRType::I32;

View File

@ -5,7 +5,7 @@ use std::collections::HashMap;
use llvm::{Error, IRBlock, IRContext, IRFunction, IRModule, IRValue};
use crate::{
ast::{
parser::{
Block, BlockLevelStatement, Expression, ExpressionKind, FunctionDefinition, IfExpression,
LetStatement, ReturnType,
},
@ -97,7 +97,7 @@ impl Expression {
Binop(op, lhs, rhs) => {
let lhs = lhs.codegen(scope);
let rhs = rhs.codegen(scope);
use crate::ast::BinaryOperator::*;
use crate::parser::BinaryOperator::*;
match op {
Add => scope.block.add(lhs, rhs).unwrap(),
Mult => scope.block.mult(lhs, rhs).unwrap(),
@ -121,7 +121,7 @@ impl Expression {
_ => then.block.move_into(&mut scope.block),
}
IRValue::from_literal(&crate::ast::Literal::I32(1), scope.block.function.module)
IRValue::from_literal(&crate::parser::Literal::I32(1), scope.block.function.module)
}
BlockExpr(_) => panic!("block expr not supported"),
FunctionCall(_) => panic!("function call expr not supported"),

View File

@ -1,6 +1,6 @@
use crate::{
ast::Parse,
lexer::{FullToken, Position, Token},
parser::Parse,
};
pub struct TokenStream<'a, 'b> {
@ -156,7 +156,7 @@ impl Drop for TokenStream<'_, '_> {
}
}
#[derive(Clone)]
#[derive(Clone, Copy)]
pub struct TokenRange {
pub start: usize,
pub end: usize,
@ -168,6 +168,15 @@ impl std::fmt::Debug for TokenRange {
}
}
impl Default for TokenRange {
fn default() -> Self {
Self {
start: Default::default(),
end: Default::default(),
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Expected {} at Ln {}, Col {}, got {:?}", .0, (.2).1, (.2).0, .1)]