diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index fb77993a3028..c7aea40926a6 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -241,20 +241,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (Struct(_), _) => false, (_, Struct(_)) => false, - (_, Boolean) => { - DataType::is_integer(from_type) - || DataType::is_floating(from_type) - || from_type == &Utf8View - || from_type == &Utf8 - || from_type == &LargeUtf8 - } - (Boolean, _) => { - DataType::is_integer(to_type) - || DataType::is_floating(to_type) - || to_type == &Utf8View - || to_type == &Utf8 - || to_type == &LargeUtf8 - } + + (_, Boolean) => from_type.is_integer() || from_type.is_floating() || from_type.is_string(), + (Boolean, _) => to_type.is_integer() || to_type.is_floating() || to_type.is_string(), (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View) => { true @@ -284,7 +273,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { ) => true, (Utf8 | LargeUtf8, Utf8View) => true, (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) => true, - (Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, + (Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric(), (_, Utf8 | Utf8View | LargeUtf8) => from_type.is_primitive(), (_, Binary | LargeBinary) => from_type.is_integer(), @@ -1217,6 +1206,7 @@ pub fn cast_with_options( Int16 => parse_string::(array, cast_options), Int32 => parse_string::(array, cast_options), Int64 => parse_string::(array, cast_options), + Float16 => parse_string::(array, cast_options), Float32 => parse_string::(array, cast_options), Float64 => parse_string::(array, cast_options), Date32 => parse_string::(array, cast_options), @@ -1279,6 +1269,7 @@ pub fn cast_with_options( Int16 => parse_string_view::(array, cast_options), Int32 => parse_string_view::(array, cast_options), Int64 => parse_string_view::(array, cast_options), + Float16 => parse_string_view::(array, cast_options), Float32 => parse_string_view::(array, cast_options), Float64 => parse_string_view::(array, cast_options), Date32 => parse_string_view::(array, cast_options), @@ -1330,6 +1321,7 @@ pub fn cast_with_options( Int16 => parse_string::(array, cast_options), Int32 => parse_string::(array, cast_options), Int64 => parse_string::(array, cast_options), + Float16 => parse_string::(array, cast_options), Float32 => parse_string::(array, cast_options), Float64 => parse_string::(array, cast_options), Date32 => parse_string::(array, cast_options), @@ -4433,6 +4425,23 @@ mod tests { assert_eq!(8.9, c.value(3)); } + #[test] + fn test_cast_string_to_f16() { + let arrays = [ + Arc::new(StringViewArray::from(vec!["3", "4.56", "seven", "8.9"])) as ArrayRef, + Arc::new(StringArray::from(vec!["3", "4.56", "seven", "8.9"])), + Arc::new(LargeStringArray::from(vec!["3", "4.56", "seven", "8.9"])), + ]; + for array in arrays { + let b = cast(&array, &DataType::Float16).unwrap(); + let c = b.as_primitive::(); + assert_eq!(half::f16::from_f32(3.0), c.value(0)); + assert_eq!(half::f16::from_f32(4.56), c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(half::f16::from_f32(8.9), c.value(3)); + } + } + #[test] fn test_cast_utf8view_to_decimal128() { let array = StringViewArray::from(vec![None, Some("4"), Some("5.6"), Some("7.89")]);