Add ordering for how types are chosen for binops
This commit is contained in:
		
							parent
							
								
									2dd482c9c2
								
							
						
					
					
						commit
						407c681cb6
					
				| @ -583,7 +583,7 @@ impl ast::TypeKind { | ||||
|             } | ||||
|             ast::TypeKind::Ptr(type_kind) => mir::TypeKind::UserPtr(Box::new(type_kind.clone().into_mir(source_mod))), | ||||
|             ast::TypeKind::F16 => mir::TypeKind::F16, | ||||
|             ast::TypeKind::F32B => mir::TypeKind::F32B, | ||||
|             ast::TypeKind::F32B => mir::TypeKind::F16B, | ||||
|             ast::TypeKind::F32 => mir::TypeKind::F32, | ||||
|             ast::TypeKind::F64 => mir::TypeKind::F64, | ||||
|             ast::TypeKind::F80 => mir::TypeKind::F80, | ||||
|  | ||||
| @ -26,7 +26,7 @@ const INTEGERS: [TypeKind; 10] = [ | ||||
| const FLOATS: [TypeKind; 7] = [ | ||||
|     TypeKind::F16, | ||||
|     TypeKind::F32, | ||||
|     TypeKind::F32B, | ||||
|     TypeKind::F16B, | ||||
|     TypeKind::F64, | ||||
|     TypeKind::F80, | ||||
|     TypeKind::F128, | ||||
|  | ||||
| @ -79,7 +79,7 @@ impl TypeKind { | ||||
|             TypeKind::U128 => Type::U128, | ||||
|             TypeKind::Bool => Type::Bool, | ||||
|             TypeKind::F16 => Type::F16, | ||||
|             TypeKind::F32B => Type::F32B, | ||||
|             TypeKind::F16B => Type::F32B, | ||||
|             TypeKind::F32 => Type::F32, | ||||
|             TypeKind::F64 => Type::F64, | ||||
|             TypeKind::F128 => Type::F128, | ||||
| @ -223,7 +223,7 @@ impl TypeKind { | ||||
|                     TypeKind::U16 | TypeKind::U32 | TypeKind::U64 | TypeKind::U128 => DwarfEncoding::Unsigned, | ||||
|                     TypeKind::F16 | ||||
|                     | TypeKind::F32 | ||||
|                     | TypeKind::F32B | ||||
|                     | TypeKind::F16B | ||||
|                     | TypeKind::F64 | ||||
|                     | TypeKind::F80 | ||||
|                     | TypeKind::F128 | ||||
|  | ||||
| @ -471,7 +471,7 @@ impl Display for TypeKind { | ||||
|             } | ||||
|             TypeKind::Vague(vague_type) => Display::fmt(vague_type, f), | ||||
|             TypeKind::F16 => write!(f, "f16"), | ||||
|             TypeKind::F32B => write!(f, "f16b"), | ||||
|             TypeKind::F16B => write!(f, "f16b"), | ||||
|             TypeKind::F32 => write!(f, "f32"), | ||||
|             TypeKind::F64 => write!(f, "f64"), | ||||
|             TypeKind::F128 => write!(f, "f128"), | ||||
|  | ||||
| @ -48,7 +48,7 @@ impl TypeKind { | ||||
|             TypeKind::Borrow(..) => false, | ||||
|             TypeKind::UserPtr(..) => false, | ||||
|             TypeKind::F16 => true, | ||||
|             TypeKind::F32B => true, | ||||
|             TypeKind::F16B => true, | ||||
|             TypeKind::F32 => true, | ||||
|             TypeKind::F64 => true, | ||||
|             TypeKind::F128 => true, | ||||
| @ -73,21 +73,26 @@ impl TypeKind { | ||||
|             TypeKind::Void => 0, | ||||
|             TypeKind::Char => 8, | ||||
|             TypeKind::Array(type_kind, len) => type_kind.size_of(map) * (*len as u64), | ||||
|             TypeKind::CustomType(key) => match &map.get(key).unwrap().kind { | ||||
|                 TypeDefinitionKind::Struct(struct_type) => { | ||||
|                     let mut size = 0; | ||||
|                     for field in &struct_type.0 { | ||||
|                         size += field.1.size_of(map) | ||||
|             TypeKind::CustomType(key) => match &map.get(key) { | ||||
|                 Some(def) => match &def.kind { | ||||
|                     TypeDefinitionKind::Struct(struct_type) => { | ||||
|                         let mut size = 0; | ||||
|                         for field in &struct_type.0 { | ||||
|                             size += field.1.size_of(map) | ||||
|                         } | ||||
|                         size | ||||
|                     } | ||||
|                     size | ||||
|                 } | ||||
|                 }, | ||||
|                 // Easy to recognize default number. Used e.g. when sorting
 | ||||
|                 // types by size
 | ||||
|                 None => 404, | ||||
|             }, | ||||
|             TypeKind::CodegenPtr(_) => 64, | ||||
|             TypeKind::Vague(_) => panic!("Tried to sizeof a vague type!"), | ||||
|             TypeKind::Borrow(..) => 64, | ||||
|             TypeKind::UserPtr(_) => 64, | ||||
|             TypeKind::F16 => 16, | ||||
|             TypeKind::F32B => 16, | ||||
|             TypeKind::F16B => 16, | ||||
|             TypeKind::F32 => 32, | ||||
|             TypeKind::F64 => 64, | ||||
|             TypeKind::F128 => 128, | ||||
| @ -118,7 +123,7 @@ impl TypeKind { | ||||
|             TypeKind::Borrow(_, _) => 64, | ||||
|             TypeKind::UserPtr(_) => 64, | ||||
|             TypeKind::F16 => 16, | ||||
|             TypeKind::F32B => 16, | ||||
|             TypeKind::F16B => 16, | ||||
|             TypeKind::F32 => 32, | ||||
|             TypeKind::F64 => 64, | ||||
|             TypeKind::F128 => 128, | ||||
| @ -148,7 +153,7 @@ impl TypeKind { | ||||
|             | TypeKind::U128 | ||||
|             | TypeKind::Char => TypeCategory::Integer, | ||||
|             TypeKind::F16 | ||||
|             | TypeKind::F32B | ||||
|             | TypeKind::F16B | ||||
|             | TypeKind::F32 | ||||
|             | TypeKind::F64 | ||||
|             | TypeKind::F128 | ||||
| @ -197,6 +202,36 @@ impl TypeKind { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Ord for TypeKind { | ||||
|     fn cmp(&self, other: &Self) -> std::cmp::Ordering { | ||||
|         use std::cmp::*; | ||||
| 
 | ||||
|         let category_ord = self.category().partial_cmp(&other.category()); | ||||
| 
 | ||||
|         match category_ord { | ||||
|             Some(Ordering::Equal) | None => { | ||||
|                 if !self.signed() && other.signed() { | ||||
|                     return Ordering::Less; | ||||
|                 } | ||||
|                 if self.signed() && !other.signed() { | ||||
|                     return Ordering::Greater; | ||||
|                 } | ||||
| 
 | ||||
|                 let self_size = self.size_of(&HashMap::new()); | ||||
|                 let other_size = other.size_of(&HashMap::new()); | ||||
|                 if self_size == 32 && other_size != 32 { | ||||
|                     return Ordering::Less; | ||||
|                 } else if self_size != 32 && other_size == 32 { | ||||
|                     return Ordering::Greater; | ||||
|                 } | ||||
| 
 | ||||
|                 self_size.cmp(&self_size) | ||||
|             } | ||||
|             Some(ord) => ord, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl BinaryOperator { | ||||
|     pub fn is_commutative(&self) -> bool { | ||||
|         match self { | ||||
| @ -224,7 +259,28 @@ impl BinaryOperator { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(PartialEq, Eq, PartialOrd, Ord)] | ||||
| const TYPE_CATEGORY_ORDER: [TypeCategory; 5] = [ | ||||
|     TypeCategory::Integer, | ||||
|     TypeCategory::Bool, | ||||
|     TypeCategory::Real, | ||||
|     TypeCategory::Other, | ||||
|     TypeCategory::TypeRef, | ||||
| ]; | ||||
| 
 | ||||
| impl PartialOrd for TypeCategory { | ||||
|     fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { | ||||
|         use std::cmp::*; | ||||
|         let self_idx = TYPE_CATEGORY_ORDER.iter().enumerate().find(|s| s.1 == self); | ||||
|         let other_idx = TYPE_CATEGORY_ORDER.iter().enumerate().find(|s| s.1 == other); | ||||
|         if let (Some(self_idx), Some(other_idx)) = (self_idx, other_idx) { | ||||
|             Some(self_idx.cmp(&other_idx)) | ||||
|         } else { | ||||
|             None | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(PartialEq, Eq, Ord)] | ||||
| pub enum TypeCategory { | ||||
|     Integer, | ||||
|     Real, | ||||
|  | ||||
| @ -105,7 +105,7 @@ impl TokenRange { | ||||
| #[derive(Hash, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] | ||||
| pub struct CustomTypeKey(pub String, pub SourceModuleId); | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] | ||||
| #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] | ||||
| pub enum TypeKind { | ||||
|     Bool, | ||||
|     I8, | ||||
| @ -120,7 +120,7 @@ pub enum TypeKind { | ||||
|     U128, | ||||
|     Void, | ||||
|     F16, | ||||
|     F32B, | ||||
|     F16B, | ||||
|     F32, | ||||
|     F64, | ||||
|     F128, | ||||
| @ -220,7 +220,7 @@ impl Literal { | ||||
|             Literal::Vague(VagueLiteral::Number(_)) => TypeKind::Vague(VagueType::Integer), | ||||
|             Literal::Vague(VagueLiteral::Decimal(_)) => TypeKind::Vague(VagueType::Decimal), | ||||
|             Literal::F16(_) => TypeKind::F16, | ||||
|             Literal::F32B(_) => TypeKind::F32B, | ||||
|             Literal::F32B(_) => TypeKind::F16B, | ||||
|             Literal::F32(_) => TypeKind::F32, | ||||
|             Literal::F64(_) => TypeKind::F64, | ||||
|             Literal::F80(_) => TypeKind::F80, | ||||
|  | ||||
| @ -141,7 +141,7 @@ impl TypeKind { | ||||
|                 | TypeKind::U64 | ||||
|                 | TypeKind::U128 | ||||
|                 | TypeKind::F16 | ||||
|                 | TypeKind::F32B | ||||
|                 | TypeKind::F16B | ||||
|                 | TypeKind::F32 | ||||
|                 | TypeKind::F64 | ||||
|                 | TypeKind::F80 | ||||
| @ -153,7 +153,7 @@ impl TypeKind { | ||||
|                 TypeKind::Vague(Vague::Unknown) => Ok(TypeKind::Vague(Vague::Decimal)), | ||||
|                 TypeKind::Vague(Vague::Decimal) => Ok(TypeKind::Vague(Vague::Decimal)), | ||||
|                 TypeKind::F16 | ||||
|                 | TypeKind::F32B | ||||
|                 | TypeKind::F16B | ||||
|                 | TypeKind::F32 | ||||
|                 | TypeKind::F64 | ||||
|                 | TypeKind::F80 | ||||
| @ -207,7 +207,7 @@ impl TypeKind { | ||||
|             }, | ||||
|             (TypeKind::Vague(Vague::Decimal), other) | (other, TypeKind::Vague(Vague::Decimal)) => match other { | ||||
|                 TypeKind::F16 | ||||
|                 | TypeKind::F32B | ||||
|                 | TypeKind::F16B | ||||
|                 | TypeKind::F32 | ||||
|                 | TypeKind::F64 | ||||
|                 | TypeKind::F80 | ||||
|  | ||||
| @ -806,14 +806,14 @@ impl Literal { | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::U128) => L::U128(v as u128), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F16) => L::F16(v as f32), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F32) => L::F32(v as f32), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F32B) => L::F32B(v as f32), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F16B) => L::F32B(v as f32), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F64) => L::F64(v as f64), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F80) => L::F80(v as f64), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F128) => L::F128(v as f64), | ||||
|                 (L::Vague(VagueL::Number(v)), TypeKind::F128PPC) => L::F128PPC(v as f64), | ||||
|                 (L::Vague(VagueL::Decimal(v)), TypeKind::F16) => L::F16(v as f32), | ||||
|                 (L::Vague(VagueL::Decimal(v)), TypeKind::F32) => L::F32(v as f32), | ||||
|                 (L::Vague(VagueL::Decimal(v)), TypeKind::F32B) => L::F32B(v as f32), | ||||
|                 (L::Vague(VagueL::Decimal(v)), TypeKind::F16B) => L::F32B(v as f32), | ||||
|                 (L::Vague(VagueL::Decimal(v)), TypeKind::F64) => L::F64(v as f64), | ||||
|                 (L::Vague(VagueL::Decimal(v)), TypeKind::F80) => L::F80(v as f64), | ||||
|                 (L::Vague(VagueL::Decimal(v)), TypeKind::F128) => L::F128(v as f64), | ||||
|  | ||||
| @ -362,6 +362,8 @@ impl Expression { | ||||
|                 let mut lhs_ref = lhs.infer_types(state, type_refs)?; | ||||
|                 let mut rhs_ref = rhs.infer_types(state, type_refs)?; | ||||
| 
 | ||||
|                 dbg!(&lhs_ref, &rhs_ref); | ||||
| 
 | ||||
|                 let binops = if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_ref.resolve_deep(), rhs_ref.resolve_deep()) { | ||||
|                     let mut applying_binops = Vec::new(); | ||||
|                     for (_, binop) in state.scope.binops.iter() { | ||||
| @ -383,6 +385,7 @@ impl Expression { | ||||
|                 } else { | ||||
|                     Vec::new() | ||||
|                 }; | ||||
| 
 | ||||
|                 if binops.len() > 0 { | ||||
|                     let binop = unsafe { binops.get_unchecked(0) }; | ||||
|                     let mut widened_lhs = binop.hands.0.clone(); | ||||
| @ -394,10 +397,10 @@ impl Expression { | ||||
|                     let binop_res = type_refs.from_binop(*op, &lhs_ref, &rhs_ref); | ||||
|                     // dbg!(&return_ty);
 | ||||
|                     // dbg!(&binop_res);
 | ||||
|                     // dbg!(&lhs_ref, &rhs_ref, &binops, &widened_lhs, &widened_rhs);
 | ||||
|                     lhs_ref.narrow(&type_refs.from_type(&widened_lhs).unwrap()); | ||||
|                     rhs_ref.narrow(&type_refs.from_type(&widened_rhs).unwrap()); | ||||
|                     *return_ty = binop_res.as_type(); | ||||
|                     dbg!(&lhs_ref, &rhs_ref); | ||||
|                     Ok(binop_res) | ||||
|                 } else { | ||||
|                     Err(ErrorKind::InvalidBinop( | ||||
| @ -438,7 +441,10 @@ impl Expression { | ||||
| 
 | ||||
|                 // Try to narrow condition type to boolean
 | ||||
|                 if let Some(mut cond_hints) = cond_hints { | ||||
|                     println!("before: {}", type_refs.types); | ||||
|                     dbg!(&cond_hints); | ||||
|                     cond_hints.narrow(&mut type_refs.from_type(&Bool).unwrap()); | ||||
|                     println!("after: {}", type_refs.types); | ||||
|                 } | ||||
| 
 | ||||
|                 // Infer LHS return type
 | ||||
|  | ||||
| @ -308,12 +308,23 @@ impl<'outer> ScopeTypeRefs<'outer> { | ||||
|                     let lhs_resolved = lhs.resolve_ref(self.types); | ||||
|                     let rhs_resolved = rhs.resolve_ref(self.types); | ||||
| 
 | ||||
|                     let binops = self | ||||
|                     let mut binops = self | ||||
|                         .types | ||||
|                         .binop_types | ||||
|                         .iter() | ||||
|                         .filter(|b| b.1.operator == op && b.1.return_ty == *ty) | ||||
|                         .collect::<Vec<_>>(); | ||||
| 
 | ||||
|                     // Sort binops by lhs and then rhs
 | ||||
|                     binops.sort_by(|a, b| { | ||||
|                         let lhs = a.1.hands.0.cmp(&b.1.hands.0); | ||||
|                         let rhs = a.1.hands.1.cmp(&b.1.hands.1); | ||||
|                         match lhs { | ||||
|                             std::cmp::Ordering::Equal => rhs, | ||||
|                             _ => lhs, | ||||
|                         } | ||||
|                     }); | ||||
| 
 | ||||
|                     for binop in binops { | ||||
|                         if let (Ok(lhs_narrow), Ok(rhs_narrow)) = ( | ||||
|                             lhs_resolved.narrow_into(&binop.1.hands.0), | ||||
| @ -375,6 +386,9 @@ impl<'outer> ScopeTypeRefs<'outer> { | ||||
|             let hint1_typeref = self.types.retrieve_typeref(*hint1.0.borrow()).unwrap(); | ||||
|             let hint2_typeref = self.types.retrieve_typeref(*hint2.0.borrow()).unwrap(); | ||||
| 
 | ||||
|             dbg!(&hint1_typeref); | ||||
|             dbg!(&hint2_typeref); | ||||
| 
 | ||||
|             match (&hint1_typeref, &hint2_typeref) { | ||||
|                 (TypeRefKind::Direct(ret_ty), TypeRefKind::BinOp(op, lhs, rhs)) => { | ||||
|                     let mut lhs_ref = self.from_type(&lhs).unwrap(); | ||||
| @ -487,6 +501,17 @@ impl<'outer> ScopeTypeRefs<'outer> { | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // Sort binops by lhs and then rhs
 | ||||
|         applying_binops.sort_by(|a, b| { | ||||
|             let lhs = a.hands.0.cmp(&b.hands.0); | ||||
|             let rhs = a.hands.1.cmp(&b.hands.1); | ||||
|             match lhs { | ||||
|                 std::cmp::Ordering::Equal => rhs, | ||||
|                 _ => lhs, | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         applying_binops | ||||
|     } | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user