Add typechecking for custom binops

This commit is contained in:
Sofia 2025-07-24 15:38:08 +03:00
parent 9f7022b4c0
commit b12e0a18a9
5 changed files with 88 additions and 31 deletions

View File

@ -401,7 +401,6 @@ impl mir::Module {
binops.insert(
ScopeBinopKey {
params: (binop.lhs.1.clone(), binop.rhs.1.clone()),
commutative: mir::pass::CommutativeKind::True,
operator: binop.op,
},
StackBinopDefinition {

View File

@ -289,6 +289,40 @@ impl TypeKind {
},
}
}
pub fn try_collapse_two(
(lhs1, rhs1): (&TypeKind, &TypeKind),
(lhs2, rhs2): (&TypeKind, &TypeKind),
) -> Option<(TypeKind, TypeKind)> {
if lhs1.collapse_into(&lhs2).is_ok() && rhs1.collapse_into(&rhs2).is_ok() {
Some((lhs1.clone(), rhs2.clone()))
} else if lhs1.collapse_into(&rhs2).is_ok() && rhs1.collapse_into(&lhs2).is_ok() {
Some((rhs1.clone(), lhs1.clone()))
} else {
None
}
}
}
impl BinaryOperator {
pub fn is_commutative(&self) -> bool {
match self {
BinaryOperator::Add => true,
BinaryOperator::Minus => false,
BinaryOperator::Mult => true,
BinaryOperator::Div => false,
BinaryOperator::Mod => false,
BinaryOperator::And => true,
BinaryOperator::Cmp(cmp_operator) => match cmp_operator {
CmpOperator::LT => false,
CmpOperator::LE => false,
CmpOperator::GT => false,
CmpOperator::GE => false,
CmpOperator::EQ => true,
CmpOperator::NE => true,
},
}
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord)]

View File

@ -172,7 +172,6 @@ pub struct ScopeVariable {
pub struct ScopeBinopKey {
pub params: (TypeKind, TypeKind),
pub operator: BinaryOperator,
pub commutative: CommutativeKind,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
@ -187,14 +186,12 @@ impl PartialEq for ScopeBinopKey {
if self.operator != other.operator {
return false;
}
if self.commutative != CommutativeKind::Any && other.commutative != CommutativeKind::Any {
if self.commutative != other.commutative {
return false;
}
if self.operator.is_commutative() != other.operator.is_commutative() {
return false;
}
let operators_eq = self.params == other.params;
let swapped_ops_eq = (self.params.1.clone(), self.params.0.clone()) == other.params;
if self.commutative == CommutativeKind::True || other.commutative == CommutativeKind::True {
if self.operator.is_commutative() {
operators_eq || swapped_ops_eq
} else {
operators_eq
@ -204,7 +201,7 @@ impl PartialEq for ScopeBinopKey {
impl std::hash::Hash for ScopeBinopKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
if self.commutative == CommutativeKind::True {
if self.operator.is_commutative() {
let mut sorted = vec![&self.params.0, &self.params.1];
sorted.sort();
sorted.hash(state);
@ -368,7 +365,6 @@ impl Module {
.set(
ScopeBinopKey {
params: (binop.lhs.1.clone(), binop.rhs.1.clone()),
commutative: CommutativeKind::True,
operator: binop.op,
},
ScopeBinopDef {

View File

@ -508,34 +508,63 @@ impl Expression {
Ok(literal.as_type())
}
ExprKind::BinOp(op, lhs, rhs) => {
// TODO make sure lhs and rhs can actually do this binary
// operation once relevant
let lhs_res = lhs.typecheck(
state,
&typerefs,
hint_t.and_then(|t| t.simple_binop_hint(op)).as_ref(),
);
// First find unfiltered parameters to binop
let lhs_res = lhs.typecheck(state, &typerefs, None);
let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1);
let rhs_res = rhs.typecheck(state, &typerefs, Some(&lhs_type));
let rhs_res = rhs.typecheck(state, &typerefs, None);
let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1);
if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) {
// Try to coerce both sides again with collapsed type
lhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
rhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
}
let operator = state
.scope
.binops
.get(&pass::ScopeBinopKey {
params: (lhs_type.clone(), rhs_type.clone()),
operator: *op,
})
.cloned();
let both_t = lhs_type.collapse_into(&rhs_type)?;
if let Some(operator) = operator {
// Re-typecheck with found operator hints
let (lhs_ty, rhs_ty) = TypeKind::try_collapse_two(
(&lhs_type, &rhs_type),
(&operator.hands.0, &operator.hands.1),
)
.unwrap();
let lhs_res = lhs.typecheck(state, &typerefs, Some(&lhs_ty));
let rhs_res = rhs.typecheck(state, &typerefs, Some(&rhs_ty));
state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1);
state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1);
Ok(operator.return_ty)
} else {
// Re-typecheck with typical everyday binop
let lhs_res = lhs.typecheck(
state,
&typerefs,
hint_t.and_then(|t| t.simple_binop_hint(op)).as_ref(),
);
let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1);
let rhs_res = rhs.typecheck(state, &typerefs, Some(&lhs_type));
let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1);
if *op == BinaryOperator::Minus && !lhs_type.signed() {
if let (Some(lhs_val), Some(rhs_val)) = (lhs.num_value()?, rhs.num_value()?) {
if lhs_val < rhs_val {
return Err(ErrorKind::NegativeUnsignedValue(lhs_type));
let both_t = lhs_type.collapse_into(&rhs_type)?;
if *op == BinaryOperator::Minus && !lhs_type.signed() {
if let (Some(lhs_val), Some(rhs_val)) = (lhs.num_value()?, rhs.num_value()?)
{
if lhs_val < rhs_val {
return Err(ErrorKind::NegativeUnsignedValue(lhs_type));
}
}
}
}
Ok(both_t.simple_binop_type(op))
if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) {
// Try to coerce both sides again with collapsed type
lhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
rhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
}
Ok(both_t.simple_binop_type(op))
}
}
ExprKind::FunctionCall(function_call) => {
let true_function = state
@ -814,7 +843,7 @@ impl Expression {
Ok(*inner)
}
ExprKind::CastTo(expression, type_kind) => {
let expr = expression.typecheck(state, typerefs, None)?;
let expr = expression.typecheck(state, typerefs, Some(&type_kind))?;
expr.resolve_ref(typerefs).cast_into(type_kind)
}
}

View File

@ -63,7 +63,6 @@ impl<'t> Pass for TypeInference<'t> {
for binop in &module.binop_defs {
let binop_key = ScopeBinopKey {
params: (binop.lhs.1.clone(), binop.rhs.1.clone()),
commutative: pass::CommutativeKind::True,
operator: binop.op,
};
if seen_binops.contains(&binop_key) {