Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.sedona.common.Functions
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
import org.locationtech.jts.geom.{Coordinate, Envelope, Geometry, GeometryFactory}
import org.locationtech.jts.operation.overlayng.OverlayNGRobust

import scala.collection.JavaConverters._
Expand All @@ -32,18 +33,7 @@ import scala.collection.mutable.ListBuffer
*/

trait TraitSTAggregateExec {
val initialGeometry: Geometry = {
// dummy value for initial value(polygon but )
// any other value is ok.
val coordinates: Array[Coordinate] = new Array[Coordinate](5)
coordinates(0) = new Coordinate(-999999999, -999999999)
coordinates(1) = new Coordinate(-999999999, -999999999)
coordinates(2) = new Coordinate(-999999999, -999999999)
coordinates(3) = new Coordinate(-999999999, -999999999)
coordinates(4) = coordinates(0)
val geometryFactory = new GeometryFactory()
geometryFactory.createPolygon(coordinates)
}
val initialGeometry: Geometry = null
val serde = ExpressionEncoder[Geometry]()

def zero: Geometry = initialGeometry
Expand All @@ -62,7 +52,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
val bufferSerde = ExpressionEncoder[ListBuffer[Geometry]]()

override def reduce(buffer: ListBuffer[Geometry], input: Geometry): ListBuffer[Geometry] = {
buffer += input
if (input != null) {
buffer += input
}
if (buffer.size >= bufferSize) {
// Perform the union when buffer size is reached
val unionGeometry = OverlayNGRobust.union(buffer.asJava)
Expand All @@ -86,6 +78,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
}

override def finish(reduction: ListBuffer[Geometry]): Geometry = {
if (reduction.isEmpty) {
return null
}
OverlayNGRobust.union(reduction.asJava)
}

Expand All @@ -97,81 +92,76 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
}

/**
* Return the envelope boundary of the entire column
* A helper class to store envelope boundary during aggregation. We use this custom case class
* instead of JTS Envelope to work with the Spark Encoder.
*/
private[apache] class ST_Envelope_Aggr
extends Aggregator[Geometry, Geometry, Geometry]
with TraitSTAggregateExec {
case class EnvelopeBuffer(minX: Double, maxX: Double, minY: Double, maxY: Double) {
def isNull: Boolean = minX > maxX

def reduce(buffer: Geometry, input: Geometry): Geometry = {
val accumulateEnvelope = buffer.getEnvelopeInternal
val newEnvelope = input.getEnvelopeInternal
val coordinates: Array[Coordinate] = new Array[Coordinate](5)
var minX = 0.0
var minY = 0.0
var maxX = 0.0
var maxY = 0.0
if (accumulateEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
// Found the accumulateEnvelope is the initial value
minX = newEnvelope.getMinX
minY = newEnvelope.getMinY
maxX = newEnvelope.getMaxX
maxY = newEnvelope.getMaxY
} else if (newEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
minX = accumulateEnvelope.getMinX
minY = accumulateEnvelope.getMinY
maxX = accumulateEnvelope.getMaxX
maxY = accumulateEnvelope.getMaxY
def toEnvelope: Envelope = {
if (isNull) {
new Envelope()
} else {
minX = Math.min(accumulateEnvelope.getMinX, newEnvelope.getMinX)
minY = Math.min(accumulateEnvelope.getMinY, newEnvelope.getMinY)
maxX = Math.max(accumulateEnvelope.getMaxX, newEnvelope.getMaxX)
maxY = Math.max(accumulateEnvelope.getMaxY, newEnvelope.getMaxY)
new Envelope(minX, maxX, minY, maxY)
}
coordinates(0) = new Coordinate(minX, minY)
coordinates(1) = new Coordinate(minX, maxY)
coordinates(2) = new Coordinate(maxX, maxY)
coordinates(3) = new Coordinate(maxX, minY)
coordinates(4) = coordinates(0)
val geometryFactory = new GeometryFactory()
geometryFactory.createPolygon(coordinates)

}

def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
val leftEnvelope = buffer1.getEnvelopeInternal
val rightEnvelope = buffer2.getEnvelopeInternal
val coordinates: Array[Coordinate] = new Array[Coordinate](5)
var minX = 0.0
var minY = 0.0
var maxX = 0.0
var maxY = 0.0
if (leftEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
minX = rightEnvelope.getMinX
minY = rightEnvelope.getMinY
maxX = rightEnvelope.getMaxX
maxY = rightEnvelope.getMaxY
} else if (rightEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
minX = leftEnvelope.getMinX
minY = leftEnvelope.getMinY
maxX = leftEnvelope.getMaxX
maxY = leftEnvelope.getMaxY
def merge(other: EnvelopeBuffer): EnvelopeBuffer = {
if (this.isNull) {
other
} else if (other.isNull) {
this
} else {
minX = Math.min(leftEnvelope.getMinX, rightEnvelope.getMinX)
minY = Math.min(leftEnvelope.getMinY, rightEnvelope.getMinY)
maxX = Math.max(leftEnvelope.getMaxX, rightEnvelope.getMaxX)
maxY = Math.max(leftEnvelope.getMaxY, rightEnvelope.getMaxY)
EnvelopeBuffer(
math.min(this.minX, other.minX),
math.max(this.maxX, other.maxX),
math.min(this.minY, other.minY),
math.max(this.maxY, other.maxY))
}
}
}

coordinates(0) = new Coordinate(minX, minY)
coordinates(1) = new Coordinate(minX, maxY)
coordinates(2) = new Coordinate(maxX, maxY)
coordinates(3) = new Coordinate(maxX, minY)
coordinates(4) = coordinates(0)
val geometryFactory = new GeometryFactory()
geometryFactory.createPolygon(coordinates)
/**
* Return the envelope boundary of the entire column
*/
private[apache] class ST_Envelope_Aggr
extends Aggregator[Geometry, Option[EnvelopeBuffer], Geometry] {

val serde = ExpressionEncoder[Geometry]()

def reduce(buffer: Option[EnvelopeBuffer], input: Geometry): Option[EnvelopeBuffer] = {
if (input == null) return buffer
val env = input.getEnvelopeInternal
val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY, env.getMaxY)
buffer match {
case Some(b) => Some(b.merge(envBuffer))
case None => Some(envBuffer)
}
}

def merge(
buffer1: Option[EnvelopeBuffer],
buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = {
(buffer1, buffer2) match {
case (Some(b1), Some(b2)) => Some(b1.merge(b2))
case (Some(_), None) => buffer1
case (None, Some(_)) => buffer2
case (None, None) => None
}
}

def finish(reduction: Option[EnvelopeBuffer]): Geometry = {
reduction match {
case Some(b) => new GeometryFactory().toGeometry(b.toEnvelope)
case None => null
}
}

def bufferEncoder: Encoder[Option[EnvelopeBuffer]] = Encoders.product[Option[EnvelopeBuffer]]

def outputEncoder: ExpressionEncoder[Geometry] = serde

def zero: Option[EnvelopeBuffer] = None
}

/**
Expand All @@ -181,16 +171,26 @@ private[apache] class ST_Intersection_Aggr
extends Aggregator[Geometry, Geometry, Geometry]
with TraitSTAggregateExec {
def reduce(buffer: Geometry, input: Geometry): Geometry = {
if (buffer.isEmpty) input
else if (buffer.equalsExact(initialGeometry)) input
else buffer.intersection(input)
if (input == null) {
return buffer
}
if (buffer == null) {
return input
}
buffer.intersection(input)
}

def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
if (buffer1.equalsExact(initialGeometry)) buffer2
else if (buffer2.equalsExact(initialGeometry)) buffer1
else buffer1.intersection(buffer2)
if (buffer1 == null) {
return buffer2
}
if (buffer2 == null) {
return buffer1
}
buffer1.intersection(buffer2)
}

override def finish(out: Geometry): Geometry = out
}

/**
Expand Down Expand Up @@ -219,7 +219,7 @@ private[apache] class ST_Collect_Agg

override def finish(reduction: ListBuffer[Geometry]): Geometry = {
if (reduction.isEmpty) {
new GeometryFactory().createGeometryCollection()
null
} else {
Functions.createMultiGeometry(reduction.toArray)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,155 @@ class aggregateFunctionTestScala extends TestBaseScala {
// Should only have 2 points (nulls are skipped)
assert(result.getNumGeometries == 2)
}

it("ST_Union_Aggr should handle null values") {
sparkSession
.sql("""
|SELECT explode(array(
| ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'),
| ST_GeomFromWKT(NULL),
| ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))')
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("polygons_with_null_for_union")

val unionDF =
sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM polygons_with_null_for_union")
val result = unionDF.take(1)(0).get(0).asInstanceOf[Geometry]

// Should union the 2 non-null polygons (total area = 2.0)
assert(result.getArea == 2.0)
}

it("ST_Envelope_Aggr should handle null values") {
sparkSession
.sql("""
|SELECT explode(array(
| ST_GeomFromWKT('POINT(1 2)'),
| ST_GeomFromWKT(NULL),
| ST_GeomFromWKT('POINT(3 4)')
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("points_with_null_for_envelope")

val envelopeDF =
sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM points_with_null_for_envelope")
val result = envelopeDF.take(1)(0).get(0).asInstanceOf[Geometry]

// Should create envelope from the 2 non-null points
assert(result.getGeometryType == "Polygon")
val envelope = result.getEnvelopeInternal
assert(envelope.getMinX == 1.0)
assert(envelope.getMinY == 2.0)
assert(envelope.getMaxX == 3.0)
assert(envelope.getMaxY == 4.0)
}

it("ST_Intersection_Aggr should handle null values") {
sparkSession
.sql("""
|SELECT explode(array(
| ST_GeomFromWKT('POLYGON((0 0, 4 0, 4 4, 0 4, 0 0))'),
| ST_GeomFromWKT(NULL),
| ST_GeomFromWKT('POLYGON((2 2, 6 2, 6 6, 2 6, 2 2))')
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("polygons_with_null_for_intersection")

val intersectionDF = sparkSession.sql(
"SELECT ST_Intersection_Aggr(geom) FROM polygons_with_null_for_intersection")
val result = intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry]

// Should intersect the 2 non-null polygons (intersection area = 4.0)
assert(result.getArea == 4.0)
}

it("ST_Union_Aggr should return null if all inputs are null") {
sparkSession
.sql("""
|SELECT explode(array(
| ST_GeomFromWKT(NULL),
| ST_GeomFromWKT(NULL)
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("all_null_union")

val unionDF = sparkSession.sql("SELECT ST_Union_Aggr(geom) FROM all_null_union")
val result = unionDF.take(1)(0).get(0)

assert(result == null)
}

it("ST_Envelope_Aggr should return null if all inputs are null") {
sparkSession
.sql("""
|SELECT explode(array(
| ST_GeomFromWKT(NULL),
| ST_GeomFromWKT(NULL)
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("all_null_envelope")

val envelopeDF = sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM all_null_envelope")
val result = envelopeDF.take(1)(0).get(0)

assert(result == null)
}

it("ST_Intersection_Aggr should return null if all inputs are null") {
sparkSession
.sql("""
|SELECT explode(array(
| ST_GeomFromWKT(NULL),
| ST_GeomFromWKT(NULL)
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("all_null_intersection")

val intersectionDF =
sparkSession.sql("SELECT ST_Intersection_Aggr(geom) FROM all_null_intersection")
val result = intersectionDF.take(1)(0).get(0)

assert(result == null)
}

it("ST_Collect_Agg should return null if all inputs are null") {
sparkSession
.sql("""
|SELECT explode(array(
| ST_GeomFromWKT(NULL),
| ST_GeomFromWKT(NULL)
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("all_null_collect")

val collectDF = sparkSession.sql("SELECT ST_Collect_Agg(geom) FROM all_null_collect")
val result = collectDF.take(1)(0).get(0)

assert(result == null)
}

it(
"ST_Envelope_Aggr should return empty geometry if inputs are mixed with null and empty geometries") {
sparkSession
.sql("""
|SELECT explode(array(
| NULL,
| NULL,
| ST_GeomFromWKT('POINT EMPTY'),
| NULL,
| ST_GeomFromWKT('POLYGON EMPTY')
|)) AS geom
""".stripMargin)
.createOrReplaceTempView("mixed_null_empty_envelope")

val envelopeDF =
sparkSession.sql("SELECT ST_Envelope_Aggr(geom) FROM mixed_null_empty_envelope")
val result = envelopeDF.take(1)(0).get(0)

assert(result != null)
assert(result.asInstanceOf[Geometry].isEmpty)
}
}

def generateRandomPolygon(index: Int): String = {
Expand Down
Loading