From 04a8a9ccdbd836ac612900e85c698d0bc2911f74 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Wed, 17 Dec 2025 12:07:40 +0800 Subject: [PATCH 1/2] Fix null handling for aggregation functions --- .../expressions/AggregateFunctions.scala | 129 ++++++------------ .../sql/aggregateFunctionTestScala.scala | 127 +++++++++++++++++ 2 files changed, 171 insertions(+), 85 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala index fc0cab6260..d2d83b7fc7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.sedona_sql.expressions import org.apache.sedona.common.Functions +import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator -import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory} +import org.locationtech.jts.geom.{Coordinate, Envelope, Geometry, GeometryFactory} import org.locationtech.jts.operation.overlayng.OverlayNGRobust import scala.collection.JavaConverters._ @@ -32,18 +33,7 @@ import scala.collection.mutable.ListBuffer */ trait TraitSTAggregateExec { - val initialGeometry: Geometry = { - // dummy value for initial value(polygon but ) - // any other value is ok. - val coordinates: Array[Coordinate] = new Array[Coordinate](5) - coordinates(0) = new Coordinate(-999999999, -999999999) - coordinates(1) = new Coordinate(-999999999, -999999999) - coordinates(2) = new Coordinate(-999999999, -999999999) - coordinates(3) = new Coordinate(-999999999, -999999999) - coordinates(4) = coordinates(0) - val geometryFactory = new GeometryFactory() - geometryFactory.createPolygon(coordinates) - } + val initialGeometry: Geometry = null val serde = ExpressionEncoder[Geometry]() def zero: Geometry = initialGeometry @@ -62,7 +52,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000) val bufferSerde = ExpressionEncoder[ListBuffer[Geometry]]() override def reduce(buffer: ListBuffer[Geometry], input: Geometry): ListBuffer[Geometry] = { - buffer += input + if (input != null) { + buffer += input + } if (buffer.size >= bufferSize) { // Perform the union when buffer size is reached val unionGeometry = OverlayNGRobust.union(buffer.asJava) @@ -86,6 +78,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000) } override def finish(reduction: ListBuffer[Geometry]): Geometry = { + if (reduction.isEmpty) { + return null + } OverlayNGRobust.union(reduction.asJava) } @@ -99,79 +94,33 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000) /** * Return the envelope boundary of the entire column */ -private[apache] class ST_Envelope_Aggr - extends Aggregator[Geometry, Geometry, Geometry] - with TraitSTAggregateExec { +private[apache] class ST_Envelope_Aggr extends Aggregator[Geometry, Envelope, Geometry] { - def reduce(buffer: Geometry, input: Geometry): Geometry = { - val accumulateEnvelope = buffer.getEnvelopeInternal - val newEnvelope = input.getEnvelopeInternal - val coordinates: Array[Coordinate] = new Array[Coordinate](5) - var minX = 0.0 - var minY = 0.0 - var maxX = 0.0 - var maxY = 0.0 - if (accumulateEnvelope.equals(initialGeometry.getEnvelopeInternal)) { - // Found the accumulateEnvelope is the initial value - minX = newEnvelope.getMinX - minY = newEnvelope.getMinY - maxX = newEnvelope.getMaxX - maxY = newEnvelope.getMaxY - } else if (newEnvelope.equals(initialGeometry.getEnvelopeInternal)) { - minX = accumulateEnvelope.getMinX - minY = accumulateEnvelope.getMinY - maxX = accumulateEnvelope.getMaxX - maxY = accumulateEnvelope.getMaxY - } else { - minX = Math.min(accumulateEnvelope.getMinX, newEnvelope.getMinX) - minY = Math.min(accumulateEnvelope.getMinY, newEnvelope.getMinY) - maxX = Math.max(accumulateEnvelope.getMaxX, newEnvelope.getMaxX) - maxY = Math.max(accumulateEnvelope.getMaxY, newEnvelope.getMaxY) + def reduce(buffer: Envelope, input: Geometry): Envelope = { + if (input != null) { + buffer.expandToInclude(input.getEnvelopeInternal) } - coordinates(0) = new Coordinate(minX, minY) - coordinates(1) = new Coordinate(minX, maxY) - coordinates(2) = new Coordinate(maxX, maxY) - coordinates(3) = new Coordinate(maxX, minY) - coordinates(4) = coordinates(0) - val geometryFactory = new GeometryFactory() - geometryFactory.createPolygon(coordinates) + buffer + } + def merge(buffer1: Envelope, buffer2: Envelope): Envelope = { + buffer1.expandToInclude(buffer2) + buffer1 } - def merge(buffer1: Geometry, buffer2: Geometry): Geometry = { - val leftEnvelope = buffer1.getEnvelopeInternal - val rightEnvelope = buffer2.getEnvelopeInternal - val coordinates: Array[Coordinate] = new Array[Coordinate](5) - var minX = 0.0 - var minY = 0.0 - var maxX = 0.0 - var maxY = 0.0 - if (leftEnvelope.equals(initialGeometry.getEnvelopeInternal)) { - minX = rightEnvelope.getMinX - minY = rightEnvelope.getMinY - maxX = rightEnvelope.getMaxX - maxY = rightEnvelope.getMaxY - } else if (rightEnvelope.equals(initialGeometry.getEnvelopeInternal)) { - minX = leftEnvelope.getMinX - minY = leftEnvelope.getMinY - maxX = leftEnvelope.getMaxX - maxY = leftEnvelope.getMaxY + def finish(reduction: Envelope): Geometry = { + if (reduction.isNull) { + null } else { - minX = Math.min(leftEnvelope.getMinX, rightEnvelope.getMinX) - minY = Math.min(leftEnvelope.getMinY, rightEnvelope.getMinY) - maxX = Math.max(leftEnvelope.getMaxX, rightEnvelope.getMaxX) - maxY = Math.max(leftEnvelope.getMaxY, rightEnvelope.getMaxY) + new GeometryFactory().toGeometry(reduction) } - - coordinates(0) = new Coordinate(minX, minY) - coordinates(1) = new Coordinate(minX, maxY) - coordinates(2) = new Coordinate(maxX, maxY) - coordinates(3) = new Coordinate(maxX, minY) - coordinates(4) = coordinates(0) - val geometryFactory = new GeometryFactory() - geometryFactory.createPolygon(coordinates) } + def bufferEncoder: Encoder[Envelope] = Encoders.javaSerialization(classOf[Envelope]) + + def outputEncoder: ExpressionEncoder[Geometry] = ExpressionEncoder[Geometry]() + + def zero: Envelope = new Envelope() } /** @@ -181,16 +130,26 @@ private[apache] class ST_Intersection_Aggr extends Aggregator[Geometry, Geometry, Geometry] with TraitSTAggregateExec { def reduce(buffer: Geometry, input: Geometry): Geometry = { - if (buffer.isEmpty) input - else if (buffer.equalsExact(initialGeometry)) input - else buffer.intersection(input) + if (input == null) { + return buffer + } + if (buffer == null) { + return input + } + buffer.intersection(input) } def merge(buffer1: Geometry, buffer2: Geometry): Geometry = { - if (buffer1.equalsExact(initialGeometry)) buffer2 - else if (buffer2.equalsExact(initialGeometry)) buffer1 - else buffer1.intersection(buffer2) + if (buffer1 == null) { + return buffer2 + } + if (buffer2 == null) { + return buffer1 + } + buffer1.intersection(buffer2) } + + override def finish(out: Geometry): Geometry = out } /** @@ -219,7 +178,7 @@ private[apache] class ST_Collect_Agg override def finish(reduction: ListBuffer[Geometry]): Geometry = { if (reduction.isEmpty) { - new GeometryFactory().createGeometryCollection() + null } else { Functions.createMultiGeometry(reduction.toArray) } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala index 911769e2ac..81b98c4ba2 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala @@ -245,6 +245,133 @@ class aggregateFunctionTestScala extends TestBaseScala { // Should only have 2 points (nulls are skipped) assert(result.getNumGeometries == 2) } + + it("ST_Union_Aggr should handle null values") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'), + | ST_GeomFromWKT(NULL), + | ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))') + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("polygons_with_null_for_union") + + val unionDF = + sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM polygons_with_null_for_union") + val result = unionDF.take(1)(0).get(0).asInstanceOf[Geometry] + + // Should union the 2 non-null polygons (total area = 2.0) + assert(result.getArea == 2.0) + } + + it("ST_Envelope_Aggr should handle null values") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT('POINT(1 2)'), + | ST_GeomFromWKT(NULL), + | ST_GeomFromWKT('POINT(3 4)') + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("points_with_null_for_envelope") + + val envelopeDF = + sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM points_with_null_for_envelope") + val result = envelopeDF.take(1)(0).get(0).asInstanceOf[Geometry] + + // Should create envelope from the 2 non-null points + assert(result.getGeometryType == "Polygon") + val envelope = result.getEnvelopeInternal + assert(envelope.getMinX == 1.0) + assert(envelope.getMinY == 2.0) + assert(envelope.getMaxX == 3.0) + assert(envelope.getMaxY == 4.0) + } + + it("ST_Intersection_Aggr should handle null values") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT('POLYGON((0 0, 4 0, 4 4, 0 4, 0 0))'), + | ST_GeomFromWKT(NULL), + | ST_GeomFromWKT('POLYGON((2 2, 6 2, 6 6, 2 6, 2 2))') + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("polygons_with_null_for_intersection") + + val intersectionDF = sparkSession.sql( + "SELECT ST_Intersection_Aggr(geom) FROM polygons_with_null_for_intersection") + val result = intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry] + + // Should intersect the 2 non-null polygons (intersection area = 4.0) + assert(result.getArea == 4.0) + } + + it("ST_Union_Aggr should return null if all inputs are null") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT(NULL), + | ST_GeomFromWKT(NULL) + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("all_null_union") + + val unionDF = sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM all_null_union") + val result = unionDF.take(1)(0).get(0) + + assert(result == null) + } + + it("ST_Envelope_Aggr should return null if all inputs are null") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT(NULL), + | ST_GeomFromWKT(NULL) + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("all_null_envelope") + + val envelopeDF = sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM all_null_envelope") + val result = envelopeDF.take(1)(0).get(0) + + assert(result == null) + } + + it("ST_Intersection_Aggr should return null if all inputs are null") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT(NULL), + | ST_GeomFromWKT(NULL) + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("all_null_intersection") + + val intersectionDF = + sparkSession.sql("SELECT ST_Intersection_Aggr(geom) FROM all_null_intersection") + val result = intersectionDF.take(1)(0).get(0) + + assert(result == null) + } + + it("ST_Collect_Agg should return null if all inputs are null") { + sparkSession + .sql(""" + |SELECT explode(array( + | ST_GeomFromWKT(NULL), + | ST_GeomFromWKT(NULL) + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("all_null_collect") + + val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM all_null_collect") + val result = collectDF.take(1)(0).get(0) + + assert(result == null) + } } def generateRandomPolygon(index: Int): String = { From cfb6c9d603b2d31a2df9339cffe601b4485f977d Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Wed, 17 Dec 2025 15:22:56 +0800 Subject: [PATCH 2/2] Fix envelope aggr --- .../expressions/AggregateFunctions.scala | 73 +++++++++++++++---- .../sql/aggregateFunctionTestScala.scala | 22 ++++++ 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala index d2d83b7fc7..ca169a2598 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala @@ -91,36 +91,77 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000) override def zero: ListBuffer[Geometry] = ListBuffer.empty } +/** + * A helper class to store envelope boundary during aggregation. We use this custom case class + * instead of JTS Envelope to work with the Spark Encoder. + */ +case class EnvelopeBuffer(minX: Double, maxX: Double, minY: Double, maxY: Double) { + def isNull: Boolean = minX > maxX + + def toEnvelope: Envelope = { + if (isNull) { + new Envelope() + } else { + new Envelope(minX, maxX, minY, maxY) + } + } + + def merge(other: EnvelopeBuffer): EnvelopeBuffer = { + if (this.isNull) { + other + } else if (other.isNull) { + this + } else { + EnvelopeBuffer( + math.min(this.minX, other.minX), + math.max(this.maxX, other.maxX), + math.min(this.minY, other.minY), + math.max(this.maxY, other.maxY)) + } + } +} + /** * Return the envelope boundary of the entire column */ -private[apache] class ST_Envelope_Aggr extends Aggregator[Geometry, Envelope, Geometry] { +private[apache] class ST_Envelope_Aggr + extends Aggregator[Geometry, Option[EnvelopeBuffer], Geometry] { - def reduce(buffer: Envelope, input: Geometry): Envelope = { - if (input != null) { - buffer.expandToInclude(input.getEnvelopeInternal) + val serde = ExpressionEncoder[Geometry]() + + def reduce(buffer: Option[EnvelopeBuffer], input: Geometry): Option[EnvelopeBuffer] = { + if (input == null) return buffer + val env = input.getEnvelopeInternal + val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY, env.getMaxY) + buffer match { + case Some(b) => Some(b.merge(envBuffer)) + case None => Some(envBuffer) } - buffer } - def merge(buffer1: Envelope, buffer2: Envelope): Envelope = { - buffer1.expandToInclude(buffer2) - buffer1 + def merge( + buffer1: Option[EnvelopeBuffer], + buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = { + (buffer1, buffer2) match { + case (Some(b1), Some(b2)) => Some(b1.merge(b2)) + case (Some(_), None) => buffer1 + case (None, Some(_)) => buffer2 + case (None, None) => None + } } - def finish(reduction: Envelope): Geometry = { - if (reduction.isNull) { - null - } else { - new GeometryFactory().toGeometry(reduction) + def finish(reduction: Option[EnvelopeBuffer]): Geometry = { + reduction match { + case Some(b) => new GeometryFactory().toGeometry(b.toEnvelope) + case None => null } } - def bufferEncoder: Encoder[Envelope] = Encoders.javaSerialization(classOf[Envelope]) + def bufferEncoder: Encoder[Option[EnvelopeBuffer]] = Encoders.product[Option[EnvelopeBuffer]] - def outputEncoder: ExpressionEncoder[Geometry] = ExpressionEncoder[Geometry]() + def outputEncoder: ExpressionEncoder[Geometry] = serde - def zero: Envelope = new Envelope() + def zero: Option[EnvelopeBuffer] = None } /** diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala index 81b98c4ba2..4485f9fcfe 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/aggregateFunctionTestScala.scala @@ -372,6 +372,28 @@ class aggregateFunctionTestScala extends TestBaseScala { assert(result == null) } + + it( + "ST_Envelope_Aggr should return empty geometry if inputs are mixed with null and empty geometries") { + sparkSession + .sql(""" + |SELECT explode(array( + | NULL, + | NULL, + | ST_GeomFromWKT('POINT EMPTY'), + | NULL, + | ST_GeomFromWKT('POLYGON EMPTY') + |)) AS geom + """.stripMargin) + .createOrReplaceTempView("mixed_null_empty_envelope") + + val envelopeDF = + sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM mixed_null_empty_envelope") + val result = envelopeDF.take(1)(0).get(0) + + assert(result != null) + assert(result.asInstanceOf[Geometry].isEmpty) + } } def generateRandomPolygon(index: Int): String = {