diff --git a/src/casting.cpp b/src/casting.cpp index 9c43c86..dc27e14 100644 --- a/src/casting.cpp +++ b/src/casting.cpp @@ -12,12 +12,28 @@ namespace types { auto bool_ty = std::shared_ptr{ new FundamentalType{ FundamentalTypeKind::Bool } }; - for (auto& source_ty : { int_ty, char_ty, bool_ty }) { - for (auto& target_ty : { int_ty, char_ty, bool_ty }) { - casts.push_back(CastDefinition{ source_ty, target_ty, false, - [](codegen::Builder& builder, std::shared_ptr target, llvm::Value* value) { - return builder.builder->CreateSExtOrTrunc(value, target->codegen(builder), "cast"); - } }); + auto numerical_types = { int_ty, char_ty, bool_ty }; + + for (auto& source_ty : numerical_types) { + for (auto& target_ty : numerical_types) { + if (types::types_equal(source_ty, target_ty)) { + casts.push_back(CastDefinition{ source_ty, target_ty, true, + [](codegen::Builder&, std::shared_ptr, llvm::Value* value) { + return value; + } }); + } + else if (target_ty->is_signed()) { + casts.push_back(CastDefinition{ source_ty, target_ty, false, + [](codegen::Builder& builder, std::shared_ptr target, llvm::Value* value) { + return builder.builder->CreateSExtOrTrunc(value, target->codegen(builder), "cast"); + } }); + } + else { + casts.push_back(CastDefinition{ source_ty, target_ty, false, + [](codegen::Builder& builder, std::shared_ptr target, llvm::Value* value) { + return builder.builder->CreateZExtOrTrunc(value, target->codegen(builder), "cast"); + } }); + } } } diff --git a/src/types.cpp b/src/types.cpp index acca953..e715457 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -75,6 +75,10 @@ namespace types { return {}; } + bool Type::is_signed() { + false; + } + std::pair> FundamentalType::load(codegen::Builder&, llvm::Value* ptr) { auto self = std::make_shared(*this); return std::pair(ptr, self); @@ -124,6 +128,18 @@ namespace types { } } + bool FundamentalType::is_signed() { + switch (this->m_ty) { + case FundamentalTypeKind::Int: + return true; + case FundamentalTypeKind::Bool: + case FundamentalTypeKind::Char: + return false; + default: + throw std::runtime_error("Invalid type"); + } + } + std::string FunctionType::formatted() { std::stringstream out{ "" }; out << "("; diff --git a/src/types.h b/src/types.h index 95d3769..0b5bd72 100644 --- a/src/types.h +++ b/src/types.h @@ -33,6 +33,7 @@ namespace types { virtual llvm::Value* sub(codegen::Builder& builder, llvm::Value* lhs, llvm::Value* rhs); virtual llvm::Value* lt(codegen::Builder& builder, llvm::Value* lhs, llvm::Value* rhs); virtual llvm::Value* gt(codegen::Builder& builder, llvm::Value* lhs, llvm::Value* rhs); + virtual bool is_signed(); }; class FundamentalType : public Type { @@ -47,6 +48,7 @@ namespace types { virtual llvm::Value* sub(codegen::Builder& builder, llvm::Value* lhs, llvm::Value* rhs) override; virtual llvm::Value* lt(codegen::Builder& builder, llvm::Value* lhs, llvm::Value* rhs) override; virtual llvm::Value* gt(codegen::Builder& builder, llvm::Value* lhs, llvm::Value* rhs) override; + virtual bool is_signed() override; };