diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index 80989448657..c5eac14c390 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -2976,7 +2976,7 @@ public void testBuffer() { actual = Functions.asWKT(Functions.buffer(lineString, 5, false, "side=both endcap=square")); expected = "POLYGON ((45.93133264396633 72.90619096859548, 46.607775042289525 73.6732560265092, 47.42614312881237 74.2866374708669, 48.35219759000809 74.72067232686456, 49.34719367328685 74.9572012163925, 50.36950221028991 74.9863281196278, 51.37635131988457 74.80683440990555, 52.32561592057597 74.4262298392609, 53.17758018177616 73.86043834148188, 53.896599175569165 73.13313179820986, 54.452590207913495 72.27473964233116, 54.82229143418446 71.321175735393, 74.82229143418445 -1.6788242646069989, 76.14346716957745 -6.501115698791454, 66.49888430120855 -9.143467169577455, 47.95639322710423 58.5366252509033, 4.068667356033674 -2.9061909685954816, 1.162476387438193 -6.974858324629158, -6.974858324629155 -1.1624763874381943, 45.93133264396633 72.90619096859548))"; - assertEquals(expected, actual); + assertGeometryEquals(expected, actual); } @Test diff --git a/docs/api/flink/Aggregator.md b/docs/api/flink/Aggregator.md index 87b2a0f8be8..94f252f6174 100644 --- a/docs/api/flink/Aggregator.md +++ b/docs/api/flink/Aggregator.md @@ -17,18 +17,21 @@ under the License. --> -## ST_Envelope_Aggr +## ST_Envelope_Agg Introduction: Return the entire envelope boundary of all geometries in A -Format: `ST_Envelope_Aggr (A: geometryColumn)` +Format: `ST_Envelope_Agg (A: geometryColumn)` Since: `v1.3.0` +!!!note + This function was previously named `ST_Envelope_Aggr`, which is deprecated since `v1.8.1`. + Example: ```sql -SELECT ST_Envelope_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Envelope_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -37,18 +40,21 @@ Output: POLYGON ((1.1 101.1, 1.1 120.1, 20.1 120.1, 20.1 101.1, 1.1 101.1)) ``` -## ST_Intersection_Aggr +## ST_Intersection_Agg Introduction: Return the polygon intersection of all polygons in A -Format: `ST_Intersection_Aggr (A: geometryColumn)` +Format: `ST_Intersection_Agg (A: geometryColumn)` Since: `v1.5.0` +!!!note + This function was previously named `ST_Intersection_Aggr`, which is deprecated since `v1.8.1`. + Example: ```sql -SELECT ST_Intersection_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Intersection_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -57,18 +63,21 @@ Output: MULTIPOINT ((1.1 101.1), (2.1 102.1), (3.1 103.1), (4.1 104.1), (5.1 105.1), (6.1 106.1), (7.1 107.1), (8.1 108.1), (9.1 109.1), (10.1 110.1)) ``` -## ST_Union_Aggr +## ST_Union_Agg Introduction: Return the polygon union of all polygons in A. All inputs must be polygons. -Format: `ST_Union_Aggr (A: geometryColumn)` +Format: `ST_Union_Agg (A: geometryColumn)` Since: `v1.3.0` +!!!note + This function was previously named `ST_Union_Aggr`, which is deprecated since `v1.8.1`. + Example: ```sql -SELECT ST_Union_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Union_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: diff --git a/docs/api/snowflake/vector-data/AggregateFunction.md b/docs/api/snowflake/vector-data/AggregateFunction.md index 7669819d99b..b8ba8ee4e28 100644 --- a/docs/api/snowflake/vector-data/AggregateFunction.md +++ b/docs/api/snowflake/vector-data/AggregateFunction.md @@ -20,11 +20,14 @@ !!!note Please always keep the schema name `SEDONA` (e.g., `SEDONA.ST_GeomFromWKT`) when you use Sedona functions to avoid conflicting with Snowflake's built-in functions. -## ST_Envelope_Aggr +## ST_Envelope_Agg Introduction: Return the entire envelope boundary of all geometries in A -Format: `ST_Envelope_Aggr (A:geometryColumn)` +Format: `ST_Envelope_Agg (A:geometryColumn)` + +!!!note + This function was previously named `ST_Envelope_Aggr`, which is deprecated since `v1.8.1`. SQL example: @@ -36,7 +39,7 @@ WITH src_tbl AS ( ) SELECT sedona.ST_AsText(envelope) FROM src_tbl, - TABLE(sedona.ST_Envelope_Aggr(src_tbl.geom) OVER (PARTITION BY 1)); + TABLE(sedona.ST_Envelope_Agg(src_tbl.geom) OVER (PARTITION BY 1)); ``` Output: @@ -45,11 +48,14 @@ Output: POLYGON ((0 0, 0 1.5, 1.5 1.5, 1.5 0, 0 0)) ``` -## ST_Intersection_Aggr +## ST_Intersection_Agg Introduction: Return the polygon intersection of all polygons in A -Format: `ST_Intersection_Aggr (A:geometryColumn)` +Format: `ST_Intersection_Agg (A:geometryColumn)` + +!!!note + This function was previously named `ST_Intersection_Aggr`, which is deprecated since `v1.8.1`. SQL example: @@ -61,7 +67,7 @@ WITH src_tbl AS ( ) SELECT sedona.ST_AsText(intersected) FROM src_tbl, - TABLE(sedona.ST_Intersection_Aggr(src_tbl.geom) OVER (PARTITION BY 1)); + TABLE(sedona.ST_Intersection_Agg(src_tbl.geom) OVER (PARTITION BY 1)); ``` Output: @@ -70,11 +76,14 @@ Output: POLYGON ((0.5 1, 1 1, 1 0.5, 0.5 0.5, 0.5 1)) ``` -## ST_Union_Aggr +## ST_Union_Agg Introduction: Return the polygon union of all polygons in A -Format: `ST_Union_Aggr (A:geometryColumn)` +Format: `ST_Union_Agg (A:geometryColumn)` + +!!!note + This function was previously named `ST_Union_Aggr`, which is deprecated since `v1.8.1`. SQL example: @@ -86,7 +95,7 @@ WITH src_tbl AS ( ) SELECT sedona.ST_AsText(unioned) FROM src_tbl, - TABLE(sedona.ST_Union_Aggr(src_tbl.geom) OVER (PARTITION BY 1)); + TABLE(sedona.ST_Union_Agg(src_tbl.geom) OVER (PARTITION BY 1)); ``` Output: diff --git a/docs/api/sql/AggregateFunction.md b/docs/api/sql/AggregateFunction.md index 8918c6d428f..4d0165f2b3c 100644 --- a/docs/api/sql/AggregateFunction.md +++ b/docs/api/sql/AggregateFunction.md @@ -19,7 +19,7 @@ ## ST_Collect_Agg -Introduction: Collects all geometries in a geometry column into a single multi-geometry (MultiPoint, MultiLineString, MultiPolygon, or GeometryCollection). Unlike `ST_Union_Aggr`, this function does not dissolve boundaries between geometries - it simply collects them into a multi-geometry. +Introduction: Collects all geometries in a geometry column into a single multi-geometry (MultiPoint, MultiLineString, MultiPolygon, or GeometryCollection). Unlike `ST_Union_Agg`, this function does not dissolve boundaries between geometries - it simply collects them into a multi-geometry. Format: `ST_Collect_Agg (A: geometryColumn)` @@ -49,18 +49,21 @@ SQL Example with GROUP BY SELECT category, ST_Collect_Agg(geom) FROM geometries GROUP BY category ``` -## ST_Envelope_Aggr +## ST_Envelope_Agg Introduction: Return the entire envelope boundary of all geometries in A -Format: `ST_Envelope_Aggr (A: geometryColumn)` +Format: `ST_Envelope_Agg (A: geometryColumn)` Since: `v1.0.0` +!!!note + This function was previously named `ST_Envelope_Aggr`, which is deprecated since `v1.8.1`. + SQL Example ```sql -SELECT ST_Envelope_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Envelope_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -69,18 +72,21 @@ Output: POLYGON ((1.1 101.1, 1.1 120.1, 20.1 120.1, 20.1 101.1, 1.1 101.1)) ``` -## ST_Intersection_Aggr +## ST_Intersection_Agg Introduction: Return the polygon intersection of all polygons in A -Format: `ST_Intersection_Aggr (A: geometryColumn)` +Format: `ST_Intersection_Agg (A: geometryColumn)` Since: `v1.0.0` +!!!note + This function was previously named `ST_Intersection_Aggr`, which is deprecated since `v1.8.1`. + SQL Example ```sql -SELECT ST_Intersection_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Intersection_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: @@ -89,18 +95,21 @@ Output: MULTIPOINT ((1.1 101.1), (2.1 102.1), (3.1 103.1), (4.1 104.1), (5.1 105.1), (6.1 106.1), (7.1 107.1), (8.1 108.1), (9.1 109.1), (10.1 110.1)) ``` -## ST_Union_Aggr +## ST_Union_Agg Introduction: Return the polygon union of all polygons in A -Format: `ST_Union_Aggr (A: geometryColumn)` +Format: `ST_Union_Agg (A: geometryColumn)` Since: `v1.0.0` +!!!note + This function was previously named `ST_Union_Aggr`, which is deprecated since `v1.8.1`. + SQL Example ```sql -SELECT ST_Union_Aggr(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) +SELECT ST_Union_Agg(ST_GeomFromText('MULTIPOINT(1.1 101.1,2.1 102.1,3.1 103.1,4.1 104.1,5.1 105.1,6.1 106.1,7.1 107.1,8.1 108.1,9.1 109.1,10.1 110.1)')) ``` Output: diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java b/flink/src/main/java/org/apache/sedona/flink/Catalog.java index dcb593e9c32..0ee24de7995 100644 --- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java +++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java @@ -27,6 +27,10 @@ public static UserDefinedFunction[] getFuncs() { new Aggregators.ST_Envelope_Aggr(), new Aggregators.ST_Intersection_Aggr(), new Aggregators.ST_Union_Aggr(), + // Aliases for *_Aggr functions with *_Agg suffix + new Aggregators.ST_Envelope_Agg(), + new Aggregators.ST_Intersection_Agg(), + new Aggregators.ST_Union_Agg(), new Constructors.ST_Point(), new Constructors.ST_PointZ(), new Constructors.ST_PointM(), diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java index edffbadb9ba..84cebd6adcb 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Aggregators.java @@ -242,4 +242,23 @@ public void resetAccumulator(Accumulators.AccGeometry acc) { acc.geom = null; } } + + // Aliases for *_Aggr functions with *_Agg suffix + @DataTypeHint( + value = "RAW", + rawSerializer = GeometryTypeSerializer.class, + bridgedTo = Geometry.class) + public static class ST_Envelope_Agg extends ST_Envelope_Aggr {} + + @DataTypeHint( + value = "RAW", + rawSerializer = GeometryTypeSerializer.class, + bridgedTo = Geometry.class) + public static class ST_Intersection_Agg extends ST_Intersection_Aggr {} + + @DataTypeHint( + value = "RAW", + rawSerializer = GeometryTypeSerializer.class, + bridgedTo = Geometry.class) + public static class ST_Union_Agg extends ST_Union_Aggr {} } diff --git a/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java b/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java index 258927622a0..0220ff434dc 100644 --- a/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/AggregatorTest.java @@ -93,4 +93,38 @@ public void testUnion_Aggr() { Row last = last(result); assertEquals(1001, ((Polygon) last.getField(0)).getArea(), 0); } + + // Test aliases for *_Aggr functions with *_Agg suffix + @Test + public void testEnvelop_Agg_Alias() { + Table pointTable = createPointTable(testDataSize); + Table result = pointTable.select(call("ST_Envelope_Agg", $(pointColNames[0]))); + Row last = last(result); + assertEquals( + String.format( + "POLYGON ((0 0, 0 %s, %s %s, %s 0, 0 0))", + testDataSize - 1, testDataSize - 1, testDataSize - 1, testDataSize - 1), + last.getField(0).toString()); + } + + @Test + public void testIntersection_Agg_Alias() { + Table polygonTable = createPolygonOverlappingTable(testDataSize); + Table result = polygonTable.select(call("ST_Intersection_Agg", $(polygonColNames[0]))); + Row last = last(result); + assertEquals("LINESTRING EMPTY", last.getField(0).toString()); + + polygonTable = createPolygonOverlappingTable(3); + result = polygonTable.select(call("ST_Intersection_Agg", $(polygonColNames[0]))); + last = last(result); + assertEquals("LINESTRING (1 1, 1 0)", last.getField(0).toString()); + } + + @Test + public void testUnion_Agg_Alias() { + Table polygonTable = createPolygonOverlappingTable(testDataSize); + Table result = polygonTable.select(call("ST_Union_Agg", $(polygonColNames[0]))); + Row last = last(result); + assertEquals(1001, ((Polygon) last.getField(0)).getArea(), 0); + } } diff --git a/python/sedona/spark/sql/st_aggregates.py b/python/sedona/spark/sql/st_aggregates.py index ec20e643075..c85a3983baa 100644 --- a/python/sedona/spark/sql/st_aggregates.py +++ b/python/sedona/spark/sql/st_aggregates.py @@ -81,6 +81,49 @@ def ST_Collect_Agg(geometry: ColumnOrName) -> Column: return _call_aggregate_function("ST_Collect_Agg", geometry) +# Aliases for *_Aggr functions with *_Agg suffix +@validate_argument_types +def ST_Envelope_Agg(geometry: ColumnOrName) -> Column: + """Aggregate Function: Get the aggregate envelope of a geometry column. + + This is an alias for ST_Envelope_Aggr. + + :param geometry: Geometry column to aggregate. + :type geometry: ColumnOrName + :return: Geometry representing the aggregate envelope of the geometry column. + :rtype: Column + """ + return ST_Envelope_Aggr(geometry) + + +@validate_argument_types +def ST_Intersection_Agg(geometry: ColumnOrName) -> Column: + """Aggregate Function: Get the aggregate intersection of a geometry column. + + This is an alias for ST_Intersection_Aggr. + + :param geometry: Geometry column to aggregate. + :type geometry: ColumnOrName + :return: Geometry representing the aggregate intersection of the geometry column. + :rtype: Column + """ + return ST_Intersection_Aggr(geometry) + + +@validate_argument_types +def ST_Union_Agg(geometry: ColumnOrName) -> Column: + """Aggregate Function: Get the aggregate union of a geometry column. + + This is an alias for ST_Union_Aggr. + + :param geometry: Geometry column to aggregate. + :type geometry: ColumnOrName + :return: Geometry representing the aggregate union of the geometry column. + :rtype: Column + """ + return ST_Union_Aggr(geometry) + + # Automatically populate __all__ __all__ = [ name diff --git a/python/tests/sql/test_aggregate_functions.py b/python/tests/sql/test_aggregate_functions.py index 4df0e3e230c..5ea7946e299 100644 --- a/python/tests/sql/test_aggregate_functions.py +++ b/python/tests/sql/test_aggregate_functions.py @@ -145,3 +145,67 @@ def test_st_collect_aggr_preserves_duplicates(self): assert len(result.geoms) == 2 # Area should be 2 because it doesn't merge overlapping areas assert result.area == 2.0 + + # Test aliases for *_Aggr functions with *_Agg suffix + def test_st_envelope_agg_alias(self): + self.spark.sql( + """ + SELECT explode(array( + ST_GeomFromWKT('POINT(1.1 101.1)'), + ST_GeomFromWKT('POINT(1.1 1100.1)'), + ST_GeomFromWKT('POINT(1000.1 1100.1)'), + ST_GeomFromWKT('POINT(1000.1 101.1)') + )) AS arealandmark + """ + ).createOrReplaceTempView("pointdf_alias") + + boundary = self.spark.sql( + "SELECT ST_Envelope_Agg(pointdf_alias.arealandmark) FROM pointdf_alias" + ) + + coordinates = [ + (1.1, 101.1), + (1.1, 1100.1), + (1000.1, 1100.1), + (1000.1, 101.1), + (1.1, 101.1), + ] + + polygon = Polygon(coordinates) + assert boundary.take(1)[0][0].equals(polygon) + + def test_st_intersection_agg_alias(self): + self.spark.sql( + """ + SELECT explode(array( + ST_GeomFromWKT('POLYGON((0 0, 4 0, 4 4, 0 4, 0 0))'), + ST_GeomFromWKT('POLYGON((2 2, 6 2, 6 6, 2 6, 2 2))') + )) AS countyshape + """ + ).createOrReplaceTempView("polygondf_alias") + + intersection = self.spark.sql( + "SELECT ST_Intersection_Agg(polygondf_alias.countyshape) FROM polygondf_alias" + ) + + result = intersection.take(1)[0][0] + # The intersection of the two polygons should be a square from (2,2) to (4,4) with area 4 + assert result.area == 4.0 + + def test_st_union_agg_alias(self): + self.spark.sql( + """ + SELECT explode(array( + ST_GeomFromWKT('POLYGON((0 0, 2 0, 2 2, 0 2, 0 0))'), + ST_GeomFromWKT('POLYGON((1 1, 3 1, 3 3, 1 3, 1 1))') + )) AS countyshape + """ + ).createOrReplaceTempView("polygondf_union_alias") + + union = self.spark.sql( + "SELECT ST_Union_Agg(polygondf_union_alias.countyshape) FROM polygondf_union_alias" + ) + + result = union.take(1)[0][0] + # Two overlapping 2x2 squares with 1x1 overlap: area = 4 + 4 - 1 = 7 + assert result.area == 7.0 diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 3ea7ab2076f..9629a7ca55e 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -1237,6 +1237,28 @@ "", "MULTIPOINT ((0 0), (1 1))", ), + # Test aliases for *_Aggr functions with *_Agg suffix + ( + sta.ST_Envelope_Agg, + ("geom",), + "exploded_points", + "", + "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))", + ), + ( + sta.ST_Intersection_Agg, + ("geom",), + "exploded_polys", + "", + "LINESTRING (1 0, 1 1)", + ), + ( + sta.ST_Union_Agg, + ("geom",), + "exploded_polys", + "", + "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))", + ), ] wrong_type_configurations = [ diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestTableFunctions.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestTableFunctions.java index 0f6db4a4f44..e6beecd95dc 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestTableFunctions.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestTableFunctions.java @@ -124,6 +124,47 @@ public void test_ST_Collect() throws ParseException { Constructors.geomFromWKT("GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (0 5, 0 10))", 0)); } + // Test aliases for *_Aggr functions with *_Agg suffix + @Test + public void test_ST_Envelope_Agg() throws ParseException { + registerUDTF(ST_Envelope_Agg.class); + verifySqlSingleRes( + "with src_tbl as (\n" + + "select sedona.ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))') geom\n" + + "union\n" + + "select sedona.ST_GeomFromText('POLYGON ((0.5 0.5, 0.5 1.5, 1.5 1.5, 1.5 0.5, 0.5 0.5))') geom\n" + + ")\n" + + "select sedona.ST_AsText(envelope) from src_tbl, table(sedona.ST_Envelope_Agg(src_tbl.geom) OVER (PARTITION BY 1));", + Constructors.geomFromWKT("POLYGON ((0 0, 0 1.5, 1.5 1.5, 1.5 0, 0 0))", 0)); + } + + @Test + public void test_ST_Intersection_Agg() throws ParseException { + registerUDTF(ST_Intersection_Agg.class); + verifySqlSingleRes( + "with src_tbl as (\n" + + "select sedona.ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))') geom\n" + + "union\n" + + "select sedona.ST_GeomFromText('POLYGON ((0.5 0.5, 0.5 1.5, 1.5 1.5, 1.5 0.5, 0.5 0.5))') geom\n" + + ")\n" + + "select sedona.ST_AsText(intersected) from src_tbl, table(sedona.ST_Intersection_Agg(src_tbl.geom) OVER (PARTITION BY 1));", + Constructors.geomFromWKT("POLYGON ((0.5 1, 1 1, 1 0.5, 0.5 0.5, 0.5 1))", 0)); + } + + @Test + public void test_ST_Union_Agg() throws ParseException { + registerUDTF(ST_Union_Agg.class); + verifySqlSingleRes( + "with src_tbl as (\n" + + "select sedona.ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))') geom\n" + + "union\n" + + "select sedona.ST_GeomFromText('POLYGON ((0.5 0.5, 0.5 1.5, 1.5 1.5, 1.5 0.5, 0.5 0.5))') geom\n" + + ")\n" + + "select sedona.ST_AsText(unioned) from src_tbl, table(sedona.ST_Union_Agg(src_tbl.geom) OVER (PARTITION BY 1));", + Constructors.geomFromWKT( + "POLYGON ((0 0, 0 1, 0.5 1, 0.5 1.5, 1.5 1.5, 1.5 0.5, 1 0.5, 1 0, 0 0))", 0)); + } + @Test public void test_ST_DumpExplode() { registerUDTF(ST_Dump.class); diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/ddl/UDTFDDLGenerator.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/ddl/UDTFDDLGenerator.java index 4350d1aded5..2930b166bce 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/ddl/UDTFDDLGenerator.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/ddl/UDTFDDLGenerator.java @@ -38,6 +38,9 @@ public class UDTFDDLGenerator { ST_Union_Aggr.class, ST_Collect.class, ST_Dump.class, + ST_Envelope_Agg.class, + ST_Intersection_Agg.class, + ST_Union_Agg.class, // ST_SubDivideExplodeV2 is not supported in Snowflake. // The error message is "java.lang.RuntimeException: // net.snowflake.client.jdbc.SnowflakeSQLException: Data type GEOMETRY is not supported in diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Envelope_Agg.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Envelope_Agg.java new file mode 100644 index 00000000000..f02cbe10687 --- /dev/null +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Envelope_Agg.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.snowflake.snowsql.udtfs; + +import java.util.stream.Stream; +import org.apache.sedona.snowflake.snowsql.GeometrySerde; +import org.apache.sedona.snowflake.snowsql.annotations.UDTFAnnotations; +import org.locationtech.jts.geom.*; +import org.locationtech.jts.io.ParseException; + +@UDTFAnnotations.TabularFunc( + name = "ST_Envelope_Agg", + argNames = {"geom"}) +public class ST_Envelope_Agg { + + public static final GeometryFactory geometryFactory = new GeometryFactory(); + + Envelope buffer = null; + + public static class OutputRow { + + public byte[] envelope; + + public OutputRow(byte[] envelopePolygon) { + this.envelope = envelopePolygon; + } + } + + public static Class getOutputClass() { + return OutputRow.class; + } + + public ST_Envelope_Agg() {} + + public Stream process(byte[] geom) throws ParseException { + Geometry geometry = GeometrySerde.deserialize(geom); + if (buffer == null) { + buffer = geometry.getEnvelopeInternal(); + } else { + buffer.expandToInclude(geometry.getEnvelopeInternal()); + } + return Stream.empty(); + } + + public Stream endPartition() { + // Returns the value we initialized in the constructor. + Polygon poly = + geometryFactory.createPolygon( + geometryFactory.createLinearRing( + new Coordinate[] { + new Coordinate(buffer.getMinX(), buffer.getMinY()), + new Coordinate(buffer.getMinX(), buffer.getMaxY()), + new Coordinate(buffer.getMaxX(), buffer.getMaxY()), + new Coordinate(buffer.getMaxX(), buffer.getMinY()), + new Coordinate(buffer.getMinX(), buffer.getMinY()) + })); + return Stream.of(new OutputRow(GeometrySerde.serialize(poly))); + } +} diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Intersection_Agg.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Intersection_Agg.java new file mode 100644 index 00000000000..1512b5ccd3d --- /dev/null +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Intersection_Agg.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.snowflake.snowsql.udtfs; + +import java.util.stream.Stream; +import org.apache.sedona.snowflake.snowsql.GeometrySerde; +import org.apache.sedona.snowflake.snowsql.annotations.UDTFAnnotations; +import org.locationtech.jts.geom.*; +import org.locationtech.jts.io.ParseException; + +@UDTFAnnotations.TabularFunc( + name = "ST_Intersection_Agg", + argNames = {"geom"}) +public class ST_Intersection_Agg { + Geometry buffer = null; + + public static class OutputRow { + + public byte[] intersected; + + public OutputRow(byte[] intersected) { + this.intersected = intersected; + } + } + + public static Class getOutputClass() { + return OutputRow.class; + } + + public ST_Intersection_Agg() {} + + public Stream process(byte[] geom) throws ParseException { + Geometry geometry = GeometrySerde.deserialize(geom); + if (buffer == null) { + buffer = geometry; + } else if (!buffer.equalsExact(geometry)) { + buffer = buffer.intersection(geometry); + } + return Stream.empty(); + } + + public Stream endPartition() { + // Returns the value we initialized in the constructor. + return Stream.of(new OutputRow(GeometrySerde.serialize(buffer))); + } +} diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Union_Agg.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Union_Agg.java new file mode 100644 index 00000000000..14bd9da6feb --- /dev/null +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/udtfs/ST_Union_Agg.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.snowflake.snowsql.udtfs; + +import java.util.stream.Stream; +import org.apache.sedona.snowflake.snowsql.GeometrySerde; +import org.apache.sedona.snowflake.snowsql.annotations.UDTFAnnotations; +import org.locationtech.jts.geom.Geometry; +import org.locationtech.jts.geom.GeometryFactory; +import org.locationtech.jts.io.ParseException; + +@UDTFAnnotations.TabularFunc( + name = "ST_Union_Agg", + argNames = {"geom"}) +public class ST_Union_Agg { + public static final GeometryFactory geometryFactory = new GeometryFactory(); + + Geometry buffer = null; + + public static class OutputRow { + + public byte[] unioned; + + public OutputRow(byte[] unioned) { + this.unioned = unioned; + } + } + + public static Class getOutputClass() { + return OutputRow.class; + } + + public ST_Union_Agg() {} + + public Stream process(byte[] geom) throws ParseException { + Geometry geometry = GeometrySerde.deserialize(geom); + if (buffer == null) { + buffer = geometry; + } else if (!buffer.equalsExact(geometry)) { + buffer = buffer.union(geometry); + } + return Stream.empty(); + } + + public Stream endPartition() { + // Returns the value we initialized in the constructor. + return Stream.of(new OutputRow(GeometrySerde.serialize(buffer))); + } +} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala index 6e6d13e023c..fc15570d162 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal} import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.sedona_sql.expressions.{ST_Envelope_Aggr, ST_Intersection_Aggr, ST_Union_Aggr} import org.locationtech.jts.geom.Geometry import scala.reflect.ClassTag @@ -93,14 +94,25 @@ abstract class AbstractCatalog { functionBuilder) } aggregateExpressions.foreach { f => - sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)) - FunctionRegistry.builtin.registerFunction( - FunctionIdentifier(f.getClass.getSimpleName), - new ExpressionInfo(f.getClass.getCanonicalName, null, f.getClass.getSimpleName), - (_: Seq[Expression]) => - throw new UnsupportedOperationException( - s"Aggregate function ${f.getClass.getSimpleName} cannot be used as a regular function")) + registerAggregateFunction(sparkSession, f.getClass.getSimpleName, f) } + // Register aliases for *_Aggr functions with *_Agg suffix + registerAggregateFunction(sparkSession, "ST_Envelope_Agg", new ST_Envelope_Aggr) + registerAggregateFunction(sparkSession, "ST_Intersection_Agg", new ST_Intersection_Aggr) + registerAggregateFunction(sparkSession, "ST_Union_Agg", new ST_Union_Aggr()) + } + + private def registerAggregateFunction( + sparkSession: SparkSession, + functionName: String, + aggregator: Aggregator[Geometry, _, _]): Unit = { + sparkSession.udf.register(functionName, functions.udaf(aggregator)) + FunctionRegistry.builtin.registerFunction( + FunctionIdentifier(functionName), + new ExpressionInfo(aggregator.getClass.getCanonicalName, null, functionName), + (_: Seq[Expression]) => + throw new UnsupportedOperationException( + s"Aggregate function $functionName cannot be used as a regular function")) } def dropAll(sparkSession: SparkSession): Unit = { @@ -110,5 +122,9 @@ abstract class AbstractCatalog { aggregateExpressions.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction( FunctionIdentifier(f.getClass.getSimpleName))) + // Drop aliases for *_Aggr functions + Seq("ST_Envelope_Agg", "ST_Intersection_Agg", "ST_Union_Agg").foreach { aliasName => + sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(aliasName)) + } } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala index fe0dc8d7140..2befcee1ad0 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_aggregates.scala @@ -62,4 +62,17 @@ object st_aggregates { val aggrFunc = udaf(new ST_Collect_Agg) aggrFunc(col(geometry)) } + + // Aliases for *_Aggr functions with *_Agg suffix + def ST_Envelope_Agg(geometry: Column): Column = ST_Envelope_Aggr(geometry) + + def ST_Envelope_Agg(geometry: String): Column = ST_Envelope_Aggr(geometry) + + def ST_Intersection_Agg(geometry: Column): Column = ST_Intersection_Aggr(geometry) + + def ST_Intersection_Agg(geometry: String): Column = ST_Intersection_Aggr(geometry) + + def ST_Union_Agg(geometry: Column): Column = ST_Union_Aggr(geometry) + + def ST_Union_Agg(geometry: String): Column = ST_Union_Aggr(geometry) } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala index 4485f9fcfe5..cd9991b6579 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala @@ -246,6 +246,56 @@ class aggregateFunctionTestScala extends TestBaseScala { assert(result.getNumGeometries == 2) } + // Test aliases for *_Aggr functions with *_Agg suffix + it("Passed ST_Envelope_Agg alias") { + var pointCsvDF = sparkSession.read + .format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csvPointInputLocation) + pointCsvDF.createOrReplaceTempView("pointtable_alias") + var pointDf = sparkSession.sql( + "select ST_Point(cast(pointtable_alias._c0 as Decimal(24,20)), cast(pointtable_alias._c1 as Decimal(24,20))) as arealandmark from pointtable_alias") + pointDf.createOrReplaceTempView("pointdf_alias") + var boundary = + sparkSession.sql("select ST_Envelope_Agg(pointdf_alias.arealandmark) from pointdf_alias") + val coordinates: Array[Coordinate] = new Array[Coordinate](5) + coordinates(0) = new Coordinate(1.1, 101.1) + coordinates(1) = new Coordinate(1.1, 1100.1) + coordinates(2) = new Coordinate(1000.1, 1100.1) + coordinates(3) = new Coordinate(1000.1, 101.1) + coordinates(4) = coordinates(0) + val geometryFactory = new GeometryFactory() + geometryFactory.createPolygon(coordinates) + assert(boundary.take(1)(0).get(0) == geometryFactory.createPolygon(coordinates)) + } + + it("Passed ST_Intersection_Agg alias") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT('POLYGON ((0 0, 0 2, 2 2, 2 0, 0 0))'), + | ST_GeomFromWKT('POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))') + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("intersecting_polygons") + + val intersectionDF = + sparkSession.sql("SELECT ST_Intersection_Agg(geom) FROM intersecting_polygons") + val result = intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry] + + // The intersection of the two squares should be a 1x1 square with area 1.0 + assertResult(1.0)(result.getArea) + } + + it("Passed ST_Union_Agg alias") { + val polygonDf = createPolygonDataFrame(100) + polygonDf.createOrReplaceTempView("polygondf_union_alias") + val unionDF = sparkSession.sql("SELECT ST_Union_Agg(geom) FROM polygondf_union_alias") + val result = unionDF.take(1)(0).get(0).asInstanceOf[Geometry] + assert(result.getArea > 0) + } + it("ST_Union_Aggr should handle null values") { sparkSession .sql(""" diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index 8b8c8ca20c1..2fb0d5b5c56 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -1832,6 +1832,34 @@ class dataFrameAPITestScala extends TestBaseScala { assert(actualResult.getNumGeometries == 3) } + // Test aliases for *_Aggr functions with *_Agg suffix + it("Passed ST_Envelope_Agg alias") { + val baseDf = + sparkSession.sql("SELECT explode(array(ST_Point(0.0, 0.0), ST_Point(1.0, 1.0))) AS geom") + val df = baseDf.select(ST_Envelope_Agg("geom")) + val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() + val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))" + assert(actualResult == expectedResult) + } + + it("Passed ST_Union_Agg alias") { + val baseDf = sparkSession.sql( + "SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))'))) AS geom") + val df = baseDf.select(ST_Union_Agg("geom")) + val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() + val expectedResult = "POLYGON ((0 0, 0 1, 1 1, 2 1, 2 0, 1 0, 0 0))" + assert(actualResult == expectedResult) + } + + it("Passed ST_Intersection_Agg alias") { + val baseDf = sparkSession.sql( + "SELECT explode(array(ST_GeomFromWKT('POLYGON ((0 0, 2 0, 2 1, 0 1, 0 0))'), ST_GeomFromWKT('POLYGON ((1 0, 3 0, 3 1, 1 1, 1 0))'))) AS geom") + val df = baseDf.select(ST_Intersection_Agg("geom")) + val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() + val expectedResult = "POLYGON ((2 0, 1 0, 1 1, 2 1, 2 0))" + assert(actualResult == expectedResult) + } + it("Passed ST_LineFromMultiPoint") { val baseDf = sparkSession.sql( "SELECT ST_GeomFromWKT('MULTIPOINT((10 40), (40 30), (20 20), (30 10))') AS multipoint")