Add codegen for assoc functions

This commit is contained in:
Sofia 2025-07-27 18:32:17 +03:00
parent 4d7c17a854
commit 24f11a77d2
2 changed files with 178 additions and 4 deletions

View File

@ -15,8 +15,11 @@ use scope::*;
use crate::{
mir::{
self, implement::TypeCategory, pass::BinopKey, CustomTypeKey, FunctionDefinitionKind, NamedVariableRef,
SourceModuleId, StructField, StructType, TypeDefinition, TypeDefinitionKind, TypeKind, WhileStatement,
self,
implement::TypeCategory,
pass::{AssociatedFunctionKey, BinopKey},
CustomTypeKey, FunctionDefinitionKind, NamedVariableRef, SourceModuleId, StructField, StructType,
TypeDefinition, TypeDefinitionKind, TypeKind, WhileStatement,
},
util::try_all,
};
@ -212,6 +215,46 @@ impl mir::Module {
}
}
let mut associated_functions = HashMap::new();
for (ty, function) in &self.associated_functions {
let param_types: Vec<Type> = function
.parameters
.iter()
.map(|(_, p)| p.get_type(&type_values))
.collect();
let is_main = self.is_main && function.name == "main";
let func = match &function.kind {
mir::FunctionDefinitionKind::Local(_, _) => Some(module.function(
&format!("{}::{}", ty, function.name),
function.return_type.get_type(&type_values),
param_types,
FunctionFlags {
is_pub: function.is_pub || is_main,
is_main,
is_imported: function.is_imported,
..FunctionFlags::default()
},
)),
mir::FunctionDefinitionKind::Extern(imported) => Some(module.function(
&function.name,
function.return_type.get_type(&type_values),
param_types,
FunctionFlags {
is_extern: true,
is_imported: *imported,
..FunctionFlags::default()
},
)),
mir::FunctionDefinitionKind::Intrinsic(_) => None,
};
if let Some(func) = func {
associated_functions.insert(AssociatedFunctionKey(ty.clone(), function.name.clone()), func);
}
}
let mut binops = HashMap::new();
for binop in &self.binop_defs {
let binop_fn_name = format!(
@ -258,6 +301,7 @@ impl mir::Module {
module_id: self.module_id,
function: &ir_function,
block: entry,
assoc_functions: &associated_functions,
functions: &functions,
types: &types,
type_values: &type_values,
@ -331,6 +375,62 @@ impl mir::Module {
module_id: self.module_id,
function,
block: entry,
assoc_functions: &associated_functions,
functions: &functions,
types: &types,
type_values: &type_values,
stack_values: HashMap::new(),
debug: Some(Debug {
info: &debug,
scope: compile_unit,
types: &debug_types,
}),
binops: &binops,
allocator: Rc::new(RefCell::new(allocator)),
};
mir_function
.kind
.codegen(
mir_function.name.clone(),
mir_function.is_pub,
&mut scope,
&mir_function.parameters,
&mir_function.return_type,
&function,
match &mir_function.kind {
FunctionDefinitionKind::Local(..) => mir_function.signature().into_debug(tokens, compile_unit),
FunctionDefinitionKind::Extern(_) => None,
FunctionDefinitionKind::Intrinsic(_) => None,
},
)
.unwrap();
}
for (ty, mir_function) in &self.associated_functions {
let function = associated_functions
.get(&AssociatedFunctionKey(ty.clone(), mir_function.name.clone()))
.unwrap();
let mut entry = function.block("entry");
let allocator = Allocator::from(
&mir_function.kind,
&mir_function.parameters,
&mut AllocatorScope {
block: &mut entry,
type_values: &type_values,
},
);
let mut scope = Scope {
context,
modules: &modules,
tokens,
module: &module,
module_id: self.module_id,
function,
block: entry,
assoc_functions: &associated_functions,
functions: &functions,
types: &types,
type_values: &type_values,
@ -1215,7 +1315,76 @@ impl mir::Expression {
}
}
}
mir::ExprKind::AssociatedFunctionCall(..) => todo!(),
mir::ExprKind::AssociatedFunctionCall(ty, call) => {
let ret_type_kind = call.return_type.known().expect("function return type unknown");
let call_name = format!("{}::{}", ty, call.name);
let ret_type = ret_type_kind.get_type(scope.type_values);
let params = try_all(
call.parameters
.iter()
.map(|e| e.codegen(scope, state))
.collect::<Vec<_>>(),
)
.map_err(|e| e.first().cloned().unwrap())?
.into_iter()
.map(|v| v.unwrap())
.collect::<Vec<_>>();
let param_instrs = params.iter().map(|e| e.instr()).collect();
let callee = scope
.assoc_functions
.get(&AssociatedFunctionKey(ty.clone(), call.name.clone()))
.expect("function not found!");
let val = scope
.block
.build_named(&call_name, Instr::FunctionCall(callee.value(), param_instrs))
.unwrap();
if let Some(debug) = &scope.debug {
let location = call.meta.into_debug(scope.tokens, debug.scope).unwrap();
let location_val = debug.info.location(&debug.scope, location);
val.with_location(&mut scope.block, location_val);
}
let ptr = if ret_type_kind != TypeKind::Void {
let ptr = scope
.block
.build_named(&call_name, Instr::Alloca(ret_type.clone()))
.unwrap();
scope
.block
.build_named(format!("{}.store", call_name), Instr::Store(ptr, val))
.unwrap();
Some(ptr)
} else {
None
};
if let Some(ptr) = ptr {
if state.should_load {
Some(StackValue(
StackValueKind::Immutable(
scope
.block
.build_named(call.name.clone(), Instr::Load(ptr, ret_type))
.unwrap(),
),
ret_type_kind,
))
} else {
Some(StackValue(
StackValueKind::Immutable(ptr),
TypeKind::CodegenPtr(Box::new(ret_type_kind)),
))
}
} else {
None
}
}
};
if let Some(value) = &value {
value.instr().maybe_location(&mut scope.block, location);

View File

@ -8,7 +8,10 @@ use reid_lib::{
use crate::{
lexer::FullToken,
mir::{pass::BinopKey, CustomTypeKey, SourceModuleId, TypeDefinition, TypeKind},
mir::{
pass::{AssociatedFunctionKey, BinopKey},
CustomTypeKey, SourceModuleId, TypeDefinition, TypeKind,
},
};
use super::{allocator::Allocator, ErrorKind, IntrinsicFunction, ModuleCodegen};
@ -23,6 +26,7 @@ pub struct Scope<'ctx, 'scope> {
pub(super) block: Block<'ctx>,
pub(super) types: &'scope HashMap<TypeValue, TypeDefinition>,
pub(super) type_values: &'scope HashMap<CustomTypeKey, TypeValue>,
pub(super) assoc_functions: &'scope HashMap<AssociatedFunctionKey, Function<'ctx>>,
pub(super) functions: &'scope HashMap<String, Function<'ctx>>,
pub(super) binops: &'scope HashMap<BinopKey, StackBinopDefinition<'ctx>>,
pub(super) stack_values: HashMap<String, StackValue>,
@ -40,6 +44,7 @@ impl<'ctx, 'a> Scope<'ctx, 'a> {
context: self.context,
module: self.module,
module_id: self.module_id,
assoc_functions: self.assoc_functions,
functions: self.functions,
types: self.types,
type_values: self.type_values,