Fiddle around with if-expression codegen

This commit is contained in:
Sofia 2025-07-09 21:12:39 +03:00
parent c50474cc8e
commit d757ac4eb3
13 changed files with 99 additions and 86 deletions

View File

@ -1,8 +1,8 @@
use reid_lib::{ConstValue, Context, InstructionKind, CmpPredicate, TerminatorKind, Type};
use reid_lib::{ConstValue, Context, Instr, CmpPredicate, TerminatorKind, Type};
fn main() {
use ConstValue::*;
use InstructionKind::*;
use Instr::*;
let context = Context::new();

View File

@ -4,8 +4,8 @@
use std::{cell::RefCell, rc::Rc};
use crate::{
BlockData, ConstValue, FunctionData, InstructionData, InstructionKind, ModuleData,
TerminatorKind, Type, util::match_types,
BlockData, ConstValue, FunctionData, Instr, InstructionData, ModuleData, TerminatorKind, Type,
util::match_types,
};
#[derive(Clone, Hash, Copy, PartialEq, Eq)]
@ -196,7 +196,7 @@ impl Builder {
}
pub fn check_instruction(&self, instruction: &InstructionValue) -> Result<(), ()> {
use super::InstructionKind::*;
use super::Instr::*;
unsafe {
match self.instr_data(&instruction).kind {
Param(_) => Ok(()),
@ -228,6 +228,11 @@ impl Builder {
Phi(vals) => {
let mut iter = vals.iter();
// TODO error: Phi must contain at least one item
// TODO error: compile can actually crash here if any of the
// incoming values come from blocks that are added later
// than the one where this one exists.
let first = iter.next().ok_or(())?;
for item in iter {
match_types(first, item, &self)?;
@ -241,7 +246,7 @@ impl Builder {
impl InstructionValue {
pub(crate) fn get_type(&self, builder: &Builder) -> Result<Type, ()> {
use InstructionKind::*;
use Instr::*;
unsafe {
match &builder.instr_data(self).kind {
Param(nth) => builder
@ -323,7 +328,7 @@ impl TerminatorKind {
use TerminatorKind::*;
match self {
Ret(instr_val) => instr_val.get_type(builder),
Branch(_) => Ok(Type::Void),
Br(_) => Ok(Type::Void),
CondBr(_, _, _) => Ok(Type::Void),
}
}

View File

@ -253,7 +253,7 @@ impl InstructionHolder {
) -> LLVMValue {
let _ty = self.value.get_type(module.builder).unwrap();
let val = unsafe {
use super::InstructionKind::*;
use super::Instr::*;
match &self.data.kind {
Param(nth) => LLVMGetParam(function.value_ref, *nth as u32),
Constant(val) => val.as_llvm(module.context_ref),
@ -348,7 +348,7 @@ impl TerminatorKind {
let value = module.values.get(val).unwrap();
LLVMBuildRet(module.builder_ref, value.value_ref)
}
TerminatorKind::Branch(block_value) => {
TerminatorKind::Br(block_value) => {
let dest = *module.blocks.get(block_value).unwrap();
LLVMBuildBr(module.builder_ref, dest)
}

View File

@ -2,7 +2,7 @@
use std::fmt::{Debug, Write};
use crate::{CmpPredicate, InstructionData, InstructionKind, TerminatorKind, builder::*};
use crate::{CmpPredicate, Instr, InstructionData, TerminatorKind, builder::*};
impl Debug for Builder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -81,7 +81,7 @@ impl Debug for InstructionValue {
}
}
impl Debug for InstructionKind {
impl Debug for Instr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Param(nth) => fmt_call(f, &"Param", &nth),
@ -141,7 +141,7 @@ impl Debug for TerminatorKind {
write!(f, "Ret ")?;
val.fmt(f)
}
Self::Branch(val) => {
Self::Br(val) => {
write!(f, "Br ")?;
val.fmt(f)
}

View File

@ -119,7 +119,7 @@ pub struct Block<'builder> {
}
impl<'builder> Block<'builder> {
pub fn build(&mut self, instruction: InstructionKind) -> Result<InstructionValue, ()> {
pub fn build(&mut self, instruction: Instr) -> Result<InstructionValue, ()> {
unsafe {
self.builder
.add_instruction(&self.value, InstructionData { kind: instruction })
@ -137,7 +137,7 @@ impl<'builder> Block<'builder> {
#[derive(Clone, Hash)]
pub struct InstructionData {
kind: InstructionKind,
kind: Instr,
}
#[derive(Clone, Copy, Hash)]
@ -151,7 +151,7 @@ pub enum CmpPredicate {
}
#[derive(Clone, Hash)]
pub enum InstructionKind {
pub enum Instr {
Param(usize),
Constant(ConstValue),
Add(InstructionValue, InstructionValue),
@ -200,6 +200,6 @@ pub enum ConstValue {
#[derive(Clone, Hash)]
pub enum TerminatorKind {
Ret(InstructionValue),
Branch(BlockValue),
Br(BlockValue),
CondBr(InstructionValue, BlockValue, BlockValue),
}

View File

@ -5,8 +5,10 @@ fn main() -> bool {
// Fibonacci
fn fibonacci(value: u16) -> u16 {
if value <= 2 {
return 1;
}
return fibonacci(value - 1) + fibonacci(value - 2);
let ret = if value <= 2 {
1
} else {
fibonacci(value - 1) + fibonacci(value - 2)
};
ret
}

View File

@ -1,3 +1,6 @@
//! This is the module that contains relevant code to parsing Reid, that is to
//! say transforming a Vec of FullTokens into a loose parsed AST that can be
//! used for unwrapping syntax sugar, and then be transformed into Reid MIR.
use crate::token_stream::TokenRange;
pub mod parse;

View File

@ -204,10 +204,17 @@ impl Parse for FunctionCallExpression {
impl Parse for IfExpression {
fn parse(mut stream: TokenStream) -> Result<Self, Error> {
stream.expect(Token::If)?;
let cond = stream.parse()?;
let then_b = stream.parse()?;
let else_b = if let Ok(_) = stream.expect(Token::Else) {
Some(stream.parse()?)
} else {
None
};
Ok(IfExpression(
stream.parse()?,
stream.parse()?,
None,
cond,
then_b,
else_b,
stream.get_range().unwrap(),
))
}

View File

@ -1,11 +1,11 @@
use std::{collections::HashMap, mem};
use reid_lib::{
builder::InstructionValue, Block, CmpPredicate, ConstValue, Context, Function, InstructionKind,
Module, TerminatorKind, Type,
Block, CmpPredicate, ConstValue, Context, Function, Instr, Module, TerminatorKind as Term,
Type, builder::InstructionValue,
};
use crate::mir::{self, types::ReturnType, TypeKind, VariableReference};
use crate::mir::{self, TypeKind, VariableReference, types::ReturnType};
/// Context that contains all of the given modules as complete codegenerated
/// LLIR that can then be finally compiled into LLVM IR.
@ -74,10 +74,7 @@ impl mir::Module {
let mut stack_values = HashMap::new();
for (i, (p_name, _)) in mir_function.parameters.iter().enumerate() {
stack_values.insert(
p_name.clone(),
entry.build(InstructionKind::Param(i)).unwrap(),
);
stack_values.insert(p_name.clone(), entry.build(Instr::Param(i)).unwrap());
}
let mut scope = Scope {
@ -91,7 +88,7 @@ impl mir::Module {
match &mir_function.kind {
mir::FunctionDefinitionKind::Local(block, _) => {
if let Some(ret) = block.codegen(&mut scope) {
scope.block.terminate(TerminatorKind::Ret(ret)).unwrap();
scope.block.terminate(Term::Ret(ret)).unwrap();
}
}
mir::FunctionDefinitionKind::Extern => {}
@ -155,62 +152,53 @@ impl mir::IfExpression {
let condition = self.0.codegen(scope).unwrap();
// Create blocks
let then_bb = scope.function.block("then");
let after_bb = scope.function.block("after");
let mut before_bb = scope.swap_block(after_bb);
let then_b = scope.function.block("then");
let mut else_b = scope.function.block("else");
let after_b = scope.function.block("after");
let mut then_scope = scope.with_block(then_bb);
// Store for convenience
let then_bb = then_b.value();
let else_bb = else_b.value();
let after_bb = after_b.value();
// Generate then-block content
let mut then_scope = scope.with_block(then_b);
let then_res = self.1.codegen(&mut then_scope);
then_scope
.block
.terminate(TerminatorKind::Branch(scope.block.value()))
.ok();
let else_bb = scope.function.block("else");
let mut else_scope = scope.with_block(else_bb);
then_scope.block.terminate(Term::Br(after_bb)).ok();
let else_res = if let Some(else_block) = &self.2 {
before_bb
.terminate(TerminatorKind::CondBr(
condition,
then_scope.block.value(),
else_scope.block.value(),
))
let mut else_scope = scope.with_block(else_b);
scope
.block
.terminate(Term::CondBr(condition, then_bb, else_bb))
.unwrap();
let opt = else_block.codegen(&mut else_scope);
if let Some(ret) = opt {
else_scope
.block
.terminate(TerminatorKind::Branch(scope.block.value()))
.ok();
else_scope.block.terminate(Term::Br(after_bb)).ok();
Some(ret)
} else {
None
}
} else {
else_scope
else_b.terminate(Term::Br(after_bb)).unwrap();
scope
.block
.terminate(TerminatorKind::Branch(scope.block.value()))
.unwrap();
before_bb
.terminate(TerminatorKind::CondBr(
condition,
then_scope.block.value(),
scope.block.value(),
))
.terminate(Term::CondBr(condition, then_bb, after_bb))
.unwrap();
None
};
// Swap block to the after-block so that construction can continue correctly
scope.swap_block(after_b);
if then_res.is_none() && else_res.is_none() {
None
} else {
let mut inc = Vec::from(then_res.as_slice());
inc.extend(else_res);
Some(scope.block.build(InstructionKind::Phi(vec![])).unwrap())
let mut incoming = Vec::from(then_res.as_slice());
incoming.extend(else_res);
Some(scope.block.build(Instr::Phi(incoming)).unwrap())
}
}
}
@ -242,21 +230,13 @@ impl mir::Expression {
let lhs = lhs_exp.codegen(scope).expect("lhs has no return value");
let rhs = rhs_exp.codegen(scope).expect("rhs has no return value");
Some(match binop {
mir::BinaryOperator::Add => {
scope.block.build(InstructionKind::Add(lhs, rhs)).unwrap()
}
mir::BinaryOperator::Minus => {
scope.block.build(InstructionKind::Sub(lhs, rhs)).unwrap()
}
mir::BinaryOperator::Mult => {
scope.block.build(InstructionKind::Mult(lhs, rhs)).unwrap()
}
mir::BinaryOperator::And => {
scope.block.build(InstructionKind::And(lhs, rhs)).unwrap()
}
mir::BinaryOperator::Add => scope.block.build(Instr::Add(lhs, rhs)).unwrap(),
mir::BinaryOperator::Minus => scope.block.build(Instr::Sub(lhs, rhs)).unwrap(),
mir::BinaryOperator::Mult => scope.block.build(Instr::Mult(lhs, rhs)).unwrap(),
mir::BinaryOperator::And => scope.block.build(Instr::And(lhs, rhs)).unwrap(),
mir::BinaryOperator::Cmp(l) => scope
.block
.build(InstructionKind::ICmp(l.int_predicate(), lhs, rhs))
.build(Instr::ICmp(l.int_predicate(), lhs, rhs))
.unwrap(),
})
}
@ -277,7 +257,7 @@ impl mir::Expression {
Some(
scope
.block
.build(InstructionKind::FunctionCall(callee.value(), params))
.build(Instr::FunctionCall(callee.value(), params))
.unwrap(),
)
}
@ -287,7 +267,7 @@ impl mir::Expression {
if let Some(ret) = block.codegen(&mut inner_scope) {
inner_scope
.block
.terminate(TerminatorKind::Branch(scope.block.value()))
.terminate(Term::Br(scope.block.value()))
.unwrap();
Some(ret)
} else {
@ -321,7 +301,7 @@ impl mir::Block {
let ret = expr.codegen(&mut scope).unwrap();
match kind {
mir::ReturnKind::Hard => {
scope.block.terminate(TerminatorKind::Ret(ret)).unwrap();
scope.block.terminate(Term::Ret(ret)).unwrap();
None
}
mir::ReturnKind::Soft => Some(ret),
@ -337,8 +317,8 @@ impl mir::Literal {
block.build(self.as_const_kind()).unwrap()
}
fn as_const_kind(&self) -> InstructionKind {
InstructionKind::Constant(match *self {
fn as_const_kind(&self) -> Instr {
Instr::Constant(match *self {
mir::Literal::I8(val) => ConstValue::I8(val),
mir::Literal::I16(val) => ConstValue::I16(val),
mir::Literal::I32(val) => ConstValue::I32(val),

View File

@ -22,6 +22,8 @@ pub enum Token {
Arrow,
/// `if`
If,
/// `else`
Else,
/// `true`
True,
/// `false`
@ -172,6 +174,7 @@ pub fn tokenize<T: Into<String>>(to_tokenize: T) -> Result<Vec<FullToken>, Error
"return" => Token::ReturnKeyword,
"fn" => Token::FnKeyword,
"if" => Token::If,
"else" => Token::Else,
"true" => Token::True,
"false" => Token::False,
_ => Token::Identifier(value),

View File

@ -87,6 +87,8 @@ pub fn compile(source: &str) -> Result<String, ReidError> {
dbg!(&ast_module);
let mut mir_context = mir::Context::from(vec![ast_module]);
println!("{}", &mir_context);
let state = mir_context.pass(&mut TypeCheck);
dbg!(&state);

View File

@ -76,7 +76,7 @@ impl Display for Block {
if let Some(ret) = &self.return_expression {
match ret.0 {
ReturnKind::Hard => writeln!(inner_f, "Return(Hard): {}", ret.1),
ReturnKind::Soft => writeln!(inner_f, "Return(Hard): {}", ret.1),
ReturnKind::Soft => writeln!(inner_f, "Return(Soft): {}", ret.1),
}?;
} else {
writeln!(inner_f, "No Return")?;

View File

@ -252,7 +252,18 @@ impl Expression {
} else {
Vague(Unknown)
};
then_ret_t.collapse_into(&else_ret_t)
let collapsed = then_ret_t.collapse_into(&else_ret_t)?;
if let Some(rhs) = rhs {
// If rhs existed, typecheck both sides to perform type
// coercion.
let lhs_res = lhs.typecheck(state, Some(collapsed));
let rhs_res = rhs.typecheck(state, Some(collapsed));
state.ok(lhs_res, lhs.meta);
state.ok(rhs_res, rhs.meta);
}
Ok(collapsed)
}
ExprKind::Block(block) => block.typecheck(state, hint_t),
}