Allow forward-declaration of structs

This commit is contained in:
Sofia 2026-04-16 01:39:07 +03:00
parent 21d17bb02d
commit f0607a2310
7 changed files with 213 additions and 16 deletions

View File

@ -26,6 +26,7 @@ namespace AST {
Expression(token::Metadata meta) : Node{ meta } {}
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) = 0;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) = 0;
virtual void typecheck_preprocess(typecheck::Scope& scope) = 0;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -37,6 +38,7 @@ namespace AST {
public:
Statement(token::Metadata meta) : Node{ meta } {}
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) = 0;
virtual void typecheck_preprocess(typecheck::Scope& scope) = 0;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) = 0;
};
@ -57,6 +59,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -74,6 +77,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -91,6 +95,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -119,6 +124,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -144,6 +150,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -169,6 +176,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -191,6 +199,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -213,6 +222,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -238,6 +248,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -263,6 +274,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -288,6 +300,7 @@ namespace AST {
virtual std::string formatted() override;
virtual codegen::StackValue codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual std::shared_ptr<types::Type> get_codegen_type(codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual std::shared_ptr<types::Type> typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -306,6 +319,7 @@ namespace AST {
virtual ~ReturnStatement() override = default;
virtual std::string formatted() override;
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override;
};
@ -328,6 +342,7 @@ namespace AST {
virtual ~InitializationStatement() override = default;
virtual std::string formatted() override;
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override;
};
@ -341,6 +356,7 @@ namespace AST {
virtual ~ExpressionStatement() override = default;
virtual std::string formatted() override;
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override;
};
@ -362,6 +378,7 @@ namespace AST {
virtual ~IfStatement() override = default;
virtual std::string formatted() override;
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override;
};
@ -369,6 +386,7 @@ namespace AST {
public:
TopLevelStatement(token::Metadata meta) : Node{ meta } {}
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) = 0;
virtual void typecheck_preprocess(typecheck::Scope& scope) = 0;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) = 0;
};
@ -397,6 +415,7 @@ namespace AST {
virtual ~Function() override = default;
virtual std::string formatted() override;
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override;
};
@ -413,6 +432,7 @@ namespace AST {
virtual ~TopLevelTypedef() override = default;
virtual std::string formatted() override;
virtual void codegen(codegen::Builder& builder, codegen::Scope& scope) override;
virtual void typecheck_preprocess(typecheck::Scope& scope) override;
virtual void typecheck(typecheck::State& state, typecheck::Scope& scope) override;
};
}

View File

@ -99,16 +99,27 @@ std::optional<CompileOutput> compile(std::string_view in_filename) {
// Perform static analysis
typecheck::Scope preprocess_scope{};
// Preprocess
for (auto& tls : statements) {
std::cout << tls->formatted() << std::endl;
tls->typecheck_preprocess(preprocess_scope);
}
// Actual typechecking
typecheck::State typecheck_state{};
typecheck_state.binops = types::create_binops();
typecheck_state.casts = types::create_casts();
typecheck::Scope typecheck_scope{};
for (auto& tls : statements) {
std::cout << tls->formatted() << std::endl;
tls->typecheck(typecheck_state, typecheck_scope);
}
// Error checking
if (typecheck_state.errors.size() > 0) {
std::cerr << "Errors while typechecking:" << std::endl;
for (auto& error : typecheck_state.errors) {

View File

@ -5,6 +5,8 @@
namespace parsing {
namespace {
static uint32_t struct_id_counter = 0;
Result<std::unique_ptr<AST::Expression>, std::string> parse_expression(token::TokenStream& stream, Scope& scope);
Result<std::shared_ptr<types::Type>, std::string> parse_type(token::TokenStream& stream, Scope& scope) {
@ -54,13 +56,27 @@ namespace parsing {
if (struct_name && !maybe_fields && scope.structs.find(*struct_name) != scope.structs.end()) {
auto original_ty = scope.structs[*struct_name];
auto original_struct_ty = dynamic_cast<types::StructType*>(original_ty.get());
auto ty = new types::StructType{ struct_name, original_struct_ty->m_fields, true };
auto ty = new types::StructType{ struct_name, original_struct_ty->m_fields, true, false, original_struct_ty->m_id };
returned = std::shared_ptr<types::Type>{ ty };
}
else {
auto ty = new types::StructType{ struct_name, maybe_fields, false };
if (scope.structs.find(*struct_name) != scope.structs.end()) {
auto original_ty = scope.structs[*struct_name];
auto original_struct_ty = dynamic_cast<types::StructType*>(original_ty.get());
if (!original_struct_ty->m_fields.has_value()) {
auto ty = new types::StructType{ struct_name, maybe_fields, false, true, original_struct_ty->m_id };
returned = std::shared_ptr<types::Type>{ ty };
}
else {
auto ty = new types::StructType{ struct_name, maybe_fields, false, false, struct_id_counter++ };
returned = std::shared_ptr<types::Type>{ ty };
}
}
else {
auto ty = new types::StructType{ struct_name, maybe_fields, false, false, struct_id_counter++ };
returned = std::shared_ptr<types::Type>{ ty };
}
}
}
else {
// TODO eventually make this be potentially more than one word
@ -560,9 +576,15 @@ namespace parsing {
if (ty->m_kind == types::TypeKind::Struct) {
auto struct_ty = dynamic_cast<types::StructType*>(ty.get());
if (!struct_ty->m_is_ref && struct_ty->m_name) {
if (scope.structs.find(*struct_ty->m_name) != scope.structs.end() && struct_ty->m_is_def) {
auto true_ty = dynamic_cast<types::StructType*>(scope.structs[*struct_ty->m_name].get());
true_ty->m_fields = struct_ty->m_fields;
}
else {
scope.structs[*struct_ty->m_name] = ty;
}
}
}
stream.m_position = inner.m_position;
auto tl_typedef = new AST::TopLevelTypedef{

View File

@ -55,9 +55,56 @@ namespace {
return expr;
}
}
std::shared_ptr<types::Type> refresh_type(typecheck::Scope& scope, std::shared_ptr<types::Type> ty) {
if (ty->m_kind == types::TypeKind::Fundamental) {
return ty;
}
else if (ty->m_kind == types::TypeKind::Array) {
auto array_ty = dynamic_cast<types::ArrayType*>(ty.get());
array_ty->m_inner = refresh_type(scope, array_ty->m_inner);
return ty;
}
else if (ty->m_kind == types::TypeKind::Function) {
auto function_ty = dynamic_cast<types::FunctionType*>(ty.get());
function_ty->m_ret_ty = refresh_type(scope, function_ty->m_ret_ty);
for (int i = 0; i < static_cast<int>(function_ty->m_param_tys.size()); i++) {
function_ty->m_param_tys[i] = refresh_type(scope, function_ty->m_param_tys[i]);
}
return ty;
}
else if (ty->m_kind == types::TypeKind::Pointer) {
auto ptr_ty = dynamic_cast<types::PointerType*>(ty.get());
ptr_ty->m_inner = refresh_type(scope, ptr_ty->m_inner);
return ty;
}
else if (ty->m_kind == types::TypeKind::Struct) {
auto struct_ty = dynamic_cast<types::StructType*>(ty.get());
if (struct_ty->m_is_ref || struct_ty->m_is_def) {
if (scope.structs.find(*struct_ty->m_name) != scope.structs.end()) {
auto pre_existing = dynamic_cast<types::StructType*>(scope.structs[*struct_ty->m_name].get());
struct_ty->m_fields = pre_existing->m_fields;
}
}
if (struct_ty->m_fields) {
for (int i = 0; i < static_cast<int>((*struct_ty->m_fields).size()); i++) {
(*struct_ty->m_fields)[i].second = refresh_type(scope, (*struct_ty->m_fields)[i].second);
}
}
return ty;
}
else {
return ty;
}
}
}
namespace AST {
void IntLiteralExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_ty = refresh_type(scope, this->m_ty);
}
std::shared_ptr<types::Type> IntLiteralExpression::typecheck(
typecheck::State&,
typecheck::Scope&,
@ -81,6 +128,8 @@ namespace AST {
return this->m_ty;
}
void StringLiteralExpression::typecheck_preprocess(typecheck::Scope&) {}
std::shared_ptr<types::Type> StringLiteralExpression::typecheck(
typecheck::State&,
typecheck::Scope&,
@ -93,6 +142,8 @@ namespace AST {
return std::shared_ptr<types::Type>{ptr_ty};
}
void ValueReferenceExpression::typecheck_preprocess(typecheck::Scope&) {}
std::shared_ptr<types::Type> ValueReferenceExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -108,6 +159,11 @@ namespace AST {
};
}
void BinaryOperationExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_lhs->typecheck_preprocess(scope);
this->m_rhs->typecheck_preprocess(scope);
}
std::shared_ptr<types::Type> BinaryOperationExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -194,6 +250,13 @@ namespace AST {
new types::FundamentalType{ types::FundamentalTypeKind::Void } };
}
void FunctionCallExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_fn_expr->typecheck_preprocess(scope);
for (auto& expr : this->m_args) {
expr->typecheck_preprocess(scope);
}
}
std::shared_ptr<types::Type> FunctionCallExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -234,6 +297,10 @@ namespace AST {
return fn_ty->m_ret_ty;
}
void CastExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_ty = refresh_type(scope, this->m_ty);
this->m_expr->typecheck_preprocess(scope);
}
std::shared_ptr<types::Type> CastExpression::typecheck(
typecheck::State& state,
@ -253,6 +320,10 @@ namespace AST {
} };
}
void RefExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_expr->typecheck_preprocess(scope);
}
std::shared_ptr<types::Type> RefExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -264,6 +335,10 @@ namespace AST {
};
}
void DerefExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_expr->typecheck_preprocess(scope);
}
std::shared_ptr<types::Type> DerefExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -281,6 +356,10 @@ namespace AST {
return ptr_ty->m_inner;
}
void IndexAccessExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_expr->typecheck_preprocess(scope);
}
std::shared_ptr<types::Type> IndexAccessExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -309,6 +388,10 @@ namespace AST {
};
}
void FieldAccessExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_expr->typecheck_preprocess(scope);
}
std::shared_ptr<types::Type> FieldAccessExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -342,6 +425,13 @@ namespace AST {
};
}
void ListInitializerExpression::typecheck_preprocess(typecheck::Scope& scope) {
this->m_ty = refresh_type(scope, this->m_ty);
for (auto& expr : this->m_expressions) {
expr->typecheck_preprocess(scope);
}
}
std::shared_ptr<types::Type> ListInitializerExpression::typecheck(
typecheck::State& state,
typecheck::Scope& scope,
@ -435,6 +525,9 @@ namespace AST {
}
}
void ReturnStatement::typecheck_preprocess(typecheck::Scope& scope) {
this->m_expr->typecheck_preprocess(scope);
}
void ReturnStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) {
auto res_ty = this->m_expr->typecheck(state, scope, scope.return_ty);
@ -444,6 +537,12 @@ namespace AST {
}
}
void InitializationStatement::typecheck_preprocess(typecheck::Scope& scope) {
this->m_type = refresh_type(scope, this->m_type);
if (this->m_expr)
(*this->m_expr)->typecheck_preprocess(scope);
}
void InitializationStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) {
if (this->m_expr) {
auto expr_ty = (*this->m_expr)->typecheck(state, scope, this->m_type);
@ -453,10 +552,21 @@ namespace AST {
scope.symbols[this->m_name] = this->m_type;
}
void ExpressionStatement::typecheck_preprocess(typecheck::Scope& scope) {
this->m_expr->typecheck_preprocess(scope);
}
void ExpressionStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) {
this->m_expr->typecheck(state, scope, {});
}
void IfStatement::typecheck_preprocess(typecheck::Scope& scope) {
this->m_condition->typecheck_preprocess(scope);
this->m_then->typecheck_preprocess(scope);
if (this->m_else)
(*this->m_else)->typecheck_preprocess(scope);
}
void IfStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) {
auto bool_ty = std::shared_ptr<types::Type>{
new types::FundamentalType{ types::FundamentalTypeKind::Bool } };
@ -471,6 +581,19 @@ namespace AST {
}
}
void Function::typecheck_preprocess(typecheck::Scope& scope) {
this->m_return_ty = refresh_type(scope, this->m_return_ty);
for (auto& param : this->m_params) {
param.second = refresh_type(scope, param.second);
}
if (this->m_statements) {
for (auto& statement : *this->m_statements) {
statement->typecheck_preprocess(scope);
}
}
}
void Function::typecheck(typecheck::State& state, typecheck::Scope& scope) {
auto return_ty = this->m_return_ty;
std::vector<std::shared_ptr<types::Type>> param_tys{};
@ -497,20 +620,27 @@ namespace AST {
}
}
void TopLevelTypedef::typecheck(typecheck::State& state, typecheck::Scope& scope) {
void TopLevelTypedef::typecheck_preprocess(typecheck::Scope& scope) {
if (this->m_ty->m_kind == types::TypeKind::Struct) {
auto struct_ty = dynamic_cast<types::StructType*>(this->m_ty.get());
if (struct_ty->m_is_ref) {
if (struct_ty->m_is_ref || struct_ty->m_is_def) {
return;
}
if (struct_ty->m_name) {
if (scope.structs.find(*struct_ty->m_name) == scope.structs.end()) {
scope.structs[*struct_ty->m_name] = this->m_ty;
}
else {
state.errors.push_back(CompileError("Struct " + *struct_ty->m_name + " declared twice!", this->m_meta));
}
}
void TopLevelTypedef::typecheck(typecheck::State&, typecheck::Scope& scope) {
if (this->m_ty->m_kind == types::TypeKind::Struct) {
auto struct_ty = dynamic_cast<types::StructType*>(this->m_ty.get());
if (struct_ty->m_is_ref || struct_ty->m_is_def) {
return;
}
if (struct_ty->m_name) {
scope.structs[*struct_ty->m_name] = this->m_ty;
}
}
}
}

View File

@ -163,7 +163,7 @@ namespace types {
std::string StructType::formatted() {
std::stringstream out{ "" };
out << "struct";
out << "struct(" << this->m_id << ")";
if (this->m_is_ref)
out << "(ref)";
out << " ";
@ -171,6 +171,7 @@ namespace types {
if (this->m_name) {
out << *this->m_name << " ";
}
if (this->m_fields) {
out << "{ ";
int counter = 0;
@ -262,9 +263,13 @@ namespace types {
auto ty1 = dynamic_cast<StructType*>(type1.get());
auto ty2 = dynamic_cast<StructType*>(type2.get());
if (ty1->m_is_ref || ty2->m_is_ref)
return ty1->m_id == ty2->m_id;
if (ty1->m_fields.has_value() != ty2->m_fields.has_value())
return false;
if (ty1->m_fields) {
if (ty1->m_fields->size() != ty2->m_fields->size())
return false;

View File

@ -105,9 +105,11 @@ namespace types {
std::optional<std::string> m_name;
std::optional<std::vector<StructField>> m_fields;
bool m_is_ref;
bool m_is_def;
uint32_t m_id;
StructType(std::optional<std::string> name, std::optional<std::vector<StructField>> fields, bool is_ref)
: Type(TypeKind::Struct), m_name{ name }, m_fields{ fields }, m_is_ref{ is_ref } {
StructType(std::optional<std::string> name, std::optional<std::vector<StructField>> fields, bool is_ref, bool is_def, uint32_t id)
: Type(TypeKind::Struct), m_name{ name }, m_fields{ fields }, m_is_ref{ is_ref }, m_is_def{ is_def }, m_id{ id } {
}
virtual ~StructType() override = default;
virtual std::string formatted() override;

9
test.c
View File

@ -10,8 +10,14 @@ void change_first(char otus[5]) {
otus[0] = 115;
}
struct Otus;
void update(struct Otus potus) {
potus.field = 20;
}
struct Otus {
int field
int field;
};
int main() {
@ -25,6 +31,7 @@ int main() {
printf(" first element: %d!", somelist[0]);
struct Otus otus = { 5 };
update(otus);
printf(" first field: %d!", otus.field);