Add ordering for how types are chosen for binops

This commit is contained in:
Sofia 2025-08-16 17:36:37 +03:00
parent 2dd482c9c2
commit 407c681cb6
10 changed files with 114 additions and 27 deletions

View File

@ -583,7 +583,7 @@ impl ast::TypeKind {
}
ast::TypeKind::Ptr(type_kind) => mir::TypeKind::UserPtr(Box::new(type_kind.clone().into_mir(source_mod))),
ast::TypeKind::F16 => mir::TypeKind::F16,
ast::TypeKind::F32B => mir::TypeKind::F32B,
ast::TypeKind::F32B => mir::TypeKind::F16B,
ast::TypeKind::F32 => mir::TypeKind::F32,
ast::TypeKind::F64 => mir::TypeKind::F64,
ast::TypeKind::F80 => mir::TypeKind::F80,

View File

@ -26,7 +26,7 @@ const INTEGERS: [TypeKind; 10] = [
const FLOATS: [TypeKind; 7] = [
TypeKind::F16,
TypeKind::F32,
TypeKind::F32B,
TypeKind::F16B,
TypeKind::F64,
TypeKind::F80,
TypeKind::F128,

View File

@ -79,7 +79,7 @@ impl TypeKind {
TypeKind::U128 => Type::U128,
TypeKind::Bool => Type::Bool,
TypeKind::F16 => Type::F16,
TypeKind::F32B => Type::F32B,
TypeKind::F16B => Type::F32B,
TypeKind::F32 => Type::F32,
TypeKind::F64 => Type::F64,
TypeKind::F128 => Type::F128,
@ -223,7 +223,7 @@ impl TypeKind {
TypeKind::U16 | TypeKind::U32 | TypeKind::U64 | TypeKind::U128 => DwarfEncoding::Unsigned,
TypeKind::F16
| TypeKind::F32
| TypeKind::F32B
| TypeKind::F16B
| TypeKind::F64
| TypeKind::F80
| TypeKind::F128

View File

@ -471,7 +471,7 @@ impl Display for TypeKind {
}
TypeKind::Vague(vague_type) => Display::fmt(vague_type, f),
TypeKind::F16 => write!(f, "f16"),
TypeKind::F32B => write!(f, "f16b"),
TypeKind::F16B => write!(f, "f16b"),
TypeKind::F32 => write!(f, "f32"),
TypeKind::F64 => write!(f, "f64"),
TypeKind::F128 => write!(f, "f128"),

View File

@ -48,7 +48,7 @@ impl TypeKind {
TypeKind::Borrow(..) => false,
TypeKind::UserPtr(..) => false,
TypeKind::F16 => true,
TypeKind::F32B => true,
TypeKind::F16B => true,
TypeKind::F32 => true,
TypeKind::F64 => true,
TypeKind::F128 => true,
@ -73,21 +73,26 @@ impl TypeKind {
TypeKind::Void => 0,
TypeKind::Char => 8,
TypeKind::Array(type_kind, len) => type_kind.size_of(map) * (*len as u64),
TypeKind::CustomType(key) => match &map.get(key).unwrap().kind {
TypeDefinitionKind::Struct(struct_type) => {
let mut size = 0;
for field in &struct_type.0 {
size += field.1.size_of(map)
TypeKind::CustomType(key) => match &map.get(key) {
Some(def) => match &def.kind {
TypeDefinitionKind::Struct(struct_type) => {
let mut size = 0;
for field in &struct_type.0 {
size += field.1.size_of(map)
}
size
}
size
}
},
// Easy to recognize default number. Used e.g. when sorting
// types by size
None => 404,
},
TypeKind::CodegenPtr(_) => 64,
TypeKind::Vague(_) => panic!("Tried to sizeof a vague type!"),
TypeKind::Borrow(..) => 64,
TypeKind::UserPtr(_) => 64,
TypeKind::F16 => 16,
TypeKind::F32B => 16,
TypeKind::F16B => 16,
TypeKind::F32 => 32,
TypeKind::F64 => 64,
TypeKind::F128 => 128,
@ -118,7 +123,7 @@ impl TypeKind {
TypeKind::Borrow(_, _) => 64,
TypeKind::UserPtr(_) => 64,
TypeKind::F16 => 16,
TypeKind::F32B => 16,
TypeKind::F16B => 16,
TypeKind::F32 => 32,
TypeKind::F64 => 64,
TypeKind::F128 => 128,
@ -148,7 +153,7 @@ impl TypeKind {
| TypeKind::U128
| TypeKind::Char => TypeCategory::Integer,
TypeKind::F16
| TypeKind::F32B
| TypeKind::F16B
| TypeKind::F32
| TypeKind::F64
| TypeKind::F128
@ -197,6 +202,36 @@ impl TypeKind {
}
}
impl Ord for TypeKind {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use std::cmp::*;
let category_ord = self.category().partial_cmp(&other.category());
match category_ord {
Some(Ordering::Equal) | None => {
if !self.signed() && other.signed() {
return Ordering::Less;
}
if self.signed() && !other.signed() {
return Ordering::Greater;
}
let self_size = self.size_of(&HashMap::new());
let other_size = other.size_of(&HashMap::new());
if self_size == 32 && other_size != 32 {
return Ordering::Less;
} else if self_size != 32 && other_size == 32 {
return Ordering::Greater;
}
self_size.cmp(&self_size)
}
Some(ord) => ord,
}
}
}
impl BinaryOperator {
pub fn is_commutative(&self) -> bool {
match self {
@ -224,7 +259,28 @@ impl BinaryOperator {
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord)]
const TYPE_CATEGORY_ORDER: [TypeCategory; 5] = [
TypeCategory::Integer,
TypeCategory::Bool,
TypeCategory::Real,
TypeCategory::Other,
TypeCategory::TypeRef,
];
impl PartialOrd for TypeCategory {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
use std::cmp::*;
let self_idx = TYPE_CATEGORY_ORDER.iter().enumerate().find(|s| s.1 == self);
let other_idx = TYPE_CATEGORY_ORDER.iter().enumerate().find(|s| s.1 == other);
if let (Some(self_idx), Some(other_idx)) = (self_idx, other_idx) {
Some(self_idx.cmp(&other_idx))
} else {
None
}
}
}
#[derive(PartialEq, Eq, Ord)]
pub enum TypeCategory {
Integer,
Real,

View File

@ -105,7 +105,7 @@ impl TokenRange {
#[derive(Hash, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct CustomTypeKey(pub String, pub SourceModuleId);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum TypeKind {
Bool,
I8,
@ -120,7 +120,7 @@ pub enum TypeKind {
U128,
Void,
F16,
F32B,
F16B,
F32,
F64,
F128,
@ -220,7 +220,7 @@ impl Literal {
Literal::Vague(VagueLiteral::Number(_)) => TypeKind::Vague(VagueType::Integer),
Literal::Vague(VagueLiteral::Decimal(_)) => TypeKind::Vague(VagueType::Decimal),
Literal::F16(_) => TypeKind::F16,
Literal::F32B(_) => TypeKind::F32B,
Literal::F32B(_) => TypeKind::F16B,
Literal::F32(_) => TypeKind::F32,
Literal::F64(_) => TypeKind::F64,
Literal::F80(_) => TypeKind::F80,

View File

@ -141,7 +141,7 @@ impl TypeKind {
| TypeKind::U64
| TypeKind::U128
| TypeKind::F16
| TypeKind::F32B
| TypeKind::F16B
| TypeKind::F32
| TypeKind::F64
| TypeKind::F80
@ -153,7 +153,7 @@ impl TypeKind {
TypeKind::Vague(Vague::Unknown) => Ok(TypeKind::Vague(Vague::Decimal)),
TypeKind::Vague(Vague::Decimal) => Ok(TypeKind::Vague(Vague::Decimal)),
TypeKind::F16
| TypeKind::F32B
| TypeKind::F16B
| TypeKind::F32
| TypeKind::F64
| TypeKind::F80
@ -207,7 +207,7 @@ impl TypeKind {
},
(TypeKind::Vague(Vague::Decimal), other) | (other, TypeKind::Vague(Vague::Decimal)) => match other {
TypeKind::F16
| TypeKind::F32B
| TypeKind::F16B
| TypeKind::F32
| TypeKind::F64
| TypeKind::F80

View File

@ -806,14 +806,14 @@ impl Literal {
(L::Vague(VagueL::Number(v)), TypeKind::U128) => L::U128(v as u128),
(L::Vague(VagueL::Number(v)), TypeKind::F16) => L::F16(v as f32),
(L::Vague(VagueL::Number(v)), TypeKind::F32) => L::F32(v as f32),
(L::Vague(VagueL::Number(v)), TypeKind::F32B) => L::F32B(v as f32),
(L::Vague(VagueL::Number(v)), TypeKind::F16B) => L::F32B(v as f32),
(L::Vague(VagueL::Number(v)), TypeKind::F64) => L::F64(v as f64),
(L::Vague(VagueL::Number(v)), TypeKind::F80) => L::F80(v as f64),
(L::Vague(VagueL::Number(v)), TypeKind::F128) => L::F128(v as f64),
(L::Vague(VagueL::Number(v)), TypeKind::F128PPC) => L::F128PPC(v as f64),
(L::Vague(VagueL::Decimal(v)), TypeKind::F16) => L::F16(v as f32),
(L::Vague(VagueL::Decimal(v)), TypeKind::F32) => L::F32(v as f32),
(L::Vague(VagueL::Decimal(v)), TypeKind::F32B) => L::F32B(v as f32),
(L::Vague(VagueL::Decimal(v)), TypeKind::F16B) => L::F32B(v as f32),
(L::Vague(VagueL::Decimal(v)), TypeKind::F64) => L::F64(v as f64),
(L::Vague(VagueL::Decimal(v)), TypeKind::F80) => L::F80(v as f64),
(L::Vague(VagueL::Decimal(v)), TypeKind::F128) => L::F128(v as f64),

View File

@ -362,6 +362,8 @@ impl Expression {
let mut lhs_ref = lhs.infer_types(state, type_refs)?;
let mut rhs_ref = rhs.infer_types(state, type_refs)?;
dbg!(&lhs_ref, &rhs_ref);
let binops = if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_ref.resolve_deep(), rhs_ref.resolve_deep()) {
let mut applying_binops = Vec::new();
for (_, binop) in state.scope.binops.iter() {
@ -383,6 +385,7 @@ impl Expression {
} else {
Vec::new()
};
if binops.len() > 0 {
let binop = unsafe { binops.get_unchecked(0) };
let mut widened_lhs = binop.hands.0.clone();
@ -394,10 +397,10 @@ impl Expression {
let binop_res = type_refs.from_binop(*op, &lhs_ref, &rhs_ref);
// dbg!(&return_ty);
// dbg!(&binop_res);
// dbg!(&lhs_ref, &rhs_ref, &binops, &widened_lhs, &widened_rhs);
lhs_ref.narrow(&type_refs.from_type(&widened_lhs).unwrap());
rhs_ref.narrow(&type_refs.from_type(&widened_rhs).unwrap());
*return_ty = binop_res.as_type();
dbg!(&lhs_ref, &rhs_ref);
Ok(binop_res)
} else {
Err(ErrorKind::InvalidBinop(
@ -438,7 +441,10 @@ impl Expression {
// Try to narrow condition type to boolean
if let Some(mut cond_hints) = cond_hints {
println!("before: {}", type_refs.types);
dbg!(&cond_hints);
cond_hints.narrow(&mut type_refs.from_type(&Bool).unwrap());
println!("after: {}", type_refs.types);
}
// Infer LHS return type

View File

@ -308,12 +308,23 @@ impl<'outer> ScopeTypeRefs<'outer> {
let lhs_resolved = lhs.resolve_ref(self.types);
let rhs_resolved = rhs.resolve_ref(self.types);
let binops = self
let mut binops = self
.types
.binop_types
.iter()
.filter(|b| b.1.operator == op && b.1.return_ty == *ty)
.collect::<Vec<_>>();
// Sort binops by lhs and then rhs
binops.sort_by(|a, b| {
let lhs = a.1.hands.0.cmp(&b.1.hands.0);
let rhs = a.1.hands.1.cmp(&b.1.hands.1);
match lhs {
std::cmp::Ordering::Equal => rhs,
_ => lhs,
}
});
for binop in binops {
if let (Ok(lhs_narrow), Ok(rhs_narrow)) = (
lhs_resolved.narrow_into(&binop.1.hands.0),
@ -375,6 +386,9 @@ impl<'outer> ScopeTypeRefs<'outer> {
let hint1_typeref = self.types.retrieve_typeref(*hint1.0.borrow()).unwrap();
let hint2_typeref = self.types.retrieve_typeref(*hint2.0.borrow()).unwrap();
dbg!(&hint1_typeref);
dbg!(&hint2_typeref);
match (&hint1_typeref, &hint2_typeref) {
(TypeRefKind::Direct(ret_ty), TypeRefKind::BinOp(op, lhs, rhs)) => {
let mut lhs_ref = self.from_type(&lhs).unwrap();
@ -487,6 +501,17 @@ impl<'outer> ScopeTypeRefs<'outer> {
}
}
}
// Sort binops by lhs and then rhs
applying_binops.sort_by(|a, b| {
let lhs = a.hands.0.cmp(&b.hands.0);
let rhs = a.hands.1.cmp(&b.hands.1);
match lhs {
std::cmp::Ordering::Equal => rhs,
_ => lhs,
}
});
applying_binops
}
}