Add typechecking

This commit is contained in:
Sofia 2025-07-07 23:03:21 +03:00
parent 12dc457b99
commit a366d22470
8 changed files with 460 additions and 169 deletions

View File

@ -1,5 +1,5 @@
// Main
fn main() {
fn main() -> i32 {
return fibonacci(3);
}

View File

@ -7,6 +7,7 @@ fn main() {
let fibonacci = FunctionDefinition {
name: fibonacci_name.clone(),
return_type: TypeKind::I32,
parameters: vec![(fibonacci_n.clone(), TypeKind::I32)],
kind: FunctionDefinitionKind::Local(
Block {
@ -126,6 +127,7 @@ fn main() {
let main = FunctionDefinition {
name: "main".to_owned(),
return_type: TypeKind::I32,
parameters: vec![],
kind: FunctionDefinitionKind::Local(
Block {

View File

@ -12,56 +12,34 @@ pub enum InferredType {
Static(mir::TypeKind),
OneOf(Vec<InferredType>),
Void,
Unknown,
}
impl InferredType {
fn collapse(&self, scope: &VirtualScope) -> mir::TypeKind {
fn collapse(&self) -> mir::TypeKind {
match self {
InferredType::FromVariable(name) => {
if let Some(inferred) = scope.get_var(name) {
inferred.collapse(scope)
} else {
mir::TypeKind::Vague(mir::VagueType::Unknown)
}
}
InferredType::FunctionReturn(name) => {
if let Some(type_kind) = scope.get_return_type(name) {
type_kind.clone()
} else {
mir::TypeKind::Vague(mir::VagueType::Unknown)
}
}
InferredType::FromVariable(_) => mir::TypeKind::Vague(mir::VagueType::Unknown),
InferredType::FunctionReturn(_) => mir::TypeKind::Vague(mir::VagueType::Unknown),
InferredType::Static(type_kind) => type_kind.clone(),
InferredType::OneOf(inferred_types) => {
let list: Vec<mir::TypeKind> =
inferred_types.iter().map(|t| t.collapse(scope)).collect();
inferred_types.iter().map(|t| t.collapse()).collect();
if let Some(first) = list.first() {
if list.iter().all(|i| i == first) {
first.clone().into()
} else {
// IntoMIRError::ConflictingType(self.get_range())
mir::TypeKind::Void
mir::TypeKind::Vague(mir::VagueType::Unknown)
}
} else {
mir::TypeKind::Void
}
}
InferredType::Void => mir::TypeKind::Void,
InferredType::Unknown => mir::TypeKind::Vague(mir::VagueType::Unknown),
}
}
}
pub struct VirtualVariable {
name: String,
inferred: InferredType,
}
pub struct VirtualFunctionSignature {
name: String,
return_type: mir::TypeKind,
parameter_types: Vec<mir::TypeKind>,
}
pub struct VirtualStorage<T> {
storage: HashMap<String, Vec<T>>,
}
@ -88,73 +66,8 @@ impl<T> Default for VirtualStorage<T> {
}
}
pub struct VirtualScope {
variables: VirtualStorage<VirtualVariable>,
functions: VirtualStorage<VirtualFunctionSignature>,
}
impl VirtualScope {
pub fn set_var(&mut self, variable: VirtualVariable) {
self.variables.set(variable.name.clone(), variable);
}
pub fn set_fun(&mut self, function: VirtualFunctionSignature) {
self.functions.set(function.name.clone(), function)
}
pub fn get_var(&self, name: &String) -> Option<InferredType> {
self.variables.get(name).and_then(|v| {
if v.len() > 1 {
Some(InferredType::OneOf(
v.iter().map(|v| v.inferred.clone()).collect(),
))
} else if let Some(v) = v.first() {
Some(v.inferred.clone())
} else {
None
}
})
}
pub fn get_return_type(&self, name: &String) -> Option<mir::TypeKind> {
self.functions.get(name).and_then(|v| {
if v.len() > 1 {
Some(mir::TypeKind::Vague(mir::VagueType::Unknown))
} else if let Some(v) = v.first() {
Some(v.return_type.clone())
} else {
None
}
})
}
}
impl Default for VirtualScope {
fn default() -> Self {
Self {
variables: Default::default(),
functions: Default::default(),
}
}
}
impl ast::Module {
pub fn process(&self) -> mir::Module {
let mut scope = VirtualScope::default();
for stmt in &self.top_level_statements {
match stmt {
FunctionDefinition(ast::FunctionDefinition(signature, _, _)) => {
scope.set_fun(VirtualFunctionSignature {
name: signature.name.clone(),
return_type: signature.return_type.into(),
parameter_types: signature.args.iter().map(|p| p.1.into()).collect(),
});
}
_ => {}
}
}
let mut imports = Vec::new();
let mut functions = Vec::new();
@ -167,13 +80,6 @@ impl ast::Module {
}
}
FunctionDefinition(ast::FunctionDefinition(signature, block, range)) => {
for (name, ptype) in &signature.args {
scope.set_var(VirtualVariable {
name: name.clone(),
inferred: InferredType::Static((*ptype).into()),
});
}
let def = mir::FunctionDefinition {
name: signature.name.clone(),
return_type: signature
@ -186,10 +92,7 @@ impl ast::Module {
.cloned()
.map(|p| (p.0, p.1.into()))
.collect(),
kind: mir::FunctionDefinitionKind::Local(
block.into_mir(&mut scope),
(*range).into(),
),
kind: mir::FunctionDefinitionKind::Local(block.into_mir(), (*range).into()),
};
functions.push(def);
}
@ -207,41 +110,33 @@ impl ast::Module {
}
impl ast::Block {
pub fn into_mir(&self, scope: &mut VirtualScope) -> mir::Block {
pub fn into_mir(&self) -> mir::Block {
let mut mir_statements = Vec::new();
for statement in &self.0 {
let (kind, range) = match statement {
ast::BlockLevelStatement::Let(s_let) => {
let t = s_let.1.infer_return_type().collapse(scope);
let t = s_let.1.infer_return_type().collapse();
let inferred = InferredType::Static(t.clone());
scope.set_var(VirtualVariable {
name: s_let.0.clone(),
inferred,
});
(
mir::StmtKind::Let(
mir::VariableReference(t, s_let.0.clone(), s_let.2.into()),
s_let.1.process(scope),
s_let.1.process(),
),
s_let.2,
)
}
ast::BlockLevelStatement::Import(_) => todo!(),
ast::BlockLevelStatement::Expression(e) => {
(StmtKind::Expression(e.process(scope)), e.1)
}
ast::BlockLevelStatement::Return(_, e) => {
(StmtKind::Expression(e.process(scope)), e.1)
}
ast::BlockLevelStatement::Expression(e) => (StmtKind::Expression(e.process()), e.1),
ast::BlockLevelStatement::Return(_, e) => (StmtKind::Expression(e.process()), e.1),
};
mir_statements.push(mir::Statement(kind, range.into()));
}
let return_expression = if let Some(r) = &self.1 {
Some((r.0.into(), Box::new(r.1.process(scope))))
Some((r.0.into(), Box::new(r.1.process())))
} else {
None
};
@ -271,40 +166,32 @@ impl From<ast::ReturnType> for mir::ReturnKind {
}
impl ast::Expression {
fn process(&self, scope: &mut VirtualScope) -> mir::Expression {
fn process(&self) -> mir::Expression {
let kind = match &self.0 {
ast::ExpressionKind::VariableName(name) => mir::ExprKind::Variable(VariableReference(
if let Some(ty) = scope.get_var(name) {
ty.collapse(scope)
} else {
mir::TypeKind::Vague(mir::VagueType::Unknown)
},
mir::TypeKind::Vague(mir::VagueType::Unknown),
name.clone(),
self.1.into(),
)),
ast::ExpressionKind::Literal(literal) => mir::ExprKind::Literal(literal.mir()),
ast::ExpressionKind::Binop(binary_operator, lhs, rhs) => mir::ExprKind::BinOp(
binary_operator.mir(),
Box::new(lhs.process(scope)),
Box::new(rhs.process(scope)),
Box::new(lhs.process()),
Box::new(rhs.process()),
),
ast::ExpressionKind::FunctionCall(fn_call_expr) => {
mir::ExprKind::FunctionCall(mir::FunctionCall {
name: fn_call_expr.0.clone(),
return_type: if let Some(r_type) = scope.get_return_type(&fn_call_expr.0) {
r_type
} else {
mir::TypeKind::Vague(mir::VagueType::Unknown)
},
parameters: fn_call_expr.1.iter().map(|e| e.process(scope)).collect(),
return_type: mir::TypeKind::Vague(mir::VagueType::Unknown),
parameters: fn_call_expr.1.iter().map(|e| e.process()).collect(),
})
}
ast::ExpressionKind::BlockExpr(block) => mir::ExprKind::Block(block.into_mir(scope)),
ast::ExpressionKind::BlockExpr(block) => mir::ExprKind::Block(block.into_mir()),
ast::ExpressionKind::IfExpr(if_expression) => {
let cond = if_expression.0.process(scope);
let then_block = if_expression.1.into_mir(scope);
let cond = if_expression.0.process();
let then_block = if_expression.1.into_mir();
let else_block = if let Some(el) = &if_expression.2 {
Some(el.into_mir(scope))
Some(el.into_mir())
} else {
None
};
@ -371,12 +258,3 @@ impl From<ast::Type> for mir::TypeKind {
value.0.into()
}
}
impl From<Option<ast::Type>> for mir::TypeKind {
fn from(value: Option<ast::Type>) -> Self {
match value {
Some(v) => v.into(),
None => mir::TypeKind::Void,
}
}
}

View File

@ -7,6 +7,7 @@ mod codegen;
mod lexer;
pub mod mir;
mod token_stream;
mod util;
// TODO:
// 1. Make it so that TopLevelStatement can only be import or function def
@ -20,6 +21,8 @@ pub enum ReidError {
LexerError(#[from] lexer::Error),
#[error(transparent)]
ParserError(#[from] token_stream::Error),
#[error("Errors during typecheck: {0:?}")]
TypeCheckErrors(Vec<mir::typecheck::Error>),
// #[error(transparent)]
// CodegenError(#[from] codegen::Error),
}
@ -49,6 +52,14 @@ pub fn compile(source: &str) -> Result<String, ReidError> {
dbg!(&mir_module);
let state = mir_module.typecheck();
dbg!(&state);
if !state.errors.is_empty() {
return Err(ReidError::TypeCheckErrors(state.errors));
}
dbg!(&mir_module);
let mut context = Context::new();
let codegen_module = mir_module.codegen(&mut context);

View File

@ -3,37 +3,51 @@
/// type-checked beforehand.
use crate::token_stream::TokenRange;
pub mod typecheck;
pub mod types;
#[derive(Debug, Clone, Copy)]
#[derive(Default, Debug, Clone, Copy)]
pub struct Metadata {
pub range: TokenRange,
}
impl std::fmt::Display for Metadata {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.range)
}
}
impl std::ops::Add for Metadata {
type Output = Metadata;
fn add(self, rhs: Self) -> Self::Output {
Metadata {
range: self.range + rhs.range,
}
}
}
impl From<TokenRange> for Metadata {
fn from(value: TokenRange) -> Self {
Metadata { range: value }
}
}
impl Default for Metadata {
fn default() -> Self {
Metadata {
range: Default::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum TypeKind {
#[error("i32")]
I32,
#[error("i16")]
I16,
#[error("void")]
Void,
Vague(VagueType),
#[error(transparent)]
Vague(#[from] VagueType),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum VagueType {
#[error("Unknown")]
Unknown,
}
@ -126,6 +140,22 @@ pub enum FunctionDefinitionKind {
Extern,
}
impl FunctionDefinition {
fn block_meta(&self) -> Metadata {
match &self.kind {
FunctionDefinitionKind::Local(block, _) => block.meta,
FunctionDefinitionKind::Extern => Metadata::default(),
}
}
fn signature(&self) -> Metadata {
match &self.kind {
FunctionDefinitionKind::Local(_, metadata) => *metadata,
FunctionDefinitionKind::Extern => Metadata::default(),
}
}
}
#[derive(Debug)]
pub struct Block {
/// List of non-returning statements

362
reid/src/mir/typecheck.rs Normal file
View File

@ -0,0 +1,362 @@
use std::{collections::HashMap, convert::Infallible, iter};
/// This module contains code relevant to doing a type checking pass on the MIR.
use crate::{mir::*, util::try_all};
use TypeKind::*;
use VagueType::*;
#[derive(Debug, Clone)]
pub struct Error {
metadata: Metadata,
kind: ErrorKind,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Error at {}: {}", self.metadata, self.kind)
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.kind.source()
}
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum ErrorKind {
#[error("NULL error, should never occur!")]
Null,
#[error("Type is vague: {0}")]
TypeIsVague(VagueType),
#[error("Types {0} and {1} are incompatible")]
TypesIncompatible(TypeKind, TypeKind),
#[error("Variable not defined: {0}")]
VariableNotDefined(String),
#[error("Function not defined: {0}")]
FunctionNotDefined(String),
#[error("Type is vague: {0}")]
ReturnTypeMismatch(TypeKind, TypeKind),
}
#[derive(Clone)]
pub struct TypeStorage<T>(HashMap<String, T>);
impl<T> Default for TypeStorage<T> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T: Collapsable> TypeStorage<T> {
fn set(&mut self, key: String, value: T) -> Result<T, ErrorKind> {
if let Some(inner) = self.0.get(&key) {
match value.collapse_into(inner) {
Ok(collapsed) => {
self.0.insert(key, collapsed.clone());
Ok(collapsed)
}
Err(e) => Err(e),
}
} else {
self.0.insert(key, value.clone());
Ok(value)
}
}
fn get(&self, key: &String) -> Option<&T> {
self.0.get(key)
}
}
#[derive(Debug)]
pub struct State {
pub errors: Vec<Error>,
}
impl State {
fn new() -> State {
State {
errors: Default::default(),
}
}
fn or_else<T: Into<Metadata> + Clone + Copy>(
&mut self,
result: Result<TypeKind, ErrorKind>,
default: TypeKind,
meta: T,
) -> TypeKind {
match result {
Ok(t) => t,
Err(e) => {
self.errors.push(Error {
metadata: meta.into(),
kind: e,
});
default
}
}
}
fn ok<T: Into<Metadata> + Clone + Copy, U>(&mut self, result: Result<U, ErrorKind>, meta: T) {
if let Err(e) = result {
self.errors.push(Error {
metadata: meta.into(),
kind: e,
});
}
}
}
#[derive(Clone, Default)]
pub struct Scope {
function_returns: TypeStorage<ScopeFunction>,
variables: TypeStorage<TypeKind>,
}
#[derive(Clone)]
pub struct ScopeFunction {
ret: TypeKind,
params: Vec<TypeKind>,
}
impl Scope {
fn inner(&self) -> Scope {
Scope {
function_returns: self.function_returns.clone(),
variables: self.variables.clone(),
}
}
}
#[derive(Clone)]
pub enum Inferred {
Type(TypeKind),
Unresolved(u32),
}
impl Module {
pub fn typecheck(&self) -> State {
let mut state = State::new();
let mut scope = Scope::default();
for function in &self.functions {
let r = scope.function_returns.set(
function.name.clone(),
ScopeFunction {
ret: function.return_type,
params: function.parameters.iter().map(|v| v.1).collect(),
},
);
}
for function in &self.functions {
let res = function.typecheck(&mut state, &mut scope);
state.ok(res, function.block_meta());
}
state
}
}
impl FunctionDefinition {
fn typecheck(&self, state: &mut State, scope: &mut Scope) -> Result<TypeKind, ErrorKind> {
for param in &self.parameters {
let param_t = state.or_else(param.1.assert_known(), Vague(Unknown), self.signature());
state.ok(
scope.variables.set(param.0.clone(), param_t),
self.signature(),
);
}
let return_type = self.return_type.clone();
dbg!(&return_type);
let inferred = match &self.kind {
FunctionDefinitionKind::Local(block, _) => block.typecheck(state, scope),
FunctionDefinitionKind::Extern => Ok(Vague(Unknown)),
};
dbg!(&inferred);
match inferred {
Ok(t) => try_collapse(&return_type, &t)
.or(Err(ErrorKind::ReturnTypeMismatch(return_type, t))),
Err(e) => Ok(state.or_else(Err(e), return_type, self.block_meta())),
}
}
}
impl Block {
fn typecheck(&self, state: &mut State, scope: &mut Scope) -> Result<TypeKind, ErrorKind> {
let mut scope = scope.inner();
for statement in &self.statements {
match &statement.0 {
StmtKind::Let(variable_reference, expression) => {
let res = expression.typecheck(state, &mut scope);
// If expression resolution itself was erronous, resolve as
// Unknown.
let res = state.or_else(res, Vague(Unknown), expression.1);
// Make sure the expression and variable type really is the same
state.ok(
res.collapse_into(&variable_reference.0),
variable_reference.2 + expression.1,
);
// TODO make sure expression/variable type is NOT vague anymore
// Variable might already be defined, note error
state.ok(
scope
.variables
.set(variable_reference.1.clone(), variable_reference.0),
variable_reference.2,
);
}
StmtKind::Import(_) => todo!(),
StmtKind::Expression(expression) => {
let res = expression.typecheck(state, &mut scope);
state.ok(res, expression.1);
}
}
}
if let Some((_, expr)) = &self.return_expression {
let res = expr.typecheck(state, &mut scope);
Ok(state.or_else(res, Vague(Unknown), expr.1))
} else {
Ok(Void)
}
}
}
impl Expression {
fn typecheck(&self, state: &mut State, scope: &mut Scope) -> Result<TypeKind, ErrorKind> {
match &self.0 {
ExprKind::Variable(var_ref) => {
let existing = state.or_else(
scope
.variables
.get(&var_ref.1)
.copied()
.ok_or(ErrorKind::VariableNotDefined(var_ref.1.clone())),
Vague(Unknown),
var_ref.2,
);
Ok(state.or_else(
var_ref.0.collapse_into(&existing),
Vague(Unknown),
var_ref.2,
))
}
ExprKind::Literal(literal) => Ok(literal.as_type()),
ExprKind::BinOp(_, lhs, rhs) => {
// TODO make sure lhs and rhs can actually do this binary
// operation once relevant
let lhs_res = lhs.typecheck(state, scope);
let rhs_res = rhs.typecheck(state, scope);
let lhs_type = state.or_else(lhs_res, Vague(Unknown), lhs.1);
let rhs_type = state.or_else(rhs_res, Vague(Unknown), rhs.1);
lhs_type.collapse_into(&rhs_type)
}
ExprKind::FunctionCall(function_call) => {
let true_function = scope
.function_returns
.get(&function_call.name)
.cloned()
.ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone()));
if let Ok(f) = true_function {
if function_call.parameters.len() != f.params.len() {
state.ok::<_, Infallible>(Err(ErrorKind::Null), self.1);
}
let true_params_iter = f.params.into_iter().chain(iter::repeat(Vague(Unknown)));
for (param, true_param_t) in
function_call.parameters.iter().zip(true_params_iter)
{
let param_res = param.typecheck(state, scope);
let param_t = state.or_else(param_res, Vague(Unknown), param.1);
state.ok(param_t.collapse_into(&true_param_t), param.1);
}
// Make sure function return type is the same as the claimed
// return type
// TODO: Set return type here actually
try_collapse(&f.ret, &function_call.return_type)
} else {
Ok(function_call.return_type)
}
}
ExprKind::If(IfExpression(cond, lhs, rhs)) => {
// TODO make sure cond_res is Boolean here
let cond_res = cond.typecheck(state, scope);
state.ok(cond_res, cond.1);
let lhs_res = lhs.typecheck(state, scope);
let lhs_type = state.or_else(lhs_res, Vague(Unknown), lhs.meta);
let rhs_type = if let Some(rhs) = rhs {
let res = rhs.typecheck(state, scope);
state.or_else(res, Vague(Unknown), rhs.meta)
} else {
Vague(Unknown)
};
lhs_type.collapse_into(&rhs_type)
}
ExprKind::Block(block) => block.typecheck(state, scope),
}
}
}
impl TypeKind {
fn assert_known(&self) -> Result<TypeKind, ErrorKind> {
if let Vague(vague) = self {
Err(ErrorKind::TypeIsVague(*vague))
} else {
Ok(*self)
}
}
}
fn try_collapse(lhs: &TypeKind, rhs: &TypeKind) -> Result<TypeKind, ErrorKind> {
lhs.collapse_into(rhs)
.or(rhs.collapse_into(lhs))
.or(Err(ErrorKind::TypesIncompatible(*lhs, *rhs)))
}
trait Collapsable: Sized + Clone {
fn collapse_into(&self, other: &Self) -> Result<Self, ErrorKind>;
}
impl Collapsable for TypeKind {
fn collapse_into(&self, other: &TypeKind) -> Result<TypeKind, ErrorKind> {
if self == other {
return Ok(self.clone());
}
match (self, other) {
(Vague(Unknown), other) | (other, Vague(Unknown)) => Ok(other.clone()),
_ => Err(ErrorKind::TypesIncompatible(*self, *other)),
}
}
}
impl Collapsable for ScopeFunction {
fn collapse_into(&self, other: &ScopeFunction) -> Result<ScopeFunction, ErrorKind> {
Ok(ScopeFunction {
ret: self.ret.collapse_into(&other.ret)?,
params: try_all(
self.params
.iter()
.zip(&other.params)
.map(|(p1, p2)| p1.collapse_into(&p2))
.collect(),
)
.map_err(|e| e.first().unwrap().clone())?,
})
}
}

View File

@ -156,7 +156,7 @@ impl Drop for TokenStream<'_, '_> {
}
}
#[derive(Clone, Copy)]
#[derive(Default, Clone, Copy)]
pub struct TokenRange {
pub start: usize,
pub end: usize,
@ -168,15 +168,6 @@ impl std::fmt::Debug for TokenRange {
}
}
impl Default for TokenRange {
fn default() -> Self {
Self {
start: Default::default(),
end: Default::default(),
}
}
}
impl std::ops::Add for TokenRange {
type Output = TokenRange;

17
reid/src/util.rs Normal file
View File

@ -0,0 +1,17 @@
pub fn try_all<U, E>(list: Vec<Result<U, E>>) -> Result<Vec<U>, Vec<E>> {
let mut successes = Vec::with_capacity(list.len());
let mut failures = Vec::with_capacity(list.len());
for item in list {
match item {
Ok(s) => successes.push(s),
Err(e) => failures.push(e),
}
}
if failures.len() > 0 {
Err(failures)
} else {
Ok(successes)
}
}