Compare commits

...

5 Commits

Author SHA1 Message Date
d2cf97af66 Move should_load to separate State 2025-07-16 23:22:06 +03:00
d034754202 Possibly fix array_structs 2025-07-16 23:09:36 +03:00
c41aab33a9 Add optional data to PassState Scope 2025-07-16 22:46:52 +03:00
c19384d77b Refactor a bit 2025-07-16 22:38:19 +03:00
3870b421a9 Refactor indexing/accessing a bit, no mutability 2025-07-16 22:04:11 +03:00
15 changed files with 602 additions and 778 deletions

View File

@ -320,6 +320,7 @@ impl Builder {
} }
Instr::GetStructElemPtr(ptr_val, idx) => { Instr::GetStructElemPtr(ptr_val, idx) => {
let ptr_ty = ptr_val.get_type(&self)?; let ptr_ty = ptr_val.get_type(&self)?;
dbg!(&ptr_ty);
if let Type::Ptr(ty) = ptr_ty { if let Type::Ptr(ty) = ptr_ty {
if let Type::CustomType(val) = *ty { if let Type::CustomType(val) = *ty {
match self.type_data(&val).kind { match self.type_data(&val).kind {

View File

@ -487,7 +487,7 @@ impl InstructionHolder {
let mut llvm_indices: Vec<_> = indices let mut llvm_indices: Vec<_> = indices
.iter() .iter()
.map(|idx| ConstValue::U32(*idx).as_llvm(module)) .map(|idx_elem| module.values.get(idx_elem).unwrap().value_ref)
.collect(); .collect();
LLVMBuildGEP2( LLVMBuildGEP2(

View File

@ -128,7 +128,7 @@ impl Debug for Instr {
instruction_value, instruction_value,
&items &items
.iter() .iter()
.map(|i| i.to_string()) .map(|expr| format!("{:?}", expr))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", "), .join(", "),
), ),

View File

@ -226,7 +226,7 @@ pub enum Instr {
Load(InstructionValue, Type), Load(InstructionValue, Type),
Store(InstructionValue, InstructionValue), Store(InstructionValue, InstructionValue),
ArrayAlloca(Type, u32), ArrayAlloca(Type, u32),
GetElemPtr(InstructionValue, Vec<u32>), GetElemPtr(InstructionValue, Vec<InstructionValue>),
GetStructElemPtr(InstructionValue, u32), GetStructElemPtr(InstructionValue, u32),
/// Integer Comparison /// Integer Comparison

View File

@ -44,8 +44,10 @@ pub enum ExpressionKind {
VariableName(String), VariableName(String),
Literal(Literal), Literal(Literal),
Array(Vec<Expression>), Array(Vec<Expression>),
ArrayIndex(Box<Expression>, u64), /// Array-indexed, e.g. <expr>[<expr>]
StructIndex(Box<Expression>, String), Indexed(Box<Expression>, Box<Expression>),
/// Struct-accessed, e.g. <expr>.<expr>
Accessed(Box<Expression>, String),
Binop(BinaryOperator, Box<Expression>, Box<Expression>), Binop(BinaryOperator, Box<Expression>, Box<Expression>),
FunctionCall(Box<FunctionCallExpression>), FunctionCall(Box<FunctionCallExpression>),
BlockExpr(Box<Block>), BlockExpr(Box<Block>),
@ -145,21 +147,11 @@ pub struct Block(
pub TokenRange, pub TokenRange,
); );
#[derive(Debug, Clone)]
pub struct VariableReference(pub VariableReferenceKind, pub TokenRange);
#[derive(Debug, Clone)]
pub enum VariableReferenceKind {
Name(String, TokenRange),
ArrayIndex(Box<VariableReference>, u64),
StructIndex(Box<VariableReference>, String),
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum BlockLevelStatement { pub enum BlockLevelStatement {
Let(LetStatement), Let(LetStatement),
/// Try to set a variable to a specified expression value /// Try to set a variable to a specified expression value
Set(VariableReference, Expression, TokenRange), Set(Expression, Expression, TokenRange),
Import { Import {
_i: ImportStatement, _i: ImportStatement,
}, },

View File

@ -136,15 +136,15 @@ impl Parse for PrimaryExpression {
while let Ok(index) = stream.parse::<ValueIndex>() { while let Ok(index) = stream.parse::<ValueIndex>() {
match index { match index {
ValueIndex::Array(ArrayValueIndex(idx)) => { ValueIndex::Array(ArrayValueIndex(idx_expr)) => {
expr = Expression( expr = Expression(
ExpressionKind::ArrayIndex(Box::new(expr), idx), ExpressionKind::Indexed(Box::new(expr), Box::new(idx_expr)),
stream.get_range().unwrap(), stream.get_range().unwrap(),
); );
} }
ValueIndex::Struct(StructValueIndex(name)) => { ValueIndex::Struct(StructValueIndex(name)) => {
expr = Expression( expr = Expression(
ExpressionKind::StructIndex(Box::new(expr), name), ExpressionKind::Accessed(Box::new(expr), name),
stream.get_range().unwrap(), stream.get_range().unwrap(),
); );
} }
@ -417,38 +417,6 @@ impl Parse for Block {
} }
} }
impl Parse for VariableReference {
fn parse(mut stream: TokenStream) -> Result<Self, Error> {
if let Some(Token::Identifier(ident)) = stream.next() {
let mut var_ref = VariableReference(
VariableReferenceKind::Name(ident, stream.get_one_token_range()),
stream.get_range().unwrap(),
);
while let Ok(val) = stream.parse::<ValueIndex>() {
match val {
ValueIndex::Array(ArrayValueIndex(idx)) => {
var_ref = VariableReference(
VariableReferenceKind::ArrayIndex(Box::new(var_ref), idx),
stream.get_range().unwrap(),
);
}
ValueIndex::Struct(StructValueIndex(name)) => {
var_ref = VariableReference(
VariableReferenceKind::StructIndex(Box::new(var_ref), name),
stream.get_range().unwrap(),
);
}
}
}
Ok(var_ref)
} else {
Err(stream.expected_err("identifier")?)?
}
}
}
impl Parse for StructExpression { impl Parse for StructExpression {
fn parse(mut stream: TokenStream) -> Result<Self, Error> { fn parse(mut stream: TokenStream) -> Result<Self, Error> {
let Some(Token::Identifier(name)) = stream.next() else { let Some(Token::Identifier(name)) = stream.next() else {
@ -515,18 +483,15 @@ impl Parse for ValueIndex {
} }
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone)]
pub struct ArrayValueIndex(u64); pub struct ArrayValueIndex(Expression);
impl Parse for ArrayValueIndex { impl Parse for ArrayValueIndex {
fn parse(mut stream: TokenStream) -> Result<Self, Error> { fn parse(mut stream: TokenStream) -> Result<Self, Error> {
stream.expect(Token::BracketOpen)?; stream.expect(Token::BracketOpen)?;
if let Some(Token::DecimalValue(idx)) = stream.next() { let expr = stream.parse()?;
stream.expect(Token::BracketClose)?; stream.expect(Token::BracketClose)?;
Ok(ArrayValueIndex(idx)) Ok(ArrayValueIndex(expr))
} else {
return Err(stream.expected_err("array index (number)")?);
}
} }
} }
@ -578,7 +543,7 @@ impl Parse for BlockLevelStatement {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct SetStatement(VariableReference, Expression, TokenRange); pub struct SetStatement(Expression, Expression, TokenRange);
impl Parse for SetStatement { impl Parse for SetStatement {
fn parse(mut stream: TokenStream) -> Result<Self, Error> { fn parse(mut stream: TokenStream) -> Result<Self, Error> {

View File

@ -157,38 +157,6 @@ impl From<ast::ReturnType> for mir::ReturnKind {
} }
} }
impl ast::VariableReference {
fn process(&self) -> mir::IndexedVariableReference {
mir::IndexedVariableReference {
kind: self.0.process(),
meta: self.1.into(),
}
}
}
impl ast::VariableReferenceKind {
fn process(&self) -> mir::IndexedVariableReferenceKind {
match &self {
ast::VariableReferenceKind::Name(name, range) => {
mir::IndexedVariableReferenceKind::Named(NamedVariableRef(
mir::TypeKind::Vague(mir::VagueType::Unknown),
name.clone(),
(*range).into(),
))
}
ast::VariableReferenceKind::ArrayIndex(var_ref, idx) => {
mir::IndexedVariableReferenceKind::ArrayIndex(Box::new(var_ref.process()), *idx)
}
ast::VariableReferenceKind::StructIndex(var_ref, name) => {
mir::IndexedVariableReferenceKind::StructIndex(
Box::new(var_ref.process()),
name.clone(),
)
}
}
}
}
impl ast::Expression { impl ast::Expression {
fn process(&self) -> mir::Expression { fn process(&self) -> mir::Expression {
let kind = match &self.0 { let kind = match &self.0 {
@ -224,10 +192,10 @@ impl ast::Expression {
ast::ExpressionKind::Array(expressions) => { ast::ExpressionKind::Array(expressions) => {
mir::ExprKind::Array(expressions.iter().map(|e| e.process()).collect()) mir::ExprKind::Array(expressions.iter().map(|e| e.process()).collect())
} }
ast::ExpressionKind::ArrayIndex(expression, idx) => mir::ExprKind::ArrayIndex( ast::ExpressionKind::Indexed(expression, idx_expr) => mir::ExprKind::Indexed(
Box::new(expression.process()), Box::new(expression.process()),
mir::TypeKind::Vague(mir::VagueType::Unknown), mir::TypeKind::Vague(mir::VagueType::Unknown),
*idx, Box::new(idx_expr.process()),
), ),
ast::ExpressionKind::StructExpression(struct_init) => mir::ExprKind::Struct( ast::ExpressionKind::StructExpression(struct_init) => mir::ExprKind::Struct(
struct_init.name.clone(), struct_init.name.clone(),
@ -237,7 +205,7 @@ impl ast::Expression {
.map(|(n, e)| (n.clone(), e.process())) .map(|(n, e)| (n.clone(), e.process()))
.collect(), .collect(),
), ),
ast::ExpressionKind::StructIndex(expression, name) => mir::ExprKind::StructIndex( ast::ExpressionKind::Accessed(expression, name) => mir::ExprKind::Accessed(
Box::new(expression.process()), Box::new(expression.process()),
mir::TypeKind::Vague(mir::VagueType::Unknown), mir::TypeKind::Vague(mir::VagueType::Unknown),
name.clone(), name.clone(),

View File

@ -8,8 +8,7 @@ use reid_lib::{
}; };
use crate::mir::{ use crate::mir::{
self, types::ReturnType, IndexedVariableReference, NamedVariableRef, StructField, StructType, self, NamedVariableRef, StructField, StructType, TypeDefinitionKind, TypeKind, VagueLiteral,
TypeDefinitionKind, TypeKind,
}; };
/// Context that contains all of the given modules as complete codegenerated /// Context that contains all of the given modules as complete codegenerated
@ -48,6 +47,83 @@ impl<'ctx> std::fmt::Debug for ModuleCodegen<'ctx> {
} }
} }
pub struct Scope<'ctx, 'a> {
context: &'ctx Context,
module: &'ctx Module<'ctx>,
function: &'ctx Function<'ctx>,
block: Block<'ctx>,
types: &'a HashMap<TypeValue, TypeDefinitionKind>,
type_values: &'a HashMap<String, TypeValue>,
functions: &'a HashMap<String, Function<'ctx>>,
stack_values: HashMap<String, StackValue>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StackValue(StackValueKind, Type);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StackValueKind {
Immutable(InstructionValue),
Mutable(InstructionValue),
}
impl StackValueKind {
unsafe fn get_instr(&self) -> &InstructionValue {
match self {
StackValueKind::Immutable(val) => val,
StackValueKind::Mutable(val) => val,
}
}
fn with_instr(&self, instr: InstructionValue) -> StackValueKind {
match self {
StackValueKind::Immutable(_) => StackValueKind::Immutable(instr),
StackValueKind::Mutable(_) => StackValueKind::Mutable(instr),
}
}
}
impl<'ctx, 'a> Scope<'ctx, 'a> {
fn with_block(&self, block: Block<'ctx>) -> Scope<'ctx, 'a> {
Scope {
block,
function: self.function,
context: self.context,
module: self.module,
functions: self.functions,
types: self.types,
type_values: self.type_values,
stack_values: self.stack_values.clone(),
}
}
/// Takes the block out from this scope, swaps the given block in it's place
/// and returns the old block.
fn swap_block(&mut self, block: Block<'ctx>) -> Block<'ctx> {
let mut old_block = block;
mem::swap(&mut self.block, &mut old_block);
old_block
}
fn get_typedef(&self, name: &String) -> Option<&TypeDefinitionKind> {
self.type_values.get(name).and_then(|v| self.types.get(v))
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Default, Clone, Copy)]
struct State {
should_load: bool,
}
impl State {
/// Sets should load, returning a new state
fn load(self, should: bool) -> State {
State {
should_load: should,
}
}
}
impl mir::Module { impl mir::Module {
fn codegen<'ctx>(&self, context: &'ctx Context) -> ModuleCodegen<'ctx> { fn codegen<'ctx>(&self, context: &'ctx Context) -> ModuleCodegen<'ctx> {
let mut module = context.module(&self.name, self.is_main); let mut module = context.module(&self.name, self.is_main);
@ -137,7 +213,8 @@ impl mir::Module {
}; };
match &mir_function.kind { match &mir_function.kind {
mir::FunctionDefinitionKind::Local(block, _) => { mir::FunctionDefinitionKind::Local(block, _) => {
if let Some(ret) = block.codegen(&mut scope) { let mut state = State::default();
if let Some(ret) = block.codegen(&mut scope, &mut state) {
scope.block.terminate(Term::Ret(ret)).unwrap(); scope.block.terminate(Term::Ret(ret)).unwrap();
} else { } else {
if !scope.block.delete_if_unused().unwrap() { if !scope.block.delete_if_unused().unwrap() {
@ -155,74 +232,40 @@ impl mir::Module {
} }
} }
pub struct Scope<'ctx, 'a> { impl mir::Block {
context: &'ctx Context, fn codegen<'ctx, 'a>(
module: &'ctx Module<'ctx>, &self,
function: &'ctx Function<'ctx>, mut scope: &mut Scope<'ctx, 'a>,
block: Block<'ctx>, state: &State,
types: &'a HashMap<TypeValue, TypeDefinitionKind>, ) -> Option<InstructionValue> {
type_values: &'a HashMap<String, TypeValue>, for stmt in &self.statements {
functions: &'a HashMap<String, Function<'ctx>>, stmt.codegen(&mut scope, state);
stack_values: HashMap<String, StackValue>,
} }
#[derive(Debug, Clone, PartialEq, Eq)] if let Some((kind, expr)) = &self.return_expression {
pub struct StackValue(StackValueKind, Type); match kind {
mir::ReturnKind::Hard => {
#[derive(Debug, Clone, Copy, PartialEq, Eq)] let ret = expr.codegen(&mut scope, &mut state.load(true))?;
pub enum StackValueKind { scope.block.terminate(Term::Ret(ret)).unwrap();
Immutable(InstructionValue), None
Mutable(InstructionValue),
} }
mir::ReturnKind::Soft => expr.codegen(&mut scope, state),
impl StackValueKind {
unsafe fn get_instr(&self) -> &InstructionValue {
match self {
StackValueKind::Immutable(val) => val,
StackValueKind::Mutable(val) => val,
} }
} else {
None
} }
fn with_instr(&self, instr: InstructionValue) -> StackValueKind {
match self {
StackValueKind::Immutable(_) => StackValueKind::Immutable(instr),
StackValueKind::Mutable(_) => StackValueKind::Mutable(instr),
}
}
}
impl<'ctx, 'a> Scope<'ctx, 'a> {
fn with_block(&self, block: Block<'ctx>) -> Scope<'ctx, 'a> {
Scope {
block,
function: self.function,
context: self.context,
module: self.module,
functions: self.functions,
types: self.types,
type_values: self.type_values,
stack_values: self.stack_values.clone(),
}
}
/// Takes the block out from this scope, swaps the given block in it's place
/// and returns the old block.
fn swap_block(&mut self, block: Block<'ctx>) -> Block<'ctx> {
let mut old_block = block;
mem::swap(&mut self.block, &mut old_block);
old_block
}
fn get_typedef(&self, name: &String) -> Option<&TypeDefinitionKind> {
self.type_values.get(name).and_then(|v| self.types.get(v))
} }
} }
impl mir::Statement { impl mir::Statement {
fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> { fn codegen<'ctx, 'a>(
&self,
scope: &mut Scope<'ctx, 'a>,
state: &State,
) -> Option<InstructionValue> {
match &self.0 { match &self.0 {
mir::StmtKind::Let(NamedVariableRef(ty, name, _), mutable, expression) => { mir::StmtKind::Let(NamedVariableRef(ty, name, _), mutable, expression) => {
let value = expression.codegen(scope).unwrap(); let value = expression.codegen(scope, state).unwrap();
scope.stack_values.insert( scope.stack_values.insert(
name.clone(), name.clone(),
StackValue( StackValue(
@ -257,31 +300,232 @@ impl mir::Statement {
); );
None None
} }
mir::StmtKind::Set(var, val) => { mir::StmtKind::Set(lhs, rhs) => {
if let Some(StackValue(kind, _)) = var.get_stack_value(scope, false) { let lhs_value = lhs
match kind { .codegen(scope, &mut state.load(false))
StackValueKind::Immutable(_) => { .expect("non-returning LHS snuck into codegen!");
panic!("Tried to mutate an immutable variable")
let rhs_value = rhs.codegen(scope, state)?;
Some(
scope
.block
.build(Instr::Store(lhs_value, rhs_value))
.unwrap(),
)
} }
StackValueKind::Mutable(ptr) => { mir::StmtKind::Import(_) => todo!(),
let expression = val.codegen(scope).unwrap(); mir::StmtKind::Expression(expression) => expression.codegen(scope, state),
Some(scope.block.build(Instr::Store(ptr, expression)).unwrap())
} }
} }
}
impl mir::Expression {
fn codegen<'ctx, 'a>(
&self,
scope: &mut Scope<'ctx, 'a>,
state: &State,
) -> Option<InstructionValue> {
match &self.0 {
mir::ExprKind::Variable(varref) => {
varref.0.known().expect("variable type unknown");
let v = scope
.stack_values
.get(&varref.1)
.expect("Variable reference not found?!");
Some(match v.0 {
StackValueKind::Immutable(val) => val.clone(),
StackValueKind::Mutable(val) => {
if state.should_load {
match v.1 {
// TODO probably wrong ..?
Type::Ptr(_) => val,
_ => scope.block.build(Instr::Load(val, v.1.clone())).unwrap(),
}
} else { } else {
panic!("") val
} }
} }
// mir::StmtKind::If(if_expression) => if_expression.codegen(scope), })
mir::StmtKind::Import(_) => todo!(), }
mir::StmtKind::Expression(expression) => expression.codegen(scope), mir::ExprKind::Literal(lit) => Some(lit.as_const(&mut scope.block)),
mir::ExprKind::BinOp(binop, lhs_exp, rhs_exp) => {
lhs_exp
.return_type()
.expect("No ret type in lhs?")
.1
.known()
.expect("lhs ret type is unknown");
rhs_exp
.return_type()
.expect("No ret type in rhs?")
.1
.known()
.expect("rhs ret type is unknown");
let lhs = lhs_exp
.codegen(scope, state)
.expect("lhs has no return value");
let rhs = rhs_exp
.codegen(scope, state)
.expect("rhs has no return value");
Some(match binop {
mir::BinaryOperator::Add => scope.block.build(Instr::Add(lhs, rhs)).unwrap(),
mir::BinaryOperator::Minus => scope.block.build(Instr::Sub(lhs, rhs)).unwrap(),
mir::BinaryOperator::Mult => scope.block.build(Instr::Mult(lhs, rhs)).unwrap(),
mir::BinaryOperator::And => scope.block.build(Instr::And(lhs, rhs)).unwrap(),
mir::BinaryOperator::Cmp(l) => scope
.block
.build(Instr::ICmp(l.int_predicate(), lhs, rhs))
.unwrap(),
})
}
mir::ExprKind::FunctionCall(call) => {
call.return_type
.known()
.expect("function return type unknown");
let params = call
.parameters
.iter()
.map(|e| e.codegen(scope, state).unwrap())
.collect();
let callee = scope
.functions
.get(&call.name)
.expect("function not found!");
Some(
scope
.block
.build(Instr::FunctionCall(callee.value(), params))
.unwrap(),
)
}
mir::ExprKind::If(if_expression) => if_expression.codegen(scope, state),
mir::ExprKind::Block(block) => {
let mut inner_scope = scope.with_block(scope.function.block("inner"));
if let Some(ret) = block.codegen(&mut inner_scope, state) {
inner_scope
.block
.terminate(Term::Br(scope.block.value()))
.unwrap();
Some(ret)
} else {
None
}
}
mir::ExprKind::Indexed(expression, val_t, idx_expr) => {
let array = expression.codegen(scope, state)?;
let idx = idx_expr.codegen(scope, state)?;
let mut ptr = scope
.block
.build(Instr::GetElemPtr(array, vec![idx]))
.unwrap();
if state.should_load {
ptr = scope
.block
.build(Instr::Load(
ptr,
val_t.get_type(scope.type_values, scope.types),
))
.unwrap();
}
Some(ptr)
}
mir::ExprKind::Array(expressions) => {
let instr_list = expressions
.iter()
.map(|e| e.codegen(scope, state).unwrap())
.collect::<Vec<_>>();
let instr_t = expressions
.iter()
.map(|e| e.return_type().unwrap().1)
.next()
.unwrap_or(TypeKind::Void);
let array = scope
.block
.build(Instr::ArrayAlloca(
instr_t.get_type(scope.type_values, scope.types),
instr_list.len() as u32,
))
.unwrap();
for (index, instr) in instr_list.iter().enumerate() {
let index_expr = scope
.block
.build(Instr::Constant(ConstValue::U32(index as u32)))
.unwrap();
let ptr = scope
.block
.build(Instr::GetElemPtr(array, vec![index_expr]))
.unwrap();
scope.block.build(Instr::Store(ptr, *instr)).unwrap();
}
Some(array)
}
mir::ExprKind::Accessed(expression, type_kind, field) => {
let struct_val = expression.codegen(scope, &mut state.load(true))?;
let struct_ty = expression.return_type().ok()?.1.known().ok()?;
let TypeKind::CustomType(name) = struct_ty else {
return None;
};
let TypeDefinitionKind::Struct(struct_ty) = scope.get_typedef(&name)?;
let idx = struct_ty.find_index(field)?;
let mut value = scope
.block
.build(Instr::GetStructElemPtr(struct_val, idx as u32))
.unwrap();
if state.should_load {
value = scope
.block
.build(Instr::Load(
value,
type_kind.get_type(scope.type_values, scope.types),
))
.unwrap();
}
Some(value)
}
mir::ExprKind::Struct(name, items) => {
let struct_ptr = scope
.block
.build(Instr::Alloca(
name.clone(),
Type::CustomType(*scope.type_values.get(name)?),
))
.unwrap();
for (i, (_, exp)) in items.iter().enumerate() {
let elem_ptr = scope
.block
.build(Instr::GetStructElemPtr(struct_ptr, i as u32))
.unwrap();
if let Some(val) = exp.codegen(scope, state) {
scope.block.build(Instr::Store(elem_ptr, val)).unwrap();
}
}
Some(struct_ptr)
}
} }
} }
} }
impl mir::IfExpression { impl mir::IfExpression {
fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> { fn codegen<'ctx, 'a>(
let condition = self.0.codegen(scope).unwrap(); &self,
scope: &mut Scope<'ctx, 'a>,
state: &State,
) -> Option<InstructionValue> {
let condition = self.0.codegen(scope, state).unwrap();
// Create blocks // Create blocks
let then_b = scope.function.block("then"); let then_b = scope.function.block("then");
@ -295,7 +539,7 @@ impl mir::IfExpression {
// Generate then-block content // Generate then-block content
let mut then_scope = scope.with_block(then_b); let mut then_scope = scope.with_block(then_b);
let then_res = self.1.codegen(&mut then_scope); let then_res = self.1.codegen(&mut then_scope, state);
then_scope.block.terminate(Term::Br(after_bb)).ok(); then_scope.block.terminate(Term::Br(after_bb)).ok();
let else_res = if let Some(else_block) = &self.2 { let else_res = if let Some(else_block) = &self.2 {
@ -305,7 +549,7 @@ impl mir::IfExpression {
.terminate(Term::CondBr(condition, then_bb, else_bb)) .terminate(Term::CondBr(condition, then_bb, else_bb))
.unwrap(); .unwrap();
let opt = else_block.codegen(&mut else_scope); let opt = else_block.codegen(&mut else_scope, state);
if let Some(ret) = opt { if let Some(ret) = opt {
else_scope.block.terminate(Term::Br(after_bb)).ok(); else_scope.block.terminate(Term::Br(after_bb)).ok();
@ -334,263 +578,6 @@ impl mir::IfExpression {
} }
} }
} }
impl mir::Expression {
fn codegen<'ctx, 'a>(&self, scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> {
match &self.0 {
mir::ExprKind::Variable(varref) => {
varref.0.known().expect("variable type unknown");
let v = scope
.stack_values
.get(&varref.1)
.expect("Variable reference not found?!");
Some(match v.0 {
StackValueKind::Immutable(val) => val.clone(),
StackValueKind::Mutable(val) => match v.1 {
// TODO probably wrong ..?
Type::Ptr(_) => val,
_ => scope.block.build(Instr::Load(val, v.1.clone())).unwrap(),
},
})
}
mir::ExprKind::Literal(lit) => Some(lit.as_const(&mut scope.block)),
mir::ExprKind::BinOp(binop, lhs_exp, rhs_exp) => {
lhs_exp
.return_type()
.expect("No ret type in lhs?")
.1
.known()
.expect("lhs ret type is unknown");
rhs_exp
.return_type()
.expect("No ret type in rhs?")
.1
.known()
.expect("rhs ret type is unknown");
let lhs = lhs_exp.codegen(scope).expect("lhs has no return value");
let rhs = rhs_exp.codegen(scope).expect("rhs has no return value");
Some(match binop {
mir::BinaryOperator::Add => scope.block.build(Instr::Add(lhs, rhs)).unwrap(),
mir::BinaryOperator::Minus => scope.block.build(Instr::Sub(lhs, rhs)).unwrap(),
mir::BinaryOperator::Mult => scope.block.build(Instr::Mult(lhs, rhs)).unwrap(),
mir::BinaryOperator::And => scope.block.build(Instr::And(lhs, rhs)).unwrap(),
mir::BinaryOperator::Cmp(l) => scope
.block
.build(Instr::ICmp(l.int_predicate(), lhs, rhs))
.unwrap(),
})
}
mir::ExprKind::FunctionCall(call) => {
call.return_type
.known()
.expect("function return type unknown");
let params = call
.parameters
.iter()
.map(|e| e.codegen(scope).unwrap())
.collect();
let callee = scope
.functions
.get(&call.name)
.expect("function not found!");
Some(
scope
.block
.build(Instr::FunctionCall(callee.value(), params))
.unwrap(),
)
}
mir::ExprKind::If(if_expression) => if_expression.codegen(scope),
mir::ExprKind::Block(block) => {
let mut inner_scope = scope.with_block(scope.function.block("inner"));
if let Some(ret) = block.codegen(&mut inner_scope) {
inner_scope
.block
.terminate(Term::Br(scope.block.value()))
.unwrap();
Some(ret)
} else {
None
}
}
mir::ExprKind::ArrayIndex(expression, val_t, idx) => {
let array = expression.codegen(scope)?;
let ptr = scope
.block
.build(Instr::GetElemPtr(array, vec![*idx as u32]))
.unwrap();
Some(
scope
.block
.build(Instr::Load(
ptr,
val_t.get_type(scope.type_values, scope.types),
))
.unwrap(),
)
}
mir::ExprKind::Array(expressions) => {
let instr_list = expressions
.iter()
.map(|e| e.codegen(scope).unwrap())
.collect::<Vec<_>>();
let instr_t = expressions
.iter()
.map(|e| e.return_type().unwrap().1)
.next()
.unwrap_or(TypeKind::Void);
let array = scope
.block
.build(Instr::ArrayAlloca(
instr_t.get_type(scope.type_values, scope.types),
instr_list.len() as u32,
))
.unwrap();
for (i, instr) in instr_list.iter().enumerate() {
let ptr = scope
.block
.build(Instr::GetElemPtr(array, vec![i as u32]))
.unwrap();
scope.block.build(Instr::Store(ptr, *instr)).unwrap();
}
Some(array)
}
mir::ExprKind::StructIndex(expression, type_kind, field) => {
let struct_val = expression.codegen(scope)?;
let struct_ty = expression.return_type().ok()?.1.known().ok()?;
let TypeKind::CustomType(name) = struct_ty else {
return None;
};
let TypeDefinitionKind::Struct(struct_ty) = scope.get_typedef(&name)?;
let idx = struct_ty.find_index(field)?;
let ptr = scope
.block
.build(Instr::GetStructElemPtr(struct_val, idx as u32))
.unwrap();
dbg!(&type_kind.get_type(scope.type_values, scope.types));
Some(
scope
.block
.build(Instr::Load(
ptr,
type_kind.get_type(scope.type_values, scope.types),
))
.unwrap(),
)
}
mir::ExprKind::Struct(name, items) => {
let struct_ptr = scope
.block
.build(Instr::Alloca(
name.clone(),
Type::CustomType(*scope.type_values.get(name)?),
))
.unwrap();
for (i, (_, exp)) in items.iter().enumerate() {
let elem_ptr = scope
.block
.build(Instr::GetStructElemPtr(struct_ptr, i as u32))
.unwrap();
if let Some(val) = exp.codegen(scope) {
scope.block.build(Instr::Store(elem_ptr, val)).unwrap();
}
}
Some(struct_ptr)
}
}
}
}
impl IndexedVariableReference {
fn get_stack_value(&self, scope: &mut Scope, load_after_gep: bool) -> Option<StackValue> {
match &self.kind {
mir::IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
scope.stack_values.get(name).cloned().map(|v| v)
}
mir::IndexedVariableReferenceKind::ArrayIndex(inner, idx) => {
let inner_stack_val = inner.get_stack_value(scope, true)?;
let mut gep_instr = scope
.block
.build(Instr::GetElemPtr(
unsafe { *inner_stack_val.0.get_instr() },
vec![*idx as u32],
))
.unwrap();
match &inner_stack_val.1 {
Type::Ptr(inner_ty) => {
if load_after_gep {
gep_instr = scope
.block
.build(Instr::Load(gep_instr, *inner_ty.clone()))
.unwrap()
}
Some(StackValue(
inner_stack_val.0.with_instr(gep_instr),
*inner_ty.clone(),
))
}
_ => panic!("Tried to codegen indexing a non-indexable value!"),
}
}
mir::IndexedVariableReferenceKind::StructIndex(inner, field) => {
let inner_stack_val = inner.get_stack_value(scope, true)?;
let (instr_value, inner_ty) = if let Type::Ptr(inner_ty) = inner_stack_val.1 {
if let Type::CustomType(ty_val) = *inner_ty {
match scope.types.get(&ty_val).unwrap() {
TypeDefinitionKind::Struct(struct_type) => {
let idx = struct_type.find_index(field)?;
let field_ty = struct_type
.get_field_ty(field)?
.get_type(scope.type_values, scope.types);
let mut gep_instr = scope
.block
.build(Instr::GetStructElemPtr(
unsafe { *inner_stack_val.0.get_instr() },
idx,
))
.unwrap();
if load_after_gep {
gep_instr = scope
.block
.build(Instr::Load(gep_instr, field_ty.clone()))
.unwrap()
}
Some((gep_instr, field_ty))
}
}
} else {
None
}
} else {
None
}?;
Some(StackValue(
inner_stack_val.0.with_instr(instr_value),
Type::Ptr(Box::new(inner_ty)),
))
}
}
}
}
impl mir::CmpOperator { impl mir::CmpOperator {
fn int_predicate(&self) -> CmpPredicate { fn int_predicate(&self) -> CmpPredicate {
match self { match self {
@ -604,30 +591,6 @@ impl mir::CmpOperator {
} }
} }
impl mir::Block {
fn codegen<'ctx, 'a>(&self, mut scope: &mut Scope<'ctx, 'a>) -> Option<InstructionValue> {
for stmt in &self.statements {
stmt.codegen(&mut scope);
}
if let Some((kind, expr)) = &self.return_expression {
if let Some(ret) = expr.codegen(&mut scope) {
match kind {
mir::ReturnKind::Hard => {
scope.block.terminate(Term::Ret(ret)).unwrap();
None
}
mir::ReturnKind::Soft => Some(ret),
}
} else {
None
}
} else {
None
}
}
}
impl mir::Literal { impl mir::Literal {
fn as_const(&self, block: &mut Block) -> InstructionValue { fn as_const(&self, block: &mut Block) -> InstructionValue {
block.build(self.as_const_kind()).unwrap() block.build(self.as_const_kind()).unwrap()
@ -647,7 +610,7 @@ impl mir::Literal {
mir::Literal::U128(val) => ConstValue::U128(val), mir::Literal::U128(val) => ConstValue::U128(val),
mir::Literal::Bool(val) => ConstValue::Bool(val), mir::Literal::Bool(val) => ConstValue::Bool(val),
mir::Literal::String(val) => ConstValue::StringPtr(val.clone()), mir::Literal::String(val) => ConstValue::StringPtr(val.clone()),
mir::Literal::Vague(_) => panic!("Got vague literal!"), mir::Literal::Vague(VagueLiteral::Number(val)) => ConstValue::I32(val as i32),
}) })
} }
} }

View File

@ -164,10 +164,10 @@ impl Display for ExprKind {
ExprKind::FunctionCall(fc) => Display::fmt(fc, f), ExprKind::FunctionCall(fc) => Display::fmt(fc, f),
ExprKind::If(if_exp) => Display::fmt(&if_exp, f), ExprKind::If(if_exp) => Display::fmt(&if_exp, f),
ExprKind::Block(block) => Display::fmt(block, f), ExprKind::Block(block) => Display::fmt(block, f),
ExprKind::ArrayIndex(expression, elem_ty, idx) => { ExprKind::Indexed(expression, elem_ty, idx_expr) => {
Display::fmt(&expression, f)?; Display::fmt(&expression, f)?;
write!(f, "<{}>", elem_ty)?; write!(f, "<{}>", elem_ty)?;
write_index(f, *idx) write_index(f, idx_expr)
} }
ExprKind::Array(expressions) => { ExprKind::Array(expressions) => {
f.write_char('[')?; f.write_char('[')?;
@ -203,7 +203,7 @@ impl Display for ExprKind {
} }
f.write_char('}') f.write_char('}')
} }
ExprKind::StructIndex(expression, type_kind, name) => { ExprKind::Accessed(expression, type_kind, name) => {
Display::fmt(&expression, f)?; Display::fmt(&expression, f)?;
write_access(f, name)?; write_access(f, name)?;
write!(f, "<{}>", type_kind) write!(f, "<{}>", type_kind)
@ -242,22 +242,6 @@ impl Display for NamedVariableRef {
} }
} }
impl Display for IndexedVariableReference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind {
IndexedVariableReferenceKind::Named(name) => Display::fmt(name, f),
IndexedVariableReferenceKind::ArrayIndex(var_ref, idx) => {
Display::fmt(&var_ref, f)?;
write_index(f, *idx)
}
IndexedVariableReferenceKind::StructIndex(var_ref, name) => {
Display::fmt(&var_ref, f)?;
write_access(f, name)
}
}
}
}
impl Display for Literal { impl Display for Literal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
@ -309,7 +293,7 @@ impl Display for Metadata {
} }
} }
fn write_index(f: &mut std::fmt::Formatter<'_>, idx: u64) -> std::fmt::Result { fn write_index(f: &mut std::fmt::Formatter<'_>, idx: impl std::fmt::Display) -> std::fmt::Result {
f.write_char('[')?; f.write_char('[')?;
Display::fmt(&idx, f)?; Display::fmt(&idx, f)?;
f.write_char(']') f.write_char(']')

View File

@ -47,13 +47,33 @@ impl StructType {
} }
} }
pub trait ReturnType { enum BlockReturn<'b> {
/// Return the return type of this node Early(&'b Statement),
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther>; Normal(ReturnKind, &'b Expression),
} }
impl ReturnType for Block { impl Block {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { fn return_expr(&self) -> Result<BlockReturn, ReturnTypeOther> {
let mut early_return = None;
for statement in &self.statements {
let ret = statement.return_type();
if let Ok((ReturnKind::Hard, _)) = ret {
early_return = Some(statement);
}
}
if let Some(s) = early_return {
return Ok(BlockReturn::Early(s));
}
self.return_expression
.as_ref()
.map(|(r, e)| BlockReturn::Normal(*r, e))
.ok_or(ReturnTypeOther::NoBlockReturn(self.meta))
}
pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
let mut early_return = None; let mut early_return = None;
for statement in &self.statements { for statement in &self.statements {
@ -72,28 +92,47 @@ impl ReturnType for Block {
.ok_or(ReturnTypeOther::NoBlockReturn(self.meta)) .ok_or(ReturnTypeOther::NoBlockReturn(self.meta))
.and_then(|(kind, stmt)| Ok((*kind, stmt.return_type()?.1))) .and_then(|(kind, stmt)| Ok((*kind, stmt.return_type()?.1)))
} }
pub fn backing_var(&self) -> Option<&NamedVariableRef> {
match self.return_expr().ok()? {
BlockReturn::Early(statement) => statement.backing_var(),
BlockReturn::Normal(kind, expr) => {
if kind == ReturnKind::Soft {
expr.backing_var()
} else {
None
}
}
}
}
} }
impl ReturnType for Statement { impl Statement {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
use StmtKind::*; use StmtKind::*;
match &self.0 { match &self.0 {
Let(var, _, expr) => if_hard( Let(var, _, expr) => if_hard(
expr.return_type()?, expr.return_type()?,
Err(ReturnTypeOther::Let(var.2 + expr.1)), Err(ReturnTypeOther::Let(var.2 + expr.1)),
), ),
Set(var, expr) => if_hard( Set(lhs, rhs) => if_hard(rhs.return_type()?, Err(ReturnTypeOther::Set(lhs.1 + rhs.1))),
expr.return_type()?,
Err(ReturnTypeOther::Set(var.meta + expr.1)),
),
Import(_) => todo!(), Import(_) => todo!(),
Expression(expression) => expression.return_type(), Expression(expression) => expression.return_type(),
} }
} }
pub fn backing_var(&self) -> Option<&NamedVariableRef> {
match &self.0 {
StmtKind::Let(_, _, _) => None,
StmtKind::Set(_, _) => None,
StmtKind::Import(_) => None,
StmtKind::Expression(expr) => expr.backing_var(),
}
}
} }
impl ReturnType for Expression { impl Expression {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
use ExprKind::*; use ExprKind::*;
match &self.0 { match &self.0 {
Literal(lit) => Ok((ReturnKind::Soft, lit.as_type())), Literal(lit) => Ok((ReturnKind::Soft, lit.as_type())),
@ -107,7 +146,7 @@ impl ReturnType for Expression {
Block(block) => block.return_type(), Block(block) => block.return_type(),
FunctionCall(fcall) => fcall.return_type(), FunctionCall(fcall) => fcall.return_type(),
If(expr) => expr.return_type(), If(expr) => expr.return_type(),
ArrayIndex(expression, _, _) => { Indexed(expression, _, _) => {
let expr_type = expression.return_type()?; let expr_type = expression.return_type()?;
if let (_, TypeKind::Array(elem_ty, _)) = expr_type { if let (_, TypeKind::Array(elem_ty, _)) = expr_type {
Ok((ReturnKind::Soft, *elem_ty)) Ok((ReturnKind::Soft, *elem_ty))
@ -126,14 +165,29 @@ impl ReturnType for Expression {
TypeKind::Array(Box::new(first.1), expressions.len() as u64), TypeKind::Array(Box::new(first.1), expressions.len() as u64),
)) ))
} }
StructIndex(_, type_kind, _) => Ok((ReturnKind::Soft, type_kind.clone())), Accessed(_, type_kind, _) => Ok((ReturnKind::Soft, type_kind.clone())),
Struct(name, _) => Ok((ReturnKind::Soft, TypeKind::CustomType(name.clone()))), Struct(name, _) => Ok((ReturnKind::Soft, TypeKind::CustomType(name.clone()))),
} }
} }
pub fn backing_var(&self) -> Option<&NamedVariableRef> {
match &self.0 {
ExprKind::Variable(var_ref) => Some(var_ref),
ExprKind::Indexed(lhs, _, _) => lhs.backing_var(),
ExprKind::Accessed(lhs, _, _) => lhs.backing_var(),
ExprKind::Array(_) => None,
ExprKind::Struct(_, _) => None,
ExprKind::Literal(_) => None,
ExprKind::BinOp(_, _, _) => None,
ExprKind::FunctionCall(_) => None,
ExprKind::If(_) => None,
ExprKind::Block(block) => block.backing_var(),
}
}
} }
impl ReturnType for IfExpression { impl IfExpression {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
let then_r = self.1.return_type()?; let then_r = self.1.return_type()?;
if let Some(else_b) = &self.2 { if let Some(else_b) = &self.2 {
let else_r = else_b.return_type()?; let else_r = else_b.return_type()?;
@ -150,14 +204,14 @@ impl ReturnType for IfExpression {
} }
} }
impl ReturnType for NamedVariableRef { impl NamedVariableRef {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
Ok((ReturnKind::Soft, self.0.clone())) Ok((ReturnKind::Soft, self.0.clone()))
} }
} }
impl ReturnType for FunctionCall { impl FunctionCall {
fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> { pub fn return_type(&self) -> Result<(ReturnKind, TypeKind), ReturnTypeOther> {
Ok((ReturnKind::Soft, self.return_type.clone())) Ok((ReturnKind::Soft, self.return_type.clone()))
} }
} }
@ -215,72 +269,6 @@ impl TypeKind {
} }
} }
impl IndexedVariableReference {
pub fn get_name(&self) -> String {
match &self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => name.clone(),
IndexedVariableReferenceKind::ArrayIndex(inner, idx) => {
format!("{}[{}]", inner.get_name(), idx)
}
IndexedVariableReferenceKind::StructIndex(inner, name) => {
format!("{}.{}", inner.get_name(), name)
}
}
}
/// Retrieve the indexed type that this variable reference is pointing to
pub fn retrieve_type(&self, scope: &pass::Scope) -> Result<TypeKind, ErrorKind> {
match &self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(ty, _, _)) => Ok(ty.clone()),
IndexedVariableReferenceKind::ArrayIndex(inner, _) => {
let inner_ty = inner.retrieve_type(scope)?;
match inner_ty {
TypeKind::Array(type_kind, _) => Ok(*type_kind),
_ => Err(ErrorKind::TriedIndexingNonArray(inner_ty)),
}
}
IndexedVariableReferenceKind::StructIndex(inner, field_name) => {
let inner_ty = inner.retrieve_type(scope)?;
match inner_ty {
TypeKind::CustomType(struct_name) => {
let struct_ty = scope
.get_struct_type(&struct_name)
.ok_or(ErrorKind::NoSuchType(struct_name.clone()))?;
struct_ty
.get_field_ty(&field_name)
.ok_or(ErrorKind::NoSuchField(field_name.clone()))
.cloned()
}
_ => Err(ErrorKind::TriedAccessingNonStruct(inner_ty)),
}
}
}
}
pub fn into_typeref<'s>(&mut self, typerefs: &'s ScopeTypeRefs) -> Option<(bool, TypeRef<'s>)> {
match &mut self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(ty, name, _)) => {
let t = typerefs.find_var(name)?;
*ty = t.1.as_type();
Some(t)
}
IndexedVariableReferenceKind::ArrayIndex(inner, _) => inner.into_typeref(typerefs),
IndexedVariableReferenceKind::StructIndex(inner, _) => inner.into_typeref(typerefs),
}
}
pub fn resolve_ref<'s>(&mut self, typerefs: &'s TypeRefs) -> Result<TypeKind, ErrorKind> {
match &mut self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(ty, _, _)) => {
*ty = ty.resolve_ref(typerefs);
Ok(ty.clone())
}
IndexedVariableReferenceKind::ArrayIndex(inner, _) => inner.resolve_ref(typerefs),
IndexedVariableReferenceKind::StructIndex(inner, _) => inner.resolve_ref(typerefs),
}
}
}
#[derive(Debug, Clone, thiserror::Error)] #[derive(Debug, Clone, thiserror::Error)]
pub enum EqualsIssue { pub enum EqualsIssue {
#[error("Function is already defined locally at {:?}", (.0).range)] #[error("Function is already defined locally at {:?}", (.0).range)]

View File

@ -2,6 +2,7 @@ use std::{
cell::RefCell, cell::RefCell,
collections::HashMap, collections::HashMap,
convert::Infallible, convert::Infallible,
fmt::Error,
fs::{self}, fs::{self},
path::PathBuf, path::PathBuf,
rc::Rc, rc::Rc,
@ -11,7 +12,7 @@ use crate::{compile_module, ReidError};
use super::{ use super::{
pass::{Pass, PassState}, pass::{Pass, PassState},
types::EqualsIssue, r#impl::EqualsIssue,
Context, FunctionDefinition, Import, Metadata, Module, Context, FunctionDefinition, Import, Metadata, Module,
}; };
@ -54,9 +55,12 @@ pub fn compile_std() -> super::Module {
/// MIR. /// MIR.
pub struct LinkerPass; pub struct LinkerPass;
type LinkerPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
impl Pass for LinkerPass { impl Pass for LinkerPass {
type Data = ();
type TError = ErrorKind; type TError = ErrorKind;
fn context(&mut self, context: &mut Context, mut state: PassState<Self::TError>) { fn context(&mut self, context: &mut Context, mut state: LinkerPassState) {
let mains = context let mains = context
.modules .modules
.iter() .iter()

View File

@ -7,12 +7,12 @@ use std::{collections::HashMap, path::PathBuf};
use crate::token_stream::TokenRange; use crate::token_stream::TokenRange;
mod display; mod display;
pub mod r#impl;
pub mod linker; pub mod linker;
pub mod pass; pub mod pass;
pub mod typecheck; pub mod typecheck;
pub mod typeinference; pub mod typeinference;
pub mod typerefs; pub mod typerefs;
pub mod types;
#[derive(Debug, Default, Clone, Copy)] #[derive(Debug, Default, Clone, Copy)]
pub struct Metadata { pub struct Metadata {
@ -192,8 +192,8 @@ pub struct Import(pub Vec<String>, pub Metadata);
#[derive(Debug)] #[derive(Debug)]
pub enum ExprKind { pub enum ExprKind {
Variable(NamedVariableRef), Variable(NamedVariableRef),
ArrayIndex(Box<Expression>, TypeKind, u64), Indexed(Box<Expression>, TypeKind, Box<Expression>),
StructIndex(Box<Expression>, TypeKind, String), Accessed(Box<Expression>, TypeKind, String),
Array(Vec<Expression>), Array(Vec<Expression>),
Struct(String, Vec<(String, Expression)>), Struct(String, Vec<(String, Expression)>),
Literal(Literal), Literal(Literal),
@ -262,24 +262,11 @@ pub struct Block {
#[derive(Debug)] #[derive(Debug)]
pub struct Statement(pub StmtKind, pub Metadata); pub struct Statement(pub StmtKind, pub Metadata);
#[derive(Debug)]
pub struct IndexedVariableReference {
pub kind: IndexedVariableReferenceKind,
pub meta: Metadata,
}
#[derive(Debug)]
pub enum IndexedVariableReferenceKind {
Named(NamedVariableRef),
ArrayIndex(Box<IndexedVariableReference>, u64),
StructIndex(Box<IndexedVariableReference>, String),
}
#[derive(Debug)] #[derive(Debug)]
pub enum StmtKind { pub enum StmtKind {
/// Variable name++mutability+type, evaluation /// Variable name++mutability+type, evaluation
Let(NamedVariableRef, bool, Expression), Let(NamedVariableRef, bool, Expression),
Set(IndexedVariableReference, Expression), Set(Expression, Expression),
Import(Import), Import(Import),
Expression(Expression), Expression(Expression),
} }

View File

@ -106,21 +106,23 @@ impl<T: Clone + std::fmt::Debug> Storage<T> {
} }
#[derive(Clone, Default, Debug)] #[derive(Clone, Default, Debug)]
pub struct Scope { pub struct Scope<Data: Clone + Default> {
pub function_returns: Storage<ScopeFunction>, pub function_returns: Storage<ScopeFunction>,
pub variables: Storage<ScopeVariable>, pub variables: Storage<ScopeVariable>,
pub types: Storage<TypeDefinitionKind>, pub types: Storage<TypeDefinitionKind>,
/// Hard Return type of this scope, if inside a function /// Hard Return type of this scope, if inside a function
pub return_type_hint: Option<TypeKind>, pub return_type_hint: Option<TypeKind>,
pub data: Data,
} }
impl Scope { impl<Data: Clone + Default> Scope<Data> {
pub fn inner(&self) -> Scope { pub fn inner(&self) -> Scope<Data> {
Scope { Scope {
function_returns: self.function_returns.clone(), function_returns: self.function_returns.clone(),
variables: self.variables.clone(), variables: self.variables.clone(),
types: self.types.clone(), types: self.types.clone(),
return_type_hint: self.return_type_hint.clone(), return_type_hint: self.return_type_hint.clone(),
data: self.data.clone(),
} }
} }
@ -144,14 +146,14 @@ pub struct ScopeVariable {
pub mutable: bool, pub mutable: bool,
} }
pub struct PassState<'st, 'sc, TError: STDError + Clone> { pub struct PassState<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> {
state: &'st mut State<TError>, state: &'st mut State<TError>,
pub scope: &'sc mut Scope, pub scope: &'sc mut Scope<Data>,
inner: Vec<Scope>, inner: Vec<Scope<Data>>,
} }
impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> { impl<'st, 'sc, Data: Clone + Default, TError: STDError + Clone> PassState<'st, 'sc, Data, TError> {
fn from(state: &'st mut State<TError>, scope: &'sc mut Scope) -> Self { fn from(state: &'st mut State<TError>, scope: &'sc mut Scope<Data>) -> Self {
PassState { PassState {
state, state,
scope, scope,
@ -186,7 +188,7 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> {
} }
} }
pub fn inner(&mut self) -> PassState<TError> { pub fn inner(&mut self) -> PassState<Data, TError> {
self.inner.push(self.scope.inner()); self.inner.push(self.scope.inner());
let scope = self.inner.last_mut().unwrap(); let scope = self.inner.last_mut().unwrap();
PassState { PassState {
@ -198,19 +200,21 @@ impl<'st, 'sc, TError: STDError + Clone> PassState<'st, 'sc, TError> {
} }
pub trait Pass { pub trait Pass {
type Data: Clone + Default;
type TError: STDError + Clone; type TError: STDError + Clone;
fn context(&mut self, _context: &mut Context, mut _state: PassState<Self::TError>) {} fn context(&mut self, _context: &mut Context, mut _state: PassState<Self::Data, Self::TError>) {
fn module(&mut self, _module: &mut Module, mut _state: PassState<Self::TError>) {} }
fn module(&mut self, _module: &mut Module, mut _state: PassState<Self::Data, Self::TError>) {}
fn function( fn function(
&mut self, &mut self,
_function: &mut FunctionDefinition, _function: &mut FunctionDefinition,
mut _state: PassState<Self::TError>, mut _state: PassState<Self::Data, Self::TError>,
) { ) {
} }
fn block(&mut self, _block: &mut Block, mut _state: PassState<Self::TError>) {} fn block(&mut self, _block: &mut Block, mut _state: PassState<Self::Data, Self::TError>) {}
fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState<Self::TError>) {} fn stmt(&mut self, _stmt: &mut Statement, mut _state: PassState<Self::Data, Self::TError>) {}
fn expr(&mut self, _expr: &mut Expression, mut _state: PassState<Self::TError>) {} fn expr(&mut self, _expr: &mut Expression, mut _state: PassState<Self::Data, Self::TError>) {}
} }
impl Context { impl Context {
@ -226,7 +230,12 @@ impl Context {
} }
impl Module { impl Module {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
for typedef in &self.typedefs { for typedef in &self.typedefs {
let kind = match &typedef.kind { let kind = match &typedef.kind {
TypeDefinitionKind::Struct(fields) => TypeDefinitionKind::Struct(fields.clone()), TypeDefinitionKind::Struct(fields) => TypeDefinitionKind::Struct(fields.clone()),
@ -256,7 +265,12 @@ impl Module {
} }
impl FunctionDefinition { impl FunctionDefinition {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
for param in &self.parameters { for param in &self.parameters {
scope scope
.variables .variables
@ -283,7 +297,12 @@ impl FunctionDefinition {
} }
impl Block { impl Block {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
let mut scope = scope.inner(); let mut scope = scope.inner();
for statement in &mut self.statements { for statement in &mut self.statements {
@ -295,7 +314,12 @@ impl Block {
} }
impl Statement { impl Statement {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
match &mut self.0 { match &mut self.0 {
StmtKind::Let(_, _, expression) => { StmtKind::Let(_, _, expression) => {
expression.pass(pass, state, scope); expression.pass(pass, state, scope);
@ -332,7 +356,12 @@ impl Statement {
} }
impl Expression { impl Expression {
fn pass<T: Pass>(&mut self, pass: &mut T, state: &mut State<T::TError>, scope: &mut Scope) { fn pass<T: Pass>(
&mut self,
pass: &mut T,
state: &mut State<T::TError>,
scope: &mut Scope<T::Data>,
) {
pass.expr(self, PassState::from(state, scope)); pass.expr(self, PassState::from(state, scope));
} }
} }

View File

@ -8,7 +8,6 @@ use VagueType as Vague;
use super::{ use super::{
pass::{Pass, PassState, ScopeFunction, ScopeVariable, Storage}, pass::{Pass, PassState, ScopeFunction, ScopeVariable, Storage},
typerefs::TypeRefs, typerefs::TypeRefs,
types::ReturnType,
}; };
#[derive(thiserror::Error, Debug, Clone)] #[derive(thiserror::Error, Debug, Clone)]
@ -31,7 +30,7 @@ pub enum ErrorKind {
FunctionAlreadyDefined(String), FunctionAlreadyDefined(String),
#[error("Variable not defined: {0}")] #[error("Variable not defined: {0}")]
VariableAlreadyDefined(String), VariableAlreadyDefined(String),
#[error("Variable not mutable: {0}")] #[error("Variable {0} is not declared as mutable")]
VariableNotMutable(String), VariableNotMutable(String),
#[error("Function {0} was given {1} parameters, but {2} were expected")] #[error("Function {0} was given {1} parameters, but {2} were expected")]
InvalidAmountParameters(String, usize, usize), InvalidAmountParameters(String, usize, usize),
@ -55,6 +54,8 @@ pub enum ErrorKind {
DuplicateTypeName(String), DuplicateTypeName(String),
#[error("Recursive type definition: {0}.{1}")] #[error("Recursive type definition: {0}.{1}")]
RecursiveTypeDefinition(String, String), RecursiveTypeDefinition(String, String),
#[error("This type of expression can not be used for assignment")]
InvalidSetExpression,
} }
/// Struct used to implement a type-checking pass that can be performed on the /// Struct used to implement a type-checking pass that can be performed on the
@ -63,40 +64,13 @@ pub struct TypeCheck<'t> {
pub refs: &'t TypeRefs, pub refs: &'t TypeRefs,
} }
fn check_typedefs_for_recursion<'a, 'b>( type TypecheckPassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
defmap: &'b HashMap<&'a String, &'b TypeDefinition>,
typedef: &'b TypeDefinition,
mut seen: HashSet<String>,
state: &mut PassState<ErrorKind>,
) {
match &typedef.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) {
if let TypeKind::CustomType(name) = field_ty {
if seen.contains(name) {
state.ok::<_, Infallible>(
Err(ErrorKind::RecursiveTypeDefinition(
typedef.name.clone(),
name.clone(),
)),
typedef.meta,
);
} else {
seen.insert(name.clone());
if let Some(inner_typedef) = defmap.get(name) {
check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state)
}
}
}
}
}
}
}
impl<'t> Pass for TypeCheck<'t> { impl<'t> Pass for TypeCheck<'t> {
type Data = ();
type TError = ErrorKind; type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) { fn module(&mut self, module: &mut Module, mut state: TypecheckPassState) {
let mut defmap = HashMap::new(); let mut defmap = HashMap::new();
for typedef in &module.typedefs { for typedef in &module.typedefs {
let TypeDefinition { name, kind, meta } = &typedef; let TypeDefinition { name, kind, meta } = &typedef;
@ -136,11 +110,41 @@ impl<'t> Pass for TypeCheck<'t> {
} }
} }
fn check_typedefs_for_recursion<'a, 'b>(
defmap: &'b HashMap<&'a String, &'b TypeDefinition>,
typedef: &'b TypeDefinition,
mut seen: HashSet<String>,
state: &mut TypecheckPassState,
) {
match &typedef.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field_ty in fields.iter().map(|StructField(_, ty, _)| ty) {
if let TypeKind::CustomType(name) = field_ty {
if seen.contains(name) {
state.ok::<_, Infallible>(
Err(ErrorKind::RecursiveTypeDefinition(
typedef.name.clone(),
name.clone(),
)),
typedef.meta,
);
} else {
seen.insert(name.clone());
if let Some(inner_typedef) = defmap.get(name) {
check_typedefs_for_recursion(defmap, inner_typedef, seen.clone(), state)
}
}
}
}
}
}
}
impl FunctionDefinition { impl FunctionDefinition {
fn typecheck( fn typecheck(
&mut self, &mut self,
hints: &TypeRefs, hints: &TypeRefs,
state: &mut PassState<ErrorKind>, state: &mut TypecheckPassState,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {
for param in &self.parameters { for param in &self.parameters {
let param_t = state.or_else( let param_t = state.or_else(
@ -185,7 +189,7 @@ impl FunctionDefinition {
impl Block { impl Block {
fn typecheck( fn typecheck(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypecheckPassState,
typerefs: &TypeRefs, typerefs: &TypeRefs,
hint_t: Option<&TypeKind>, hint_t: Option<&TypeKind>,
) -> Result<(ReturnKind, TypeKind), ErrorKind> { ) -> Result<(ReturnKind, TypeKind), ErrorKind> {
@ -250,51 +254,46 @@ impl Block {
state.ok(res, variable_reference.2); state.ok(res, variable_reference.2);
None None
} }
StmtKind::Set(variable_reference, expression) => { StmtKind::Set(lhs, rhs) => {
// Update typing from reference // Typecheck expression and coerce to variable type
variable_reference.resolve_ref(&typerefs)?; let lhs_res = lhs.typecheck(&mut state, typerefs, None);
// If expression resolution itself was erronous, resolve as
if let Some(var) = state // Unknown.
.ok( let lhs_ty = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1);
variable_reference
.get_variable(&state.scope.variables, &state.scope.types),
variable_reference.meta,
)
.flatten()
{
let field_ty = variable_reference.retrieve_type(&state.scope)?;
dbg!(&field_ty);
// Typecheck expression and coerce to variable type // Typecheck expression and coerce to variable type
let res = expression.typecheck(&mut state, &typerefs, Some(&field_ty)); let res = rhs.typecheck(&mut state, &typerefs, Some(&lhs_ty));
// If expression resolution itself was erronous, resolve as // If expression resolution itself was erronous, resolve as
// Unknown. // Unknown.
let expr_ty = let rhs_ty = state.or_else(res, TypeKind::Vague(Vague::Unknown), rhs.1);
state.or_else(res, TypeKind::Vague(Vague::Unknown), expression.1);
// Make sure the expression and variable type to really // Make sure the expression and variable type to really
// be the same // be the same
state.ok( state.ok(lhs_ty.collapse_into(&rhs_ty), lhs.1 + rhs.1);
expr_ty.collapse_into(&field_ty),
variable_reference.meta + expression.1,
);
if !var.mutable { if let Some(named_var) = lhs.backing_var() {
if let Some(scope_var) = state.scope.variables.get(&named_var.1) {
if !scope_var.mutable {
state.ok::<_, Infallible>( state.ok::<_, Infallible>(
Err(ErrorKind::VariableNotMutable(variable_reference.get_name())), Err(ErrorKind::VariableNotMutable(named_var.1.clone())),
variable_reference.meta, lhs.1,
); );
} }
}
None
} else { } else {
state.ok::<_, Infallible>( state.ok::<_, Infallible>(Err(ErrorKind::InvalidSetExpression), lhs.1);
Err(ErrorKind::VariableNotDefined(variable_reference.get_name())),
variable_reference.meta,
);
None
} }
// TODO add error about variable mutability, need to check
// that the expression is based on a variable first though..
// if true {
// state.ok::<_, Infallible>(
// Err(ErrorKind::VariableNotMutable(variable_reference.get_name())),
// variable_reference.meta,
// );
// }
None
} }
StmtKind::Import(_) => todo!(), // TODO StmtKind::Import(_) => todo!(), // TODO
StmtKind::Expression(expression) => { StmtKind::Expression(expression) => {
@ -345,8 +344,8 @@ impl Block {
impl Expression { impl Expression {
fn typecheck( fn typecheck(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypecheckPassState,
hints: &TypeRefs, typerefs: &TypeRefs,
hint_t: Option<&TypeKind>, hint_t: Option<&TypeKind>,
) -> Result<TypeKind, ErrorKind> { ) -> Result<TypeKind, ErrorKind> {
match &mut self.0 { match &mut self.0 {
@ -363,11 +362,11 @@ impl Expression {
TypeKind::Vague(Vague::Unknown), TypeKind::Vague(Vague::Unknown),
var_ref.2, var_ref.2,
) )
.resolve_ref(hints); .resolve_ref(typerefs);
// Update typing to be more accurate // Update typing to be more accurate
var_ref.0 = state.or_else( var_ref.0 = state.or_else(
var_ref.0.resolve_ref(hints).collapse_into(&existing), var_ref.0.resolve_ref(typerefs).collapse_into(&existing),
TypeKind::Vague(Vague::Unknown), TypeKind::Vague(Vague::Unknown),
var_ref.2, var_ref.2,
); );
@ -381,15 +380,15 @@ impl Expression {
ExprKind::BinOp(op, lhs, rhs) => { ExprKind::BinOp(op, lhs, rhs) => {
// TODO make sure lhs and rhs can actually do this binary // TODO make sure lhs and rhs can actually do this binary
// operation once relevant // operation once relevant
let lhs_res = lhs.typecheck(state, &hints, None); let lhs_res = lhs.typecheck(state, &typerefs, None);
let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1); let lhs_type = state.or_else(lhs_res, TypeKind::Vague(Vague::Unknown), lhs.1);
let rhs_res = rhs.typecheck(state, &hints, Some(&lhs_type)); let rhs_res = rhs.typecheck(state, &typerefs, Some(&lhs_type));
let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1); let rhs_type = state.or_else(rhs_res, TypeKind::Vague(Vague::Unknown), rhs.1);
if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) { if let Some(collapsed) = state.ok(rhs_type.collapse_into(&rhs_type), self.1) {
// Try to coerce both sides again with collapsed type // Try to coerce both sides again with collapsed type
lhs.typecheck(state, &hints, Some(&collapsed)).ok(); lhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
rhs.typecheck(state, &hints, Some(&collapsed)).ok(); rhs.typecheck(state, &typerefs, Some(&collapsed)).ok();
} }
let both_t = lhs_type.collapse_into(&rhs_type)?; let both_t = lhs_type.collapse_into(&rhs_type)?;
@ -429,7 +428,7 @@ impl Expression {
function_call.parameters.iter_mut().zip(true_params_iter) function_call.parameters.iter_mut().zip(true_params_iter)
{ {
// Typecheck every param separately // Typecheck every param separately
let param_res = param.typecheck(state, &hints, Some(&true_param_t)); let param_res = param.typecheck(state, &typerefs, Some(&true_param_t));
let param_t = let param_t =
state.or_else(param_res, TypeKind::Vague(Vague::Unknown), param.1); state.or_else(param_res, TypeKind::Vague(Vague::Unknown), param.1);
state.ok(param_t.collapse_into(&true_param_t), param.1); state.ok(param_t.collapse_into(&true_param_t), param.1);
@ -439,29 +438,29 @@ impl Expression {
// return type // return type
let ret_t = f let ret_t = f
.ret .ret
.collapse_into(&function_call.return_type.resolve_ref(hints))?; .collapse_into(&function_call.return_type.resolve_ref(typerefs))?;
// Update typing to be more accurate // Update typing to be more accurate
function_call.return_type = ret_t.clone(); function_call.return_type = ret_t.clone();
Ok(ret_t.resolve_ref(hints)) Ok(ret_t.resolve_ref(typerefs))
} else { } else {
Ok(function_call.return_type.clone().resolve_ref(hints)) Ok(function_call.return_type.clone().resolve_ref(typerefs))
} }
} }
ExprKind::If(IfExpression(cond, lhs, rhs)) => { ExprKind::If(IfExpression(cond, lhs, rhs)) => {
let cond_res = cond.typecheck(state, &hints, Some(&TypeKind::Bool)); let cond_res = cond.typecheck(state, &typerefs, Some(&TypeKind::Bool));
let cond_t = state.or_else(cond_res, TypeKind::Vague(Vague::Unknown), cond.1); let cond_t = state.or_else(cond_res, TypeKind::Vague(Vague::Unknown), cond.1);
state.ok(cond_t.collapse_into(&TypeKind::Bool), cond.1); state.ok(cond_t.collapse_into(&TypeKind::Bool), cond.1);
// Typecheck then/else return types and make sure they are the // Typecheck then/else return types and make sure they are the
// same, if else exists. // same, if else exists.
let then_res = lhs.typecheck(state, &hints, hint_t); let then_res = lhs.typecheck(state, &typerefs, hint_t);
let (then_ret_kind, then_ret_t) = state.or_else( let (then_ret_kind, then_ret_t) = state.or_else(
then_res, then_res,
(ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)), (ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)),
lhs.meta, lhs.meta,
); );
let else_ret_t = if let Some(else_block) = rhs { let else_ret_t = if let Some(else_block) = rhs {
let res = else_block.typecheck(state, &hints, hint_t); let res = else_block.typecheck(state, &typerefs, hint_t);
let (else_ret_kind, else_ret_t) = state.or_else( let (else_ret_kind, else_ret_t) = state.or_else(
res, res,
(ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)), (ReturnKind::Soft, TypeKind::Vague(Vague::Unknown)),
@ -491,33 +490,32 @@ impl Expression {
if let Some(rhs) = rhs { if let Some(rhs) = rhs {
// If rhs existed, typecheck both sides to perform type // If rhs existed, typecheck both sides to perform type
// coercion. // coercion.
let lhs_res = lhs.typecheck(state, &hints, Some(&collapsed)); let lhs_res = lhs.typecheck(state, &typerefs, Some(&collapsed));
let rhs_res = rhs.typecheck(state, &hints, Some(&collapsed)); let rhs_res = rhs.typecheck(state, &typerefs, Some(&collapsed));
state.ok(lhs_res, lhs.meta); state.ok(lhs_res, lhs.meta);
state.ok(rhs_res, rhs.meta); state.ok(rhs_res, rhs.meta);
} }
Ok(collapsed) Ok(collapsed)
} }
ExprKind::Block(block) => match block.typecheck(state, &hints, hint_t) { ExprKind::Block(block) => match block.typecheck(state, &typerefs, hint_t) {
Ok((ReturnKind::Hard, _)) => Ok(TypeKind::Void), Ok((ReturnKind::Hard, _)) => Ok(TypeKind::Void),
Ok((_, ty)) => Ok(ty), Ok((_, ty)) => Ok(ty),
Err(e) => Err(e), Err(e) => Err(e),
}, },
ExprKind::ArrayIndex(expression, elem_ty, idx) => { ExprKind::Indexed(expression, elem_ty, _) => {
// Try to unwrap hint type from array if possible // Try to unwrap hint type from array if possible
let hint_t = hint_t.map(|t| match t { let hint_t = hint_t.map(|t| match t {
TypeKind::Array(type_kind, _) => &type_kind, TypeKind::Array(type_kind, _) => &type_kind,
_ => t, _ => t,
}); });
let expr_t = expression.typecheck(state, hints, hint_t)?; // TODO it could be possible to check length against constants..
if let TypeKind::Array(inferred_ty, len) = expr_t {
if len <= *idx { let expr_t = expression.typecheck(state, typerefs, hint_t)?;
return Err(ErrorKind::IndexOutOfBounds(*idx, len)); if let TypeKind::Array(inferred_ty, _) = expr_t {
}
let ty = state.or_else( let ty = state.or_else(
elem_ty.resolve_ref(hints).collapse_into(&inferred_ty), elem_ty.resolve_ref(typerefs).collapse_into(&inferred_ty),
TypeKind::Vague(Vague::Unknown), TypeKind::Vague(Vague::Unknown),
self.1, self.1,
); );
@ -537,7 +535,7 @@ impl Expression {
let mut expr_result = try_all( let mut expr_result = try_all(
expressions expressions
.iter_mut() .iter_mut()
.map(|e| e.typecheck(state, hints, hint_t)) .map(|e| e.typecheck(state, typerefs, hint_t))
.collect(), .collect(),
); );
match &mut expr_result { match &mut expr_result {
@ -564,12 +562,12 @@ impl Expression {
} }
} }
} }
ExprKind::StructIndex(expression, type_kind, field_name) => { ExprKind::Accessed(expression, type_kind, field_name) => {
// Resolve expected type // Resolve expected type
let expected_ty = type_kind.resolve_ref(hints); let expected_ty = type_kind.resolve_ref(typerefs);
// Typecheck expression // Typecheck expression
let expr_res = expression.typecheck(state, hints, Some(&expected_ty)); let expr_res = expression.typecheck(state, typerefs, Some(&expected_ty));
let expr_ty = let expr_ty =
state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), expression.1); state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), expression.1);
@ -615,7 +613,7 @@ impl Expression {
); );
// Typecheck the actual expression // Typecheck the actual expression
let expr_res = field_expr.typecheck(state, hints, Some(expected_ty)); let expr_res = field_expr.typecheck(state, typerefs, Some(expected_ty));
let expr_ty = let expr_ty =
state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), field_expr.1); state.or_else(expr_res, TypeKind::Vague(Vague::Unknown), field_expr.1);
@ -628,64 +626,6 @@ impl Expression {
} }
} }
impl IndexedVariableReference {
fn get_variable(
&self,
storage: &Storage<ScopeVariable>,
types: &Storage<TypeDefinitionKind>,
) -> Result<Option<ScopeVariable>, ErrorKind> {
match &self.kind {
IndexedVariableReferenceKind::Named(NamedVariableRef(_, name, _)) => {
Ok(storage.get(&name).cloned())
}
IndexedVariableReferenceKind::ArrayIndex(inner_ref, _) => {
if let Some(var) = inner_ref.get_variable(storage, types)? {
match &var.ty {
TypeKind::Array(inner_ty, _) => Ok(Some(ScopeVariable {
ty: *inner_ty.clone(),
mutable: var.mutable,
})),
_ => Err(ErrorKind::TriedIndexingNonArray(var.ty.clone())),
}
} else {
Ok(None)
}
}
IndexedVariableReferenceKind::StructIndex(var_ref, field_name) => {
if let Some(var) = var_ref.get_variable(storage, types)? {
match &var.ty {
TypeKind::CustomType(type_name) => {
if let Some(kind) = types.get(type_name) {
match &kind {
TypeDefinitionKind::Struct(struct_type) => {
if let Some(StructField(_, field_ty, _)) = struct_type
.0
.iter()
.find(|StructField(n, _, _)| n == field_name)
{
Ok(Some(ScopeVariable {
ty: field_ty.clone(),
mutable: var.mutable,
}))
} else {
Err(ErrorKind::NoSuchField(field_name.clone()))
}
}
}
} else {
Err(ErrorKind::NoSuchType(type_name.clone()))
}
}
_ => Err(ErrorKind::TriedAccessingNonStruct(var.ty.clone())),
}
} else {
Ok(None)
}
}
}
}
}
impl Literal { impl Literal {
/// Try to coerce this literal, ie. convert it to a more specific type in /// Try to coerce this literal, ie. convert it to a more specific type in
/// regards to the given hint if any. /// regards to the given hint if any.

View File

@ -10,9 +10,9 @@ use crate::{mir::TypeKind, util::try_all};
use super::{ use super::{
pass::{Pass, PassState}, pass::{Pass, PassState},
r#impl::pick_return,
typecheck::ErrorKind, typecheck::ErrorKind,
typerefs::{ScopeTypeRefs, TypeRef, TypeRefs}, typerefs::{ScopeTypeRefs, TypeRef, TypeRefs},
types::{pick_return, ReturnType},
Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module, Block, ExprKind, Expression, FunctionDefinition, FunctionDefinitionKind, IfExpression, Module,
ReturnKind, StmtKind, ReturnKind, StmtKind,
TypeKind::*, TypeKind::*,
@ -26,10 +26,13 @@ pub struct TypeInference<'t> {
pub refs: &'t TypeRefs, pub refs: &'t TypeRefs,
} }
type TypeInferencePassState<'st, 'sc> = PassState<'st, 'sc, (), ErrorKind>;
impl<'t> Pass for TypeInference<'t> { impl<'t> Pass for TypeInference<'t> {
type Data = ();
type TError = ErrorKind; type TError = ErrorKind;
fn module(&mut self, module: &mut Module, mut state: PassState<ErrorKind>) { fn module(&mut self, module: &mut Module, mut state: TypeInferencePassState) {
for function in &mut module.functions { for function in &mut module.functions {
let res = function.infer_types(&self.refs, &mut state.inner()); let res = function.infer_types(&self.refs, &mut state.inner());
state.ok(res, function.block_meta()); state.ok(res, function.block_meta());
@ -41,7 +44,7 @@ impl FunctionDefinition {
fn infer_types( fn infer_types(
&mut self, &mut self,
type_refs: &TypeRefs, type_refs: &TypeRefs,
state: &mut PassState<ErrorKind>, state: &mut TypeInferencePassState,
) -> Result<(), ErrorKind> { ) -> Result<(), ErrorKind> {
let scope_hints = ScopeTypeRefs::from(type_refs); let scope_hints = ScopeTypeRefs::from(type_refs);
for param in &self.parameters { for param in &self.parameters {
@ -74,7 +77,7 @@ impl FunctionDefinition {
impl Block { impl Block {
fn infer_types<'s>( fn infer_types<'s>(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypeInferencePassState,
outer_hints: &'s ScopeTypeRefs, outer_hints: &'s ScopeTypeRefs,
) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> { ) -> Result<(ReturnKind, TypeRef<'s>), ErrorKind> {
let mut state = state.inner(); let mut state = state.inner();
@ -104,18 +107,18 @@ impl Block {
var_ref.narrow(&expr_ty_ref); var_ref.narrow(&expr_ty_ref);
} }
} }
StmtKind::Set(var, expr) => { StmtKind::Set(lhs, rhs) => {
// Update this MIR type to its TypeRef // Infer hints for the expression itself
let var_ref = var.into_typeref(&inner_hints); let lhs_infer = lhs.infer_types(&mut state, &inner_hints);
let lhs_ref = state.ok(lhs_infer, rhs.1);
// Infer hints for the expression itself // Infer hints for the expression itself
let inferred = expr.infer_types(&mut state, &inner_hints); let rhs_infer = rhs.infer_types(&mut state, &inner_hints);
let expr_ty_ref = state.ok(inferred, expr.1); let rhs_ref = state.ok(rhs_infer, rhs.1);
// Try to narrow the variable type declaration with the // Try to narrow the lhs with rhs
// expression if let (Some(mut lhs_ref), Some(rhs_ref)) = (lhs_ref, rhs_ref) {
if let (Some((_, mut var_ref)), Some(expr_ty_ref)) = (var_ref, expr_ty_ref) { lhs_ref.narrow(&rhs_ref);
var_ref.narrow(&expr_ty_ref);
} }
} }
StmtKind::Import(_) => todo!(), StmtKind::Import(_) => todo!(),
@ -150,7 +153,7 @@ impl Block {
impl Expression { impl Expression {
fn infer_types<'s>( fn infer_types<'s>(
&mut self, &mut self,
state: &mut PassState<ErrorKind>, state: &mut TypeInferencePassState,
type_refs: &'s ScopeTypeRefs<'s>, type_refs: &'s ScopeTypeRefs<'s>,
) -> Result<TypeRef<'s>, ErrorKind> { ) -> Result<TypeRef<'s>, ErrorKind> {
match &mut self.0 { match &mut self.0 {
@ -249,7 +252,7 @@ impl Expression {
ReturnKind::Soft => Ok(block_ref.1), ReturnKind::Soft => Ok(block_ref.1),
} }
} }
ExprKind::ArrayIndex(expression, index_ty, _) => { ExprKind::Indexed(expression, index_ty, _) => {
let expr_ty = expression.infer_types(state, type_refs)?; let expr_ty = expression.infer_types(state, type_refs)?;
// Check that the resolved type is at least an array, no // Check that the resolved type is at least an array, no
@ -302,7 +305,7 @@ impl Expression {
} }
} }
} }
ExprKind::StructIndex(expression, type_kind, field_name) => { ExprKind::Accessed(expression, type_kind, field_name) => {
let expr_ty = expression.infer_types(state, type_refs)?; let expr_ty = expression.infer_types(state, type_refs)?;
// Check that the resolved type is at least a struct, no // Check that the resolved type is at least a struct, no