Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/src/arrow/csv/column_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,25 @@ TEST_F(InferringColumnBuilderTest, SingleChunkInteger) {
{ArrayFromJSON(int64(), "[null, 123, 456]")});
}

TEST_F(InferringColumnBuilderTest, SingleChunkDefaultColumnTypeDoesNotOverrideInference) {
auto options = ConvertOptions::Defaults();
options.default_column_type = utf8();
auto tg = TaskGroup::MakeSerial();

CheckInferred(tg, {{"0000404", "0000505", "0000606"}}, options,
{ArrayFromJSON(int64(), "[404, 505, 606]")});
}

TEST_F(InferringColumnBuilderTest,
MultipleChunkDefaultColumnTypeDoesNotOverrideInference) {
auto options = ConvertOptions::Defaults();
options.default_column_type = utf8();
auto tg = TaskGroup::MakeSerial();

CheckInferred(tg, {{"0000404"}, {"0000505", "0000606"}}, options,
{ArrayFromJSON(int64(), "[404]"), ArrayFromJSON(int64(), "[505, 606]")});
}

TEST_F(InferringColumnBuilderTest, MultipleChunkInteger) {
auto options = ConvertOptions::Defaults();
auto tg = TaskGroup::MakeSerial();
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/csv/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ ConvertOptions ConvertOptions::Defaults() {
"NULL", "NaN", "n/a", "nan", "null"};
options.true_values = {"1", "True", "TRUE", "true"};
options.false_values = {"0", "False", "FALSE", "false"};
options.default_column_type = nullptr;
return options;
}

Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/csv/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct ARROW_EXPORT ConvertOptions {
bool check_utf8 = true;
/// Optional per-column types (disabling type inference on those columns)
std::unordered_map<std::string, std::shared_ptr<DataType>> column_types;
/// Default type to use for columns not in `column_types`
std::shared_ptr<DataType> default_column_type;
/// Recognized spellings for null values
std::vector<std::string> null_values;
/// Recognized spellings for boolean true values
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/arrow/csv/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,14 @@ class ReaderMixin {
// Does the named column have a fixed type?
auto it = convert_options_.column_types.find(col_name);
if (it == convert_options_.column_types.end()) {
conversion_schema_.columns.push_back(
ConversionSchema::InferredColumn(std::move(col_name), col_index));
// If not explicitly typed, respect default_column_type when provided
if (convert_options_.default_column_type != nullptr) {
conversion_schema_.columns.push_back(ConversionSchema::TypedColumn(
std::move(col_name), col_index, convert_options_.default_column_type));
} else {
conversion_schema_.columns.push_back(
ConversionSchema::InferredColumn(std::move(col_name), col_index));
}
} else {
conversion_schema_.columns.push_back(
ConversionSchema::TypedColumn(std::move(col_name), col_index, it->second));
Expand Down
87 changes: 87 additions & 0 deletions cpp/src/arrow/csv/reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,5 +488,92 @@ TEST(CountRowsAsync, Errors) {
internal::GetCpuThreadPool(), read_options, parse_options));
}

TEST(ReaderTests, DefaultColumnTypePartialDefault) {
auto table_buffer = std::make_shared<Buffer>(
"id,name,value,date\n"
"0000101,apple,0003.1400,2024-01-15\n"
"00102,banana,001.6180,2024-02-20\n"
"0003,cherry,02.71800,2024-03-25\n");

auto input = std::make_shared<io::BufferReader>(table_buffer);
auto read_options = ReadOptions::Defaults();
auto parse_options = ParseOptions::Defaults();
auto convert_options = ConvertOptions::Defaults();
convert_options.column_types["id"] = int64();
convert_options.default_column_type = utf8();

ASSERT_OK_AND_ASSIGN(auto reader,
TableReader::Make(io::default_io_context(), input, read_options,
parse_options, convert_options));
ASSERT_OK_AND_ASSIGN(auto table, reader->Read());

auto expected_schema = schema({field("id", int64()), field("name", utf8()),
field("value", utf8()), field("date", utf8())});
AssertSchemaEqual(expected_schema, table->schema());

auto expected_table = TableFromJSON(
expected_schema,
{R"([{"id":101, "name":"apple", "value":"0003.1400", "date":"2024-01-15"},
{"id":102, "name":"banana", "value":"001.6180", "date":"2024-02-20"},
{"id":3, "name":"cherry", "value":"02.71800", "date":"2024-03-25"}])"});
ASSERT_TRUE(table->Equals(*expected_table));
}

TEST(ReaderTests, DefaultColumnTypeForcesTypedColumns) {
auto table_buffer = std::make_shared<Buffer>(
"id,amount,code\n"
"0000404,000045.6700,001\n"
"0000505,000000.10,010\n");

auto input = std::make_shared<io::BufferReader>(table_buffer);
auto read_options = ReadOptions::Defaults();
auto parse_options = ParseOptions::Defaults();
auto convert_options = ConvertOptions::Defaults();
convert_options.default_column_type = utf8();

ASSERT_OK_AND_ASSIGN(auto reader,
TableReader::Make(io::default_io_context(), input, read_options,
parse_options, convert_options));
ASSERT_OK_AND_ASSIGN(auto table, reader->Read());

auto expected_schema =
schema({field("id", utf8()), field("amount", utf8()), field("code", utf8())});
AssertSchemaEqual(expected_schema, table->schema());

auto expected_table = TableFromJSON(
expected_schema, {R"([{"id":"0000404", "amount":"000045.6700", "code":"001"},
{"id":"0000505", "amount":"000000.10", "code":"010"}])"});
ASSERT_TRUE(table->Equals(*expected_table));
}

TEST(ReaderTests, DefaultColumnTypeAllStringsNoHeader) {
// Input without header; autogenerate column names and default all to strings
auto table_buffer = std::make_shared<Buffer>("AB|000388907|000045.6700\n");

auto input = std::make_shared<io::BufferReader>(table_buffer);
auto read_options = ReadOptions::Defaults();
read_options.autogenerate_column_names = true; // treat first row as data
auto parse_options = ParseOptions::Defaults();
parse_options.delimiter = '|';
auto convert_options = ConvertOptions::Defaults();
convert_options.default_column_type = utf8();

ASSERT_OK_AND_ASSIGN(auto reader,
TableReader::Make(io::default_io_context(), input, read_options,
parse_options, convert_options));
ASSERT_OK_AND_ASSIGN(auto table, reader->Read());

auto expected_schema =
schema({field("f0", utf8()), field("f1", utf8()), field("f2", utf8())});
AssertSchemaEqual(expected_schema, table->schema());

auto expected_table = TableFromJSON(expected_schema, {R"([{
"f0":"AB",
"f1":"000388907",
"f2":"000045.6700"
}])"});
ASSERT_TRUE(table->Equals(*expected_table));
}

} // namespace csv
} // namespace arrow
1 change: 1 addition & 0 deletions docs/source/python/csv.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Available convert options are:

~ConvertOptions.check_utf8
~ConvertOptions.column_types
~ConvertOptions.default_column_type
~ConvertOptions.null_values
~ConvertOptions.true_values
~ConvertOptions.false_values
Expand Down
81 changes: 71 additions & 10 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ cdef class ConvertOptions(_Weakrefable):
column_types : pyarrow.Schema or dict, optional
Explicitly map column names to column types. Passing this argument
disables type inference on the defined columns.
default_column_type : pyarrow.DataType, optional
Explicitly map columns not specified in column_types to a default type.
Passing this argument disables type inference on all columns.
null_values : list, optional
A sequence of strings that denote nulls in the data
(defaults are appropriate in most cases). Note that by default,
Expand Down Expand Up @@ -807,6 +810,40 @@ cdef class ConvertOptions(_Weakrefable):
fast: bool
----
fast: [[true,true,false,false,null]]

Set a default column type for all columns (disables type inference):

>>> convert_options = csv.ConvertOptions(default_column_type=pa.string())
>>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options)
pyarrow.Table
animals: string
n_legs: string
entry: string
fast: string
----
animals: [["Flamingo","Horse","Brittle stars","Centipede",""]]
n_legs: [["2","4","5","100","6"]]
entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]]
fast: [["Yes","Yes","No","No",""]]

Combine default_column_type with column_types (specific column types override default):

>>> convert_options = csv.ConvertOptions(
... column_types={"n_legs": pa.int64(), "fast": pa.bool_()},
... default_column_type=pa.string(),
... true_values=["Yes"],
... false_values=["No"])
>>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options)
pyarrow.Table
animals: string
n_legs: int64
entry: string
fast: bool
----
animals: [["Flamingo","Horse","Brittle stars","Centipede",""]]
n_legs: [[2,4,5,100,6]]
entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]]
fast: [[true,true,false,false,null]]
"""

# Avoid mistakingly creating attributes
Expand All @@ -816,7 +853,7 @@ cdef class ConvertOptions(_Weakrefable):
self.options.reset(
new CCSVConvertOptions(CCSVConvertOptions.Defaults()))

def __init__(self, *, check_utf8=None, column_types=None, null_values=None,
def __init__(self, *, check_utf8=None, column_types=None, default_column_type=None, null_values=None,
true_values=None, false_values=None, decimal_point=None,
strings_can_be_null=None, quoted_strings_can_be_null=None,
include_columns=None, include_missing_columns=None,
Expand All @@ -826,6 +863,8 @@ cdef class ConvertOptions(_Weakrefable):
self.check_utf8 = check_utf8
if column_types is not None:
self.column_types = column_types
if default_column_type is not None:
self.default_column_type = default_column_type
if null_values is not None:
self.null_values = null_values
if true_values is not None:
Expand Down Expand Up @@ -910,6 +949,27 @@ cdef class ConvertOptions(_Weakrefable):
assert typ != NULL
deref(self.options).column_types[tobytes(k)] = typ

@property
def default_column_type(self):
"""
Explicitly map columns not specified in column_types to a default type.
"""
if deref(self.options).default_column_type != NULL:
return pyarrow_wrap_data_type(deref(self.options).default_column_type)
else:
return None

@default_column_type.setter
def default_column_type(self, value):
cdef:
shared_ptr[CDataType] typ
if value is not None:
typ = pyarrow_unwrap_data_type(ensure_type(value))
assert typ != NULL
deref(self.options).default_column_type = typ
else:
deref(self.options).default_column_type.reset()

@property
def null_values(self):
"""
Expand Down Expand Up @@ -1071,6 +1131,7 @@ cdef class ConvertOptions(_Weakrefable):
return (
self.check_utf8 == other.check_utf8 and
self.column_types == other.column_types and
self.default_column_type == other.default_column_type and
self.null_values == other.null_values and
self.true_values == other.true_values and
self.false_values == other.false_values and
Expand All @@ -1087,17 +1148,17 @@ cdef class ConvertOptions(_Weakrefable):
)

def __getstate__(self):
return (self.check_utf8, self.column_types, self.null_values,
self.true_values, self.false_values, self.decimal_point,
self.timestamp_parsers, self.strings_can_be_null,
self.quoted_strings_can_be_null, self.auto_dict_encode,
self.auto_dict_max_cardinality, self.include_columns,
self.include_missing_columns)
return (self.check_utf8, self.column_types, self.default_column_type,
self.null_values, self.true_values, self.false_values,
self.decimal_point, self.timestamp_parsers,
self.strings_can_be_null, self.quoted_strings_can_be_null,
self.auto_dict_encode, self.auto_dict_max_cardinality,
self.include_columns, self.include_missing_columns)

def __setstate__(self, state):
(self.check_utf8, self.column_types, self.null_values,
self.true_values, self.false_values, self.decimal_point,
self.timestamp_parsers, self.strings_can_be_null,
(self.check_utf8, self.column_types, self.default_column_type,
self.null_values, self.true_values, self.false_values,
self.decimal_point, self.timestamp_parsers, self.strings_can_be_null,
self.quoted_strings_can_be_null, self.auto_dict_encode,
self.auto_dict_max_cardinality, self.include_columns,
self.include_missing_columns) = state
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil:
cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions":
c_bool check_utf8
unordered_map[c_string, shared_ptr[CDataType]] column_types
shared_ptr[CDataType] default_column_type
vector[c_string] null_values
vector[c_string] true_values
vector[c_string] false_values
Expand Down
65 changes: 64 additions & 1 deletion python/pyarrow/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def test_convert_options(pickle_module):
include_columns=['def', 'abc'],
include_missing_columns=False,
auto_dict_encode=True,
timestamp_parsers=[ISO8601, '%y-%m'])
timestamp_parsers=[ISO8601, '%y-%m'],
default_column_type=pa.int16())

with pytest.raises(ValueError):
opts.decimal_point = '..'
Expand Down Expand Up @@ -325,6 +326,17 @@ def test_convert_options(pickle_module):
with pytest.raises(TypeError):
opts.column_types = 0

assert opts.default_column_type is None
opts.default_column_type = pa.string()
assert opts.default_column_type == pa.string()
opts.default_column_type = 'int32'
assert opts.default_column_type == pa.int32()
opts.default_column_type = None
assert opts.default_column_type is None

with pytest.raises(TypeError, match='DataType expected'):
opts.default_column_type = 123

assert isinstance(opts.null_values, list)
assert '' in opts.null_values
assert 'N/A' in opts.null_values
Expand Down Expand Up @@ -1331,6 +1343,57 @@ def test_column_types_with_column_names(self):
'y': ['b', 'd', 'f'],
}

def test_default_column_type(self):
rows = b"a,b,c,d\n001,2.5,hello,true\n4,3.14,world,false\n"

# Test with default_column_type only - all columns should use the specified type.
opts = ConvertOptions(default_column_type=pa.string())
table = self.read_bytes(rows, convert_options=opts)
schema = pa.schema([('a', pa.string()),
('b', pa.string()),
('c', pa.string()),
('d', pa.string())])
assert table.schema == schema
assert table.to_pydict() == {
'a': ["001", "4"],
'b': ["2.5", "3.14"],
'c': ["hello", "world"],
'd': ["true", "false"],
}

# Test with both column_types and default_column_type
# Columns specified in column_types should override default_column_type
opts = ConvertOptions(
column_types={'b': pa.float64(), 'd': pa.bool_()},
default_column_type=pa.string()
)
table = self.read_bytes(rows, convert_options=opts)
schema = pa.schema([('a', pa.string()),
('b', pa.float64()),
('c', pa.string()),
('d', pa.bool_())])
assert table.schema == schema
assert table.to_pydict() == {
'a': ["001", "4"],
'b': [2.5, 3.14],
'c': ["hello", "world"],
'd': [True, False],
}

# Test that default_column_type disables type inference
opts_no_default = ConvertOptions(column_types={'b': pa.float64()})
table_no_default = self.read_bytes(rows, convert_options=opts_no_default)

opts_with_default = ConvertOptions(
column_types={'b': pa.float64()},
default_column_type=pa.string()
)
table_with_default = self.read_bytes(rows, convert_options=opts_with_default)

# Column 'a' should be int64 without default, string with default
assert table_no_default.schema.field('a').type == pa.int64()
assert table_with_default.schema.field('a').type == pa.string()

def test_no_ending_newline(self):
# No \n after last line
rows = b"a,b,c\n1,2,3\n4,5,6"
Expand Down