diff --git a/src/ast.h b/src/ast.h index 13f0d35..6160efc 100644 --- a/src/ast.h +++ b/src/ast.h @@ -26,6 +26,7 @@ namespace AST { Expression(token::Metadata meta) : Node{ meta } {} virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) = 0; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) = 0; + virtual void typecheck_preprocess(typecheck::Scope& scope) = 0; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -37,6 +38,7 @@ namespace AST { public: Statement(token::Metadata meta) : Node{ meta } {} virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) = 0; + virtual void typecheck_preprocess(typecheck::Scope& scope) = 0; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) = 0; }; @@ -57,6 +59,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -74,6 +77,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -91,6 +95,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -119,6 +124,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -144,6 +150,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -169,6 +176,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -191,6 +199,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -213,6 +222,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -238,6 +248,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -263,6 +274,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -288,6 +300,7 @@ namespace AST { virtual std::string formatted() override; virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override; virtual std::shared_ptr get_codegen_type(codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual std::shared_ptr typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -306,6 +319,7 @@ namespace AST { virtual ~ReturnStatement() override = default; virtual std::string formatted() override; virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override; }; @@ -328,6 +342,7 @@ namespace AST { virtual ~InitializationStatement() override = default; virtual std::string formatted() override; virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override; }; @@ -341,6 +356,7 @@ namespace AST { virtual ~ExpressionStatement() override = default; virtual std::string formatted() override; virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override; }; @@ -362,6 +378,7 @@ namespace AST { virtual ~IfStatement() override = default; virtual std::string formatted() override; virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override; }; @@ -369,6 +386,7 @@ namespace AST { public: TopLevelStatement(token::Metadata meta) : Node{ meta } {} virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) = 0; + virtual void typecheck_preprocess(typecheck::Scope& scope) = 0; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) = 0; }; @@ -397,6 +415,7 @@ namespace AST { virtual ~Function() override = default; virtual std::string formatted() override; virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override; }; @@ -413,6 +432,7 @@ namespace AST { virtual ~TopLevelTypedef() override = default; virtual std::string formatted() override; virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override; + virtual void typecheck_preprocess(typecheck::Scope& scope) override; virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override; }; } diff --git a/src/main.cpp b/src/main.cpp index 490c78a..0b3f096 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -99,16 +99,27 @@ std::optional compile(std::string_view in_filename) { // Perform static analysis + typecheck::Scope preprocess_scope{}; + + // Preprocess + for (auto& tls : statements) { + std::cout << tls->formatted() << std::endl; + tls->typecheck_preprocess(preprocess_scope); + } + + // Actual typechecking typecheck::State typecheck_state{}; typecheck_state.binops = types::create_binops(); typecheck_state.casts = types::create_casts(); typecheck::Scope typecheck_scope{}; + for (auto& tls : statements) { std::cout << tls->formatted() << std::endl; tls->typecheck(typecheck_state, typecheck_scope); } + // Error checking if (typecheck_state.errors.size() > 0) { std::cerr << "Errors while typechecking:" << std::endl; for (auto& error : typecheck_state.errors) { diff --git a/src/parsing.cpp b/src/parsing.cpp index c98df24..f27f90a 100644 --- a/src/parsing.cpp +++ b/src/parsing.cpp @@ -5,6 +5,8 @@ namespace parsing { namespace { + static uint32_t struct_id_counter = 0; + Result, std::string> parse_expression(token::TokenStream& stream, Scope& scope); Result, std::string> parse_type(token::TokenStream& stream, Scope& scope) { @@ -54,12 +56,26 @@ namespace parsing { if (struct_name && !maybe_fields && scope.structs.find(*struct_name) != scope.structs.end()) { auto original_ty = scope.structs[*struct_name]; auto original_struct_ty = dynamic_cast(original_ty.get()); - auto ty = new types::StructType{ struct_name, original_struct_ty->m_fields, true }; + auto ty = new types::StructType{ struct_name, original_struct_ty->m_fields, true, false, original_struct_ty->m_id }; returned = std::shared_ptr{ ty }; } else { - auto ty = new types::StructType{ struct_name, maybe_fields, false }; - returned = std::shared_ptr{ ty }; + if (scope.structs.find(*struct_name) != scope.structs.end()) { + auto original_ty = scope.structs[*struct_name]; + auto original_struct_ty = dynamic_cast(original_ty.get()); + if (!original_struct_ty->m_fields.has_value()) { + auto ty = new types::StructType{ struct_name, maybe_fields, false, true, original_struct_ty->m_id }; + returned = std::shared_ptr{ ty }; + } + else { + auto ty = new types::StructType{ struct_name, maybe_fields, false, false, struct_id_counter++ }; + returned = std::shared_ptr{ ty }; + } + } + else { + auto ty = new types::StructType{ struct_name, maybe_fields, false, false, struct_id_counter++ }; + returned = std::shared_ptr{ ty }; + } } } else { @@ -560,7 +576,13 @@ namespace parsing { if (ty->m_kind == types::TypeKind::Struct) { auto struct_ty = dynamic_cast(ty.get()); if (!struct_ty->m_is_ref && struct_ty->m_name) { - scope.structs[*struct_ty->m_name] = ty; + if (scope.structs.find(*struct_ty->m_name) != scope.structs.end() && struct_ty->m_is_def) { + auto true_ty = dynamic_cast(scope.structs[*struct_ty->m_name].get()); + true_ty->m_fields = struct_ty->m_fields; + } + else { + scope.structs[*struct_ty->m_name] = ty; + } } } diff --git a/src/typechecker.cpp b/src/typechecker.cpp index c0fb498..9c88c5a 100644 --- a/src/typechecker.cpp +++ b/src/typechecker.cpp @@ -55,9 +55,56 @@ namespace { return expr; } } + + std::shared_ptr refresh_type(typecheck::Scope& scope, std::shared_ptr ty) { + if (ty->m_kind == types::TypeKind::Fundamental) { + return ty; + } + else if (ty->m_kind == types::TypeKind::Array) { + auto array_ty = dynamic_cast(ty.get()); + array_ty->m_inner = refresh_type(scope, array_ty->m_inner); + return ty; + } + else if (ty->m_kind == types::TypeKind::Function) { + auto function_ty = dynamic_cast(ty.get()); + function_ty->m_ret_ty = refresh_type(scope, function_ty->m_ret_ty); + for (int i = 0; i < static_cast(function_ty->m_param_tys.size()); i++) { + function_ty->m_param_tys[i] = refresh_type(scope, function_ty->m_param_tys[i]); + } + return ty; + } + else if (ty->m_kind == types::TypeKind::Pointer) { + auto ptr_ty = dynamic_cast(ty.get()); + ptr_ty->m_inner = refresh_type(scope, ptr_ty->m_inner); + return ty; + } + else if (ty->m_kind == types::TypeKind::Struct) { + auto struct_ty = dynamic_cast(ty.get()); + if (struct_ty->m_is_ref || struct_ty->m_is_def) { + if (scope.structs.find(*struct_ty->m_name) != scope.structs.end()) { + auto pre_existing = dynamic_cast(scope.structs[*struct_ty->m_name].get()); + struct_ty->m_fields = pre_existing->m_fields; + } + } + if (struct_ty->m_fields) { + for (int i = 0; i < static_cast((*struct_ty->m_fields).size()); i++) { + (*struct_ty->m_fields)[i].second = refresh_type(scope, (*struct_ty->m_fields)[i].second); + } + } + return ty; + } + else { + return ty; + } + } } namespace AST { + + void IntLiteralExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_ty = refresh_type(scope, this->m_ty); + } + std::shared_ptr IntLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, @@ -81,6 +128,8 @@ namespace AST { return this->m_ty; } + void StringLiteralExpression::typecheck_preprocess(typecheck::Scope&) {} + std::shared_ptr StringLiteralExpression::typecheck( typecheck::State&, typecheck::Scope&, @@ -93,6 +142,8 @@ namespace AST { return std::shared_ptr{ptr_ty}; } + void ValueReferenceExpression::typecheck_preprocess(typecheck::Scope&) {} + std::shared_ptr ValueReferenceExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -108,6 +159,11 @@ namespace AST { }; } + void BinaryOperationExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_lhs->typecheck_preprocess(scope); + this->m_rhs->typecheck_preprocess(scope); + } + std::shared_ptr BinaryOperationExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -194,6 +250,13 @@ namespace AST { new types::FundamentalType{ types::FundamentalTypeKind::Void } }; } + void FunctionCallExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_fn_expr->typecheck_preprocess(scope); + for (auto& expr : this->m_args) { + expr->typecheck_preprocess(scope); + } + } + std::shared_ptr FunctionCallExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -234,6 +297,10 @@ namespace AST { return fn_ty->m_ret_ty; } + void CastExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_ty = refresh_type(scope, this->m_ty); + this->m_expr->typecheck_preprocess(scope); + } std::shared_ptr CastExpression::typecheck( typecheck::State& state, @@ -253,6 +320,10 @@ namespace AST { } }; } + void RefExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_expr->typecheck_preprocess(scope); + } + std::shared_ptr RefExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -264,6 +335,10 @@ namespace AST { }; } + void DerefExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_expr->typecheck_preprocess(scope); + } + std::shared_ptr DerefExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -281,6 +356,10 @@ namespace AST { return ptr_ty->m_inner; } + void IndexAccessExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_expr->typecheck_preprocess(scope); + } + std::shared_ptr IndexAccessExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -309,6 +388,10 @@ namespace AST { }; } + void FieldAccessExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_expr->typecheck_preprocess(scope); + } + std::shared_ptr FieldAccessExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -342,6 +425,13 @@ namespace AST { }; } + void ListInitializerExpression::typecheck_preprocess(typecheck::Scope& scope) { + this->m_ty = refresh_type(scope, this->m_ty); + for (auto& expr : this->m_expressions) { + expr->typecheck_preprocess(scope); + } + } + std::shared_ptr ListInitializerExpression::typecheck( typecheck::State& state, typecheck::Scope& scope, @@ -435,6 +525,9 @@ namespace AST { } } + void ReturnStatement::typecheck_preprocess(typecheck::Scope& scope) { + this->m_expr->typecheck_preprocess(scope); + } void ReturnStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto res_ty = this->m_expr->typecheck(state, scope, scope.return_ty); @@ -444,6 +537,12 @@ namespace AST { } } + void InitializationStatement::typecheck_preprocess(typecheck::Scope& scope) { + this->m_type = refresh_type(scope, this->m_type); + if (this->m_expr) + (*this->m_expr)->typecheck_preprocess(scope); + } + void InitializationStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { if (this->m_expr) { auto expr_ty = (*this->m_expr)->typecheck(state, scope, this->m_type); @@ -453,10 +552,21 @@ namespace AST { scope.symbols[this->m_name] = this->m_type; } + void ExpressionStatement::typecheck_preprocess(typecheck::Scope& scope) { + this->m_expr->typecheck_preprocess(scope); + } + void ExpressionStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { this->m_expr->typecheck(state, scope, {}); } + void IfStatement::typecheck_preprocess(typecheck::Scope& scope) { + this->m_condition->typecheck_preprocess(scope); + this->m_then->typecheck_preprocess(scope); + if (this->m_else) + (*this->m_else)->typecheck_preprocess(scope); + } + void IfStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto bool_ty = std::shared_ptr{ new types::FundamentalType{ types::FundamentalTypeKind::Bool } }; @@ -471,6 +581,19 @@ namespace AST { } } + void Function::typecheck_preprocess(typecheck::Scope& scope) { + this->m_return_ty = refresh_type(scope, this->m_return_ty); + for (auto& param : this->m_params) { + param.second = refresh_type(scope, param.second); + } + + if (this->m_statements) { + for (auto& statement : *this->m_statements) { + statement->typecheck_preprocess(scope); + } + } + } + void Function::typecheck(typecheck::State& state, typecheck::Scope& scope) { auto return_ty = this->m_return_ty; std::vector> param_tys{}; @@ -497,19 +620,26 @@ namespace AST { } } - void TopLevelTypedef::typecheck(typecheck::State& state, typecheck::Scope& scope) { + void TopLevelTypedef::typecheck_preprocess(typecheck::Scope& scope) { if (this->m_ty->m_kind == types::TypeKind::Struct) { auto struct_ty = dynamic_cast(this->m_ty.get()); - if (struct_ty->m_is_ref) { + if (struct_ty->m_is_ref || struct_ty->m_is_def) { return; } if (struct_ty->m_name) { - if (scope.structs.find(*struct_ty->m_name) == scope.structs.end()) { - scope.structs[*struct_ty->m_name] = this->m_ty; - } - else { - state.errors.push_back(CompileError("Struct " + *struct_ty->m_name + " declared twice!", this->m_meta)); - } + scope.structs[*struct_ty->m_name] = this->m_ty; + } + } + } + + void TopLevelTypedef::typecheck(typecheck::State&, typecheck::Scope& scope) { + if (this->m_ty->m_kind == types::TypeKind::Struct) { + auto struct_ty = dynamic_cast(this->m_ty.get()); + if (struct_ty->m_is_ref || struct_ty->m_is_def) { + return; + } + if (struct_ty->m_name) { + scope.structs[*struct_ty->m_name] = this->m_ty; } } } diff --git a/src/types.cpp b/src/types.cpp index e643b6e..d8b0ffd 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -163,7 +163,7 @@ namespace types { std::string StructType::formatted() { std::stringstream out{ "" }; - out << "struct"; + out << "struct(" << this->m_id << ")"; if (this->m_is_ref) out << "(ref)"; out << " "; @@ -171,6 +171,7 @@ namespace types { if (this->m_name) { out << *this->m_name << " "; } + if (this->m_fields) { out << "{ "; int counter = 0; @@ -262,9 +263,13 @@ namespace types { auto ty1 = dynamic_cast(type1.get()); auto ty2 = dynamic_cast(type2.get()); + if (ty1->m_is_ref || ty2->m_is_ref) + return ty1->m_id == ty2->m_id; + if (ty1->m_fields.has_value() != ty2->m_fields.has_value()) return false; + if (ty1->m_fields) { if (ty1->m_fields->size() != ty2->m_fields->size()) return false; diff --git a/src/types.h b/src/types.h index e9f923a..c649cd4 100644 --- a/src/types.h +++ b/src/types.h @@ -105,9 +105,11 @@ namespace types { std::optional m_name; std::optional> m_fields; bool m_is_ref; + bool m_is_def; + uint32_t m_id; - StructType(std::optional name, std::optional> fields, bool is_ref) - : Type(TypeKind::Struct), m_name{ name }, m_fields{ fields }, m_is_ref{ is_ref } { + StructType(std::optional name, std::optional> fields, bool is_ref, bool is_def, uint32_t id) + : Type(TypeKind::Struct), m_name{ name }, m_fields{ fields }, m_is_ref{ is_ref }, m_is_def{ is_def }, m_id{ id } { } virtual ~StructType() override = default; virtual std::string formatted() override; diff --git a/test.c b/test.c index 04f99bc..59f59f1 100644 --- a/test.c +++ b/test.c @@ -10,8 +10,14 @@ void change_first(char otus[5]) { otus[0] = 115; } + +struct Otus; + +void update(struct Otus potus) { + potus.field = 20; +} struct Otus { - int field + int field; }; int main() { @@ -25,6 +31,7 @@ int main() { printf(" first element: %d!", somelist[0]); struct Otus otus = { 5 }; + update(otus); printf(" first field: %d!", otus.field);