Implement type inference for structs

This commit is contained in:
Sofia 2025-07-15 23:16:37 +03:00
parent e13b6349f0
commit 1d1e574136
10 changed files with 156 additions and 59 deletions

View File

@ -2,7 +2,7 @@ use std::path::PathBuf;
use crate::{
ast::{self},
mir::{self, NamedVariableRef, StmtKind},
mir::{self, NamedVariableRef, StmtKind, StructType},
};
impl mir::Context {
@ -68,12 +68,12 @@ impl ast::Module {
name: name.clone(),
kind: match kind {
ast::TypeDefinitionKind::Struct(struct_definition_fields) => {
mir::TypeDefinitionKind::Struct(
mir::TypeDefinitionKind::Struct(StructType(
struct_definition_fields
.iter()
.map(|s| (s.name.clone(), s.ty.clone().into()))
.collect(),
)
))
}
},
meta: (*range).into(),

View File

@ -518,7 +518,7 @@ impl TypeKind {
TypeKind::Array(elem_t, _) => Type::Ptr(Box::new(elem_t.get_type())),
TypeKind::Void => Type::Void,
TypeKind::Vague(_) => panic!("Tried to compile a vague type!"),
TypeKind::CustomType(_, custom_type_kind) => todo!("codegen for custom type"),
TypeKind::CustomType(_) => todo!("codegen for custom type"),
}
}
}

View File

@ -55,7 +55,7 @@ impl Display for TypeDefinitionKind {
writeln!(f)?;
let mut state = Default::default();
let mut inner_f = PadAdapter::wrap(f, &mut state);
for (field_name, field_ty) in items {
for (field_name, field_ty) in &items.0 {
writeln!(inner_f, "{}: {:?},", field_name, field_ty)?;
}
f.write_char('}')

View File

@ -2,9 +2,9 @@
//! Reid. It contains a simplified version of Reid which can be e.g.
//! typechecked.
use std::path::PathBuf;
use std::{collections::HashMap, path::PathBuf};
use crate::{ast::Type, token_stream::TokenRange};
use crate::token_stream::TokenRange;
mod display;
pub mod linker;
@ -65,8 +65,8 @@ pub enum TypeKind {
StringPtr,
#[error("[{0}; {1}]")]
Array(Box<TypeKind>, u64),
#[error("{0} ({1})")]
CustomType(String, CustomTypeKind),
#[error("{0}")]
CustomType(String),
#[error(transparent)]
Vague(#[from] VagueType),
}
@ -81,13 +81,10 @@ pub enum VagueType {
TypeRef(usize),
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum CustomTypeKind {
#[error("struct({0:?})")]
Struct(Vec<TypeKind>),
#[error("CustomType")]
Unknown,
}
#[derive(Clone, Debug)]
pub struct StructType(pub Vec<(String, TypeKind)>);
pub type TypedefMap = HashMap<String, TypeDefinitionKind>;
impl TypeKind {
pub fn known(&self) -> Result<TypeKind, VagueType> {
@ -100,7 +97,7 @@ impl TypeKind {
}
impl TypeKind {
pub fn signed(&self) -> bool {
pub fn signed(&self, typedefs: &TypedefMap) -> bool {
match self {
TypeKind::Void => false,
TypeKind::Vague(_) => false,
@ -117,11 +114,13 @@ impl TypeKind {
TypeKind::U128 => false,
TypeKind::StringPtr => false,
TypeKind::Array(_, _) => false,
TypeKind::CustomType(_, _) => false,
TypeKind::CustomType(name) => match typedefs.get(name).unwrap() {
TypeDefinitionKind::Struct(_) => false,
},
}
}
pub fn is_maths(&self) -> bool {
pub fn is_maths(&self, typedefs: &TypedefMap) -> bool {
use TypeKind::*;
match &self {
I8 => true,
@ -139,7 +138,9 @@ impl TypeKind {
Void => false,
StringPtr => false,
Array(_, _) => false,
TypeKind::CustomType(_, _) => false,
TypeKind::CustomType(name) => match typedefs.get(name).unwrap() {
TypeDefinitionKind::Struct(_) => false,
},
}
}
}
@ -328,9 +329,9 @@ pub struct TypeDefinition {
pub meta: Metadata,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum TypeDefinitionKind {
Struct(Vec<(String, TypeKind)>),
Struct(StructType),
}
#[derive(Debug)]

View File

@ -103,13 +103,17 @@ impl<T: Clone + std::fmt::Debug> Storage<T> {
pub fn get(&self, key: &String) -> Option<&T> {
self.0.get(key)
}
pub fn get_mut(&mut self, key: &String) -> Option<&mut T> {
self.0.get_mut(key)
}
}
#[derive(Clone, Default, Debug)]
pub struct Scope {
pub function_returns: Storage<ScopeFunction>,
pub variables: Storage<ScopeVariable>,
pub types: Storage<ScopeTypedefKind>,
pub types: Storage<TypeDefinitionKind>,
/// Hard Return type of this scope, if inside a function
pub return_type_hint: Option<TypeKind>,
}
@ -126,11 +130,6 @@ pub struct ScopeVariable {
pub mutable: bool,
}
#[derive(Clone, Debug)]
pub enum ScopeTypedefKind {
Struct(Vec<(String, TypeKind)>),
}
impl Scope {
pub fn inner(&self) -> Scope {
Scope {
@ -140,6 +139,10 @@ impl Scope {
return_type_hint: self.return_type_hint.clone(),
}
}
pub fn get_typedefs(&self) -> &TypedefMap {
&self.types.0
}
}
pub struct PassState<'st, 'sc, TError: STDError + Clone> {
@ -227,7 +230,7 @@ impl Module {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) {
for typedef in &self.typedefs {
let kind = match &typedef.kind {
TypeDefinitionKind::Struct(fields) => ScopeTypedefKind::Struct(fields.clone()),
TypeDefinitionKind::Struct(fields) => TypeDefinitionKind::Struct(fields.clone()),
};
scope.types.set(typedef.name.clone(), kind).ok();
}

View File

@ -43,6 +43,12 @@ pub enum ErrorKind {
TriedIndexingNonArray(TypeKind),
#[error("Index {0} out of bounds ({1})")]
IndexOutOfBounds(u64, u64),
#[error("No such type {0} could be found")]
NoSuchType(String),
#[error("Attempted to access field of non-struct type of {0}")]
TriedAccessingNonStruct(TypeKind),
#[error("No such struct-field on type {0}")]
NoSuchField(String),
}
/// Struct used to implement a type-checking pass that can be performed on the

View File

@ -4,17 +4,18 @@
//! must then be passed through TypeCheck with the same [`TypeRefs`] in order to
//! place the correct types from the IDs and check that there are no issues.
use std::iter;
use std::{convert::Infallible, iter};
use crate::{mir::TypeKind, util::try_all};
use super::{
pass::{Pass, PassState},
pass::{self, Pass, PassState},
typecheck::ErrorKind,
typerefs::{ScopeTypeRefs, TypeRef, TypeRefs},
types::{pick_return, ReturnType},
Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression,
IndexedVariableReference, Module, NamedVariableRef, ReturnKind, StmtKind,
IndexedVariableReference, IndexedVariableReferenceKind, Module, NamedVariableRef, ReturnKind,
StmtKind, StructType, TypeDefinitionKind,
TypeKind::*,
VagueType::*,
};
@ -106,7 +107,7 @@ impl Block {
}
StmtKind::Set(var, expr) => {
// Get the TypeRef for this variable declaration
let var_ref = var.find_hint(&inner_hints)?;
let var_ref = var.find_hint(&state, &inner_hints)?;
// If ok, update the MIR type to this TypeRef
if let Some((_, var_ref)) = &var_ref {
@ -155,14 +156,15 @@ impl Block {
impl IndexedVariableReference {
fn find_hint<'s>(
&self,
state: &PassState<ErrorKind>,
hints: &'s ScopeTypeRefs,
) -> Result<Option<(bool, TypeRef<'s>)>, ErrorKind> {
match &self.kind {
super::IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
Ok(hints.find_hint(&name))
IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
Ok(hints.find_var(&name))
}
super::IndexedVariableReferenceKind::ArrayIndex(inner, _) => {
if let Some((mutable, inner_ref)) = inner.find_hint(hints)? {
IndexedVariableReferenceKind::ArrayIndex(inner, _) => {
if let Some((mutable, inner_ref)) = inner.find_hint(state, hints)? {
// Check that the resolved type is at least an array, no
// need for further resolution.
let inner_ty = inner_ref.resolve_weak().unwrap();
@ -177,8 +179,30 @@ impl IndexedVariableReference {
Ok(None)
}
}
super::IndexedVariableReferenceKind::StructIndex(indexed_variable_reference, _) => {
todo!("struct index refrence type inference")
IndexedVariableReferenceKind::StructIndex(inner, field_name) => {
if let Some((mutable, inner_ref)) = inner.find_hint(state, hints)? {
// Check that the resolved type is at least an array, no
// need for further resolution.
let inner_ty = inner_ref.resolve_weak().unwrap();
match &inner_ty {
CustomType(struct_name) => match state.scope.types.get(&struct_name) {
Some(kind) => match kind {
TypeDefinitionKind::Struct(struct_ty) => Ok(hints
.from_type(
&struct_ty
.get_field_ty(field_name)
.cloned()
.ok_or(ErrorKind::NoSuchField(self.get_name()))?,
)
.map(|v| (mutable, v))),
},
None => Err(ErrorKind::TriedAccessingNonStruct(inner_ty.clone())),
},
_ => Err(ErrorKind::TriedAccessingNonStruct(inner_ty)),
}
} else {
Ok(None)
}
}
}
}
@ -194,7 +218,7 @@ impl Expression {
ExprKind::Variable(var) => {
// Find variable type
let type_ref = type_refs
.find_hint(&var.1)
.find_var(&var.1)
.map(|(_, hint)| hint)
.ok_or(ErrorKind::VariableNotDefined(var.1.clone()));
@ -339,10 +363,71 @@ impl Expression {
}
}
}
ExprKind::StructIndex(expression, type_kind, _) => {
todo!("type inference for struct indexes")
ExprKind::StructIndex(expression, type_kind, field_name) => {
let expr_ty = expression.infer_types(state, type_refs)?;
// Check that the resolved type is at least a struct, no
// need for further resolution.
let kind = expr_ty.resolve_weak().unwrap();
match kind {
CustomType(name) => {
let struct_ty = state.scope.get_struct_type_mut(&name)?;
match struct_ty.get_field_ty_mut(&field_name) {
Some(field_ty) => {
let elem_ty = type_refs.from_type(&type_kind).unwrap();
*field_ty = elem_ty.as_type().clone();
Ok(elem_ty)
}
None => Err(ErrorKind::NoSuchField(field_name.clone())),
}
}
_ => Err(ErrorKind::TriedAccessingNonStruct(kind)),
}
}
ExprKind::Struct(struct_name, fields) => {
let expected_struct_ty = state.scope.get_struct_type(&struct_name)?.clone();
for field in fields {
if let Some(expected_field_ty) = expected_struct_ty.get_field_ty(&field.0) {
let field_ty = field.1.infer_types(state, type_refs);
if let Some(mut field_ty) = state.ok(field_ty, field.1 .1) {
field_ty.narrow(&type_refs.from_type(&expected_field_ty).unwrap());
}
} else {
state.ok::<_, Infallible>(
Err(ErrorKind::NoSuchField(format!(
"{}.{}",
struct_name, field.0
))),
field.1 .1,
);
}
}
Ok(type_refs
.from_type(&TypeKind::CustomType(struct_name.clone()))
.unwrap())
}
ExprKind::Struct(_, items) => todo!("type inference for struct expression"),
}
}
}
impl pass::Scope {
fn get_struct_type(&self, name: &String) -> Result<&StructType, ErrorKind> {
let ty = self
.types
.get(&name)
.ok_or(ErrorKind::NoSuchType(name.clone()))?;
match ty {
TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty),
}
}
fn get_struct_type_mut(&mut self, name: &String) -> Result<&mut StructType, ErrorKind> {
let ty = self
.types
.get_mut(&name)
.ok_or(ErrorKind::NoSuchType(name.clone()))?;
match ty {
TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty),
}
}
}

View File

@ -8,7 +8,7 @@ use crate::mir::VagueType;
use super::{
typecheck::{Collapsable, ErrorKind},
BinaryOperator, TypeKind,
BinaryOperator, TypeDefinition, TypeKind,
};
#[derive(Clone)]
@ -209,12 +209,12 @@ impl<'outer> ScopeTypeRefs<'outer> {
}
}
pub fn find_hint(&'outer self, name: &String) -> Option<(bool, TypeRef<'outer>)> {
pub fn find_var(&'outer self, name: &String) -> Option<(bool, TypeRef<'outer>)> {
self.variables
.borrow()
.get(name)
.map(|(mutable, idx)| (*mutable, TypeRef(idx.clone(), self)))
.or(self.outer.map(|o| o.find_hint(name)).flatten())
.or(self.outer.map(|o| o.find_var(name)).flatten())
}
pub fn binop(

View File

@ -1,3 +1,5 @@
use std::collections::HashMap;
use crate::util::try_all;
use super::*;
@ -29,6 +31,16 @@ impl TypeKind {
}
}
impl StructType {
pub fn get_field_ty(&self, name: &String) -> Option<&TypeKind> {
self.0.iter().find(|(n, _)| n == name).map(|(_, ty)| ty)
}
pub fn get_field_ty_mut(&mut self, name: &String) -> Option<&mut TypeKind> {
self.0.iter_mut().find(|(n, _)| n == name).map(|(_, ty)| ty)
}
}
pub trait ReturnType {
/// Return the return type of this node
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther>;
@ -108,18 +120,8 @@ impl ReturnType for Expression {
TypeKind::Array(Box::new(first.1), expressions.len() as u64),
))
}
StructIndex(expression, type_kind, _) => todo!("todo return type for struct index"),
Struct(name, items) => {
let f_types = try_all(items.iter().map(|e| e.1.return_type()).collect())
.map_err(|e| unsafe { e.get_unchecked(0).clone() })?
.iter()
.map(|r| r.1.clone())
.collect();
Ok((
ReturnKind::Soft,
TypeKind::CustomType(name.clone(), CustomTypeKind::Struct(f_types)),
))
}
StructIndex(_, type_kind, _) => Ok((ReturnKind::Soft, type_kind.clone())),
Struct(name, _) => Ok((ReturnKind::Soft, TypeKind::CustomType(name.clone()))),
}
}
}

View File

@ -11,5 +11,5 @@ fn main() -> u32 {
second: 3,
};
return Test.second;
return value.second;
}