From 0d631bfa89dc788b76222be51257611930923af7 Mon Sep 17 00:00:00 2001 From: sofia Date: Sun, 13 Jul 2025 15:26:36 +0300 Subject: [PATCH] Remove redundant TypeRef, add other optimizations --- reid/src/mir/scopehints.rs | 130 +++++++++++++++++++------------------ reid/src/mir/typecheck.rs | 30 +++++---- 2 files changed, 83 insertions(+), 77 deletions(-) diff --git a/reid/src/mir/scopehints.rs b/reid/src/mir/scopehints.rs index 9912611..b718d87 100644 --- a/reid/src/mir/scopehints.rs +++ b/reid/src/mir/scopehints.rs @@ -1,8 +1,14 @@ -use std::{cell::RefCell, collections::HashMap, rc::Rc}; +use std::{ + any::TypeId, + cell::RefCell, + collections::{HashMap, HashSet}, + hint::black_box, + rc::Rc, +}; use super::{ typecheck::{Collapsable, ErrorKind}, - BinaryOperator, Literal, TypeKind, VagueType, + BinaryOperator, TypeKind, }; #[derive(Clone)] @@ -13,11 +19,8 @@ impl<'scope> ScopeHint<'scope> { unsafe { *self.1.types.hints.borrow().get_unchecked(*self.0.borrow()) } } - pub fn narrow(&mut self, ty_ref: &TypeRef) -> Result, ErrorKind> { - match ty_ref { - TypeRef::Hint(other) => self.1.combine_vars(self, other), - TypeRef::Literal(ty) => self.1.narrow_to_type(self, ty), - } + pub fn narrow(&mut self, other: &ScopeHint) -> Result, ErrorKind> { + self.1.combine_vars(self, other) } pub fn as_type(&self) -> TypeKind { @@ -52,6 +55,39 @@ impl TypeHints { self.hints.borrow_mut().push(ty); typecell } + + pub fn find(&self, ty: TypeKind) -> Option { + if ty.known().is_err() { + // Only do this for non-vague types that can not be further narrowed + // down. + return None; + } + + if let Some(idx) = self + .hints + .borrow_mut() + .iter() + .enumerate() + .find(|(_, t)| **t == ty) + .map(|(i, _)| i) + { + Some(Rc::new(RefCell::new(idx))) + } else { + None + } + } + + unsafe fn recurse_type_ref(&self, mut idx: usize) -> TypeIdRef { + let refs = self.type_refs.borrow(); + let mut inner_idx = refs.get_unchecked(idx); + let mut seen = HashSet::new(); + while (*inner_idx.borrow()) != idx && !seen.contains(&idx) { + seen.insert(idx); + idx = *inner_idx.borrow(); + inner_idx = refs.get_unchecked(idx); + } + return refs.get_unchecked(idx).clone(); + } } #[derive(Debug)] @@ -72,12 +108,7 @@ impl<'outer> ScopeHints<'outer> { } pub fn retrieve_type(&self, idx: usize) -> Option { - let inner_idx = self - .types - .type_refs - .borrow() - .get(idx) - .map(|i| *i.borrow())?; + let inner_idx = unsafe { *self.types.recurse_type_ref(idx).borrow() }; self.types.hints.borrow().get(inner_idx).copied() } @@ -97,9 +128,22 @@ impl<'outer> ScopeHints<'outer> { Ok(ScopeHint(idx, self)) } - fn new_vague(&'outer self, vague: &VagueType) -> ScopeHint<'outer> { - let idx = self.types.new(TypeKind::Vague(*vague)); - ScopeHint(idx, self) + pub fn from_type(&'outer self, ty: &TypeKind) -> Option> { + let idx = match ty { + TypeKind::Vague(super::VagueType::Hinted(idx)) => { + let inner_idx = unsafe { *self.types.recurse_type_ref(*idx).borrow() }; + self.types.type_refs.borrow().get(inner_idx).cloned()? + } + TypeKind::Vague(_) => self.types.new(*ty), + _ => { + if let Some(ty_ref) = self.types.find(*ty) { + ty_ref + } else { + self.types.new(*ty) + } + } + }; + Some(ScopeHint(idx, self)) } fn narrow_to_type( @@ -129,7 +173,7 @@ impl<'outer> ScopeHints<'outer> { .clone(); self.narrow_to_type(&hint1, &ty)?; for idx in self.types.type_refs.borrow_mut().iter_mut() { - if *idx == hint2.0 { + if *idx == hint2.0 && idx != &hint1.0 { *idx.borrow_mut() = *hint1.0.borrow(); } } @@ -156,58 +200,16 @@ impl<'outer> ScopeHints<'outer> { pub fn binop( &'outer self, op: &BinaryOperator, - lhs: &mut TypeRef<'outer>, - rhs: &mut TypeRef<'outer>, - ) -> Result, ErrorKind> { + lhs: &mut ScopeHint<'outer>, + rhs: &mut ScopeHint<'outer>, + ) -> Result, ErrorKind> { let ty = lhs.narrow(rhs)?; Ok(match op { BinaryOperator::Add => ty, BinaryOperator::Minus => ty, BinaryOperator::Mult => ty, - BinaryOperator::And => TypeRef::Literal(TypeKind::Bool), - BinaryOperator::Cmp(_) => TypeRef::Literal(TypeKind::Bool), - }) - } -} - -#[derive(Debug)] -pub enum TypeRef<'scope> { - Hint(ScopeHint<'scope>), - Literal(TypeKind), -} - -impl<'scope> TypeRef<'scope> { - pub fn narrow(&mut self, other: &mut TypeRef<'scope>) -> Result, ErrorKind> { - match (self, other) { - (TypeRef::Hint(hint), unk) | (unk, TypeRef::Hint(hint)) => { - Ok(TypeRef::Hint(hint.narrow(unk)?)) - } - (TypeRef::Literal(lit1), TypeRef::Literal(lit2)) => { - Ok(TypeRef::Literal(lit1.collapse_into(lit2)?)) - } - } - } - - pub fn from_type(hints: &'scope ScopeHints<'scope>, ty: TypeKind) -> TypeRef<'scope> { - match &ty.known() { - Ok(ty) => TypeRef::Literal(*ty), - Err(vague) => match &vague { - super::VagueType::Hinted(idx) => TypeRef::Hint(ScopeHint( - unsafe { hints.types.type_refs.borrow().get_unchecked(*idx).clone() }, - hints, - )), - _ => TypeRef::Hint(hints.new_vague(vague)), - }, - } - } - - pub fn from_literal( - hints: &'scope ScopeHints<'scope>, - lit: Literal, - ) -> Result, ErrorKind> { - Ok(match lit { - Literal::Vague(vague) => TypeRef::Hint(hints.new_vague(&vague.as_type())), - _ => TypeRef::Literal(lit.as_type()), + BinaryOperator::And => self.from_type(&TypeKind::Bool).unwrap(), + BinaryOperator::Cmp(_) => self.from_type(&TypeKind::Bool).unwrap(), }) } } diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index a93161f..c609232 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -8,7 +8,7 @@ use VagueType::*; use super::{ pass::{Pass, PassState, ScopeFunction, ScopeVariable}, - scopehints::{ScopeHints, TypeHints, TypeRef}, + scopehints::{ScopeHint, ScopeHints, TypeHints}, types::{pick_return, ReturnType}, }; @@ -82,6 +82,7 @@ impl FunctionDefinition { let hints = ScopeHints::from(&types); if let Ok(_) = block.infer_hints(state, &hints) { print!("{}", block); + dbg!(&hints); block.typecheck(state, &hints, Some(return_type)) } else { Ok(Vague(Unknown)) @@ -104,7 +105,7 @@ impl Block { &mut self, state: &mut PassState, outer_hints: &'s ScopeHints, - ) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> { + ) -> Result<(ReturnKind, ScopeHint<'s>), ErrorKind> { let mut state = state.inner(); let inner_hints = outer_hints.inner(); @@ -149,12 +150,12 @@ impl Block { } let (kind, ty) = self.return_type().ok().unwrap_or((ReturnKind::Soft, Void)); - let mut ret_type_ref = TypeRef::from_type(&outer_hints, ty); + let mut ret_type_ref = outer_hints.from_type(&ty).unwrap(); if kind == ReturnKind::Hard { if let Some(hint) = state.scope.return_type_hint { state.ok( - ret_type_ref.narrow(&mut TypeRef::from_type(outer_hints, hint)), + ret_type_ref.narrow(&mut outer_hints.from_type(&hint).unwrap()), self.meta, ); } @@ -312,7 +313,7 @@ impl Expression { &mut self, state: &mut PassState, hints: &'s ScopeHints<'s>, - ) -> Result, ErrorKind> { + ) -> Result, ErrorKind> { match &mut self.0 { ExprKind::Variable(var) => { let hint = hints @@ -322,9 +323,9 @@ impl Expression { if let Ok(hint) = &hint { var.0 = hint.as_type() } - Ok(TypeRef::Hint(hint?)) + hint } - ExprKind::Literal(literal) => TypeRef::from_literal(hints, *literal), + ExprKind::Literal(literal) => Ok(hints.from_type(&literal.as_type()).unwrap()), ExprKind::BinOp(op, lhs, rhs) => { let mut lhs_ref = lhs.infer_hints(state, hints)?; let mut rhs_ref = rhs.infer_hints(state, hints)?; @@ -346,20 +347,23 @@ impl Expression { let expr_res = param_expr.infer_hints(state, hints); if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) { state.ok( - param_ref.narrow(&mut TypeRef::from_type(hints, *param_t)), + param_ref.narrow(&mut hints.from_type(param_t).unwrap()), param_expr.1, ); } } - Ok(TypeRef::from_type(hints, fn_call.ret)) + Ok(hints.from_type(&fn_call.ret).unwrap()) } ExprKind::If(IfExpression(cond, lhs, rhs)) => { let cond_res = cond.infer_hints(state, hints); let cond_hints = state.ok(cond_res, cond.1); if let Some(mut cond_hints) = cond_hints { - state.ok(cond_hints.narrow(&mut TypeRef::Literal(Bool)), cond.1); + state.ok( + cond_hints.narrow(&mut hints.from_type(&Bool).unwrap()), + cond.1, + ); } let lhs_res = lhs.infer_hints(state, hints); @@ -374,20 +378,20 @@ impl Expression { Ok(pick_return(lhs_hints, rhs_hints).1) } else { // Failed to retrieve types from either - Ok(TypeRef::from_type(hints, Vague(Unknown))) + Ok(hints.from_type(&Vague(Unknown)).unwrap()) } } else { if let Some((_, type_ref)) = lhs_hints { Ok(type_ref) } else { - Ok(TypeRef::from_type(hints, Vague(Unknown))) + Ok(hints.from_type(&Vague(Unknown)).unwrap()) } } } ExprKind::Block(block) => { let block_ref = block.infer_hints(state, hints)?; match block_ref.0 { - ReturnKind::Hard => Ok(TypeRef::from_type(hints, Void)), + ReturnKind::Hard => Ok(hints.from_type(&Void).unwrap()), ReturnKind::Soft => Ok(block_ref.1), } }