#include "typechecker.h" #include "ast.h" #include "errors.h" #include "result.h" namespace { enum class TypecheckResKind { Ok, Castable, }; struct TypecheckRes { TypecheckResKind kind; std::shared_ptr result; }; Result check_type( typecheck::State& state, std::shared_ptr checked, std::shared_ptr target) { auto potential_cast = types::find_cast(state.casts, checked, target); if (types::types_equal(checked, target)) { return TypecheckRes{ TypecheckResKind::Ok, target }; } else if (potential_cast.has_value()) { if (potential_cast->allow_implicit) return TypecheckRes{ TypecheckResKind::Castable, target }; return std::string{ "Type " + checked->formatted() + " not implicitly castable to " + target->formatted() }; } return std::string{ "Types " + checked->formatted() + " and " + target->formatted() + " incompatible" }; } std::unique_ptr handle_res( std::unique_ptr expr, Result res, typecheck::State& state) { if (res.ok()) { auto result = res.unwrap(); if (result.kind == TypecheckResKind::Ok) { return expr; } else { return std::unique_ptr { new AST::CastExpression{ expr->m_meta, result.result, std::move(expr) } }; } } else { state.errors.push_back(CompileError(res.unwrap_err(), expr->m_meta)); return expr; } } } namespace AST { std::shared_ptr IntLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, std::optional> expected_ty ) { // Allow implicitly converting IntLiteralExpression to other types // representable by integers. if (expected_ty) { if ((*expected_ty)->m_kind == types::TypeKind::Fundamental) { auto ty = dynamic_cast((*expected_ty).get()); if ( ty->m_ty == types::FundamentalTypeKind::Bool || ty->m_ty == types::FundamentalTypeKind::Char || ty->m_ty == types::FundamentalTypeKind::Int ) { this->m_ty = *expected_ty; } } } return this->m_ty; } std::shared_ptr StringLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, std::optional> ) { auto char_ty = std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Char } }; auto ptr_ty = new types::ArrayType{ char_ty, static_cast(this->m_value.size()) + 1 }; return std::shared_ptr{ptr_ty}; } std::shared_ptr ValueReferenceExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { if (scope.symbols.find(this->m_name) != scope.symbols.end()) { return scope.symbols[this->m_name]; } state.errors.push_back(CompileError("Value " + this->m_name + " not defined", this->m_meta)); return std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } std::shared_ptr BinaryOperationExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty ) { auto lhs_ty = this->m_lhs->typecheck(state, scope, {}); auto rhs_ty = this->m_rhs->typecheck(state, scope, {}); if (this->m_binop == types::BinOp::Assignment) { // Re-typecheck rhs to actually match lhs auto rhs_ty = this->m_rhs->typecheck(state, scope, lhs_ty); auto rhs_ty_res = check_type(state, rhs_ty, lhs_ty); this->m_rhs = handle_res(std::move(this->m_rhs), rhs_ty_res, state); return lhs_ty; } // Try to find a binop that matches exactly auto binop = types::find_binop( state.binops, lhs_ty, this->m_binop, rhs_ty ); if (binop) { return binop->result(*binop, lhs_ty, rhs_ty); } // If that fails, try to find binop that matches on one side perfectly // and is castable on the other side, and would also be perfectly // assignable to the expected value. for (auto& binop : state.binops) { if (expected_ty) { // Skip any binops that would not be immediately assignable to // the expected type if (!types::types_equal(binop.result(binop, lhs_ty, rhs_ty), *expected_ty)) { continue; } } if (types::types_equal(binop.lhs, lhs_ty)) { auto rhs_res = check_type(state, rhs_ty, binop.rhs); if (!rhs_res.ok()) // Skip if not implicitly castable to lhs continue; this->m_rhs = handle_res(std::move(this->m_rhs), rhs_res, state); return binop.result(binop, lhs_ty, rhs_ty); } else if (types::types_equal(binop.rhs, rhs_ty)) { auto lhs_res = check_type(state, lhs_ty, binop.lhs); if (!lhs_res.ok()) // Skip if not implicitly castable to rhs continue; this->m_lhs = handle_res(std::move(this->m_lhs), lhs_res, state); return binop.result(binop, lhs_ty, rhs_ty); } } // Finally check for any binop that allows the result to be implicitly // casted to the result for (auto& binop : state.binops) { if (expected_ty) { // Skip any binops that would not even be implicitly castable to // the expected result auto result_res = check_type(state, binop.result(binop, lhs_ty, rhs_ty), *expected_ty); if (!result_res.ok()) continue; } auto lhs_result = check_type(state, lhs_ty, binop.lhs); auto rhs_result = check_type(state, rhs_ty, binop.rhs); this->m_lhs = handle_res(std::move(this->m_lhs), lhs_result, state); this->m_rhs = handle_res(std::move(this->m_rhs), lhs_result, state); return binop.result(binop, lhs_ty, rhs_ty); } // No suitable binops found :( state.errors.push_back(CompileError( "No suitable binop between " + lhs_ty->formatted() + " " + types::format_operator(this->m_binop) + " " + rhs_ty->formatted(), this->m_meta)); return std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } std::shared_ptr FunctionCallExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { auto expr_ty = this->m_fn_expr->typecheck(state, scope, {}); if (expr_ty->m_kind != types::TypeKind::Function) { state.errors.push_back(CompileError("Tried calling a non-function", this->m_meta)); return std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } auto fn_ty = dynamic_cast(expr_ty.get()); if (this->m_args.size() < fn_ty->m_param_tys.size()) { state.errors.push_back(CompileError("too few arguments", this->m_meta)); } else if (this->m_args.size() > fn_ty->m_param_tys.size() && !fn_ty->m_vararg) { state.errors.push_back(CompileError("too many arguments", this->m_meta)); } else { for (int i = 0; i < static_cast(this->m_args.size()); i++) { if (i < static_cast(fn_ty->m_param_tys.size())) { auto expected_param_ty = fn_ty->m_param_tys[i]; auto param_ty = this->m_args[i]->typecheck(state, scope, expected_param_ty); auto check_res = check_type(state, param_ty, expected_param_ty); this->m_args[i] = handle_res(std::move(this->m_args[i]), check_res, state); } else { this->m_args[i]->typecheck(state, scope, {}); } } } return fn_ty->m_ret_ty; } std::shared_ptr CastExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { auto expr_ty = this->m_expr->typecheck(state, scope, {}); auto cast = types::find_cast(state.casts, expr_ty, this->m_ty); if (cast) { return cast->target_ty; } state.errors.push_back(CompileError("Cast from type " + expr_ty->formatted() + "to type " + this->m_ty->formatted() + " is not permitted", this->m_meta)); return std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } std::shared_ptr RefExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { auto expr_ty = this->m_expr->typecheck(state, scope, {}); return std::shared_ptr { new types::PointerType{ expr_ty } }; } std::shared_ptr DerefExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { auto expr_ty = this->m_expr->typecheck(state, scope, {}); if (expr_ty->m_kind != types::TypeKind::Pointer) { state.errors.push_back( CompileError("Tried to deref " + expr_ty->formatted(), this->m_meta)); return std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } auto ptr_ty = dynamic_cast(expr_ty.get()); return ptr_ty->m_inner; } std::shared_ptr IndexAccessExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { auto expr_ty = this->m_expr->typecheck(state, scope, {}); if (expr_ty->m_kind != types::TypeKind::Pointer && expr_ty->m_kind != types::TypeKind::Array) { state.errors.push_back( CompileError("Tried to index " + expr_ty->formatted(), this->m_meta)); return std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } if (expr_ty->m_kind == types::TypeKind::Pointer) { auto ptr_ty = dynamic_cast(expr_ty.get()); return ptr_ty->m_inner; } else if (expr_ty->m_kind == types::TypeKind::Array) { auto ptr_ty = dynamic_cast(expr_ty.get()); return ptr_ty->m_inner; } // Default return type return std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } void ReturnStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto res_ty = this->m_expr->typecheck(state, scope, scope.return_ty); if (scope.return_ty) { auto check_res = check_type(state, res_ty, *scope.return_ty); this->m_expr = handle_res(std::move(this->m_expr), check_res, state); } } void InitializationStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { if (this->m_expr) { auto expr_ty = (*this->m_expr)->typecheck(state, scope, this->m_type); auto check_res = check_type(state, expr_ty, this->m_type); this->m_expr = handle_res(std::move(*this->m_expr), check_res, state); } scope.symbols[this->m_name] = this->m_type; } void ExpressionStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { this->m_expr->typecheck(state, scope, {}); } void IfStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto bool_ty = std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Bool } }; auto expr_ty = this->m_condition->typecheck(state, scope, bool_ty); auto check_res = check_type(state, expr_ty, bool_ty); this->m_condition = handle_res(std::move(this->m_condition), check_res, state); this->m_then->typecheck(state, scope); if (this->m_else) { (*this->m_else)->typecheck(state, scope); } } void Function::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto return_ty = this->m_return_ty; std::vector> param_tys{}; for (auto& param : this->m_params) { param_tys.push_back(param.second); } auto function_ty = new types::FunctionType{ return_ty, param_tys, this->m_is_vararg }; scope.symbols[this->m_name] = std::shared_ptr{ function_ty }; typecheck::Scope inner{ scope }; inner.return_ty = return_ty; for (auto& param : this->m_params) { if (param.first) { inner.symbols[*param.first] = param.second; } } if (this->m_statements) { for (auto& statement : *this->m_statements) { statement->typecheck(state, inner); } } } }