Fiddle around with if-expression codegen

This commit is contained in:
Sofia 2025-07-09 21:12:39 +03:00
parent c50474cc8e
commit d757ac4eb3
13 changed files with 99 additions and 86 deletions

View File

@ -1,8 +1,8 @@
use reid_lib::{ConstValue, Context, InstructionKind, CmpPredicate, TerminatorKind, Type}; use reid_lib::{ConstValue, Context, Instr, CmpPredicate, TerminatorKind, Type};
fn main() { fn main() {
use ConstValue::*; use ConstValue::*;
use InstructionKind::*; use Instr::*;
let context = Context::new(); let context = Context::new();

View File

@ -4,8 +4,8 @@
use std::{cell::RefCell, rc::Rc}; use std::{cell::RefCell, rc::Rc};
use crate::{ use crate::{
BlockData, ConstValue, FunctionData, InstructionData, InstructionKind, ModuleData, BlockData, ConstValue, FunctionData, Instr, InstructionData, ModuleData, TerminatorKind, Type,
TerminatorKind, Type, util::match_types, util::match_types,
}; };
#[derive(Clone, Hash, Copy, PartialEq, Eq)] #[derive(Clone, Hash, Copy, PartialEq, Eq)]
@ -196,7 +196,7 @@ impl Builder {
} }
pub fn check_instruction(&self, instruction: &InstructionValue) -> Result<(), ()> { pub fn check_instruction(&self, instruction: &InstructionValue) -> Result<(), ()> {
use super::InstructionKind::*; use super::Instr::*;
unsafe { unsafe {
match self.instr_data(&instruction).kind { match self.instr_data(&instruction).kind {
Param(_) => Ok(()), Param(_) => Ok(()),
@ -228,6 +228,11 @@ impl Builder {
Phi(vals) => { Phi(vals) => {
let mut iter = vals.iter(); let mut iter = vals.iter();
// TODO error: Phi must contain at least one item // TODO error: Phi must contain at least one item
// TODO error: compile can actually crash here if any of the
// incoming values come from blocks that are added later
// than the one where this one exists.
let first = iter.next().ok_or(())?; let first = iter.next().ok_or(())?;
for item in iter { for item in iter {
match_types(first, item, &self)?; match_types(first, item, &self)?;
@ -241,7 +246,7 @@ impl Builder {
impl InstructionValue { impl InstructionValue {
pub(crate) fn get_type(&self, builder: &Builder) -> Result<Type, ()> { pub(crate) fn get_type(&self, builder: &Builder) -> Result<Type, ()> {
use InstructionKind::*; use Instr::*;
unsafe { unsafe {
match &builder.instr_data(self).kind { match &builder.instr_data(self).kind {
Param(nth) => builder Param(nth) => builder
@ -323,7 +328,7 @@ impl TerminatorKind {
use TerminatorKind::*; use TerminatorKind::*;
match self { match self {
Ret(instr_val) => instr_val.get_type(builder), Ret(instr_val) => instr_val.get_type(builder),
Branch(_) => Ok(Type::Void), Br(_) => Ok(Type::Void),
CondBr(_, _, _) => Ok(Type::Void), CondBr(_, _, _) => Ok(Type::Void),
} }
} }

View File

@ -253,7 +253,7 @@ impl InstructionHolder {
) -> LLVMValue { ) -> LLVMValue {
let _ty = self.value.get_type(module.builder).unwrap(); let _ty = self.value.get_type(module.builder).unwrap();
let val = unsafe { let val = unsafe {
use super::InstructionKind::*; use super::Instr::*;
match &self.data.kind { match &self.data.kind {
Param(nth) => LLVMGetParam(function.value_ref, *nth as u32), Param(nth) => LLVMGetParam(function.value_ref, *nth as u32),
Constant(val) => val.as_llvm(module.context_ref), Constant(val) => val.as_llvm(module.context_ref),
@ -348,7 +348,7 @@ impl TerminatorKind {
let value = module.values.get(val).unwrap(); let value = module.values.get(val).unwrap();
LLVMBuildRet(module.builder_ref, value.value_ref) LLVMBuildRet(module.builder_ref, value.value_ref)
} }
TerminatorKind::Branch(block_value) => { TerminatorKind::Br(block_value) => {
let dest = *module.blocks.get(block_value).unwrap(); let dest = *module.blocks.get(block_value).unwrap();
LLVMBuildBr(module.builder_ref, dest) LLVMBuildBr(module.builder_ref, dest)
} }

View File

@ -2,7 +2,7 @@
use std::fmt::{Debug, Write}; use std::fmt::{Debug, Write};
use crate::{CmpPredicate, InstructionData, InstructionKind, TerminatorKind, builder::*}; use crate::{CmpPredicate, Instr, InstructionData, TerminatorKind, builder::*};
impl Debug for Builder { impl Debug for Builder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -81,7 +81,7 @@ impl Debug for InstructionValue {
} }
} }
impl Debug for InstructionKind { impl Debug for Instr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::Param(nth) => fmt_call(f, &"Param", &nth), Self::Param(nth) => fmt_call(f, &"Param", &nth),
@ -141,7 +141,7 @@ impl Debug for TerminatorKind {
write!(f, "Ret ")?; write!(f, "Ret ")?;
val.fmt(f) val.fmt(f)
} }
Self::Branch(val) => { Self::Br(val) => {
write!(f, "Br ")?; write!(f, "Br ")?;
val.fmt(f) val.fmt(f)
} }

View File

@ -119,7 +119,7 @@ pub struct Block<'builder> {
} }
impl<'builder> Block<'builder> { impl<'builder> Block<'builder> {
pub fn build(&mut self, instruction: InstructionKind) -> Result<InstructionValue, ()> { pub fn build(&mut self, instruction: Instr) -> Result<InstructionValue, ()> {
unsafe { unsafe {
self.builder self.builder
.add_instruction(&self.value, InstructionData { kind: instruction }) .add_instruction(&self.value, InstructionData { kind: instruction })
@ -137,7 +137,7 @@ impl<'builder> Block<'builder> {
#[derive(Clone, Hash)] #[derive(Clone, Hash)]
pub struct InstructionData { pub struct InstructionData {
kind: InstructionKind, kind: Instr,
} }
#[derive(Clone, Copy, Hash)] #[derive(Clone, Copy, Hash)]
@ -151,7 +151,7 @@ pub enum CmpPredicate {
} }
#[derive(Clone, Hash)] #[derive(Clone, Hash)]
pub enum InstructionKind { pub enum Instr {
Param(usize), Param(usize),
Constant(ConstValue), Constant(ConstValue),
Add(InstructionValue, InstructionValue), Add(InstructionValue, InstructionValue),
@ -200,6 +200,6 @@ pub enum ConstValue {
#[derive(Clone, Hash)] #[derive(Clone, Hash)]
pub enum TerminatorKind { pub enum TerminatorKind {
Ret(InstructionValue), Ret(InstructionValue),
Branch(BlockValue), Br(BlockValue),
CondBr(InstructionValue, BlockValue, BlockValue), CondBr(InstructionValue, BlockValue, BlockValue),
} }

View File

@ -5,8 +5,10 @@ fn main() -> bool {
// Fibonacci // Fibonacci
fn fibonacci(value: u16) -> u16 { fn fibonacci(value: u16) -> u16 {
if value <= 2 { let ret = if value <= 2 {
return 1; 1
} } else {
return fibonacci(value - 1) + fibonacci(value - 2); fibonacci(value - 1) + fibonacci(value - 2)
};
ret
} }

View File

@ -1,3 +1,6 @@
//! This is the module that contains relevant code to parsing Reid, that is to
//! say transforming a Vec of FullTokens into a loose parsed AST that can be
//! used for unwrapping syntax sugar, and then be transformed into Reid MIR.
use crate::token_stream::TokenRange; use crate::token_stream::TokenRange;
pub mod parse; pub mod parse;

View File

@ -204,10 +204,17 @@ impl Parse for FunctionCallExpression {
impl Parse for IfExpression { impl Parse for IfExpression {
fn parse(mut stream: TokenStream) -> Result<Self, Error> { fn parse(mut stream: TokenStream) -> Result<Self, Error> {
stream.expect(Token::If)?; stream.expect(Token::If)?;
let cond = stream.parse()?;
let then_b = stream.parse()?;
let else_b = if let Ok(_) = stream.expect(Token::Else) {
Some(stream.parse()?)
} else {
None
};
Ok(IfExpression( Ok(IfExpression(
stream.parse()?, cond,
stream.parse()?, then_b,
None, else_b,
stream.get_range().unwrap(), stream.get_range().unwrap(),
)) ))
} }
@ -324,7 +331,7 @@ impl Parse for Block {
ReturnType::Hard => { ReturnType::Hard => {
return_stmt = Some((*r_type, e.clone())); return_stmt = Some((*r_type, e.clone()));
break; // Return has to be the last statement break; // Return has to be the last statement
// TODO: Make a mechanism that "can" parse even after this // TODO: Make a mechanism that "can" parse even after this
} }
ReturnType::Soft => { ReturnType::Soft => {
return_stmt = Some((*r_type, e.clone())); return_stmt = Some((*r_type, e.clone()));

View File

@ -1,11 +1,11 @@
use std::{collections::HashMap, mem}; use std::{collections::HashMap, mem};
use reid_lib::{ use reid_lib::{
builder::InstructionValue, Block, CmpPredicate, ConstValue, Context, Function, InstructionKind, Block, CmpPredicate, ConstValue, Context, Function, Instr, Module, TerminatorKind as Term,
Module, TerminatorKind, Type, Type, builder::InstructionValue,
}; };
use crate::mir::{self, types::ReturnType, TypeKind, VariableReference}; use crate::mir::{self, TypeKind, VariableReference, types::ReturnType};
/// Context that contains all of the given modules as complete codegenerated /// Context that contains all of the given modules as complete codegenerated
/// LLIR that can then be finally compiled into LLVM IR. /// LLIR that can then be finally compiled into LLVM IR.
@ -74,10 +74,7 @@ impl mir::Module {
let mut stack_values = HashMap::new(); let mut stack_values = HashMap::new();
for (i, (p_name, _)) in mir_function.parameters.iter().enumerate() { for (i, (p_name, _)) in mir_function.parameters.iter().enumerate() {
stack_values.insert( stack_values.insert(p_name.clone(), entry.build(Instr::Param(i)).unwrap());
p_name.clone(),
entry.build(InstructionKind::Param(i)).unwrap(),
);
} }
let mut scope = Scope { let mut scope = Scope {
@ -91,7 +88,7 @@ impl mir::Module {
match &mir_function.kind { match &mir_function.kind {
mir::FunctionDefinitionKind::Local(block, _) => { mir::FunctionDefinitionKind::Local(block, _) => {
if let Some(ret) = block.codegen(&mut scope) { if let Some(ret) = block.codegen(&mut scope) {
scope.block.terminate(TerminatorKind::Ret(ret)).unwrap(); scope.block.terminate(Term::Ret(ret)).unwrap();
} }
} }
mir::FunctionDefinitionKind::Extern => {} mir::FunctionDefinitionKind::Extern => {}
@ -155,62 +152,53 @@ impl mir::IfExpression {
let condition = self.0.codegen(scope).unwrap(); let condition = self.0.codegen(scope).unwrap();
// Create blocks // Create blocks
let then_bb = scope.function.block("then"); let then_b = scope.function.block("then");
let after_bb = scope.function.block("after"); let mut else_b = scope.function.block("else");
let mut before_bb = scope.swap_block(after_bb); let after_b = scope.function.block("after");
let mut then_scope = scope.with_block(then_bb); // Store for convenience
let then_bb = then_b.value();
let else_bb = else_b.value();
let after_bb = after_b.value();
// Generate then-block content
let mut then_scope = scope.with_block(then_b);
let then_res = self.1.codegen(&mut then_scope); let then_res = self.1.codegen(&mut then_scope);
then_scope then_scope.block.terminate(Term::Br(after_bb)).ok();
.block
.terminate(TerminatorKind::Branch(scope.block.value()))
.ok();
let else_bb = scope.function.block("else");
let mut else_scope = scope.with_block(else_bb);
let else_res = if let Some(else_block) = &self.2 { let else_res = if let Some(else_block) = &self.2 {
before_bb let mut else_scope = scope.with_block(else_b);
.terminate(TerminatorKind::CondBr( scope
condition, .block
then_scope.block.value(), .terminate(Term::CondBr(condition, then_bb, else_bb))
else_scope.block.value(),
))
.unwrap(); .unwrap();
let opt = else_block.codegen(&mut else_scope); let opt = else_block.codegen(&mut else_scope);
if let Some(ret) = opt { if let Some(ret) = opt {
else_scope else_scope.block.terminate(Term::Br(after_bb)).ok();
.block
.terminate(TerminatorKind::Branch(scope.block.value()))
.ok();
Some(ret) Some(ret)
} else { } else {
None None
} }
} else { } else {
else_scope else_b.terminate(Term::Br(after_bb)).unwrap();
scope
.block .block
.terminate(TerminatorKind::Branch(scope.block.value())) .terminate(Term::CondBr(condition, then_bb, after_bb))
.unwrap();
before_bb
.terminate(TerminatorKind::CondBr(
condition,
then_scope.block.value(),
scope.block.value(),
))
.unwrap(); .unwrap();
None None
}; };
// Swap block to the after-block so that construction can continue correctly
scope.swap_block(after_b);
if then_res.is_none() && else_res.is_none() { if then_res.is_none() && else_res.is_none() {
None None
} else { } else {
let mut inc = Vec::from(then_res.as_slice()); let mut incoming = Vec::from(then_res.as_slice());
inc.extend(else_res); incoming.extend(else_res);
Some(scope.block.build(Instr::Phi(incoming)).unwrap())
Some(scope.block.build(InstructionKind::Phi(vec![])).unwrap())
} }
} }
} }
@ -242,21 +230,13 @@ impl mir::Expression {
let lhs = lhs_exp.codegen(scope).expect("lhs has no return value"); let lhs = lhs_exp.codegen(scope).expect("lhs has no return value");
let rhs = rhs_exp.codegen(scope).expect("rhs has no return value"); let rhs = rhs_exp.codegen(scope).expect("rhs has no return value");
Some(match binop { Some(match binop {
mir::BinaryOperator::Add => { mir::BinaryOperator::Add => scope.block.build(Instr::Add(lhs, rhs)).unwrap(),
scope.block.build(InstructionKind::Add(lhs, rhs)).unwrap() mir::BinaryOperator::Minus => scope.block.build(Instr::Sub(lhs, rhs)).unwrap(),
} mir::BinaryOperator::Mult => scope.block.build(Instr::Mult(lhs, rhs)).unwrap(),
mir::BinaryOperator::Minus => { mir::BinaryOperator::And => scope.block.build(Instr::And(lhs, rhs)).unwrap(),
scope.block.build(InstructionKind::Sub(lhs, rhs)).unwrap()
}
mir::BinaryOperator::Mult => {
scope.block.build(InstructionKind::Mult(lhs, rhs)).unwrap()
}
mir::BinaryOperator::And => {
scope.block.build(InstructionKind::And(lhs, rhs)).unwrap()
}
mir::BinaryOperator::Cmp(l) => scope mir::BinaryOperator::Cmp(l) => scope
.block .block
.build(InstructionKind::ICmp(l.int_predicate(), lhs, rhs)) .build(Instr::ICmp(l.int_predicate(), lhs, rhs))
.unwrap(), .unwrap(),
}) })
} }
@ -277,7 +257,7 @@ impl mir::Expression {
Some( Some(
scope scope
.block .block
.build(InstructionKind::FunctionCall(callee.value(), params)) .build(Instr::FunctionCall(callee.value(), params))
.unwrap(), .unwrap(),
) )
} }
@ -287,7 +267,7 @@ impl mir::Expression {
if let Some(ret) = block.codegen(&mut inner_scope) { if let Some(ret) = block.codegen(&mut inner_scope) {
inner_scope inner_scope
.block .block
.terminate(TerminatorKind::Branch(scope.block.value())) .terminate(Term::Br(scope.block.value()))
.unwrap(); .unwrap();
Some(ret) Some(ret)
} else { } else {
@ -321,7 +301,7 @@ impl mir::Block {
let ret = expr.codegen(&mut scope).unwrap(); let ret = expr.codegen(&mut scope).unwrap();
match kind { match kind {
mir::ReturnKind::Hard => { mir::ReturnKind::Hard => {
scope.block.terminate(TerminatorKind::Ret(ret)).unwrap(); scope.block.terminate(Term::Ret(ret)).unwrap();
None None
} }
mir::ReturnKind::Soft => Some(ret), mir::ReturnKind::Soft => Some(ret),
@ -337,8 +317,8 @@ impl mir::Literal {
block.build(self.as_const_kind()).unwrap() block.build(self.as_const_kind()).unwrap()
} }
fn as_const_kind(&self) -> InstructionKind { fn as_const_kind(&self) -> Instr {
InstructionKind::Constant(match *self { Instr::Constant(match *self {
mir::Literal::I8(val) => ConstValue::I8(val), mir::Literal::I8(val) => ConstValue::I8(val),
mir::Literal::I16(val) => ConstValue::I16(val), mir::Literal::I16(val) => ConstValue::I16(val),
mir::Literal::I32(val) => ConstValue::I32(val), mir::Literal::I32(val) => ConstValue::I32(val),

View File

@ -22,6 +22,8 @@ pub enum Token {
Arrow, Arrow,
/// `if` /// `if`
If, If,
/// `else`
Else,
/// `true` /// `true`
True, True,
/// `false` /// `false`
@ -172,6 +174,7 @@ pub fn tokenize<T: Into<String>>(to_tokenize: T) -> Result<Vec<FullToken>, Error
"return" => Token::ReturnKeyword, "return" => Token::ReturnKeyword,
"fn" => Token::FnKeyword, "fn" => Token::FnKeyword,
"if" => Token::If, "if" => Token::If,
"else" => Token::Else,
"true" => Token::True, "true" => Token::True,
"false" => Token::False, "false" => Token::False,
_ => Token::Identifier(value), _ => Token::Identifier(value),

View File

@ -87,6 +87,8 @@ pub fn compile(source: &str) -> Result<String, ReidError> {
dbg!(&ast_module); dbg!(&ast_module);
let mut mir_context = mir::Context::from(vec![ast_module]); let mut mir_context = mir::Context::from(vec![ast_module]);
println!("{}", &mir_context);
let state = mir_context.pass(&mut TypeCheck); let state = mir_context.pass(&mut TypeCheck);
dbg!(&state); dbg!(&state);

View File

@ -76,7 +76,7 @@ impl Display for Block {
if let Some(ret) = &self.return_expression { if let Some(ret) = &self.return_expression {
match ret.0 { match ret.0 {
ReturnKind::Hard => writeln!(inner_f, "Return(Hard): {}", ret.1), ReturnKind::Hard => writeln!(inner_f, "Return(Hard): {}", ret.1),
ReturnKind::Soft => writeln!(inner_f, "Return(Hard): {}", ret.1), ReturnKind::Soft => writeln!(inner_f, "Return(Soft): {}", ret.1),
}?; }?;
} else { } else {
writeln!(inner_f, "No Return")?; writeln!(inner_f, "No Return")?;

View File

@ -252,7 +252,18 @@ impl Expression {
} else { } else {
Vague(Unknown) Vague(Unknown)
}; };
then_ret_t.collapse_into(&else_ret_t)
let collapsed = then_ret_t.collapse_into(&else_ret_t)?;
if let Some(rhs) = rhs {
// If rhs existed, typecheck both sides to perform type
// coercion.
let lhs_res = lhs.typecheck(state, Some(collapsed));
let rhs_res = rhs.typecheck(state, Some(collapsed));
state.ok(lhs_res, lhs.meta);
state.ok(rhs_res, rhs.meta);
}
Ok(collapsed)
} }
ExprKind::Block(block) => block.typecheck(state, hint_t), ExprKind::Block(block) => block.typecheck(state, hint_t),
} }