Implement typechecking for structs

This commit is contained in:
Sofia 2025-07-16 00:16:53 +03:00
parent aafab49f82
commit 77439ee34a
3 changed files with 165 additions and 34 deletions

View File

@ -103,10 +103,6 @@ impl<T: Clone + std::fmt::Debug> Storage<T> {
pub fn get(&self, key: &String) -> Option<&T> {
self.0.get(key)
}
pub fn get_mut(&mut self, key: &String) -> Option<&mut T> {
self.0.get_mut(key)
}
}
#[derive(Clone, Default, Debug)]

View File

@ -1,6 +1,6 @@
//! This module contains code relevant to doing a type checking pass on the MIR.
//! During typechecking relevant types are also coerced if possible.
use std::{convert::Infallible, iter};
use std::{collections::HashSet, convert::Infallible, iter};
use crate::{mir::*, util::try_all};
use VagueType as Vague;
@ -49,6 +49,12 @@ pub enum ErrorKind {
TriedAccessingNonStruct(TypeKind),
#[error("No such struct-field on type {0}")]
NoSuchField(String),
#[error("Struct field declared twice {0}")]
DuplicateStructField(String),
#[error("Type declared twice {0}")]
DuplicateTypeName(String),
#[error("Recursive type definition: {0}.{1}")]
RecursiveTypeDefinition(String, String),
}
/// Struct used to implement a type-checking pass that can be performed on the
@ -57,10 +63,72 @@ pub struct TypeCheck<'t> {
pub refs: &'t TypeRefs,
}
fn check_typedefs_for_recursion<'a, 'b>(
defmap: &'b HashMap<&'a String, &'b TypeDefinition>,
typedef: &'b TypeDefinition,
mut seen: HashSet<String>,
state: &mut PassState<ErrorKind>,
) {
match &typedef.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field_ty in fields.iter().map(|(_, ty)| ty) {
if let TypeKind::CustomType(name) = field_ty {
if seen.contains(name) {
state.ok::<_, Infallible>(
Err(ErrorKind::RecursiveTypeDefinition(
typedef.name.clone(),
name.clone(),
)),
typedef.meta,
);
} else {
seen.insert(name.clone());
if let Some(inner_typedef) = defmap.get(name) {
check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state)
}
}
}
}
}
}
}
impl<'t> Pass for TypeCheck<'t> {
type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) {
let mut defmap = HashMap::new();
for typedef in &module.typedefs {
let TypeDefinition { name, kind, meta } = &typedef;
match kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
let mut fieldmap = HashMap::new();
for (name, field_ty) in fields {
if let Some(_) = fieldmap.insert(name, field_ty) {
state.ok::<_, Infallible>(
Err(ErrorKind::DuplicateStructField(name.clone())),
meta.clone(),
);
}
}
}
}
if let Some(_) = defmap.insert(&typedef.name, typedef) {
state.ok::<_, Infallible>(
Err(ErrorKind::DuplicateTypeName(name.clone())),
meta.clone(),
);
}
}
let seen = HashSet::new();
for typedef in defmap.values() {
let mut curr = seen.clone();
curr.insert(typedef.name.clone());
check_typedefs_for_recursion(&defmap, typedef, HashSet::new(), &mut state);
}
for function in &mut module.functions {
let res = function.typecheck(&self.refs, &mut state.inner());
state.ok(res, function.block_meta());
@ -185,7 +253,8 @@ impl Block {
StmtKind::Set(variable_reference, expression) => {
if let Some(var) = state
.ok(
variable_reference.get_variable(&state.scope.variables),
variable_reference
.get_variable(&state.scope.variables, &state.scope.types),
variable_reference.meta,
)
.flatten()
@ -493,10 +562,59 @@ impl Expression {
}
}
}
ExprKind::StructIndex(expression, type_kind, _) => {
todo!("typechecking for struct index")
ExprKind::StructIndex(expression, type_kind, field_name) => {
// Resolve expected type
let expected_ty = type_kind.resolve_hinted(hints);
// Typecheck expression
let expr_res = expression.typecheck(state, hints, Some(&expected_ty));
let expr_ty =
state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), expression.1);
if let TypeKind::CustomType(struct_name) = expr_ty {
let struct_type = state.scope.get_struct_type(&struct_name)?;
if let Some(expr_field_ty) = struct_type.get_field_ty(&field_name) {
// Make sure they are the same
let true_ty = state.or_else(
expr_field_ty.collapse_into(&expected_ty),
TypeKind::Vague(Vague::Unknown),
self.1,
);
*type_kind = true_ty.clone();
// Update possibly resolved type
Ok(true_ty)
} else {
Err(ErrorKind::NoSuchField(field_name.clone()))
}
} else {
Err(ErrorKind::TriedAccessingNonStruct(expr_ty))
}
}
ExprKind::Struct(struct_name, items) => {
let struct_def = state.scope.get_struct_type(struct_name)?.clone();
for (field_name, field_expr) in items {
// Get expected type, or error if field does not exist
let expected_ty = state.or_else(
struct_def
.get_field_ty(field_name)
.ok_or(ErrorKind::NoSuchField(format!(
"{}.{}",
struct_name, field_name
))),
&TypeKind::Vague(VagueType::Unknown),
field_expr.1,
);
// Typecheck the actual expression
let expr_res = field_expr.typecheck(state, hints, Some(expected_ty));
let expr_ty =
state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), field_expr.1);
// Make sure both are the same type, report error if not
state.ok(expr_ty.collapse_into(&expr_ty), field_expr.1);
}
Ok(TypeKind::CustomType(struct_name.clone()))
}
ExprKind::Struct(_, items) => todo!("typechecking for struct expression"),
}
}
}
@ -505,13 +623,14 @@ impl IndexedVariableReference {
fn get_variable(
&self,
storage: &Storage<ScopeVariable>,
types: &Storage<TypeDefinitionKind>,
) -> Result<Option<ScopeVariable>, ErrorKind> {
match &self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
Ok(storage.get(&name).cloned())
}
IndexedVariableReferenceKind::ArrayIndex(inner_ref, _) => {
if let Some(var) = inner_ref.get_variable(storage)? {
if let Some(var) = inner_ref.get_variable(storage, types)? {
match &var.ty {
TypeKind::Array(inner_ty, _) => Ok(Some(ScopeVariable {
ty: *inner_ty.clone(),
@ -523,8 +642,34 @@ impl IndexedVariableReference {
Ok(None)
}
}
IndexedVariableReferenceKind::StructIndex(indexed_variable_reference, _) => {
todo!("struct index refrence typecheck")
IndexedVariableReferenceKind::StructIndex(var_ref, field_name) => {
if let Some(var) = var_ref.get_variable(storage, types)? {
match &var.ty {
TypeKind::CustomType(type_name) => {
if let Some(kind) = types.get(type_name) {
match &kind {
TypeDefinitionKind::Struct(struct_type) => {
if let Some((_, field_ty)) =
struct_type.0.iter().find(|(n, _)| n == field_name)
{
Ok(Some(ScopeVariable {
ty: field_ty.clone(),
mutable: var.mutable,
}))
} else {
Err(ErrorKind::NoSuchField(field_name.clone()))
}
}
}
} else {
Err(ErrorKind::NoSuchType(type_name.clone()))
}
}
_ => Err(ErrorKind::TriedAccessingNonStruct(var.ty.clone())),
}
} else {
Ok(None)
}
}
}
}
@ -656,3 +801,15 @@ impl Collapsable for ScopeFunction {
})
}
}
impl pass::Scope {
pub fn get_struct_type(&self, name: &String) -> Result<&StructType, ErrorKind> {
let ty = self
.types
.get(&name)
.ok_or(ErrorKind::NoSuchType(name.clone()))?;
match ty {
TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty),
}
}
}

View File

@ -412,25 +412,3 @@ impl Expression {
}
}
}
impl pass::Scope {
fn get_struct_type(&self, name: &String) -> Result<&StructType, ErrorKind> {
let ty = self
.types
.get(&name)
.ok_or(ErrorKind::NoSuchType(name.clone()))?;
match ty {
TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty),
}
}
fn get_struct_type_mut(&mut self, name: &String) -> Result<&mut StructType, ErrorKind> {
let ty = self
.types
.get_mut(&name)
.ok_or(ErrorKind::NoSuchType(name.clone()))?;
match ty {
TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty),
}
}
}