diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 7121ccdf954..de7d00b4e47 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -332,11 +332,6 @@ private void cleanupPusher() throws IOException { } private void close() throws IOException, InterruptedException { - // here we wait for all the in-flight batches to return which sent by dataPusher thread - dataPusher.waitOnTermination(); - sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue()); - shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId); - // merge and push residual data to reduce network traffic // NB: since dataPusher thread have no in-flight data at this point, // we now push merged data by task thread will not introduce any contention @@ -369,6 +364,8 @@ private void close() throws IOException, InterruptedException { sendOffsets = null; long waitStartTime = System.nanoTime(); + dataPusher.waitOnTermination(); + sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue()); shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions); writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);