From ffb3515aaabc5b7b35b74722cefdc05642b99b28 Mon Sep 17 00:00:00 2001 From: sofia Date: Mon, 25 Aug 2025 20:49:38 +0300 Subject: [PATCH] Implement generic type swapping --- libtest.sh | 2 +- reid/src/ast/process.rs | 16 +++++-- reid/src/mir/fmt.rs | 7 ++- reid/src/mir/generics.rs | 95 +++++++++++++++++++++++++++++++++------- reid/src/mir/mod.rs | 36 ++++++++++++++- 5 files changed, 132 insertions(+), 24 deletions(-) diff --git a/libtest.sh b/libtest.sh index a013ade..7408c89 100755 --- a/libtest.sh +++ b/libtest.sh @@ -16,7 +16,7 @@ BINARY="$(echo $1 | cut -d'.' -f1)"".out" echo $1 -cargo run -p reid -- $@ && \ +cargo run -p reid -- run $@ && \ ./$BINARY ; echo "Return value: ""$?" ## Command from: clang -v hello.o -o test diff --git a/reid/src/ast/process.rs b/reid/src/ast/process.rs index bfb6687..d4d1c59 100644 --- a/reid/src/ast/process.rs +++ b/reid/src/ast/process.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, path::PathBuf}; use crate::{ - ast::{self, ReturnType}, + ast::{self, ReturnType, TypeKind}, mir::{ self, CustomTypeKey, FunctionParam, ModuleMap, NamedVariableRef, ReturnKind, SourceModuleId, StmtKind, StructField, StructType, WhileStatement, @@ -44,7 +44,11 @@ impl ast::Module { let def = mir::FunctionDefinition { name: signature.name.clone(), documentation: signature.documentation.clone(), - generics: signature.generics.clone(), + generics: signature + .generics + .iter() + .map(|g| (g.clone(), mir::TypeKind::Vague(mir::VagueType::Unknown))) + .collect(), linkage_name: None, is_pub: false, is_imported: false, @@ -176,11 +180,15 @@ impl ast::FunctionDefinition { ty: p.1 .0.into_mir(module_id), meta: p.2.as_meta(module_id), })); + mir::FunctionDefinition { name: signature.name.clone(), documentation: signature.documentation.clone(), - // TODO generics parsing - generics: Vec::new(), + generics: signature + .generics + .iter() + .map(|g| (g.clone(), mir::TypeKind::Vague(mir::VagueType::Unknown))) + .collect(), linkage_name: None, is_pub: *is_pub, is_imported: false, diff --git a/reid/src/mir/fmt.rs b/reid/src/mir/fmt.rs index f52631c..3699017 100644 --- a/reid/src/mir/fmt.rs +++ b/reid/src/mir/fmt.rs @@ -158,9 +158,14 @@ impl Display for FunctionDefinition { } write!( f, - "{}fn {}({}) -> {:#} ", + "{}fn {}<{}>({}) -> {:#} ", if self.is_pub { "pub " } else { "" }, self.name, + self.generics + .iter() + .map(|(n, t)| format!("{n} = {:?}", t)) + .collect::>() + .join(", "), self.parameters .iter() .map(|FunctionParam { name, ty, .. }| format!("{}: {:#}", name, ty)) diff --git a/reid/src/mir/generics.rs b/reid/src/mir/generics.rs index 5d07846..d1c57af 100644 --- a/reid/src/mir/generics.rs +++ b/reid/src/mir/generics.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, path::PathBuf}; use crate::mir::{ - self, FunctionCall, GlobalKind, GlobalValue, IfExpression, Literal, Module, SourceModuleId, TypeKind, - WhileStatement, + self, generics, CustomTypeKey, FunctionCall, FunctionDefinition, FunctionParam, GlobalKind, GlobalValue, + IfExpression, Literal, Module, SourceModuleId, TypeKind, WhileStatement, }; use super::pass::{Pass, PassResult, PassState}; @@ -43,18 +43,18 @@ impl Pass for GenericsPass { ); } - for module in &context.modules { + for module in &mut context.modules { let mut calls = HashMap::new(); let mut assoc_calls = HashMap::new(); - for function in &module.1.associated_functions { - match &function.1.kind { + for function in &mut module.1.associated_functions { + match &mut function.1.kind { mir::FunctionDefinitionKind::Local(block, _) => block.find_calls(&mut calls, &mut assoc_calls), mir::FunctionDefinitionKind::Extern(_) => {} mir::FunctionDefinitionKind::Intrinsic(_) => {} } } - for function in &module.1.functions { - match &function.kind { + for function in &mut module.1.functions { + match &mut function.kind { mir::FunctionDefinitionKind::Local(block, _) => block.find_calls(&mut calls, &mut assoc_calls), mir::FunctionDefinitionKind::Extern(_) => {} mir::FunctionDefinitionKind::Intrinsic(_) => {} @@ -80,7 +80,6 @@ impl Pass for GenericsPass { } } } - dbg!(&function_map); self.function_map = function_map; @@ -88,24 +87,61 @@ impl Pass for GenericsPass { } fn module(&mut self, module: &mut mir::Module, mut state: PassState) -> PassResult { + for function in module.functions.drain(..).collect::>() { + if let Some(source) = function.source { + let functions = self.function_map.get(&source).unwrap(); + let calls = functions.calls.get(&function.name).unwrap(); + + if function.generics.len() > 0 { + for call in calls { + if let Some(clone) = function.try_clone() { + let generics = function + .generics + .iter() + .zip(call) + .map(|((n, _), t)| (n.clone(), t.clone())) + .collect(); + module.functions.push(FunctionDefinition { + name: name_fmt(function.name.clone(), call.clone()), + return_type: function.return_type.replace_generic(&generics), + parameters: function + .parameters + .iter() + .map(|p| FunctionParam { + ty: p.ty.replace_generic(&generics), + ..p.clone() + }) + .collect(), + generics, + ..clone + }); + } + } + } else { + module.functions.push(function); + } + } else { + module.functions.push(function); + } + } Ok(()) } } impl mir::Block { - fn find_calls(&self, calls: &mut HashMap, assoc_calls: &mut HashMap<(TypeKind, String), Calls>) { - for statement in &self.statements { + fn find_calls(&mut self, calls: &mut HashMap, assoc_calls: &mut HashMap<(TypeKind, String), Calls>) { + for statement in &mut self.statements { statement.find_calls(calls, assoc_calls); } - if let Some((_, Some(e))) = &self.return_expression { + if let Some((_, Some(e))) = &mut self.return_expression { e.find_calls(calls, assoc_calls); } } } impl mir::Statement { - fn find_calls(&self, calls: &mut HashMap, assoc_calls: &mut HashMap<(TypeKind, String), Calls>) { - match &self.0 { + fn find_calls(&mut self, calls: &mut HashMap, assoc_calls: &mut HashMap<(TypeKind, String), Calls>) { + match &mut self.0 { mir::StmtKind::Let(_, _, expression) => expression.find_calls(calls, assoc_calls), mir::StmtKind::Set(expression, expression1) => { expression.find_calls(calls, assoc_calls); @@ -122,8 +158,8 @@ impl mir::Statement { } impl mir::Expression { - fn find_calls(&self, calls: &mut HashMap, assoc_calls: &mut HashMap<(TypeKind, String), Calls>) { - match &self.0 { + fn find_calls(&mut self, calls: &mut HashMap, assoc_calls: &mut HashMap<(TypeKind, String), Calls>) { + match &mut self.0 { mir::ExprKind::Variable(_) => {} mir::ExprKind::Indexed(expression, _, expression1) => { expression.find_calls(calls, assoc_calls); @@ -142,7 +178,7 @@ impl mir::Expression { item.1.find_calls(calls, assoc_calls); } } - mir::ExprKind::Literal(_) => todo!(), + mir::ExprKind::Literal(_) => {} mir::ExprKind::BinOp(_, lhs, rhs, _) => { lhs.find_calls(calls, assoc_calls); rhs.find_calls(calls, assoc_calls); @@ -153,6 +189,7 @@ impl mir::Expression { } else { calls.insert(function_call.name.clone(), vec![function_call.generics.clone()]); } + function_call.name = name_fmt(function_call.name.clone(), function_call.generics.clone()) } mir::ExprKind::AssociatedFunctionCall(ty, function_call) => { if let Some(calls) = assoc_calls.get_mut(&(ty.clone(), function_call.name.clone())) { @@ -163,11 +200,12 @@ impl mir::Expression { vec![function_call.generics.clone()], ); } + function_call.name = name_fmt(function_call.name.clone(), function_call.generics.clone()) } mir::ExprKind::If(IfExpression(cond, then_e, else_e)) => { cond.find_calls(calls, assoc_calls); then_e.find_calls(calls, assoc_calls); - if let Some(else_e) = else_e.as_ref() { + if let Some(else_e) = else_e.as_mut() { else_e.find_calls(calls, assoc_calls); } } @@ -179,3 +217,26 @@ impl mir::Expression { } } } + +fn name_fmt(name: String, generics: Vec) -> String { + format!( + "{}.{}", + name, + generics.iter().map(|t| t.to_string()).collect::>().join(".") + ) +} + +impl TypeKind { + fn replace_generic(&self, generics: &Vec<(String, TypeKind)>) -> TypeKind { + match self { + TypeKind::CustomType(CustomTypeKey(name, _)) => { + if let Some((_, inner)) = generics.iter().find(|(n, _)| n == name) { + inner.clone() + } else { + self.clone() + } + } + _ => self.clone(), + } + } +} diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index 3e4f960..ed37929 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -324,7 +324,7 @@ pub struct FunctionDefinition { pub name: String, pub documentation: Option, pub linkage_name: Option, - pub generics: Vec, + pub generics: Vec<(String, TypeKind)>, /// Whether this function is visible to outside modules pub is_pub: bool, /// Whether this module is from an external module, and has been imported @@ -336,6 +336,40 @@ pub struct FunctionDefinition { pub signature_meta: Metadata, } +impl FunctionDefinition { + pub fn try_clone(&self) -> Option { + match &self.kind { + FunctionDefinitionKind::Local(block, metadata) => Some(FunctionDefinition { + name: self.name.clone(), + documentation: self.documentation.clone(), + linkage_name: self.linkage_name.clone(), + generics: self.generics.clone(), + is_pub: self.is_pub.clone(), + is_imported: self.is_imported.clone(), + return_type: self.return_type.clone(), + parameters: self.parameters.clone(), + kind: FunctionDefinitionKind::Local(block.clone(), metadata.clone()), + source: self.source.clone(), + signature_meta: self.signature_meta.clone(), + }), + FunctionDefinitionKind::Extern(e) => Some(FunctionDefinition { + name: self.name.clone(), + documentation: self.documentation.clone(), + linkage_name: self.linkage_name.clone(), + generics: self.generics.clone(), + is_pub: self.is_pub.clone(), + is_imported: self.is_imported.clone(), + return_type: self.return_type.clone(), + parameters: self.parameters.clone(), + kind: FunctionDefinitionKind::Extern(*e), + source: self.source.clone(), + signature_meta: self.signature_meta.clone(), + }), + FunctionDefinitionKind::Intrinsic(intrinsic_function) => None, + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct FunctionParam { pub name: String,