diff --git a/python/pyarrow/src/arrow/python/inference.cc b/python/pyarrow/src/arrow/python/inference.cc index 1aa7915ba1e..5aed8453498 100644 --- a/python/pyarrow/src/arrow/python/inference.cc +++ b/python/pyarrow/src/arrow/python/inference.cc @@ -108,6 +108,17 @@ class NumPyDtypeUnifier { GetNumPyTypeName(new_dtype)); } + Status InvalidDatetimeUnitMix(PyArray_Descr* new_descr) { + auto new_meta = reinterpret_cast( + PyDataType_C_METADATA(new_descr)); + auto current_meta = reinterpret_cast( + PyDataType_C_METADATA(current_dtype_)); + + return Status::Invalid("Cannot mix NumPy datetime64 units ", + DatetimeUnitName(current_meta->meta.base), " and ", + DatetimeUnitName(new_meta->meta.base)); + } + int Observe_BOOL(PyArray_Descr* descr, int dtype) { return INVALID; } int Observe_INT8(PyArray_Descr* descr, int dtype) { @@ -255,7 +266,17 @@ class NumPyDtypeUnifier { } int Observe_DATETIME(PyArray_Descr* dtype_obj) { - // TODO: check that units are all the same + // Check that datetime units are consistent across all values + auto datetime_meta = reinterpret_cast( + PyDataType_C_METADATA(dtype_obj)); + auto current_meta = reinterpret_cast( + PyDataType_C_METADATA(current_dtype_)); + + if (datetime_meta->meta.base != current_meta->meta.base) { + // Units don't match - this is invalid + return INVALID; + } + return OK; } @@ -267,6 +288,13 @@ class NumPyDtypeUnifier { current_type_num_ = dtype; return Status::OK(); } else if (current_type_num_ == dtype) { + // Same type, but for datetime we still need to check units match + if (dtype == NPY_DATETIME) { + int action = Observe_DATETIME(descr); + if (action == INVALID) { + return InvalidDatetimeUnitMix(descr); + } + } return Status::OK(); } @@ -309,6 +337,41 @@ class NumPyDtypeUnifier { int current_type_num() const { return current_type_num_; } private: + static const char* DatetimeUnitName(NPY_DATETIMEUNIT unit) { + switch (unit) { + case NPY_FR_Y: + return "Y"; + case NPY_FR_M: + return "M"; + case NPY_FR_W: + return "W"; + case NPY_FR_D: + return "D"; + case NPY_FR_h: + return "h"; + case NPY_FR_m: + return "m"; + case NPY_FR_s: + return "s"; + case NPY_FR_ms: + return "ms"; + case NPY_FR_us: + return "us"; + case NPY_FR_ns: + return "ns"; + case NPY_FR_ps: + return "ps"; + case NPY_FR_fs: + return "fs"; + case NPY_FR_as: + return "as"; + case NPY_FR_GENERIC: + return "generic"; + default: + return "unknown"; + } + } + int current_type_num_; PyArray_Descr* current_dtype_; }; diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index ec361159c5f..3bce2d87c05 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -2489,7 +2489,30 @@ def test_array_from_different_numpy_datetime_units_raises(): ms = np.array(data, dtype='datetime64[ms]') data = list(s[:2]) + list(ms[2:]) - with pytest.raises(pa.ArrowNotImplementedError): + with pytest.raises(pa.ArrowInvalid, match="units s and ms"): + pa.array(data) + + +@pytest.mark.numpy +@pytest.mark.parametrize('unit', [ + 'Y', # year + 'M', # month + 'W', # week + 'h', # hour + 'm', # minute + 'ps', # picosecond + 'fs', # femtosecond + 'as', # attosecond +]) +def test_array_from_unsupported_numpy_datetime_unit_names(unit): + s_data = [np.datetime64('2020-01-01', 's')] + unsupported_data = [np.datetime64('2020', unit)] + + # Mix supported unit (s) with unsupported unit + data = s_data + unsupported_data + + with pytest.raises(pa.ArrowInvalid, + match=f"Cannot mix NumPy datetime64 units s and {unit}"): pa.array(data) @@ -2514,8 +2537,8 @@ def test_array_from_timestamp_with_generic_unit(): x = np.datetime64('2017-01-01 01:01:01.111111111') y = np.datetime64('2018-11-22 12:24:48.111111111') - with pytest.raises(pa.ArrowNotImplementedError, - match='Unbound or generic datetime64 time unit'): + with pytest.raises(pa.ArrowInvalid, + match='Cannot mix NumPy datetime64 units'): pa.array([n, x, y])