From 407c681cb69eb26e4e997c7a07da87fa06bc8679 Mon Sep 17 00:00:00 2001 From: sofia Date: Sat, 16 Aug 2025 17:36:37 +0300 Subject: [PATCH] Add ordering for how types are chosen for binops --- reid/src/ast/process.rs | 2 +- reid/src/codegen/intrinsics.rs | 2 +- reid/src/codegen/util.rs | 4 +- reid/src/mir/fmt.rs | 2 +- reid/src/mir/implement.rs | 80 +++++++++++++++++++++---- reid/src/mir/mod.rs | 6 +- reid/src/mir/typecheck/mod.rs | 6 +- reid/src/mir/typecheck/typecheck.rs | 4 +- reid/src/mir/typecheck/typeinference.rs | 8 ++- reid/src/mir/typecheck/typerefs.rs | 27 ++++++++- 10 files changed, 114 insertions(+), 27 deletions(-) diff --git a/reid/src/ast/process.rs b/reid/src/ast/process.rs index cc73355..59178e2 100644 --- a/reid/src/ast/process.rs +++ b/reid/src/ast/process.rs @@ -583,7 +583,7 @@ impl ast::TypeKind { } ast::TypeKind::Ptr(type_kind) => mir::TypeKind::UserPtr(Box::new(type_kind.clone().into_mir(source_mod))), ast::TypeKind::F16 => mir::TypeKind::F16, - ast::TypeKind::F32B => mir::TypeKind::F32B, + ast::TypeKind::F32B => mir::TypeKind::F16B, ast::TypeKind::F32 => mir::TypeKind::F32, ast::TypeKind::F64 => mir::TypeKind::F64, ast::TypeKind::F80 => mir::TypeKind::F80, diff --git a/reid/src/codegen/intrinsics.rs b/reid/src/codegen/intrinsics.rs index a944dff..3a06f20 100644 --- a/reid/src/codegen/intrinsics.rs +++ b/reid/src/codegen/intrinsics.rs @@ -26,7 +26,7 @@ const INTEGERS: [TypeKind; 10] = [ const FLOATS: [TypeKind; 7] = [ TypeKind::F16, TypeKind::F32, - TypeKind::F32B, + TypeKind::F16B, TypeKind::F64, TypeKind::F80, TypeKind::F128, diff --git a/reid/src/codegen/util.rs b/reid/src/codegen/util.rs index 8e84e51..d3200e3 100644 --- a/reid/src/codegen/util.rs +++ b/reid/src/codegen/util.rs @@ -79,7 +79,7 @@ impl TypeKind { TypeKind::U128 => Type::U128, TypeKind::Bool => Type::Bool, TypeKind::F16 => Type::F16, - TypeKind::F32B => Type::F32B, + TypeKind::F16B => Type::F32B, TypeKind::F32 => Type::F32, TypeKind::F64 => Type::F64, TypeKind::F128 => Type::F128, @@ -223,7 +223,7 @@ impl TypeKind { TypeKind::U16 | TypeKind::U32 | TypeKind::U64 | TypeKind::U128 => DwarfEncoding::Unsigned, TypeKind::F16 | TypeKind::F32 - | TypeKind::F32B + | TypeKind::F16B | TypeKind::F64 | TypeKind::F80 | TypeKind::F128 diff --git a/reid/src/mir/fmt.rs b/reid/src/mir/fmt.rs index f58cc74..bd23860 100644 --- a/reid/src/mir/fmt.rs +++ b/reid/src/mir/fmt.rs @@ -471,7 +471,7 @@ impl Display for TypeKind { } TypeKind::Vague(vague_type) => Display::fmt(vague_type, f), TypeKind::F16 => write!(f, "f16"), - TypeKind::F32B => write!(f, "f16b"), + TypeKind::F16B => write!(f, "f16b"), TypeKind::F32 => write!(f, "f32"), TypeKind::F64 => write!(f, "f64"), TypeKind::F128 => write!(f, "f128"), diff --git a/reid/src/mir/implement.rs b/reid/src/mir/implement.rs index edccf07..d44ab80 100644 --- a/reid/src/mir/implement.rs +++ b/reid/src/mir/implement.rs @@ -48,7 +48,7 @@ impl TypeKind { TypeKind::Borrow(..) => false, TypeKind::UserPtr(..) => false, TypeKind::F16 => true, - TypeKind::F32B => true, + TypeKind::F16B => true, TypeKind::F32 => true, TypeKind::F64 => true, TypeKind::F128 => true, @@ -73,21 +73,26 @@ impl TypeKind { TypeKind::Void => 0, TypeKind::Char => 8, TypeKind::Array(type_kind, len) => type_kind.size_of(map) * (*len as u64), - TypeKind::CustomType(key) => match &map.get(key).unwrap().kind { - TypeDefinitionKind::Struct(struct_type) => { - let mut size = 0; - for field in &struct_type.0 { - size += field.1.size_of(map) + TypeKind::CustomType(key) => match &map.get(key) { + Some(def) => match &def.kind { + TypeDefinitionKind::Struct(struct_type) => { + let mut size = 0; + for field in &struct_type.0 { + size += field.1.size_of(map) + } + size } - size - } + }, + // Easy to recognize default number. Used e.g. when sorting + // types by size + None => 404, }, TypeKind::CodegenPtr(_) => 64, TypeKind::Vague(_) => panic!("Tried to sizeof a vague type!"), TypeKind::Borrow(..) => 64, TypeKind::UserPtr(_) => 64, TypeKind::F16 => 16, - TypeKind::F32B => 16, + TypeKind::F16B => 16, TypeKind::F32 => 32, TypeKind::F64 => 64, TypeKind::F128 => 128, @@ -118,7 +123,7 @@ impl TypeKind { TypeKind::Borrow(_, _) => 64, TypeKind::UserPtr(_) => 64, TypeKind::F16 => 16, - TypeKind::F32B => 16, + TypeKind::F16B => 16, TypeKind::F32 => 32, TypeKind::F64 => 64, TypeKind::F128 => 128, @@ -148,7 +153,7 @@ impl TypeKind { | TypeKind::U128 | TypeKind::Char => TypeCategory::Integer, TypeKind::F16 - | TypeKind::F32B + | TypeKind::F16B | TypeKind::F32 | TypeKind::F64 | TypeKind::F128 @@ -197,6 +202,36 @@ impl TypeKind { } } +impl Ord for TypeKind { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + use std::cmp::*; + + let category_ord = self.category().partial_cmp(&other.category()); + + match category_ord { + Some(Ordering::Equal) | None => { + if !self.signed() && other.signed() { + return Ordering::Less; + } + if self.signed() && !other.signed() { + return Ordering::Greater; + } + + let self_size = self.size_of(&HashMap::new()); + let other_size = other.size_of(&HashMap::new()); + if self_size == 32 && other_size != 32 { + return Ordering::Less; + } else if self_size != 32 && other_size == 32 { + return Ordering::Greater; + } + + self_size.cmp(&self_size) + } + Some(ord) => ord, + } + } +} + impl BinaryOperator { pub fn is_commutative(&self) -> bool { match self { @@ -224,7 +259,28 @@ impl BinaryOperator { } } -#[derive(PartialEq, Eq, PartialOrd, Ord)] +const TYPE_CATEGORY_ORDER: [TypeCategory; 5] = [ + TypeCategory::Integer, + TypeCategory::Bool, + TypeCategory::Real, + TypeCategory::Other, + TypeCategory::TypeRef, +]; + +impl PartialOrd for TypeCategory { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::*; + let self_idx = TYPE_CATEGORY_ORDER.iter().enumerate().find(|s| s.1 == self); + let other_idx = TYPE_CATEGORY_ORDER.iter().enumerate().find(|s| s.1 == other); + if let (Some(self_idx), Some(other_idx)) = (self_idx, other_idx) { + Some(self_idx.cmp(&other_idx)) + } else { + None + } + } +} + +#[derive(PartialEq, Eq, Ord)] pub enum TypeCategory { Integer, Real, diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index a363a38..3713bae 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -105,7 +105,7 @@ impl TokenRange { #[derive(Hash, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct CustomTypeKey(pub String, pub SourceModuleId); -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum TypeKind { Bool, I8, @@ -120,7 +120,7 @@ pub enum TypeKind { U128, Void, F16, - F32B, + F16B, F32, F64, F128, @@ -220,7 +220,7 @@ impl Literal { Literal::Vague(VagueLiteral::Number(_)) => TypeKind::Vague(VagueType::Integer), Literal::Vague(VagueLiteral::Decimal(_)) => TypeKind::Vague(VagueType::Decimal), Literal::F16(_) => TypeKind::F16, - Literal::F32B(_) => TypeKind::F32B, + Literal::F32B(_) => TypeKind::F16B, Literal::F32(_) => TypeKind::F32, Literal::F64(_) => TypeKind::F64, Literal::F80(_) => TypeKind::F80, diff --git a/reid/src/mir/typecheck/mod.rs b/reid/src/mir/typecheck/mod.rs index 97a9831..502a913 100644 --- a/reid/src/mir/typecheck/mod.rs +++ b/reid/src/mir/typecheck/mod.rs @@ -141,7 +141,7 @@ impl TypeKind { | TypeKind::U64 | TypeKind::U128 | TypeKind::F16 - | TypeKind::F32B + | TypeKind::F16B | TypeKind::F32 | TypeKind::F64 | TypeKind::F80 @@ -153,7 +153,7 @@ impl TypeKind { TypeKind::Vague(Vague::Unknown) => Ok(TypeKind::Vague(Vague::Decimal)), TypeKind::Vague(Vague::Decimal) => Ok(TypeKind::Vague(Vague::Decimal)), TypeKind::F16 - | TypeKind::F32B + | TypeKind::F16B | TypeKind::F32 | TypeKind::F64 | TypeKind::F80 @@ -207,7 +207,7 @@ impl TypeKind { }, (TypeKind::Vague(Vague::Decimal), other) | (other, TypeKind::Vague(Vague::Decimal)) => match other { TypeKind::F16 - | TypeKind::F32B + | TypeKind::F16B | TypeKind::F32 | TypeKind::F64 | TypeKind::F80 diff --git a/reid/src/mir/typecheck/typecheck.rs b/reid/src/mir/typecheck/typecheck.rs index aa3a8a8..c69ea5c 100644 --- a/reid/src/mir/typecheck/typecheck.rs +++ b/reid/src/mir/typecheck/typecheck.rs @@ -806,14 +806,14 @@ impl Literal { (L::Vague(VagueL::Number(v)), TypeKind::U128) => L::U128(v as u128), (L::Vague(VagueL::Number(v)), TypeKind::F16) => L::F16(v as f32), (L::Vague(VagueL::Number(v)), TypeKind::F32) => L::F32(v as f32), - (L::Vague(VagueL::Number(v)), TypeKind::F32B) => L::F32B(v as f32), + (L::Vague(VagueL::Number(v)), TypeKind::F16B) => L::F32B(v as f32), (L::Vague(VagueL::Number(v)), TypeKind::F64) => L::F64(v as f64), (L::Vague(VagueL::Number(v)), TypeKind::F80) => L::F80(v as f64), (L::Vague(VagueL::Number(v)), TypeKind::F128) => L::F128(v as f64), (L::Vague(VagueL::Number(v)), TypeKind::F128PPC) => L::F128PPC(v as f64), (L::Vague(VagueL::Decimal(v)), TypeKind::F16) => L::F16(v as f32), (L::Vague(VagueL::Decimal(v)), TypeKind::F32) => L::F32(v as f32), - (L::Vague(VagueL::Decimal(v)), TypeKind::F32B) => L::F32B(v as f32), + (L::Vague(VagueL::Decimal(v)), TypeKind::F16B) => L::F32B(v as f32), (L::Vague(VagueL::Decimal(v)), TypeKind::F64) => L::F64(v as f64), (L::Vague(VagueL::Decimal(v)), TypeKind::F80) => L::F80(v as f64), (L::Vague(VagueL::Decimal(v)), TypeKind::F128) => L::F128(v as f64), diff --git a/reid/src/mir/typecheck/typeinference.rs b/reid/src/mir/typecheck/typeinference.rs index 3d07458..dc0a6f2 100644 --- a/reid/src/mir/typecheck/typeinference.rs +++ b/reid/src/mir/typecheck/typeinference.rs @@ -362,6 +362,8 @@ impl Expression { let mut lhs_ref = lhs.infer_types(state, type_refs)?; let mut rhs_ref = rhs.infer_types(state, type_refs)?; + dbg!(&lhs_ref, &rhs_ref); + let binops = if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_ref.resolve_deep(), rhs_ref.resolve_deep()) { let mut applying_binops = Vec::new(); for (_, binop) in state.scope.binops.iter() { @@ -383,6 +385,7 @@ impl Expression { } else { Vec::new() }; + if binops.len() > 0 { let binop = unsafe { binops.get_unchecked(0) }; let mut widened_lhs = binop.hands.0.clone(); @@ -394,10 +397,10 @@ impl Expression { let binop_res = type_refs.from_binop(*op, &lhs_ref, &rhs_ref); // dbg!(&return_ty); // dbg!(&binop_res); - // dbg!(&lhs_ref, &rhs_ref, &binops, &widened_lhs, &widened_rhs); lhs_ref.narrow(&type_refs.from_type(&widened_lhs).unwrap()); rhs_ref.narrow(&type_refs.from_type(&widened_rhs).unwrap()); *return_ty = binop_res.as_type(); + dbg!(&lhs_ref, &rhs_ref); Ok(binop_res) } else { Err(ErrorKind::InvalidBinop( @@ -438,7 +441,10 @@ impl Expression { // Try to narrow condition type to boolean if let Some(mut cond_hints) = cond_hints { + println!("before: {}", type_refs.types); + dbg!(&cond_hints); cond_hints.narrow(&mut type_refs.from_type(&Bool).unwrap()); + println!("after: {}", type_refs.types); } // Infer LHS return type diff --git a/reid/src/mir/typecheck/typerefs.rs b/reid/src/mir/typecheck/typerefs.rs index dc6b6d0..8863ee3 100644 --- a/reid/src/mir/typecheck/typerefs.rs +++ b/reid/src/mir/typecheck/typerefs.rs @@ -308,12 +308,23 @@ impl<'outer> ScopeTypeRefs<'outer> { let lhs_resolved = lhs.resolve_ref(self.types); let rhs_resolved = rhs.resolve_ref(self.types); - let binops = self + let mut binops = self .types .binop_types .iter() .filter(|b| b.1.operator == op && b.1.return_ty == *ty) .collect::>(); + + // Sort binops by lhs and then rhs + binops.sort_by(|a, b| { + let lhs = a.1.hands.0.cmp(&b.1.hands.0); + let rhs = a.1.hands.1.cmp(&b.1.hands.1); + match lhs { + std::cmp::Ordering::Equal => rhs, + _ => lhs, + } + }); + for binop in binops { if let (Ok(lhs_narrow), Ok(rhs_narrow)) = ( lhs_resolved.narrow_into(&binop.1.hands.0), @@ -375,6 +386,9 @@ impl<'outer> ScopeTypeRefs<'outer> { let hint1_typeref = self.types.retrieve_typeref(*hint1.0.borrow()).unwrap(); let hint2_typeref = self.types.retrieve_typeref(*hint2.0.borrow()).unwrap(); + dbg!(&hint1_typeref); + dbg!(&hint2_typeref); + match (&hint1_typeref, &hint2_typeref) { (TypeRefKind::Direct(ret_ty), TypeRefKind::BinOp(op, lhs, rhs)) => { let mut lhs_ref = self.from_type(&lhs).unwrap(); @@ -487,6 +501,17 @@ impl<'outer> ScopeTypeRefs<'outer> { } } } + + // Sort binops by lhs and then rhs + applying_binops.sort_by(|a, b| { + let lhs = a.hands.0.cmp(&b.hands.0); + let rhs = a.hands.1.cmp(&b.hands.1); + match lhs { + std::cmp::Ordering::Equal => rhs, + _ => lhs, + } + }); + applying_binops } }