From e064b3f04dfcf60b3f1a91895ade2e7d4eaf924d Mon Sep 17 00:00:00 2001 From: Sofia Date: Sun, 10 May 2026 18:50:46 +0300 Subject: [PATCH] Improve typechecking for binops --- src/ast.cpp | 1 + src/ast.h | 2 +- src/binops.cpp | 27 +++++++++++++++++++++++++-- src/casting.cpp | 9 +++++++++ src/typechecker.cpp | 9 +++++++++ src/types.cpp | 3 +++ src/types.h | 3 +++ test.c | 2 +- 8 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/ast.cpp b/src/ast.cpp index 14e10e7..72cf9eb 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -6,6 +6,7 @@ namespace AST { std::string IntLiteralExpression::formatted() { std::stringstream out{ "" }; out << this->m_value; + out << this->m_ty->formatted(); return out.str(); } diff --git a/src/ast.h b/src/ast.h index 1b47b72..18840bb 100644 --- a/src/ast.h +++ b/src/ast.h @@ -55,7 +55,7 @@ namespace AST { : Expression{ meta } , m_value{ value } , m_ty{ { std::shared_ptr{ - new types::FundamentalType{true, types::FundamentalTypeKind::Int} + new types::FundamentalType{true, types::FundamentalTypeKind::AnyInt} } } } { } virtual ~IntLiteralExpression() override = default; diff --git a/src/binops.cpp b/src/binops.cpp index bb1934e..e83109d 100644 --- a/src/binops.cpp +++ b/src/binops.cpp @@ -107,13 +107,32 @@ namespace types { auto int_ty = std::shared_ptr{ new types::FundamentalType{ false, types::FundamentalTypeKind::Int } }; + auto uint_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::UInt } }; auto char_ty = std::shared_ptr{ new types::FundamentalType{ false, types::FundamentalTypeKind::Char } }; + auto uchar_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::UChar } }; + auto short_int_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::ShortInt } }; + auto ushort_int_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::UShortInt } }; + auto long_int_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::LongInt } }; + auto ulong_int_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::ULongInt } }; + auto long_long_int_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::LongLongInt } }; + auto ulong_long_int_ty = std::shared_ptr{ + new types::FundamentalType{ false, types::FundamentalTypeKind::ULongLongInt } }; auto bool_ty = std::shared_ptr{ new types::FundamentalType{ false, types::FundamentalTypeKind::Bool } }; // Integer Increment/Decrement unaries - for (auto& ty : { int_ty, char_ty }) { + for (auto& ty : { + short_int_ty, int_ty, long_int_ty, long_long_int_ty, char_ty, + ushort_int_ty, uint_ty, ulong_int_ty, ulong_long_int_ty, uchar_ty, + }) { definitions.push_back(UnopDefinition{ ty, types::Unary::AddPostfix, ty, [](codegen::Builder& builder, std::shared_ptr ty, llvm::Value* ptr) { @@ -168,7 +187,11 @@ namespace types { } // Not & Negation - for (auto& ty : { int_ty, char_ty, bool_ty }) { + for (auto& ty : { + short_int_ty, int_ty, long_int_ty, long_long_int_ty, char_ty, + ushort_int_ty, uint_ty, ulong_int_ty, ulong_long_int_ty, uchar_ty, + bool_ty + }) { definitions.push_back(UnopDefinition{ ty, types::Unary::Not, ty, [](codegen::Builder& builder, std::shared_ptr ty, llvm::Value* value) { diff --git a/src/casting.cpp b/src/casting.cpp index 568571d..b2e5595 100644 --- a/src/casting.cpp +++ b/src/casting.cpp @@ -28,6 +28,8 @@ namespace types { new FundamentalType{ false, FundamentalTypeKind::ULongLongInt } }; auto bool_ty = std::shared_ptr{ new FundamentalType{ false, FundamentalTypeKind::Bool } }; + auto any_int_ty = std::shared_ptr{ + new FundamentalType{ false, FundamentalTypeKind::AnyInt } }; auto numerical_types = { short_int_ty, int_ty, long_int_ty, long_long_int_ty, char_ty, @@ -35,6 +37,13 @@ namespace types { bool_ty }; + for (auto& target_ty : numerical_types) { + casts.push_back(CastDefinition{ any_int_ty, target_ty, true, + [](codegen::Builder&, std::shared_ptr, llvm::Value* value) { + return value; + } }); + } + for (auto& source_ty : numerical_types) { for (auto& target_ty : numerical_types) { if (types::types_equal(source_ty, target_ty)) { diff --git a/src/typechecker.cpp b/src/typechecker.cpp index 2b326c7..236e620 100644 --- a/src/typechecker.cpp +++ b/src/typechecker.cpp @@ -224,6 +224,10 @@ namespace AST { if (!rhs_res.ok()) // Skip if not implicitly castable to lhs continue; + + rhs_ty = this->m_rhs->typecheck(state, scope, binop.rhs).type; + rhs_res = check_type(state, rhs_ty, binop.rhs); + this->m_rhs = handle_res(std::move(this->m_rhs), rhs_res, state); return { binop.result(binop, lhs_ty, rhs_ty), false, false }; } @@ -232,6 +236,9 @@ namespace AST { if (!lhs_res.ok()) // Skip if not implicitly castable to rhs continue; + + lhs_ty = this->m_lhs->typecheck(state, scope, binop.lhs).type; + lhs_res = check_type(state, lhs_ty, binop.lhs); this->m_lhs = handle_res(std::move(this->m_lhs), lhs_res, state); return { binop.result(binop, lhs_ty, rhs_ty), false, false }; } @@ -247,6 +254,8 @@ namespace AST { if (!result_res.ok()) continue; } + lhs_ty = this->m_lhs->typecheck(state, scope, binop.lhs).type; + rhs_ty = this->m_rhs->typecheck(state, scope, binop.rhs).type; auto lhs_result = check_type(state, lhs_ty, binop.lhs); auto rhs_result = check_type(state, rhs_ty, binop.rhs); this->m_lhs = handle_res(std::move(this->m_lhs), lhs_result, state); diff --git a/src/types.cpp b/src/types.cpp index 6438902..512c607 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -70,6 +70,9 @@ namespace types { case FundamentalTypeKind::ULongLongInt: out << "ULongLongInt"; break; + case FundamentalTypeKind::AnyInt: + out << "AnyInt"; + break; case FundamentalTypeKind::Bool: out << "Bool"; break; diff --git a/src/types.h b/src/types.h index 1bd4845..b891f77 100644 --- a/src/types.h +++ b/src/types.h @@ -26,6 +26,9 @@ namespace types { ULongInt, ULongLongInt, + /// @brief stand-in type for integer literals + AnyInt, + Bool, Char, UChar, diff --git a/test.c b/test.c index d3ac9ca..f7af84c 100644 --- a/test.c +++ b/test.c @@ -71,7 +71,7 @@ long long int main() { printf("while-counter: %d\n", counter++); } - short int sh = 123; + short int sh = 123 + 5; long int lg = 456; long long int longer = 789;