From d757ac4eb3e3ec98bdca0c06171a0a11a3c117a1 Mon Sep 17 00:00:00 2001 From: sofia Date: Wed, 9 Jul 2025 21:12:39 +0300 Subject: [PATCH] Fiddle around with if-expression codegen --- reid-llvm-lib/examples/libtest.rs | 4 +- reid-llvm-lib/src/builder.rs | 15 +++-- reid-llvm-lib/src/compile.rs | 4 +- reid-llvm-lib/src/debug.rs | 6 +- reid-llvm-lib/src/lib.rs | 8 +-- reid/examples/reid/fibonacci.reid | 10 +-- reid/src/ast/mod.rs | 3 + reid/src/ast/parse.rs | 15 +++-- reid/src/codegen.rs | 100 ++++++++++++------------------ reid/src/lexer.rs | 3 + reid/src/lib.rs | 2 + reid/src/mir/display.rs | 2 +- reid/src/mir/typecheck.rs | 13 +++- 13 files changed, 99 insertions(+), 86 deletions(-) diff --git a/reid-llvm-lib/examples/libtest.rs b/reid-llvm-lib/examples/libtest.rs index 0630a55..834a8fe 100644 --- a/reid-llvm-lib/examples/libtest.rs +++ b/reid-llvm-lib/examples/libtest.rs @@ -1,8 +1,8 @@ -use reid_lib::{ConstValue, Context, InstructionKind, CmpPredicate, TerminatorKind, Type}; +use reid_lib::{ConstValue, Context, Instr, CmpPredicate, TerminatorKind, Type}; fn main() { use ConstValue::*; - use InstructionKind::*; + use Instr::*; let context = Context::new(); diff --git a/reid-llvm-lib/src/builder.rs b/reid-llvm-lib/src/builder.rs index ec0dadc..d6819fd 100644 --- a/reid-llvm-lib/src/builder.rs +++ b/reid-llvm-lib/src/builder.rs @@ -4,8 +4,8 @@ use std::{cell::RefCell, rc::Rc}; use crate::{ - BlockData, ConstValue, FunctionData, InstructionData, InstructionKind, ModuleData, - TerminatorKind, Type, util::match_types, + BlockData, ConstValue, FunctionData, Instr, InstructionData, ModuleData, TerminatorKind, Type, + util::match_types, }; #[derive(Clone, Hash, Copy, PartialEq, Eq)] @@ -196,7 +196,7 @@ impl Builder { } pub fn check_instruction(&self, instruction: &InstructionValue) -> Result<(), ()> { - use super::InstructionKind::*; + use super::Instr::*; unsafe { match self.instr_data(&instruction).kind { Param(_) => Ok(()), @@ -228,6 +228,11 @@ impl Builder { Phi(vals) => { let mut iter = vals.iter(); // TODO error: Phi must contain at least one item + + // TODO error: compile can actually crash here if any of the + // incoming values come from blocks that are added later + // than the one where this one exists. + let first = iter.next().ok_or(())?; for item in iter { match_types(first, item, &self)?; @@ -241,7 +246,7 @@ impl Builder { impl InstructionValue { pub(crate) fn get_type(&self, builder: &Builder) -> Result { - use InstructionKind::*; + use Instr::*; unsafe { match &builder.instr_data(self).kind { Param(nth) => builder @@ -323,7 +328,7 @@ impl TerminatorKind { use TerminatorKind::*; match self { Ret(instr_val) => instr_val.get_type(builder), - Branch(_) => Ok(Type::Void), + Br(_) => Ok(Type::Void), CondBr(_, _, _) => Ok(Type::Void), } } diff --git a/reid-llvm-lib/src/compile.rs b/reid-llvm-lib/src/compile.rs index ea123e4..c034402 100644 --- a/reid-llvm-lib/src/compile.rs +++ b/reid-llvm-lib/src/compile.rs @@ -253,7 +253,7 @@ impl InstructionHolder { ) -> LLVMValue { let _ty = self.value.get_type(module.builder).unwrap(); let val = unsafe { - use super::InstructionKind::*; + use super::Instr::*; match &self.data.kind { Param(nth) => LLVMGetParam(function.value_ref, *nth as u32), Constant(val) => val.as_llvm(module.context_ref), @@ -348,7 +348,7 @@ impl TerminatorKind { let value = module.values.get(val).unwrap(); LLVMBuildRet(module.builder_ref, value.value_ref) } - TerminatorKind::Branch(block_value) => { + TerminatorKind::Br(block_value) => { let dest = *module.blocks.get(block_value).unwrap(); LLVMBuildBr(module.builder_ref, dest) } diff --git a/reid-llvm-lib/src/debug.rs b/reid-llvm-lib/src/debug.rs index f179bab..9dd86ab 100644 --- a/reid-llvm-lib/src/debug.rs +++ b/reid-llvm-lib/src/debug.rs @@ -2,7 +2,7 @@ use std::fmt::{Debug, Write}; -use crate::{CmpPredicate, InstructionData, InstructionKind, TerminatorKind, builder::*}; +use crate::{CmpPredicate, Instr, InstructionData, TerminatorKind, builder::*}; impl Debug for Builder { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -81,7 +81,7 @@ impl Debug for InstructionValue { } } -impl Debug for InstructionKind { +impl Debug for Instr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Param(nth) => fmt_call(f, &"Param", &nth), @@ -141,7 +141,7 @@ impl Debug for TerminatorKind { write!(f, "Ret ")?; val.fmt(f) } - Self::Branch(val) => { + Self::Br(val) => { write!(f, "Br ")?; val.fmt(f) } diff --git a/reid-llvm-lib/src/lib.rs b/reid-llvm-lib/src/lib.rs index b0d3c79..af9e135 100644 --- a/reid-llvm-lib/src/lib.rs +++ b/reid-llvm-lib/src/lib.rs @@ -119,7 +119,7 @@ pub struct Block<'builder> { } impl<'builder> Block<'builder> { - pub fn build(&mut self, instruction: InstructionKind) -> Result { + pub fn build(&mut self, instruction: Instr) -> Result { unsafe { self.builder .add_instruction(&self.value, InstructionData { kind: instruction }) @@ -137,7 +137,7 @@ impl<'builder> Block<'builder> { #[derive(Clone, Hash)] pub struct InstructionData { - kind: InstructionKind, + kind: Instr, } #[derive(Clone, Copy, Hash)] @@ -151,7 +151,7 @@ pub enum CmpPredicate { } #[derive(Clone, Hash)] -pub enum InstructionKind { +pub enum Instr { Param(usize), Constant(ConstValue), Add(InstructionValue, InstructionValue), @@ -200,6 +200,6 @@ pub enum ConstValue { #[derive(Clone, Hash)] pub enum TerminatorKind { Ret(InstructionValue), - Branch(BlockValue), + Br(BlockValue), CondBr(InstructionValue, BlockValue, BlockValue), } diff --git a/reid/examples/reid/fibonacci.reid b/reid/examples/reid/fibonacci.reid index 3567752..d3a3034 100644 --- a/reid/examples/reid/fibonacci.reid +++ b/reid/examples/reid/fibonacci.reid @@ -5,8 +5,10 @@ fn main() -> bool { // Fibonacci fn fibonacci(value: u16) -> u16 { - if value <= 2 { - return 1; - } - return fibonacci(value - 1) + fibonacci(value - 2); + let ret = if value <= 2 { + 1 + } else { + fibonacci(value - 1) + fibonacci(value - 2) + }; + ret } diff --git a/reid/src/ast/mod.rs b/reid/src/ast/mod.rs index 8739dd8..57772e7 100644 --- a/reid/src/ast/mod.rs +++ b/reid/src/ast/mod.rs @@ -1,3 +1,6 @@ +//! This is the module that contains relevant code to parsing Reid, that is to +//! say transforming a Vec of FullTokens into a loose parsed AST that can be +//! used for unwrapping syntax sugar, and then be transformed into Reid MIR. use crate::token_stream::TokenRange; pub mod parse; diff --git a/reid/src/ast/parse.rs b/reid/src/ast/parse.rs index a21820f..a78c951 100644 --- a/reid/src/ast/parse.rs +++ b/reid/src/ast/parse.rs @@ -204,10 +204,17 @@ impl Parse for FunctionCallExpression { impl Parse for IfExpression { fn parse(mut stream: TokenStream) -> Result { stream.expect(Token::If)?; + let cond = stream.parse()?; + let then_b = stream.parse()?; + let else_b = if let Ok(_) = stream.expect(Token::Else) { + Some(stream.parse()?) + } else { + None + }; Ok(IfExpression( - stream.parse()?, - stream.parse()?, - None, + cond, + then_b, + else_b, stream.get_range().unwrap(), )) } @@ -324,7 +331,7 @@ impl Parse for Block { ReturnType::Hard => { return_stmt = Some((*r_type, e.clone())); break; // Return has to be the last statement - // TODO: Make a mechanism that "can" parse even after this + // TODO: Make a mechanism that "can" parse even after this } ReturnType::Soft => { return_stmt = Some((*r_type, e.clone())); diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index b1b11d2..f0e62ba 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -1,11 +1,11 @@ use std::{collections::HashMap, mem}; use reid_lib::{ - builder::InstructionValue, Block, CmpPredicate, ConstValue, Context, Function, InstructionKind, - Module, TerminatorKind, Type, + Block, CmpPredicate, ConstValue, Context, Function, Instr, Module, TerminatorKind as Term, + Type, builder::InstructionValue, }; -use crate::mir::{self, types::ReturnType, TypeKind, VariableReference}; +use crate::mir::{self, TypeKind, VariableReference, types::ReturnType}; /// Context that contains all of the given modules as complete codegenerated /// LLIR that can then be finally compiled into LLVM IR. @@ -74,10 +74,7 @@ impl mir::Module { let mut stack_values = HashMap::new(); for (i, (p_name, _)) in mir_function.parameters.iter().enumerate() { - stack_values.insert( - p_name.clone(), - entry.build(InstructionKind::Param(i)).unwrap(), - ); + stack_values.insert(p_name.clone(), entry.build(Instr::Param(i)).unwrap()); } let mut scope = Scope { @@ -91,7 +88,7 @@ impl mir::Module { match &mir_function.kind { mir::FunctionDefinitionKind::Local(block, _) => { if let Some(ret) = block.codegen(&mut scope) { - scope.block.terminate(TerminatorKind::Ret(ret)).unwrap(); + scope.block.terminate(Term::Ret(ret)).unwrap(); } } mir::FunctionDefinitionKind::Extern => {} @@ -155,62 +152,53 @@ impl mir::IfExpression { let condition = self.0.codegen(scope).unwrap(); // Create blocks - let then_bb = scope.function.block("then"); - let after_bb = scope.function.block("after"); - let mut before_bb = scope.swap_block(after_bb); + let then_b = scope.function.block("then"); + let mut else_b = scope.function.block("else"); + let after_b = scope.function.block("after"); - let mut then_scope = scope.with_block(then_bb); + // Store for convenience + let then_bb = then_b.value(); + let else_bb = else_b.value(); + let after_bb = after_b.value(); + + // Generate then-block content + let mut then_scope = scope.with_block(then_b); let then_res = self.1.codegen(&mut then_scope); - then_scope - .block - .terminate(TerminatorKind::Branch(scope.block.value())) - .ok(); - - let else_bb = scope.function.block("else"); - let mut else_scope = scope.with_block(else_bb); + then_scope.block.terminate(Term::Br(after_bb)).ok(); let else_res = if let Some(else_block) = &self.2 { - before_bb - .terminate(TerminatorKind::CondBr( - condition, - then_scope.block.value(), - else_scope.block.value(), - )) + let mut else_scope = scope.with_block(else_b); + scope + .block + .terminate(Term::CondBr(condition, then_bb, else_bb)) .unwrap(); let opt = else_block.codegen(&mut else_scope); if let Some(ret) = opt { - else_scope - .block - .terminate(TerminatorKind::Branch(scope.block.value())) - .ok(); + else_scope.block.terminate(Term::Br(after_bb)).ok(); Some(ret) } else { None } } else { - else_scope + else_b.terminate(Term::Br(after_bb)).unwrap(); + scope .block - .terminate(TerminatorKind::Branch(scope.block.value())) - .unwrap(); - before_bb - .terminate(TerminatorKind::CondBr( - condition, - then_scope.block.value(), - scope.block.value(), - )) + .terminate(Term::CondBr(condition, then_bb, after_bb)) .unwrap(); None }; + // Swap block to the after-block so that construction can continue correctly + scope.swap_block(after_b); + if then_res.is_none() && else_res.is_none() { None } else { - let mut inc = Vec::from(then_res.as_slice()); - inc.extend(else_res); - - Some(scope.block.build(InstructionKind::Phi(vec![])).unwrap()) + let mut incoming = Vec::from(then_res.as_slice()); + incoming.extend(else_res); + Some(scope.block.build(Instr::Phi(incoming)).unwrap()) } } } @@ -242,21 +230,13 @@ impl mir::Expression { let lhs = lhs_exp.codegen(scope).expect("lhs has no return value"); let rhs = rhs_exp.codegen(scope).expect("rhs has no return value"); Some(match binop { - mir::BinaryOperator::Add => { - scope.block.build(InstructionKind::Add(lhs, rhs)).unwrap() - } - mir::BinaryOperator::Minus => { - scope.block.build(InstructionKind::Sub(lhs, rhs)).unwrap() - } - mir::BinaryOperator::Mult => { - scope.block.build(InstructionKind::Mult(lhs, rhs)).unwrap() - } - mir::BinaryOperator::And => { - scope.block.build(InstructionKind::And(lhs, rhs)).unwrap() - } + mir::BinaryOperator::Add => scope.block.build(Instr::Add(lhs, rhs)).unwrap(), + mir::BinaryOperator::Minus => scope.block.build(Instr::Sub(lhs, rhs)).unwrap(), + mir::BinaryOperator::Mult => scope.block.build(Instr::Mult(lhs, rhs)).unwrap(), + mir::BinaryOperator::And => scope.block.build(Instr::And(lhs, rhs)).unwrap(), mir::BinaryOperator::Cmp(l) => scope .block - .build(InstructionKind::ICmp(l.int_predicate(), lhs, rhs)) + .build(Instr::ICmp(l.int_predicate(), lhs, rhs)) .unwrap(), }) } @@ -277,7 +257,7 @@ impl mir::Expression { Some( scope .block - .build(InstructionKind::FunctionCall(callee.value(), params)) + .build(Instr::FunctionCall(callee.value(), params)) .unwrap(), ) } @@ -287,7 +267,7 @@ impl mir::Expression { if let Some(ret) = block.codegen(&mut inner_scope) { inner_scope .block - .terminate(TerminatorKind::Branch(scope.block.value())) + .terminate(Term::Br(scope.block.value())) .unwrap(); Some(ret) } else { @@ -321,7 +301,7 @@ impl mir::Block { let ret = expr.codegen(&mut scope).unwrap(); match kind { mir::ReturnKind::Hard => { - scope.block.terminate(TerminatorKind::Ret(ret)).unwrap(); + scope.block.terminate(Term::Ret(ret)).unwrap(); None } mir::ReturnKind::Soft => Some(ret), @@ -337,8 +317,8 @@ impl mir::Literal { block.build(self.as_const_kind()).unwrap() } - fn as_const_kind(&self) -> InstructionKind { - InstructionKind::Constant(match *self { + fn as_const_kind(&self) -> Instr { + Instr::Constant(match *self { mir::Literal::I8(val) => ConstValue::I8(val), mir::Literal::I16(val) => ConstValue::I16(val), mir::Literal::I32(val) => ConstValue::I32(val), diff --git a/reid/src/lexer.rs b/reid/src/lexer.rs index e1daf65..689e3ba 100644 --- a/reid/src/lexer.rs +++ b/reid/src/lexer.rs @@ -22,6 +22,8 @@ pub enum Token { Arrow, /// `if` If, + /// `else` + Else, /// `true` True, /// `false` @@ -172,6 +174,7 @@ pub fn tokenize>(to_tokenize: T) -> Result, Error "return" => Token::ReturnKeyword, "fn" => Token::FnKeyword, "if" => Token::If, + "else" => Token::Else, "true" => Token::True, "false" => Token::False, _ => Token::Identifier(value), diff --git a/reid/src/lib.rs b/reid/src/lib.rs index ccc9367..fb69095 100644 --- a/reid/src/lib.rs +++ b/reid/src/lib.rs @@ -87,6 +87,8 @@ pub fn compile(source: &str) -> Result { dbg!(&ast_module); let mut mir_context = mir::Context::from(vec![ast_module]); + println!("{}", &mir_context); + let state = mir_context.pass(&mut TypeCheck); dbg!(&state); diff --git a/reid/src/mir/display.rs b/reid/src/mir/display.rs index ea44a57..fe7aea2 100644 --- a/reid/src/mir/display.rs +++ b/reid/src/mir/display.rs @@ -76,7 +76,7 @@ impl Display for Block { if let Some(ret) = &self.return_expression { match ret.0 { ReturnKind::Hard => writeln!(inner_f, "Return(Hard): {}", ret.1), - ReturnKind::Soft => writeln!(inner_f, "Return(Hard): {}", ret.1), + ReturnKind::Soft => writeln!(inner_f, "Return(Soft): {}", ret.1), }?; } else { writeln!(inner_f, "No Return")?; diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index 334d525..03eba81 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -252,7 +252,18 @@ impl Expression { } else { Vague(Unknown) }; - then_ret_t.collapse_into(&else_ret_t) + + let collapsed = then_ret_t.collapse_into(&else_ret_t)?; + if let Some(rhs) = rhs { + // If rhs existed, typecheck both sides to perform type + // coercion. + let lhs_res = lhs.typecheck(state, Some(collapsed)); + let rhs_res = rhs.typecheck(state, Some(collapsed)); + state.ok(lhs_res, lhs.meta); + state.ok(rhs_res, rhs.meta); + } + + Ok(collapsed) } ExprKind::Block(block) => block.typecheck(state, hint_t), }