Refactor a bit

This commit is contained in:
Sofia 2025-07-16 22:38:19 +03:00
parent 3870b421a9
commit c19384d77b
11 changed files with 375 additions and 547 deletions

View File

@ -147,21 +147,11 @@ pub struct Block(
pub TokenRange, pub TokenRange,
); );
#[derive(Debug, Clone)]
pub struct VariableReference(pub VariableReferenceKind, pub TokenRange);
#[derive(Debug, Clone)]
pub enum VariableReferenceKind {
Name(String, TokenRange),
ArrayIndex(Box<VariableReference>, u64),
StructIndex(Box<VariableReference>, String),
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum BlockLevelStatement { pub enum BlockLevelStatement {
Let(LetStatement), Let(LetStatement),
/// Try to set a variable to a specified expression value /// Try to set a variable to a specified expression value
Set(VariableReference, Expression, TokenRange), Set(Expression, Expression, TokenRange),
Import { Import {
_i: ImportStatement, _i: ImportStatement,
}, },

View File

@ -417,32 +417,6 @@ impl Parse for Block {
} }
} }
impl Parse for VariableReference {
fn parse(mut stream: TokenStream) -> Result<Self, Error> {
if let Some(Token::Identifier(ident)) = stream.next() {
let mut var_ref = VariableReference(
VariableReferenceKind::Name(ident, stream.get_one_token_range()),
stream.get_range().unwrap(),
);
while let Ok(val) = stream.parse::<ValueIndex>() {
match val {
ValueIndex::Array(ArrayValueIndex(idx)) => {
todo!();
}
ValueIndex::Struct(StructValueIndex(name)) => {
todo!();
}
}
}
Ok(var_ref)
} else {
Err(stream.expected_err("identifier")?)?
}
}
}
impl Parse for StructExpression { impl Parse for StructExpression {
fn parse(mut stream: TokenStream) -> Result<Self, Error> { fn parse(mut stream: TokenStream) -> Result<Self, Error> {
let Some(Token::Identifier(name)) = stream.next() else { let Some(Token::Identifier(name)) = stream.next() else {
@ -569,7 +543,7 @@ impl Parse for BlockLevelStatement {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct SetStatement(VariableReference, Expression, TokenRange); pub struct SetStatement(Expression, Expression, TokenRange);
impl Parse for SetStatement { impl Parse for SetStatement {
fn parse(mut stream: TokenStream) -> Result<Self, Error> { fn parse(mut stream: TokenStream) -> Result<Self, Error> {

View File

@ -157,38 +157,6 @@ impl From<ast::ReturnType> for mir::ReturnKind {
} }
} }
impl ast::VariableReference {
fn process(&self) -> mir::IndexedVariableReference {
mir::IndexedVariableReference {
kind: self.0.process(),
meta: self.1.into(),
}
}
}
impl ast::VariableReferenceKind {
fn process(&self) -> mir::IndexedVariableReferenceKind {
match &self {
ast::VariableReferenceKind::Name(name, range) => {
mir::IndexedVariableReferenceKind::Named(NamedVariableRef(
mir::TypeKind::Vague(mir::VagueType::Unknown),
name.clone(),
(*range).into(),
))
}
ast::VariableReferenceKind::ArrayIndex(var_ref, idx) => {
mir::IndexedVariableReferenceKind::ArrayIndex(Box::new(var_ref.process()), *idx)
}
ast::VariableReferenceKind::StructIndex(var_ref, name) => {
mir::IndexedVariableReferenceKind::StructIndex(
Box::new(var_ref.process()),
name.clone(),
)
}
}
}
}
impl ast::Expression { impl ast::Expression {
fn process(&self) -> mir::Expression { fn process(&self) -> mir::Expression {
let kind = match &self.0 { let kind = match &self.0 {

View File

@ -8,8 +8,7 @@ use reid_lib::{
}; };
use crate::mir::{ use crate::mir::{
self, types::ReturnType, IndexedVariableReference, NamedVariableRef, StructField, StructType, self, NamedVariableRef, StructField, StructType, TypeDefinitionKind, TypeKind, VagueLiteral,
TypeDefinitionKind, TypeKind, VagueLiteral,
}; };
/// Context that contains all of the given modules as complete codegenerated /// Context that contains all of the given modules as complete codegenerated
@ -48,6 +47,69 @@ impl<'ctx> std::fmt::Debug for ModuleCodegen<'ctx> {
} }
} }
pub struct Scope<'ctx, 'a> {
context: &'ctx Context,
module: &'ctx Module<'ctx>,
function: &'ctx Function<'ctx>,
block: Block<'ctx>,
types: &'a HashMap<TypeValue, TypeDefinitionKind>,
type_values: &'a HashMap<String, TypeValue>,
functions: &'a HashMap<String, Function<'ctx>>,
stack_values: HashMap<String, StackValue>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StackValue(StackValueKind, Type);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StackValueKind {
Immutable(InstructionValue),
Mutable(InstructionValue),
}
impl StackValueKind {
unsafe fn get_instr(&self) -> &InstructionValue {
match self {
StackValueKind::Immutable(val) => val,
StackValueKind::Mutable(val) => val,
}
}
fn with_instr(&self, instr: InstructionValue) -> StackValueKind {
match self {
StackValueKind::Immutable(_) => StackValueKind::Immutable(instr),
StackValueKind::Mutable(_) => StackValueKind::Mutable(instr),
}
}
}
impl<'ctx, 'a> Scope<'ctx, 'a> {
fn with_block(&self, block: Block<'ctx>) -> Scope<'ctx, 'a> {
Scope {
block,
function: self.function,
context: self.context,
module: self.module,
functions: self.functions,
types: self.types,
type_values: self.type_values,
stack_values: self.stack_values.clone(),
}
}
/// Takes the block out from this scope, swaps the given block in it's place
/// and returns the old block.
fn swap_block(&mut self, block: Block<'ctx>) -> Block<'ctx> {
let mut old_block = block;
mem::swap(&mut self.block, &mut old_block);
old_block
}
fn get_typedef(&self, name: &String) -> Option<&TypeDefinitionKind> {
self.type_values.get(name).and_then(|v| self.types.get(v))
}
}
impl mir::Module { impl mir::Module {
fn codegen<'ctx>(&self, context: &'ctx Context) -> ModuleCodegen<'ctx> { fn codegen<'ctx>(&self, context: &'ctx Context) -> ModuleCodegen<'ctx> {
let mut module = context.module(&self.name, self.is_main); let mut module = context.module(&self.name, self.is_main);
@ -155,66 +217,24 @@ impl mir::Module {
} }
} }
pub struct Scope<'ctx, 'a> { impl mir::Block {
context: &'ctx Context, fn codegen<'ctx, 'a>(&self, mut scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> {
module: &'ctx Module<'ctx>, for stmt in &self.statements {
function: &'ctx Function<'ctx>, stmt.codegen(&mut scope);
block: Block<'ctx>,
types: &'a HashMap<TypeValue, TypeDefinitionKind>,
type_values: &'a HashMap<String, TypeValue>,
functions: &'a HashMap<String, Function<'ctx>>,
stack_values: HashMap<String, StackValue>,
} }
#[derive(Debug, Clone, PartialEq, Eq)] if let Some((kind, expr)) = &self.return_expression {
pub struct StackValue(StackValueKind, Type); match kind {
mir::ReturnKind::Hard => {
#[derive(Debug, Clone, Copy, PartialEq, Eq)] let ret = expr.codegen(&mut scope)?;
pub enum StackValueKind { scope.block.terminate(Term::Ret(ret)).unwrap();
Immutable(InstructionValue), None
Mutable(InstructionValue),
} }
mir::ReturnKind::Soft => expr.codegen(&mut scope),
impl StackValueKind {
unsafe fn get_instr(&self) -> &InstructionValue {
match self {
StackValueKind::Immutable(val) => val,
StackValueKind::Mutable(val) => val,
} }
} else {
None
} }
fn with_instr(&self, instr: InstructionValue) -> StackValueKind {
match self {
StackValueKind::Immutable(_) => StackValueKind::Immutable(instr),
StackValueKind::Mutable(_) => StackValueKind::Mutable(instr),
}
}
}
impl<'ctx, 'a> Scope<'ctx, 'a> {
fn with_block(&self, block: Block<'ctx>) -> Scope<'ctx, 'a> {
Scope {
block,
function: self.function,
context: self.context,
module: self.module,
functions: self.functions,
types: self.types,
type_values: self.type_values,
stack_values: self.stack_values.clone(),
}
}
/// Takes the block out from this scope, swaps the given block in it's place
/// and returns the old block.
fn swap_block(&mut self, block: Block<'ctx>) -> Block<'ctx> {
let mut old_block = block;
mem::swap(&mut self.block, &mut old_block);
old_block
}
fn get_typedef(&self, name: &String) -> Option<&TypeDefinitionKind> {
self.type_values.get(name).and_then(|v| self.types.get(v))
} }
} }
@ -257,20 +277,22 @@ impl mir::Statement {
); );
None None
} }
mir::StmtKind::Set(var, val) => { mir::StmtKind::Set(lhs, rhs) => {
if let Some(StackValue(kind, _)) = var.get_stack_value(scope, false) { todo!("codegen!");
match kind {
StackValueKind::Immutable(_) => { // if let Some(StackValue(kind, _)) = var.get_stack_value(scope, false) {
panic!("Tried to mutate an immutable variable") // match kind {
} // StackValueKind::Immutable(_) => {
StackValueKind::Mutable(ptr) => { // panic!("Tried to mutate an immutable variable")
let expression = val.codegen(scope).unwrap(); // }
Some(scope.block.build(Instr::Store(ptr, expression)).unwrap()) // StackValueKind::Mutable(ptr) => {
} // let expression = val.codegen(scope).unwrap();
} // Some(scope.block.build(Instr::Store(ptr, expression)).unwrap())
} else { // }
panic!("") // }
} // } else {
// panic!("")
// }
} }
// mir::StmtKind::If(if_expression) => if_expression.codegen(scope), // mir::StmtKind::If(if_expression) => if_expression.codegen(scope),
mir::StmtKind::Import(_) => todo!(), mir::StmtKind::Import(_) => todo!(),
@ -279,62 +301,6 @@ impl mir::Statement {
} }
} }
impl mir::IfExpression {
fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> {
let condition = self.0.codegen(scope).unwrap();
// Create blocks
let then_b = scope.function.block("then");
let mut else_b = scope.function.block("else");
let after_b = scope.function.block("after");
// Store for convenience
let then_bb = then_b.value();
let else_bb = else_b.value();
let after_bb = after_b.value();
// Generate then-block content
let mut then_scope = scope.with_block(then_b);
let then_res = self.1.codegen(&mut then_scope);
then_scope.block.terminate(Term::Br(after_bb)).ok();
let else_res = if let Some(else_block) = &self.2 {
let mut else_scope = scope.with_block(else_b);
scope
.block
.terminate(Term::CondBr(condition, then_bb, else_bb))
.unwrap();
let opt = else_block.codegen(&mut else_scope);
if let Some(ret) = opt {
else_scope.block.terminate(Term::Br(after_bb)).ok();
Some(ret)
} else {
None
}
} else {
else_b.terminate(Term::Br(after_bb)).unwrap();
scope
.block
.terminate(Term::CondBr(condition, then_bb, after_bb))
.unwrap();
None
};
// Swap block to the after-block so that construction can continue correctly
scope.swap_block(after_b);
if then_res.is_none() && else_res.is_none() {
None
} else {
let mut incoming = Vec::from(then_res.as_slice());
incoming.extend(else_res);
Some(scope.block.build(Instr::Phi(incoming)).unwrap())
}
}
}
impl mir::Expression { impl mir::Expression {
fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> { fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> {
match &self.0 { match &self.0 {
@ -517,16 +483,71 @@ impl mir::Expression {
} }
} }
impl IndexedVariableReference { impl mir::IfExpression {
fn get_stack_value(&self, scope: &mut Scope, load_after_gep: bool) -> Option<StackValue> { fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> {
match &self.kind { let condition = self.0.codegen(scope).unwrap();
mir::IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
scope.stack_values.get(name).cloned().map(|v| v) // Create blocks
} let then_b = scope.function.block("then");
mir::IndexedVariableReferenceKind::ArrayIndex(inner, idx) => { let mut else_b = scope.function.block("else");
let inner_stack_val = inner.get_stack_value(scope, true)?; let after_b = scope.function.block("after");
// Store for convenience
let then_bb = then_b.value();
let else_bb = else_b.value();
let after_bb = after_b.value();
// Generate then-block content
let mut then_scope = scope.with_block(then_b);
let then_res = self.1.codegen(&mut then_scope);
then_scope.block.terminate(Term::Br(after_bb)).ok();
let else_res = if let Some(else_block) = &self.2 {
let mut else_scope = scope.with_block(else_b);
scope
.block
.terminate(Term::CondBr(condition, then_bb, else_bb))
.unwrap();
let opt = else_block.codegen(&mut else_scope);
if let Some(ret) = opt {
else_scope.block.terminate(Term::Br(after_bb)).ok();
Some(ret)
} else {
None
}
} else {
else_b.terminate(Term::Br(after_bb)).unwrap();
scope
.block
.terminate(Term::CondBr(condition, then_bb, after_bb))
.unwrap();
None
};
// Swap block to the after-block so that construction can continue correctly
scope.swap_block(after_b);
if then_res.is_none() && else_res.is_none() {
None
} else {
let mut incoming = Vec::from(then_res.as_slice());
incoming.extend(else_res);
Some(scope.block.build(Instr::Phi(incoming)).unwrap())
}
}
}
// impl IndexedVariableReference {
// fn get_stack_value(&self, scope: &mut Scope_after_gep: bool) -> Option<StackValue> {
// match &self.kind {
// mir::IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
// scope.stack_values.get(name).cloned().map(|v| v)
// }
// mir::IndexedVariableReferenceKind::ArrayIndex(inner, idx) => {
// let inner_stack_val = inner.get_stack_value(scope, true)?;
todo!();
// let mut gep_instr = scope // let mut gep_instr = scope
// .block // .block
// .build(Instr::GetElemPtr( // .build(Instr::GetElemPtr(
@ -550,52 +571,52 @@ impl IndexedVariableReference {
// } // }
// _ => panic!("Tried to codegen indexing a non-indexable value!"), // _ => panic!("Tried to codegen indexing a non-indexable value!"),
// } // }
} // }
mir::IndexedVariableReferenceKind::StructIndex(inner, field) => { // mir::IndexedVariableReferenceKind::StructIndex(inner, field) => {
let inner_stack_val = inner.get_stack_value(scope, true)?; // let inner_stack_val = inner.get_stack_value(scope, true)?;
let (instr_value, inner_ty) = if let Type::Ptr(inner_ty) = inner_stack_val.1 { // let (instr_value, inner_ty) = if let Type::Ptr(inner_ty) = inner_stack_val.1 {
if let Type::CustomType(ty_val) = *inner_ty { // if let Type::CustomType(ty_val) = *inner_ty {
match scope.types.get(&ty_val).unwrap() { // match scope.types.get(&ty_val).unwrap() {
TypeDefinitionKind::Struct(struct_type) => { // TypeDefinitionKind::Struct(struct_type) => {
let idx = struct_type.find_index(field)?; // let idx = struct_type.find_index(field)?;
let field_ty = struct_type // let field_ty = struct_type
.get_field_ty(field)? // .get_field_ty(field)?
.get_type(scope.type_values, scope.types); // .get_type(scope.type_values, scope.types);
let mut gep_instr = scope // let mut gep_instr = scope
.block // .block
.build(Instr::GetStructElemPtr( // .build(Instr::GetStructElemPtr(
unsafe { *inner_stack_val.0.get_instr() }, // unsafe { *inner_stack_val.0.get_instr() },
idx, // idx,
)) // ))
.unwrap(); // .unwrap();
if load_after_gep { // if load_after_gep {
gep_instr = scope // gep_instr = scope
.block // .block
.build(Instr::Load(gep_instr, field_ty.clone())) // .build(Instr::Load(gep_instr, field_ty.clone()))
.unwrap() // .unwrap()
} // }
Some((gep_instr, field_ty)) // Some((gep_instr, field_ty))
} // }
} // }
} else { // } else {
None // None
} // }
} else { // } else {
None // None
}?; // }?;
Some(StackValue( // Some(StackValue(
inner_stack_val.0.with_instr(instr_value), // inner_stack_val.0.with_instr(instr_value),
Type::Ptr(Box::new(inner_ty)), // Type::Ptr(Box::new(inner_ty)),
)) // ))
} // }
} // }
} // }
} // }
impl mir::CmpOperator { impl mir::CmpOperator {
fn int_predicate(&self) -> CmpPredicate { fn int_predicate(&self) -> CmpPredicate {
@ -610,30 +631,6 @@ impl mir::CmpOperator {
} }
} }
impl mir::Block {
fn codegen<'ctx, 'a>(&self, mut scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> {
for stmt in &self.statements {
stmt.codegen(&mut scope);
}
if let Some((kind, expr)) = &self.return_expression {
if let Some(ret) = expr.codegen(&mut scope) {
match kind {
mir::ReturnKind::Hard => {
scope.block.terminate(Term::Ret(ret)).unwrap();
None
}
mir::ReturnKind::Soft => Some(ret),
}
} else {
None
}
} else {
None
}
}
}
impl mir::Literal { impl mir::Literal {
fn as_const(&self, block: &mut Block) -> InstructionValue { fn as_const(&self, block: &mut Block) -> InstructionValue {
block.build(self.as_const_kind()).unwrap() block.build(self.as_const_kind()).unwrap()

View File

@ -242,22 +242,6 @@ impl Display for NamedVariableRef {
} }
} }
impl Display for IndexedVariableReference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind {
IndexedVariableReferenceKind::Named(name) => Display::fmt(name, f),
IndexedVariableReferenceKind::ArrayIndex(var_ref, idx) => {
Display::fmt(&var_ref, f)?;
write_index(f, *idx)
}
IndexedVariableReferenceKind::StructIndex(var_ref, name) => {
Display::fmt(&var_ref, f)?;
write_access(f, name)
}
}
}
}
impl Display for Literal { impl Display for Literal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {

View File

@ -47,13 +47,33 @@ impl StructType {
} }
} }
pub trait ReturnType { enum BlockReturn<'b> {
/// Return the return type of this node Early(&'b Statement),
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther>; Normal(ReturnKind, &'b Expression),
} }
impl ReturnType for Block { impl Block {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { fn return_expr(&self) -> Result<BlockReturn, ReturnTypeOther> {
let mut early_return = None;
for statement in &self.statements {
let ret = statement.return_type();
if let Ok((ReturnKind::Hard, _)) = ret {
early_return = Some(statement);
}
}
if let Some(s) = early_return {
return Ok(BlockReturn::Early(s));
}
self.return_expression
.as_ref()
.map(|(r, e)| BlockReturn::Normal(*r, e))
.ok_or(ReturnTypeOther::NoBlockReturn(self.meta))
}
pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
let mut early_return = None; let mut early_return = None;
for statement in &self.statements { for statement in &self.statements {
@ -72,28 +92,47 @@ impl ReturnType for Block {
.ok_or(ReturnTypeOther::NoBlockReturn(self.meta)) .ok_or(ReturnTypeOther::NoBlockReturn(self.meta))
.and_then(|(kind, stmt)| Ok((*kind, stmt.return_type()?.1))) .and_then(|(kind, stmt)| Ok((*kind, stmt.return_type()?.1)))
} }
pub fn backing_var(&self) -> Option<&NamedVariableRef> {
match self.return_expr().ok()? {
BlockReturn::Early(statement) => statement.backing_var(),
BlockReturn::Normal(kind, expr) => {
if kind == ReturnKind::Soft {
expr.backing_var()
} else {
None
}
}
}
}
} }
impl ReturnType for Statement { impl Statement {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
use StmtKind::*; use StmtKind::*;
match &self.0 { match &self.0 {
Let(var, _, expr) => if_hard( Let(var, _, expr) => if_hard(
expr.return_type()?, expr.return_type()?,
Err(ReturnTypeOther::Let(var.2 + expr.1)), Err(ReturnTypeOther::Let(var.2 + expr.1)),
), ),
Set(var, expr) => if_hard( Set(lhs, rhs) => if_hard(rhs.return_type()?, Err(ReturnTypeOther::Set(lhs.1 + rhs.1))),
expr.return_type()?,
Err(ReturnTypeOther::Set(var.meta + expr.1)),
),
Import(_) => todo!(), Import(_) => todo!(),
Expression(expression) => expression.return_type(), Expression(expression) => expression.return_type(),
} }
} }
pub fn backing_var(&self) -> Option<&NamedVariableRef> {
match &self.0 {
StmtKind::Let(_, _, _) => None,
StmtKind::Set(_, _) => None,
StmtKind::Import(_) => None,
StmtKind::Expression(expr) => expr.backing_var(),
}
}
} }
impl ReturnType for Expression { impl Expression {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
use ExprKind::*; use ExprKind::*;
match &self.0 { match &self.0 {
Literal(lit) => Ok((ReturnKind::Soft, lit.as_type())), Literal(lit) => Ok((ReturnKind::Soft, lit.as_type())),
@ -130,10 +169,25 @@ impl ReturnType for Expression {
Struct(name, _) => Ok((ReturnKind::Soft, TypeKind::CustomType(name.clone()))), Struct(name, _) => Ok((ReturnKind::Soft, TypeKind::CustomType(name.clone()))),
} }
} }
pub fn backing_var(&self) -> Option<&NamedVariableRef> {
match &self.0 {
ExprKind::Variable(var_ref) => Some(var_ref),
ExprKind::Indexed(lhs, _, _) => lhs.backing_var(),
ExprKind::Accessed(lhs, _, _) => lhs.backing_var(),
ExprKind::Array(_) => None,
ExprKind::Struct(_, _) => None,
ExprKind::Literal(_) => None,
ExprKind::BinOp(_, _, _) => None,
ExprKind::FunctionCall(_) => None,
ExprKind::If(_) => None,
ExprKind::Block(block) => block.backing_var(),
}
}
} }
impl ReturnType for IfExpression { impl IfExpression {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
let then_r = self.1.return_type()?; let then_r = self.1.return_type()?;
if let Some(else_b) = &self.2 { if let Some(else_b) = &self.2 {
let else_r = else_b.return_type()?; let else_r = else_b.return_type()?;
@ -150,14 +204,14 @@ impl ReturnType for IfExpression {
} }
} }
impl ReturnType for NamedVariableRef { impl NamedVariableRef {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
Ok((ReturnKind::Soft, self.0.clone())) Ok((ReturnKind::Soft, self.0.clone()))
} }
} }
impl ReturnType for FunctionCall { impl FunctionCall {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
Ok((ReturnKind::Soft, self.return_type.clone())) Ok((ReturnKind::Soft, self.return_type.clone()))
} }
} }
@ -215,72 +269,6 @@ impl TypeKind {
} }
} }
impl IndexedVariableReference {
pub fn get_name(&self) -> String {
match &self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => name.clone(),
IndexedVariableReferenceKind::ArrayIndex(inner, idx) => {
format!("{}[{}]", inner.get_name(), idx)
}
IndexedVariableReferenceKind::StructIndex(inner, name) => {
format!("{}.{}", inner.get_name(), name)
}
}
}
/// Retrieve the indexed type that this variable reference is pointing to
pub fn retrieve_type(&self, scope: &pass::Scope) -> Result<TypeKind, ErrorKind> {
match &self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(ty, _, _)) => Ok(ty.clone()),
IndexedVariableReferenceKind::ArrayIndex(inner, _) => {
let inner_ty = inner.retrieve_type(scope)?;
match inner_ty {
TypeKind::Array(type_kind, _) => Ok(*type_kind),
_ => Err(ErrorKind::TriedIndexingNonArray(inner_ty)),
}
}
IndexedVariableReferenceKind::StructIndex(inner, field_name) => {
let inner_ty = inner.retrieve_type(scope)?;
match inner_ty {
TypeKind::CustomType(struct_name) => {
let struct_ty = scope
.get_struct_type(&struct_name)
.ok_or(ErrorKind::NoSuchType(struct_name.clone()))?;
struct_ty
.get_field_ty(&field_name)
.ok_or(ErrorKind::NoSuchField(field_name.clone()))
.cloned()
}
_ => Err(ErrorKind::TriedAccessingNonStruct(inner_ty)),
}
}
}
}
pub fn into_typeref<'s>(&mut self, typerefs: &'s ScopeTypeRefs) -> Option<(bool, TypeRef<'s>)> {
match &mut self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(ty, name, _)) => {
let t = typerefs.find_var(name)?;
*ty = t.1.as_type();
Some(t)
}
IndexedVariableReferenceKind::ArrayIndex(inner, _) => inner.into_typeref(typerefs),
IndexedVariableReferenceKind::StructIndex(inner, _) => inner.into_typeref(typerefs),
}
}
pub fn resolve_ref<'s>(&mut self, typerefs: &'s TypeRefs) -> Result<TypeKind, ErrorKind> {
match &mut self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(ty, _, _)) => {
*ty = ty.resolve_ref(typerefs);
Ok(ty.clone())
}
IndexedVariableReferenceKind::ArrayIndex(inner, _) => inner.resolve_ref(typerefs),
IndexedVariableReferenceKind::StructIndex(inner, _) => inner.resolve_ref(typerefs),
}
}
}
#[derive(Debug, Clone, thiserror::Error)] #[derive(Debug, Clone, thiserror::Error)]
pub enum EqualsIssue { pub enum EqualsIssue {
#[error("Function is already defined locally at {:?}", (.0).range)] #[error("Function is already defined locally at {:?}", (.0).range)]

View File

@ -11,7 +11,7 @@ use crate::{compile_module, ReidError};
use super::{ use super::{
pass::{Pass, PassState}, pass::{Pass, PassState},
types::EqualsIssue, r#impl::EqualsIssue,
Context, FunctionDefinition, Import, Metadata, Module, Context, FunctionDefinition, Import, Metadata, Module,
}; };

View File

@ -7,12 +7,12 @@ use std::{collections::HashMap, path::PathBuf};
use crate::token_stream::TokenRange; use crate::token_stream::TokenRange;
mod display; mod display;
pub mod r#impl;
pub mod linker; pub mod linker;
pub mod pass; pub mod pass;
pub mod typecheck; pub mod typecheck;
pub mod typeinference; pub mod typeinference;
pub mod typerefs; pub mod typerefs;
pub mod types;
#[derive(Debug, Default, Clone, Copy)] #[derive(Debug, Default, Clone, Copy)]
pub struct Metadata { pub struct Metadata {
@ -262,24 +262,11 @@ pub struct Block {
#[derive(Debug)] #[derive(Debug)]
pub struct Statement(pub StmtKind, pub Metadata); pub struct Statement(pub StmtKind, pub Metadata);
#[derive(Debug)]
pub struct IndexedVariableReference {
pub kind: IndexedVariableReferenceKind,
pub meta: Metadata,
}
#[derive(Debug)]
pub enum IndexedVariableReferenceKind {
Named(NamedVariableRef),
ArrayIndex(Box<IndexedVariableReference>, u64),
StructIndex(Box<IndexedVariableReference>, String),
}
#[derive(Debug)] #[derive(Debug)]
pub enum StmtKind { pub enum StmtKind {
/// Variable name++mutability+type, evaluation /// Variable name++mutability+type, evaluation
Let(NamedVariableRef, bool, Expression), Let(NamedVariableRef, bool, Expression),
Set(IndexedVariableReference, Expression), Set(Expression, Expression),
Import(Import), Import(Import),
Expression(Expression), Expression(Expression),
} }

View File

@ -8,7 +8,6 @@ use VagueType as Vague;
use super::{ use super::{
pass::{Pass, PassState, ScopeFunction, ScopeVariable, Storage}, pass::{Pass, PassState, ScopeFunction, ScopeVariable, Storage},
typerefs::TypeRefs, typerefs::TypeRefs,
types::ReturnType,
}; };
#[derive(thiserror::Error, Debug, Clone)] #[derive(thiserror::Error, Debug, Clone)]
@ -31,7 +30,7 @@ pub enum ErrorKind {
FunctionAlreadyDefined(String), FunctionAlreadyDefined(String),
#[error("Variable not defined: {0}")] #[error("Variable not defined: {0}")]
VariableAlreadyDefined(String), VariableAlreadyDefined(String),
#[error("Variable not mutable: {0}")] #[error("Variable {0} is not declared as mutable")]
VariableNotMutable(String), VariableNotMutable(String),
#[error("Function {0} was given {1} parameters, but {2} were expected")] #[error("Function {0} was given {1} parameters, but {2} were expected")]
InvalidAmountParameters(String, usize, usize), InvalidAmountParameters(String, usize, usize),
@ -55,6 +54,8 @@ pub enum ErrorKind {
DuplicateTypeName(String), DuplicateTypeName(String),
#[error("Recursive type definition: {0}.{1}")] #[error("Recursive type definition: {0}.{1}")]
RecursiveTypeDefinition(String, String), RecursiveTypeDefinition(String, String),
#[error("This type of expression can not be used for assignment")]
InvalidSetExpression,
} }
/// Struct used to implement a type-checking pass that can be performed on the /// Struct used to implement a type-checking pass that can be performed on the
@ -250,51 +251,46 @@ impl Block {
state.ok(res, variable_reference.2); state.ok(res, variable_reference.2);
None None
} }
StmtKind::Set(variable_reference, expression) => { StmtKind::Set(lhs, rhs) => {
// Update typing from reference // Typecheck expression and coerce to variable type
variable_reference.resolve_ref(&typerefs)?; let lhs_res = lhs.typecheck(&mut state, typerefs, None);
// If expression resolution itself was erronous, resolve as
if let Some(var) = state // Unknown.
.ok( let lhs_ty = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1);
variable_reference
.get_variable(&state.scope.variables, &state.scope.types),
variable_reference.meta,
)
.flatten()
{
let field_ty = variable_reference.retrieve_type(&state.scope)?;
dbg!(&field_ty);
// Typecheck expression and coerce to variable type // Typecheck expression and coerce to variable type
let res = expression.typecheck(&mut state, &typerefs, Some(&field_ty)); let res = rhs.typecheck(&mut state, &typerefs, Some(&lhs_ty));
// If expression resolution itself was erronous, resolve as // If expression resolution itself was erronous, resolve as
// Unknown. // Unknown.
let expr_ty = let rhs_ty = state.or_else(res, TypeKind::Vague(Vague::Unknown), rhs.1);
state.or_else(res, TypeKind::Vague(Vague::Unknown), expression.1);
// Make sure the expression and variable type to really // Make sure the expression and variable type to really
// be the same // be the same
state.ok( state.ok(lhs_ty.collapse_into(&rhs_ty), lhs.1 + rhs.1);
expr_ty.collapse_into(&field_ty),
variable_reference.meta + expression.1,
);
if !var.mutable { if let Some(named_var) = lhs.backing_var() {
if let Some(scope_var) = state.scope.variables.get(&named_var.1) {
if !scope_var.mutable {
state.ok::<_, Infallible>( state.ok::<_, Infallible>(
Err(ErrorKind::VariableNotMutable(variable_reference.get_name())), Err(ErrorKind::VariableNotMutable(named_var.1.clone())),
variable_reference.meta, lhs.1,
); );
} }
}
None
} else { } else {
state.ok::<_, Infallible>( state.ok::<_, Infallible>(Err(ErrorKind::InvalidSetExpression), lhs.1);
Err(ErrorKind::VariableNotDefined(variable_reference.get_name())),
variable_reference.meta,
);
None
} }
// TODO add error about variable mutability, need to check
// that the expression is based on a variable first though..
// if true {
// state.ok::<_, Infallible>(
// Err(ErrorKind::VariableNotMutable(variable_reference.get_name())),
// variable_reference.meta,
// );
// }
None
} }
StmtKind::Import(_) => todo!(), // TODO StmtKind::Import(_) => todo!(), // TODO
StmtKind::Expression(expression) => { StmtKind::Expression(expression) => {
@ -346,7 +342,7 @@ impl Expression {
fn typecheck( fn typecheck(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut PassState<ErrorKind>,
hints: &TypeRefs, typerefs: &TypeRefs,
hint_t: Option<&TypeKind>, hint_t: Option<&TypeKind>,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {
match &mut self.0 { match &mut self.0 {
@ -363,11 +359,11 @@ impl Expression {
TypeKind::Vague(Vague::Unknown), TypeKind::Vague(Vague::Unknown),
var_ref.2, var_ref.2,
) )
.resolve_ref(hints); .resolve_ref(typerefs);
// Update typing to be more accurate // Update typing to be more accurate
var_ref.0 = state.or_else( var_ref.0 = state.or_else(
var_ref.0.resolve_ref(hints).collapse_into(&existing), var_ref.0.resolve_ref(typerefs).collapse_into(&existing),
TypeKind::Vague(Vague::Unknown), TypeKind::Vague(Vague::Unknown),
var_ref.2, var_ref.2,
); );
@ -381,15 +377,15 @@ impl Expression {
ExprKind::BinOp(op, lhs, rhs) => { ExprKind::BinOp(op, lhs, rhs) => {
// TODO make sure lhs and rhs can actually do this binary // TODO make sure lhs and rhs can actually do this binary
// operation once relevant // operation once relevant
let lhs_res = lhs.typecheck(state, &hints, None); let lhs_res = lhs.typecheck(state, &typerefs, None);
let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1); let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1);
let rhs_res = rhs.typecheck(state, &hints, Some(&lhs_type)); let rhs_res = rhs.typecheck(state, &typerefs, Some(&lhs_type));
let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1); let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1);
if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) { if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) {
// Try to coerce both sides again with collapsed type // Try to coerce both sides again with collapsed type
lhs.typecheck(state, &hints, Some(&collapsed)).ok(); lhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
rhs.typecheck(state, &hints, Some(&collapsed)).ok(); rhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
} }
let both_t = lhs_type.collapse_into(&rhs_type)?; let both_t = lhs_type.collapse_into(&rhs_type)?;
@ -429,7 +425,7 @@ impl Expression {
function_call.parameters.iter_mut().zip(true_params_iter) function_call.parameters.iter_mut().zip(true_params_iter)
{ {
// Typecheck every param separately // Typecheck every param separately
let param_res = param.typecheck(state, &hints, Some(&true_param_t)); let param_res = param.typecheck(state, &typerefs, Some(&true_param_t));
let param_t = let param_t =
state.or_else(param_res, TypeKind::Vague(Vague::Unknown), param.1); state.or_else(param_res, TypeKind::Vague(Vague::Unknown), param.1);
state.ok(param_t.collapse_into(&true_param_t), param.1); state.ok(param_t.collapse_into(&true_param_t), param.1);
@ -439,29 +435,29 @@ impl Expression {
// return type // return type
let ret_t = f let ret_t = f
.ret .ret
.collapse_into(&function_call.return_type.resolve_ref(hints))?; .collapse_into(&function_call.return_type.resolve_ref(typerefs))?;
// Update typing to be more accurate // Update typing to be more accurate
function_call.return_type = ret_t.clone(); function_call.return_type = ret_t.clone();
Ok(ret_t.resolve_ref(hints)) Ok(ret_t.resolve_ref(typerefs))
} else { } else {
Ok(function_call.return_type.clone().resolve_ref(hints)) Ok(function_call.return_type.clone().resolve_ref(typerefs))
} }
} }
ExprKind::If(IfExpression(cond, lhs, rhs)) => { ExprKind::If(IfExpression(cond, lhs, rhs)) => {
let cond_res = cond.typecheck(state, &hints, Some(&TypeKind::Bool)); let cond_res = cond.typecheck(state, &typerefs, Some(&TypeKind::Bool));
let cond_t = state.or_else(cond_res, TypeKind::Vague(Vague::Unknown), cond.1); let cond_t = state.or_else(cond_res, TypeKind::Vague(Vague::Unknown), cond.1);
state.ok(cond_t.collapse_into(&TypeKind::Bool), cond.1); state.ok(cond_t.collapse_into(&TypeKind::Bool), cond.1);
// Typecheck then/else return types and make sure they are the // Typecheck then/else return types and make sure they are the
// same, if else exists. // same, if else exists.
let then_res = lhs.typecheck(state, &hints, hint_t); let then_res = lhs.typecheck(state, &typerefs, hint_t);
let (then_ret_kind, then_ret_t) = state.or_else( let (then_ret_kind, then_ret_t) = state.or_else(
then_res, then_res,
(ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)), (ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)),
lhs.meta, lhs.meta,
); );
let else_ret_t = if let Some(else_block) = rhs { let else_ret_t = if let Some(else_block) = rhs {
let res = else_block.typecheck(state, &hints, hint_t); let res = else_block.typecheck(state, &typerefs, hint_t);
let (else_ret_kind, else_ret_t) = state.or_else( let (else_ret_kind, else_ret_t) = state.or_else(
res, res,
(ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)), (ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)),
@ -491,30 +487,32 @@ impl Expression {
if let Some(rhs) = rhs { if let Some(rhs) = rhs {
// If rhs existed, typecheck both sides to perform type // If rhs existed, typecheck both sides to perform type
// coercion. // coercion.
let lhs_res = lhs.typecheck(state, &hints, Some(&collapsed)); let lhs_res = lhs.typecheck(state, &typerefs, Some(&collapsed));
let rhs_res = rhs.typecheck(state, &hints, Some(&collapsed)); let rhs_res = rhs.typecheck(state, &typerefs, Some(&collapsed));
state.ok(lhs_res, lhs.meta); state.ok(lhs_res, lhs.meta);
state.ok(rhs_res, rhs.meta); state.ok(rhs_res, rhs.meta);
} }
Ok(collapsed) Ok(collapsed)
} }
ExprKind::Block(block) => match block.typecheck(state, &hints, hint_t) { ExprKind::Block(block) => match block.typecheck(state, &typerefs, hint_t) {
Ok((ReturnKind::Hard, _)) => Ok(TypeKind::Void), Ok((ReturnKind::Hard, _)) => Ok(TypeKind::Void),
Ok((_, ty)) => Ok(ty), Ok((_, ty)) => Ok(ty),
Err(e) => Err(e), Err(e) => Err(e),
}, },
ExprKind::Indexed(expression, elem_ty, idx) => { ExprKind::Indexed(expression, elem_ty, _) => {
// Try to unwrap hint type from array if possible // Try to unwrap hint type from array if possible
let hint_t = hint_t.map(|t| match t { let hint_t = hint_t.map(|t| match t {
TypeKind::Array(type_kind, _) => &type_kind, TypeKind::Array(type_kind, _) => &type_kind,
_ => t, _ => t,
}); });
let expr_t = expression.typecheck(state, hints, hint_t)?; // TODO it could be possible to check length against constants..
if let TypeKind::Array(inferred_ty, len) = expr_t {
let expr_t = expression.typecheck(state, typerefs, hint_t)?;
if let TypeKind::Array(inferred_ty, _) = expr_t {
let ty = state.or_else( let ty = state.or_else(
elem_ty.resolve_ref(hints).collapse_into(&inferred_ty), elem_ty.resolve_ref(typerefs).collapse_into(&inferred_ty),
TypeKind::Vague(Vague::Unknown), TypeKind::Vague(Vague::Unknown),
self.1, self.1,
); );
@ -534,7 +532,7 @@ impl Expression {
let mut expr_result = try_all( let mut expr_result = try_all(
expressions expressions
.iter_mut() .iter_mut()
.map(|e| e.typecheck(state, hints, hint_t)) .map(|e| e.typecheck(state, typerefs, hint_t))
.collect(), .collect(),
); );
match &mut expr_result { match &mut expr_result {
@ -563,10 +561,10 @@ impl Expression {
} }
ExprKind::Accessed(expression, type_kind, field_name) => { ExprKind::Accessed(expression, type_kind, field_name) => {
// Resolve expected type // Resolve expected type
let expected_ty = type_kind.resolve_ref(hints); let expected_ty = type_kind.resolve_ref(typerefs);
// Typecheck expression // Typecheck expression
let expr_res = expression.typecheck(state, hints, Some(&expected_ty)); let expr_res = expression.typecheck(state, typerefs, Some(&expected_ty));
let expr_ty = let expr_ty =
state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), expression.1); state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), expression.1);
@ -612,7 +610,7 @@ impl Expression {
); );
// Typecheck the actual expression // Typecheck the actual expression
let expr_res = field_expr.typecheck(state, hints, Some(expected_ty)); let expr_res = field_expr.typecheck(state, typerefs, Some(expected_ty));
let expr_ty = let expr_ty =
state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), field_expr.1); state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), field_expr.1);
@ -625,64 +623,6 @@ impl Expression {
} }
} }
impl IndexedVariableReference {
fn get_variable(
&self,
storage: &Storage<ScopeVariable>,
types: &Storage<TypeDefinitionKind>,
) -> Result<Option<ScopeVariable>, ErrorKind> {
match &self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
Ok(storage.get(&name).cloned())
}
IndexedVariableReferenceKind::ArrayIndex(inner_ref, _) => {
if let Some(var) = inner_ref.get_variable(storage, types)? {
match &var.ty {
TypeKind::Array(inner_ty, _) => Ok(Some(ScopeVariable {
ty: *inner_ty.clone(),
mutable: var.mutable,
})),
_ => Err(ErrorKind::TriedIndexingNonArray(var.ty.clone())),
}
} else {
Ok(None)
}
}
IndexedVariableReferenceKind::StructIndex(var_ref, field_name) => {
if let Some(var) = var_ref.get_variable(storage, types)? {
match &var.ty {
TypeKind::CustomType(type_name) => {
if let Some(kind) = types.get(type_name) {
match &kind {
TypeDefinitionKind::Struct(struct_type) => {
if let Some(StructField(_, field_ty, _)) = struct_type
.0
.iter()
.find(|StructField(n, _, _)| n == field_name)
{
Ok(Some(ScopeVariable {
ty: field_ty.clone(),
mutable: var.mutable,
}))
} else {
Err(ErrorKind::NoSuchField(field_name.clone()))
}
}
}
} else {
Err(ErrorKind::NoSuchType(type_name.clone()))
}
}
_ => Err(ErrorKind::TriedAccessingNonStruct(var.ty.clone())),
}
} else {
Ok(None)
}
}
}
}
}
impl Literal { impl Literal {
/// Try to coerce this literal, ie. convert it to a more specific type in /// Try to coerce this literal, ie. convert it to a more specific type in
/// regards to the given hint if any. /// regards to the given hint if any.

View File

@ -10,9 +10,9 @@ use crate::{mir::TypeKind, util::try_all};
use super::{ use super::{
pass::{Pass, PassState}, pass::{Pass, PassState},
r#impl::pick_return,
typecheck::ErrorKind, typecheck::ErrorKind,
typerefs::{ScopeTypeRefs, TypeRef, TypeRefs}, typerefs::{ScopeTypeRefs, TypeRef, TypeRefs},
types::{pick_return, ReturnType},
Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module, Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module,
ReturnKind, StmtKind, ReturnKind, StmtKind,
TypeKind::*, TypeKind::*,
@ -104,18 +104,18 @@ impl Block {
var_ref.narrow(&expr_ty_ref); var_ref.narrow(&expr_ty_ref);
} }
} }
StmtKind::Set(var, expr) => { StmtKind::Set(lhs, rhs) => {
// Update this MIR type to its TypeRef // Infer hints for the expression itself
let var_ref = var.into_typeref(&inner_hints); let lhs_infer = lhs.infer_types(&mut state, &inner_hints);
let lhs_ref = state.ok(lhs_infer, rhs.1);
// Infer hints for the expression itself // Infer hints for the expression itself
let inferred = expr.infer_types(&mut state, &inner_hints); let rhs_infer = rhs.infer_types(&mut state, &inner_hints);
let expr_ty_ref = state.ok(inferred, expr.1); let rhs_ref = state.ok(rhs_infer, rhs.1);
// Try to narrow the variable type declaration with the // Try to narrow the lhs with rhs
// expression if let (Some(mut lhs_ref), Some(rhs_ref)) = (lhs_ref, rhs_ref) {
if let (Some((_, mut var_ref)), Some(expr_ty_ref)) = (var_ref, expr_ty_ref) { lhs_ref.narrow(&rhs_ref);
var_ref.narrow(&expr_ty_ref);
} }
} }
StmtKind::Import(_) => todo!(), StmtKind::Import(_) => todo!(),

View File

@ -11,7 +11,7 @@ fn main() -> u32 {
second: [6, 3, 17, 8], second: [6, 3, 17, 8],
}]; }];
// value[0].second[2] = 99; value[0].second[2] = 99;
return value[0].second[2]; return value[0].second[2];
} }