From 77439ee34ac38946b133dab2450b6962555a797f Mon Sep 17 00:00:00 2001 From: sofia Date: Wed, 16 Jul 2025 00:16:53 +0300 Subject: [PATCH] Implement typechecking for structs --- reid/src/mir/pass.rs | 4 - reid/src/mir/typecheck.rs | 173 ++++++++++++++++++++++++++++++++-- reid/src/mir/typeinference.rs | 22 ----- 3 files changed, 165 insertions(+), 34 deletions(-) diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index 397354e..abde672 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -103,10 +103,6 @@ impl Storage { pub fn get(&self, key: &String) -> Option<&T> { self.0.get(key) } - - pub fn get_mut(&mut self, key: &String) -> Option<&mut T> { - self.0.get_mut(key) - } } #[derive(Clone, Default, Debug)] diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index 3432974..16257d8 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -1,6 +1,6 @@ //! This module contains code relevant to doing a type checking pass on the MIR. //! During typechecking relevant types are also coerced if possible. -use std::{convert::Infallible, iter}; +use std::{collections::HashSet, convert::Infallible, iter}; use crate::{mir::*, util::try_all}; use VagueType as Vague; @@ -49,6 +49,12 @@ pub enum ErrorKind { TriedAccessingNonStruct(TypeKind), #[error("No such struct-field on type {0}")] NoSuchField(String), + #[error("Struct field declared twice {0}")] + DuplicateStructField(String), + #[error("Type declared twice {0}")] + DuplicateTypeName(String), + #[error("Recursive type definition: {0}.{1}")] + RecursiveTypeDefinition(String, String), } /// Struct used to implement a type-checking pass that can be performed on the @@ -57,10 +63,72 @@ pub struct TypeCheck<'t> { pub refs: &'t TypeRefs, } +fn check_typedefs_for_recursion<'a, 'b>( + defmap: &'b HashMap<&'a String, &'b TypeDefinition>, + typedef: &'b TypeDefinition, + mut seen: HashSet, + state: &mut PassState, +) { + match &typedef.kind { + TypeDefinitionKind::Struct(StructType(fields)) => { + for field_ty in fields.iter().map(|(_, ty)| ty) { + if let TypeKind::CustomType(name) = field_ty { + if seen.contains(name) { + state.ok::<_, Infallible>( + Err(ErrorKind::RecursiveTypeDefinition( + typedef.name.clone(), + name.clone(), + )), + typedef.meta, + ); + } else { + seen.insert(name.clone()); + if let Some(inner_typedef) = defmap.get(name) { + check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state) + } + } + } + } + } + } +} + impl<'t> Pass for TypeCheck<'t> { type TError = ErrorKind; fn module(&mut self, module: &mut Module, mut state: PassState) { + let mut defmap = HashMap::new(); + for typedef in &module.typedefs { + let TypeDefinition { name, kind, meta } = &typedef; + match kind { + TypeDefinitionKind::Struct(StructType(fields)) => { + let mut fieldmap = HashMap::new(); + for (name, field_ty) in fields { + if let Some(_) = fieldmap.insert(name, field_ty) { + state.ok::<_, Infallible>( + Err(ErrorKind::DuplicateStructField(name.clone())), + meta.clone(), + ); + } + } + } + } + + if let Some(_) = defmap.insert(&typedef.name, typedef) { + state.ok::<_, Infallible>( + Err(ErrorKind::DuplicateTypeName(name.clone())), + meta.clone(), + ); + } + } + + let seen = HashSet::new(); + for typedef in defmap.values() { + let mut curr = seen.clone(); + curr.insert(typedef.name.clone()); + check_typedefs_for_recursion(&defmap, typedef, HashSet::new(), &mut state); + } + for function in &mut module.functions { let res = function.typecheck(&self.refs, &mut state.inner()); state.ok(res, function.block_meta()); @@ -185,7 +253,8 @@ impl Block { StmtKind::Set(variable_reference, expression) => { if let Some(var) = state .ok( - variable_reference.get_variable(&state.scope.variables), + variable_reference + .get_variable(&state.scope.variables, &state.scope.types), variable_reference.meta, ) .flatten() @@ -493,10 +562,59 @@ impl Expression { } } } - ExprKind::StructIndex(expression, type_kind, _) => { - todo!("typechecking for struct index") + ExprKind::StructIndex(expression, type_kind, field_name) => { + // Resolve expected type + let expected_ty = type_kind.resolve_hinted(hints); + + // Typecheck expression + let expr_res = expression.typecheck(state, hints, Some(&expected_ty)); + let expr_ty = + state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), expression.1); + + if let TypeKind::CustomType(struct_name) = expr_ty { + let struct_type = state.scope.get_struct_type(&struct_name)?; + if let Some(expr_field_ty) = struct_type.get_field_ty(&field_name) { + // Make sure they are the same + let true_ty = state.or_else( + expr_field_ty.collapse_into(&expected_ty), + TypeKind::Vague(Vague::Unknown), + self.1, + ); + *type_kind = true_ty.clone(); + // Update possibly resolved type + Ok(true_ty) + } else { + Err(ErrorKind::NoSuchField(field_name.clone())) + } + } else { + Err(ErrorKind::TriedAccessingNonStruct(expr_ty)) + } + } + ExprKind::Struct(struct_name, items) => { + let struct_def = state.scope.get_struct_type(struct_name)?.clone(); + for (field_name, field_expr) in items { + // Get expected type, or error if field does not exist + let expected_ty = state.or_else( + struct_def + .get_field_ty(field_name) + .ok_or(ErrorKind::NoSuchField(format!( + "{}.{}", + struct_name, field_name + ))), + &TypeKind::Vague(VagueType::Unknown), + field_expr.1, + ); + + // Typecheck the actual expression + let expr_res = field_expr.typecheck(state, hints, Some(expected_ty)); + let expr_ty = + state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), field_expr.1); + + // Make sure both are the same type, report error if not + state.ok(expr_ty.collapse_into(&expr_ty), field_expr.1); + } + Ok(TypeKind::CustomType(struct_name.clone())) } - ExprKind::Struct(_, items) => todo!("typechecking for struct expression"), } } } @@ -505,13 +623,14 @@ impl IndexedVariableReference { fn get_variable( &self, storage: &Storage, + types: &Storage, ) -> Result, ErrorKind> { match &self.kind { IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => { Ok(storage.get(&name).cloned()) } IndexedVariableReferenceKind::ArrayIndex(inner_ref, _) => { - if let Some(var) = inner_ref.get_variable(storage)? { + if let Some(var) = inner_ref.get_variable(storage, types)? { match &var.ty { TypeKind::Array(inner_ty, _) => Ok(Some(ScopeVariable { ty: *inner_ty.clone(), @@ -523,8 +642,34 @@ impl IndexedVariableReference { Ok(None) } } - IndexedVariableReferenceKind::StructIndex(indexed_variable_reference, _) => { - todo!("struct index refrence typecheck") + IndexedVariableReferenceKind::StructIndex(var_ref, field_name) => { + if let Some(var) = var_ref.get_variable(storage, types)? { + match &var.ty { + TypeKind::CustomType(type_name) => { + if let Some(kind) = types.get(type_name) { + match &kind { + TypeDefinitionKind::Struct(struct_type) => { + if let Some((_, field_ty)) = + struct_type.0.iter().find(|(n, _)| n == field_name) + { + Ok(Some(ScopeVariable { + ty: field_ty.clone(), + mutable: var.mutable, + })) + } else { + Err(ErrorKind::NoSuchField(field_name.clone())) + } + } + } + } else { + Err(ErrorKind::NoSuchType(type_name.clone())) + } + } + _ => Err(ErrorKind::TriedAccessingNonStruct(var.ty.clone())), + } + } else { + Ok(None) + } } } } @@ -656,3 +801,15 @@ impl Collapsable for ScopeFunction { }) } } + +impl pass::Scope { + pub fn get_struct_type(&self, name: &String) -> Result<&StructType, ErrorKind> { + let ty = self + .types + .get(&name) + .ok_or(ErrorKind::NoSuchType(name.clone()))?; + match ty { + TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty), + } + } +} diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index b8c472c..493279a 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -412,25 +412,3 @@ impl Expression { } } } - -impl pass::Scope { - fn get_struct_type(&self, name: &String) -> Result<&StructType, ErrorKind> { - let ty = self - .types - .get(&name) - .ok_or(ErrorKind::NoSuchType(name.clone()))?; - match ty { - TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty), - } - } - - fn get_struct_type_mut(&mut self, name: &String) -> Result<&mut StructType, ErrorKind> { - let ty = self - .types - .get_mut(&name) - .ok_or(ErrorKind::NoSuchType(name.clone()))?; - match ty { - TypeDefinitionKind::Struct(struct_ty) => Ok(struct_ty), - } - } -}