diff --git a/reid/src/mir/pass.rs b/reid/src/mir/pass.rs index 69537a8..e13ad49 100644 --- a/reid/src/mir/pass.rs +++ b/reid/src/mir/pass.rs @@ -5,6 +5,7 @@ use std::collections::HashMap; use std::convert::Infallible; use std::error::Error as STDError; +use crate::codegen::intrinsics::form_intrinsic_binops; use crate::error_raporting::ReidError; use super::*; @@ -53,12 +54,7 @@ impl State { } } - fn or_else + Clone + Copy>( - &mut self, - result: Result, - default: U, - meta: T, - ) -> U { + fn or_else + Clone + Copy>(&mut self, result: Result, default: U, meta: T) -> U { match result { Ok(t) => t, Err(e) => { @@ -71,11 +67,7 @@ impl State { } } - fn ok + Clone + Copy, U>( - &mut self, - result: Result, - meta: T, - ) -> Option { + fn ok + Clone + Copy, U>(&mut self, result: Result, meta: T) -> Option { match result { Ok(v) => Some(v), Err(e) => { @@ -200,10 +192,10 @@ impl PartialEq for ScopeBinopKey { return false; } - let operators_eq = self.params.0.narrow_into(&other.params.0).is_ok() - && self.params.1.narrow_into(&other.params.1).is_ok(); - let swapped_ops_eq = self.params.0.narrow_into(&other.params.1).is_ok() - && self.params.1.narrow_into(&other.params.0).is_ok(); + let operators_eq = + self.params.0.narrow_into(&other.params.0).is_ok() && self.params.1.narrow_into(&other.params.1).is_ok(); + let swapped_ops_eq = + self.params.0.narrow_into(&other.params.1).is_ok() && self.params.1.narrow_into(&other.params.0).is_ok(); if self.operator.is_commutative() { operators_eq || swapped_ops_eq @@ -253,11 +245,7 @@ pub struct PassState<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> } impl<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> PassState<'st, 'sc, Data, TError> { - fn from( - state: &'st mut State, - scope: &'sc mut Scope, - module_id: Option, - ) -> Self { + fn from(state: &'st mut State, scope: &'sc mut Scope, module_id: Option) -> Self { PassState { state, scope, @@ -275,19 +263,11 @@ impl<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> PassState<'st, ' self.state.or_else(result, default, meta) } - pub fn ok + Clone + Copy, U>( - &mut self, - result: Result, - meta: TMeta, - ) -> Option { + pub fn ok + Clone + Copy, U>(&mut self, result: Result, meta: TMeta) -> Option { self.state.ok(result, meta) } - pub fn note_errors + Clone>( - &mut self, - errors: &Vec, - meta: TMeta, - ) { + pub fn note_errors + Clone>(&mut self, errors: &Vec, meta: TMeta) { for error in errors { self.ok::<_, Infallible>(Err(error.clone()), meta.clone().into()); } @@ -311,18 +291,10 @@ pub trait Pass { type Data: Clone + Default; type TError: STDError + Clone; - fn context( - &mut self, - _context: &mut Context, - mut _state: PassState, - ) -> PassResult { + fn context(&mut self, _context: &mut Context, mut _state: PassState) -> PassResult { Ok(()) } - fn module( - &mut self, - _module: &mut Module, - mut _state: PassState, - ) -> PassResult { + fn module(&mut self, _module: &mut Module, mut _state: PassState) -> PassResult { Ok(()) } fn function( @@ -332,25 +304,13 @@ pub trait Pass { ) -> PassResult { Ok(()) } - fn block( - &mut self, - _block: &mut Block, - mut _state: PassState, - ) -> PassResult { + fn block(&mut self, _block: &mut Block, mut _state: PassState) -> PassResult { Ok(()) } - fn stmt( - &mut self, - _stmt: &mut Statement, - mut _state: PassState, - ) -> PassResult { + fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState) -> PassResult { Ok(()) } - fn expr( - &mut self, - _expr: &mut Expression, - mut _state: PassState, - ) -> PassResult { + fn expr(&mut self, _expr: &mut Expression, mut _state: PassState) -> PassResult { Ok(()) } } @@ -360,6 +320,24 @@ impl Context { let mut state = State::new(); let mut scope = Scope::default(); pass.context(self, PassState::from(&mut state, &mut scope, None))?; + + for intrinsic in form_intrinsic_binops() { + scope + .binops + .set( + ScopeBinopKey { + params: (intrinsic.lhs.1.clone(), intrinsic.rhs.1.clone()), + operator: intrinsic.op, + }, + ScopeBinopDef { + hands: (intrinsic.lhs.1.clone(), intrinsic.rhs.1.clone()), + operator: intrinsic.op, + return_ty: intrinsic.return_type.clone(), + }, + ) + .ok(); + } + for (_, module) in &mut self.modules { module.pass(pass, &mut state, &mut scope.inner())?; } @@ -368,12 +346,7 @@ impl Context { } impl Module { - fn pass( - &mut self, - pass: &mut T, - state: &mut State, - scope: &mut Scope, - ) -> PassResult { + fn pass(&mut self, pass: &mut T, state: &mut State, scope: &mut Scope) -> PassResult { for typedef in &self.typedefs { scope .types diff --git a/reid/src/mir/typecheck/typecheck.rs b/reid/src/mir/typecheck/typecheck.rs index 9106acf..ec6488f 100644 --- a/reid/src/mir/typecheck/typecheck.rs +++ b/reid/src/mir/typecheck/typecheck.rs @@ -418,13 +418,13 @@ impl Expression { let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1); let expected_return_ty = ret_ty.resolve_ref(typerefs); - let binops = typerefs.binop_types.filter(&pass::ScopeBinopKey { + let binops = state.scope.binops.filter(&pass::ScopeBinopKey { params: (lhs_type.clone(), rhs_type.clone()), operator: *op, }); if let Some(binop) = binops .iter() - .filter(|f| f.1.return_ty == expected_return_ty) + .filter(|f| f.1.return_ty.narrow_into(&expected_return_ty).is_ok()) .map(|v| (v.1.clone())) .next() { diff --git a/reid/src/mir/typecheck/typeinference.rs b/reid/src/mir/typecheck/typeinference.rs index 9ed66bc..cdc6010 100644 --- a/reid/src/mir/typecheck/typeinference.rs +++ b/reid/src/mir/typecheck/typeinference.rs @@ -119,6 +119,7 @@ impl BinopDefinition { .fn_kind .infer_types(state, &scope_hints, Some(self.return_type.clone()))?; if let Some(mut ret_ty) = ret_ty { + dbg!(&ret_ty, &self.return_type); ret_ty.narrow(&scope_hints.from_type(&self.return_type).unwrap()); } @@ -296,7 +297,27 @@ impl Expression { let mut lhs_ref = lhs.infer_types(state, type_refs)?; let mut rhs_ref = rhs.infer_types(state, type_refs)?; - let binops = type_refs.available_binops(op, &mut lhs_ref, &mut rhs_ref); + let binops = if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_ref.resolve_deep(), rhs_ref.resolve_deep()) { + let mut applying_binops = Vec::new(); + for (_, binop) in state.scope.binops.iter() { + if binop.operator != *op { + continue; + } + if let Some(_) = binop.narrow(&lhs_ty, &rhs_ty) { + applying_binops.push(binop.clone()); + continue; + } + if binop.operator.is_commutative() { + if let Some(_) = binop.narrow(&lhs_ty, &rhs_ty) { + applying_binops.push(binop.clone()); + continue; + } + } + } + applying_binops + } else { + Vec::new() + }; if binops.len() > 0 { let binop = unsafe { binops.get_unchecked(0) }; @@ -307,9 +328,15 @@ impl Expression { widened_rhs = widened_rhs.widen_into(&binop.hands.1); } let binop_res = type_refs.from_binop(*op, &lhs_ref, &rhs_ref); + dbg!(&type_refs.types.type_refs); + dbg!(&type_refs.types.hints); lhs_ref.narrow(&type_refs.from_type(&widened_lhs).unwrap()); rhs_ref.narrow(&type_refs.from_type(&widened_rhs).unwrap()); + dbg!(&lhs_ref, &rhs_ref); *return_ty = binop_res.as_type(); + dbg!(&type_refs.types.hints, &type_refs.types.type_refs); + dbg!(&return_ty); + dbg!(&type_refs.from_type(&return_ty)); Ok(binop_res) } else { Err(ErrorKind::InvalidBinop( diff --git a/reid/src/mir/typecheck/typerefs.rs b/reid/src/mir/typecheck/typerefs.rs index c870399..c4865cd 100644 --- a/reid/src/mir/typecheck/typerefs.rs +++ b/reid/src/mir/typecheck/typerefs.rs @@ -283,14 +283,14 @@ impl<'outer> ScopeTypeRefs<'outer> { match &lhs { TypeKind::Vague(VagueType::TypeRef(idx)) => { let mut lhs_ref = TypeRef(Rc::new(RefCell::new(*idx)), self); - let narrowed = self.narrow_to_type(&mut lhs_ref, &lhs_narrow).unwrap_or(lhs_ref); + self.narrow_to_type(&mut lhs_ref, &lhs_narrow).unwrap_or(lhs_ref); } _ => {} }; match &rhs { TypeKind::Vague(VagueType::TypeRef(idx)) => { let mut rhs_ref = TypeRef(Rc::new(RefCell::new(*idx)), self); - let narrowed = self.narrow_to_type(&mut rhs_ref, &rhs_narrow).unwrap_or(rhs_ref); + self.narrow_to_type(&mut rhs_ref, &rhs_narrow).unwrap_or(rhs_ref); } _ => {} } @@ -312,9 +312,20 @@ impl<'outer> ScopeTypeRefs<'outer> { .clone() .widen(self.types); self.narrow_to_type(&hint1, &ty)?; + let hint1_typeref = self.types.retrieve_typeref(*hint1.0.borrow()).unwrap(); for idx in self.types.type_refs.borrow_mut().iter_mut() { - if *idx == hint2.0 && idx != &hint1.0 { - *idx.borrow_mut() = *hint1.0.borrow(); + match hint1_typeref { + TypeRefKind::Direct(_) => { + if *idx == hint2.0 && idx != &hint1.0 { + *idx.borrow_mut() = *hint1.0.borrow(); + } + } + TypeRefKind::BinOp(_, _, _) => { + // TODO may not be good ? + // if *idx == hint2.0 && idx != &hint1.0 { + // *idx.borrow_mut() = *hint1.0.borrow(); + // } + } } } Some(TypeRef(hint1.0.clone(), self))