Refactor Type Hints -> Type Refs

This commit is contained in:
Sofia 2025-07-13 15:58:19 +03:00
parent f3f47831e9
commit 92736e392e
5 changed files with 74 additions and 74 deletions

View File

@ -39,7 +39,7 @@
//! - Loops //! - Loops
//! ``` //! ```
use mir::{scopehints::TypeHints, typecheck::TypeCheck, typeinference::TypeInference}; use mir::{typecheck::TypeCheck, typeinference::TypeInference, typerefs::TypeRefs};
use reid_lib::Context; use reid_lib::Context;
use crate::{ast::TopLevelStatement, lexer::Token, token_stream::TokenStream}; use crate::{ast::TopLevelStatement, lexer::Token, token_stream::TokenStream};
@ -89,13 +89,13 @@ pub fn compile(source: &str) -> Result<String, ReidError> {
println!("{}", &mir_context); println!("{}", &mir_context);
let hints = TypeHints::default(); let refs = TypeRefs::default();
let state = mir_context.pass(&mut TypeInference { hints: &hints }); let state = mir_context.pass(&mut TypeInference { refs: &refs });
dbg!(&state, &hints); dbg!(&state, &refs);
println!("{}", &mir_context); println!("{}", &mir_context);
let state = mir_context.pass(&mut TypeCheck { hints: &hints }); let state = mir_context.pass(&mut TypeCheck { refs: &refs });
dbg!(&state); dbg!(&state);
println!("{}", &mir_context); println!("{}", &mir_context);

View File

@ -6,9 +6,9 @@ use crate::token_stream::TokenRange;
mod display; mod display;
pub mod pass; pub mod pass;
pub mod scopehints;
pub mod typecheck; pub mod typecheck;
pub mod typeinference; pub mod typeinference;
pub mod typerefs;
pub mod types; pub mod types;
#[derive(Debug, Default, Clone, Copy)] #[derive(Debug, Default, Clone, Copy)]
@ -68,8 +68,8 @@ pub enum VagueType {
Unknown, Unknown,
#[error("Number")] #[error("Number")]
Number, Number,
#[error("Hinted({0})")] #[error("TypeRef({0})")]
Hinted(usize), TypeRef(usize),
} }
impl TypeKind { impl TypeKind {

View File

@ -8,7 +8,7 @@ use VagueType::*;
use super::{ use super::{
pass::{Pass, PassState, ScopeFunction, ScopeVariable}, pass::{Pass, PassState, ScopeFunction, ScopeVariable},
scopehints::{ScopeHints, TypeHint, TypeHints}, typerefs::{ScopeTypeRefs, TypeRef, TypeRefs},
types::{pick_return, ReturnType}, types::{pick_return, ReturnType},
}; };
@ -43,7 +43,7 @@ pub enum ErrorKind {
/// Struct used to implement a type-checking pass that can be performed on the /// Struct used to implement a type-checking pass that can be performed on the
/// MIR. /// MIR.
pub struct TypeCheck<'t> { pub struct TypeCheck<'t> {
pub hints: &'t TypeHints, pub refs: &'t TypeRefs,
} }
impl<'t> Pass for TypeCheck<'t> { impl<'t> Pass for TypeCheck<'t> {
@ -51,7 +51,7 @@ impl<'t> Pass for TypeCheck<'t> {
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) { fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) {
for function in &mut module.functions { for function in &mut module.functions {
let res = function.typecheck(&self.hints, &mut state); let res = function.typecheck(&self.refs, &mut state);
state.ok(res, function.block_meta()); state.ok(res, function.block_meta());
} }
} }
@ -60,7 +60,7 @@ impl<'t> Pass for TypeCheck<'t> {
impl FunctionDefinition { impl FunctionDefinition {
fn typecheck( fn typecheck(
&mut self, &mut self,
hints: &TypeHints, hints: &TypeRefs,
state: &mut PassState<ErrorKind>, state: &mut PassState<ErrorKind>,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {
for param in &self.parameters { for param in &self.parameters {
@ -101,7 +101,7 @@ impl Block {
fn typecheck( fn typecheck(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut PassState<ErrorKind>,
hints: &TypeHints, hints: &TypeRefs,
hint_t: Option<TypeKind>, hint_t: Option<TypeKind>,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {
let mut state = state.inner(); let mut state = state.inner();
@ -246,7 +246,7 @@ impl Expression {
fn typecheck( fn typecheck(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut PassState<ErrorKind>,
hints: &TypeHints, hints: &TypeRefs,
hint_t: Option<TypeKind>, hint_t: Option<TypeKind>,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {
match &mut self.0 { match &mut self.0 {
@ -428,7 +428,7 @@ impl TypeKind {
Vague(vague_type) => match vague_type { Vague(vague_type) => match vague_type {
Unknown => Err(ErrorKind::TypeIsVague(*vague_type)), Unknown => Err(ErrorKind::TypeIsVague(*vague_type)),
Number => Ok(TypeKind::I32), Number => Ok(TypeKind::I32),
Hinted(_) => panic!("Hinted default!"), TypeRef(_) => panic!("Hinted default!"),
}, },
_ => Ok(*self), _ => Ok(*self),
} }
@ -445,9 +445,9 @@ impl TypeKind {
}) })
} }
fn resolve_hinted(&self, hints: &TypeHints) -> TypeKind { fn resolve_hinted(&self, hints: &TypeRefs) -> TypeKind {
match self { match self {
Vague(Hinted(idx)) => hints.retrieve_type(*idx).unwrap(), Vague(TypeRef(idx)) => hints.retrieve_type(*idx).unwrap(),
_ => *self, _ => *self,
} }
} }

View File

@ -4,8 +4,8 @@ use reid_lib::Function;
use super::{ use super::{
pass::{Pass, PassState, ScopeVariable}, pass::{Pass, PassState, ScopeVariable},
scopehints::{self, ScopeHints, TypeHint, TypeHints},
typecheck::ErrorKind, typecheck::ErrorKind,
typerefs::{self, ScopeTypeRefs, TypeRef, TypeRefs},
types::{pick_return, ReturnType}, types::{pick_return, ReturnType},
Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module, Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module,
ReturnKind, StmtKind, ReturnKind, StmtKind,
@ -16,7 +16,7 @@ use super::{
/// Struct used to implement a type-checking pass that can be performed on the /// Struct used to implement a type-checking pass that can be performed on the
/// MIR. /// MIR.
pub struct TypeInference<'t> { pub struct TypeInference<'t> {
pub hints: &'t TypeHints, pub refs: &'t TypeRefs,
} }
impl<'t> Pass for TypeInference<'t> { impl<'t> Pass for TypeInference<'t> {
@ -24,7 +24,7 @@ impl<'t> Pass for TypeInference<'t> {
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) { fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) {
for function in &mut module.functions { for function in &mut module.functions {
let res = function.infer_hints(&self.hints, &mut state); let res = function.infer_hints(&self.refs, &mut state);
state.ok(res, function.block_meta()); state.ok(res, function.block_meta());
} }
} }
@ -33,7 +33,7 @@ impl<'t> Pass for TypeInference<'t> {
impl FunctionDefinition { impl FunctionDefinition {
fn infer_hints( fn infer_hints(
&mut self, &mut self,
hints: &TypeHints, type_refs: &TypeRefs,
state: &mut PassState<ErrorKind>, state: &mut PassState<ErrorKind>,
) -> Result<(), ErrorKind> { ) -> Result<(), ErrorKind> {
for param in &self.parameters { for param in &self.parameters {
@ -51,7 +51,7 @@ impl FunctionDefinition {
.or(Err(ErrorKind::VariableAlreadyDefined(param.0.clone()))); .or(Err(ErrorKind::VariableAlreadyDefined(param.0.clone())));
state.ok(res, self.signature()); state.ok(res, self.signature());
} }
let scope_hints = ScopeHints::from(hints); let scope_hints = ScopeTypeRefs::from(type_refs);
let return_type = self.return_type.clone(); let return_type = self.return_type.clone();
let return_type_hint = scope_hints.from_type(&return_type).unwrap(); let return_type_hint = scope_hints.from_type(&return_type).unwrap();
@ -76,8 +76,8 @@ impl Block {
fn infer_hints<'s>( fn infer_hints<'s>(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut PassState<ErrorKind>,
outer_hints: &'s ScopeHints, outer_hints: &'s ScopeTypeRefs,
) -> Result<(ReturnKind, TypeHint<'s>), ErrorKind> { ) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> {
let mut state = state.inner(); let mut state = state.inner();
let inner_hints = outer_hints.inner(); let inner_hints = outer_hints.inner();
@ -140,11 +140,11 @@ impl Expression {
fn infer_hints<'s>( fn infer_hints<'s>(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut PassState<ErrorKind>,
hints: &'s ScopeHints<'s>, type_refs: &'s ScopeTypeRefs<'s>,
) -> Result<TypeHint<'s>, ErrorKind> { ) -> Result<TypeRef<'s>, ErrorKind> {
match &mut self.0 { match &mut self.0 {
ExprKind::Variable(var) => { ExprKind::Variable(var) => {
let hint = hints let hint = type_refs
.find_hint(&var.1) .find_hint(&var.1)
.map(|(_, hint)| hint) .map(|(_, hint)| hint)
.ok_or(ErrorKind::VariableNotDefined(var.1.clone())); .ok_or(ErrorKind::VariableNotDefined(var.1.clone()));
@ -153,11 +153,11 @@ impl Expression {
} }
hint hint
} }
ExprKind::Literal(literal) => Ok(hints.from_type(&literal.as_type()).unwrap()), ExprKind::Literal(literal) => Ok(type_refs.from_type(&literal.as_type()).unwrap()),
ExprKind::BinOp(op, lhs, rhs) => { ExprKind::BinOp(op, lhs, rhs) => {
let mut lhs_ref = lhs.infer_hints(state, hints)?; let mut lhs_ref = lhs.infer_hints(state, type_refs)?;
let mut rhs_ref = rhs.infer_hints(state, hints)?; let mut rhs_ref = rhs.infer_hints(state, type_refs)?;
hints.binop(op, &mut lhs_ref, &mut rhs_ref) type_refs.binop(op, &mut lhs_ref, &mut rhs_ref)
} }
ExprKind::FunctionCall(function_call) => { ExprKind::FunctionCall(function_call) => {
let fn_call = state let fn_call = state
@ -172,33 +172,33 @@ impl Expression {
for (param_expr, param_t) in for (param_expr, param_t) in
function_call.parameters.iter_mut().zip(true_params_iter) function_call.parameters.iter_mut().zip(true_params_iter)
{ {
let expr_res = param_expr.infer_hints(state, hints); let expr_res = param_expr.infer_hints(state, type_refs);
if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) { if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) {
state.ok( state.ok(
param_ref.narrow(&mut hints.from_type(param_t).unwrap()), param_ref.narrow(&mut type_refs.from_type(param_t).unwrap()),
param_expr.1, param_expr.1,
); );
} }
} }
Ok(hints.from_type(&fn_call.ret).unwrap()) Ok(type_refs.from_type(&fn_call.ret).unwrap())
} }
ExprKind::If(IfExpression(cond, lhs, rhs)) => { ExprKind::If(IfExpression(cond, lhs, rhs)) => {
let cond_res = cond.infer_hints(state, hints); let cond_res = cond.infer_hints(state, type_refs);
let cond_hints = state.ok(cond_res, cond.1); let cond_hints = state.ok(cond_res, cond.1);
if let Some(mut cond_hints) = cond_hints { if let Some(mut cond_hints) = cond_hints {
state.ok( state.ok(
cond_hints.narrow(&mut hints.from_type(&Bool).unwrap()), cond_hints.narrow(&mut type_refs.from_type(&Bool).unwrap()),
cond.1, cond.1,
); );
} }
let lhs_res = lhs.infer_hints(state, hints); let lhs_res = lhs.infer_hints(state, type_refs);
let lhs_hints = state.ok(lhs_res, cond.1); let lhs_hints = state.ok(lhs_res, cond.1);
if let Some(rhs) = rhs { if let Some(rhs) = rhs {
let rhs_res = rhs.infer_hints(state, hints); let rhs_res = rhs.infer_hints(state, type_refs);
let rhs_hints = state.ok(rhs_res, cond.1); let rhs_hints = state.ok(rhs_res, cond.1);
if let (Some(mut lhs_hints), Some(mut rhs_hints)) = (lhs_hints, rhs_hints) { if let (Some(mut lhs_hints), Some(mut rhs_hints)) = (lhs_hints, rhs_hints) {
@ -206,20 +206,20 @@ impl Expression {
Ok(pick_return(lhs_hints, rhs_hints).1) Ok(pick_return(lhs_hints, rhs_hints).1)
} else { } else {
// Failed to retrieve types from either // Failed to retrieve types from either
Ok(hints.from_type(&Vague(Unknown)).unwrap()) Ok(type_refs.from_type(&Vague(Unknown)).unwrap())
} }
} else { } else {
if let Some((_, type_ref)) = lhs_hints { if let Some((_, type_ref)) = lhs_hints {
Ok(type_ref) Ok(type_ref)
} else { } else {
Ok(hints.from_type(&Vague(Unknown)).unwrap()) Ok(type_refs.from_type(&Vague(Unknown)).unwrap())
} }
} }
} }
ExprKind::Block(block) => { ExprKind::Block(block) => {
let block_ref = block.infer_hints(state, hints)?; let block_ref = block.infer_hints(state, type_refs)?;
match block_ref.0 { match block_ref.0 {
ReturnKind::Hard => Ok(hints.from_type(&Void).unwrap()), ReturnKind::Hard => Ok(type_refs.from_type(&Void).unwrap()),
ReturnKind::Soft => Ok(block_ref.1), ReturnKind::Soft => Ok(block_ref.1),
} }
} }

View File

@ -10,23 +10,23 @@ use super::{
}; };
#[derive(Clone)] #[derive(Clone)]
pub struct TypeHint<'scope>(TypeIdRef, &'scope ScopeHints<'scope>); pub struct TypeRef<'scope>(TypeIdRef, &'scope ScopeTypeRefs<'scope>);
impl<'scope> TypeHint<'scope> { impl<'scope> TypeRef<'scope> {
pub unsafe fn resolve_type(&self) -> TypeKind { pub unsafe fn resolve_type(&self) -> TypeKind {
unsafe { *self.1.types.hints.borrow().get_unchecked(*self.0.borrow()) } unsafe { *self.1.types.hints.borrow().get_unchecked(*self.0.borrow()) }
} }
pub fn narrow(&mut self, other: &TypeHint) -> Result<TypeHint<'scope>, ErrorKind> { pub fn narrow(&mut self, other: &TypeRef) -> Result<TypeRef<'scope>, ErrorKind> {
self.1.combine_vars(self, other) self.1.combine_vars(self, other)
} }
pub fn as_type(&self) -> TypeKind { pub fn as_type(&self) -> TypeKind {
TypeKind::Vague(super::VagueType::Hinted(*self.0.borrow())) TypeKind::Vague(super::VagueType::TypeRef(*self.0.borrow()))
} }
} }
impl<'scope> std::fmt::Debug for TypeHint<'scope> { impl<'scope> std::fmt::Debug for TypeRef<'scope> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Hint") f.debug_tuple("Hint")
.field(&self.0) .field(&self.0)
@ -38,14 +38,14 @@ impl<'scope> std::fmt::Debug for TypeHint<'scope> {
type TypeIdRef = Rc<RefCell<usize>>; type TypeIdRef = Rc<RefCell<usize>>;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct TypeHints { pub struct TypeRefs {
/// Simple list of types that variables can refrence /// Simple list of types that variables can refrence
hints: RefCell<Vec<TypeKind>>, hints: RefCell<Vec<TypeKind>>,
/// Indirect ID-references, referring to hints-vec /// Indirect ID-references, referring to hints-vec
type_refs: RefCell<Vec<TypeIdRef>>, type_refs: RefCell<Vec<TypeIdRef>>,
} }
impl TypeHints { impl TypeRefs {
pub fn new(&self, ty: TypeKind) -> TypeIdRef { pub fn new(&self, ty: TypeKind) -> TypeIdRef {
let idx = self.hints.borrow().len(); let idx = self.hints.borrow().len();
let typecell = Rc::new(RefCell::new(idx)); let typecell = Rc::new(RefCell::new(idx));
@ -94,16 +94,16 @@ impl TypeHints {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct ScopeHints<'outer> { pub struct ScopeTypeRefs<'outer> {
types: &'outer TypeHints, types: &'outer TypeRefs,
outer: Option<&'outer ScopeHints<'outer>>, outer: Option<&'outer ScopeTypeRefs<'outer>>,
/// Mapping of what types variables point to /// Mapping of what types variables point to
variables: RefCell<HashMap<String, (bool, TypeIdRef)>>, variables: RefCell<HashMap<String, (bool, TypeIdRef)>>,
} }
impl<'outer> ScopeHints<'outer> { impl<'outer> ScopeTypeRefs<'outer> {
pub fn from(types: &'outer TypeHints) -> ScopeHints<'outer> { pub fn from(types: &'outer TypeRefs) -> ScopeTypeRefs<'outer> {
ScopeHints { ScopeTypeRefs {
types, types,
outer: Default::default(), outer: Default::default(),
variables: Default::default(), variables: Default::default(),
@ -115,7 +115,7 @@ impl<'outer> ScopeHints<'outer> {
name: String, name: String,
mutable: bool, mutable: bool,
initial_ty: TypeKind, initial_ty: TypeKind,
) -> Result<TypeHint<'outer>, ErrorKind> { ) -> Result<TypeRef<'outer>, ErrorKind> {
if self.variables.borrow().contains_key(&name) { if self.variables.borrow().contains_key(&name) {
return Err(ErrorKind::VariableAlreadyDefined(name)); return Err(ErrorKind::VariableAlreadyDefined(name));
} }
@ -123,12 +123,12 @@ impl<'outer> ScopeHints<'outer> {
self.variables self.variables
.borrow_mut() .borrow_mut()
.insert(name, (mutable, idx.clone())); .insert(name, (mutable, idx.clone()));
Ok(TypeHint(idx, self)) Ok(TypeRef(idx, self))
} }
pub fn from_type(&'outer self, ty: &TypeKind) -> Option<TypeHint<'outer>> { pub fn from_type(&'outer self, ty: &TypeKind) -> Option<TypeRef<'outer>> {
let idx = match ty { let idx = match ty {
TypeKind::Vague(super::VagueType::Hinted(idx)) => { TypeKind::Vague(super::VagueType::TypeRef(idx)) => {
let inner_idx = unsafe { *self.types.recurse_type_ref(*idx).borrow() }; let inner_idx = unsafe { *self.types.recurse_type_ref(*idx).borrow() };
self.types.type_refs.borrow().get(inner_idx).cloned()? self.types.type_refs.borrow().get(inner_idx).cloned()?
} }
@ -141,27 +141,27 @@ impl<'outer> ScopeHints<'outer> {
} }
} }
}; };
Some(TypeHint(idx, self)) Some(TypeRef(idx, self))
} }
fn narrow_to_type( fn narrow_to_type(
&'outer self, &'outer self,
hint: &TypeHint, hint: &TypeRef,
ty: &TypeKind, ty: &TypeKind,
) -> Result<TypeHint<'outer>, ErrorKind> { ) -> Result<TypeRef<'outer>, ErrorKind> {
unsafe { unsafe {
let mut hints = self.types.hints.borrow_mut(); let mut hints = self.types.hints.borrow_mut();
let existing = hints.get_unchecked_mut(*hint.0.borrow()); let existing = hints.get_unchecked_mut(*hint.0.borrow());
*existing = existing.collapse_into(&ty)?; *existing = existing.collapse_into(&ty)?;
Ok(TypeHint(hint.0.clone(), self)) Ok(TypeRef(hint.0.clone(), self))
} }
} }
fn combine_vars( fn combine_vars(
&'outer self, &'outer self,
hint1: &TypeHint, hint1: &TypeRef,
hint2: &TypeHint, hint2: &TypeRef,
) -> Result<TypeHint<'outer>, ErrorKind> { ) -> Result<TypeRef<'outer>, ErrorKind> {
unsafe { unsafe {
let ty = self let ty = self
.types .types
@ -175,32 +175,32 @@ impl<'outer> ScopeHints<'outer> {
*idx.borrow_mut() = *hint1.0.borrow(); *idx.borrow_mut() = *hint1.0.borrow();
} }
} }
Ok(TypeHint(hint1.0.clone(), self)) Ok(TypeRef(hint1.0.clone(), self))
} }
} }
pub fn inner(&'outer self) -> ScopeHints<'outer> { pub fn inner(&'outer self) -> ScopeTypeRefs<'outer> {
ScopeHints { ScopeTypeRefs {
types: self.types, types: self.types,
outer: Some(self), outer: Some(self),
variables: Default::default(), variables: Default::default(),
} }
} }
pub fn find_hint(&'outer self, name: &String) -> Option<(bool, TypeHint<'outer>)> { pub fn find_hint(&'outer self, name: &String) -> Option<(bool, TypeRef<'outer>)> {
self.variables self.variables
.borrow() .borrow()
.get(name) .get(name)
.map(|(mutable, idx)| (*mutable, TypeHint(idx.clone(), self))) .map(|(mutable, idx)| (*mutable, TypeRef(idx.clone(), self)))
.or(self.outer.map(|o| o.find_hint(name)).flatten()) .or(self.outer.map(|o| o.find_hint(name)).flatten())
} }
pub fn binop( pub fn binop(
&'outer self, &'outer self,
op: &BinaryOperator, op: &BinaryOperator,
lhs: &mut TypeHint<'outer>, lhs: &mut TypeRef<'outer>,
rhs: &mut TypeHint<'outer>, rhs: &mut TypeRef<'outer>,
) -> Result<TypeHint<'outer>, ErrorKind> { ) -> Result<TypeRef<'outer>, ErrorKind> {
let ty = lhs.narrow(rhs)?; let ty = lhs.narrow(rhs)?;
Ok(match op { Ok(match op {
BinaryOperator::Add => ty, BinaryOperator::Add => ty,