From c41aab33a9015e20c70ea6178269211f07c7c977 Mon Sep 17 00:00:00 2001 From: sofia Date: Wed, 16 Jul 2025 22:46:52 +0300 Subject: [PATCH] Add optional data to PassState Scope --- reid/src/mir/linker.rs | 6 ++- reid/src/mir/pass.rs | 69 +++++++++++++++++++++++++---------- reid/src/mir/typecheck.rs | 69 ++++++++++++++++++----------------- reid/src/mir/typeinference.rs | 11 ++++-- 4 files changed, 97 insertions(+), 58 deletions(-) diff --git a/reid/src/mir/linker.rs b/reid/src/mir/linker.rs index 3d3419a..c212ef2 100644 --- a/reid/src/mir/linker.rs +++ b/reid/src/mir/linker.rs @@ -2,6 +2,7 @@ use std::{ cell::RefCell, collections::HashMap, convert::Infallible, + fmt::Error, fs::{self}, path::PathBuf, rc::Rc, @@ -54,9 +55,12 @@ pub fn compile_std() -> super::Module { /// MIR. pub struct LinkerPass; +type LinkerPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>; + impl Pass for LinkerPass { + type Data = (); type TError = ErrorKind; - fn context(&mut self, context: &mut Context, mut state: PassState) { + fn context(&mut self, context: &mut Context, mut state: LinkerPassState) { let mains = context .modules .iter() diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index a2cc8bd..3706015 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -106,21 +106,23 @@ impl Storage { } #[derive(Clone, Default, Debug)] -pub struct Scope { +pub struct Scope { pub function_returns: Storage, pub variables: Storage, pub types: Storage, /// Hard Return type of this scope, if inside a function pub return_type_hint: Option, + pub data: Data, } -impl Scope { - pub fn inner(&self) -> Scope { +impl Scope { + pub fn inner(&self) -> Scope { Scope { function_returns: self.function_returns.clone(), variables: self.variables.clone(), types: self.types.clone(), return_type_hint: self.return_type_hint.clone(), + data: self.data.clone(), } } @@ -144,14 +146,14 @@ pub struct ScopeVariable { pub mutable: bool, } -pub struct PassState<'st, 'sc, TError: STDError + Clone> { +pub struct PassState<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> { state: &'st mut State, - pub scope: &'sc mut Scope, - inner: Vec, + pub scope: &'sc mut Scope, + inner: Vec>, } -impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> { - fn from(state: &'st mut State, scope: &'sc mut Scope) -> Self { +impl<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> PassState<'st, 'sc, Data, TError> { + fn from(state: &'st mut State, scope: &'sc mut Scope) -> Self { PassState { state, scope, @@ -186,7 +188,7 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> { } } - pub fn inner(&mut self) -> PassState { + pub fn inner(&mut self) -> PassState { self.inner.push(self.scope.inner()); let scope = self.inner.last_mut().unwrap(); PassState { @@ -198,19 +200,21 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> { } pub trait Pass { + type Data: Clone + Default; type TError: STDError + Clone; - fn context(&mut self, _context: &mut Context, mut _state: PassState) {} - fn module(&mut self, _module: &mut Module, mut _state: PassState) {} + fn context(&mut self, _context: &mut Context, mut _state: PassState) { + } + fn module(&mut self, _module: &mut Module, mut _state: PassState) {} fn function( &mut self, _function: &mut FunctionDefinition, - mut _state: PassState, + mut _state: PassState, ) { } - fn block(&mut self, _block: &mut Block, mut _state: PassState) {} - fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState) {} - fn expr(&mut self, _expr: &mut Expression, mut _state: PassState) {} + fn block(&mut self, _block: &mut Block, mut _state: PassState) {} + fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState) {} + fn expr(&mut self, _expr: &mut Expression, mut _state: PassState) {} } impl Context { @@ -226,7 +230,12 @@ impl Context { } impl Module { - fn pass(&mut self, pass: &mut T, state: &mut State, scope: &mut Scope) { + fn pass( + &mut self, + pass: &mut T, + state: &mut State, + scope: &mut Scope, + ) { for typedef in &self.typedefs { let kind = match &typedef.kind { TypeDefinitionKind::Struct(fields) => TypeDefinitionKind::Struct(fields.clone()), @@ -256,7 +265,12 @@ impl Module { } impl FunctionDefinition { - fn pass(&mut self, pass: &mut T, state: &mut State, scope: &mut Scope) { + fn pass( + &mut self, + pass: &mut T, + state: &mut State, + scope: &mut Scope, + ) { for param in &self.parameters { scope .variables @@ -283,7 +297,12 @@ impl FunctionDefinition { } impl Block { - fn pass(&mut self, pass: &mut T, state: &mut State, scope: &mut Scope) { + fn pass( + &mut self, + pass: &mut T, + state: &mut State, + scope: &mut Scope, + ) { let mut scope = scope.inner(); for statement in &mut self.statements { @@ -295,7 +314,12 @@ impl Block { } impl Statement { - fn pass(&mut self, pass: &mut T, state: &mut State, scope: &mut Scope) { + fn pass( + &mut self, + pass: &mut T, + state: &mut State, + scope: &mut Scope, + ) { match &mut self.0 { StmtKind::Let(_, _, expression) => { expression.pass(pass, state, scope); @@ -332,7 +356,12 @@ impl Statement { } impl Expression { - fn pass(&mut self, pass: &mut T, state: &mut State, scope: &mut Scope) { + fn pass( + &mut self, + pass: &mut T, + state: &mut State, + scope: &mut Scope, + ) { pass.expr(self, PassState::from(state, scope)); } } diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index d73aa23..93edf9f 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -64,40 +64,13 @@ pub struct TypeCheck<'t> { pub refs: &'t TypeRefs, } -fn check_typedefs_for_recursion<'a, 'b>( - defmap: &'b HashMap<&'a String, &'b TypeDefinition>, - typedef: &'b TypeDefinition, - mut seen: HashSet, - state: &mut PassState, -) { - match &typedef.kind { - TypeDefinitionKind::Struct(StructType(fields)) => { - for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) { - if let TypeKind::CustomType(name) = field_ty { - if seen.contains(name) { - state.ok::<_, Infallible>( - Err(ErrorKind::RecursiveTypeDefinition( - typedef.name.clone(), - name.clone(), - )), - typedef.meta, - ); - } else { - seen.insert(name.clone()); - if let Some(inner_typedef) = defmap.get(name) { - check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state) - } - } - } - } - } - } -} +type TypecheckPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>; impl<'t> Pass for TypeCheck<'t> { + type Data = (); type TError = ErrorKind; - fn module(&mut self, module: &mut Module, mut state: PassState) { + fn module(&mut self, module: &mut Module, mut state: TypecheckPassState) { let mut defmap = HashMap::new(); for typedef in &module.typedefs { let TypeDefinition { name, kind, meta } = &typedef; @@ -137,11 +110,41 @@ impl<'t> Pass for TypeCheck<'t> { } } +fn check_typedefs_for_recursion<'a, 'b>( + defmap: &'b HashMap<&'a String, &'b TypeDefinition>, + typedef: &'b TypeDefinition, + mut seen: HashSet, + state: &mut TypecheckPassState, +) { + match &typedef.kind { + TypeDefinitionKind::Struct(StructType(fields)) => { + for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) { + if let TypeKind::CustomType(name) = field_ty { + if seen.contains(name) { + state.ok::<_, Infallible>( + Err(ErrorKind::RecursiveTypeDefinition( + typedef.name.clone(), + name.clone(), + )), + typedef.meta, + ); + } else { + seen.insert(name.clone()); + if let Some(inner_typedef) = defmap.get(name) { + check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state) + } + } + } + } + } + } +} + impl FunctionDefinition { fn typecheck( &mut self, hints: &TypeRefs, - state: &mut PassState, + state: &mut TypecheckPassState, ) -> Result { for param in &self.parameters { let param_t = state.or_else( @@ -186,7 +189,7 @@ impl FunctionDefinition { impl Block { fn typecheck( &mut self, - state: &mut PassState, + state: &mut TypecheckPassState, typerefs: &TypeRefs, hint_t: Option<&TypeKind>, ) -> Result<(ReturnKind, TypeKind), ErrorKind> { @@ -341,7 +344,7 @@ impl Block { impl Expression { fn typecheck( &mut self, - state: &mut PassState, + state: &mut TypecheckPassState, typerefs: &TypeRefs, hint_t: Option<&TypeKind>, ) -> Result { diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index 44eb951..dd06051 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -26,10 +26,13 @@ pub struct TypeInference<'t> { pub refs: &'t TypeRefs, } +type TypeInferencePassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>; + impl<'t> Pass for TypeInference<'t> { + type Data = (); type TError = ErrorKind; - fn module(&mut self, module: &mut Module, mut state: PassState) { + fn module(&mut self, module: &mut Module, mut state: TypeInferencePassState) { for function in &mut module.functions { let res = function.infer_types(&self.refs, &mut state.inner()); state.ok(res, function.block_meta()); @@ -41,7 +44,7 @@ impl FunctionDefinition { fn infer_types( &mut self, type_refs: &TypeRefs, - state: &mut PassState, + state: &mut TypeInferencePassState, ) -> Result<(), ErrorKind> { let scope_hints = ScopeTypeRefs::from(type_refs); for param in &self.parameters { @@ -74,7 +77,7 @@ impl FunctionDefinition { impl Block { fn infer_types<'s>( &mut self, - state: &mut PassState, + state: &mut TypeInferencePassState, outer_hints: &'s ScopeTypeRefs, ) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> { let mut state = state.inner(); @@ -150,7 +153,7 @@ impl Block { impl Expression { fn infer_types<'s>( &mut self, - state: &mut PassState, + state: &mut TypeInferencePassState, type_refs: &'s ScopeTypeRefs<'s>, ) -> Result, ErrorKind> { match &mut self.0 {