diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index fd861c9..2a0f1d7 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -401,7 +401,6 @@ impl mir::Module { binops.insert( ScopeBinopKey { params: (binop.lhs.1.clone(), binop.rhs.1.clone()), - commutative: mir::pass::CommutativeKind::True, operator: binop.op, }, StackBinopDefinition { diff --git a/reid/src/mir/implement.rs b/reid/src/mir/implement.rs index 6235e84..bbd16a2 100644 --- a/reid/src/mir/implement.rs +++ b/reid/src/mir/implement.rs @@ -289,6 +289,40 @@ impl TypeKind { }, } } + + pub fn try_collapse_two( + (lhs1, rhs1): (&TypeKind, &TypeKind), + (lhs2, rhs2): (&TypeKind, &TypeKind), + ) -> Option<(TypeKind, TypeKind)> { + if lhs1.collapse_into(&lhs2).is_ok() && rhs1.collapse_into(&rhs2).is_ok() { + Some((lhs1.clone(), rhs2.clone())) + } else if lhs1.collapse_into(&rhs2).is_ok() && rhs1.collapse_into(&lhs2).is_ok() { + Some((rhs1.clone(), lhs1.clone())) + } else { + None + } + } +} + +impl BinaryOperator { + pub fn is_commutative(&self) -> bool { + match self { + BinaryOperator::Add => true, + BinaryOperator::Minus => false, + BinaryOperator::Mult => true, + BinaryOperator::Div => false, + BinaryOperator::Mod => false, + BinaryOperator::And => true, + BinaryOperator::Cmp(cmp_operator) => match cmp_operator { + CmpOperator::LT => false, + CmpOperator::LE => false, + CmpOperator::GT => false, + CmpOperator::GE => false, + CmpOperator::EQ => true, + CmpOperator::NE => true, + }, + } + } } #[derive(PartialEq, Eq, PartialOrd, Ord)] diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index 3823d74..9595166 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -172,7 +172,6 @@ pub struct ScopeVariable { pub struct ScopeBinopKey { pub params: (TypeKind, TypeKind), pub operator: BinaryOperator, - pub commutative: CommutativeKind, } #[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] @@ -187,14 +186,12 @@ impl PartialEq for ScopeBinopKey { if self.operator != other.operator { return false; } - if self.commutative != CommutativeKind::Any && other.commutative != CommutativeKind::Any { - if self.commutative != other.commutative { - return false; - } + if self.operator.is_commutative() != other.operator.is_commutative() { + return false; } let operators_eq = self.params == other.params; let swapped_ops_eq = (self.params.1.clone(), self.params.0.clone()) == other.params; - if self.commutative == CommutativeKind::True || other.commutative == CommutativeKind::True { + if self.operator.is_commutative() { operators_eq || swapped_ops_eq } else { operators_eq @@ -204,7 +201,7 @@ impl PartialEq for ScopeBinopKey { impl std::hash::Hash for ScopeBinopKey { fn hash(&self, state: &mut H) { - if self.commutative == CommutativeKind::True { + if self.operator.is_commutative() { let mut sorted = vec![&self.params.0, &self.params.1]; sorted.sort(); sorted.hash(state); @@ -368,7 +365,6 @@ impl Module { .set( ScopeBinopKey { params: (binop.lhs.1.clone(), binop.rhs.1.clone()), - commutative: CommutativeKind::True, operator: binop.op, }, ScopeBinopDef { diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index acd103a..08d5ddc 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -508,34 +508,63 @@ impl Expression { Ok(literal.as_type()) } ExprKind::BinOp(op, lhs, rhs) => { - // TODO make sure lhs and rhs can actually do this binary - // operation once relevant - let lhs_res = lhs.typecheck( - state, - &typerefs, - hint_t.and_then(|t| t.simple_binop_hint(op)).as_ref(), - ); + // First find unfiltered parameters to binop + let lhs_res = lhs.typecheck(state, &typerefs, None); let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1); - let rhs_res = rhs.typecheck(state, &typerefs, Some(&lhs_type)); + let rhs_res = rhs.typecheck(state, &typerefs, None); let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1); - if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) { - // Try to coerce both sides again with collapsed type - lhs.typecheck(state, &typerefs, Some(&collapsed)).ok(); - rhs.typecheck(state, &typerefs, Some(&collapsed)).ok(); - } + let operator = state + .scope + .binops + .get(&pass::ScopeBinopKey { + params: (lhs_type.clone(), rhs_type.clone()), + operator: *op, + }) + .cloned(); - let both_t = lhs_type.collapse_into(&rhs_type)?; + if let Some(operator) = operator { + // Re-typecheck with found operator hints + let (lhs_ty, rhs_ty) = TypeKind::try_collapse_two( + (&lhs_type, &rhs_type), + (&operator.hands.0, &operator.hands.1), + ) + .unwrap(); + let lhs_res = lhs.typecheck(state, &typerefs, Some(&lhs_ty)); + let rhs_res = rhs.typecheck(state, &typerefs, Some(&rhs_ty)); + state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1); + state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1); + Ok(operator.return_ty) + } else { + // Re-typecheck with typical everyday binop + let lhs_res = lhs.typecheck( + state, + &typerefs, + hint_t.and_then(|t| t.simple_binop_hint(op)).as_ref(), + ); + let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1); + let rhs_res = rhs.typecheck(state, &typerefs, Some(&lhs_type)); + let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1); - if *op == BinaryOperator::Minus && !lhs_type.signed() { - if let (Some(lhs_val), Some(rhs_val)) = (lhs.num_value()?, rhs.num_value()?) { - if lhs_val < rhs_val { - return Err(ErrorKind::NegativeUnsignedValue(lhs_type)); + let both_t = lhs_type.collapse_into(&rhs_type)?; + + if *op == BinaryOperator::Minus && !lhs_type.signed() { + if let (Some(lhs_val), Some(rhs_val)) = (lhs.num_value()?, rhs.num_value()?) + { + if lhs_val < rhs_val { + return Err(ErrorKind::NegativeUnsignedValue(lhs_type)); + } } } - } - Ok(both_t.simple_binop_type(op)) + if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) { + // Try to coerce both sides again with collapsed type + lhs.typecheck(state, &typerefs, Some(&collapsed)).ok(); + rhs.typecheck(state, &typerefs, Some(&collapsed)).ok(); + } + + Ok(both_t.simple_binop_type(op)) + } } ExprKind::FunctionCall(function_call) => { let true_function = state @@ -814,7 +843,7 @@ impl Expression { Ok(*inner) } ExprKind::CastTo(expression, type_kind) => { - let expr = expression.typecheck(state, typerefs, None)?; + let expr = expression.typecheck(state, typerefs, Some(&type_kind))?; expr.resolve_ref(typerefs).cast_into(type_kind) } } diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index 1ece469..395c932 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -63,7 +63,6 @@ impl<'t> Pass for TypeInference<'t> { for binop in &module.binop_defs { let binop_key = ScopeBinopKey { params: (binop.lhs.1.clone(), binop.rhs.1.clone()), - commutative: pass::CommutativeKind::True, operator: binop.op, }; if seen_binops.contains(&binop_key) {