diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index 50a3257..d136df3 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -1,7 +1,7 @@ //! This module contains relevant code for [`Pass`] and shared code between //! passes. Passes can be performed on Reid MIR to e.g. typecheck the code. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::convert::Infallible; use std::error::Error as STDError; @@ -115,6 +115,7 @@ impl Storage { #[derive(Clone, Default, Debug)] pub struct Scope { + pub binops: Storage, pub function_returns: Storage, pub variables: Storage, pub types: Storage, @@ -128,6 +129,7 @@ impl Scope { Scope { function_returns: self.function_returns.clone(), variables: self.variables.clone(), + binops: self.binops.clone(), types: self.types.clone(), return_type_hint: self.return_type_hint.clone(), data: self.data.clone(), @@ -162,6 +164,56 @@ pub struct ScopeVariable { pub mutable: bool, } +#[derive(Clone, Debug, Eq)] +pub struct ScopeBinopKey { + pub operators: (TypeKind, TypeKind), + pub commutative: CommutativeKind, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum CommutativeKind { + True, + False, + Any, +} + +impl PartialEq for ScopeBinopKey { + fn eq(&self, other: &Self) -> bool { + if self.commutative != CommutativeKind::Any && other.commutative != CommutativeKind::Any { + if self.commutative != other.commutative { + return false; + } + } + let operators_eq = self.operators == other.operators; + let swapped_ops_eq = + (self.operators.1.clone(), self.operators.0.clone()) == other.operators; + if self.commutative == CommutativeKind::True || other.commutative == CommutativeKind::True { + operators_eq || swapped_ops_eq + } else { + operators_eq + } + } +} + +impl std::hash::Hash for ScopeBinopKey { + fn hash(&self, state: &mut H) { + if self.commutative == CommutativeKind::True { + let mut sorted = vec![&self.operators.0, &self.operators.1]; + sorted.sort(); + sorted.hash(state); + } else { + self.operators.hash(state); + } + } +} + +#[derive(Clone, Debug)] +pub struct ScopeBinopDef { + pub operators: (TypeKind, TypeKind), + pub commutative: bool, + pub return_ty: TypeKind, +} + pub struct PassState<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> { state: &'st mut State, pub scope: &'sc mut Scope, @@ -301,6 +353,20 @@ impl Module { .ok(); } + for binop in &self.binop_defs { + scope.binops.set( + ScopeBinopKey { + operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), + commutative: CommutativeKind::True, + }, + ScopeBinopDef { + operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), + commutative: true, + return_ty: binop.return_ty.clone(), + }, + ); + } + for function in &self.functions { scope .function_returns diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs index 842a494..3d70126 100644 --- a/reid/src/mir/typecheck.rs +++ b/reid/src/mir/typecheck.rs @@ -70,6 +70,8 @@ pub enum ErrorKind { NotCastableTo(TypeKind, TypeKind), #[error("Cannot divide by zero")] DivideZero, + #[error("Binary operation between {0} and {1} is already defined!")] + BinaryOpAlreadyDefined(TypeKind, TypeKind), } /// Struct used to implement a type-checking pass that can be performed on the diff --git a/reid/src/mir/typeinference.rs b/reid/src/mir/typeinference.rs index 99f637c..e831c64 100644 --- a/reid/src/mir/typeinference.rs +++ b/reid/src/mir/typeinference.rs @@ -4,12 +4,16 @@ //! must then be passed through TypeCheck with the same [`TypeRefs`] in order to //! place the correct types from the IDs and check that there are no issues. -use std::{collections::HashMap, convert::Infallible, iter}; +use std::{ + collections::{HashMap, HashSet}, + convert::Infallible, + iter, +}; use crate::{mir::TypeKind, util::try_all}; use super::{ - pass::{Pass, PassResult, PassState}, + pass::{self, Pass, PassResult, PassState, ScopeBinopDef, ScopeBinopKey}, typecheck::{ErrorKind, ErrorTypedefKind}, typerefs::{ScopeTypeRefs, TypeRef, TypeRefs}, BinopDefinition, Block, CustomTypeKey, ExprKind, Expression, FunctionDefinition, @@ -55,6 +59,25 @@ impl<'t> Pass for TypeInference<'t> { } } + let mut seen_binops = HashSet::new(); + for binop in &module.binop_defs { + let binop_key = ScopeBinopKey { + operators: (binop.lhs.1.clone(), binop.rhs.1.clone()), + commutative: pass::CommutativeKind::True, + }; + if seen_binops.contains(&binop_key) { + state.note_errors( + &vec![ErrorKind::BinaryOpAlreadyDefined( + binop.lhs.1.clone(), + binop.rhs.1.clone(), + )], + binop.signature(), + ); + } else { + seen_binops.insert(binop_key); + } + } + for binop in &mut module.binop_defs { let res = binop.infer_types(&self.refs, &mut state.inner()); state.ok(res, binop.block_meta());