#include "typechecker.h" #include "ast.h" #include "errors.h" #include "result.h" namespace { enum class TypecheckRes { Ok, Castable, }; Result check_type(std::shared_ptr checked, std::shared_ptr target) { if (types::types_equal(checked, target)) { return TypecheckRes::Ok; } return std::string{ "Types " + checked->formatted() + " and " + target->formatted() + " incompatible" }; } std::unique_ptr handle_res( std::unique_ptr expr, Result res, typecheck::State& state) { if (res.ok()) { auto result = res.unwrap(); if (result == TypecheckRes::Ok) { return expr; } else { state.errors.push_back(CompileError("Casting not yet implemented", expr->m_meta)); return expr; } } else { state.errors.push_back(CompileError(res.unwrap_err(), expr->m_meta)); return expr; } } } namespace AST { std::shared_ptr IntLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, std::optional> ) { return std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Int } }; } std::shared_ptr StringLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, std::optional> ) { auto char_ty = std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Char } }; auto ptr_ty = new types::PointerType{ char_ty }; return std::shared_ptr{ptr_ty}; } std::shared_ptr ValueReferenceExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { if (scope.symbols.find(this->m_name) != scope.symbols.end()) { return scope.symbols[this->m_name]; } state.errors.push_back(CompileError("Value " + this->m_name + " not defined", this->m_meta)); return std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } std::shared_ptr BinaryOperationExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { auto lhs_ty = this->m_lhs->typecheck(state, scope, {}); auto rhs_ty = this->m_rhs->typecheck(state, scope, {}); if (this->m_binop == types::BinOp::Assignment) { return lhs_ty; } auto binop = types::find_binop( state.binops, lhs_ty, this->m_binop, rhs_ty ); if (binop) { return binop->result; } // TODO check for binops that may be implicitly castable state.errors.push_back(CompileError( "No suitable binop between " + lhs_ty->formatted() + " " + types::format_operator(this->m_binop) + " " + rhs_ty->formatted(), this->m_meta)); return std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } std::shared_ptr FunctionCallExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { auto expr_ty = this->m_fn_expr->typecheck(state, scope, {}); if (expr_ty->m_kind != types::TypeKind::Function) { state.errors.push_back(CompileError("Tried calling a non-function", this->m_meta)); return std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } auto fn_ty = dynamic_cast(expr_ty.get()); if (this->m_args.size() < fn_ty->m_param_tys.size()) { state.errors.push_back(CompileError("too few arguments", this->m_meta)); } else if (this->m_args.size() > fn_ty->m_param_tys.size() && !fn_ty->m_vararg) { state.errors.push_back(CompileError("too many arguments", this->m_meta)); } else { for (int i = 0; i < static_cast(this->m_args.size()); i++) { if (i < static_cast(fn_ty->m_param_tys.size())) { auto expected_param_ty = fn_ty->m_param_tys[i]; auto param_ty = this->m_args[i]->typecheck(state, scope, expected_param_ty); auto check_res = check_type(param_ty, expected_param_ty); this->m_args[i] = handle_res(std::move(this->m_args[i]), check_res, state); } else { this->m_args[i]->typecheck(state, scope, {}); } } } return fn_ty->m_ret_ty; } void ReturnStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto res_ty = this->m_expr->typecheck(state, scope, scope.return_ty); if (scope.return_ty) { auto check_res = check_type(res_ty, *scope.return_ty); this->m_expr = handle_res(std::move(this->m_expr), check_res, state); } } void InitializationStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { if (this->m_expr) { (*this->m_expr)->typecheck(state, scope, this->m_type); } scope.symbols[this->m_name] = this->m_type; } void ExpressionStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { this->m_expr->typecheck(state, scope, {}); } void IfStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto bool_ty = std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Bool } }; auto expr_ty = this->m_condition->typecheck(state, scope, bool_ty); auto check_res = check_type(expr_ty, bool_ty); this->m_condition = handle_res(std::move(this->m_condition), check_res, state); this->m_then->typecheck(state, scope); if (this->m_else) { (*this->m_else)->typecheck(state, scope); } } void Function::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto return_ty = this->m_return_ty; std::vector> param_tys{}; for (auto& param : this->m_params) { param_tys.push_back(param.second); } auto function_ty = new types::FunctionType{ return_ty, param_tys, this->m_is_vararg }; scope.symbols[this->m_name] = std::shared_ptr{ function_ty }; typecheck::Scope inner{ scope }; inner.return_ty = return_ty; for (auto& param : this->m_params) { if (param.first) { inner.symbols[*param.first] = param.second; } } if (this->m_statements) { for (auto& statement : *this->m_statements) { statement->typecheck(state, inner); } } } }