Add optional data to PassState Scope

This commit is contained in:
Sofia 2025-07-16 22:46:52 +03:00
parent c19384d77b
commit c41aab33a9
4 changed files with 97 additions and 58 deletions

View File

@ -2,6 +2,7 @@ use std::{
cell::RefCell, cell::RefCell,
collections::HashMap, collections::HashMap,
convert::Infallible, convert::Infallible,
fmt::Error,
fs::{self}, fs::{self},
path::PathBuf, path::PathBuf,
rc::Rc, rc::Rc,
@ -54,9 +55,12 @@ pub fn compile_std() -> super::Module {
/// MIR. /// MIR.
pub struct LinkerPass; pub struct LinkerPass;
type LinkerPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
impl Pass for LinkerPass { impl Pass for LinkerPass {
type Data = ();
type TError = ErrorKind; type TError = ErrorKind;
fn context(&mut self, context: &mut Context, mut state: PassState<Self::TError>) { fn context(&mut self, context: &mut Context, mut state: LinkerPassState) {
let mains = context let mains = context
.modules .modules
.iter() .iter()

View File

@ -106,21 +106,23 @@ impl<T: Clone + std::fmt::Debug> Storage<T> {
} }
#[derive(Clone, Default, Debug)] #[derive(Clone, Default, Debug)]
pub struct Scope { pub struct Scope<Data: Clone + Default> {
pub function_returns: Storage<ScopeFunction>, pub function_returns: Storage<ScopeFunction>,
pub variables: Storage<ScopeVariable>, pub variables: Storage<ScopeVariable>,
pub types: Storage<TypeDefinitionKind>, pub types: Storage<TypeDefinitionKind>,
/// Hard Return type of this scope, if inside a function /// Hard Return type of this scope, if inside a function
pub return_type_hint: Option<TypeKind>, pub return_type_hint: Option<TypeKind>,
pub data: Data,
} }
impl Scope { impl<Data: Clone + Default> Scope<Data> {
pub fn inner(&self) -> Scope { pub fn inner(&self) -> Scope<Data> {
Scope { Scope {
function_returns: self.function_returns.clone(), function_returns: self.function_returns.clone(),
variables: self.variables.clone(), variables: self.variables.clone(),
types: self.types.clone(), types: self.types.clone(),
return_type_hint: self.return_type_hint.clone(), return_type_hint: self.return_type_hint.clone(),
data: self.data.clone(),
} }
} }
@ -144,14 +146,14 @@ pub struct ScopeVariable {
pub mutable: bool, pub mutable: bool,
} }
pub struct PassState<'st, 'sc, TError: STDError + Clone> { pub struct PassState<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> {
state: &'st mut State<TError>, state: &'st mut State<TError>,
pub scope: &'sc mut Scope, pub scope: &'sc mut Scope<Data>,
inner: Vec<Scope>, inner: Vec<Scope<Data>>,
} }
impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> { impl<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> PassState<'st, 'sc, Data, TError> {
fn from(state: &'st mut State<TError>, scope: &'sc mut Scope) -> Self { fn from(state: &'st mut State<TError>, scope: &'sc mut Scope<Data>) -> Self {
PassState { PassState {
state, state,
scope, scope,
@ -186,7 +188,7 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> {
} }
} }
pub fn inner(&mut self) -> PassState<TError> { pub fn inner(&mut self) -> PassState<Data, TError> {
self.inner.push(self.scope.inner()); self.inner.push(self.scope.inner());
let scope = self.inner.last_mut().unwrap(); let scope = self.inner.last_mut().unwrap();
PassState { PassState {
@ -198,19 +200,21 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> {
} }
pub trait Pass { pub trait Pass {
type Data: Clone + Default;
type TError: STDError + Clone; type TError: STDError + Clone;
fn context(&mut self, _context: &mut Context, mut _state: PassState<Self::TError>) {} fn context(&mut self, _context: &mut Context, mut _state: PassState<Self::Data, Self::TError>) {
fn module(&mut self, _module: &mut Module, mut _state: PassState<Self::TError>) {} }
fn module(&mut self, _module: &mut Module, mut _state: PassState<Self::Data, Self::TError>) {}
fn function( fn function(
&mut self, &mut self,
_function: &mut FunctionDefinition, _function: &mut FunctionDefinition,
mut _state: PassState<Self::TError>, mut _state: PassState<Self::Data, Self::TError>,
) { ) {
} }
fn block(&mut self, _block: &mut Block, mut _state: PassState<Self::TError>) {} fn block(&mut self, _block: &mut Block, mut _state: PassState<Self::Data, Self::TError>) {}
fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState<Self::TError>) {} fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState<Self::Data, Self::TError>) {}
fn expr(&mut self, _expr: &mut Expression, mut _state: PassState<Self::TError>) {} fn expr(&mut self, _expr: &mut Expression, mut _state: PassState<Self::Data, Self::TError>) {}
} }
impl Context { impl Context {
@ -226,7 +230,12 @@ impl Context {
} }
impl Module { impl Module {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
for typedef in &self.typedefs { for typedef in &self.typedefs {
let kind = match &typedef.kind { let kind = match &typedef.kind {
TypeDefinitionKind::Struct(fields) => TypeDefinitionKind::Struct(fields.clone()), TypeDefinitionKind::Struct(fields) => TypeDefinitionKind::Struct(fields.clone()),
@ -256,7 +265,12 @@ impl Module {
} }
impl FunctionDefinition { impl FunctionDefinition {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
for param in &self.parameters { for param in &self.parameters {
scope scope
.variables .variables
@ -283,7 +297,12 @@ impl FunctionDefinition {
} }
impl Block { impl Block {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
let mut scope = scope.inner(); let mut scope = scope.inner();
for statement in &mut self.statements { for statement in &mut self.statements {
@ -295,7 +314,12 @@ impl Block {
} }
impl Statement { impl Statement {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
match &mut self.0 { match &mut self.0 {
StmtKind::Let(_, _, expression) => { StmtKind::Let(_, _, expression) => {
expression.pass(pass, state, scope); expression.pass(pass, state, scope);
@ -332,7 +356,12 @@ impl Statement {
} }
impl Expression { impl Expression {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
pass.expr(self, PassState::from(state, scope)); pass.expr(self, PassState::from(state, scope));
} }
} }

View File

@ -64,40 +64,13 @@ pub struct TypeCheck<'t> {
pub refs: &'t TypeRefs, pub refs: &'t TypeRefs,
} }
fn check_typedefs_for_recursion<'a, 'b>( type TypecheckPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
defmap: &'b HashMap<&'a String, &'b TypeDefinition>,
typedef: &'b TypeDefinition,
mut seen: HashSet<String>,
state: &mut PassState<ErrorKind>,
) {
match &typedef.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) {
if let TypeKind::CustomType(name) = field_ty {
if seen.contains(name) {
state.ok::<_, Infallible>(
Err(ErrorKind::RecursiveTypeDefinition(
typedef.name.clone(),
name.clone(),
)),
typedef.meta,
);
} else {
seen.insert(name.clone());
if let Some(inner_typedef) = defmap.get(name) {
check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state)
}
}
}
}
}
}
}
impl<'t> Pass for TypeCheck<'t> { impl<'t> Pass for TypeCheck<'t> {
type Data = ();
type TError = ErrorKind; type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) { fn module(&mut self, module: &mut Module, mut state: TypecheckPassState) {
let mut defmap = HashMap::new(); let mut defmap = HashMap::new();
for typedef in &module.typedefs { for typedef in &module.typedefs {
let TypeDefinition { name, kind, meta } = &typedef; let TypeDefinition { name, kind, meta } = &typedef;
@ -137,11 +110,41 @@ impl<'t> Pass for TypeCheck<'t> {
} }
} }
fn check_typedefs_for_recursion<'a, 'b>(
defmap: &'b HashMap<&'a String, &'b TypeDefinition>,
typedef: &'b TypeDefinition,
mut seen: HashSet<String>,
state: &mut TypecheckPassState,
) {
match &typedef.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) {
if let TypeKind::CustomType(name) = field_ty {
if seen.contains(name) {
state.ok::<_, Infallible>(
Err(ErrorKind::RecursiveTypeDefinition(
typedef.name.clone(),
name.clone(),
)),
typedef.meta,
);
} else {
seen.insert(name.clone());
if let Some(inner_typedef) = defmap.get(name) {
check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state)
}
}
}
}
}
}
}
impl FunctionDefinition { impl FunctionDefinition {
fn typecheck( fn typecheck(
&mut self, &mut self,
hints: &TypeRefs, hints: &TypeRefs,
state: &mut PassState<ErrorKind>, state: &mut TypecheckPassState,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {
for param in &self.parameters { for param in &self.parameters {
let param_t = state.or_else( let param_t = state.or_else(
@ -186,7 +189,7 @@ impl FunctionDefinition {
impl Block { impl Block {
fn typecheck( fn typecheck(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypecheckPassState,
typerefs: &TypeRefs, typerefs: &TypeRefs,
hint_t: Option<&TypeKind>, hint_t: Option<&TypeKind>,
) -> Result<(ReturnKind, TypeKind), ErrorKind> { ) -> Result<(ReturnKind, TypeKind), ErrorKind> {
@ -341,7 +344,7 @@ impl Block {
impl Expression { impl Expression {
fn typecheck( fn typecheck(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypecheckPassState,
typerefs: &TypeRefs, typerefs: &TypeRefs,
hint_t: Option<&TypeKind>, hint_t: Option<&TypeKind>,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {

View File

@ -26,10 +26,13 @@ pub struct TypeInference<'t> {
pub refs: &'t TypeRefs, pub refs: &'t TypeRefs,
} }
type TypeInferencePassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
impl<'t> Pass for TypeInference<'t> { impl<'t> Pass for TypeInference<'t> {
type Data = ();
type TError = ErrorKind; type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) { fn module(&mut self, module: &mut Module, mut state: TypeInferencePassState) {
for function in &mut module.functions { for function in &mut module.functions {
let res = function.infer_types(&self.refs, &mut state.inner()); let res = function.infer_types(&self.refs, &mut state.inner());
state.ok(res, function.block_meta()); state.ok(res, function.block_meta());
@ -41,7 +44,7 @@ impl FunctionDefinition {
fn infer_types( fn infer_types(
&mut self, &mut self,
type_refs: &TypeRefs, type_refs: &TypeRefs,
state: &mut PassState<ErrorKind>, state: &mut TypeInferencePassState,
) -> Result<(), ErrorKind> { ) -> Result<(), ErrorKind> {
let scope_hints = ScopeTypeRefs::from(type_refs); let scope_hints = ScopeTypeRefs::from(type_refs);
for param in &self.parameters { for param in &self.parameters {
@ -74,7 +77,7 @@ impl FunctionDefinition {
impl Block { impl Block {
fn infer_types<'s>( fn infer_types<'s>(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypeInferencePassState,
outer_hints: &'s ScopeTypeRefs, outer_hints: &'s ScopeTypeRefs,
) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> { ) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> {
let mut state = state.inner(); let mut state = state.inner();
@ -150,7 +153,7 @@ impl Block {
impl Expression { impl Expression {
fn infer_types<'s>( fn infer_types<'s>(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypeInferencePassState,
type_refs: &'s ScopeTypeRefs<'s>, type_refs: &'s ScopeTypeRefs<'s>,
) -> Result<TypeRef<'s>, ErrorKind> { ) -> Result<TypeRef<'s>, ErrorKind> {
match &mut self.0 { match &mut self.0 {