Add codegen to custom binops

This commit is contained in:
Sofia 2025-07-24 14:43:56 +03:00
parent 4a7f27205c
commit aec7d55e9b
10 changed files with 157 additions and 52 deletions

View File

@ -5,8 +5,8 @@ impl binop (lhs: u16) + (rhs: u32) -> u32 {
} }
fn main() -> u32 { fn main() -> u32 {
let value = 6; let value = 6 as u16;
let other = 15; let other = 15 as u32;
return value * other + 7 * -value; return value + other;
} }

View File

@ -110,7 +110,7 @@ impl ast::Module {
lhs: (lhs.0.clone(), lhs.1 .0.into_mir(module_id)), lhs: (lhs.0.clone(), lhs.1 .0.into_mir(module_id)),
op: op.mir(), op: op.mir(),
rhs: (rhs.0.clone(), rhs.1 .0.into_mir(module_id)), rhs: (rhs.0.clone(), rhs.1 .0.into_mir(module_id)),
return_ty: return_ty.0.into_mir(module_id), return_type: return_ty.0.into_mir(module_id),
fn_kind: mir::FunctionDefinitionKind::Local( fn_kind: mir::FunctionDefinitionKind::Local(
block.into_mir(module_id), block.into_mir(module_id),
block.2.as_meta(module_id), block.2.as_meta(module_id),

View File

@ -78,11 +78,11 @@ pub struct Scope<'ctx, 'scope> {
tokens: &'ctx Vec<FullToken>, tokens: &'ctx Vec<FullToken>,
module: &'ctx Module<'ctx>, module: &'ctx Module<'ctx>,
pub(super) module_id: SourceModuleId, pub(super) module_id: SourceModuleId,
function: &'ctx StackFunction<'ctx>, function: &'ctx Function<'ctx>,
pub(super) block: Block<'ctx>, pub(super) block: Block<'ctx>,
pub(super) types: &'scope HashMap<TypeValue, TypeDefinition>, pub(super) types: &'scope HashMap<TypeValue, TypeDefinition>,
pub(super) type_values: &'scope HashMap<CustomTypeKey, TypeValue>, pub(super) type_values: &'scope HashMap<CustomTypeKey, TypeValue>,
functions: &'scope HashMap<String, StackFunction<'ctx>>, functions: &'scope HashMap<String, Function<'ctx>>,
stack_values: HashMap<String, StackValue>, stack_values: HashMap<String, StackValue>,
debug: Option<Debug<'ctx>>, debug: Option<Debug<'ctx>>,
allocator: Rc<RefCell<Allocator>>, allocator: Rc<RefCell<Allocator>>,
@ -131,10 +131,6 @@ pub struct Debug<'ctx> {
types: &'ctx HashMap<TypeKind, DebugTypeValue>, types: &'ctx HashMap<TypeKind, DebugTypeValue>,
} }
pub struct StackFunction<'ctx> {
ir: Function<'ctx>,
}
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct StackValue(StackValueKind, TypeKind); pub struct StackValue(StackValueKind, TypeKind);
@ -291,21 +287,6 @@ impl mir::Module {
insert_debug!(&TypeKind::CustomType(type_key.clone())); insert_debug!(&TypeKind::CustomType(type_key.clone()));
} }
// let mut binops = HashMap::new();
// for binop in &self.binop_defs {
// binops.insert(
// ScopeBinopKey {
// operators: (binop.lhs.1.clone(), binop.rhs.1.clone()),
// commutative: mir::pass::CommutativeKind::True,
// },
// StackBinopDefinition {
// parameters: (binop.lhs.clone(), binop.rhs.clone()),
// return_ty: binop.return_ty.clone(),
// ir: todo!(),
// },
// );
// }
let mut functions = HashMap::new(); let mut functions = HashMap::new();
for function in &self.functions { for function in &self.functions {
@ -348,12 +329,91 @@ impl mir::Module {
), ),
}; };
functions.insert(function.name.clone(), StackFunction { ir: func }); functions.insert(function.name.clone(), func);
}
let mut binops = HashMap::new();
for binop in &self.binop_defs {
let binop_fn_name = format!(
"binop.{}.{:?}.{}.{}",
binop.lhs.1, binop.op, binop.rhs.1, binop.return_type
);
let ir_function = module.function(
&binop_fn_name,
binop.return_type.get_type(&type_values),
vec![
binop.lhs.1.get_type(&type_values),
binop.rhs.1.get_type(&type_values),
],
FunctionFlags::default(),
);
let mut entry = ir_function.block("entry");
let allocator = Allocator::from(
&binop.fn_kind,
&vec![binop.lhs.clone(), binop.rhs.clone()],
&mut AllocatorScope {
block: &mut entry,
module_id: self.module_id,
type_values: &type_values,
},
);
let mut scope = Scope {
context,
modules: &modules,
tokens,
module: &module,
module_id: self.module_id,
function: &ir_function,
block: entry,
functions: &functions,
types: &types,
type_values: &type_values,
stack_values: HashMap::new(),
debug: Some(Debug {
info: &debug,
scope: compile_unit,
types: &debug_types,
}),
allocator: Rc::new(RefCell::new(allocator)),
};
binop
.fn_kind
.codegen(
binop_fn_name.clone(),
false,
&mut scope,
&vec![binop.lhs.clone(), binop.rhs.clone()],
&binop.return_type,
&ir_function,
match &binop.fn_kind {
FunctionDefinitionKind::Local(_, meta) => {
meta.into_debug(tokens, compile_unit)
}
FunctionDefinitionKind::Extern(_) => None,
FunctionDefinitionKind::Intrinsic(_) => None,
},
)
.unwrap();
binops.insert(
ScopeBinopKey {
operators: (binop.lhs.1.clone(), binop.rhs.1.clone()),
commutative: mir::pass::CommutativeKind::True,
},
StackBinopDefinition {
parameters: (binop.lhs.clone(), binop.rhs.clone()),
return_ty: binop.return_type.clone(),
ir: ir_function,
},
);
} }
for mir_function in &self.functions { for mir_function in &self.functions {
let function = functions.get(&mir_function.name).unwrap(); let function = functions.get(&mir_function.name).unwrap();
let mut entry = function.ir.block("entry"); let mut entry = function.block("entry");
let allocator = Allocator::from( let allocator = Allocator::from(
&mir_function.kind, &mir_function.kind,
@ -393,7 +453,7 @@ impl mir::Module {
&mut scope, &mut scope,
&mir_function.parameters, &mir_function.parameters,
&mir_function.return_type, &mir_function.return_type,
&function.ir, &function,
match &mir_function.kind { match &mir_function.kind {
FunctionDefinitionKind::Local(..) => { FunctionDefinitionKind::Local(..) => {
mir_function.signature().into_debug(tokens, compile_unit) mir_function.signature().into_debug(tokens, compile_unit)
@ -670,9 +730,9 @@ impl mir::Statement {
mir::StmtKind::While(WhileStatement { mir::StmtKind::While(WhileStatement {
condition, block, .. condition, block, ..
}) => { }) => {
let condition_block = scope.function.ir.block("while.cond"); let condition_block = scope.function.block("while.cond");
let condition_true_block = scope.function.ir.block("while.body"); let condition_true_block = scope.function.block("while.body");
let condition_failed_block = scope.function.ir.block("while.end"); let condition_failed_block = scope.function.block("while.end");
scope scope
.block .block
@ -881,7 +941,7 @@ impl mir::Expression {
.block .block
.build_named( .build_named(
call.name.clone(), call.name.clone(),
Instr::FunctionCall(callee.ir.value(), param_instrs), Instr::FunctionCall(callee.value(), param_instrs),
) )
.unwrap(); .unwrap();
@ -929,7 +989,7 @@ impl mir::Expression {
} }
mir::ExprKind::If(if_expression) => if_expression.codegen(scope, state)?, mir::ExprKind::If(if_expression) => if_expression.codegen(scope, state)?,
mir::ExprKind::Block(block) => { mir::ExprKind::Block(block) => {
let inner = scope.function.ir.block("inner"); let inner = scope.function.block("inner");
scope.block.terminate(Term::Br(inner.value())).unwrap(); scope.block.terminate(Term::Br(inner.value())).unwrap();
let mut inner_scope = scope.with_block(inner); let mut inner_scope = scope.with_block(inner);
@ -938,7 +998,7 @@ impl mir::Expression {
} else { } else {
None None
}; };
let outer = scope.function.ir.block("outer"); let outer = scope.function.block("outer");
inner_scope.block.terminate(Term::Br(outer.value())).ok(); inner_scope.block.terminate(Term::Br(outer.value())).ok();
scope.swap_block(outer); scope.swap_block(outer);
ret ret
@ -1341,9 +1401,9 @@ impl mir::IfExpression {
let condition = self.0.codegen(scope, state)?.unwrap(); let condition = self.0.codegen(scope, state)?.unwrap();
// Create blocks // Create blocks
let mut then_b = scope.function.ir.block("then"); let mut then_b = scope.function.block("then");
let mut else_b = scope.function.ir.block("else"); let mut else_b = scope.function.block("else");
let after_b = scope.function.ir.block("after"); let after_b = scope.function.block("after");
if let Some(debug) = &scope.debug { if let Some(debug) = &scope.debug {
let before_location = self.0 .1.into_debug(scope.tokens, debug.scope).unwrap(); let before_location = self.0 .1.into_debug(scope.tokens, debug.scope).unwrap();

View File

@ -64,7 +64,7 @@ impl Display for BinopDefinition {
write!( write!(
f, f,
"impl binop ({}: {:#}) {} ({}: {:#}) -> {:#} ", "impl binop ({}: {:#}) {} ({}: {:#}) -> {:#} ",
self.lhs.0, self.lhs.1, self.op, self.rhs.0, self.rhs.1, self.return_ty self.lhs.0, self.lhs.1, self.op, self.rhs.0, self.rhs.1, self.return_type
)?; )?;
Display::fmt(&self.fn_kind, f) Display::fmt(&self.fn_kind, f)
} }

View File

@ -1,4 +1,4 @@
use super::{typecheck::ErrorKind, typerefs::TypeRefs, VagueType as Vague, *}; use super::{pass::ScopeBinopDef, typecheck::ErrorKind, typerefs::TypeRefs, VagueType as Vague, *};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ReturnTypeOther { pub enum ReturnTypeOther {
@ -95,7 +95,7 @@ impl TypeKind {
/// Return the type that is the result of a binary operator between two /// Return the type that is the result of a binary operator between two
/// values of this type /// values of this type
pub fn binop_type(&self, op: &BinaryOperator) -> TypeKind { pub fn simple_binop_type(&self, op: &BinaryOperator) -> TypeKind {
// TODO make some type of mechanism that allows to binop two values of // TODO make some type of mechanism that allows to binop two values of
// differing types.. // differing types..
// TODO Return None for arrays later // TODO Return None for arrays later
@ -110,6 +110,20 @@ impl TypeKind {
} }
} }
pub fn binop_type<'o>(
lhs: &TypeKind,
rhs: &TypeKind,
binop: &ScopeBinopDef,
) -> Option<(TypeKind, TypeKind, TypeKind)> {
let lhs_ty = lhs.collapse_into(&binop.operators.0);
let rhs_ty = rhs.collapse_into(&binop.operators.1);
if let (Ok(lhs_ty), Ok(rhs_ty)) = (lhs_ty, rhs_ty) {
Some((lhs_ty, rhs_ty, binop.return_ty.clone()))
} else {
None
}
}
/// Reverse of binop_type, where the given hint is the known required output /// Reverse of binop_type, where the given hint is the known required output
/// type of the binop, and the output is the hint for the lhs/rhs type. /// type of the binop, and the output is the hint for the lhs/rhs type.
pub fn binop_hint(&self, op: &BinaryOperator) -> Option<TypeKind> { pub fn binop_hint(&self, op: &BinaryOperator) -> Option<TypeKind> {

View File

@ -370,7 +370,7 @@ pub struct BinopDefinition {
pub lhs: (String, TypeKind), pub lhs: (String, TypeKind),
pub op: BinaryOperator, pub op: BinaryOperator,
pub rhs: (String, TypeKind), pub rhs: (String, TypeKind),
pub return_ty: TypeKind, pub return_type: TypeKind,
pub fn_kind: FunctionDefinitionKind, pub fn_kind: FunctionDefinitionKind,
pub meta: Metadata, pub meta: Metadata,
} }

View File

@ -111,6 +111,10 @@ impl<Key: std::hash::Hash + Eq, T: Clone + std::fmt::Debug> Storage<Key, T> {
pub fn get(&self, key: &Key) -> Option<&T> { pub fn get(&self, key: &Key) -> Option<&T> {
self.0.get(key) self.0.get(key)
} }
pub fn iter(&self) -> impl Iterator<Item = (&Key, &T)> {
self.0.iter()
}
} }
#[derive(Clone, Default, Debug)] #[derive(Clone, Default, Debug)]
@ -362,7 +366,7 @@ impl Module {
ScopeBinopDef { ScopeBinopDef {
operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), operators: (binop.lhs.1.clone(), binop.rhs.1.clone()),
commutative: true, commutative: true,
return_ty: binop.return_ty.clone(), return_ty: binop.return_type.clone(),
}, },
); );
} }

View File

@ -203,9 +203,9 @@ impl BinopDefinition {
state.ok(res, self.signature()); state.ok(res, self.signature());
} }
let return_type = self.return_ty.clone().assert_known(typerefs, state)?; let return_type = self.return_type.clone().assert_known(typerefs, state)?;
state.scope.return_type_hint = Some(self.return_ty.clone()); state.scope.return_type_hint = Some(self.return_type.clone());
let inferred = let inferred =
self.fn_kind self.fn_kind
.typecheck(&typerefs, &mut state.inner(), Some(return_type.clone())); .typecheck(&typerefs, &mut state.inner(), Some(return_type.clone()));
@ -535,7 +535,7 @@ impl Expression {
} }
} }
Ok(both_t.binop_type(op)) Ok(both_t.simple_binop_type(op))
} }
ExprKind::FunctionCall(function_call) => { ExprKind::FunctionCall(function_call) => {
let true_function = state let true_function = state

View File

@ -124,11 +124,11 @@ impl BinopDefinition {
self.signature(), self.signature(),
); );
let ret_ty = self let ret_ty =
.fn_kind self.fn_kind
.infer_types(state, &scope_hints, Some(self.return_ty.clone()))?; .infer_types(state, &scope_hints, Some(self.return_type.clone()))?;
if let Some(mut ret_ty) = ret_ty { if let Some(mut ret_ty) = ret_ty {
ret_ty.narrow(&scope_hints.from_type(&self.return_ty).unwrap()); ret_ty.narrow(&scope_hints.from_type(&self.return_type).unwrap());
} }
Ok(()) Ok(())
@ -312,7 +312,7 @@ impl Expression {
let mut lhs_ref = lhs.infer_types(state, type_refs)?; let mut lhs_ref = lhs.infer_types(state, type_refs)?;
let mut rhs_ref = rhs.infer_types(state, type_refs)?; let mut rhs_ref = rhs.infer_types(state, type_refs)?;
type_refs type_refs
.binop(op, &mut lhs_ref, &mut rhs_ref) .binop(op, &mut lhs_ref, &mut rhs_ref, &state.scope.binops)
.ok_or(ErrorKind::TypesIncompatible( .ok_or(ErrorKind::TypesIncompatible(
lhs_ref.resolve_deep().unwrap(), lhs_ref.resolve_deep().unwrap(),
rhs_ref.resolve_deep().unwrap(), rhs_ref.resolve_deep().unwrap(),

View File

@ -6,7 +6,11 @@ use std::{
use crate::mir::VagueType; use crate::mir::VagueType;
use super::{typecheck::ErrorKind, BinaryOperator, TypeKind}; use super::{
pass::{ScopeBinopDef, ScopeBinopKey, Storage},
typecheck::ErrorKind,
BinaryOperator, TypeKind,
};
#[derive(Clone)] #[derive(Clone)]
pub struct TypeRef<'scope>( pub struct TypeRef<'scope>(
@ -227,8 +231,31 @@ impl<'outer> ScopeTypeRefs<'outer> {
op: &BinaryOperator, op: &BinaryOperator,
lhs: &mut TypeRef<'outer>, lhs: &mut TypeRef<'outer>,
rhs: &mut TypeRef<'outer>, rhs: &mut TypeRef<'outer>,
binops: &Storage<ScopeBinopKey, ScopeBinopDef>,
) -> Option<TypeRef<'outer>> { ) -> Option<TypeRef<'outer>> {
for (_, binop) in binops.iter() {
if let Some(ret) = try_binop(lhs, rhs, binop) {
return Some(ret);
}
if binop.commutative {
if let Some(ret) = try_binop(rhs, lhs, binop) {
return Some(ret);
}
}
}
let ty = lhs.narrow(rhs)?; let ty = lhs.narrow(rhs)?;
self.from_type(&ty.as_type().binop_type(op)) self.from_type(&ty.as_type().simple_binop_type(op))
} }
} }
fn try_binop<'o>(
lhs: &mut TypeRef<'o>,
rhs: &mut TypeRef<'o>,
binop: &ScopeBinopDef,
) -> Option<TypeRef<'o>> {
let (lhs_ty, rhs_ty, ret_ty) =
TypeKind::binop_type(&lhs.resolve_deep()?, &rhs.resolve_deep()?, binop)?;
lhs.narrow(&lhs.1.from_type(&lhs_ty).unwrap()).unwrap();
rhs.narrow(&rhs.1.from_type(&rhs_ty).unwrap()).unwrap();
lhs.1.from_type(&ret_ty)
}