diff --git a/src/ast.h b/src/ast.h index 6160efc..e322257 100644 --- a/src/ast.h +++ b/src/ast.h @@ -27,7 +27,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) = 0; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) = 0; virtual void typecheck_preprocess(typecheck::Scope& scope) = 0; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -60,7 +60,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -78,7 +78,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -96,7 +96,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -125,7 +125,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -151,7 +151,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -177,7 +177,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -200,7 +200,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -223,7 +223,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -249,7 +249,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -275,7 +275,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -301,7 +301,7 @@ namespace AST { virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; virtual void typecheck_preprocess(typecheck::Scope& scope) override; - virtual std::shared_ptr typecheck( + virtual typecheck::ExpressionType typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty diff --git a/src/typechecker.cpp b/src/typechecker.cpp index 4d040be..2eaeb08 100644 --- a/src/typechecker.cpp +++ b/src/typechecker.cpp @@ -105,7 +105,7 @@ namespace AST { this->m_ty = refresh_type(scope, this->m_ty); } - std::shared_ptr IntLiteralExpression::typecheck( + typecheck::ExpressionType IntLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, std::optional> expected_ty @@ -125,12 +125,12 @@ namespace AST { } } - return this->m_ty; + return { this->m_ty, false }; } void StringLiteralExpression::typecheck_preprocess(typecheck::Scope&) {} - std::shared_ptr StringLiteralExpression::typecheck( + typecheck::ExpressionType StringLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, std::optional> @@ -139,24 +139,24 @@ namespace AST { new types::FundamentalType{ types::FundamentalTypeKind::Char } }; auto ptr_ty = new types::ArrayType{ char_ty, static_cast(this->m_value.size()) + 1, true }; - return std::shared_ptr{ptr_ty}; + return { std::shared_ptr{ptr_ty}, true }; } void ValueReferenceExpression::typecheck_preprocess(typecheck::Scope&) {} - std::shared_ptr ValueReferenceExpression::typecheck( + typecheck::ExpressionType ValueReferenceExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { if (scope.symbols.find(this->m_name) != scope.symbols.end()) { - return scope.symbols[this->m_name]; + return { scope.symbols[this->m_name], false }; } state.errors.push_back(CompileError("Value " + this->m_name + " not defined", this->m_meta)); - return std::shared_ptr{ + return { std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } void BinaryOperationExpression::typecheck_preprocess(typecheck::Scope& scope) { @@ -164,20 +164,20 @@ namespace AST { this->m_rhs->typecheck_preprocess(scope); } - std::shared_ptr BinaryOperationExpression::typecheck( + typecheck::ExpressionType BinaryOperationExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty ) { - auto lhs_ty = this->m_lhs->typecheck(state, scope, {}); - auto rhs_ty = this->m_rhs->typecheck(state, scope, {}); + auto lhs_ty = this->m_lhs->typecheck(state, scope, {}).type; + auto rhs_ty = this->m_rhs->typecheck(state, scope, {}).type; if (this->m_binop == types::BinOp::Assignment) { // Re-typecheck rhs to actually match lhs - auto rhs_ty = this->m_rhs->typecheck(state, scope, lhs_ty); + auto rhs_ty = this->m_rhs->typecheck(state, scope, lhs_ty).type; auto rhs_ty_res = check_type(state, rhs_ty, lhs_ty); this->m_rhs = handle_res(std::move(this->m_rhs), rhs_ty_res, state); - return lhs_ty; + return { lhs_ty, false }; } // Try to find a binop that matches exactly @@ -189,7 +189,7 @@ namespace AST { ); if (binop) { - return binop->result(*binop, lhs_ty, rhs_ty); + return { binop->result(*binop, lhs_ty, rhs_ty), false }; } // If that fails, try to find binop that matches on one side perfectly @@ -209,7 +209,7 @@ namespace AST { // Skip if not implicitly castable to lhs continue; this->m_rhs = handle_res(std::move(this->m_rhs), rhs_res, state); - return binop.result(binop, lhs_ty, rhs_ty); + return { binop.result(binop, lhs_ty, rhs_ty), false }; } else if (types::types_equal(binop.rhs, rhs_ty)) { auto lhs_res = check_type(state, lhs_ty, binop.lhs); @@ -217,7 +217,7 @@ namespace AST { // Skip if not implicitly castable to rhs continue; this->m_lhs = handle_res(std::move(this->m_lhs), lhs_res, state); - return binop.result(binop, lhs_ty, rhs_ty); + return { binop.result(binop, lhs_ty, rhs_ty), false }; } } @@ -235,7 +235,7 @@ namespace AST { auto rhs_result = check_type(state, rhs_ty, binop.rhs); this->m_lhs = handle_res(std::move(this->m_lhs), lhs_result, state); this->m_rhs = handle_res(std::move(this->m_rhs), lhs_result, state); - return binop.result(binop, lhs_ty, rhs_ty); + return { binop.result(binop, lhs_ty, rhs_ty), false }; } // No suitable binops found :( @@ -246,8 +246,8 @@ namespace AST { + rhs_ty->formatted(), this->m_meta)); - return std::shared_ptr{ - new types::FundamentalType{ types::FundamentalTypeKind::Void } }; + return { std::shared_ptr{ + new types::FundamentalType{ types::FundamentalTypeKind::Void } }, false }; } void FunctionCallExpression::typecheck_preprocess(typecheck::Scope& scope) { @@ -257,18 +257,18 @@ namespace AST { } } - std::shared_ptr FunctionCallExpression::typecheck( + typecheck::ExpressionType FunctionCallExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { - auto expr_ty = this->m_fn_expr->typecheck(state, scope, {}); + auto expr_ty = this->m_fn_expr->typecheck(state, scope, {}).type; if (expr_ty->m_kind != types::TypeKind::Function) { state.errors.push_back(CompileError("Tried calling a non-function", this->m_meta)); - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } auto fn_ty = dynamic_cast(expr_ty.get()); @@ -283,7 +283,7 @@ namespace AST { for (int i = 0; i < static_cast(this->m_args.size()); i++) { if (i < static_cast(fn_ty->m_param_tys.size())) { auto expected_param_ty = fn_ty->m_param_tys[i]; - auto param_ty = this->m_args[i]->typecheck(state, scope, expected_param_ty); + auto param_ty = this->m_args[i]->typecheck(state, scope, expected_param_ty).type; auto check_res = check_type(state, param_ty, expected_param_ty); this->m_args[i] = handle_res(std::move(this->m_args[i]), check_res, state); @@ -294,7 +294,7 @@ namespace AST { } } - return fn_ty->m_ret_ty; + return { fn_ty->m_ret_ty, false }; } void CastExpression::typecheck_preprocess(typecheck::Scope& scope) { @@ -302,127 +302,127 @@ namespace AST { this->m_expr->typecheck_preprocess(scope); } - std::shared_ptr CastExpression::typecheck( + typecheck::ExpressionType CastExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { - auto expr_ty = this->m_expr->typecheck(state, scope, {}); + auto expr_ty = this->m_expr->typecheck(state, scope, {}).type; auto cast = types::find_cast(state.casts, expr_ty, this->m_ty); if (cast) { - return cast->target_ty; + return { cast->target_ty, false }; } state.errors.push_back(CompileError("Cast from type " + expr_ty->formatted() + "to type " + this->m_ty->formatted() + " is not permitted", this->m_meta)); - return std::shared_ptr { new types::FundamentalType{ + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void - } }; + } }, false }; } void RefExpression::typecheck_preprocess(typecheck::Scope& scope) { this->m_expr->typecheck_preprocess(scope); } - std::shared_ptr RefExpression::typecheck( + typecheck::ExpressionType RefExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { - auto expr_ty = this->m_expr->typecheck(state, scope, {}); - return std::shared_ptr { + auto expr_ty = this->m_expr->typecheck(state, scope, {}).type; + return { std::shared_ptr { new types::PointerType{ expr_ty } - }; + }, false }; } void DerefExpression::typecheck_preprocess(typecheck::Scope& scope) { this->m_expr->typecheck_preprocess(scope); } - std::shared_ptr DerefExpression::typecheck( + typecheck::ExpressionType DerefExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { - auto expr_ty = this->m_expr->typecheck(state, scope, {}); + auto expr_ty = this->m_expr->typecheck(state, scope, {}).type; if (expr_ty->m_kind != types::TypeKind::Pointer) { state.errors.push_back( CompileError("Tried to deref " + expr_ty->formatted(), this->m_meta)); - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } auto ptr_ty = dynamic_cast(expr_ty.get()); - return ptr_ty->m_inner; + return { ptr_ty->m_inner, false }; } void IndexAccessExpression::typecheck_preprocess(typecheck::Scope& scope) { this->m_expr->typecheck_preprocess(scope); } - std::shared_ptr IndexAccessExpression::typecheck( + typecheck::ExpressionType IndexAccessExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { - auto expr_ty = this->m_expr->typecheck(state, scope, {}); + auto expr_ty = this->m_expr->typecheck(state, scope, {}).type; if (expr_ty->m_kind != types::TypeKind::Pointer && expr_ty->m_kind != types::TypeKind::Array) { state.errors.push_back( CompileError("Tried to index " + expr_ty->formatted(), this->m_meta)); - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } if (expr_ty->m_kind == types::TypeKind::Pointer) { auto ptr_ty = dynamic_cast(expr_ty.get()); - return ptr_ty->m_inner; + return { ptr_ty->m_inner, false }; } else if (expr_ty->m_kind == types::TypeKind::Array) { auto ptr_ty = dynamic_cast(expr_ty.get()); - return ptr_ty->m_inner; + return { ptr_ty->m_inner, false }; } // Default return type - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } void FieldAccessExpression::typecheck_preprocess(typecheck::Scope& scope) { this->m_expr->typecheck_preprocess(scope); } - std::shared_ptr FieldAccessExpression::typecheck( + typecheck::ExpressionType FieldAccessExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> ) { - auto expr_ty = this->m_expr->typecheck(state, scope, {}); + auto expr_ty = this->m_expr->typecheck(state, scope, {}).type; if (expr_ty->m_kind != types::TypeKind::Struct) { state.errors.push_back( CompileError("Tried to access " + expr_ty->formatted() + "." + this->m_field, this->m_meta)); - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } auto struct_ty = dynamic_cast(expr_ty.get()); if (struct_ty->m_fields) { for (auto& field : *struct_ty->m_fields) { if (field.first == this->m_field) { - return field.second; + return { field.second, false }; } } state.errors.push_back(CompileError("No such field", this->m_meta)); - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } state.errors.push_back(CompileError("Cannot access fields of opaque struct", this->m_meta)); - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, false }; } void ListInitializerExpression::typecheck_preprocess(typecheck::Scope& scope) { @@ -432,7 +432,7 @@ namespace AST { } } - std::shared_ptr ListInitializerExpression::typecheck( + typecheck::ExpressionType ListInitializerExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, std::optional> expected_ty @@ -441,7 +441,7 @@ namespace AST { if ((*expected_ty)->m_kind == types::TypeKind::Array) { auto array_ty = dynamic_cast(expected_ty->get()); for (auto& expr : this->m_expressions) { - auto expr_ty = expr->typecheck(state, scope, array_ty->m_inner); + auto expr_ty = expr->typecheck(state, scope, array_ty->m_inner).type; auto expr_res = check_type(state, expr_ty, array_ty->m_inner); expr = handle_res(std::move(expr), expr_res, state); } @@ -462,40 +462,40 @@ namespace AST { state.errors.push_back(CompileError( "Too many initializer values for " + struct_ty->formatted(), this->m_meta)); - return *expected_ty; + return { *expected_ty, true }; } for (int i = 0; i < static_cast(this->m_expressions.size()); i++) { auto expected_field = (*struct_ty->m_fields)[i]; - auto expr_ty = this->m_expressions[i]->typecheck(state, scope, expected_field.second); + auto expr_ty = this->m_expressions[i]->typecheck(state, scope, expected_field.second).type; auto res = check_type(state, expr_ty, expected_field.second); this->m_expressions[i] = handle_res(std::move(this->m_expressions[i]), res, state); } this->m_ty = *expected_ty; - return this->m_ty; + return { this->m_ty, true }; } else { if (this->m_expressions.size() > 0) { state.errors.push_back(CompileError( "Too many initializer values for " + struct_ty->formatted(), this->m_meta)); - return *expected_ty; + return { *expected_ty, true }; } else { this->m_ty = *expected_ty; - return this->m_ty; + return { this->m_ty, true }; } } } else { - return std::shared_ptr { + return { std::shared_ptr { new types::FundamentalType{ types::FundamentalTypeKind::Void } - }; + }, true }; } - return this->m_ty; + return { this->m_ty, true }; } // No expected ty, try to infer array type from elements @@ -508,12 +508,12 @@ namespace AST { true } }; - return this->m_ty; + return { this->m_ty, true }; } else { - auto first_expr_ty = this->m_expressions[0]->typecheck(state, scope, {}); + auto first_expr_ty = this->m_expressions[0]->typecheck(state, scope, {}).type; for (int i = 1; i < static_cast(this->m_expressions.size()); i++) { - auto expr_ty = this->m_expressions[i]->typecheck(state, scope, first_expr_ty); + auto expr_ty = this->m_expressions[i]->typecheck(state, scope, first_expr_ty).type; auto expr_res = check_type(state, expr_ty, first_expr_ty); this->m_expressions[i] = handle_res(std::move(this->m_expressions[i]), expr_res, state); } @@ -524,7 +524,7 @@ namespace AST { true } }; - return this->m_ty; + return { this->m_ty, true }; } } @@ -533,7 +533,7 @@ namespace AST { } void ReturnStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { - auto res_ty = this->m_expr->typecheck(state, scope, scope.return_ty); + auto res_ty = this->m_expr->typecheck(state, scope, scope.return_ty).type; if (scope.return_ty) { auto check_res = check_type(state, res_ty, *scope.return_ty); this->m_expr = handle_res(std::move(this->m_expr), check_res, state); @@ -549,7 +549,12 @@ namespace AST { void InitializationStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { if (this->m_expr) { auto expr_ty = (*this->m_expr)->typecheck(state, scope, this->m_type); - auto check_res = check_type(state, expr_ty, this->m_type); + if (this->m_type->m_kind == types::TypeKind::Array && !expr_ty.array_initializer) { + state.errors.push_back( + CompileError("Arrays can only be initialized with list-initializers or strings", this->m_meta) + ); + } + auto check_res = check_type(state, expr_ty.type, this->m_type); this->m_expr = handle_res(std::move(*this->m_expr), check_res, state); } scope.symbols[this->m_name] = this->m_type; @@ -573,7 +578,7 @@ namespace AST { void IfStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto bool_ty = std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Bool } }; - auto expr_ty = this->m_condition->typecheck(state, scope, bool_ty); + auto expr_ty = this->m_condition->typecheck(state, scope, bool_ty).type; auto check_res = check_type(state, expr_ty, bool_ty); this->m_condition = handle_res(std::move(this->m_condition), check_res, state); diff --git a/src/typechecker.h b/src/typechecker.h index 401b9cf..8ffdd7e 100644 --- a/src/typechecker.h +++ b/src/typechecker.h @@ -20,6 +20,11 @@ namespace typecheck { std::vector casts; std::vector errors; }; + + struct ExpressionType { + std::shared_ptr type; + bool array_initializer; + }; } #endif \ No newline at end of file