From b19a32cd8a0380bf98c0b24e3095ee3ee79a4427 Mon Sep 17 00:00:00 2001 From: sofia Date: Wed, 9 Jul 2025 22:20:08 +0300 Subject: [PATCH] Make early returns work even without an explicit return --- reid/examples/reid/fibonacci.reid | 2 +- reid/src/codegen.rs | 4 +++- reid/src/mir/mod.rs | 2 +- reid/src/mir/typecheck.rs | 27 ++++++++++++++++++++++--- reid/src/mir/types.rs | 33 ++++++++++++++++++++----------- 5 files changed, 50 insertions(+), 18 deletions(-) diff --git a/reid/examples/reid/fibonacci.reid b/reid/examples/reid/fibonacci.reid index 9eb4150..2c0224a 100644 --- a/reid/examples/reid/fibonacci.reid +++ b/reid/examples/reid/fibonacci.reid @@ -9,5 +9,5 @@ fn fibonacci(value: u16) -> u16 { return 1; } else { return fibonacci(value - 1) + fibonacci(value - 2); - } + }; } diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index de56801..95de547 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -93,7 +93,7 @@ impl mir::Module { if !scope.block.delete_if_unused().unwrap() { // Add a void return just in case if the block // wasn't unused but didn't have a terminator yet - scope.block.terminate(Term::RetVoid).unwrap(); + scope.block.terminate(Term::RetVoid).ok(); } } } @@ -225,11 +225,13 @@ impl mir::Expression { lhs_exp .return_type() .expect("No ret type in lhs?") + .1 .is_known() .expect("lhs ret type is unknown"); rhs_exp .return_type() .expect("No ret type in rhs?") + .1 .is_known() .expect("rhs ret type is unknown"); diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index 76cf592..84ef2dc 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -177,7 +177,7 @@ pub enum CmpOperator { NE, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum ReturnKind { Hard, Soft, diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index 03eba81..aff931d 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -6,7 +6,10 @@ use crate::{mir::*, util::try_all}; use TypeKind::*; use VagueType::*; -use super::pass::{Pass, PassState, ScopeFunction}; +use super::{ + pass::{Pass, PassState, ScopeFunction}, + types::ReturnType, +}; #[derive(thiserror::Error, Debug, Clone)] pub enum ErrorKind { @@ -86,8 +89,10 @@ impl Block { ) -> Result { let mut state = state.inner(); + let mut early_return = None; + for statement in &mut self.statements { - match &mut statement.0 { + let ret = match &mut statement.0 { StmtKind::Let(variable_reference, expression) => { let res = expression.typecheck(&mut state, Some(variable_reference.0)); @@ -118,15 +123,31 @@ impl Block { variable_reference.1.clone(), ))); state.ok(res, variable_reference.2); + None } StmtKind::Import(_) => todo!(), StmtKind::Expression(expression) => { let res = expression.typecheck(&mut state, None); - state.ok(res, expression.1); + let res_t = state.or_else(res, Void, expression.1); + if let Ok((kind, _)) = expression.return_type() { + Some((kind, expression)) + } else { + None + } } + }; + + if let Some((ReturnKind::Hard, _)) = ret { + early_return = early_return.or(ret); } } + if let Some((ReturnKind::Hard, expr)) = early_return { + let hint = state.scope.return_type_hint; + let res = expr.typecheck(&mut state, hint); + return Ok(state.or_else(res, Vague(Unknown), expr.1)); + } + if let Some((return_kind, expr)) = &mut self.return_expression { // Use function return type as hint if return is hard. let ret_hint_t = match return_kind { diff --git a/reid/src/mir/types.rs b/reid/src/mir/types.rs index 253dea1..9b4f377 100644 --- a/reid/src/mir/types.rs +++ b/reid/src/mir/types.rs @@ -9,20 +9,20 @@ pub enum ReturnTypeOther { } pub trait ReturnType { - fn return_type(&self) -> Result; + fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther>; } impl ReturnType for Block { - fn return_type(&self) -> Result { + fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { self.return_expression .as_ref() .ok_or(ReturnTypeOther::NoBlockReturn(self.meta.range)) - .and_then(|(_, stmt)| stmt.return_type()) + .and_then(|(kind, stmt)| Ok((*kind, stmt.return_type()?.1))) } } impl ReturnType for Statement { - fn return_type(&self) -> Result { + fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { use StmtKind::*; match &self.0 { Expression(e) => e.return_type(), @@ -33,12 +33,21 @@ impl ReturnType for Statement { } impl ReturnType for Expression { - fn return_type(&self) -> Result { + fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { use ExprKind::*; match &self.0 { - Literal(lit) => Ok(lit.as_type()), + Literal(lit) => Ok((ReturnKind::Soft, lit.as_type())), Variable(var) => var.return_type(), - BinOp(_, expr, _) => expr.return_type(), + BinOp(_, then_e, else_e) => { + let then_r = then_e.return_type()?; + let else_e = else_e.return_type()?; + let kind = if then_r.0 == ReturnKind::Hard && else_e.0 == ReturnKind::Hard { + ReturnKind::Hard + } else { + ReturnKind::Hard + }; + Ok((kind, then_r.1)) + } Block(block) => block.return_type(), FunctionCall(fcall) => fcall.return_type(), If(expr) => expr.return_type(), @@ -47,19 +56,19 @@ impl ReturnType for Expression { } impl ReturnType for IfExpression { - fn return_type(&self) -> Result { + fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { self.1.return_type() } } impl ReturnType for VariableReference { - fn return_type(&self) -> Result { - Ok(self.0.clone()) + fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { + Ok((ReturnKind::Soft, self.0.clone())) } } impl ReturnType for FunctionCall { - fn return_type(&self) -> Result { - Ok(self.return_type.clone()) + fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { + Ok((ReturnKind::Soft, self.return_type.clone())) } }