diff --git a/src/ast.rs b/src/ast.rs index 2bb37ba..3c19a55 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -214,7 +214,7 @@ impl Parse for FunctionCallExpression { } #[derive(Debug, Clone)] -pub struct IfExpression(Expression, pub Block, pub TokenRange); +pub struct IfExpression(pub Expression, pub Block, pub TokenRange); impl Parse for IfExpression { fn parse(mut stream: TokenStream) -> Result { diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs index 03df00e..7b76565 100644 --- a/src/codegen/llvm.rs +++ b/src/codegen/llvm.rs @@ -1,3 +1,4 @@ +use std::borrow::BorrowMut; use std::ffi::{CStr, CString}; use std::mem; @@ -21,6 +22,7 @@ pub enum Error { #[derive(Clone, Debug, Eq, PartialEq)] pub enum IRType { I32, + Boolean, } impl IRType { @@ -29,6 +31,7 @@ impl IRType { unsafe { return match self { I32 => LLVMInt32TypeInContext(context.context), + Boolean => LLVMInt1TypeInContext(context.context), }; } } @@ -140,26 +143,20 @@ impl<'a, 'b> IRFunction<'a, 'b> { } } } - - pub fn attach(&mut self, block: IRBlock) { - unsafe { LLVMAppendExistingBasicBlock(self.value, block.blockref) } - } } -pub struct IRBlock<'a, 'b> { - pub module: &'b IRModule<'a>, +pub struct IRBlock<'a, 'b, 'c> { + pub function: &'c IRFunction<'a, 'b>, blockref: *mut LLVMBasicBlock, } -impl<'a, 'b, 'c> IRBlock<'a, 'b> { - pub fn new(module: &'b IRModule<'a>) -> IRBlock<'a, 'b> { +impl<'a, 'b, 'c> IRBlock<'a, 'b, 'c> { + pub fn new(function: &'c IRFunction<'a, 'b>, name: &CStr) -> IRBlock<'a, 'b, 'c> { unsafe { - let blockref = LLVMCreateBasicBlockInContext( - module.context.context, - into_cstring("entryblock").as_ptr(), - ); + let blockref = + LLVMCreateBasicBlockInContext(function.module.context.context, name.as_ptr()); - IRBlock { module, blockref } + IRBlock { function, blockref } } } @@ -169,12 +166,12 @@ impl<'a, 'b, 'c> IRBlock<'a, 'b> { IRValue(rhs_t, rhs_v): IRValue, ) -> Result { unsafe { - LLVMPositionBuilderAtEnd(self.module.context.builder, self.blockref); + LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref); if lhs_t == rhs_t { Ok(IRValue( lhs_t, LLVMBuildAdd( - self.module.context.builder, + self.function.module.context.builder, lhs_v, rhs_v, c"tmpadd".as_ptr(), @@ -192,12 +189,12 @@ impl<'a, 'b, 'c> IRBlock<'a, 'b> { IRValue(rhs_t, rhs_v): IRValue, ) -> Result { unsafe { - LLVMPositionBuilderAtEnd(self.module.context.builder, self.blockref); + LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref); if lhs_t == rhs_t { Ok(IRValue( lhs_t, LLVMBuildMul( - self.module.context.builder, + self.function.module.context.builder, lhs_v, rhs_v, c"tmpadd".as_ptr(), @@ -209,14 +206,70 @@ impl<'a, 'b, 'c> IRBlock<'a, 'b> { } } - pub fn add_return(&mut self, value: Option) { + pub fn less_than( + &mut self, + IRValue(lhs_t, lhs_v): IRValue, + IRValue(rhs_t, rhs_v): IRValue, + ) -> Result { unsafe { - LLVMPositionBuilderAtEnd(self.module.context.builder, self.blockref); - if let Some(IRValue(_, value)) = value { - LLVMBuildRet(self.module.context.builder, value); + LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref); + if lhs_t == rhs_t { + Ok(IRValue( + IRType::Boolean, + LLVMBuildICmp( + self.function.module.context.builder, + llvm_sys::LLVMIntPredicate::LLVMIntULT, + lhs_v, + rhs_v, + c"IntULT".as_ptr(), + ), + )) } else { - LLVMBuildRetVoid(self.module.context.builder); + Err(Error::TypeMismatch(lhs_t, rhs_t)) } } } + + pub fn add_return(&mut self, value: Option) { + unsafe { + LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref); + if let Some(IRValue(_, value)) = value { + LLVMBuildRet(self.function.module.context.builder, value); + } else { + LLVMBuildRetVoid(self.function.module.context.builder); + } + } + } + + pub fn branch( + &mut self, + IRValue(_, condition): IRValue, + then_block: &mut IRBlock, + else_block: &mut IRBlock, + ) { + unsafe { + LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref); + LLVMBuildCondBr( + self.function.module.context.builder, + condition, + then_block.blockref, + else_block.blockref, + ); + } + } + + pub fn move_into(&mut self, block: &mut IRBlock) { + unsafe { + LLVMPositionBuilderAtEnd(self.function.module.context.builder, self.blockref); + LLVMBuildBr(self.function.module.context.builder, block.blockref); + } + } +} + +impl<'a, 'b, 'c> Drop for IRBlock<'a, 'b, 'c> { + fn drop(&mut self) { + unsafe { + LLVMAppendExistingBasicBlock(self.function.value, self.blockref); + } + } } diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 9b1d2d6..c5ddfb9 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -6,8 +6,8 @@ use llvm::{Error, IRBlock, IRContext, IRFunction, IRModule, IRValue}; use crate::{ ast::{ - Block, BlockLevelStatement, Expression, ExpressionKind, FunctionDefinition, LetStatement, - ReturnType, + Block, BlockLevelStatement, Expression, ExpressionKind, FunctionDefinition, IfExpression, + LetStatement, ReturnType, }, TopLevelStatement, }; @@ -41,13 +41,11 @@ impl TopLevelStatement { impl FunctionDefinition { fn codegen(&self, scope: &mut ScopeData, module: &mut IRModule) { let FunctionDefinition(signature, block, _) = self; - let mut ir_function = IRFunction::new(&signature.name, module); + let ir_function = IRFunction::new(&signature.name, module); - let ir_block = IRBlock::new(&module); + let ir_block = IRBlock::new(&ir_function, c"entry"); let mut scope = scope.inner(ir_block); block.codegen(&mut scope); - - ir_function.attach(scope.block); } } @@ -87,7 +85,7 @@ impl Expression { use ExpressionKind::*; match kind { - Literal(lit) => IRValue::from_literal(lit, &mut scope.block.module), + Literal(lit) => IRValue::from_literal(lit, &scope.block.function.module), VariableName(v) => scope.data.fetch(v), Binop(op, lhs, rhs) => { let lhs = lhs.codegen(scope); @@ -96,10 +94,28 @@ impl Expression { match op { Add => scope.block.add(lhs, rhs).unwrap(), Mult => scope.block.mult(lhs, rhs).unwrap(), + LessThan => scope.block.less_than(lhs, rhs).unwrap(), _ => panic!("operator not supported: {:?}", op), } } - _ => panic!("expression type not supported"), + IfExpr(ifx) => { + let IfExpression(expr, block, _) = ifx.as_ref(); + let condition = expr.codegen(scope); + + let mut then = IRBlock::new(scope.block.function, c"then"); + let mut after = IRBlock::new(scope.block.function, c"merge"); + + scope.block.branch(condition, &mut then, &mut after); + scope.block = after; + + let mut inner = scope.inner(then); + block.codegen(&mut inner); + inner.block.move_into(&mut scope.block); + + IRValue::from_literal(&crate::ast::Literal::I32(1), scope.block.function.module) + } + BlockExpr(_) => panic!("block expr not supported"), + FunctionCall(_) => panic!("function call expr not supported"), } } } @@ -116,11 +132,11 @@ impl ScopeData { } } - fn with_block<'a, 'b>(self, block: IRBlock<'a, 'b>) -> Scope<'a, 'b> { + fn with_block<'a, 'b, 'c>(self, block: IRBlock<'a, 'b, 'c>) -> Scope<'a, 'b, 'c> { Scope { data: self, block } } - fn inner<'a, 'b>(&self, block: IRBlock<'a, 'b>) -> Scope<'a, 'b> { + fn inner<'a, 'b, 'c>(&self, block: IRBlock<'a, 'b, 'c>) -> Scope<'a, 'b, 'c> { self.clone().with_block(block) } @@ -139,13 +155,13 @@ impl ScopeData { } } -struct Scope<'a, 'b> { +struct Scope<'a, 'b, 'c> { data: ScopeData, - block: IRBlock<'a, 'b>, + block: IRBlock<'a, 'b, 'c>, } -impl<'a, 'b> Scope<'a, 'b> { - fn inner(&self, block: IRBlock<'a, 'b>) -> Scope<'a, 'b> { +impl<'a, 'b, 'c> Scope<'a, 'b, 'c> { + fn inner(&self, block: IRBlock<'a, 'b, 'c>) -> Scope<'a, 'b, 'c> { self.data.clone().with_block(block) } }