diff --git a/reid/examples/reid/fibonacci.reid b/reid/examples/reid/fibonacci.reid index 30160b7..b287c6f 100644 --- a/reid/examples/reid/fibonacci.reid +++ b/reid/examples/reid/fibonacci.reid @@ -1,5 +1,5 @@ // Main -fn main() { +fn main() -> i32 { return fibonacci(3); } diff --git a/reid/examples/testcodegen.rs b/reid/examples/testcodegen.rs index 9dc3c51..940edf1 100644 --- a/reid/examples/testcodegen.rs +++ b/reid/examples/testcodegen.rs @@ -7,6 +7,7 @@ fn main() { let fibonacci = FunctionDefinition { name: fibonacci_name.clone(), + return_type: TypeKind::I32, parameters: vec![(fibonacci_n.clone(), TypeKind::I32)], kind: FunctionDefinitionKind::Local( Block { @@ -126,6 +127,7 @@ fn main() { let main = FunctionDefinition { name: "main".to_owned(), + return_type: TypeKind::I32, parameters: vec![], kind: FunctionDefinitionKind::Local( Block { diff --git a/reid/src/ast/process.rs b/reid/src/ast/process.rs index bf6ae42..a963274 100644 --- a/reid/src/ast/process.rs +++ b/reid/src/ast/process.rs @@ -12,56 +12,34 @@ pub enum InferredType { Static(mir::TypeKind), OneOf(Vec), Void, + Unknown, } impl InferredType { - fn collapse(&self, scope: &VirtualScope) -> mir::TypeKind { + fn collapse(&self) -> mir::TypeKind { match self { - InferredType::FromVariable(name) => { - if let Some(inferred) = scope.get_var(name) { - inferred.collapse(scope) - } else { - mir::TypeKind::Vague(mir::VagueType::Unknown) - } - } - InferredType::FunctionReturn(name) => { - if let Some(type_kind) = scope.get_return_type(name) { - type_kind.clone() - } else { - mir::TypeKind::Vague(mir::VagueType::Unknown) - } - } + InferredType::FromVariable(_) => mir::TypeKind::Vague(mir::VagueType::Unknown), + InferredType::FunctionReturn(_) => mir::TypeKind::Vague(mir::VagueType::Unknown), InferredType::Static(type_kind) => type_kind.clone(), InferredType::OneOf(inferred_types) => { let list: Vec = - inferred_types.iter().map(|t| t.collapse(scope)).collect(); + inferred_types.iter().map(|t| t.collapse()).collect(); if let Some(first) = list.first() { if list.iter().all(|i| i == first) { first.clone().into() } else { - // IntoMIRError::ConflictingType(self.get_range()) - mir::TypeKind::Void + mir::TypeKind::Vague(mir::VagueType::Unknown) } } else { mir::TypeKind::Void } } InferredType::Void => mir::TypeKind::Void, + InferredType::Unknown => mir::TypeKind::Vague(mir::VagueType::Unknown), } } } -pub struct VirtualVariable { - name: String, - inferred: InferredType, -} - -pub struct VirtualFunctionSignature { - name: String, - return_type: mir::TypeKind, - parameter_types: Vec, -} - pub struct VirtualStorage { storage: HashMap>, } @@ -88,73 +66,8 @@ impl Default for VirtualStorage { } } -pub struct VirtualScope { - variables: VirtualStorage, - functions: VirtualStorage, -} - -impl VirtualScope { - pub fn set_var(&mut self, variable: VirtualVariable) { - self.variables.set(variable.name.clone(), variable); - } - - pub fn set_fun(&mut self, function: VirtualFunctionSignature) { - self.functions.set(function.name.clone(), function) - } - - pub fn get_var(&self, name: &String) -> Option { - self.variables.get(name).and_then(|v| { - if v.len() > 1 { - Some(InferredType::OneOf( - v.iter().map(|v| v.inferred.clone()).collect(), - )) - } else if let Some(v) = v.first() { - Some(v.inferred.clone()) - } else { - None - } - }) - } - - pub fn get_return_type(&self, name: &String) -> Option { - self.functions.get(name).and_then(|v| { - if v.len() > 1 { - Some(mir::TypeKind::Vague(mir::VagueType::Unknown)) - } else if let Some(v) = v.first() { - Some(v.return_type.clone()) - } else { - None - } - }) - } -} - -impl Default for VirtualScope { - fn default() -> Self { - Self { - variables: Default::default(), - functions: Default::default(), - } - } -} - impl ast::Module { pub fn process(&self) -> mir::Module { - let mut scope = VirtualScope::default(); - - for stmt in &self.top_level_statements { - match stmt { - FunctionDefinition(ast::FunctionDefinition(signature, _, _)) => { - scope.set_fun(VirtualFunctionSignature { - name: signature.name.clone(), - return_type: signature.return_type.into(), - parameter_types: signature.args.iter().map(|p| p.1.into()).collect(), - }); - } - _ => {} - } - } - let mut imports = Vec::new(); let mut functions = Vec::new(); @@ -167,13 +80,6 @@ impl ast::Module { } } FunctionDefinition(ast::FunctionDefinition(signature, block, range)) => { - for (name, ptype) in &signature.args { - scope.set_var(VirtualVariable { - name: name.clone(), - inferred: InferredType::Static((*ptype).into()), - }); - } - let def = mir::FunctionDefinition { name: signature.name.clone(), return_type: signature @@ -186,10 +92,7 @@ impl ast::Module { .cloned() .map(|p| (p.0, p.1.into())) .collect(), - kind: mir::FunctionDefinitionKind::Local( - block.into_mir(&mut scope), - (*range).into(), - ), + kind: mir::FunctionDefinitionKind::Local(block.into_mir(), (*range).into()), }; functions.push(def); } @@ -207,41 +110,33 @@ impl ast::Module { } impl ast::Block { - pub fn into_mir(&self, scope: &mut VirtualScope) -> mir::Block { + pub fn into_mir(&self) -> mir::Block { let mut mir_statements = Vec::new(); for statement in &self.0 { let (kind, range) = match statement { ast::BlockLevelStatement::Let(s_let) => { - let t = s_let.1.infer_return_type().collapse(scope); + let t = s_let.1.infer_return_type().collapse(); let inferred = InferredType::Static(t.clone()); - scope.set_var(VirtualVariable { - name: s_let.0.clone(), - inferred, - }); ( mir::StmtKind::Let( mir::VariableReference(t, s_let.0.clone(), s_let.2.into()), - s_let.1.process(scope), + s_let.1.process(), ), s_let.2, ) } ast::BlockLevelStatement::Import(_) => todo!(), - ast::BlockLevelStatement::Expression(e) => { - (StmtKind::Expression(e.process(scope)), e.1) - } - ast::BlockLevelStatement::Return(_, e) => { - (StmtKind::Expression(e.process(scope)), e.1) - } + ast::BlockLevelStatement::Expression(e) => (StmtKind::Expression(e.process()), e.1), + ast::BlockLevelStatement::Return(_, e) => (StmtKind::Expression(e.process()), e.1), }; mir_statements.push(mir::Statement(kind, range.into())); } let return_expression = if let Some(r) = &self.1 { - Some((r.0.into(), Box::new(r.1.process(scope)))) + Some((r.0.into(), Box::new(r.1.process()))) } else { None }; @@ -271,40 +166,32 @@ impl From for mir::ReturnKind { } impl ast::Expression { - fn process(&self, scope: &mut VirtualScope) -> mir::Expression { + fn process(&self) -> mir::Expression { let kind = match &self.0 { ast::ExpressionKind::VariableName(name) => mir::ExprKind::Variable(VariableReference( - if let Some(ty) = scope.get_var(name) { - ty.collapse(scope) - } else { - mir::TypeKind::Vague(mir::VagueType::Unknown) - }, + mir::TypeKind::Vague(mir::VagueType::Unknown), name.clone(), self.1.into(), )), ast::ExpressionKind::Literal(literal) => mir::ExprKind::Literal(literal.mir()), ast::ExpressionKind::Binop(binary_operator, lhs, rhs) => mir::ExprKind::BinOp( binary_operator.mir(), - Box::new(lhs.process(scope)), - Box::new(rhs.process(scope)), + Box::new(lhs.process()), + Box::new(rhs.process()), ), ast::ExpressionKind::FunctionCall(fn_call_expr) => { mir::ExprKind::FunctionCall(mir::FunctionCall { name: fn_call_expr.0.clone(), - return_type: if let Some(r_type) = scope.get_return_type(&fn_call_expr.0) { - r_type - } else { - mir::TypeKind::Vague(mir::VagueType::Unknown) - }, - parameters: fn_call_expr.1.iter().map(|e| e.process(scope)).collect(), + return_type: mir::TypeKind::Vague(mir::VagueType::Unknown), + parameters: fn_call_expr.1.iter().map(|e| e.process()).collect(), }) } - ast::ExpressionKind::BlockExpr(block) => mir::ExprKind::Block(block.into_mir(scope)), + ast::ExpressionKind::BlockExpr(block) => mir::ExprKind::Block(block.into_mir()), ast::ExpressionKind::IfExpr(if_expression) => { - let cond = if_expression.0.process(scope); - let then_block = if_expression.1.into_mir(scope); + let cond = if_expression.0.process(); + let then_block = if_expression.1.into_mir(); let else_block = if let Some(el) = &if_expression.2 { - Some(el.into_mir(scope)) + Some(el.into_mir()) } else { None }; @@ -371,12 +258,3 @@ impl From for mir::TypeKind { value.0.into() } } - -impl From> for mir::TypeKind { - fn from(value: Option) -> Self { - match value { - Some(v) => v.into(), - None => mir::TypeKind::Void, - } - } -} diff --git a/reid/src/lib.rs b/reid/src/lib.rs index 95ec4d4..167bd2d 100644 --- a/reid/src/lib.rs +++ b/reid/src/lib.rs @@ -7,6 +7,7 @@ mod codegen; mod lexer; pub mod mir; mod token_stream; +mod util; // TODO: // 1. Make it so that TopLevelStatement can only be import or function def @@ -20,6 +21,8 @@ pub enum ReidError { LexerError(#[from] lexer::Error), #[error(transparent)] ParserError(#[from] token_stream::Error), + #[error("Errors during typecheck: {0:?}")] + TypeCheckErrors(Vec), // #[error(transparent)] // CodegenError(#[from] codegen::Error), } @@ -49,6 +52,14 @@ pub fn compile(source: &str) -> Result { dbg!(&mir_module); + let state = mir_module.typecheck(); + dbg!(&state); + if !state.errors.is_empty() { + return Err(ReidError::TypeCheckErrors(state.errors)); + } + + dbg!(&mir_module); + let mut context = Context::new(); let codegen_module = mir_module.codegen(&mut context); diff --git a/reid/src/mir/mod.rs b/reid/src/mir/mod.rs index 4b11592..8af706d 100644 --- a/reid/src/mir/mod.rs +++ b/reid/src/mir/mod.rs @@ -3,37 +3,51 @@ /// type-checked beforehand. use crate::token_stream::TokenRange; +pub mod typecheck; pub mod types; -#[derive(Debug, Clone, Copy)] +#[derive(Default, Debug, Clone, Copy)] pub struct Metadata { pub range: TokenRange, } +impl std::fmt::Display for Metadata { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.range) + } +} + +impl std::ops::Add for Metadata { + type Output = Metadata; + + fn add(self, rhs: Self) -> Self::Output { + Metadata { + range: self.range + rhs.range, + } + } +} + impl From for Metadata { fn from(value: TokenRange) -> Self { Metadata { range: value } } } -impl Default for Metadata { - fn default() -> Self { - Metadata { - range: Default::default(), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] pub enum TypeKind { + #[error("i32")] I32, + #[error("i16")] I16, + #[error("void")] Void, - Vague(VagueType), + #[error(transparent)] + Vague(#[from] VagueType), } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] pub enum VagueType { + #[error("Unknown")] Unknown, } @@ -126,6 +140,22 @@ pub enum FunctionDefinitionKind { Extern, } +impl FunctionDefinition { + fn block_meta(&self) -> Metadata { + match &self.kind { + FunctionDefinitionKind::Local(block, _) => block.meta, + FunctionDefinitionKind::Extern => Metadata::default(), + } + } + + fn signature(&self) -> Metadata { + match &self.kind { + FunctionDefinitionKind::Local(_, metadata) => *metadata, + FunctionDefinitionKind::Extern => Metadata::default(), + } + } +} + #[derive(Debug)] pub struct Block { /// List of non-returning statements diff --git a/reid/src/mir/typecheck.rs b/reid/src/mir/typecheck.rs new file mode 100644 index 0000000..fc3e43e --- /dev/null +++ b/reid/src/mir/typecheck.rs @@ -0,0 +1,362 @@ +use std::{collections::HashMap, convert::Infallible, iter}; + +/// This module contains code relevant to doing a type checking pass on the MIR. +use crate::{mir::*, util::try_all}; +use TypeKind::*; +use VagueType::*; + +#[derive(Debug, Clone)] +pub struct Error { + metadata: Metadata, + kind: ErrorKind, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error at {}: {}", self.metadata, self.kind) + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.kind.source() + } +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum ErrorKind { + #[error("NULL error, should never occur!")] + Null, + #[error("Type is vague: {0}")] + TypeIsVague(VagueType), + #[error("Types {0} and {1} are incompatible")] + TypesIncompatible(TypeKind, TypeKind), + #[error("Variable not defined: {0}")] + VariableNotDefined(String), + #[error("Function not defined: {0}")] + FunctionNotDefined(String), + #[error("Type is vague: {0}")] + ReturnTypeMismatch(TypeKind, TypeKind), +} + +#[derive(Clone)] +pub struct TypeStorage(HashMap); + +impl Default for TypeStorage { + fn default() -> Self { + Self(Default::default()) + } +} + +impl TypeStorage { + fn set(&mut self, key: String, value: T) -> Result { + if let Some(inner) = self.0.get(&key) { + match value.collapse_into(inner) { + Ok(collapsed) => { + self.0.insert(key, collapsed.clone()); + Ok(collapsed) + } + Err(e) => Err(e), + } + } else { + self.0.insert(key, value.clone()); + Ok(value) + } + } + + fn get(&self, key: &String) -> Option<&T> { + self.0.get(key) + } +} + +#[derive(Debug)] +pub struct State { + pub errors: Vec, +} + +impl State { + fn new() -> State { + State { + errors: Default::default(), + } + } + + fn or_else + Clone + Copy>( + &mut self, + result: Result, + default: TypeKind, + meta: T, + ) -> TypeKind { + match result { + Ok(t) => t, + Err(e) => { + self.errors.push(Error { + metadata: meta.into(), + kind: e, + }); + default + } + } + } + + fn ok + Clone + Copy, U>(&mut self, result: Result, meta: T) { + if let Err(e) = result { + self.errors.push(Error { + metadata: meta.into(), + kind: e, + }); + } + } +} + +#[derive(Clone, Default)] +pub struct Scope { + function_returns: TypeStorage, + variables: TypeStorage, +} + +#[derive(Clone)] +pub struct ScopeFunction { + ret: TypeKind, + params: Vec, +} + +impl Scope { + fn inner(&self) -> Scope { + Scope { + function_returns: self.function_returns.clone(), + variables: self.variables.clone(), + } + } +} + +#[derive(Clone)] +pub enum Inferred { + Type(TypeKind), + Unresolved(u32), +} + +impl Module { + pub fn typecheck(&self) -> State { + let mut state = State::new(); + let mut scope = Scope::default(); + + for function in &self.functions { + let r = scope.function_returns.set( + function.name.clone(), + ScopeFunction { + ret: function.return_type, + params: function.parameters.iter().map(|v| v.1).collect(), + }, + ); + } + + for function in &self.functions { + let res = function.typecheck(&mut state, &mut scope); + state.ok(res, function.block_meta()); + } + + state + } +} + +impl FunctionDefinition { + fn typecheck(&self, state: &mut State, scope: &mut Scope) -> Result { + for param in &self.parameters { + let param_t = state.or_else(param.1.assert_known(), Vague(Unknown), self.signature()); + state.ok( + scope.variables.set(param.0.clone(), param_t), + self.signature(), + ); + } + + let return_type = self.return_type.clone(); + dbg!(&return_type); + let inferred = match &self.kind { + FunctionDefinitionKind::Local(block, _) => block.typecheck(state, scope), + FunctionDefinitionKind::Extern => Ok(Vague(Unknown)), + }; + dbg!(&inferred); + + match inferred { + Ok(t) => try_collapse(&return_type, &t) + .or(Err(ErrorKind::ReturnTypeMismatch(return_type, t))), + Err(e) => Ok(state.or_else(Err(e), return_type, self.block_meta())), + } + } +} + +impl Block { + fn typecheck(&self, state: &mut State, scope: &mut Scope) -> Result { + let mut scope = scope.inner(); + + for statement in &self.statements { + match &statement.0 { + StmtKind::Let(variable_reference, expression) => { + let res = expression.typecheck(state, &mut scope); + + // If expression resolution itself was erronous, resolve as + // Unknown. + let res = state.or_else(res, Vague(Unknown), expression.1); + + // Make sure the expression and variable type really is the same + state.ok( + res.collapse_into(&variable_reference.0), + variable_reference.2 + expression.1, + ); + + // TODO make sure expression/variable type is NOT vague anymore + + // Variable might already be defined, note error + state.ok( + scope + .variables + .set(variable_reference.1.clone(), variable_reference.0), + variable_reference.2, + ); + } + StmtKind::Import(_) => todo!(), + StmtKind::Expression(expression) => { + let res = expression.typecheck(state, &mut scope); + state.ok(res, expression.1); + } + } + } + + if let Some((_, expr)) = &self.return_expression { + let res = expr.typecheck(state, &mut scope); + Ok(state.or_else(res, Vague(Unknown), expr.1)) + } else { + Ok(Void) + } + } +} + +impl Expression { + fn typecheck(&self, state: &mut State, scope: &mut Scope) -> Result { + match &self.0 { + ExprKind::Variable(var_ref) => { + let existing = state.or_else( + scope + .variables + .get(&var_ref.1) + .copied() + .ok_or(ErrorKind::VariableNotDefined(var_ref.1.clone())), + Vague(Unknown), + var_ref.2, + ); + + Ok(state.or_else( + var_ref.0.collapse_into(&existing), + Vague(Unknown), + var_ref.2, + )) + } + ExprKind::Literal(literal) => Ok(literal.as_type()), + ExprKind::BinOp(_, lhs, rhs) => { + // TODO make sure lhs and rhs can actually do this binary + // operation once relevant + let lhs_res = lhs.typecheck(state, scope); + let rhs_res = rhs.typecheck(state, scope); + let lhs_type = state.or_else(lhs_res, Vague(Unknown), lhs.1); + let rhs_type = state.or_else(rhs_res, Vague(Unknown), rhs.1); + lhs_type.collapse_into(&rhs_type) + } + ExprKind::FunctionCall(function_call) => { + let true_function = scope + .function_returns + .get(&function_call.name) + .cloned() + .ok_or(ErrorKind::FunctionNotDefined(function_call.name.clone())); + + if let Ok(f) = true_function { + if function_call.parameters.len() != f.params.len() { + state.ok::<_, Infallible>(Err(ErrorKind::Null), self.1); + } + + let true_params_iter = f.params.into_iter().chain(iter::repeat(Vague(Unknown))); + + for (param, true_param_t) in + function_call.parameters.iter().zip(true_params_iter) + { + let param_res = param.typecheck(state, scope); + let param_t = state.or_else(param_res, Vague(Unknown), param.1); + state.ok(param_t.collapse_into(&true_param_t), param.1); + } + + // Make sure function return type is the same as the claimed + // return type + // TODO: Set return type here actually + try_collapse(&f.ret, &function_call.return_type) + } else { + Ok(function_call.return_type) + } + } + ExprKind::If(IfExpression(cond, lhs, rhs)) => { + // TODO make sure cond_res is Boolean here + let cond_res = cond.typecheck(state, scope); + state.ok(cond_res, cond.1); + + let lhs_res = lhs.typecheck(state, scope); + let lhs_type = state.or_else(lhs_res, Vague(Unknown), lhs.meta); + let rhs_type = if let Some(rhs) = rhs { + let res = rhs.typecheck(state, scope); + state.or_else(res, Vague(Unknown), rhs.meta) + } else { + Vague(Unknown) + }; + lhs_type.collapse_into(&rhs_type) + } + ExprKind::Block(block) => block.typecheck(state, scope), + } + } +} + +impl TypeKind { + fn assert_known(&self) -> Result { + if let Vague(vague) = self { + Err(ErrorKind::TypeIsVague(*vague)) + } else { + Ok(*self) + } + } +} + +fn try_collapse(lhs: &TypeKind, rhs: &TypeKind) -> Result { + lhs.collapse_into(rhs) + .or(rhs.collapse_into(lhs)) + .or(Err(ErrorKind::TypesIncompatible(*lhs, *rhs))) +} + +trait Collapsable: Sized + Clone { + fn collapse_into(&self, other: &Self) -> Result; +} + +impl Collapsable for TypeKind { + fn collapse_into(&self, other: &TypeKind) -> Result { + if self == other { + return Ok(self.clone()); + } + + match (self, other) { + (Vague(Unknown), other) | (other, Vague(Unknown)) => Ok(other.clone()), + _ => Err(ErrorKind::TypesIncompatible(*self, *other)), + } + } +} + +impl Collapsable for ScopeFunction { + fn collapse_into(&self, other: &ScopeFunction) -> Result { + Ok(ScopeFunction { + ret: self.ret.collapse_into(&other.ret)?, + params: try_all( + self.params + .iter() + .zip(&other.params) + .map(|(p1, p2)| p1.collapse_into(&p2)) + .collect(), + ) + .map_err(|e| e.first().unwrap().clone())?, + }) + } +} diff --git a/reid/src/token_stream.rs b/reid/src/token_stream.rs index bf0ddd8..d35d6f9 100644 --- a/reid/src/token_stream.rs +++ b/reid/src/token_stream.rs @@ -156,7 +156,7 @@ impl Drop for TokenStream<'_, '_> { } } -#[derive(Clone, Copy)] +#[derive(Default, Clone, Copy)] pub struct TokenRange { pub start: usize, pub end: usize, @@ -168,15 +168,6 @@ impl std::fmt::Debug for TokenRange { } } -impl Default for TokenRange { - fn default() -> Self { - Self { - start: Default::default(), - end: Default::default(), - } - } -} - impl std::ops::Add for TokenRange { type Output = TokenRange; diff --git a/reid/src/util.rs b/reid/src/util.rs new file mode 100644 index 0000000..cf45009 --- /dev/null +++ b/reid/src/util.rs @@ -0,0 +1,17 @@ +pub fn try_all(list: Vec>) -> Result, Vec> { + let mut successes = Vec::with_capacity(list.len()); + let mut failures = Vec::with_capacity(list.len()); + + for item in list { + match item { + Ok(s) => successes.push(s), + Err(e) => failures.push(e), + } + } + + if failures.len() > 0 { + Err(failures) + } else { + Ok(successes) + } +}