Implement some kind of if/else

This commit is contained in:
Sofia 2024-08-25 23:17:52 +03:00
parent 8defa39b31
commit e21f47e34b
3 changed files with 106 additions and 37 deletions

View File

@ -214,7 +214,7 @@ impl Parse for FunctionCallExpression {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct IfExpression(Expression, pub Block, pub TokenRange); pub struct IfExpression(pub Expression, pub Block, pub TokenRange);
impl Parse for IfExpression { impl Parse for IfExpression {
fn parse(mut stream: TokenStream) -> Result<Self, Error> { fn parse(mut stream: TokenStream) -> Result<Self, Error> {

View File

@ -1,3 +1,4 @@
use std::borrow::BorrowMut;
use std::ffi::{CStr, CString}; use std::ffi::{CStr, CString};
use std::mem; use std::mem;
@ -21,6 +22,7 @@ pub enum Error {
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub enum IRType { pub enum IRType {
I32, I32,
Boolean,
} }
impl IRType { impl IRType {
@ -29,6 +31,7 @@ impl IRType {
unsafe { unsafe {
return match self { return match self {
I32 => LLVMInt32TypeInContext(context.context), I32 => LLVMInt32TypeInContext(context.context),
Boolean => LLVMInt1TypeInContext(context.context),
}; };
} }
} }
@ -140,26 +143,20 @@ impl<'a, 'b> IRFunction<'a, 'b> {
} }
} }
} }
pub fn attach(&mut self, block: IRBlock) {
unsafe { LLVMAppendExistingBasicBlock(self.value, block.blockref) }
}
} }
pub struct IRBlock<'a, 'b> { pub struct IRBlock<'a, 'b, 'c> {
pub module: &'b IRModule<'a>, pub function: &'c IRFunction<'a, 'b>,
blockref: *mut LLVMBasicBlock, blockref: *mut LLVMBasicBlock,
} }
impl<'a, 'b, 'c> IRBlock<'a, 'b> { impl<'a, 'b, 'c> IRBlock<'a, 'b, 'c> {
pub fn new(module: &'b IRModule<'a>) -> IRBlock<'a, 'b> { pub fn new(function: &'c IRFunction<'a, 'b>, name: &CStr) -> IRBlock<'a, 'b, 'c> {
unsafe { unsafe {
let blockref = LLVMCreateBasicBlockInContext( let blockref =
module.context.context, LLVMCreateBasicBlockInContext(function.module.context.context, name.as_ptr());
into_cstring("entryblock").as_ptr(),
);
IRBlock { module, blockref } IRBlock { function, blockref }
} }
} }
@ -169,12 +166,12 @@ impl<'a, 'b, 'c> IRBlock<'a, 'b> {
IRValue(rhs_t, rhs_v): IRValue, IRValue(rhs_t, rhs_v): IRValue,
) -> Result<IRValue, Error> { ) -> Result<IRValue, Error> {
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.module.context.builder, self.blockref); LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref);
if lhs_t == rhs_t { if lhs_t == rhs_t {
Ok(IRValue( Ok(IRValue(
lhs_t, lhs_t,
LLVMBuildAdd( LLVMBuildAdd(
self.module.context.builder, self.function.module.context.builder,
lhs_v, lhs_v,
rhs_v, rhs_v,
c"tmpadd".as_ptr(), c"tmpadd".as_ptr(),
@ -192,12 +189,12 @@ impl<'a, 'b, 'c> IRBlock<'a, 'b> {
IRValue(rhs_t, rhs_v): IRValue, IRValue(rhs_t, rhs_v): IRValue,
) -> Result<IRValue, Error> { ) -> Result<IRValue, Error> {
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.module.context.builder, self.blockref); LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref);
if lhs_t == rhs_t { if lhs_t == rhs_t {
Ok(IRValue( Ok(IRValue(
lhs_t, lhs_t,
LLVMBuildMul( LLVMBuildMul(
self.module.context.builder, self.function.module.context.builder,
lhs_v, lhs_v,
rhs_v, rhs_v,
c"tmpadd".as_ptr(), c"tmpadd".as_ptr(),
@ -209,14 +206,70 @@ impl<'a, 'b, 'c> IRBlock<'a, 'b> {
} }
} }
pub fn less_than(
&mut self,
IRValue(lhs_t, lhs_v): IRValue,
IRValue(rhs_t, rhs_v): IRValue,
) -> Result<IRValue, Error> {
unsafe {
LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref);
if lhs_t == rhs_t {
Ok(IRValue(
IRType::Boolean,
LLVMBuildICmp(
self.function.module.context.builder,
llvm_sys::LLVMIntPredicate::LLVMIntULT,
lhs_v,
rhs_v,
c"IntULT".as_ptr(),
),
))
} else {
Err(Error::TypeMismatch(lhs_t, rhs_t))
}
}
}
pub fn add_return(&mut self, value: Option<IRValue>) { pub fn add_return(&mut self, value: Option<IRValue>) {
unsafe { unsafe {
LLVMPositionBuilderAtEnd(self.module.context.builder, self.blockref); LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref);
if let Some(IRValue(_, value)) = value { if let Some(IRValue(_, value)) = value {
LLVMBuildRet(self.module.context.builder, value); LLVMBuildRet(self.function.module.context.builder, value);
} else { } else {
LLVMBuildRetVoid(self.module.context.builder); LLVMBuildRetVoid(self.function.module.context.builder);
} }
} }
} }
pub fn branch(
&mut self,
IRValue(_, condition): IRValue,
then_block: &mut IRBlock,
else_block: &mut IRBlock,
) {
unsafe {
LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref);
LLVMBuildCondBr(
self.function.module.context.builder,
condition,
then_block.blockref,
else_block.blockref,
);
}
}
pub fn move_into(&mut self, block: &mut IRBlock) {
unsafe {
LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref);
LLVMBuildBr(self.function.module.context.builder, block.blockref);
}
}
}
impl<'a, 'b, 'c> Drop for IRBlock<'a, 'b, 'c> {
fn drop(&mut self) {
unsafe {
LLVMAppendExistingBasicBlock(self.function.value, self.blockref);
}
}
} }

View File

@ -6,8 +6,8 @@ use llvm::{Error, IRBlock, IRContext, IRFunction, IRModule, IRValue};
use crate::{ use crate::{
ast::{ ast::{
Block, BlockLevelStatement, Expression, ExpressionKind, FunctionDefinition, LetStatement, Block, BlockLevelStatement, Expression, ExpressionKind, FunctionDefinition, IfExpression,
ReturnType, LetStatement, ReturnType,
}, },
TopLevelStatement, TopLevelStatement,
}; };
@ -41,13 +41,11 @@ impl TopLevelStatement {
impl FunctionDefinition { impl FunctionDefinition {
fn codegen(&self, scope: &mut ScopeData, module: &mut IRModule) { fn codegen(&self, scope: &mut ScopeData, module: &mut IRModule) {
let FunctionDefinition(signature, block, _) = self; let FunctionDefinition(signature, block, _) = self;
let mut ir_function = IRFunction::new(&signature.name, module); let ir_function = IRFunction::new(&signature.name, module);
let ir_block = IRBlock::new(&module); let ir_block = IRBlock::new(&ir_function, c"entry");
let mut scope = scope.inner(ir_block); let mut scope = scope.inner(ir_block);
block.codegen(&mut scope); block.codegen(&mut scope);
ir_function.attach(scope.block);
} }
} }
@ -87,7 +85,7 @@ impl Expression {
use ExpressionKind::*; use ExpressionKind::*;
match kind { match kind {
Literal(lit) => IRValue::from_literal(lit, &mut scope.block.module), Literal(lit) => IRValue::from_literal(lit, &scope.block.function.module),
VariableName(v) => scope.data.fetch(v), VariableName(v) => scope.data.fetch(v),
Binop(op, lhs, rhs) => { Binop(op, lhs, rhs) => {
let lhs = lhs.codegen(scope); let lhs = lhs.codegen(scope);
@ -96,10 +94,28 @@ impl Expression {
match op { match op {
Add => scope.block.add(lhs, rhs).unwrap(), Add => scope.block.add(lhs, rhs).unwrap(),
Mult => scope.block.mult(lhs, rhs).unwrap(), Mult => scope.block.mult(lhs, rhs).unwrap(),
LessThan => scope.block.less_than(lhs, rhs).unwrap(),
_ => panic!("operator not supported: {:?}", op), _ => panic!("operator not supported: {:?}", op),
} }
} }
_ => panic!("expression type not supported"), IfExpr(ifx) => {
let IfExpression(expr, block, _) = ifx.as_ref();
let condition = expr.codegen(scope);
let mut then = IRBlock::new(scope.block.function, c"then");
let mut after = IRBlock::new(scope.block.function, c"merge");
scope.block.branch(condition, &mut then, &mut after);
scope.block = after;
let mut inner = scope.inner(then);
block.codegen(&mut inner);
inner.block.move_into(&mut scope.block);
IRValue::from_literal(&crate::ast::Literal::I32(1), scope.block.function.module)
}
BlockExpr(_) => panic!("block expr not supported"),
FunctionCall(_) => panic!("function call expr not supported"),
} }
} }
} }
@ -116,11 +132,11 @@ impl ScopeData {
} }
} }
fn with_block<'a, 'b>(self, block: IRBlock<'a, 'b>) -> Scope<'a, 'b> { fn with_block<'a, 'b, 'c>(self, block: IRBlock<'a, 'b, 'c>) -> Scope<'a, 'b, 'c> {
Scope { data: self, block } Scope { data: self, block }
} }
fn inner<'a, 'b>(&self, block: IRBlock<'a, 'b>) -> Scope<'a, 'b> { fn inner<'a, 'b, 'c>(&self, block: IRBlock<'a, 'b, 'c>) -> Scope<'a, 'b, 'c> {
self.clone().with_block(block) self.clone().with_block(block)
} }
@ -139,13 +155,13 @@ impl ScopeData {
} }
} }
struct Scope<'a, 'b> { struct Scope<'a, 'b, 'c> {
data: ScopeData, data: ScopeData,
block: IRBlock<'a, 'b>, block: IRBlock<'a, 'b, 'c>,
} }
impl<'a, 'b> Scope<'a, 'b> { impl<'a, 'b, 'c> Scope<'a, 'b, 'c> {
fn inner(&self, block: IRBlock<'a, 'b>) -> Scope<'a, 'b> { fn inner(&self, block: IRBlock<'a, 'b, 'c>) -> Scope<'a, 'b, 'c> {
self.data.clone().with_block(block) self.data.clone().with_block(block)
} }
} }