Add optional data to PassState Scope

This commit is contained in:
Sofia 2025-07-16 22:46:52 +03:00
parent c19384d77b
commit c41aab33a9
4 changed files with 97 additions and 58 deletions

View File

@ -2,6 +2,7 @@ use std::{
cell::RefCell,
collections::HashMap,
convert::Infallible,
fmt::Error,
fs::{self},
path::PathBuf,
rc::Rc,
@ -54,9 +55,12 @@ pub fn compile_std() -> super::Module {
/// MIR.
pub struct LinkerPass;
type LinkerPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
impl Pass for LinkerPass {
type Data = ();
type TError = ErrorKind;
fn context(&mut self, context: &mut Context, mut state: PassState<Self::TError>) {
fn context(&mut self, context: &mut Context, mut state: LinkerPassState) {
let mains = context
.modules
.iter()

View File

@ -106,21 +106,23 @@ impl<T: Clone + std::fmt::Debug> Storage<T> {
}
#[derive(Clone, Default, Debug)]
pub struct Scope {
pub struct Scope<Data: Clone + Default> {
pub function_returns: Storage<ScopeFunction>,
pub variables: Storage<ScopeVariable>,
pub types: Storage<TypeDefinitionKind>,
/// Hard Return type of this scope, if inside a function
pub return_type_hint: Option<TypeKind>,
pub data: Data,
}
impl Scope {
pub fn inner(&self) -> Scope {
impl<Data: Clone + Default> Scope<Data> {
pub fn inner(&self) -> Scope<Data> {
Scope {
function_returns: self.function_returns.clone(),
variables: self.variables.clone(),
types: self.types.clone(),
return_type_hint: self.return_type_hint.clone(),
data: self.data.clone(),
}
}
@ -144,14 +146,14 @@ pub struct ScopeVariable {
pub mutable: bool,
}
pub struct PassState<'st, 'sc, TError: STDError + Clone> {
pub struct PassState<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> {
state: &'st mut State<TError>,
pub scope: &'sc mut Scope,
inner: Vec<Scope>,
pub scope: &'sc mut Scope<Data>,
inner: Vec<Scope<Data>>,
}
impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> {
fn from(state: &'st mut State<TError>, scope: &'sc mut Scope) -> Self {
impl<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> PassState<'st, 'sc, Data, TError> {
fn from(state: &'st mut State<TError>, scope: &'sc mut Scope<Data>) -> Self {
PassState {
state,
scope,
@ -186,7 +188,7 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> {
}
}
pub fn inner(&mut self) -> PassState<TError> {
pub fn inner(&mut self) -> PassState<Data, TError> {
self.inner.push(self.scope.inner());
let scope = self.inner.last_mut().unwrap();
PassState {
@ -198,19 +200,21 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> {
}
pub trait Pass {
type Data: Clone + Default;
type TError: STDError + Clone;
fn context(&mut self, _context: &mut Context, mut _state: PassState<Self::TError>) {}
fn module(&mut self, _module: &mut Module, mut _state: PassState<Self::TError>) {}
fn context(&mut self, _context: &mut Context, mut _state: PassState<Self::Data, Self::TError>) {
}
fn module(&mut self, _module: &mut Module, mut _state: PassState<Self::Data, Self::TError>) {}
fn function(
&mut self,
_function: &mut FunctionDefinition,
mut _state: PassState<Self::TError>,
mut _state: PassState<Self::Data, Self::TError>,
) {
}
fn block(&mut self, _block: &mut Block, mut _state: PassState<Self::TError>) {}
fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState<Self::TError>) {}
fn expr(&mut self, _expr: &mut Expression, mut _state: PassState<Self::TError>) {}
fn block(&mut self, _block: &mut Block, mut _state: PassState<Self::Data, Self::TError>) {}
fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState<Self::Data, Self::TError>) {}
fn expr(&mut self, _expr: &mut Expression, mut _state: PassState<Self::Data, Self::TError>) {}
}
impl Context {
@ -226,7 +230,12 @@ impl Context {
}
impl Module {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) {
fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
for typedef in &self.typedefs {
let kind = match &typedef.kind {
TypeDefinitionKind::Struct(fields) => TypeDefinitionKind::Struct(fields.clone()),
@ -256,7 +265,12 @@ impl Module {
}
impl FunctionDefinition {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) {
fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
for param in &self.parameters {
scope
.variables
@ -283,7 +297,12 @@ impl FunctionDefinition {
}
impl Block {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) {
fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
let mut scope = scope.inner();
for statement in &mut self.statements {
@ -295,7 +314,12 @@ impl Block {
}
impl Statement {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) {
fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
match &mut self.0 {
StmtKind::Let(_, _, expression) => {
expression.pass(pass, state, scope);
@ -332,7 +356,12 @@ impl Statement {
}
impl Expression {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) {
fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
pass.expr(self, PassState::from(state, scope));
}
}

View File

@ -64,40 +64,13 @@ pub struct TypeCheck<'t> {
pub refs: &'t TypeRefs,
}
fn check_typedefs_for_recursion<'a, 'b>(
defmap: &'b HashMap<&'a String, &'b TypeDefinition>,
typedef: &'b TypeDefinition,
mut seen: HashSet<String>,
state: &mut PassState<ErrorKind>,
) {
match &typedef.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) {
if let TypeKind::CustomType(name) = field_ty {
if seen.contains(name) {
state.ok::<_, Infallible>(
Err(ErrorKind::RecursiveTypeDefinition(
typedef.name.clone(),
name.clone(),
)),
typedef.meta,
);
} else {
seen.insert(name.clone());
if let Some(inner_typedef) = defmap.get(name) {
check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state)
}
}
}
}
}
}
}
type TypecheckPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
impl<'t> Pass for TypeCheck<'t> {
type Data = ();
type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) {
fn module(&mut self, module: &mut Module, mut state: TypecheckPassState) {
let mut defmap = HashMap::new();
for typedef in &module.typedefs {
let TypeDefinition { name, kind, meta } = &typedef;
@ -137,11 +110,41 @@ impl<'t> Pass for TypeCheck<'t> {
}
}
fn check_typedefs_for_recursion<'a, 'b>(
defmap: &'b HashMap<&'a String, &'b TypeDefinition>,
typedef: &'b TypeDefinition,
mut seen: HashSet<String>,
state: &mut TypecheckPassState,
) {
match &typedef.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) {
if let TypeKind::CustomType(name) = field_ty {
if seen.contains(name) {
state.ok::<_, Infallible>(
Err(ErrorKind::RecursiveTypeDefinition(
typedef.name.clone(),
name.clone(),
)),
typedef.meta,
);
} else {
seen.insert(name.clone());
if let Some(inner_typedef) = defmap.get(name) {
check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state)
}
}
}
}
}
}
}
impl FunctionDefinition {
fn typecheck(
&mut self,
hints: &TypeRefs,
state: &mut PassState<ErrorKind>,
state: &mut TypecheckPassState,
) -> Result<TypeKind, ErrorKind> {
for param in &self.parameters {
let param_t = state.or_else(
@ -186,7 +189,7 @@ impl FunctionDefinition {
impl Block {
fn typecheck(
&mut self,
state: &mut PassState<ErrorKind>,
state: &mut TypecheckPassState,
typerefs: &TypeRefs,
hint_t: Option<&TypeKind>,
) -> Result<(ReturnKind, TypeKind), ErrorKind> {
@ -341,7 +344,7 @@ impl Block {
impl Expression {
fn typecheck(
&mut self,
state: &mut PassState<ErrorKind>,
state: &mut TypecheckPassState,
typerefs: &TypeRefs,
hint_t: Option<&TypeKind>,
) -> Result<TypeKind, ErrorKind> {

View File

@ -26,10 +26,13 @@ pub struct TypeInference<'t> {
pub refs: &'t TypeRefs,
}
type TypeInferencePassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
impl<'t> Pass for TypeInference<'t> {
type Data = ();
type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) {
fn module(&mut self, module: &mut Module, mut state: TypeInferencePassState) {
for function in &mut module.functions {
let res = function.infer_types(&self.refs, &mut state.inner());
state.ok(res, function.block_meta());
@ -41,7 +44,7 @@ impl FunctionDefinition {
fn infer_types(
&mut self,
type_refs: &TypeRefs,
state: &mut PassState<ErrorKind>,
state: &mut TypeInferencePassState,
) -> Result<(), ErrorKind> {
let scope_hints = ScopeTypeRefs::from(type_refs);
for param in &self.parameters {
@ -74,7 +77,7 @@ impl FunctionDefinition {
impl Block {
fn infer_types<'s>(
&mut self,
state: &mut PassState<ErrorKind>,
state: &mut TypeInferencePassState,
outer_hints: &'s ScopeTypeRefs,
) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> {
let mut state = state.inner();
@ -150,7 +153,7 @@ impl Block {
impl Expression {
fn infer_types<'s>(
&mut self,
state: &mut PassState<ErrorKind>,
state: &mut TypeInferencePassState,
type_refs: &'s ScopeTypeRefs<'s>,
) -> Result<TypeRef<'s>, ErrorKind> {
match &mut self.0 {