From 6448b0c438643d5c1a6abe7bb3a0dbac663786a6 Mon Sep 17 00:00:00 2001 From: sofia Date: Wed, 21 Aug 2024 23:31:09 +0300 Subject: [PATCH] Add necessary codegen for easy.reid --- src/ast.rs | 8 +-- src/codegen/llvm.rs | 133 ++++++++++++++++++++++++++++++-------------- src/codegen/mod.rs | 41 ++++++++++++-- 3 files changed, 130 insertions(+), 52 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 774b3f4..2bb37ba 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -362,7 +362,7 @@ impl Parse for Block { statements.push(BlockLevelStatement::Expression(e)); } let statement = stream.parse()?; - if let BlockLevelStatement::Return((r_type, e)) = &statement { + if let BlockLevelStatement::Return(r_type, e) = &statement { match r_type { ReturnType::Hard => { return_stmt = Some((*r_type, e.clone())); @@ -387,7 +387,7 @@ pub enum BlockLevelStatement { Let(LetStatement), Import(ImportStatement), Expression(Expression), - Return((ReturnType, Expression)), + Return(ReturnType, Expression), } impl Parse for BlockLevelStatement { @@ -400,14 +400,14 @@ impl Parse for BlockLevelStatement { stream.next(); let exp = stream.parse()?; stream.expect(Token::Semi)?; - Stmt::Return((ReturnType::Hard, exp)) + Stmt::Return(ReturnType::Hard, exp) } _ => { if let Ok(e) = stream.parse() { if stream.expect(Token::Semi).is_ok() { Stmt::Expression(e) } else { - Stmt::Return((ReturnType::Soft, e)) + Stmt::Return(ReturnType::Soft, e) } } else { Err(stream.expected_err("expression")?)? diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs index 1252e13..cc5a2fa 100644 --- a/src/codegen/llvm.rs +++ b/src/codegen/llvm.rs @@ -12,6 +12,50 @@ fn into_cstring>(value: T) -> CString { unsafe { CString::from_vec_with_nul_unchecked((string + "\0").into_bytes()) } } +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Type mismatch: {0:?} vs {1:?}")] + TypeMismatch(IRType, IRType), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum IRType { + I32, +} + +impl IRType { + fn in_context(&self, context: &mut IRContext) -> *mut LLVMType { + use IRType::*; + unsafe { + return match self { + I32 => LLVMInt32TypeInContext(context.context), + }; + } + } +} + +#[derive(Clone)] +pub struct IRValue(pub IRType, *mut LLVMValue); + +impl IRValue { + pub fn from_literal(literal: &ast::Literal, block: &mut IRBlock) -> Self { + use ast::Literal; + match literal { + Literal::I32(v) => { + let ir_type = IRType::I32; + unsafe { + let ir_value = LLVMConstInt( + ir_type.in_context(block.function.module.context), + mem::transmute(*v as i64), + 1, + ); + return IRValue(ir_type, ir_value); + } + } + }; + } +} + pub struct IRContext { context: *mut LLVMContext, builder: *mut LLVMBuilder, @@ -71,22 +115,6 @@ impl<'a> Drop for IRModule<'a> { } } -#[derive(Clone)] -pub enum IRType { - I32, -} - -impl IRType { - fn in_context(&self, context: &mut IRContext) -> *mut LLVMType { - use IRType::*; - unsafe { - return match self { - I32 => LLVMInt32TypeInContext(context.context), - }; - } - } -} - pub struct IRFunction<'a, 'b> { module: &'b mut IRModule<'a>, /// The actual function @@ -132,10 +160,54 @@ impl<'a, 'b, 'c> IRBlock<'a, 'b, 'c> { } } + pub fn add( + &mut self, + IRValue(lhs_t, lhs_v): IRValue, + IRValue(rhs_t, rhs_v): IRValue, + ) -> Result { + unsafe { + if lhs_t == rhs_t { + Ok(IRValue( + lhs_t, + LLVMBuildAdd( + self.function.module.context.builder, + lhs_v, + rhs_v, + c"tmpadd".as_ptr(), + ), + )) + } else { + Err(Error::TypeMismatch(lhs_t, rhs_t)) + } + } + } + + pub fn mult( + &mut self, + IRValue(lhs_t, lhs_v): IRValue, + IRValue(rhs_t, rhs_v): IRValue, + ) -> Result { + unsafe { + if lhs_t == rhs_t { + Ok(IRValue( + lhs_t, + LLVMBuildMul( + self.function.module.context.builder, + lhs_v, + rhs_v, + c"tmpadd".as_ptr(), + ), + )) + } else { + Err(Error::TypeMismatch(lhs_t, rhs_t)) + } + } + } + pub fn add_return(self, value: Option) { unsafe { - if let Some(value) = value { - LLVMBuildRet(self.function.module.context.builder, value.ir_value); + if let Some(IRValue(_, value)) = value { + LLVMBuildRet(self.function.module.context.builder, value); } else { LLVMBuildRetVoid(self.function.module.context.builder); } @@ -150,28 +222,3 @@ impl<'a, 'b, 'c> Drop for IRBlock<'a, 'b, 'c> { } } } - -#[derive(Clone)] -pub struct IRValue { - pub ir_type: IRType, - ir_value: *mut LLVMValue, -} - -impl IRValue { - pub fn from_literal(literal: &ast::Literal, block: &mut IRBlock) -> Self { - use ast::Literal; - match literal { - Literal::I32(v) => { - let ir_type = IRType::I32; - unsafe { - let ir_value = LLVMConstInt( - ir_type.in_context(block.function.module.context), - mem::transmute(*v as i64), - 1, - ); - return IRValue { ir_type, ir_value }; - } - } - }; - } -} diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 51ab6ef..cbdf29c 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -2,16 +2,16 @@ mod llvm; use std::collections::HashMap; -use llvm::{IRBlock, IRContext, IRFunction, IRModule, IRValue}; +use llvm::{Error, IRBlock, IRContext, IRFunction, IRModule, IRValue}; use crate::{ - ast::{Block, Expression, ExpressionKind, FunctionDefinition}, + ast::{ + BinaryOperator, Block, BlockLevelStatement, Expression, ExpressionKind, FunctionDefinition, + LetStatement, ReturnType, + }, TopLevelStatement, }; -#[derive(thiserror::Error, Debug)] -pub enum Error {} - pub fn form_context() -> IRContext { IRContext::new() } @@ -49,6 +49,10 @@ impl FunctionDefinition { impl Block { fn codegen(&self, mut scope: Scope) { + for statement in &self.0 { + statement.codegen(&mut scope); + } + if let Some((_, return_exp)) = &self.1 { let value = return_exp.codegen(&mut scope); scope.block.add_return(Some(value)); @@ -56,6 +60,23 @@ impl Block { } } +impl BlockLevelStatement { + fn codegen(&self, scope: &mut Scope) { + use BlockLevelStatement::*; + match self { + Expression(exp) | Return(ReturnType::Soft, exp) => { + exp.codegen(scope); + } + Let(LetStatement(name, exp, _)) => { + let val = exp.codegen(scope); + scope.data.insert(name, val); + } + Return(ReturnType::Hard, _) => panic!("hard returns here should not be possible.."), + Import(_) => panic!("block level import not supported"), + } + } +} + impl Expression { fn codegen(&self, scope: &mut Scope) -> IRValue { let Expression(kind, _) = self; @@ -64,6 +85,16 @@ impl Expression { match kind { Literal(lit) => IRValue::from_literal(lit, &mut scope.block), VariableName(v) => scope.data.fetch(v), + Binop(op, lhs, rhs) => { + let lhs = lhs.codegen(scope); + let rhs = rhs.codegen(scope); + use crate::ast::BinaryOperator::*; + match op { + Add => scope.block.add(lhs, rhs).unwrap(), + Mult => scope.block.mult(lhs, rhs).unwrap(), + _ => panic!("operator not supported: {:?}", op), + } + } _ => panic!("expression type not supported"), } }