Add typeinference and typechecking for Associated Functions

This commit is contained in:
Sofia 2025-07-27 18:24:49 +03:00
parent 46668b7099
commit 4d7c17a854
6 changed files with 133 additions and 13 deletions

View File

@ -9,6 +9,6 @@ impl Otus {
}
fn main() -> u32 {
let otus = Otus {};
let otus = Otus { field: 17 };
return Otus::test(&otus);
}

View File

@ -443,7 +443,7 @@ impl Expression {
},
Err(_) => Ok((ReturnKind::Soft, type_kind.clone())),
},
AssociatedFunctionCall(type_kind, function_call) => todo!(),
AssociatedFunctionCall(_, fcall) => fcall.return_type(),
}
}
@ -462,7 +462,7 @@ impl Expression {
ExprKind::FunctionCall(_) => None,
ExprKind::If(_) => None,
ExprKind::CastTo(expression, _) => expression.backing_var(),
ExprKind::AssociatedFunctionCall(type_kind, function_call) => None,
ExprKind::AssociatedFunctionCall(..) => None,
}
}

View File

@ -123,7 +123,8 @@ pub type BinopMap = Storage<BinopKey, ScopeBinopDef>;
pub struct Scope<Data: Clone + Default> {
pub module_id: Option<SourceModuleId>,
pub binops: BinopMap,
pub function_returns: Storage<String, ScopeFunction>,
pub associated_functions: Storage<AssociatedFunctionKey, ScopeFunction>,
pub functions: Storage<String, ScopeFunction>,
pub variables: Storage<String, ScopeVariable>,
pub types: Storage<CustomTypeKey, TypeDefinition>,
/// Hard Return type of this scope, if inside a function
@ -135,7 +136,8 @@ impl<Data: Clone + Default> Scope<Data> {
pub fn inner(&self) -> Scope<Data> {
Scope {
module_id: self.module_id,
function_returns: self.function_returns.clone(),
associated_functions: self.associated_functions.clone(),
functions: self.functions.clone(),
variables: self.variables.clone(),
binops: self.binops.clone(),
types: self.types.clone(),
@ -181,6 +183,9 @@ pub struct ScopeVariable {
pub mutable: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct AssociatedFunctionKey(pub TypeKind, pub String);
#[derive(Clone, Debug, Eq)]
pub struct BinopKey {
pub params: (TypeKind, TypeKind),
@ -389,7 +394,7 @@ impl Module {
for function in &self.functions {
scope
.function_returns
.functions
.set(
function.name.clone(),
ScopeFunction {
@ -400,6 +405,19 @@ impl Module {
.ok();
}
for (ty, function) in &self.associated_functions {
scope
.associated_functions
.set(
AssociatedFunctionKey(ty.clone(), function.name.clone()),
ScopeFunction {
ret: function.return_type.clone(),
params: function.parameters.iter().cloned().map(|v| v.1).collect(),
},
)
.ok();
}
pass.module(self, PassState::from(state, scope, Some(self.module_id)))?;
for function in &mut self.functions {

View File

@ -24,12 +24,16 @@ pub enum ErrorKind {
TypesIncompatible(TypeKind, TypeKind),
#[error("Variable not defined: {0}")]
VariableNotDefined(String),
#[error("Function not defined: {0}")]
#[error("Function {0} not defined")]
FunctionNotDefined(String),
#[error("Function {0} not defined for type {1}")]
AssocFunctionNotDefined(String, TypeKind),
#[error("Expected a return type of {0}, got {1} instead")]
ReturnTypeMismatch(TypeKind, TypeKind),
#[error("Function {0} already defined {1}")]
FunctionAlreadyDefined(String, ErrorTypedefKind),
#[error("Function {0}::{1} already defined {2}")]
AssocFunctionAlreadyDefined(TypeKind, String, ErrorTypedefKind),
#[error("Variable already defined: {0}")]
VariableAlreadyDefined(String),
#[error("Variable {0} is not declared as mutable")]

View File

@ -443,7 +443,7 @@ impl Expression {
ExprKind::FunctionCall(function_call) => {
let true_function = state
.scope
.function_returns
.functions
.get(&function_call.name)
.cloned()
.ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()));
@ -724,7 +724,56 @@ impl Expression {
let expr = expression.typecheck(state, typerefs, HintKind::Default)?;
expr.resolve_ref(typerefs).cast_into(type_kind)
}
ExprKind::AssociatedFunctionCall(type_kind, function_call) => todo!(),
ExprKind::AssociatedFunctionCall(type_kind, function_call) => {
let true_function = state
.scope
.associated_functions
.get(&pass::AssociatedFunctionKey(
type_kind.clone(),
function_call.name.clone(),
))
.cloned()
.ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()));
if let Some(f) = state.ok(true_function, self.1) {
let param_len_given = function_call.parameters.len();
let param_len_expected = f.params.len();
// Check that there are the same number of parameters given
// as expected
if param_len_given != param_len_expected {
state.ok::<_, Infallible>(
Err(ErrorKind::InvalidAmountParameters(
function_call.name.clone(),
param_len_given,
param_len_expected,
)),
self.1,
);
}
let true_params_iter = f
.params
.into_iter()
.chain(iter::repeat(TypeKind::Vague(Vague::Unknown)));
for (param, true_param_t) in function_call.parameters.iter_mut().zip(true_params_iter) {
// Typecheck every param separately
let param_res = param.typecheck(state, &typerefs, HintKind::Coerce(true_param_t.clone()));
let param_t = state.or_else(param_res, TypeKind::Vague(Vague::Unknown), param.1);
state.ok(param_t.narrow_into(&true_param_t), param.1);
}
// Make sure function return type is the same as the claimed
// return type
let ret_t = f.ret.narrow_into(&function_call.return_type.resolve_ref(typerefs))?;
// Update typing to be more accurate
function_call.return_type = ret_t.clone();
Ok(ret_t.resolve_ref(typerefs))
} else {
Ok(function_call.return_type.clone().resolve_ref(typerefs))
}
}
}
}
}

View File

@ -12,8 +12,8 @@ use std::{
use crate::{
mir::{
BinopDefinition, Block, CustomTypeKey, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind,
IfExpression, Module, ReturnKind, StmtKind, TypeKind, WhileStatement,
pass::AssociatedFunctionKey, BinopDefinition, Block, CustomTypeKey, ExprKind, Expression, FunctionDefinition,
FunctionDefinitionKind, IfExpression, Module, ReturnKind, StmtKind, TypeKind, WhileStatement,
},
util::try_all,
};
@ -60,6 +60,29 @@ impl<'t> Pass for TypeInference<'t> {
}
}
let mut seen_assoc_functions = HashMap::new();
for (ty, function) in &mut module.associated_functions {
if let Some(kind) = seen_assoc_functions.get(&(ty.clone(), function.name.clone())) {
state.note_errors(
&vec![ErrorKind::AssocFunctionAlreadyDefined(
ty.clone(),
function.name.clone(),
*kind,
)],
function.signature(),
);
} else {
seen_assoc_functions.insert(
(ty.clone(), function.name.clone()),
match function.kind {
FunctionDefinitionKind::Local(..) => ErrorTypedefKind::Local,
FunctionDefinitionKind::Extern(..) => ErrorTypedefKind::Extern,
FunctionDefinitionKind::Intrinsic(..) => ErrorTypedefKind::Intrinsic,
},
);
}
}
let mut seen_binops = HashSet::new();
for binop in &module.binop_defs {
let binop_key = BinopKey {
@ -365,7 +388,7 @@ impl Expression {
// Get function definition and types
let fn_call = state
.scope
.function_returns
.functions
.get(&function_call.name)
.ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()))?
.clone();
@ -566,7 +589,33 @@ impl Expression {
expression.infer_types(state, type_refs)?;
Ok(type_refs.from_type(type_kind).unwrap())
}
ExprKind::AssociatedFunctionCall(type_kind, function_call) => todo!(),
ExprKind::AssociatedFunctionCall(type_kind, function_call) => {
// Get function definition and types
let fn_call = state
.scope
.associated_functions
.get(&AssociatedFunctionKey(type_kind.clone(), function_call.name.clone()))
.ok_or(ErrorKind::AssocFunctionNotDefined(
function_call.name.clone(),
type_kind.clone(),
))?
.clone();
// Infer param expression types and narrow them to the
// expected function parameters (or Unknown types if too
// many were provided)
let true_params_iter = fn_call.params.iter().chain(iter::repeat(&Vague(Unknown)));
for (param_expr, param_t) in function_call.parameters.iter_mut().zip(true_params_iter) {
let expr_res = param_expr.infer_types(state, type_refs);
if let Some(mut param_ref) = state.ok(expr_res, param_expr.1) {
param_ref.narrow(&mut type_refs.from_type(param_t).unwrap());
}
}
// Provide function return type
Ok(type_refs.from_type(&fn_call.ret).unwrap())
}
}
}
}