Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,22 @@ TypeHolder CommonBinary(const TypeHolder* begin, size_t count) {
return large_binary();
}

bool CastableToDecimal(const DataType& type) {
return is_numeric(type.id()) || is_decimal(type.id());
}

Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector<TypeHolder>* types) {
const DataType& left_type = *(*types)[0];
const DataType& right_type = *(*types)[1];
DCHECK(is_decimal(left_type.id()) || is_decimal(right_type.id()));

if ((is_decimal(left_type.id()) && !CastableToDecimal(right_type)) ||
(is_decimal(right_type.id()) && !CastableToDecimal(left_type))) {
// If the other type is not castable to decimal, do not cast. The dispatch will
// gracefully fail by kernel selection.
return Status::OK();
}

// decimal + float64 = float64
// decimal + float32 is roughly float64 + float32 so we choose float64
if (is_floating(left_type.id()) || is_floating(right_type.id())) {
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ TEST(TestDispatchBest, CastBinaryDecimalArgs) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Decimals with negative scales not supported"),
CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));

// Non-castable -> unchanged
for (const auto promotion :
{DecimalPromotion::kAdd, DecimalPromotion::kMultiply, DecimalPromotion::kDivide}) {
for (const auto& args : std::vector<std::vector<TypeHolder>>{
{decimal128(3, 2), boolean()},
{boolean(), decimal128(3, 2)},
{decimal128(3, 2), utf8()},
{utf8(), decimal128(3, 2)},
}) {
auto args_copy = args;
ASSERT_OK(CastBinaryDecimalArgs(promotion, &args_copy));
AssertTypeEqual(*args_copy[0], *args[0]);
AssertTypeEqual(*args_copy[1], *args[1]);
}
}
}

TEST(TestDispatchBest, CastDecimalArgs) {
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,29 @@ TEST_F(TestBinaryArithmeticDecimal, Power) {
}
}

TEST_F(TestBinaryArithmeticDecimal, ErrorOnNonCastable) {
for (const auto& name : {"add", "subtract", "multiply", "divide"}) {
for (const auto& suffix : {"", "_checked"}) {
auto func = std::string(name) + suffix;
SCOPED_TRACE(func);
for (const auto& dec_ty : PositiveScaleTypes()) {
SCOPED_TRACE(dec_ty->ToString());
auto dec_arr = ArrayFromJSON(dec_ty, R"([])");
for (const auto& other_ty : {boolean(), fixed_size_binary(42), utf8()}) {
SCOPED_TRACE(other_ty->ToString());
auto other_arr = ArrayFromJSON(other_ty, R"([])");
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
::testing::HasSubstr("has no kernel matching"),
CallFunction(func, {dec_arr, other_arr}));
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
::testing::HasSubstr("has no kernel matching"),
CallFunction(func, {other_arr, dec_arr}));
}
}
}
}
}

TYPED_TEST(TestBinaryArithmeticIntegral, ShiftLeft) {
for (auto check_overflow : {false, true}) {
this->SetOverflowCheck(check_overflow);
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,26 @@ TYPED_TEST(TestCompareDecimal, DifferentParameters) {
}
}

TYPED_TEST(TestCompareDecimal, ErrorOnNonCastable) {
auto dec_ty = std::make_shared<TypeParam>(3, 2);
auto dec_arr = ArrayFromJSON(dec_ty, R"([])");

for (const auto& func :
{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) {
SCOPED_TRACE(func);
for (const auto& other_ty : {boolean(), fixed_size_binary(42), utf8()}) {
SCOPED_TRACE(other_ty->ToString());
auto other_arr = ArrayFromJSON(other_ty, R"([])");
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
::testing::HasSubstr("has no kernel matching"),
CallFunction(func, {dec_arr, other_arr}));
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
::testing::HasSubstr("has no kernel matching"),
CallFunction(func, {other_arr, dec_arr}));
}
}
}

// Helper to organize tests for fixed size binary comparisons
struct CompareCase {
std::shared_ptr<DataType> lhs_type;
Expand Down
Loading