From 9f7022b4c0448c2e08342a11310290760c1aa8e2 Mon Sep 17 00:00:00 2001 From: sofia Date: Thu, 24 Jul 2025 15:09:27 +0300 Subject: [PATCH] Add operator to scopebinop, add some typechecking for binops --- reid/src/codegen.rs | 3 ++- reid/src/mir/implement.rs | 24 +++++++++++++++--- reid/src/mir/mod.rs | 4 +-- reid/src/mir/pass.rs | 46 +++++++++++++++++++++-------------- reid/src/mir/typecheck.rs | 2 +- reid/src/mir/typeinference.rs | 3 ++- 6 files changed, 55 insertions(+), 27 deletions(-) diff --git a/reid/src/codegen.rs b/reid/src/codegen.rs index 9eb3ef9..fd861c9 100644 --- a/reid/src/codegen.rs +++ b/reid/src/codegen.rs @@ -400,8 +400,9 @@ impl mir::Module { binops.insert( ScopeBinopKey { - operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), + params: (binop.lhs.1.clone(), binop.rhs.1.clone()), commutative: mir::pass::CommutativeKind::True, + operator: binop.op, }, StackBinopDefinition { parameters: (binop.lhs.clone(), binop.rhs.clone()), diff --git a/reid/src/mir/implement.rs b/reid/src/mir/implement.rs index f8951eb..6235e84 100644 --- a/reid/src/mir/implement.rs +++ b/reid/src/mir/implement.rs @@ -110,13 +110,13 @@ impl TypeKind { } } - pub fn binop_type<'o>( + pub fn binop_type( lhs: &TypeKind, rhs: &TypeKind, binop: &ScopeBinopDef, ) -> Option<(TypeKind, TypeKind, TypeKind)> { - let lhs_ty = lhs.collapse_into(&binop.operators.0); - let rhs_ty = rhs.collapse_into(&binop.operators.1); + let lhs_ty = lhs.collapse_into(&binop.hands.0); + let rhs_ty = rhs.collapse_into(&binop.hands.1); if let (Ok(lhs_ty), Ok(rhs_ty)) = (lhs_ty, rhs_ty) { Some((lhs_ty, rhs_ty, binop.return_ty.clone())) } else { @@ -126,7 +126,7 @@ impl TypeKind { /// Reverse of binop_type, where the given hint is the known required output /// type of the binop, and the output is the hint for the lhs/rhs type. - pub fn binop_hint(&self, op: &BinaryOperator) -> Option { + pub fn simple_binop_hint(&self, op: &BinaryOperator) -> Option { match op { BinaryOperator::Add | BinaryOperator::Minus @@ -138,6 +138,22 @@ impl TypeKind { } } + pub fn binop_hint( + &self, + lhs: &TypeKind, + rhs: &TypeKind, + binop: &ScopeBinopDef, + ) -> Option<(TypeKind, TypeKind)> { + self.collapse_into(&binop.return_ty).ok()?; + let lhs_ty = lhs.collapse_into(&binop.hands.0); + let rhs_ty = rhs.collapse_into(&binop.hands.1); + if let (Ok(lhs_ty), Ok(rhs_ty)) = (lhs_ty, rhs_ty) { + Some((lhs_ty, rhs_ty)) + } else { + None + } + } + pub fn signed(&self) -> bool { match self { TypeKind::Bool => false, diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index 2e3055a..09652a7 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -214,7 +214,7 @@ impl VagueLiteral { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum BinaryOperator { Add, Minus, @@ -226,7 +226,7 @@ pub enum BinaryOperator { } /// Specifically the operators that LLVM likes to take in as "icmp" parameters -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum CmpOperator { LT, LE, diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index 64b29e6..3823d74 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -170,7 +170,8 @@ pub struct ScopeVariable { #[derive(Clone, Debug, Eq)] pub struct ScopeBinopKey { - pub operators: (TypeKind, TypeKind), + pub params: (TypeKind, TypeKind), + pub operator: BinaryOperator, pub commutative: CommutativeKind, } @@ -183,14 +184,16 @@ pub enum CommutativeKind { impl PartialEq for ScopeBinopKey { fn eq(&self, other: &Self) -> bool { + if self.operator != other.operator { + return false; + } if self.commutative != CommutativeKind::Any && other.commutative != CommutativeKind::Any { if self.commutative != other.commutative { return false; } } - let operators_eq = self.operators == other.operators; - let swapped_ops_eq = - (self.operators.1.clone(), self.operators.0.clone()) == other.operators; + let operators_eq = self.params == other.params; + let swapped_ops_eq = (self.params.1.clone(), self.params.0.clone()) == other.params; if self.commutative == CommutativeKind::True || other.commutative == CommutativeKind::True { operators_eq || swapped_ops_eq } else { @@ -202,18 +205,20 @@ impl PartialEq for ScopeBinopKey { impl std::hash::Hash for ScopeBinopKey { fn hash(&self, state: &mut H) { if self.commutative == CommutativeKind::True { - let mut sorted = vec![&self.operators.0, &self.operators.1]; + let mut sorted = vec![&self.params.0, &self.params.1]; sorted.sort(); sorted.hash(state); + self.operator.hash(state); } else { - self.operators.hash(state); + self.params.hash(state); } } } #[derive(Clone, Debug)] pub struct ScopeBinopDef { - pub operators: (TypeKind, TypeKind), + pub hands: (TypeKind, TypeKind), + pub operator: BinaryOperator, pub commutative: bool, pub return_ty: TypeKind, } @@ -358,17 +363,22 @@ impl Module { } for binop in &self.binop_defs { - scope.binops.set( - ScopeBinopKey { - operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), - commutative: CommutativeKind::True, - }, - ScopeBinopDef { - operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), - commutative: true, - return_ty: binop.return_type.clone(), - }, - ); + scope + .binops + .set( + ScopeBinopKey { + params: (binop.lhs.1.clone(), binop.rhs.1.clone()), + commutative: CommutativeKind::True, + operator: binop.op, + }, + ScopeBinopDef { + hands: (binop.lhs.1.clone(), binop.rhs.1.clone()), + operator: binop.op, + commutative: true, + return_ty: binop.return_type.clone(), + }, + ) + .ok(); } for function in &self.functions { diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index a44801f..acd103a 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -513,7 +513,7 @@ impl Expression { let lhs_res = lhs.typecheck( state, &typerefs, - hint_t.and_then(|t| t.binop_hint(op)).as_ref(), + hint_t.and_then(|t| t.simple_binop_hint(op)).as_ref(), ); let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1); let rhs_res = rhs.typecheck(state, &typerefs, Some(&lhs_type)); diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index c42104e..1ece469 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -62,8 +62,9 @@ impl<'t> Pass for TypeInference<'t> { let mut seen_binops = HashSet::new(); for binop in &module.binop_defs { let binop_key = ScopeBinopKey { - operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), + params: (binop.lhs.1.clone(), binop.rhs.1.clone()), commutative: pass::CommutativeKind::True, + operator: binop.op, }; if seen_binops.contains(&binop_key) { state.note_errors(