Compare commits

...

4 Commits

Author SHA1 Message Date
f3f47831e9 Split type inference to it's very own pass 2025-07-13 15:55:14 +03:00
7d77e1df32 rename ScopeHint to TypeHint 2025-07-13 15:27:28 +03:00
0d631bfa89 Remove redundant TypeRef, add other optimizations 2025-07-13 15:26:36 +03:00
29e78cf1aa update errors 2025-07-13 13:58:31 +03:00
5 changed files with 337 additions and 256 deletions

View File

@ -39,7 +39,7 @@
//! - Loops
//! ```
use mir::typecheck::TypeCheck;
use mir::{scopehints::TypeHints, typecheck::TypeCheck, typeinference::TypeInference};
use reid_lib::Context;
use crate::{ast::TopLevelStatement, lexer::Token, token_stream::TokenStream};
@ -89,9 +89,14 @@ pub fn compile(source: &str) -> Result<String, ReidError> {
println!("{}", &mir_context);
let state = mir_context.pass(&mut TypeCheck);
dbg!(&state);
let hints = TypeHints::default();
let state = mir_context.pass(&mut TypeInference { hints: &hints });
dbg!(&state, &hints);
println!("{}", &mir_context);
let state = mir_context.pass(&mut TypeCheck { hints: &hints });
dbg!(&state);
println!("{}", &mir_context);
if !state.errors.is_empty() {

View File

@ -6,8 +6,9 @@ use crate::token_stream::TokenRange;
mod display;
pub mod pass;
mod scopehints;
pub mod scopehints;
pub mod typecheck;
pub mod typeinference;
pub mod types;
#[derive(Debug, Default, Clone, Copy)]

View File

@ -1,23 +1,24 @@
use std::{cell::RefCell, collections::HashMap, rc::Rc};
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
rc::Rc,
};
use super::{
typecheck::{Collapsable, ErrorKind},
BinaryOperator, Literal, TypeKind, VagueType,
BinaryOperator, TypeKind,
};
#[derive(Clone)]
pub struct ScopeHint<'scope>(TypeIdRef, &'scope ScopeHints<'scope>);
pub struct TypeHint<'scope>(TypeIdRef, &'scope ScopeHints<'scope>);
impl<'scope> ScopeHint<'scope> {
impl<'scope> TypeHint<'scope> {
pub unsafe fn resolve_type(&self) -> TypeKind {
unsafe { *self.1.types.hints.borrow().get_unchecked(*self.0.borrow()) }
}
pub fn narrow(&mut self, ty_ref: &TypeRef) -> Result<ScopeHint<'scope>, ErrorKind> {
match ty_ref {
TypeRef::Hint(other) => self.1.combine_vars(self, other),
TypeRef::Literal(ty) => self.1.narrow_to_type(self, ty),
}
pub fn narrow(&mut self, other: &TypeHint) -> Result<TypeHint<'scope>, ErrorKind> {
self.1.combine_vars(self, other)
}
pub fn as_type(&self) -> TypeKind {
@ -25,7 +26,7 @@ impl<'scope> ScopeHint<'scope> {
}
}
impl<'scope> std::fmt::Debug for ScopeHint<'scope> {
impl<'scope> std::fmt::Debug for TypeHint<'scope> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Hint")
.field(&self.0)
@ -52,6 +53,44 @@ impl TypeHints {
self.hints.borrow_mut().push(ty);
typecell
}
pub fn find(&self, ty: TypeKind) -> Option<TypeIdRef> {
if ty.known().is_err() {
// Only do this for non-vague types that can not be further narrowed
// down.
return None;
}
if let Some(idx) = self
.hints
.borrow_mut()
.iter()
.enumerate()
.find(|(_, t)| **t == ty)
.map(|(i, _)| i)
{
Some(Rc::new(RefCell::new(idx)))
} else {
None
}
}
unsafe fn recurse_type_ref(&self, mut idx: usize) -> TypeIdRef {
let refs = self.type_refs.borrow();
let mut inner_idx = refs.get_unchecked(idx);
let mut seen = HashSet::new();
while (*inner_idx.borrow()) != idx && !seen.contains(&idx) {
seen.insert(idx);
idx = *inner_idx.borrow();
inner_idx = refs.get_unchecked(idx);
}
return refs.get_unchecked(idx).clone();
}
pub fn retrieve_type(&self, idx: usize) -> Option<TypeKind> {
let inner_idx = unsafe { *self.recurse_type_ref(idx).borrow() };
self.hints.borrow().get(inner_idx).copied()
}
}
#[derive(Debug)]
@ -71,22 +110,12 @@ impl<'outer> ScopeHints<'outer> {
}
}
pub fn retrieve_type(&self, idx: usize) -> Option<TypeKind> {
let inner_idx = self
.types
.type_refs
.borrow()
.get(idx)
.map(|i| *i.borrow())?;
self.types.hints.borrow().get(inner_idx).copied()
}
pub fn new_var(
&'outer self,
name: String,
mutable: bool,
initial_ty: TypeKind,
) -> Result<ScopeHint<'outer>, ErrorKind> {
) -> Result<TypeHint<'outer>, ErrorKind> {
if self.variables.borrow().contains_key(&name) {
return Err(ErrorKind::VariableAlreadyDefined(name));
}
@ -94,32 +123,45 @@ impl<'outer> ScopeHints<'outer> {
self.variables
.borrow_mut()
.insert(name, (mutable, idx.clone()));
Ok(ScopeHint(idx, self))
Ok(TypeHint(idx, self))
}
fn new_vague(&'outer self, vague: &VagueType) -> ScopeHint<'outer> {
let idx = self.types.new(TypeKind::Vague(*vague));
ScopeHint(idx, self)
pub fn from_type(&'outer self, ty: &TypeKind) -> Option<TypeHint<'outer>> {
let idx = match ty {
TypeKind::Vague(super::VagueType::Hinted(idx)) => {
let inner_idx = unsafe { *self.types.recurse_type_ref(*idx).borrow() };
self.types.type_refs.borrow().get(inner_idx).cloned()?
}
TypeKind::Vague(_) => self.types.new(*ty),
_ => {
if let Some(ty_ref) = self.types.find(*ty) {
ty_ref
} else {
self.types.new(*ty)
}
}
};
Some(TypeHint(idx, self))
}
fn narrow_to_type(
&'outer self,
hint: &ScopeHint,
hint: &TypeHint,
ty: &TypeKind,
) -> Result<ScopeHint<'outer>, ErrorKind> {
) -> Result<TypeHint<'outer>, ErrorKind> {
unsafe {
let mut hints = self.types.hints.borrow_mut();
let existing = hints.get_unchecked_mut(*hint.0.borrow());
*existing = existing.collapse_into(&ty)?;
Ok(ScopeHint(hint.0.clone(), self))
Ok(TypeHint(hint.0.clone(), self))
}
}
fn combine_vars(
&'outer self,
hint1: &ScopeHint,
hint2: &ScopeHint,
) -> Result<ScopeHint<'outer>, ErrorKind> {
hint1: &TypeHint,
hint2: &TypeHint,
) -> Result<TypeHint<'outer>, ErrorKind> {
unsafe {
let ty = self
.types
@ -129,11 +171,11 @@ impl<'outer> ScopeHints<'outer> {
.clone();
self.narrow_to_type(&hint1, &ty)?;
for idx in self.types.type_refs.borrow_mut().iter_mut() {
if *idx == hint2.0 {
if *idx == hint2.0 && idx != &hint1.0 {
*idx.borrow_mut() = *hint1.0.borrow();
}
}
Ok(ScopeHint(hint1.0.clone(), self))
Ok(TypeHint(hint1.0.clone(), self))
}
}
@ -145,69 +187,27 @@ impl<'outer> ScopeHints<'outer> {
}
}
pub fn find_hint(&'outer self, name: &String) -> Option<(bool, ScopeHint<'outer>)> {
pub fn find_hint(&'outer self, name: &String) -> Option<(bool, TypeHint<'outer>)> {
self.variables
.borrow()
.get(name)
.map(|(mutable, idx)| (*mutable, ScopeHint(idx.clone(), self)))
.map(|(mutable, idx)| (*mutable, TypeHint(idx.clone(), self)))
.or(self.outer.map(|o| o.find_hint(name)).flatten())
}
pub fn binop(
&'outer self,
op: &BinaryOperator,
lhs: &mut TypeRef<'outer>,
rhs: &mut TypeRef<'outer>,
) -> Result<TypeRef<'outer>, ErrorKind> {
lhs: &mut TypeHint<'outer>,
rhs: &mut TypeHint<'outer>,
) -> Result<TypeHint<'outer>, ErrorKind> {
let ty = lhs.narrow(rhs)?;
Ok(match op {
BinaryOperator::Add => ty,
BinaryOperator::Minus => ty,
BinaryOperator::Mult => ty,
BinaryOperator::And => TypeRef::Literal(TypeKind::Bool),
BinaryOperator::Cmp(_) => TypeRef::Literal(TypeKind::Bool),
})
}
}
#[derive(Debug)]
pub enum TypeRef<'scope> {
Hint(ScopeHint<'scope>),
Literal(TypeKind),
}
impl<'scope> TypeRef<'scope> {
pub fn narrow(&mut self, other: &mut TypeRef<'scope>) -> Result<TypeRef<'scope>, ErrorKind> {
match (self, other) {
(TypeRef::Hint(hint), unk) | (unk, TypeRef::Hint(hint)) => {
Ok(TypeRef::Hint(hint.narrow(unk)?))
}
(TypeRef::Literal(lit1), TypeRef::Literal(lit2)) => {
Ok(TypeRef::Literal(lit1.collapse_into(lit2)?))
}
}
}
pub fn from_type(hints: &'scope ScopeHints<'scope>, ty: TypeKind) -> TypeRef<'scope> {
match &ty.known() {
Ok(ty) => TypeRef::Literal(*ty),
Err(vague) => match &vague {
super::VagueType::Hinted(idx) => TypeRef::Hint(ScopeHint(
unsafe { hints.types.type_refs.borrow().get_unchecked(*idx).clone() },
hints,
)),
_ => TypeRef::Hint(hints.new_vague(vague)),
},
}
}
pub fn from_literal(
hints: &'scope ScopeHints<'scope>,
lit: Literal,
) -> Result<TypeRef<'scope>, ErrorKind> {
Ok(match lit {
Literal::Vague(vague) => TypeRef::Hint(hints.new_vague(&vague.as_type())),
_ => TypeRef::Literal(lit.as_type()),
BinaryOperator::And => self.from_type(&TypeKind::Bool).unwrap(),
BinaryOperator::Cmp(_) => self.from_type(&TypeKind::Bool).unwrap(),
})
}
}

View File

@ -8,7 +8,7 @@ use VagueType::*;
use super::{
pass::{Pass, PassState, ScopeFunction, ScopeVariable},
scopehints::{ScopeHints, TypeHints, TypeRef},
scopehints::{ScopeHints, TypeHint, TypeHints},
types::{pick_return, ReturnType},
};
@ -18,8 +18,6 @@ pub enum ErrorKind {
Null,
#[error("Type is vague: {0}")]
TypeIsVague(VagueType),
#[error("Can not coerce {0} to vague type {1}")]
HintIsVague(TypeKind, VagueType),
#[error("Literal {0} can not be coerced to type {1}")]
LiteralIncompatible(Literal, TypeKind),
#[error("Types {0} and {1} are incompatible")]
@ -28,7 +26,7 @@ pub enum ErrorKind {
VariableNotDefined(String),
#[error("Function not defined: {0}")]
FunctionNotDefined(String),
#[error("Type is vague: {0}")]
#[error("Expected a return type of {0}, got {1} instead")]
ReturnTypeMismatch(TypeKind, TypeKind),
#[error("Function not defined: {0}")]
FunctionAlreadyDefined(String),
@ -44,21 +42,27 @@ pub enum ErrorKind {
/// Struct used to implement a type-checking pass that can be performed on the
/// MIR.
pub struct TypeCheck;
pub struct TypeCheck<'t> {
pub hints: &'t TypeHints,
}
impl Pass for TypeCheck {
impl<'t> Pass for TypeCheck<'t> {
type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) {
for function in &mut module.functions {
let res = function.typecheck(&mut state);
let res = function.typecheck(&self.hints, &mut state);
state.ok(res, function.block_meta());
}
}
}
impl FunctionDefinition {
fn typecheck(&mut self, state: &mut PassState<ErrorKind>) -> Result<TypeKind, ErrorKind> {
fn typecheck(
&mut self,
hints: &TypeHints,
state: &mut PassState<ErrorKind>,
) -> Result<TypeKind, ErrorKind> {
for param in &self.parameters {
let param_t = state.or_else(param.1.assert_known(), Vague(Unknown), self.signature());
let res = state
@ -79,15 +83,7 @@ impl FunctionDefinition {
let inferred = match &mut self.kind {
FunctionDefinitionKind::Local(block, _) => {
state.scope.return_type_hint = Some(self.return_type);
let types = TypeHints::default();
let hints = ScopeHints::from(&types);
if let Ok(_) = block.infer_hints(state, &hints) {
print!("{}", block);
block.typecheck(state, &hints, Some(return_type))
} else {
Ok(Vague(Unknown))
}
block.typecheck(state, &hints, Some(return_type))
}
FunctionDefinitionKind::Extern => Ok(Vague(Unknown)),
};
@ -102,76 +98,13 @@ impl FunctionDefinition {
}
impl Block {
fn infer_hints<'s>(
&mut self,
state: &mut PassState<ErrorKind>,
outer_hints: &'s ScopeHints,
) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> {
let mut state = state.inner();
let inner_hints = outer_hints.inner();
for statement in &mut self.statements {
match &mut statement.0 {
StmtKind::Let(var, mutable, expr) => {
let mut var_ref =
state.ok(inner_hints.new_var(var.1.clone(), *mutable, var.0), var.2);
if let Some(var_ref) = &var_ref {
var.0 = var_ref.as_type();
}
let inferred = expr.infer_hints(&mut state, &inner_hints);
let mut expr_ty_ref = state.ok(inferred, expr.1);
if let (Some(var_ref), Some(expr_ty_ref)) =
(var_ref.as_mut(), expr_ty_ref.as_mut())
{
state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1);
}
}
StmtKind::Set(var, expr) => {
let var_ref = inner_hints.find_hint(&var.1);
if let Some((_, var_ref)) = &var_ref {
var.0 = var_ref.as_type()
}
let inferred = expr.infer_hints(&mut state, &inner_hints);
let expr_ty_ref = state.ok(inferred, expr.1);
if let (Some((_, mut var_ref)), Some(expr_ty_ref)) = (var_ref, expr_ty_ref) {
state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1);
}
}
StmtKind::Import(_) => todo!(),
StmtKind::Expression(expr) => {
let expr_res = expr.infer_hints(&mut state, &inner_hints);
state.ok(expr_res, expr.1);
}
};
}
if let Some(ret_expr) = &mut self.return_expression {
let ret_res = ret_expr.1.infer_hints(&mut state, &inner_hints);
state.ok(ret_res, ret_expr.1 .1);
}
let (kind, ty) = self.return_type().ok().unwrap_or((ReturnKind::Soft, Void));
let mut ret_type_ref = TypeRef::from_type(&outer_hints, ty);
if kind == ReturnKind::Hard {
if let Some(hint) = state.scope.return_type_hint {
state.ok(
ret_type_ref.narrow(&mut TypeRef::from_type(outer_hints, hint)),
self.meta,
);
}
}
Ok((kind, ret_type_ref))
}
fn typecheck(
&mut self,
state: &mut PassState<ErrorKind>,
hints: &ScopeHints,
hints: &TypeHints,
hint_t: Option<TypeKind>,
) -> Result<TypeKind, ErrorKind> {
let mut state = state.inner();
let hints = hints.inner();
let mut early_return = None;
@ -310,96 +243,10 @@ impl Block {
}
impl Expression {
fn infer_hints<'s>(
&mut self,
state: &mut PassState<ErrorKind>,
hints: &'s ScopeHints<'s>,
) -> Result<TypeRef<'s>, ErrorKind> {
match &mut self.0 {
ExprKind::Variable(var) => {
let hint = hints
.find_hint(&var.1)
.map(|(_, hint)| hint)
.ok_or(ErrorKind::VariableNotDefined(var.1.clone()));
if let Ok(hint) = &hint {
var.0 = hint.as_type()
}
Ok(TypeRef::Hint(hint?))
}
ExprKind::Literal(literal) => TypeRef::from_literal(hints, *literal),
ExprKind::BinOp(op, lhs, rhs) => {
let mut lhs_ref = lhs.infer_hints(state, hints)?;
let mut rhs_ref = rhs.infer_hints(state, hints)?;
hints.binop(op, &mut lhs_ref, &mut rhs_ref)
}
ExprKind::FunctionCall(function_call) => {
let fn_call = state
.scope
.function_returns
.get(&function_call.name)
.ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()))?
.clone();
let true_params_iter = fn_call.params.iter().chain(iter::repeat(&Vague(Unknown)));
for (param_expr, param_t) in
function_call.parameters.iter_mut().zip(true_params_iter)
{
let expr_res = param_expr.infer_hints(state, hints);
if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) {
state.ok(
param_ref.narrow(&mut TypeRef::from_type(hints, *param_t)),
param_expr.1,
);
}
}
Ok(TypeRef::from_type(hints, fn_call.ret))
}
ExprKind::If(IfExpression(cond, lhs, rhs)) => {
let cond_res = cond.infer_hints(state, hints);
let cond_hints = state.ok(cond_res, cond.1);
if let Some(mut cond_hints) = cond_hints {
state.ok(cond_hints.narrow(&mut TypeRef::Literal(Bool)), cond.1);
}
let lhs_res = lhs.infer_hints(state, hints);
let lhs_hints = state.ok(lhs_res, cond.1);
if let Some(rhs) = rhs {
let rhs_res = rhs.infer_hints(state, hints);
let rhs_hints = state.ok(rhs_res, cond.1);
if let (Some(mut lhs_hints), Some(mut rhs_hints)) = (lhs_hints, rhs_hints) {
state.ok(lhs_hints.1.narrow(&mut rhs_hints.1), self.1);
Ok(pick_return(lhs_hints, rhs_hints).1)
} else {
// Failed to retrieve types from either
Ok(TypeRef::from_type(hints, Vague(Unknown)))
}
} else {
if let Some((_, type_ref)) = lhs_hints {
Ok(type_ref)
} else {
Ok(TypeRef::from_type(hints, Vague(Unknown)))
}
}
}
ExprKind::Block(block) => {
let block_ref = block.infer_hints(state, hints)?;
match block_ref.0 {
ReturnKind::Hard => Ok(TypeRef::from_type(hints, Void)),
ReturnKind::Soft => Ok(block_ref.1),
}
}
}
}
fn typecheck(
&mut self,
state: &mut PassState<ErrorKind>,
hints: &ScopeHints,
hints: &TypeHints,
hint_t: Option<TypeKind>,
) -> Result<TypeKind, ErrorKind> {
match &mut self.0 {
@ -570,7 +417,7 @@ impl Literal {
impl TypeKind {
/// Assert that a type is already known and not vague. Return said type or
/// error.
fn assert_known(&self) -> Result<TypeKind, ErrorKind> {
pub fn assert_known(&self) -> Result<TypeKind, ErrorKind> {
self.known().map_err(ErrorKind::TypeIsVague)
}
@ -598,7 +445,7 @@ impl TypeKind {
})
}
fn resolve_hinted(&self, hints: &ScopeHints) -> TypeKind {
fn resolve_hinted(&self, hints: &TypeHints) -> TypeKind {
match self {
Vague(Hinted(idx)) => hints.retrieve_type(*idx).unwrap(),
_ => *self,

View File

@ -0,0 +1,228 @@
use std::iter;
use reid_lib::Function;
use super::{
pass::{Pass, PassState, ScopeVariable},
scopehints::{self, ScopeHints, TypeHint, TypeHints},
typecheck::ErrorKind,
types::{pick_return, ReturnType},
Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module,
ReturnKind, StmtKind,
TypeKind::*,
VagueType::*,
};
/// Struct used to implement a type-checking pass that can be performed on the
/// MIR.
pub struct TypeInference<'t> {
pub hints: &'t TypeHints,
}
impl<'t> Pass for TypeInference<'t> {
type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) {
for function in &mut module.functions {
let res = function.infer_hints(&self.hints, &mut state);
state.ok(res, function.block_meta());
}
}
}
impl FunctionDefinition {
fn infer_hints(
&mut self,
hints: &TypeHints,
state: &mut PassState<ErrorKind>,
) -> Result<(), ErrorKind> {
for param in &self.parameters {
let param_t = state.or_else(param.1.assert_known(), Vague(Unknown), self.signature());
let res = state
.scope
.variables
.set(
param.0.clone(),
ScopeVariable {
ty: param_t,
mutable: false,
},
)
.or(Err(ErrorKind::VariableAlreadyDefined(param.0.clone())));
state.ok(res, self.signature());
}
let scope_hints = ScopeHints::from(hints);
let return_type = self.return_type.clone();
let return_type_hint = scope_hints.from_type(&return_type).unwrap();
let mut ret = match &mut self.kind {
FunctionDefinitionKind::Local(block, _) => {
state.scope.return_type_hint = Some(self.return_type);
let block_res = block.infer_hints(state, &scope_hints);
state.ok(block_res.map(|(_, ty)| ty), self.block_meta())
}
FunctionDefinitionKind::Extern => Some(scope_hints.from_type(&Vague(Unknown)).unwrap()),
};
if let Some(ret) = &mut ret {
state.ok(ret.narrow(&return_type_hint), self.signature());
}
Ok(())
}
}
impl Block {
fn infer_hints<'s>(
&mut self,
state: &mut PassState<ErrorKind>,
outer_hints: &'s ScopeHints,
) -> Result<(ReturnKind, TypeHint<'s>), ErrorKind> {
let mut state = state.inner();
let inner_hints = outer_hints.inner();
for statement in &mut self.statements {
match &mut statement.0 {
StmtKind::Let(var, mutable, expr) => {
let mut var_ref =
state.ok(inner_hints.new_var(var.1.clone(), *mutable, var.0), var.2);
if let Some(var_ref) = &var_ref {
var.0 = var_ref.as_type();
}
let inferred = expr.infer_hints(&mut state, &inner_hints);
let mut expr_ty_ref = state.ok(inferred, expr.1);
if let (Some(var_ref), Some(expr_ty_ref)) =
(var_ref.as_mut(), expr_ty_ref.as_mut())
{
state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1);
}
}
StmtKind::Set(var, expr) => {
let var_ref = inner_hints.find_hint(&var.1);
if let Some((_, var_ref)) = &var_ref {
var.0 = var_ref.as_type()
}
let inferred = expr.infer_hints(&mut state, &inner_hints);
let expr_ty_ref = state.ok(inferred, expr.1);
if let (Some((_, mut var_ref)), Some(expr_ty_ref)) = (var_ref, expr_ty_ref) {
state.ok(var_ref.narrow(&expr_ty_ref), var.2 + expr.1);
}
}
StmtKind::Import(_) => todo!(),
StmtKind::Expression(expr) => {
let expr_res = expr.infer_hints(&mut state, &inner_hints);
state.ok(expr_res, expr.1);
}
};
}
if let Some(ret_expr) = &mut self.return_expression {
let ret_res = ret_expr.1.infer_hints(&mut state, &inner_hints);
state.ok(ret_res, ret_expr.1 .1);
}
let (kind, ty) = self.return_type().ok().unwrap_or((ReturnKind::Soft, Void));
let mut ret_type_ref = outer_hints.from_type(&ty).unwrap();
if kind == ReturnKind::Hard {
if let Some(hint) = state.scope.return_type_hint {
state.ok(
ret_type_ref.narrow(&mut outer_hints.from_type(&hint).unwrap()),
self.meta,
);
}
}
Ok((kind, ret_type_ref))
}
}
impl Expression {
fn infer_hints<'s>(
&mut self,
state: &mut PassState<ErrorKind>,
hints: &'s ScopeHints<'s>,
) -> Result<TypeHint<'s>, ErrorKind> {
match &mut self.0 {
ExprKind::Variable(var) => {
let hint = hints
.find_hint(&var.1)
.map(|(_, hint)| hint)
.ok_or(ErrorKind::VariableNotDefined(var.1.clone()));
if let Ok(hint) = &hint {
var.0 = hint.as_type()
}
hint
}
ExprKind::Literal(literal) => Ok(hints.from_type(&literal.as_type()).unwrap()),
ExprKind::BinOp(op, lhs, rhs) => {
let mut lhs_ref = lhs.infer_hints(state, hints)?;
let mut rhs_ref = rhs.infer_hints(state, hints)?;
hints.binop(op, &mut lhs_ref, &mut rhs_ref)
}
ExprKind::FunctionCall(function_call) => {
let fn_call = state
.scope
.function_returns
.get(&function_call.name)
.ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()))?
.clone();
let true_params_iter = fn_call.params.iter().chain(iter::repeat(&Vague(Unknown)));
for (param_expr, param_t) in
function_call.parameters.iter_mut().zip(true_params_iter)
{
let expr_res = param_expr.infer_hints(state, hints);
if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) {
state.ok(
param_ref.narrow(&mut hints.from_type(param_t).unwrap()),
param_expr.1,
);
}
}
Ok(hints.from_type(&fn_call.ret).unwrap())
}
ExprKind::If(IfExpression(cond, lhs, rhs)) => {
let cond_res = cond.infer_hints(state, hints);
let cond_hints = state.ok(cond_res, cond.1);
if let Some(mut cond_hints) = cond_hints {
state.ok(
cond_hints.narrow(&mut hints.from_type(&Bool).unwrap()),
cond.1,
);
}
let lhs_res = lhs.infer_hints(state, hints);
let lhs_hints = state.ok(lhs_res, cond.1);
if let Some(rhs) = rhs {
let rhs_res = rhs.infer_hints(state, hints);
let rhs_hints = state.ok(rhs_res, cond.1);
if let (Some(mut lhs_hints), Some(mut rhs_hints)) = (lhs_hints, rhs_hints) {
state.ok(lhs_hints.1.narrow(&mut rhs_hints.1), self.1);
Ok(pick_return(lhs_hints, rhs_hints).1)
} else {
// Failed to retrieve types from either
Ok(hints.from_type(&Vague(Unknown)).unwrap())
}
} else {
if let Some((_, type_ref)) = lhs_hints {
Ok(type_ref)
} else {
Ok(hints.from_type(&Vague(Unknown)).unwrap())
}
}
}
ExprKind::Block(block) => {
let block_ref = block.infer_hints(state, hints)?;
match block_ref.0 {
ReturnKind::Hard => Ok(hints.from_type(&Void).unwrap()),
ReturnKind::Soft => Ok(block_ref.1),
}
}
}
}
}