From 0cac4576e629db6a974cfe1fc45840a00c456834 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 8 Oct 2025 15:55:13 +0300 Subject: [PATCH 01/36] Initial test worked --- datafusion/common/src/config.rs | 5 + datafusion/execution/src/config.rs | 11 + .../src/enforce_distribution.rs | 27 + .../physical-optimizer/src/join_selection.rs | 36 +- .../physical-plan/src/joins/hash_join/exec.rs | 156 +++ .../physical-plan/src/joins/hash_join/mod.rs | 1 + .../src/joins/hash_join/partitioned.rs | 1154 +++++++++++++++++ .../src/joins/hash_join/shared_bounds.rs | 4 + datafusion/physical-plan/src/joins/mod.rs | 2 + datafusion/proto/src/physical_plan/mod.rs | 1 + 10 files changed, 1393 insertions(+), 4 deletions(-) create mode 100644 datafusion/physical-plan/src/joins/hash_join/partitioned.rs diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 6abb2f5c6d3ca..cd862830cf0ba 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -761,6 +761,11 @@ config_namespace! { /// using the provided `target_partitions` level pub repartition_joins: bool, default = true + /// Should DataFusion use spillable partitioned hash joins instead of regular partitioned joins + /// when repartitioning is enabled. This allows handling larger datasets by spilling to disk + /// when memory pressure occurs during join execution. + pub enable_spillable_hash_join: bool, default = false + /// Should DataFusion allow symmetric hash joins for unbounded data sources even when /// its inputs do not have any ordering or filtering If the flag is not enabled, /// the SymmetricHashJoin operator will be unable to prune its internal buffers, diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 491b1aca69ea1..e959b5684f813 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -235,6 +235,11 @@ impl SessionConfig { self.options.optimizer.repartition_joins } + /// Are spillable partitioned hash joins enabled? + pub fn enable_spillable_hash_join(&self) -> bool { + self.options.optimizer.enable_spillable_hash_join + } + /// Are aggregates repartitioned during execution? pub fn repartition_aggregations(&self) -> bool { self.options.optimizer.repartition_aggregations @@ -298,6 +303,12 @@ impl SessionConfig { self } + /// Enables or disables spillable partitioned hash joins for handling larger datasets + pub fn with_enable_spillable_hash_join(mut self, enabled: bool) -> Self { + self.options_mut().optimizer.enable_spillable_hash_join = enabled; + self + } + /// Enables or disables the use of repartitioning for aggregations to improve parallelism pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self { self.options_mut().optimizer.repartition_aggregations = enabled; diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 898386e2f9880..af7d7d0ce414b 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -348,6 +348,33 @@ pub fn adjust_input_keys_ordering( // Can not satisfy, clear the current requirements and generate new empty requirements requirements.data.clear(); } + PartitionMode::PartitionedSpillable => { + // For partitioned spillable, use the same logic as regular partitioned + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + new_conditions.0, + filter.clone(), + join_type, + // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. + projection.clone(), + PartitionMode::PartitionedSpillable, + *null_equality, + ) + .map(|e| Arc::new(e) as _) + }; + return reorder_partitioned_join_keys( + requirements, + on, + &[], + &join_constructor, + ) + .map(Transformed::yes); + } } } else if let Some(CrossJoinExec { left, .. }) = plan.as_any().downcast_ref::() diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index c2cfca681f667..8785c20329edb 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -134,12 +134,14 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; + let enable_spillable = config.enable_spillable_hash_join; new_plan .transform_up(|plan| { statistical_join_selection_subrule( plan, collect_threshold_byte_size, collect_threshold_num_rows, + enable_spillable, ) }) .data() @@ -229,12 +231,19 @@ pub(crate) fn try_collect_left( /// creates a standard partitioned hash join. pub(crate) fn partitioned_hash_join( hash_join: &HashJoinExec, + enable_spillable: bool, ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); + let partition_mode = if enable_spillable { + PartitionMode::PartitionedSpillable + } else { + PartitionMode::Partitioned + }; + if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - hash_join.swap_inputs(PartitionMode::Partitioned) + hash_join.swap_inputs(partition_mode) } else { Ok(Arc::new(HashJoinExec::try_new( Arc::clone(left), @@ -243,7 +252,7 @@ pub(crate) fn partitioned_hash_join( hash_join.filter().cloned(), hash_join.join_type(), hash_join.projection.clone(), - PartitionMode::Partitioned, + partition_mode, hash_join.null_equality(), )?)) } @@ -255,6 +264,7 @@ fn statistical_join_selection_subrule( plan: Arc, collect_threshold_byte_size: usize, collect_threshold_num_rows: usize, + enable_spillable: bool, ) -> Result>> { let transformed = if let Some(hash_join) = plan.as_any().downcast_ref::() { @@ -266,12 +276,12 @@ fn statistical_join_selection_subrule( collect_threshold_num_rows, )? .map_or_else( - || partitioned_hash_join(hash_join).map(Some), + || partitioned_hash_join(hash_join, enable_spillable).map(Some), |v| Ok(Some(v)), )?, PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? .map_or_else( - || partitioned_hash_join(hash_join).map(Some), + || partitioned_hash_join(hash_join, enable_spillable).map(Some), |v| Ok(Some(v)), )?, PartitionMode::Partitioned => { @@ -287,6 +297,21 @@ fn statistical_join_selection_subrule( None } } + PartitionMode::PartitionedSpillable => { + println!("Using PartitionMode::PartitionedSpillable"); + // For partitioned spillable, use the same logic as regular partitioned + let left = hash_join.left(); + let right = hash_join.right(); + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + hash_join + .swap_inputs(PartitionMode::PartitionedSpillable) + .map(Some)? + } else { + None + } + } } } else if let Some(cross_join) = plan.as_any().downcast_ref::() { let left = cross_join.left(); @@ -522,6 +547,9 @@ pub(crate) fn swap_join_according_to_unboundedness( (PartitionMode::CollectLeft, _) => { hash_join.swap_inputs(PartitionMode::CollectLeft) } + (PartitionMode::PartitionedSpillable, _) => { + hash_join.swap_inputs(PartitionMode::PartitionedSpillable) + } (PartitionMode::Auto, _) => { // Use `PartitionMode::Partitioned` as default if `Auto` is selected. hash_join.swap_inputs(PartitionMode::Partitioned) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index cb697d4609953..574c9607d1be8 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -36,6 +36,7 @@ use crate::joins::utils::{ update_hash, OnceAsync, OnceFut, }; use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; +use crate::coalesce_partitions::CoalescePartitionsExec; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, @@ -592,6 +593,10 @@ impl HashJoinExec { PartitionMode::Partitioned => { symmetric_join_output_partitioning(left, right, &join_type)? } + PartitionMode::PartitionedSpillable => { + // For partitioned spillable, use the same partitioning as regular partitioned + symmetric_join_output_partitioning(left, right, &join_type)? + } }; let emission_type = if left.boundedness().is_unbounded() { @@ -797,6 +802,18 @@ impl ExecutionPlan for HashJoinExec { Distribution::UnspecifiedDistribution, Distribution::UnspecifiedDistribution, ], + PartitionMode::PartitionedSpillable => { + // For partitioned spillable, use the same distribution as regular partitioned + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } } } @@ -888,6 +905,7 @@ impl ExecutionPlan for HashJoinExec { partition: usize, context: Arc, ) -> Result { + println!("Executing HashJoinExec"); let on_left = self .on .iter() @@ -956,6 +974,80 @@ impl ExecutionPlan for HashJoinExec { PartitionMode::Auto ); } + PartitionMode::PartitionedSpillable => { + // For partitioned spillable mode, we need to collect the left side + // and then create a partitioned hash join stream + println!("PartitionedSpillable mode"); + // Coalesce left partitions to get the full build side in a single stream + let left_plan: Arc = if self.left.output_partitioning().partition_count() == 1 { + Arc::clone(&self.left) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.left))) + }; + let left_stream = left_plan.execute(0, Arc::clone(&context))?; + let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); + + let left_fut = self.left_fut.try_once(|| { + Ok(collect_left_input( + self.random_state.clone(), + left_stream, + on_left.clone(), + join_metrics.clone(), + reservation, + need_produce_result_in_final(self.join_type), + self.right().output_partitioning().partition_count(), + enable_dynamic_filter_pushdown, + )) + })?; + + // Re-enable spillable stream with single-partition direct-probe for now + use crate::joins::hash_join::partitioned::PartitionedHashJoinStream; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + let batch_size = context.session_config().batch_size(); + let num_partitions = 1; // Start with single partition correctness + let memory_threshold = { + let bytes = context + .session_config() + .options() + .execution + .sort_spill_reservation_bytes; + if bytes == 0 { 1024 * 1024 * 1024 } else { bytes } + }; + let partitioned_reservation = MemoryConsumer::new("PartitionedHashJoin") + .register(context.memory_pool()); + let partitioned_stream = PartitionedHashJoinStream::new( + partition, + self.schema(), + on_left, + on_right, + self.filter.clone(), + self.join_type, + right_stream, + left_fut, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + batch_size, + num_partitions, + memory_threshold, + partitioned_reservation, + context.runtime_env(), + )?; + return Ok(Box::pin(partitioned_stream)); + } }; let batch_size = context.session_config().batch_size(); @@ -1638,6 +1730,10 @@ mod tests { PartitionMode::Auto => { return internal_err!("Unexpected PartitionMode::Auto in join tests") } + PartitionMode::PartitionedSpillable => Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), }; let right_repartitioned: Arc = match partition_mode { @@ -1659,6 +1755,10 @@ mod tests { PartitionMode::Auto => { return internal_err!("Unexpected PartitionMode::Auto in join tests") } + PartitionMode::PartitionedSpillable => Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), }; let join = HashJoinExec::try_new( @@ -1788,6 +1888,62 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_spillable_join_inner_one(batch_size: usize) -> Result<()> { + // Configure tiny spill reservation to force spill and 4 partitions + let session_config = SessionConfig::default() + .with_batch_size(batch_size) + .with_target_partitions(4) + .with_sort_spill_reservation_bytes(1) + .with_spill_compression(datafusion_common::config::SpillCompression::Uncompressed); + let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (columns, batches, metrics) = join_collect_with_partition_mode( + Arc::clone(&left), + Arc::clone(&right), + on, + &JoinType::Inner, + PartitionMode::PartitionedSpillable, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + assert_join_metrics!(metrics, 3); + Ok(()) + } + #[tokio::test] async fn join_inner_one_no_shared_column_names() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 7f1e5cae13a3e..7c4a297414f3c 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -20,5 +20,6 @@ pub use exec::HashJoinExec; mod exec; +mod partitioned; mod shared_bounds; mod stream; diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs new file mode 100644 index 0000000000000..19814d0654992 --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -0,0 +1,1154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partitioned Hash Join implementation +//! +//! This module implements a partitioned hash join that can handle large datasets +//! by partitioning both build and probe sides into multiple partitions and +//! processing them sequentially. This approach is similar to sort-merge join +//! but uses hash-based partitioning instead of sorting. +//! +//! # State Machine Overview +//! +//! The partitioned hash join follows this state machine pattern: +//! +//! ```text +//! PartitionBuildSide → ProcessPartitions(i) → Done +//! ``` +//! +//! ## PartitionBuildSide State +//! - Partitions build-side data into multiple partitions based on hash values +//! - Keeps one partition resident in memory (partition 0) +//! - Spills other partitions to disk when memory pressure occurs +//! - Uses consistent hashing to ensure same keys go to same partition +//! +//! ## ProcessPartitions State +//! - Processes each partition sequentially +//! - Loads build-side hash map for current partition (from memory or disk) +//! - Probes all probe batches for this partition against the hash map +//! - Generates join results and handles unmatched rows for outer joins +//! - Tracks matched rows for proper outer join semantics + +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::mem; + +use crate::joins::hash_join::exec::JoinLeftData; +use crate::joins::join_hash_map::JoinHashMapType; +use crate::joins::utils::{ + build_batch_from_indices, equal_rows_arr, get_final_indices_from_bit_map, + need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + OnceFut, StatefulStreamResult, +}; +use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; +use crate::spill::spill_manager::SpillManager; +use crate::{RecordBatchStream, SendableRecordBatchStream}; + +use arrow::array::{Array, ArrayRef, BooleanBufferBuilder, UInt32Array, UInt64Array}; +use arrow::compute::take; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{ + hash_utils::create_hashes, internal_datafusion_err, internal_err, DataFusionError, + JoinSide, JoinType, NullEquality, Result, +}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_physical_expr::PhysicalExprRef; + +use ahash::RandomState; +use futures::{ready, Stream, StreamExt}; + + +/// State of the partitioned hash join stream +#[derive(Debug, Clone)] +pub(super) enum PartitionedHashJoinState { + /// Initial state - partitioning build side + PartitionBuildSide, + /// Processing a specific partition + ProcessPartition(ProcessPartitionState), + /// All partitions processed, handling unmatched rows for outer joins + HandleUnmatchedRows, + /// Join completed + Completed, +} + +/// State for processing a specific partition +#[derive(Debug, Clone)] +pub(super) struct ProcessPartitionState { + /// Current partition being processed + pub partition_id: usize, + /// Total number of partitions + pub total_partitions: usize, + /// Whether we're processing the last partition + pub is_last_partition: bool, +} + +/// Represents a partition of build-side data +pub(super) enum BuildPartition { + /// Partition data in memory + InMemory { + /// Hash map for this partition + hash_map: Box, + /// Build-side batch data + batch: RecordBatch, + /// Join key values + values: Vec, + /// Memory reservation for this partition + reservation: MemoryReservation, + }, + /// Partition data spilled to disk + Spilled { + /// Spill file containing the partition data (taken on load) + spill_file: Option, + /// Memory reservation (released when spilled) + reservation: MemoryReservation, + }, +} + +/// Represents a partition of probe-side data +#[derive(Debug)] +pub(super) struct ProbePartition { + /// Batches in this partition + pub batches: Vec, + /// Join key values for each batch + pub values: Vec>, + /// Hash values for each batch + pub hashes: Vec>, +} + +// Use RefCountedTempFile from datafusion_execution::disk_manager + +/// Partitioned Hash Join stream that can handle large datasets by partitioning +/// both build and probe sides and processing them sequentially. +pub(super) struct PartitionedHashJoinStream { + // ======================================================================== + // PROPERTIES: + // These fields are initialized at the start and remain constant throughout + // the execution. + // ======================================================================== + /// Partition identifier for debugging and determinism + pub partition: usize, + /// Output schema + pub schema: SchemaRef, + /// Join key columns from the right (probe side) + pub on_right: Vec, + /// Join key columns from the left (build side) + pub on_left: Vec, + /// Optional join filter + pub filter: Option, + /// Type of the join (left, right, semi, etc) + pub join_type: JoinType, + /// Right (probe) input stream + pub right: SendableRecordBatchStream, + /// Future that yields the collected build-side data + pub left_fut: OnceFut, + /// Random state used for hashing initialization + pub random_state: RandomState, + /// Metrics + pub join_metrics: BuildProbeJoinMetrics, + /// Information of index and left / right placement of columns + pub column_indices: Vec, + /// Defines the null equality for the join + pub null_equality: NullEquality, + /// Maximum output batch size + pub batch_size: usize, + /// Number of partitions to use + pub num_partitions: usize, + /// Memory threshold for spilling (in bytes) + pub memory_threshold: usize, + + // ======================================================================== + // STATE: + // These fields track the execution state and are updated during execution. + // ======================================================================== + /// Current state of the stream + pub state: PartitionedHashJoinState, + /// Build-side partitions + pub build_partitions: Vec, + /// Probe-side partitions + pub probe_partitions: Vec, + /// Current partition being processed + pub current_partition: Option, + /// Manages the process of spilling and reading back intermediate data + pub spill_manager: SpillManager, + /// Memory reservation for the entire operation + pub memory_reservation: MemoryReservation, + /// Runtime environment + pub runtime_env: Arc, + /// Scratch space for computing hashes + pub hashes_buffer: Vec, + /// Whether the right side has an ordering to potentially preserve + pub right_side_ordered: bool, + /// Shared bounds accumulator for coordinating dynamic filter updates (optional) + pub bounds_accumulator: Option>, + /// Current probe batch (filtered to the active partition), if any + pub current_probe_batch: Option, + /// Current probe values for ON expressions + pub current_probe_values: Vec, + /// Current probe hashes (filtered to the active partition) + pub current_probe_hashes: Vec, + /// Current lookup offset within the join hash map + pub current_offset: crate::joins::join_hash_map::JoinHashMapOffset, + /// Bitmaps to track matched build-side rows for outer joins (one per partition) + pub matched_build_rows_per_partition: Vec, + /// Current partition being processed for unmatched rows + pub unmatched_partition: usize, + /// Cached unmatched build/probe indices for current partition (chunked emission) + pub unmatched_left_indices_cache: Option, + pub unmatched_right_indices_cache: Option, + pub unmatched_offset: usize, + /// Whether we've buffered the entire probe side into per-partition batches + pub probes_buffered: bool, + /// Current read position per partition within buffered probe batches + pub probe_batch_positions: Vec, +} + +impl PartitionedHashJoinStream { + /// Ensure the build partition is loaded in-memory (reload if spilled) + fn ensure_build_partition_loaded(&mut self, part_id: usize) -> Result<()> { + let needs_reload = matches!( + self.build_partitions.get(part_id), + Some(BuildPartition::Spilled { .. }) + ); + if !needs_reload { + return Ok(()); + } + + if let Some(BuildPartition::Spilled { spill_file, .. }) = + self.build_partitions.get_mut(part_id) + { + let spill_file = spill_file + .take() + .ok_or_else(|| internal_datafusion_err!("spill file already consumed for this partition"))?; + + let mut stream = self.spill_manager.read_spill_as_stream(spill_file)?; + let batch = futures::executor::block_on(async { + use futures::StreamExt; + stream.next().await.transpose() + })? + .ok_or_else(|| internal_datafusion_err!("empty spilled partition"))?; + + println!( + "Reloaded spilled build partition {} for probing (rows={})", + part_id, + batch.num_rows() + ); + + // Reconstruct join values from on_left expressions + let mut values: Vec = Vec::with_capacity(self.on_left.len()); + for c in &self.on_left { + values.push(c.evaluate(&batch)?.into_array(batch.num_rows())?); + } + + // Rebuild the hash map from the reloaded batch + let mut hash_map: Box = Box::new( + crate::joins::join_hash_map::JoinHashMapU32::with_capacity(batch.num_rows()), + ); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(batch.num_rows(), 0); + crate::joins::utils::update_hash( + &self.on_left, + &batch, + &mut *hash_map, + 0, + &self.random_state, + &mut self.hashes_buffer, + 0, + true, + )?; + + let new_reservation = MemoryConsumer::new("partition_reload") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + self.build_partitions[part_id] = BuildPartition::InMemory { + hash_map, + batch, + values, + reservation: new_reservation, + }; + } + Ok(()) + } + /// Create a new partitioned hash join stream + pub fn new( + partition: usize, + schema: SchemaRef, + on_left: Vec, + on_right: Vec, + filter: Option, + join_type: JoinType, + right: SendableRecordBatchStream, + left_fut: OnceFut, + random_state: RandomState, + join_metrics: BuildProbeJoinMetrics, + column_indices: Vec, + null_equality: NullEquality, + batch_size: usize, + num_partitions: usize, + memory_threshold: usize, + memory_reservation: MemoryReservation, + runtime_env: Arc, + ) -> Result { + let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), partition); + let spill_manager = SpillManager::new( + runtime_env.clone(), + spill_metrics, + schema.clone(), + ); + + println!( + "PartitionedHashJoinStream created: partition={}, num_partitions={}, memory_threshold={} bytes", + partition, num_partitions, memory_threshold + ); + + Ok(Self { + partition, + schema, + on_left, + on_right, + filter, + join_type, + right, + left_fut, + random_state, + join_metrics, + column_indices, + null_equality, + batch_size, + num_partitions, + memory_threshold, + state: PartitionedHashJoinState::PartitionBuildSide, + build_partitions: Vec::new(), + probe_partitions: Vec::new(), + current_partition: None, + spill_manager, + memory_reservation, + runtime_env, + hashes_buffer: Vec::new(), + right_side_ordered: false, + bounds_accumulator: None, + current_probe_batch: None, + current_probe_values: vec![], + current_probe_hashes: vec![], + current_offset: (0, None), + matched_build_rows_per_partition: Vec::new(), + unmatched_partition: 0, + unmatched_left_indices_cache: None, + unmatched_right_indices_cache: None, + unmatched_offset: 0, + probes_buffered: false, + probe_batch_positions: vec![], + }) + } + + /// Buffer the entire probe side stream into per-partition batches. + /// Returns Pending until the right stream is fully consumed. + fn buffer_probe_side( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + if self.probe_partitions.is_empty() { + self.probe_partitions = (0..self.num_partitions) + .map(|_| ProbePartition { + batches: Vec::new(), + values: Vec::new(), + hashes: Vec::new(), + }) + .collect(); + } + loop { + match self.right.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + // Compute ON values for the full batch + let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; + keys_values.push(v); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + // For each partition, select rows and push filtered batch + for part_id in 0..self.num_partitions { + let indices: Vec = hashes + .iter() + .enumerate() + .filter_map(|(i, &h)| ((h as usize) % self.num_partitions == part_id).then_some(i as u32)) + .collect(); + if indices.is_empty() { + continue; + } + let indices_arr: UInt32Array = indices.clone().into(); + let mut filtered_columns: Vec = Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push(take(col, &indices_arr, None).map_err(DataFusionError::from)?); + } + let filtered_batch = RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + + // Filtered ON values for this partition's batch + let mut filtered_on_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c.evaluate(&filtered_batch)?.into_array(filtered_batch.num_rows())?; + filtered_on_values.push(v); + } + let filtered_hashes: Vec = indices + .iter() + .map(|&i| hashes[i as usize]) + .collect(); + + self.probe_partitions[part_id].batches.push(filtered_batch); + self.probe_partitions[part_id].values.push(filtered_on_values); + self.probe_partitions[part_id].hashes.push(filtered_hashes); + } + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + // Finished buffering + self.probes_buffered = true; + self.probe_batch_positions = vec![0; self.num_partitions]; + println!( + "Buffered probe side: per-partition batch counts = {:?}", + self.probe_partitions.iter().map(|p| p.batches.len()).collect::>() + ); + return Poll::Ready(Ok(())); + } + Poll::Pending => return Poll::Pending, + } + } + } + + /// Partition build-side data into multiple partitions + fn partition_build_side( + &mut self, + build_data: Arc, + ) -> Result>> { + println!("Partitioning build side data into {} partitions", self.num_partitions); + // Initialize partitions + self.build_partitions = Vec::with_capacity(self.num_partitions); + // Initialize per-partition matched rows bitmaps + self.matched_build_rows_per_partition = Vec::with_capacity(self.num_partitions); + + // Extract build-side data + let batch = build_data.batch(); + let values = build_data.values(); + + // Compute hash values for all rows in the build-side batch + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(values, &self.random_state, &mut hashes)?; + + // Partition the data based on hash values + let mut partition_batches: Vec> = vec![Vec::new(); self.num_partitions]; + + for (row_idx, &hash) in hashes.iter().enumerate() { + let partition_id = (hash as usize) % self.num_partitions; + partition_batches[partition_id].push(row_idx); + } + + // Create partitions; spill when memory_threshold is exceeded + for partition_id in 0..self.num_partitions { + let row_indices = &partition_batches[partition_id]; + if row_indices.is_empty() { + // Empty partition - create empty hash map + let empty_hash_map: Box = + Box::new(crate::joins::join_hash_map::JoinHashMapU32::with_capacity(0)); + let empty_batch = batch.slice(0, 0); + let empty_values: Vec = values.iter().map(|arr| arr.slice(0, 0)).collect(); + + // Initialize empty matched rows bitmap for this partition + let matched_bitmap = BooleanBufferBuilder::new(0); + self.matched_build_rows_per_partition.push(matched_bitmap); + + self.build_partitions.push(BuildPartition::InMemory { + hash_map: empty_hash_map, + batch: empty_batch, + values: empty_values, + reservation: MemoryConsumer::new("empty_partition").with_can_spill(true).register(&self.runtime_env.memory_pool), + }); + continue; + } + + // Create batch slice for this partition + let partition_batch = self.take_rows(batch, row_indices)?; + let partition_values: Vec = values.iter() + .map(|arr| self.take_rows_from_array(arr, row_indices)) + .collect::>>()?; + + // Estimate memory for this partition + let estimated_size = partition_batch.get_array_memory_size() + + partition_values + .iter() + .map(|a| a.get_array_memory_size()) + .sum::(); + + // Decide spilling using global reservation (per DF best practice) + let mut will_spill = false; + match self.memory_reservation.try_grow(estimated_size) { + Ok(_) => { + if self.memory_reservation.size() > self.memory_threshold { + // Exceeds threshold: roll back and spill + let _ = self.memory_reservation.try_shrink(estimated_size); + will_spill = true; + } + } + Err(_) => { + will_spill = true; + } + } + + if will_spill && self.runtime_env.disk_manager.tmp_files_enabled() { + println!( + "Spilling build partition {} (rows={}) due to memory threshold (threshold={} bytes, current={})", + partition_id, + row_indices.len(), + self.memory_threshold, + self.memory_reservation.size() + ); + // Spill this partition to disk and do not keep it in memory + let spill_file = self + .spill_manager + .spill_record_batch_and_finish(&[partition_batch.clone()], "hash_join_build_partition")? + .ok_or_else(|| internal_datafusion_err!("expected spill file"))?; + + // Initialize matched rows bitmap for this partition + let mut matched_bitmap = BooleanBufferBuilder::new(row_indices.len()); + matched_bitmap.append_n(row_indices.len(), false); + self.matched_build_rows_per_partition.push(matched_bitmap); + + // Per-partition reservation kept as zero-sized placeholder + let reservation = MemoryConsumer::new("partition_spilled").with_can_spill(true).register(&self.runtime_env.memory_pool); + + self.build_partitions.push(BuildPartition::Spilled { + spill_file: Some(spill_file), + reservation, + }); + continue; + } + + // Create hash map for this partition + let partition_hash_map: Box = + Box::new(crate::joins::join_hash_map::JoinHashMapU32::with_capacity(row_indices.len())); + + // Build the hash map for this partition using existing utilities + let mut partition_hash_map = partition_hash_map; + self.hashes_buffer.clear(); + self.hashes_buffer.resize(partition_batch.num_rows(), 0); + crate::joins::utils::update_hash( + &self.on_left, + &partition_batch, + &mut *partition_hash_map, + 0, + &self.random_state, + &mut self.hashes_buffer, + 0, + true, + )?; + + println!( + "Built in-memory hash map for partition {} (rows={})", + partition_id, + row_indices.len() + ); + + // Initialize matched rows bitmap for this partition + let mut matched_bitmap = BooleanBufferBuilder::new(row_indices.len()); + matched_bitmap.append_n(row_indices.len(), false); + self.matched_build_rows_per_partition.push(matched_bitmap); + + // Per-partition reservation: zero-sized placeholder; global reservation tracks memory + let reservation = MemoryConsumer::new("partition_memory").with_can_spill(true).register(&self.runtime_env.memory_pool); + + self.build_partitions.push(BuildPartition::InMemory { + hash_map: partition_hash_map, + batch: partition_batch, + values: partition_values, + reservation, + }); + } + + // Start processing the first partition + println!( + "Partitioning complete. Created {} partitions. Starting to process partition 0", + self.build_partitions.len() + ); + + self.state = PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + partition_id: 0, + total_partitions: self.num_partitions, + is_last_partition: self.num_partitions == 1, + }); + + Ok(StatefulStreamResult::Continue) + } + + /// Take specific rows from a RecordBatch + fn take_rows(&self, batch: &RecordBatch, indices: &[usize]) -> Result { + use arrow::compute::take; + use arrow::array::UInt32Array; + + let indices_array = UInt32Array::from( + indices.iter().map(|&i| i as u32).collect::>() + ); + + let columns: Result, DataFusionError> = batch.columns().iter() + .map(|col| take(col, &indices_array, None).map_err(|e| e.into())) + .collect(); + + Ok(RecordBatch::try_new(batch.schema(), columns?)?) + } + + /// Take specific rows from an ArrayRef + fn take_rows_from_array(&self, array: &ArrayRef, indices: &[usize]) -> Result { + use arrow::compute::take; + use arrow::array::UInt32Array; + + let indices_array = UInt32Array::from( + indices.iter().map(|&i| i as u32).collect::>() + ); + + Ok(take(array, &indices_array, None).map_err(DataFusionError::from)?) + } + + /// Release resources associated with a finished partition when safe to do so. + /// Only releases memory eagerly when we don't need unmatched rows in the final phase. + fn release_partition_resources(&mut self, partition_id: usize) { + if need_produce_result_in_final(self.join_type) { + return; + } + + if partition_id >= self.build_partitions.len() { + return; + } + + // Take ownership of the old partition to drop heavy resources + let placeholder_reservation = MemoryConsumer::new("partition_released_placeholder") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let old_partition = mem::replace( + &mut self.build_partitions[partition_id], + BuildPartition::Spilled { + spill_file: None, + reservation: placeholder_reservation, + }, + ); + + match old_partition { + BuildPartition::InMemory { batch, values, reservation, .. } => { + // Estimate memory held by this partition and shrink global reservation + let mut estimated_size = batch.get_array_memory_size(); + estimated_size += values.iter().map(|a| a.get_array_memory_size()).sum::(); + let _ = self.memory_reservation.try_shrink(estimated_size); + + // Replace with an empty in-memory partition to keep indexing stable + let empty_batch = RecordBatch::new_empty(batch.schema()); + let empty_values: Vec = self + .on_left + .iter() + .filter_map(|expr| expr.evaluate(&empty_batch).ok()) + .filter_map(|v| v.into_array(empty_batch.num_rows()).ok()) + .collect(); + let empty_hash_map: Box = Box::new( + crate::joins::join_hash_map::JoinHashMapU32::with_capacity(0), + ); + + self.build_partitions[partition_id] = BuildPartition::InMemory { + hash_map: empty_hash_map, + batch: empty_batch, + values: empty_values, + reservation, + }; + } + BuildPartition::Spilled { reservation, .. } => { + // Keep as empty spilled (no further action needed) + self.build_partitions[partition_id] = BuildPartition::Spilled { + spill_file: None, + reservation, + }; + } + } + } + + /// Process a specific partition + fn process_partition( + &mut self, + cx: &mut Context<'_>, + partition_state: &ProcessPartitionState, + ) -> Poll>>> { + // Guard against invalid partition ids (off-by-one protection) + if partition_state.partition_id >= partition_state.total_partitions { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + println!( + "Processing partition {} (total_partitions={}), build_partitions.len()={}", + partition_state.partition_id, + partition_state.total_partitions, + self.build_partitions.len() + ); + + // Do not buffer probe side here; selection happens below depending on num_partitions + + // (Spill reload handled by ensure_build_partition_loaded earlier if needed) + + // (Build partition will be immutably borrowed later within a narrower scope) + + // Ensure the build partition is ready (reload if spilled) BEFORE any immutable borrows + self.ensure_build_partition_loaded(partition_state.partition_id)?; + + // If only 1 partition, stream the probe side directly (simpler and correct across executor partitions) + if self.num_partitions == 1 { + if self.current_probe_batch.is_none() { + match ready!(self.right.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + // Compute hashes for the full batch + let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; + keys_values.push(v); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + // No filtering needed when only one partition + self.current_probe_hashes = hashes; + self.current_probe_values = keys_values; + self.current_probe_batch = Some(batch); + self.current_offset = (0, None); + + if let Some(pb) = self.current_probe_batch.as_ref() { + println!( + "[spill-join] Direct probe batch rows={} (partitions=1)", + pb.num_rows() + ); + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(pb.num_rows()); + } + } + Some(Err(e)) => return Poll::Ready(Err(e)), + None => { + // No more probe data for this partition, release and advance + self.release_partition_resources(partition_state.partition_id); + if partition_state.is_last_partition { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + } else { + self.state = PartitionedHashJoinState::ProcessPartition( + ProcessPartitionState { + partition_id: partition_state.partition_id + 1, + total_partitions: partition_state.total_partitions, + is_last_partition: partition_state.partition_id + 1 + == partition_state.total_partitions, + }, + ); + } + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + } + } + } else { + // For multiple inner partitions, buffer the probe side once and consume per partition + if !self.probes_buffered { + ready!(self.buffer_probe_side(cx))?; + } + if self.current_probe_batch.is_none() { + let part_id = partition_state.partition_id; + let pos = *self.probe_batch_positions.get(part_id).unwrap_or(&0); + if let Some(probe_part) = self.probe_partitions.get(part_id) { + if pos < probe_part.batches.len() { + let filtered_batch = probe_part.batches[pos].clone(); + let filtered_on_values = probe_part.values[pos].clone(); + let filtered_hashes = probe_part.hashes[pos].clone(); + + self.current_probe_hashes = filtered_hashes; + self.current_probe_values = filtered_on_values; + self.current_probe_batch = Some(filtered_batch); + self.current_offset = (0, None); + self.probe_batch_positions[part_id] = pos + 1; + } else { + // No more probe data for this partition, release and advance + self.release_partition_resources(partition_state.partition_id); + if partition_state.is_last_partition { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + } else { + self.state = PartitionedHashJoinState::ProcessPartition( + ProcessPartitionState { + partition_id: partition_state.partition_id + 1, + total_partitions: partition_state.total_partitions, + is_last_partition: partition_state.partition_id + 1 + == partition_state.total_partitions, + }, + ); + } + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + } + } + } + + // At this point we have a current probe batch for this partition + let (result, build_ids_to_mark, next_offset) = { + let probe_batch = self + .current_probe_batch + .as_ref() + .ok_or_else(|| internal_datafusion_err!("expected probe batch"))?; + + let (build_hashmap, build_batch, build_values) = match self + .build_partitions + .get(partition_state.partition_id) + { + Some(BuildPartition::InMemory { + hash_map, + batch, + values, + .. + }) => (&**hash_map, batch, values as &Vec), + _ => return Poll::Ready(internal_err!("Missing or invalid build partition")), + }; + + // Lookup against hash map with limit + let (probe_indices, build_indices, next_offset) = build_hashmap + .get_matched_indices_with_limit_offset( + &self.current_probe_hashes, + self.batch_size, + self.current_offset, + ); + + let build_indices: UInt64Array = build_indices.into(); + let probe_indices: UInt32Array = probe_indices.into(); + + println!( + "[spill-join] Candidates before equality: build_ids={}, probe_ids={}, build_rows={}, probe_rows={}", + build_indices.len(), + probe_indices.len(), + build_batch.num_rows(), + probe_batch.num_rows() + ); + + // Resolve hash collisions + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, + build_values, + &self.current_probe_values, + self.null_equality, + )?; + + println!( + "[spill-join] Matched after equality: {}", + build_indices.len() + ); + + // Prepare ids for marking after we release borrows + let build_ids_to_mark: Vec = build_indices.values().to_vec(); + + // Build output batch (Left side is build) + let result = build_batch_from_indices( + &self.schema, + build_batch, + probe_batch, + &build_indices, + &probe_indices, + &self.column_indices, + JoinSide::Left, + )?; + + (result, build_ids_to_mark, next_offset) + }; + + // Mark matched build-side rows for outer joins (use current partition's bitmap) + if let Some(bitmap) = self.matched_build_rows_per_partition.get_mut(partition_state.partition_id) { + for build_idx in build_ids_to_mark { + bitmap.set_bit(build_idx as usize, true); + } + } + + // Update offset or fetch a new probe batch + if let Some(offset) = next_offset { + self.current_offset = offset; + } else { + // Finished this probe batch + self.current_probe_batch = None; + self.current_probe_values.clear(); + self.current_probe_hashes.clear(); + self.current_offset = (0, None); + } + + if result.num_rows() == 0 { + println!( + "[spill-join] Skipping empty batch emission (partition={})", + partition_state.partition_id + ); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + self.join_metrics.output_batches.add(1); + self.join_metrics.baseline.record_output(result.num_rows()); + println!( + "[spill-join] Emitting batch: rows={} (partition={})", + result.num_rows(), + partition_state.partition_id + ); + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))) + } + + /// Handle unmatched rows for outer joins + fn handle_unmatched_rows(&mut self) -> Result>> { + if !need_produce_result_in_final(self.join_type) { + self.state = PartitionedHashJoinState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + } + + // If we have cached unmatched indices for current partition, emit them chunk-by-chunk + if let (Some(left_all), Some(right_all)) = ( + self.unmatched_left_indices_cache.as_ref(), + self.unmatched_right_indices_cache.as_ref(), + ) { + let total = left_all.len(); + if self.unmatched_offset < total { + let remaining = total - self.unmatched_offset; + let to_emit = remaining.min(self.batch_size); + + let left_chunk_ref = left_all.slice(self.unmatched_offset, to_emit); + let right_chunk_ref = right_all.slice(self.unmatched_offset, to_emit); + let left_chunk = left_chunk_ref + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("failed to downcast left indices chunk"))?; + let right_chunk = right_chunk_ref + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("failed to downcast right indices chunk"))?; + + // Use current partition's build batch + let partition = self + .build_partitions + .get(self.unmatched_partition) + .ok_or_else(|| internal_datafusion_err!("missing build partition during unmatched cached emission"))?; + let build_batch = match partition { + BuildPartition::InMemory { batch, .. } => batch, + BuildPartition::Spilled { .. } => { + // Should not happen because we only cache after loading InMemory indices + return Ok(StatefulStreamResult::Continue); + } + }; + + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + println!( + "Emitting unmatched rows chunk: partition={}, offset={}, size={} (total={})", + self.unmatched_partition, + self.unmatched_offset, + to_emit, + total + ); + + let result = build_batch_from_indices( + &self.schema, + build_batch, + &empty_right_batch, + left_chunk, + right_chunk, + &self.column_indices, + JoinSide::Left, + )?; + + self.unmatched_offset += to_emit; + if self.unmatched_offset >= total { + // finished this partition's unmatched rows + self.unmatched_left_indices_cache = None; + self.unmatched_right_indices_cache = None; + self.unmatched_offset = 0; + println!( + "Finished emitting unmatched rows for partition {}", + self.unmatched_partition + ); + self.unmatched_partition += 1; + } + + return Ok(StatefulStreamResult::Ready(Some(result))); + } else { + // Safety: should not reach here; reset caches + self.unmatched_left_indices_cache = None; + self.unmatched_right_indices_cache = None; + self.unmatched_offset = 0; + } + } + + // Process unmatched rows for the current partition + if self.unmatched_partition < self.build_partitions.len() { + let partition = self.build_partitions.get_mut(self.unmatched_partition) + .ok_or_else(|| internal_datafusion_err!("missing build partition during unmatched processing"))?; + + match partition { + BuildPartition::InMemory { batch, .. } => { + // Get unmatched indices for this partition using its bitmap + let (left_indices, right_indices) = if let Some(bitmap) = self.matched_build_rows_per_partition.get(self.unmatched_partition) { + get_final_indices_from_bit_map( + bitmap, + self.join_type, + ) + } else { + // If no bitmap, skip this partition + self.unmatched_partition += 1; + return Ok(StatefulStreamResult::Continue); + }; + + println!( + "Unmatched calculation for partition {} -> {} rows", + self.unmatched_partition, + left_indices.len() + ); + + if left_indices.len() > 0 { + // Cache the full indices and emit first chunk via cached path next call + self.unmatched_left_indices_cache = Some(left_indices.clone()); + self.unmatched_right_indices_cache = Some(right_indices.clone()); + self.unmatched_offset = 0; + // Fall-through into cached emission on next invocation + return Ok(StatefulStreamResult::Continue); + } else { + // No unmatched rows in this partition, move to next + self.unmatched_partition += 1; + return Ok(StatefulStreamResult::Continue); + } + } + BuildPartition::Spilled { spill_file, .. } => { + // Reload the spilled partition to produce unmatched rows + let spill_file = spill_file.take().ok_or_else(|| internal_datafusion_err!("spill file already consumed for unmatched"))?; + + let mut stream = self.spill_manager.read_spill_as_stream(spill_file)?; + let batch = futures::executor::block_on(async { + use futures::StreamExt; + stream.next().await.transpose() + })? + .ok_or_else(|| internal_datafusion_err!("empty spilled partition for unmatched"))?; + + println!( + "Reloaded spilled build partition {} for unmatched rows (rows={})", + self.unmatched_partition, + batch.num_rows() + ); + + // Replace with in-memory (no need to rebuild hash map here) + let new_reservation = MemoryConsumer::new("partition_reload_unmatched") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let mut values: Vec = Vec::with_capacity(self.on_left.len()); + for c in &self.on_left { + values.push(c.evaluate(&batch)?.into_array(batch.num_rows())?); + } + let hash_map: Box = Box::new( + crate::joins::join_hash_map::JoinHashMapU32::with_capacity(batch.num_rows()), + ); + self.build_partitions[self.unmatched_partition] = BuildPartition::InMemory { + hash_map, + batch, + values, + reservation: new_reservation, + }; + println!( + "Prepared spilled partition {} as InMemory for unmatched emission", + self.unmatched_partition + ); + // Continue; next iteration will handle InMemory branch to emit unmatched + return Ok(StatefulStreamResult::Continue); + } + } + } else { + // All partitions processed + self.state = PartitionedHashJoinState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + } + } +} + +impl RecordBatchStream for PartitionedHashJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl Stream for PartitionedHashJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match self.state.clone() { + PartitionedHashJoinState::PartitionBuildSide => { + // Collect build side and partition it + let left_data = { + let fut = &mut self.left_fut; + ready!(fut.get_shared(cx))? + }; + match self.partition_build_side(left_data) { + Ok(StatefulStreamResult::Continue) => continue, + Ok(StatefulStreamResult::Ready(Some(batch))) => { + println!( + "[spill-join] poll_next yielding initial batch: rows={}", + batch.num_rows() + ); + return Poll::Ready(Some(Ok(batch))); + } + Ok(StatefulStreamResult::Ready(None)) => return Poll::Ready(None), + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + PartitionedHashJoinState::ProcessPartition(partition_state) => { + match self.process_partition(cx, &partition_state) { + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { + println!( + "[spill-join] poll_next yielding process batch: rows={} (state partition={})", + batch.num_rows(), partition_state.partition_id + ); + return Poll::Ready(Some(Ok(batch))); + } + Poll::Ready(Ok(StatefulStreamResult::Ready(None))) => { + return Poll::Ready(None); + } + Poll::Ready(Ok(StatefulStreamResult::Continue)) => { + continue; + } + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => return Poll::Pending, + } + } + PartitionedHashJoinState::HandleUnmatchedRows => { + match self.handle_unmatched_rows() { + Ok(StatefulStreamResult::Ready(Some(batch))) => { + println!( + "[spill-join] poll_next yielding unmatched batch: rows={}", + batch.num_rows() + ); + return Poll::Ready(Some(Ok(batch))); + } + Ok(StatefulStreamResult::Ready(None)) => { + return Poll::Ready(None); + } + Ok(StatefulStreamResult::Continue) => { + continue; + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + PartitionedHashJoinState::Completed => return Poll::Ready(None), + } + } + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 40dc4ac2e5d10..335ead1dedf02 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -160,6 +160,10 @@ impl SharedBoundsAccumulator { PartitionMode::Partitioned => { left_child.output_partitioning().partition_count() } + // For partitioned spillable, use the same logic as regular partitioned + PartitionMode::PartitionedSpillable => { + left_child.output_partitioning().partition_count() + } // Default value, will be resolved during optimization (does not exist once `execute()` is called; will be replaced by one of the other two) PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), }; diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434e..e30fcf763b0e3 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -56,6 +56,8 @@ pub enum PartitionMode { /// mode(Partitioned/CollectLeft) is optimal based on statistics. It will /// also consider swapping the left and right inputs for the Join Auto, + /// Partitioned hash join that can spill to disk for large datasets + PartitionedSpillable, } /// Partitioning mode to use for symmetric hash join diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e577de5b1d0e0..bfa9cf0425887 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -2202,6 +2202,7 @@ impl protobuf::PhysicalPlanNode { PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft, PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned, PartitionMode::Auto => protobuf::PartitionMode::Auto, + PartitionMode::PartitionedSpillable => protobuf::PartitionMode::Partitioned, }; Ok(protobuf::PhysicalPlanNode { From de04ffd9dbc210e27427d6fa944939fb11e5045a Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Fri, 10 Oct 2025 13:00:49 +0300 Subject: [PATCH 02/36] Temporary unknown partitioning to make it work --- datafusion/common/src/config.rs | 2 +- .../physical-plan/src/joins/hash_join/exec.rs | 91 ++++++++++++++++--- .../src/joins/hash_join/partitioned.rs | 2 +- 3 files changed, 78 insertions(+), 17 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index cd862830cf0ba..54895bc56e049 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -764,7 +764,7 @@ config_namespace! { /// Should DataFusion use spillable partitioned hash joins instead of regular partitioned joins /// when repartitioning is enabled. This allows handling larger datasets by spilling to disk /// when memory pressure occurs during join execution. - pub enable_spillable_hash_join: bool, default = false + pub enable_spillable_hash_join: bool, default = true /// Should DataFusion allow symmetric hash joins for unbounded data sources even when /// its inputs do not have any ordering or filtering If the flag is not enabled, diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 574c9607d1be8..a41654af9e762 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -50,7 +50,7 @@ use crate::{ need_produce_result_in_final, symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, }, - metrics::{ExecutionPlanMetricsSet, MetricsSet}, + metrics::{ExecutionPlanMetricsSet, MetricsSet, SpillMetrics}, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, }; @@ -594,8 +594,9 @@ impl HashJoinExec { symmetric_join_output_partitioning(left, right, &join_type)? } PartitionMode::PartitionedSpillable => { - // For partitioned spillable, use the same partitioning as regular partitioned - symmetric_join_output_partitioning(left, right, &join_type)? + // While stabilizing spillable join, advertise single output partition to + // match the current execution behavior and avoid downstream partition fanout. + Partitioning::UnknownPartitioning(1) } }; @@ -802,18 +803,11 @@ impl ExecutionPlan for HashJoinExec { Distribution::UnspecifiedDistribution, Distribution::UnspecifiedDistribution, ], - PartitionMode::PartitionedSpillable => { - // For partitioned spillable, use the same distribution as regular partitioned - let (left_expr, right_expr) = self - .on - .iter() - .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) - .unzip(); - vec![ - Distribution::HashPartitioned(left_expr), - Distribution::HashPartitioned(right_expr), - ] - } + PartitionMode::PartitionedSpillable => vec![ + // While stabilizing, do not require specific input distributions + Distribution::UnspecifiedDistribution, + Distribution::UnspecifiedDistribution, + ], } } @@ -1027,6 +1021,8 @@ impl ExecutionPlan for HashJoinExec { }; let partitioned_reservation = MemoryConsumer::new("PartitionedHashJoin") .register(context.memory_pool()); + // Reuse this operator's metrics set for spill metrics visibility + let spill_metrics = SpillMetrics::new(&self.metrics, partition); let partitioned_stream = PartitionedHashJoinStream::new( partition, self.schema(), @@ -1038,6 +1034,7 @@ impl ExecutionPlan for HashJoinExec { left_fut, self.random_state.clone(), join_metrics, + spill_metrics, column_indices_after_projection, self.null_equality, batch_size, @@ -4656,4 +4653,68 @@ mod tests { fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } + + #[tokio::test] + async fn partitioned_spillable_spills_to_disk() -> Result<()> { + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + // Force spilling with very low reservation; single partition correctness path + let session_config = SessionConfig::default() + .with_batch_size(1024) + .with_target_partitions(1) + .with_sort_spill_reservation_bytes(1) + .with_spill_compression(datafusion_common::config::SpillCompression::Uncompressed); + let runtime = RuntimeEnvBuilder::new().build_arc()?; + let task_ctx = Arc::new(TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime)); + + // Build left/right to ensure build side has more than 1 row to trigger spill partitioning + let left = build_table( + ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8]), + ("b1", &vec![1, 1, 1, 1, 1, 1, 1, 1]), + ("c1", &vec![0, 0, 0, 0, 0, 0, 0, 0]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![1, 1, 1, 2]), + ("c2", &vec![0, 0, 0, 0]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + // Execute with PartitionedSpillable + let join = HashJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + on, + None, + &JoinType::Inner, + None, + PartitionMode::PartitionedSpillable, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, Arc::clone(&task_ctx))?; + // Collect all batches to drive execution and spill + let _ = common::collect(stream).await?; + + // Assert that spilling occurred by inspecting metrics on the operator + let metrics = join.metrics().unwrap(); + // Find any spill metrics in the tree and ensure spilled_rows > 0 + let mut spilled_any = false; + for m in metrics.iter() { + let name = m.value().name(); + let v = m.value().as_usize(); + if (name == "spilled_rows" || name == "spilled_bytes" || name == "spill_count") && v > 0 { + spilled_any = true; + break; + } + } + assert!(spilled_any, "expected spilling to occur in PartitionedSpillable mode"); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 19814d0654992..601eb40bdd490 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -298,6 +298,7 @@ impl PartitionedHashJoinStream { left_fut: OnceFut, random_state: RandomState, join_metrics: BuildProbeJoinMetrics, + spill_metrics: SpillMetrics, column_indices: Vec, null_equality: NullEquality, batch_size: usize, @@ -306,7 +307,6 @@ impl PartitionedHashJoinStream { memory_reservation: MemoryReservation, runtime_env: Arc, ) -> Result { - let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), partition); let spill_manager = SpillManager::new( runtime_env.clone(), spill_metrics, From 9d8d3b3e9fbd02eb1249886c67a1ca37dd4fb307 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Tue, 14 Oct 2025 12:12:04 +0300 Subject: [PATCH 03/36] Add multi-part spill reload --- .../src/joins/hash_join/partitioned.rs | 270 +++++++++++------- 1 file changed, 163 insertions(+), 107 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 601eb40bdd490..aac2c23ed7d6c 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -54,12 +54,12 @@ use crate::joins::utils::{ need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceFut, StatefulStreamResult, }; -use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; +use crate::metrics::{SpillMetrics}; use crate::spill::spill_manager::SpillManager; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::{Array, ArrayRef, BooleanBufferBuilder, UInt32Array, UInt64Array}; -use arrow::compute::take; +use arrow::compute::{take, concat_batches}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{ @@ -217,74 +217,97 @@ pub(super) struct PartitionedHashJoinStream { pub probes_buffered: bool, /// Current read position per partition within buffered probe batches pub probe_batch_positions: Vec, + /// Pending async spill reload stream for build partitions + pub pending_reload_stream: Option, + /// Accumulated batches for pending reload + pub pending_reload_batches: Vec, + /// Target partition id for pending reload + pub pending_reload_partition: Option, } impl PartitionedHashJoinStream { /// Ensure the build partition is loaded in-memory (reload if spilled) - fn ensure_build_partition_loaded(&mut self, part_id: usize) -> Result<()> { + fn ensure_build_partition_loaded(&mut self, cx: &mut Context<'_>, part_id: usize) -> Poll> { let needs_reload = matches!( self.build_partitions.get(part_id), Some(BuildPartition::Spilled { .. }) ); if !needs_reload { - return Ok(()); + return Poll::Ready(Ok(())); } - if let Some(BuildPartition::Spilled { spill_file, .. }) = - self.build_partitions.get_mut(part_id) - { - let spill_file = spill_file - .take() - .ok_or_else(|| internal_datafusion_err!("spill file already consumed for this partition"))?; - - let mut stream = self.spill_manager.read_spill_as_stream(spill_file)?; - let batch = futures::executor::block_on(async { - use futures::StreamExt; - stream.next().await.transpose() - })? - .ok_or_else(|| internal_datafusion_err!("empty spilled partition"))?; - - println!( - "Reloaded spilled build partition {} for probing (rows={})", - part_id, - batch.num_rows() - ); - - // Reconstruct join values from on_left expressions - let mut values: Vec = Vec::with_capacity(self.on_left.len()); - for c in &self.on_left { - values.push(c.evaluate(&batch)?.into_array(batch.num_rows())?); + // Kick off reload if needed + if self.pending_reload_partition.is_none() { + if let Some(BuildPartition::Spilled { spill_file, .. }) = self.build_partitions.get_mut(part_id) { + let spill_file = spill_file.take().ok_or_else(|| internal_datafusion_err!("spill file already consumed for this partition"))?; + let stream = self.spill_manager.read_spill_as_stream(spill_file)?; + self.pending_reload_stream = Some(stream); + self.pending_reload_batches.clear(); + self.pending_reload_partition = Some(part_id); } + } - // Rebuild the hash map from the reloaded batch - let mut hash_map: Box = Box::new( - crate::joins::join_hash_map::JoinHashMapU32::with_capacity(batch.num_rows()), - ); - self.hashes_buffer.clear(); - self.hashes_buffer.resize(batch.num_rows(), 0); - crate::joins::utils::update_hash( - &self.on_left, - &batch, - &mut *hash_map, - 0, - &self.random_state, - &mut self.hashes_buffer, - 0, - true, - )?; + // Drive stream forward + if self.pending_reload_partition == Some(part_id) { + if let Some(stream) = self.pending_reload_stream.as_mut() { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.pending_reload_batches.push(batch); + return Poll::Pending; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + // Concatenate + let first_schema = self.pending_reload_batches.get(0) + .ok_or_else(|| internal_datafusion_err!("empty spilled partition"))? + .schema(); + let concatenated = concat_batches(&first_schema, self.pending_reload_batches.as_slice()) + .map_err(DataFusionError::from)?; - let new_reservation = MemoryConsumer::new("partition_reload") - .with_can_spill(true) - .register(&self.runtime_env.memory_pool); + println!("Reloaded spilled build partition {} for probing (rows={})", part_id, concatenated.num_rows()); - self.build_partitions[part_id] = BuildPartition::InMemory { - hash_map, - batch, - values, - reservation: new_reservation, - }; + // Recompute values and hashmap + let mut values: Vec = Vec::with_capacity(self.on_left.len()); + for c in &self.on_left { + values.push(c.evaluate(&concatenated)?.into_array(concatenated.num_rows())?); + } + + let mut hash_map: Box = Box::new( + crate::joins::join_hash_map::JoinHashMapU32::with_capacity(concatenated.num_rows()), + ); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(concatenated.num_rows(), 0); + crate::joins::utils::update_hash( + &self.on_left, + &concatenated, + &mut *hash_map, + 0, + &self.random_state, + &mut self.hashes_buffer, + 0, + true, + )?; + + let new_reservation = MemoryConsumer::new("partition_reload").with_can_spill(true).register(&self.runtime_env.memory_pool); + + self.build_partitions[part_id] = BuildPartition::InMemory { + hash_map, + batch: concatenated, + values, + reservation: new_reservation, + }; + + self.pending_reload_stream = None; + self.pending_reload_batches.clear(); + self.pending_reload_partition = None; + return Poll::Ready(Ok(())); + } + Poll::Pending => return Poll::Pending, + } + } } - Ok(()) + + Poll::Pending } /// Create a new partitioned hash join stream pub fn new( @@ -355,6 +378,9 @@ impl PartitionedHashJoinStream { unmatched_offset: 0, probes_buffered: false, probe_batch_positions: vec![], + pending_reload_stream: None, + pending_reload_batches: Vec::new(), + pending_reload_partition: None, }) } @@ -710,7 +736,11 @@ impl PartitionedHashJoinStream { // (Build partition will be immutably borrowed later within a narrower scope) // Ensure the build partition is ready (reload if spilled) BEFORE any immutable borrows - self.ensure_build_partition_loaded(partition_state.partition_id)?; + match self.ensure_build_partition_loaded(cx, partition_state.partition_id) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } // If only 1 partition, stream the probe side directly (simpler and correct across executor partitions) if self.num_partitions == 1 { @@ -906,11 +936,11 @@ impl PartitionedHashJoinStream { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))) } - /// Handle unmatched rows for outer joins - fn handle_unmatched_rows(&mut self) -> Result>> { + /// Handle unmatched rows for outer joins (poll-based, non-blocking spill reload) + fn handle_unmatched_rows(&mut self, cx: &mut Context<'_>) -> Poll>>> { if !need_produce_result_in_final(self.join_type) { self.state = PartitionedHashJoinState::Completed; - return Ok(StatefulStreamResult::Ready(None)); + return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); } // If we have cached unmatched indices for current partition, emit them chunk-by-chunk @@ -943,7 +973,7 @@ impl PartitionedHashJoinStream { BuildPartition::InMemory { batch, .. } => batch, BuildPartition::Spilled { .. } => { // Should not happen because we only cache after loading InMemory indices - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } }; @@ -979,7 +1009,7 @@ impl PartitionedHashJoinStream { self.unmatched_partition += 1; } - return Ok(StatefulStreamResult::Ready(Some(result))); + return Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))); } else { // Safety: should not reach here; reset caches self.unmatched_left_indices_cache = None; @@ -994,7 +1024,7 @@ impl PartitionedHashJoinStream { .ok_or_else(|| internal_datafusion_err!("missing build partition during unmatched processing"))?; match partition { - BuildPartition::InMemory { batch, .. } => { + BuildPartition::InMemory { batch: _batch, .. } => { // Get unmatched indices for this partition using its bitmap let (left_indices, right_indices) = if let Some(bitmap) = self.matched_build_rows_per_partition.get(self.unmatched_partition) { get_final_indices_from_bit_map( @@ -1004,7 +1034,7 @@ impl PartitionedHashJoinStream { } else { // If no bitmap, skip this partition self.unmatched_partition += 1; - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); }; println!( @@ -1019,59 +1049,84 @@ impl PartitionedHashJoinStream { self.unmatched_right_indices_cache = Some(right_indices.clone()); self.unmatched_offset = 0; // Fall-through into cached emission on next invocation - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } else { // No unmatched rows in this partition, move to next self.unmatched_partition += 1; - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } } BuildPartition::Spilled { spill_file, .. } => { - // Reload the spilled partition to produce unmatched rows - let spill_file = spill_file.take().ok_or_else(|| internal_datafusion_err!("spill file already consumed for unmatched"))?; - - let mut stream = self.spill_manager.read_spill_as_stream(spill_file)?; - let batch = futures::executor::block_on(async { - use futures::StreamExt; - stream.next().await.transpose() - })? - .ok_or_else(|| internal_datafusion_err!("empty spilled partition for unmatched"))?; - - println!( - "Reloaded spilled build partition {} for unmatched rows (rows={})", - self.unmatched_partition, - batch.num_rows() - ); + // Non-blocking reload of spilled partition for unmatched rows + if self.pending_reload_partition.is_none() { + let taken = spill_file.take().ok_or_else(|| internal_datafusion_err!("spill file already consumed for unmatched"))?; + let stream = self.spill_manager.read_spill_as_stream(taken)?; + self.pending_reload_stream = Some(stream); + self.pending_reload_batches.clear(); + self.pending_reload_partition = Some(self.unmatched_partition); + } - // Replace with in-memory (no need to rebuild hash map here) - let new_reservation = MemoryConsumer::new("partition_reload_unmatched") - .with_can_spill(true) - .register(&self.runtime_env.memory_pool); - let mut values: Vec = Vec::with_capacity(self.on_left.len()); - for c in &self.on_left { - values.push(c.evaluate(&batch)?.into_array(batch.num_rows())?); + if self.pending_reload_partition == Some(self.unmatched_partition) { + if let Some(stream) = self.pending_reload_stream.as_mut() { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.pending_reload_batches.push(batch); + return Poll::Pending; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + let first_schema = self.pending_reload_batches.get(0) + .ok_or_else(|| internal_datafusion_err!("empty spilled partition for unmatched"))? + .schema(); + let concatenated = concat_batches(&first_schema, self.pending_reload_batches.as_slice()) + .map_err(DataFusionError::from)?; + + println!( + "Reloaded spilled build partition {} for unmatched rows (rows={})", + self.unmatched_partition, + concatenated.num_rows() + ); + + let new_reservation = MemoryConsumer::new("partition_reload_unmatched") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let mut values: Vec = Vec::with_capacity(self.on_left.len()); + for c in &self.on_left { + values.push(c.evaluate(&concatenated)?.into_array(concatenated.num_rows())?); + } + let hash_map: Box = Box::new( + crate::joins::join_hash_map::JoinHashMapU32::with_capacity(concatenated.num_rows()), + ); + self.build_partitions[self.unmatched_partition] = BuildPartition::InMemory { + hash_map, + batch: concatenated, + values, + reservation: new_reservation, + }; + println!( + "Prepared spilled partition {} as InMemory for unmatched emission", + self.unmatched_partition + ); + + // Clear pending + self.pending_reload_stream = None; + self.pending_reload_batches.clear(); + self.pending_reload_partition = None; + + // Continue; next iteration will handle InMemory branch + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Poll::Pending => return Poll::Pending, + } + } } - let hash_map: Box = Box::new( - crate::joins::join_hash_map::JoinHashMapU32::with_capacity(batch.num_rows()), - ); - self.build_partitions[self.unmatched_partition] = BuildPartition::InMemory { - hash_map, - batch, - values, - reservation: new_reservation, - }; - println!( - "Prepared spilled partition {} as InMemory for unmatched emission", - self.unmatched_partition - ); - // Continue; next iteration will handle InMemory branch to emit unmatched - return Ok(StatefulStreamResult::Continue); + Poll::Pending } } } else { // All partitions processed self.state = PartitionedHashJoinState::Completed; - return Ok(StatefulStreamResult::Ready(None)); + return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); } } } @@ -1130,21 +1185,22 @@ impl Stream for PartitionedHashJoinStream { } } PartitionedHashJoinState::HandleUnmatchedRows => { - match self.handle_unmatched_rows() { - Ok(StatefulStreamResult::Ready(Some(batch))) => { + match self.handle_unmatched_rows(cx) { + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { println!( "[spill-join] poll_next yielding unmatched batch: rows={}", batch.num_rows() ); return Poll::Ready(Some(Ok(batch))); } - Ok(StatefulStreamResult::Ready(None)) => { + Poll::Ready(Ok(StatefulStreamResult::Ready(None))) => { return Poll::Ready(None); } - Ok(StatefulStreamResult::Continue) => { + Poll::Ready(Ok(StatefulStreamResult::Continue)) => { continue; } - Err(e) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => return Poll::Pending, } } PartitionedHashJoinState::Completed => return Poll::Ready(None), From 622fbc1b78561a7862b1ea771f75a61c7f043d03 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Thu, 16 Oct 2025 12:41:28 +0300 Subject: [PATCH 04/36] Some weird AI fixes. Better revert to proceed --- .../physical-plan/src/joins/hash_join/exec.rs | 90 ++++++- .../src/joins/hash_join/partitioned.rs | 231 ++++++++++++++---- .../src/joins/hash_join/stream.rs | 43 ++++ 3 files changed, 303 insertions(+), 61 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index a41654af9e762..a999a79d19b8b 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -594,8 +594,7 @@ impl HashJoinExec { symmetric_join_output_partitioning(left, right, &join_type)? } PartitionMode::PartitionedSpillable => { - // While stabilizing spillable join, advertise single output partition to - // match the current execution behavior and avoid downstream partition fanout. + // Stabilize: single output partition to avoid downstream fanout and repartition panics Partitioning::UnknownPartitioning(1) } }; @@ -969,10 +968,32 @@ impl ExecutionPlan for HashJoinExec { ); } PartitionMode::PartitionedSpillable => { - // For partitioned spillable mode, we need to collect the left side - // and then create a partitioned hash join stream println!("PartitionedSpillable mode"); - // Coalesce left partitions to get the full build side in a single stream + let enable_spillable = context + .session_config() + .options() + .optimizer + .enable_spillable_hash_join; + if !enable_spillable { + println!( + "PartitionedSpillable disabled by optimizer.enable_spillable_hash_join=false; using legacy Partitioned semantics" + ); + // Legacy fallback: behave like Partitioned + let left_stream = self.left.execute(partition, Arc::clone(&context))?; + let reservation = MemoryConsumer::new(format!("HashJoinInput[{partition}]")) + .register(context.memory_pool()); + OnceFut::new(collect_left_input( + self.random_state.clone(), + left_stream, + on_left.clone(), + join_metrics.clone(), + reservation, + need_produce_result_in_final(self.join_type), + 1, + enable_dynamic_filter_pushdown, + )) + } else { + // Spillable enabled: coalesce left to a single stream let left_plan: Arc = if self.left.output_partitioning().partition_count() == 1 { Arc::clone(&self.left) } else { @@ -980,7 +1001,6 @@ impl ExecutionPlan for HashJoinExec { }; let left_stream = left_plan.execute(0, Arc::clone(&context))?; let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); - let left_fut = self.left_fut.try_once(|| { Ok(collect_left_input( self.random_state.clone(), @@ -993,10 +1013,58 @@ impl ExecutionPlan for HashJoinExec { enable_dynamic_filter_pushdown, )) })?; + // For Right-side oriented joins, fall back to standard HashJoinStream for correctness + if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark) { + // Fall back to standard HashJoinStream but ensure the probe side is a single coalesced stream + let right_plan: Arc = if self.right.output_partitioning().partition_count() == 1 { + Arc::clone(&self.right) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.right))) + }; + let right_stream = right_plan.execute(0, Arc::clone(&context))?; + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + // Classic HashJoinStream constructor + return Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial( + BuildSideInitialState { left_fut } + ), + context.session_config().batch_size(), + vec![], + self.right.output_ordering().is_some(), + None, + ))); + } - // Re-enable spillable stream with single-partition direct-probe for now + // Enable spillable stream; coalesce right to a single stream to avoid downstream fanout use crate::joins::hash_join::partitioned::PartitionedHashJoinStream; - let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let right_plan: Arc = if self.right.output_partitioning().partition_count() == 1 { + Arc::clone(&self.right) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.right))) + }; + let right_stream = right_plan.execute(0, Arc::clone(&context))?; let column_indices_after_projection = match &self.projection { Some(projection) => projection .iter() @@ -1010,7 +1078,8 @@ impl ExecutionPlan for HashJoinExec { .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::>(); let batch_size = context.session_config().batch_size(); - let num_partitions = 1; // Start with single partition correctness + // Single output partition execution + let num_partitions = 1; let memory_threshold = { let bytes = context .session_config() @@ -1044,6 +1113,7 @@ impl ExecutionPlan for HashJoinExec { context.runtime_env(), )?; return Ok(Box::pin(partitioned_stream)); + } } }; @@ -1075,7 +1145,7 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let right_stream = self.right.execute(partition, context)?; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; // update column indices to reflect the projection let column_indices_after_projection = match &self.projection { diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index aac2c23ed7d6c..bad116487585a 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -51,7 +51,8 @@ use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::join_hash_map::JoinHashMapType; use crate::joins::utils::{ build_batch_from_indices, equal_rows_arr, get_final_indices_from_bit_map, - need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + need_produce_result_in_final, apply_join_filter_to_indices, adjust_indices_by_join_type, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceFut, StatefulStreamResult, }; use crate::metrics::{SpillMetrics}; @@ -195,6 +196,10 @@ pub(super) struct PartitionedHashJoinStream { pub hashes_buffer: Vec, /// Whether the right side has an ordering to potentially preserve pub right_side_ordered: bool, + /// Whether this stream has emitted a placeholder batch for downstream scheduling + pub placeholder_emitted: bool, + /// Running alignment start for right indices across probe batches (for semi/anti/mark) + pub right_alignment_start: usize, /// Shared bounds accumulator for coordinating dynamic filter updates (optional) pub bounds_accumulator: Option>, /// Current probe batch (filtered to the active partition), if any @@ -297,6 +302,15 @@ impl PartitionedHashJoinStream { reservation: new_reservation, }; + if let Some(BuildPartition::InMemory { hash_map, batch, .. }) = self.build_partitions.get(part_id) { + println!( + "Reloaded partition {} hashmap empty? {} rows={}", + part_id, + hash_map.is_empty(), + batch.num_rows() + ); + } + self.pending_reload_stream = None; self.pending_reload_batches.clear(); self.pending_reload_partition = None; @@ -366,6 +380,8 @@ impl PartitionedHashJoinStream { runtime_env, hashes_buffer: Vec::new(), right_side_ordered: false, + placeholder_emitted: false, + right_alignment_start: 0, bounds_accumulator: None, current_probe_batch: None, current_probe_values: vec![], @@ -539,6 +555,11 @@ impl PartitionedHashJoinStream { } } + // Disable spilling in single-partition mode to avoid reload deadlocks and ensure progress + if self.num_partitions == 1 { + will_spill = false; + } + if will_spill && self.runtime_env.disk_manager.tmp_files_enabled() { println!( "Spilling build partition {} (rows={}) due to memory threshold (threshold={} bytes, current={})", @@ -601,6 +622,13 @@ impl PartitionedHashJoinStream { // Per-partition reservation: zero-sized placeholder; global reservation tracks memory let reservation = MemoryConsumer::new("partition_memory").with_can_spill(true).register(&self.runtime_env.memory_pool); + let is_empty_after = partition_hash_map.is_empty(); + println!( + "Partition {} hashmap empty after build? {}", + partition_id, + is_empty_after + ); + self.build_partitions.push(BuildPartition::InMemory { hash_map: partition_hash_map, batch: partition_batch, @@ -609,16 +637,17 @@ impl PartitionedHashJoinStream { }); } - // Start processing the first partition + // Start processing at the stream's assigned output partition + let start_partition = self.partition.min(self.num_partitions.saturating_sub(1)); println!( - "Partitioning complete. Created {} partitions. Starting to process partition 0", - self.build_partitions.len() + "Partitioning complete. Created {} partitions. Starting to process partition {}", + self.build_partitions.len(), start_partition ); - + self.state = PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { - partition_id: 0, + partition_id: start_partition, total_partitions: self.num_partitions, - is_last_partition: self.num_partitions == 1, + is_last_partition: start_partition + 1 == self.num_partitions, }); Ok(StatefulStreamResult::Continue) @@ -743,7 +772,7 @@ impl PartitionedHashJoinStream { } // If only 1 partition, stream the probe side directly (simpler and correct across executor partitions) - if self.num_partitions == 1 { + if self.num_partitions == 1 || self.partition >= self.num_partitions { if self.current_probe_batch.is_none() { match ready!(self.right.poll_next_unpin(cx)) { Some(Ok(batch)) => { @@ -792,41 +821,33 @@ impl PartitionedHashJoinStream { } } } else { - // For multiple inner partitions, buffer the probe side once and consume per partition - if !self.probes_buffered { - ready!(self.buffer_probe_side(cx))?; - } + // Multi-partition execution: assume the scheduler already partitioned the right stream. if self.current_probe_batch.is_none() { - let part_id = partition_state.partition_id; - let pos = *self.probe_batch_positions.get(part_id).unwrap_or(&0); - if let Some(probe_part) = self.probe_partitions.get(part_id) { - if pos < probe_part.batches.len() { - let filtered_batch = probe_part.batches[pos].clone(); - let filtered_on_values = probe_part.values[pos].clone(); - let filtered_hashes = probe_part.hashes[pos].clone(); - - self.current_probe_hashes = filtered_hashes; - self.current_probe_values = filtered_on_values; - self.current_probe_batch = Some(filtered_batch); + match self.right.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + // Compute ON values and hashes for the full batch; no extra filtering here + let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; + keys_values.push(v); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + self.current_probe_hashes = hashes; + self.current_probe_values = keys_values; + self.current_probe_batch = Some(batch); self.current_offset = (0, None); - self.probe_batch_positions[part_id] = pos + 1; - } else { - // No more probe data for this partition, release and advance + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + // End of right stream for this partition: transition to unmatched rows for THIS partition only self.release_partition_resources(partition_state.partition_id); - if partition_state.is_last_partition { - self.state = PartitionedHashJoinState::HandleUnmatchedRows; - } else { - self.state = PartitionedHashJoinState::ProcessPartition( - ProcessPartitionState { - partition_id: partition_state.partition_id + 1, - total_partitions: partition_state.total_partitions, - is_last_partition: partition_state.partition_id + 1 - == partition_state.total_partitions, - }, - ); - } + self.unmatched_partition = partition_state.partition_id; + self.state = PartitionedHashJoinState::HandleUnmatchedRows; return Poll::Ready(Ok(StatefulStreamResult::Continue)); } + Poll::Pending => return Poll::Pending, } } } @@ -850,6 +871,11 @@ impl PartitionedHashJoinStream { }) => (&**hash_map, batch, values as &Vec), _ => return Poll::Ready(internal_err!("Missing or invalid build partition")), }; + println!( + "[spill-join] Partition {} build hashmap empty? {}", + partition_state.partition_id, + build_hashmap.is_empty() + ); // Lookup against hash map with limit let (probe_indices, build_indices, next_offset) = build_hashmap @@ -879,24 +905,103 @@ impl PartitionedHashJoinStream { self.null_equality, )?; + // Debug: log key data types and sample matched pairs + if !build_indices.is_empty() { + let build_key0 = build_values + .get(0) + .map(|a| a.data_type().clone()) + .unwrap_or_else(|| build_batch.schema().field(0).data_type().clone()); + let probe_key0 = self + .current_probe_values + .get(0) + .map(|a| a.data_type().clone()) + .unwrap_or_else(|| probe_batch.schema().field(0).data_type().clone()); + println!( + "[spill-join] Key types: build={:?}, probe={:?}, null_equality={:?}", + build_key0, + probe_key0, + self.null_equality + ); + let sample = build_indices.len().min(5); + let mut pairs = Vec::new(); + for i in 0..sample { + let b = build_indices.value(i) as usize; + let p = probe_indices.value(i) as usize; + let b_slice = build_values[0].as_ref().slice(b, 1); + let p_slice = self.current_probe_values[0].as_ref().slice(p, 1); + pairs.push(format!("({},{})", + b_slice.get_array_memory_size(), + p_slice.get_array_memory_size() + )); + } + println!("[spill-join] Sample pairs (mem sizes) {} -> {}: {}", sample, build_indices.len(), pairs.join(", ")); + } + + // Apply residual join filter if present + let (build_indices, probe_indices) = if let Some(filter) = &self.filter { + apply_join_filter_to_indices( + build_batch, + probe_batch, + build_indices, + probe_indices, + filter, + JoinSide::Left, + None, + )? + } else { + (build_indices, probe_indices) + }; + let (build_indices, probe_indices) = adjust_indices_by_join_type( + build_indices, + probe_indices, + 0..probe_batch.num_rows(), + self.join_type, + self.right_side_ordered, + )?; + println!( - "[spill-join] Matched after equality: {}", + "[spill-join] Matched after equality{}: {}", + if self.filter.is_some() { "+filter" } else { "" }, build_indices.len() ); // Prepare ids for marking after we release borrows let build_ids_to_mark: Vec = build_indices.values().to_vec(); - // Build output batch (Left side is build) - let result = build_batch_from_indices( - &self.schema, - build_batch, - probe_batch, - &build_indices, - &probe_indices, - &self.column_indices, - JoinSide::Left, - )?; + // Build output batch depending on join side semantics + let result = if matches!(self.join_type, JoinType::RightMark) { + println!("[spill-join] Building output with JoinSide::Right (RightMark)"); + build_batch_from_indices( + &self.schema, + probe_batch, + build_batch, + &build_indices, + &probe_indices, + &self.column_indices, + JoinSide::Right, + )? + } else if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { + println!("[spill-join] Building output with JoinSide::Right ({:?})", self.join_type); + build_batch_from_indices( + &self.schema, + probe_batch, + build_batch, + &build_indices, + &probe_indices, + &self.column_indices, + JoinSide::Right, + )? + } else { + build_batch_from_indices( + &self.schema, + build_batch, + probe_batch, + &build_indices, + &probe_indices, + &self.column_indices, + JoinSide::Left, + )? + }; (result, build_ids_to_mark, next_offset) }; @@ -917,6 +1022,7 @@ impl PartitionedHashJoinStream { self.current_probe_values.clear(); self.current_probe_hashes.clear(); self.current_offset = (0, None); + // Alignment is batch-local for semi/anti/mark in spillable path; do not carry across batches } if result.num_rows() == 0 { @@ -1069,7 +1175,12 @@ impl PartitionedHashJoinStream { if self.pending_reload_partition == Some(self.unmatched_partition) { if let Some(stream) = self.pending_reload_stream.as_mut() { match stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { + Poll::Ready(Some(Ok(batch))) => { + println!( + "Reload stream yielded batch for build partition {} (rows={})", + self.unmatched_partition, + batch.num_rows() + ); self.pending_reload_batches.push(batch); return Poll::Pending; } @@ -1116,7 +1227,15 @@ impl PartitionedHashJoinStream { // Continue; next iteration will handle InMemory branch return Poll::Ready(Ok(StatefulStreamResult::Continue)); } - Poll::Pending => return Poll::Pending, + Poll::Pending => { + // Yield until more data is available from reload stream + println!( + "Reload stream pending for build partition {} (accumulated_batches={})", + self.unmatched_partition, + self.pending_reload_batches.len() + ); + return Poll::Pending; + } } } } @@ -1166,6 +1285,16 @@ impl Stream for PartitionedHashJoinStream { } } PartitionedHashJoinState::ProcessPartition(partition_state) => { + // Emit a zero-row placeholder once in multi-output mode to satisfy downstream schedulers + if self.num_partitions > 1 && !self.placeholder_emitted { + self.placeholder_emitted = true; + let empty = RecordBatch::new_empty(self.schema.clone()); + println!( + "[spill-join] Emitting placeholder empty batch for partition {}", + partition_state.partition_id + ); + return Poll::Ready(Some(Ok(empty))); + } match self.process_partition(cx, &partition_state) { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { println!( diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 4484eeabd3264..ce5892849aa6c 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -556,6 +556,17 @@ impl HashJoinStream { last_joined_right_idx.map_or(0, |v| v + 1) }; + if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark) { + println!( + "[hash-join] Align {:?}: pre-adjust right_indices={}, range={}..{} (next_offset_present={})", + self.join_type, + right_indices.len(), + index_alignment_range_start, + index_alignment_range_end, + next_offset.is_some() + ); + } + let (left_indices, right_indices) = adjust_indices_by_join_type( left_indices, right_indices, @@ -564,6 +575,27 @@ impl HashJoinStream { self.right_side_ordered, )?; + if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark) { + println!( + "[hash-join] Align {:?}: post-adjust unique_right_indices={} (range={}..{})", + self.join_type, + right_indices.len(), + index_alignment_range_start, + index_alignment_range_end + ); + } + + if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { + println!( + "[hash-join] Right {:?}: probe_batch_rows={}, unique_matched_right_indices={} (range={}..{})", + self.join_type, + state.batch.num_rows(), + right_indices.len(), + index_alignment_range_start, + index_alignment_range_end + ); + } + let result = if self.join_type == JoinType::RightMark { build_batch_from_indices( &self.schema, @@ -574,6 +606,17 @@ impl HashJoinStream { &self.column_indices, JoinSide::Right, )? + } else if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { + // Emit probe-side rows for right-oriented joins + build_batch_from_indices( + &self.schema, + &state.batch, + build_side.left_data.batch(), + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Right, + )? } else { build_batch_from_indices( &self.schema, From fcb7da65474b276d96b2540a6c43b9a7a1e6a318 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Mon, 27 Oct 2025 16:28:54 +0300 Subject: [PATCH 05/36] Basic memory + disk spilling --- .../src/spill/in_memory_spill_buffer.rs | 47 ++++++++++++++ datafusion/physical-plan/src/spill/mod.rs | 65 ++++++++++++++++++- .../physical-plan/src/spill/spill_manager.rs | 58 ++++++++++++++++- 3 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs diff --git a/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs b/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs new file mode 100644 index 0000000000000..81c8e594af094 --- /dev/null +++ b/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs @@ -0,0 +1,47 @@ +use crate::memory::MemoryStream; +use crate::spill::spill_manager::GetSlicedSize; +use arrow::array::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::SendableRecordBatchStream; +use std::sync::Arc; + +#[derive(Debug)] +pub struct InMemorySpillBuffer { + batches: Vec, + total_bytes: usize, +} + +impl InMemorySpillBuffer { + pub fn from_batch(batch: &RecordBatch) -> Result { + Ok(Self { + batches: vec![batch.clone()], + total_bytes: batch.get_sliced_size()?, + }) + } + + pub fn from_batches(batches: &[RecordBatch]) -> Result { + let mut total_bytes = 0; + let mut owned = Vec::with_capacity(batches.len()); + for b in batches { + total_bytes += b.get_sliced_size()?; + owned.push(b.clone()); + } + Ok(Self { + batches: owned, + total_bytes, + }) + } + + /// return FIFO stream of batches + pub fn as_stream( + self: Arc, + schema: Arc, + ) -> Result { + let stream = MemoryStream::try_new(self.batches.clone(), schema, None)?; + Ok(Box::pin(stream)) + } + + pub fn size(&self) -> usize { + self.total_bytes + } +} diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index fab62bff840f6..270b3654b2bad 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -19,6 +19,7 @@ pub(crate) mod in_progress_spill_file; pub(crate) mod spill_manager; +pub(crate) mod in_memory_spill_buffer; use std::fs::File; use std::io::BufReader; @@ -376,17 +377,18 @@ mod tests { use crate::common::collect; use crate::metrics::ExecutionPlanMetricsSet; use crate::metrics::SpillMetrics; - use crate::spill::spill_manager::SpillManager; + use crate::spill::spill_manager::{SpillLocation, SpillManager}; use crate::test::build_table_i32; use arrow::array::{ArrayRef, Float64Array, Int32Array, ListArray, StringArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; - use datafusion_execution::runtime_env::RuntimeEnv; + use datafusion_execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; use futures::StreamExt as _; use std::sync::Arc; + use datafusion_execution::memory_pool::{FairSpillPool, MemoryPool}; #[tokio::test] async fn test_batch_spill_and_read() -> Result<()> { @@ -426,6 +428,65 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_batch_spill_to_memory_and_disk_and_read() -> Result<()> { + let schema: SchemaRef = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from_iter_values(0..1000)), + Arc::new(Int32Array::from_iter_values(1000..2000)), + ], + )?; + + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from_iter_values(2000..4000)), + Arc::new(Int32Array::from_iter_values(4000..6000)), + ], + )?; + + let num_rows = batch1.num_rows() + batch2.num_rows(); + let batches = vec![batch1, batch2]; + + // --- create small memory pool (simulate memory pressure) --- + let memory_limit_bytes = 20 * 1024; // 20KB + let memory_pool: Arc = Arc::new(FairSpillPool::new(memory_limit_bytes)); + + + // Construct SpillManager + let env = RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build_arc()?; + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + + let results = spill_manager.spill_batches_auto(&batches, "TestAutoSpill")?; + assert_eq!(results.len(), 2); + + let mem_count = results.iter().filter(|r| matches!(r, SpillLocation::Memory(_))).count(); + let disk_count = results.iter().filter(|r| matches!(r, SpillLocation::Disk(_))).count(); + assert!(mem_count >= 1); + assert!(disk_count >= 1); + + let spilled_rows = spill_manager.metrics.spilled_rows.value(); + assert_eq!(spilled_rows, num_rows); + + for spill in results { + let stream = spill_manager.load_spilled_batch(spill)?; + let collected = collect(stream).await?; + assert!(!collected.is_empty()); + assert_eq!(collected[0].schema(), schema); + } + + Ok(()) + } + #[tokio::test] async fn test_batch_spill_and_read_dictionary_arrays() -> Result<()> { // See https://github.com/apache/datafusion/issues/4658 diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index ad23bd66a021a..1ac6d072f45d4 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -21,15 +21,17 @@ use arrow::array::StringViewArray; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_execution::runtime_env::RuntimeEnv; +use std::slice; use std::sync::Arc; -use datafusion_common::{config::SpillCompression, Result}; +use datafusion_common::{config::SpillCompression, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::SendableRecordBatchStream; use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; use crate::coop::cooperative; use crate::{common::spawn_buffered, metrics::SpillMetrics}; +use crate::spill::in_memory_spill_buffer::InMemorySpillBuffer; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -168,6 +170,43 @@ impl SpillManager { Ok(file.map(|f| (f, max_record_batch_size))) } + pub(crate) fn spill_batch_auto(&self, batch: &RecordBatch, request_msg: &str) -> Result { + let size = batch.get_sliced_size()?; + + // check pool limit + let used = self.env.memory_pool.reserved(); + let limit = match self.env.memory_pool.memory_limit() { + datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, + _ => usize::MAX, + }; + + if used + size * 3 / 2 <= limit { + let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); + self.metrics.spilled_bytes.add(size); + self.metrics.spilled_rows.add(batch.num_rows()); + Ok(SpillLocation::Memory(buf)) + } else { + let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { + return Err(DataFusionError::Execution( + "failed to spill batch to disk".into(), + )); + }; + Ok(SpillLocation::Disk(file)) + } + } + + pub fn spill_batches_auto( + &self, + batches: &[RecordBatch], + request_msg: &str, + ) -> Result> { + let mut result = Vec::with_capacity(batches.len()); + for batch in batches { + result.push(self.spill_batch_auto(batch, request_msg)?); + } + Ok(result) + } + /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. /// This method will generate output in FIFO order: the batch appended first /// will be read first. @@ -182,8 +221,25 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + + pub fn load_spilled_batch( + &self, + spill: SpillLocation, + ) -> Result { + match spill { + SpillLocation::Memory(buf) => Ok(buf.as_stream(Arc::clone(&self.schema))?), + SpillLocation::Disk(file) => self.read_spill_as_stream(file), + } + } +} + +#[derive(Debug)] +pub enum SpillLocation { + Memory(Arc), + Disk(RefCountedTempFile), } + pub(crate) trait GetSlicedSize { /// Returns the size of the `RecordBatch` when sliced. /// Note: if multiple arrays or even a single array share the same data buffers, we may double count each buffer. From e9441e209334b2eaea468d6a6c8d51a77dae3d47 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Mon, 27 Oct 2025 16:29:27 +0300 Subject: [PATCH 06/36] Basic memory + disk spilling --- datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs b/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs index 81c8e594af094..bba0f6f95625f 100644 --- a/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs +++ b/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs @@ -32,7 +32,6 @@ impl InMemorySpillBuffer { }) } - /// return FIFO stream of batches pub fn as_stream( self: Arc, schema: Arc, From c1d66ff83ebe9adf46b964b800ec898faddfafa4 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Mon, 27 Oct 2025 16:31:54 +0300 Subject: [PATCH 07/36] Basic memory + disk spilling --- datafusion/physical-plan/src/spill/spill_manager.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 1ac6d072f45d4..e8f5037f764bc 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -170,22 +170,26 @@ impl SpillManager { Ok(file.map(|f| (f, max_record_batch_size))) } + /// Automatically decides whether to spill the given RecordBatch to memory or disk, + /// depending on available memory pool capacity. pub(crate) fn spill_batch_auto(&self, batch: &RecordBatch, request_msg: &str) -> Result { let size = batch.get_sliced_size()?; - // check pool limit + // Check current memory usage and total limit from the runtime memory pool let used = self.env.memory_pool.reserved(); let limit = match self.env.memory_pool.memory_limit() { datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, _ => usize::MAX, }; + // If there's enough memory (with a small safety margin), keep it in memory if used + size * 3 / 2 <= limit { let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); self.metrics.spilled_bytes.add(size); self.metrics.spilled_rows.add(batch.num_rows()); Ok(SpillLocation::Memory(buf)) } else { + // Otherwise spill to disk using the existing SpillManager logic let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { return Err(DataFusionError::Execution( "failed to spill batch to disk".into(), From 7833aab1f5a1e21d5dca95635bea8cbf4a1f2426 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Mon, 3 Nov 2025 17:16:15 +0200 Subject: [PATCH 08/36] Inner Join works --- .../physical-optimizer/src/join_selection.rs | 2 +- .../physical-plan/src/joins/hash_join/exec.rs | 317 ++-- .../src/joins/hash_join/partitioned.rs | 1535 +++++++++++++---- .../src/joins/hash_join/stream.rs | 173 +- 4 files changed, 1600 insertions(+), 427 deletions(-) diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 8785c20329edb..976096e7761fc 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -240,7 +240,7 @@ pub(crate) fn partitioned_hash_join( } else { PartitionMode::Partitioned }; - + if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { hash_join.swap_inputs(partition_mode) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index a999a79d19b8b..97ffe3d2f0745 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -21,6 +21,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; use std::{any::Any, vec}; +use crate::coalesce_partitions::CoalescePartitionsExec; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, @@ -36,7 +37,6 @@ use crate::joins::utils::{ update_hash, OnceAsync, OnceFut, }; use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; -use crate::coalesce_partitions::CoalescePartitionsExec; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, @@ -572,13 +572,19 @@ impl HashJoinExec { mode: PartitionMode, projection: Option<&Vec>, ) -> Result { - // Calculate equivalence properties: + // Calculate equivalence properties. For the spillable path, do not claim + // any input order preservation to avoid incorrect planner assumptions + // (e.g., SortPreservingMerge) when the operator may perturb order. + let maintains = match mode { + PartitionMode::PartitionedSpillable => vec![false, false], + _ => Self::maintains_input_order(join_type), + }; let mut eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), right.equivalence_properties().clone(), &join_type, Arc::clone(&schema), - &Self::maintains_input_order(join_type), + &maintains, Some(Self::probe_side()), on, )?; @@ -594,7 +600,8 @@ impl HashJoinExec { symmetric_join_output_partitioning(left, right, &join_type)? } PartitionMode::PartitionedSpillable => { - // Stabilize: single output partition to avoid downstream fanout and repartition panics + // Report output partitions consistent with the right side to enable + // proper upstream planning (e.g., repartitioning and aggregations) Partitioning::UnknownPartitioning(1) } }; @@ -968,20 +975,19 @@ impl ExecutionPlan for HashJoinExec { ); } PartitionMode::PartitionedSpillable => { - println!("PartitionedSpillable mode"); let enable_spillable = context .session_config() .options() .optimizer .enable_spillable_hash_join; + if !enable_spillable { - println!( - "PartitionedSpillable disabled by optimizer.enable_spillable_hash_join=false; using legacy Partitioned semantics" - ); // Legacy fallback: behave like Partitioned - let left_stream = self.left.execute(partition, Arc::clone(&context))?; - let reservation = MemoryConsumer::new(format!("HashJoinInput[{partition}]")) - .register(context.memory_pool()); + let left_stream = + self.left.execute(partition, Arc::clone(&context))?; + let reservation = + MemoryConsumer::new(format!("HashJoinInput[{partition}]")) + .register(context.memory_pool()); OnceFut::new(collect_left_input( self.random_state.clone(), left_stream, @@ -993,35 +999,111 @@ impl ExecutionPlan for HashJoinExec { enable_dynamic_filter_pushdown, )) } else { - // Spillable enabled: coalesce left to a single stream - let left_plan: Arc = if self.left.output_partitioning().partition_count() == 1 { - Arc::clone(&self.left) - } else { - Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.left))) - }; - let left_stream = left_plan.execute(0, Arc::clone(&context))?; - let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); - let left_fut = self.left_fut.try_once(|| { - Ok(collect_left_input( - self.random_state.clone(), - left_stream, - on_left.clone(), - join_metrics.clone(), - reservation, - need_produce_result_in_final(self.join_type), - self.right().output_partitioning().partition_count(), - enable_dynamic_filter_pushdown, - )) - })?; - // For Right-side oriented joins, fall back to standard HashJoinStream for correctness - if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark) { - // Fall back to standard HashJoinStream but ensure the probe side is a single coalesced stream - let right_plan: Arc = if self.right.output_partitioning().partition_count() == 1 { - Arc::clone(&self.right) - } else { - Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.right))) + // Spillable enabled: coalesce left to a single stream + let left_plan: Arc = + if self.left.output_partitioning().partition_count() == 1 { + Arc::clone(&self.left) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.left))) + }; + let build_schema = left_plan.schema(); + let left_stream = left_plan.execute(0, Arc::clone(&context))?; + let reservation = MemoryConsumer::new("HashJoinInput") + .register(context.memory_pool()); + let left_fut = self.left_fut.try_once(|| { + Ok(collect_left_input( + self.random_state.clone(), + left_stream, + on_left.clone(), + join_metrics.clone(), + reservation, + need_produce_result_in_final(self.join_type), + self.right().output_partitioning().partition_count(), + enable_dynamic_filter_pushdown, + )) + })?; + + let make_bounds_accumulator = |right_plan: &Arc| { + if enable_dynamic_filter_pushdown { + self.dynamic_filter.as_ref().map(|df| { + let filter = Arc::clone(&df.filter); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + Arc::clone(df.bounds_accumulator.get_or_init(|| { + Arc::new(SharedBoundsAccumulator::new_from_partition_mode( + self.mode, + left_plan.as_ref(), + right_plan.as_ref(), + filter, + on_right, + )) + })) + }) + } else { + None + } }; + + // For Right-side oriented joins, fall back to standard HashJoinStream for correctness + if matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ) { + let right_plan: Arc = + if self.right.output_partitioning().partition_count() == 1 { + Arc::clone(&self.right) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone( + &self.right, + ))) + }; + let right_stream = right_plan.execute(0, Arc::clone(&context))?; + let shared_bounds_accumulator = make_bounds_accumulator(&right_plan); + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + return Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), + context.session_config().batch_size(), + vec![], + self.right.output_ordering().is_some(), + shared_bounds_accumulator, + ))); + } + + use crate::joins::hash_join::partitioned::PartitionedHashJoinStream; + let right_plan: Arc = + if self.right.output_partitioning().partition_count() == 1 { + Arc::clone(&self.right) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.right))) + }; let right_stream = right_plan.execute(0, Arc::clone(&context))?; + let shared_bounds_accumulator = make_bounds_accumulator(&right_plan); + let probe_schema = right_plan.schema(); let column_indices_after_projection = match &self.projection { Some(projection) => projection .iter() @@ -1034,85 +1116,58 @@ impl ExecutionPlan for HashJoinExec { .iter() .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::>(); - // Classic HashJoinStream constructor - return Ok(Box::pin(HashJoinStream::new( + let batch_size = context.session_config().batch_size(); + let mut num_partitions = context.session_config().target_partitions(); + if num_partitions == 0 { + num_partitions = 1; + } + let np2 = num_partitions.next_power_of_two(); + let num_partitions = np2.max(1); + let memory_threshold = { + let bytes = context + .session_config() + .options() + .execution + .sort_spill_reservation_bytes; + if bytes == 0 { + 1024 * 1024 * 1024 + } else { + bytes + } + }; + let partitioned_reservation = + MemoryConsumer::new("PartitionedHashJoin") + .register(context.memory_pool()); + let probe_spill_metrics = + SpillMetrics::new(&self.metrics, partition); + let build_spill_metrics = + SpillMetrics::new(&self.metrics, partition); + let partitioned_stream = PartitionedHashJoinStream::new( partition, self.schema(), + on_left.clone(), on_right, self.filter.clone(), self.join_type, right_stream, + left_fut, self.random_state.clone(), join_metrics, + probe_spill_metrics, + build_spill_metrics, column_indices_after_projection, self.null_equality, - HashJoinStreamState::WaitBuildSide, - BuildSide::Initial( - BuildSideInitialState { left_fut } - ), - context.session_config().batch_size(), - vec![], + batch_size, + num_partitions, + memory_threshold, + partitioned_reservation, + context.runtime_env(), + build_schema, + probe_schema, self.right.output_ordering().is_some(), - None, - ))); - } - - // Enable spillable stream; coalesce right to a single stream to avoid downstream fanout - use crate::joins::hash_join::partitioned::PartitionedHashJoinStream; - let right_plan: Arc = if self.right.output_partitioning().partition_count() == 1 { - Arc::clone(&self.right) - } else { - Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.right))) - }; - let right_stream = right_plan.execute(0, Arc::clone(&context))?; - let column_indices_after_projection = match &self.projection { - Some(projection) => projection - .iter() - .map(|i| self.column_indices[*i].clone()) - .collect(), - None => self.column_indices.clone(), - }; - let on_right = self - .on - .iter() - .map(|(_, right_expr)| Arc::clone(right_expr)) - .collect::>(); - let batch_size = context.session_config().batch_size(); - // Single output partition execution - let num_partitions = 1; - let memory_threshold = { - let bytes = context - .session_config() - .options() - .execution - .sort_spill_reservation_bytes; - if bytes == 0 { 1024 * 1024 * 1024 } else { bytes } - }; - let partitioned_reservation = MemoryConsumer::new("PartitionedHashJoin") - .register(context.memory_pool()); - // Reuse this operator's metrics set for spill metrics visibility - let spill_metrics = SpillMetrics::new(&self.metrics, partition); - let partitioned_stream = PartitionedHashJoinStream::new( - partition, - self.schema(), - on_left, - on_right, - self.filter.clone(), - self.join_type, - right_stream, - left_fut, - self.random_state.clone(), - join_metrics, - spill_metrics, - column_indices_after_projection, - self.null_equality, - batch_size, - num_partitions, - memory_threshold, - partitioned_reservation, - context.runtime_env(), - )?; - return Ok(Box::pin(partitioned_stream)); + shared_bounds_accumulator, + )?; + return Ok(Box::pin(partitioned_stream)); } } }; @@ -1963,8 +2018,11 @@ mod tests { .with_batch_size(batch_size) .with_target_partitions(4) .with_sort_spill_reservation_bytes(1) - .with_spill_compression(datafusion_common::config::SpillCompression::Uncompressed); - let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + .with_spill_compression( + datafusion_common::config::SpillCompression::Uncompressed, + ); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); let left = build_table( ("a1", &vec![1, 2, 3]), @@ -1976,12 +2034,10 @@ mod tests { ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; let (columns, batches, metrics) = join_collect_with_partition_mode( Arc::clone(&left), @@ -4732,11 +4788,15 @@ mod tests { .with_batch_size(1024) .with_target_partitions(1) .with_sort_spill_reservation_bytes(1) - .with_spill_compression(datafusion_common::config::SpillCompression::Uncompressed); + .with_spill_compression( + datafusion_common::config::SpillCompression::Uncompressed, + ); let runtime = RuntimeEnvBuilder::new().build_arc()?; - let task_ctx = Arc::new(TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime)); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); // Build left/right to ensure build side has more than 1 row to trigger spill partitioning let left = build_table( @@ -4749,12 +4809,10 @@ mod tests { ("b1", &vec![1, 1, 1, 2]), ("c2", &vec![0, 0, 0, 0]), ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; // Execute with PartitionedSpillable let join = HashJoinExec::try_new( @@ -4779,12 +4837,19 @@ mod tests { for m in metrics.iter() { let name = m.value().name(); let v = m.value().as_usize(); - if (name == "spilled_rows" || name == "spilled_bytes" || name == "spill_count") && v > 0 { + if (name == "spilled_rows" + || name == "spilled_bytes" + || name == "spill_count") + && v > 0 + { spilled_any = true; break; } } - assert!(spilled_any, "expected spilling to occur in PartitionedSpillable mode"); + assert!( + spilled_any, + "expected spilling to occur in PartitionedSpillable mode" + ); Ok(()) } } diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index bad116487585a..879be939572f8 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -43,38 +43,37 @@ //! - Generates join results and handles unmatched rows for outer joins //! - Tracks matched rows for proper outer join semantics +use std::mem; use std::sync::Arc; use std::task::{Context, Poll}; -use std::mem; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::join_hash_map::JoinHashMapType; use crate::joins::utils::{ - build_batch_from_indices, equal_rows_arr, get_final_indices_from_bit_map, - need_produce_result_in_final, apply_join_filter_to_indices, adjust_indices_by_join_type, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - OnceFut, StatefulStreamResult, + adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, + equal_rows_arr, get_final_indices_from_bit_map, need_produce_result_in_final, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceFut, StatefulStreamResult, }; -use crate::metrics::{SpillMetrics}; +use crate::metrics::SpillMetrics; +use crate::spill::in_progress_spill_file::InProgressSpillFile; use crate::spill::spill_manager::SpillManager; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::{Array, ArrayRef, BooleanBufferBuilder, UInt32Array, UInt64Array}; -use arrow::compute::{take, concat_batches}; +use arrow::compute::{concat_batches, take}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{ hash_utils::create_hashes, internal_datafusion_err, internal_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, }; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; use futures::{ready, Stream, StreamExt}; - /// State of the partitioned hash join stream #[derive(Debug, Clone)] @@ -120,6 +119,13 @@ pub(super) enum BuildPartition { /// Memory reservation (released when spilled) reservation: MemoryReservation, }, + /// Partition resources released and not available + Released { + /// Placeholder reservation + reservation: MemoryReservation, + }, + /// Empty partition (no rows) + Empty, } /// Represents a partition of probe-side data @@ -186,8 +192,10 @@ pub(super) struct PartitionedHashJoinStream { pub probe_partitions: Vec, /// Current partition being processed pub current_partition: Option, - /// Manages the process of spilling and reading back intermediate data - pub spill_manager: SpillManager, + /// Spill manager for probe-side (right) batches + pub probe_spill_manager: SpillManager, + /// Spill manager for build-side (left) batches + pub build_spill_manager: SpillManager, /// Memory reservation for the entire operation pub memory_reservation: MemoryReservation, /// Runtime environment @@ -201,7 +209,12 @@ pub(super) struct PartitionedHashJoinStream { /// Running alignment start for right indices across probe batches (for semi/anti/mark) pub right_alignment_start: usize, /// Shared bounds accumulator for coordinating dynamic filter updates (optional) - pub bounds_accumulator: Option>, + pub bounds_accumulator: + Option>, + /// Future used to synchronize dynamic filter updates across partitions + pub bounds_waiter: Option>, + /// Cached probe-side schema + pub probe_schema: SchemaRef, /// Current probe batch (filtered to the active partition), if any pub current_probe_batch: Option, /// Current probe values for ON expressions @@ -210,6 +223,8 @@ pub(super) struct PartitionedHashJoinStream { pub current_probe_hashes: Vec, /// Current lookup offset within the join hash map pub current_offset: crate::joins::join_hash_map::JoinHashMapOffset, + /// Max joined probe-side index from current batch (for Right/Semi/Anti alignment) + pub joined_probe_idx: Option, /// Bitmaps to track matched build-side rows for outer joins (one per partition) pub matched_build_rows_per_partition: Vec, /// Current partition being processed for unmatched rows @@ -222,17 +237,103 @@ pub(super) struct PartitionedHashJoinStream { pub probes_buffered: bool, /// Current read position per partition within buffered probe batches pub probe_batch_positions: Vec, + /// Metrics: total probe rows buffered per partition (RAM) + pub probe_buffered_rows_per_part: Vec, + /// Metrics: total probe rows spilled per partition (disk) + pub probe_spilled_rows_per_part: Vec, + /// Metrics: total probe rows consumed during probing per partition + pub probe_consumed_rows_per_part: Vec, + /// Metrics: total matches after equality per partition + pub matched_rows_per_part: Vec, + /// Metrics: total rows emitted per partition + pub emitted_rows_per_part: Vec, + /// Metrics: total candidate pairs before equality per partition + pub candidate_pairs_per_part: Vec, + /// One-time flag to run shadow verification per partition + pub verify_once_per_part: Vec, + /// One-time flag for filter debug logging per partition + pub filter_debug_once_per_part: Vec, /// Pending async spill reload stream for build partitions pub pending_reload_stream: Option, /// Accumulated batches for pending reload pub pending_reload_batches: Vec, /// Target partition id for pending reload pub pending_reload_partition: Option, + /// In-progress probe spill writers, one per partition (used when corresponding build is spilled) + pub probe_spill_in_progress: Vec>, + /// Finalized probe spill files per partition (set after buffering probe side) + pub probe_spill_files: Vec>, + /// Pending probe stream for the current partition's probe spill file + pub pending_probe_stream: Option, + /// Target partition id for pending probe stream + pub pending_probe_partition: Option, } impl PartitionedHashJoinStream { + /// Compute partition id for a given hash using radix mask when possible + #[inline] + fn partition_for_hash(&self, hash: u64) -> usize { + if self.num_partitions.is_power_of_two() { + (hash as usize) & (self.num_partitions - 1) + } else { + // Fallback when num_partitions is not a power of two + (hash as usize) % self.num_partitions + } + } + + /// Report build-side bounds to the shared accumulator when dynamic filtering is enabled + fn poll_bounds_update( + &mut self, + cx: &mut Context<'_>, + build_data: &Arc, + ) -> Poll> { + if let Some(ref accumulator) = self.bounds_accumulator { + if self.bounds_waiter.is_none() { + println!( + "[spill-join] partition={} reporting build bounds (rows={})", + self.partition, + build_data.batch().num_rows() + ); + let accumulator = Arc::clone(accumulator); + let partition = self.partition; + let bounds = build_data.bounds.clone(); + self.bounds_waiter = Some(OnceFut::new(async move { + accumulator + .report_partition_bounds(partition, bounds) + .await + })); + } + + if let Some(waiter) = self.bounds_waiter.as_mut() { + match waiter.get(cx) { + Poll::Ready(Ok(_)) => { + println!( + "[spill-join] partition={} build bounds reported", + self.partition + ); + self.bounds_waiter = None; + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + println!( + "[spill-join] partition={} waiting on shared bounds barrier", + self.partition + ); + return Poll::Pending; + } + } + } + } + + Poll::Ready(Ok(())) + } + /// Ensure the build partition is loaded in-memory (reload if spilled) - fn ensure_build_partition_loaded(&mut self, cx: &mut Context<'_>, part_id: usize) -> Poll> { + fn ensure_build_partition_loaded( + &mut self, + cx: &mut Context<'_>, + part_id: usize, + ) -> Poll> { let needs_reload = matches!( self.build_partitions.get(part_id), Some(BuildPartition::Spilled { .. }) @@ -243,12 +344,22 @@ impl PartitionedHashJoinStream { // Kick off reload if needed if self.pending_reload_partition.is_none() { - if let Some(BuildPartition::Spilled { spill_file, .. }) = self.build_partitions.get_mut(part_id) { - let spill_file = spill_file.take().ok_or_else(|| internal_datafusion_err!("spill file already consumed for this partition"))?; - let stream = self.spill_manager.read_spill_as_stream(spill_file)?; + if let Some(BuildPartition::Spilled { spill_file, .. }) = + self.build_partitions.get_mut(part_id) + { + let spill_file = spill_file.take().ok_or_else(|| { + internal_datafusion_err!( + "spill file already consumed for this partition" + ) + })?; + let stream = self.build_spill_manager.read_spill_as_stream(spill_file)?; self.pending_reload_stream = Some(stream); self.pending_reload_batches.clear(); self.pending_reload_partition = Some(part_id); + println!( + "[spill-join][reload] start partition {}", + part_id + ); } } @@ -257,43 +368,71 @@ impl PartitionedHashJoinStream { if let Some(stream) = self.pending_reload_stream.as_mut() { match stream.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { + println!( + "[spill-join][reload] partition {} batch rows={}", + part_id, + batch.num_rows() + ); self.pending_reload_batches.push(batch); return Poll::Pending; - } + }, Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Ready(None) => { // Concatenate - let first_schema = self.pending_reload_batches.get(0) - .ok_or_else(|| internal_datafusion_err!("empty spilled partition"))? + let first_schema = self + .pending_reload_batches + .get(0) + .ok_or_else(|| { + internal_datafusion_err!("empty spilled partition") + })? .schema(); - let concatenated = concat_batches(&first_schema, self.pending_reload_batches.as_slice()) - .map_err(DataFusionError::from)?; + let concatenated = concat_batches( + &first_schema, + self.pending_reload_batches.as_slice(), + ) + .map_err(DataFusionError::from)?; + + println!( + "Reloaded spilled build partition {} for probing (rows={})", + part_id, + concatenated.num_rows() + ); - println!("Reloaded spilled build partition {} for probing (rows={})", part_id, concatenated.num_rows()); + // Grow global reservation conservatively by concatenated batch size + let concat_size = concatenated.get_array_memory_size(); + let _ = self.memory_reservation.try_grow(concat_size); // Recompute values and hashmap - let mut values: Vec = Vec::with_capacity(self.on_left.len()); + let mut values: Vec = + Vec::with_capacity(self.on_left.len()); for c in &self.on_left { - values.push(c.evaluate(&concatenated)?.into_array(concatenated.num_rows())?); + values.push( + c.evaluate(&concatenated)? + .into_array(concatenated.num_rows())?, + ); } let mut hash_map: Box = Box::new( - crate::joins::join_hash_map::JoinHashMapU32::with_capacity(concatenated.num_rows()), + crate::joins::join_hash_map::JoinHashMapU32::with_capacity( + concatenated.num_rows(), + ), ); self.hashes_buffer.clear(); self.hashes_buffer.resize(concatenated.num_rows(), 0); - crate::joins::utils::update_hash( - &self.on_left, - &concatenated, - &mut *hash_map, - 0, + // Build HT for reloaded partition from precomputed key arrays (no re-eval) + create_hashes( + &values, &self.random_state, &mut self.hashes_buffer, - 0, - true, )?; + hash_map.extend_zero(concatenated.num_rows()); + let iter = + self.hashes_buffer.iter().enumerate().map(|(i, h)| (i, h)); + hash_map.update_from_iter(Box::new(iter), 0); - let new_reservation = MemoryConsumer::new("partition_reload").with_can_spill(true).register(&self.runtime_env.memory_pool); + let new_reservation = MemoryConsumer::new("partition_reload") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); self.build_partitions[part_id] = BuildPartition::InMemory { hash_map, @@ -302,21 +441,33 @@ impl PartitionedHashJoinStream { reservation: new_reservation, }; - if let Some(BuildPartition::InMemory { hash_map, batch, .. }) = self.build_partitions.get(part_id) { - println!( - "Reloaded partition {} hashmap empty? {} rows={}", - part_id, - hash_map.is_empty(), - batch.num_rows() - ); - } + if let Some(BuildPartition::InMemory { + hash_map, batch, .. + }) = self.build_partitions.get(part_id) + { + println!( + "Reloaded partition {} hashmap empty? {} rows={}", + part_id, + hash_map.is_empty(), + batch.num_rows() + ); + } self.pending_reload_stream = None; self.pending_reload_batches.clear(); self.pending_reload_partition = None; + // Shrink global reservation now that partition is resident with per-partition reservation + let _ = self.memory_reservation.try_shrink(concat_size); return Poll::Ready(Ok(())); } - Poll::Pending => return Poll::Pending, + Poll::Pending => { + println!( + "[spill-join][reload] partition {} pending batches={}", + part_id, + self.pending_reload_batches.len() + ); + return Poll::Pending; + } } } } @@ -335,7 +486,8 @@ impl PartitionedHashJoinStream { left_fut: OnceFut, random_state: RandomState, join_metrics: BuildProbeJoinMetrics, - spill_metrics: SpillMetrics, + probe_spill_metrics: SpillMetrics, + build_spill_metrics: SpillMetrics, column_indices: Vec, null_equality: NullEquality, batch_size: usize, @@ -343,16 +495,23 @@ impl PartitionedHashJoinStream { memory_threshold: usize, memory_reservation: MemoryReservation, runtime_env: Arc, + build_schema: SchemaRef, + probe_schema: SchemaRef, + right_side_ordered: bool, + bounds_accumulator: Option< + Arc, + >, ) -> Result { - let spill_manager = SpillManager::new( + let probe_spill_manager = SpillManager::new( runtime_env.clone(), - spill_metrics, - schema.clone(), + probe_spill_metrics, + Arc::clone(&probe_schema), ); - println!( - "PartitionedHashJoinStream created: partition={}, num_partitions={}, memory_threshold={} bytes", - partition, num_partitions, memory_threshold + let build_spill_manager = SpillManager::new( + runtime_env.clone(), + build_spill_metrics, + Arc::clone(&build_schema), ); Ok(Self { @@ -375,18 +534,22 @@ impl PartitionedHashJoinStream { build_partitions: Vec::new(), probe_partitions: Vec::new(), current_partition: None, - spill_manager, + probe_spill_manager, + build_spill_manager, memory_reservation, runtime_env, hashes_buffer: Vec::new(), - right_side_ordered: false, + right_side_ordered, placeholder_emitted: false, right_alignment_start: 0, - bounds_accumulator: None, + bounds_accumulator, + bounds_waiter: None, + probe_schema, current_probe_batch: None, current_probe_values: vec![], current_probe_hashes: vec![], current_offset: (0, None), + joined_probe_idx: None, matched_build_rows_per_partition: Vec::new(), unmatched_partition: 0, unmatched_left_indices_cache: None, @@ -397,15 +560,24 @@ impl PartitionedHashJoinStream { pending_reload_stream: None, pending_reload_batches: Vec::new(), pending_reload_partition: None, + probe_spill_in_progress: (0..num_partitions).map(|_| None).collect(), + probe_spill_files: (0..num_partitions).map(|_| None).collect(), + pending_probe_stream: None, + pending_probe_partition: None, + probe_buffered_rows_per_part: vec![0; num_partitions], + probe_spilled_rows_per_part: vec![0; num_partitions], + probe_consumed_rows_per_part: vec![0; num_partitions], + matched_rows_per_part: vec![0; num_partitions], + emitted_rows_per_part: vec![0; num_partitions], + candidate_pairs_per_part: vec![0; num_partitions], + verify_once_per_part: vec![false; num_partitions], + filter_debug_once_per_part: vec![false; num_partitions], }) } /// Buffer the entire probe side stream into per-partition batches. /// Returns Pending until the right stream is fully consumed. - fn buffer_probe_side( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { + fn buffer_probe_side(&mut self, cx: &mut Context<'_>) -> Poll> { if self.probe_partitions.is_empty() { self.probe_partitions = (0..self.num_partitions) .map(|_| ProbePartition { @@ -418,47 +590,121 @@ impl PartitionedHashJoinStream { loop { match self.right.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { - // Compute ON values for the full batch - let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + // Compute ON values for the full batch (once) + println!( + "[spill-join] probe batch rows={} schema={:?}", + batch.num_rows(), + batch.schema().fields().len() + ); + let mut keys_values: Vec = + Vec::with_capacity(self.on_right.len()); for c in &self.on_right { let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; keys_values.push(v); } + // Compute hashes (once) let mut hashes = vec![0u64; batch.num_rows()]; create_hashes(&keys_values, &self.random_state, &mut hashes)?; - // For each partition, select rows and push filtered batch + // Build per-partition row indices in one pass + let mut indices_per_part: Vec> = + vec![Vec::new(); self.num_partitions]; + for (row_idx, &hash) in hashes.iter().enumerate() { + let pid = self.partition_for_hash(hash) as usize; + indices_per_part[pid].push(row_idx as u32); + } + + // For each non-empty partition, slice both data columns and already computed key values for part_id in 0..self.num_partitions { - let indices: Vec = hashes - .iter() - .enumerate() - .filter_map(|(i, &h)| ((h as usize) % self.num_partitions == part_id).then_some(i as u32)) - .collect(); - if indices.is_empty() { + let part_indices = &indices_per_part[part_id]; + if part_indices.is_empty() { continue; } - let indices_arr: UInt32Array = indices.clone().into(); - let mut filtered_columns: Vec = Vec::with_capacity(batch.num_columns()); + let indices_arr: UInt32Array = part_indices.clone().into(); + if self.probe_partitions[part_id].batches.is_empty() { + println!( + "[spill-join] probe partition {} first rows {:?}", + part_id, + &part_indices[..part_indices.len().min(10)] + ); + } + + // Take data columns + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); for col in batch.columns() { - filtered_columns.push(take(col, &indices_arr, None).map_err(DataFusionError::from)?); + filtered_columns.push( + take(col, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); } - let filtered_batch = RecordBatch::try_new(batch.schema(), filtered_columns) - .map_err(DataFusionError::from)?; - - // Filtered ON values for this partition's batch - let mut filtered_on_values: Vec = Vec::with_capacity(self.on_right.len()); - for c in &self.on_right { - let v = c.evaluate(&filtered_batch)?.into_array(filtered_batch.num_rows())?; - filtered_on_values.push(v); + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + + // Take ON key values using precomputed arrays (no re-eval) + let mut filtered_on_values: Vec = + Vec::with_capacity(self.on_right.len()); + for arr in &keys_values { + filtered_on_values.push( + take(arr, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); } - let filtered_hashes: Vec = indices - .iter() - .map(|&i| hashes[i as usize]) - .collect(); - self.probe_partitions[part_id].batches.push(filtered_batch); - self.probe_partitions[part_id].values.push(filtered_on_values); - self.probe_partitions[part_id].hashes.push(filtered_hashes); + // Slice hashes + let mut filtered_hashes: Vec = + Vec::with_capacity(part_indices.len()); + for &i in part_indices.iter() { + filtered_hashes.push(hashes[i as usize]); + } + + // If corresponding build partition is spilled, stream this partition's probe to disk + match self.build_partitions.get_mut(part_id) { + Some(BuildPartition::Spilled { .. }) => { + // Lazily create in-progress file + if self.probe_spill_in_progress[part_id].is_none() { + let ipf = self + .probe_spill_manager + .create_in_progress_file( + "hash_join_probe_partition", + )?; + self.probe_spill_in_progress[part_id] = Some(ipf); + } + if let Some(ref mut ipf) = + self.probe_spill_in_progress[part_id] + { + ipf.append_batch(&filtered_batch)?; + println!( + "[spill-join][probe-spill] write partition={} rows={}", + part_id, + filtered_batch.num_rows() + ); + } + self.probe_spilled_rows_per_part[part_id] += + filtered_batch.num_rows(); + // Do not RAM-buffer spilled probe partitions + } + _ => { + // Keep in memory for in-memory build partitions + self.probe_partitions[part_id] + .batches + .push(filtered_batch); + self.probe_partitions[part_id] + .values + .push(filtered_on_values); + self.probe_partitions[part_id] + .hashes + .push(filtered_hashes); + // Track buffered rows + let last = self.probe_partitions[part_id] + .batches + .last() + .unwrap(); + self.probe_buffered_rows_per_part[part_id] += + last.num_rows(); + } + } } } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), @@ -467,12 +713,38 @@ impl PartitionedHashJoinStream { self.probes_buffered = true; self.probe_batch_positions = vec![0; self.num_partitions]; println!( - "Buffered probe side: per-partition batch counts = {:?}", - self.probe_partitions.iter().map(|p| p.batches.len()).collect::>() + "[spill-join] probe buffered rows per partition = {:?}", + self.probe_partitions + .iter() + .enumerate() + .map(|(i, p)| (i, p.batches.iter().map(|b| b.num_rows()).sum::())) + .collect::>() ); + // Finalize any in-progress probe spill files + for part_id in 0..self.num_partitions { + if let Some(mut ipf) = + self.probe_spill_in_progress[part_id].take() + { + if let Some(file) = ipf.finish()? { + println!( + "[spill-join][probe-spill] finalize partition={} rows_spilled={}", + part_id, + self.probe_spilled_rows_per_part[part_id] + ); + self.probe_spill_files[part_id] = Some(file); + } + } + } return Poll::Ready(Ok(())); } - Poll::Pending => return Poll::Pending, + Poll::Pending => { + println!( + "[spill-join][probe-buffer] pending batches buffered={:?} spilled_rows={:?}", + self.probe_buffered_rows_per_part, + self.probe_spilled_rows_per_part + ); + return Poll::Pending; + }, } } } @@ -482,54 +754,74 @@ impl PartitionedHashJoinStream { &mut self, build_data: Arc, ) -> Result>> { - println!("Partitioning build side data into {} partitions", self.num_partitions); + println!( + "Partitioning build side data into {} partitions", + self.num_partitions + ); + // Metrics: record build input + self.join_metrics.build_input_batches.add(1); + self.join_metrics + .build_input_rows + .add(build_data.batch().num_rows()); // Initialize partitions self.build_partitions = Vec::with_capacity(self.num_partitions); // Initialize per-partition matched rows bitmaps self.matched_build_rows_per_partition = Vec::with_capacity(self.num_partitions); - + // Extract build-side data let batch = build_data.batch(); let values = build_data.values(); - + // Compute hash values for all rows in the build-side batch let mut hashes = vec![0u64; batch.num_rows()]; create_hashes(values, &self.random_state, &mut hashes)?; - + // Partition the data based on hash values - let mut partition_batches: Vec> = vec![Vec::new(); self.num_partitions]; - + let mut partition_batches: Vec> = + vec![Vec::new(); self.num_partitions]; + for (row_idx, &hash) in hashes.iter().enumerate() { - let partition_id = (hash as usize) % self.num_partitions; + let partition_id = self.partition_for_hash(hash); + if row_idx < 10 { + println!( + "[spill-join] build row {} hash={} -> partition {}", + row_idx, hash, partition_id + ); + } partition_batches[partition_id].push(row_idx); } - + // Create partitions; spill when memory_threshold is exceeded for partition_id in 0..self.num_partitions { let row_indices = &partition_batches[partition_id]; if row_indices.is_empty() { // Empty partition - create empty hash map - let empty_hash_map: Box = - Box::new(crate::joins::join_hash_map::JoinHashMapU32::with_capacity(0)); + let empty_hash_map: Box = Box::new( + crate::joins::join_hash_map::JoinHashMapU32::with_capacity(0), + ); let empty_batch = batch.slice(0, 0); - let empty_values: Vec = values.iter().map(|arr| arr.slice(0, 0)).collect(); - + let empty_values: Vec = + values.iter().map(|arr| arr.slice(0, 0)).collect(); + // Initialize empty matched rows bitmap for this partition let matched_bitmap = BooleanBufferBuilder::new(0); self.matched_build_rows_per_partition.push(matched_bitmap); - + self.build_partitions.push(BuildPartition::InMemory { hash_map: empty_hash_map, batch: empty_batch, values: empty_values, - reservation: MemoryConsumer::new("empty_partition").with_can_spill(true).register(&self.runtime_env.memory_pool), + reservation: MemoryConsumer::new("empty_partition") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool), }); continue; } - + // Create batch slice for this partition let partition_batch = self.take_rows(batch, row_indices)?; - let partition_values: Vec = values.iter() + let partition_values: Vec = values + .iter() .map(|arr| self.take_rows_from_array(arr, row_indices)) .collect::>>()?; @@ -570,8 +862,11 @@ impl PartitionedHashJoinStream { ); // Spill this partition to disk and do not keep it in memory let spill_file = self - .spill_manager - .spill_record_batch_and_finish(&[partition_batch.clone()], "hash_join_build_partition")? + .build_spill_manager + .spill_record_batch_and_finish( + &[partition_batch.clone()], + "hash_join_build_partition", + )? .ok_or_else(|| internal_datafusion_err!("expected spill file"))?; // Initialize matched rows bitmap for this partition @@ -580,7 +875,9 @@ impl PartitionedHashJoinStream { self.matched_build_rows_per_partition.push(matched_bitmap); // Per-partition reservation kept as zero-sized placeholder - let reservation = MemoryConsumer::new("partition_spilled").with_can_spill(true).register(&self.runtime_env.memory_pool); + let reservation = MemoryConsumer::new("partition_spilled") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); self.build_partitions.push(BuildPartition::Spilled { spill_file: Some(spill_file), @@ -588,45 +885,55 @@ impl PartitionedHashJoinStream { }); continue; } - + // Create hash map for this partition - let partition_hash_map: Box = - Box::new(crate::joins::join_hash_map::JoinHashMapU32::with_capacity(row_indices.len())); - - // Build the hash map for this partition using existing utilities + let partition_hash_map: Box = + Box::new(crate::joins::join_hash_map::JoinHashMapU32::with_capacity( + row_indices.len(), + )); + + // Build the hash map for this partition from pre-sliced key arrays let mut partition_hash_map = partition_hash_map; self.hashes_buffer.clear(); self.hashes_buffer.resize(partition_batch.num_rows(), 0); - crate::joins::utils::update_hash( - &self.on_left, - &partition_batch, - &mut *partition_hash_map, - 0, + create_hashes( + &partition_values, &self.random_state, &mut self.hashes_buffer, - 0, - true, )?; + partition_hash_map.extend_zero(partition_batch.num_rows()); + let iter = self.hashes_buffer.iter().enumerate().map(|(i, h)| (i, h)); + partition_hash_map.update_from_iter(Box::new(iter), 0); println!( "Built in-memory hash map for partition {} (rows={})", partition_id, row_indices.len() ); - + // Metrics: approximate build memory used (batch + values) + let approx = partition_batch.get_array_memory_size() + + partition_values + .iter() + .map(|a| a.get_array_memory_size()) + .sum::(); + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size().saturating_add(approx)); + // Initialize matched rows bitmap for this partition let mut matched_bitmap = BooleanBufferBuilder::new(row_indices.len()); matched_bitmap.append_n(row_indices.len(), false); self.matched_build_rows_per_partition.push(matched_bitmap); - + // Per-partition reservation: zero-sized placeholder; global reservation tracks memory - let reservation = MemoryConsumer::new("partition_memory").with_can_spill(true).register(&self.runtime_env.memory_pool); + let reservation = MemoryConsumer::new("partition_memory") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); let is_empty_after = partition_hash_map.is_empty(); println!( "Partition {} hashmap empty after build? {}", - partition_id, - is_empty_after + partition_id, is_empty_after ); self.build_partitions.push(BuildPartition::InMemory { @@ -636,9 +943,11 @@ impl PartitionedHashJoinStream { reservation, }); } - - // Start processing at the stream's assigned output partition - let start_partition = self.partition.min(self.num_partitions.saturating_sub(1)); + + // Start processing from the first radix partition and iterate sequentially + // This ensures a single stream can process all partitions when the + // operator reports a single output partition. + let start_partition = 0; println!( "Partitioning complete. Created {} partitions. Starting to process partition {}", self.build_partitions.len(), start_partition @@ -649,35 +958,39 @@ impl PartitionedHashJoinStream { total_partitions: self.num_partitions, is_last_partition: start_partition + 1 == self.num_partitions, }); - + Ok(StatefulStreamResult::Continue) } - + /// Take specific rows from a RecordBatch fn take_rows(&self, batch: &RecordBatch, indices: &[usize]) -> Result { - use arrow::compute::take; use arrow::array::UInt32Array; - - let indices_array = UInt32Array::from( - indices.iter().map(|&i| i as u32).collect::>() - ); - - let columns: Result, DataFusionError> = batch.columns().iter() + use arrow::compute::take; + + let indices_array = + UInt32Array::from(indices.iter().map(|&i| i as u32).collect::>()); + + let columns: Result, DataFusionError> = batch + .columns() + .iter() .map(|col| take(col, &indices_array, None).map_err(|e| e.into())) .collect(); - + Ok(RecordBatch::try_new(batch.schema(), columns?)?) } - + /// Take specific rows from an ArrayRef - fn take_rows_from_array(&self, array: &ArrayRef, indices: &[usize]) -> Result { - use arrow::compute::take; + fn take_rows_from_array( + &self, + array: &ArrayRef, + indices: &[usize], + ) -> Result { use arrow::array::UInt32Array; - - let indices_array = UInt32Array::from( - indices.iter().map(|&i| i as u32).collect::>() - ); - + use arrow::compute::take; + + let indices_array = + UInt32Array::from(indices.iter().map(|&i| i as u32).collect::>()); + Ok(take(array, &indices_array, None).map_err(DataFusionError::from)?) } @@ -693,22 +1006,30 @@ impl PartitionedHashJoinStream { } // Take ownership of the old partition to drop heavy resources - let placeholder_reservation = MemoryConsumer::new("partition_released_placeholder") - .with_can_spill(true) - .register(&self.runtime_env.memory_pool); + let placeholder_reservation = + MemoryConsumer::new("partition_released_placeholder") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); let old_partition = mem::replace( &mut self.build_partitions[partition_id], - BuildPartition::Spilled { - spill_file: None, + BuildPartition::Released { reservation: placeholder_reservation, }, ); match old_partition { - BuildPartition::InMemory { batch, values, reservation, .. } => { + BuildPartition::InMemory { + batch, + values, + reservation, + .. + } => { // Estimate memory held by this partition and shrink global reservation let mut estimated_size = batch.get_array_memory_size(); - estimated_size += values.iter().map(|a| a.get_array_memory_size()).sum::(); + estimated_size += values + .iter() + .map(|a| a.get_array_memory_size()) + .sum::(); let _ = self.memory_reservation.try_shrink(estimated_size); // Replace with an empty in-memory partition to keep indexing stable @@ -731,11 +1052,16 @@ impl PartitionedHashJoinStream { }; } BuildPartition::Spilled { reservation, .. } => { - // Keep as empty spilled (no further action needed) - self.build_partitions[partition_id] = BuildPartition::Spilled { - spill_file: None, - reservation, - }; + // Transition to Released; no files remain + self.build_partitions[partition_id] = + BuildPartition::Released { reservation }; + } + BuildPartition::Released { reservation } => { + self.build_partitions[partition_id] = + BuildPartition::Released { reservation }; + } + BuildPartition::Empty => { + // no-op } } } @@ -757,7 +1083,7 @@ impl PartitionedHashJoinStream { partition_state.total_partitions, self.build_partitions.len() ); - + // Do not buffer probe side here; selection happens below depending on num_partitions // (Spill reload handled by ensure_build_partition_loaded earlier if needed) @@ -771,38 +1097,181 @@ impl PartitionedHashJoinStream { Poll::Pending => return Poll::Pending, } - // If only 1 partition, stream the probe side directly (simpler and correct across executor partitions) - if self.num_partitions == 1 || self.partition >= self.num_partitions { - if self.current_probe_batch.is_none() { - match ready!(self.right.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - // Compute hashes for the full batch - let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); - for c in &self.on_right { - let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; - keys_values.push(v); - } - let mut hashes = vec![0u64; batch.num_rows()]; - create_hashes(&keys_values, &self.random_state, &mut hashes)?; + // Ensure probe side is fully buffered into per-partition containers + if !self.probes_buffered { + match self.buffer_probe_side(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } - // No filtering needed when only one partition - self.current_probe_hashes = hashes; - self.current_probe_values = keys_values; - self.current_probe_batch = Some(batch); - self.current_offset = (0, None); + // Select next probe batch for current partition + if self.current_probe_batch.is_none() { + // Decide probe source based on whether we spilled probe for this partition + let has_spilled_probe = self + .probe_spill_in_progress + .get(partition_state.partition_id) + .and_then(|o| o.as_ref()) + .is_some() + || self + .probe_spill_files + .get(partition_state.partition_id) + .and_then(|o| o.as_ref()) + .is_some() + || self + .pending_probe_partition + .is_some_and(|p| p == partition_state.partition_id); + let has_buffered_probe = self + .probe_partitions + .get(partition_state.partition_id) + .map(|p| !p.batches.is_empty()) + .unwrap_or(false); - if let Some(pb) = self.current_probe_batch.as_ref() { - println!( - "[spill-join] Direct probe batch rows={} (partitions=1)", - pb.num_rows() - ); - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(pb.num_rows()); - } + // Prefer buffered probe batches first; when exhausted, consume spilled probe stream + let pos = self.probe_batch_positions[partition_state.partition_id]; + let buffered_len = self + .probe_partitions + .get(partition_state.partition_id) + .map(|p| p.batches.len()) + .unwrap_or(0); + if has_buffered_probe && pos < buffered_len { + let part = &self.probe_partitions[partition_state.partition_id]; + // Take buffered batch/values/hashes + let batch = part.batches[pos].clone(); + let values = part.values[pos].clone(); + let hashes = part.hashes[pos].clone(); + self.probe_batch_positions[partition_state.partition_id] = pos + 1; + + self.current_probe_batch = Some(batch); + self.current_probe_values = values; + self.current_probe_hashes = hashes; + self.current_offset = (0, None); + if let Some(b) = &self.current_probe_batch { + self.probe_consumed_rows_per_part[partition_state.partition_id] = + self.probe_consumed_rows_per_part[partition_state.partition_id] + .saturating_add(b.num_rows()); + } + } else if has_spilled_probe { + // Stream from probe spill file for this partition + if self.pending_probe_partition.is_none() { + let file = self + .probe_spill_files + .get_mut(partition_state.partition_id) + .and_then(|o| o.take()); + if let Some(file) = file { + let stream = + self.probe_spill_manager.read_spill_as_stream(file)?; + self.pending_probe_stream = Some(stream); + self.pending_probe_partition = Some(partition_state.partition_id); + } else { + // Spilled probe indicated but file not yet finalized: wait + println!( + "[spill-join] Waiting for spilled probe file for partition {}", + partition_state.partition_id + ); + return Poll::Pending; } - Some(Err(e)) => return Poll::Ready(Err(e)), - None => { - // No more probe data for this partition, release and advance + } + if self.pending_probe_partition == Some(partition_state.partition_id) { + if let Some(stream) = self.pending_probe_stream.as_mut() { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + // Compute ON values and hashes for this filtered batch + let mut keys_values: Vec = + Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c + .evaluate(&batch)? + .into_array(batch.num_rows())?; + keys_values.push(v); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes( + &keys_values, + &self.random_state, + &mut hashes, + )?; + + self.current_probe_batch = Some(batch); + self.current_probe_values = keys_values; + self.current_probe_hashes = hashes; + self.current_offset = (0, None); + if let Some(b) = &self.current_probe_batch { + self.probe_consumed_rows_per_part + [partition_state.partition_id] = self + .probe_consumed_rows_per_part + [partition_state.partition_id] + .saturating_add(b.num_rows()); + } + println!( + "[spill-join][probe-spill] partition={} batch rows={}", + partition_state.partition_id, + self.current_probe_batch + .as_ref() + .map(|b| b.num_rows()) + .unwrap_or(0) + ); + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + // Finished probe for this partition; advance + self.pending_probe_stream = None; + self.pending_probe_partition = None; + println!( + "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + partition_state.partition_id, + self.probe_buffered_rows_per_part[partition_state.partition_id], + self.probe_spilled_rows_per_part[partition_state.partition_id], + self.probe_consumed_rows_per_part[partition_state.partition_id], + self.candidate_pairs_per_part[partition_state.partition_id], + self.matched_rows_per_part[partition_state.partition_id], + self.emitted_rows_per_part[partition_state.partition_id] + ); + println!( + "[spill-join][probe-spill] partition={} stream complete", + partition_state.partition_id + ); + self.release_partition_resources( + partition_state.partition_id, + ); + if partition_state.is_last_partition { + self.state = + PartitionedHashJoinState::HandleUnmatchedRows; + } else { + self.state = + PartitionedHashJoinState::ProcessPartition( + ProcessPartitionState { + partition_id: partition_state + .partition_id + + 1, + total_partitions: partition_state + .total_partitions, + is_last_partition: partition_state + .partition_id + + 1 + == partition_state.total_partitions, + }, + ); + } + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Poll::Pending => return Poll::Pending, + } + } else { + // No stream available; nothing to read, advance + self.pending_probe_stream = None; + self.pending_probe_partition = None; + println!( + "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + partition_state.partition_id, + self.probe_buffered_rows_per_part[partition_state.partition_id], + self.probe_spilled_rows_per_part[partition_state.partition_id], + self.probe_consumed_rows_per_part[partition_state.partition_id], + self.candidate_pairs_per_part[partition_state.partition_id], + self.matched_rows_per_part[partition_state.partition_id], + self.emitted_rows_per_part[partition_state.partition_id] + ); self.release_partition_resources(partition_state.partition_id); if partition_state.is_last_partition { self.state = PartitionedHashJoinState::HandleUnmatchedRows; @@ -819,37 +1288,60 @@ impl PartitionedHashJoinStream { return Poll::Ready(Ok(StatefulStreamResult::Continue)); } } + } else { + // Neither spilled nor buffered probe for this partition: advance + println!( + "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + partition_state.partition_id, + self.probe_buffered_rows_per_part[partition_state.partition_id], + self.probe_spilled_rows_per_part[partition_state.partition_id], + self.probe_consumed_rows_per_part[partition_state.partition_id], + self.candidate_pairs_per_part[partition_state.partition_id], + self.matched_rows_per_part[partition_state.partition_id], + self.emitted_rows_per_part[partition_state.partition_id] + ); + self.release_partition_resources(partition_state.partition_id); + if partition_state.is_last_partition { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + } else { + self.state = PartitionedHashJoinState::ProcessPartition( + ProcessPartitionState { + partition_id: partition_state.partition_id + 1, + total_partitions: partition_state.total_partitions, + is_last_partition: partition_state.partition_id + 1 + == partition_state.total_partitions, + }, + ); + } + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } - } else { - // Multi-partition execution: assume the scheduler already partitioned the right stream. - if self.current_probe_batch.is_none() { - match self.right.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - // Compute ON values and hashes for the full batch; no extra filtering here - let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); - for c in &self.on_right { - let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; - keys_values.push(v); - } - let mut hashes = vec![0u64; batch.num_rows()]; - create_hashes(&keys_values, &self.random_state, &mut hashes)?; + } - self.current_probe_hashes = hashes; - self.current_probe_values = keys_values; - self.current_probe_batch = Some(batch); - self.current_offset = (0, None); - } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(None) => { - // End of right stream for this partition: transition to unmatched rows for THIS partition only - self.release_partition_resources(partition_state.partition_id); - self.unmatched_partition = partition_state.partition_id; - self.state = PartitionedHashJoinState::HandleUnmatchedRows; - return Poll::Ready(Ok(StatefulStreamResult::Continue)); - } - Poll::Pending => return Poll::Pending, - } + // If no probe batch selected, advance to next partition (no probe rows here) + if self.current_probe_batch.is_none() { + println!( + "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + partition_state.partition_id, + self.probe_buffered_rows_per_part[partition_state.partition_id], + self.probe_spilled_rows_per_part[partition_state.partition_id], + self.probe_consumed_rows_per_part[partition_state.partition_id], + self.candidate_pairs_per_part[partition_state.partition_id], + self.matched_rows_per_part[partition_state.partition_id], + self.emitted_rows_per_part[partition_state.partition_id] + ); + self.release_partition_resources(partition_state.partition_id); + if partition_state.is_last_partition { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + } else { + self.state = + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + partition_id: partition_state.partition_id + 1, + total_partitions: partition_state.total_partitions, + is_last_partition: partition_state.partition_id + 1 + == partition_state.total_partitions, + }); } + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } // At this point we have a current probe batch for this partition @@ -859,18 +1351,73 @@ impl PartitionedHashJoinStream { .as_ref() .ok_or_else(|| internal_datafusion_err!("expected probe batch"))?; - let (build_hashmap, build_batch, build_values) = match self - .build_partitions - .get(partition_state.partition_id) - { - Some(BuildPartition::InMemory { - hash_map, - batch, - values, - .. - }) => (&**hash_map, batch, values as &Vec), - _ => return Poll::Ready(internal_err!("Missing or invalid build partition")), - }; + let (build_hashmap, build_batch, build_values) = + match self.build_partitions.get(partition_state.partition_id) { + Some(BuildPartition::InMemory { + hash_map, + batch, + values, + .. + }) => (&**hash_map, batch, values as &Vec), + Some(BuildPartition::Spilled { .. }) => { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Some(BuildPartition::Released { .. }) + | Some(BuildPartition::Empty) + | None => { + return Poll::Ready(internal_err!( + "Missing or invalid build partition" + )); + } + }; + // Debug: log ON expressions and output mapping once we have both sides + let on_left_desc = self + .on_left + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", "); + let on_right_desc = self + .on_right + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", "); + let mapping_desc = self + .column_indices + .iter() + .map(|ci| { + let side = match ci.side { + JoinSide::Left => "L", + JoinSide::Right => "R", + JoinSide::None => "M", + }; + format!("{}@{}", side, ci.index) + }) + .collect::>() + .join(", "); + println!( + "[spill-join] ON build=[{}] | probe=[{}] | out=[{}]", + on_left_desc, on_right_desc, mapping_desc + ); + + // Log resolved output column names for the current mapping + let out_names = self + .column_indices + .iter() + .map(|ci| match ci.side { + JoinSide::Left => { + format!("L:{}", build_batch.schema().field(ci.index).name()) + } + JoinSide::Right => { + format!("R:{}", probe_batch.schema().field(ci.index).name()) + } + JoinSide::None => "M:mark".to_string(), + }) + .collect::>() + .join(", "); + println!("[spill-join] OUT columns: {}", out_names); + println!( "[spill-join] Partition {} build hashmap empty? {}", partition_state.partition_id, @@ -888,6 +1435,10 @@ impl PartitionedHashJoinStream { let build_indices: UInt64Array = build_indices.into(); let probe_indices: UInt32Array = probe_indices.into(); + // Track candidate pairs before equality + self.candidate_pairs_per_part[partition_state.partition_id] = self + .candidate_pairs_per_part[partition_state.partition_id] + .saturating_add(build_indices.len()); println!( "[spill-join] Candidates before equality: build_ids={}, probe_ids={}, build_rows={}, probe_rows={}", build_indices.len(), @@ -905,40 +1456,102 @@ impl PartitionedHashJoinStream { self.null_equality, )?; + // Shadow verify on INNER join with single Int64 key (first 50k rows) + if matches!(self.join_type, JoinType::Inner) + && build_values.len() == 1 + && self.current_probe_values.len() == 1 + && build_values[0].data_type() == &arrow::datatypes::DataType::Int64 + && self.current_probe_values[0].data_type() + == &arrow::datatypes::DataType::Int64 + && !self.verify_once_per_part[partition_state.partition_id] + { + use arrow::array::Int64Array; + use std::collections::HashMap; + let bcol = build_values[0] + .as_any() + .downcast_ref::() + .unwrap(); + let pcol = self.current_probe_values[0] + .as_any() + .downcast_ref::() + .unwrap(); + let mut map: HashMap = HashMap::new(); + let max_b = bcol.len().min(50_000); + for i in 0..max_b { + if bcol.is_null(i) { + continue; + } + let k = bcol.value(i); + *map.entry(k).or_insert(0) += 1; + } + let mut expect = 0usize; + let max_p = pcol.len().min(50_000); + for i in 0..max_p { + if pcol.is_null(i) { + continue; + } + let k = pcol.value(i); + if let Some(&c) = map.get(&k) { + expect += c; + } + } + println!( + "[spill-join][verify] part={} expect_pairs~{} vs actual_after_eq={}", + partition_state.partition_id, + expect, + build_indices.len() + ); + self.verify_once_per_part[partition_state.partition_id] = true; + } + // Debug: log key data types and sample matched pairs if !build_indices.is_empty() { - let build_key0 = build_values - .get(0) - .map(|a| a.data_type().clone()) - .unwrap_or_else(|| build_batch.schema().field(0).data_type().clone()); - let probe_key0 = self + let build_types = build_values + .iter() + .map(|a| format!("{:?}", a.data_type())) + .collect::>() + .join(", "); + let probe_types = self .current_probe_values - .get(0) - .map(|a| a.data_type().clone()) - .unwrap_or_else(|| probe_batch.schema().field(0).data_type().clone()); + .iter() + .map(|a| format!("{:?}", a.data_type())) + .collect::>() + .join(", "); println!( - "[spill-join] Key types: build={:?}, probe={:?}, null_equality={:?}", - build_key0, - probe_key0, - self.null_equality + "[spill-join] Key types: build=[{}], probe=[{}], null_equality={:?}", + build_types, probe_types, self.null_equality ); let sample = build_indices.len().min(5); let mut pairs = Vec::new(); for i in 0..sample { let b = build_indices.value(i) as usize; let p = probe_indices.value(i) as usize; - let b_slice = build_values[0].as_ref().slice(b, 1); - let p_slice = self.current_probe_values[0].as_ref().slice(p, 1); - pairs.push(format!("({},{})", - b_slice.get_array_memory_size(), - p_slice.get_array_memory_size() - )); + // Include actual first-key values for sanity checks + let bk = &build_values[0]; + let pk = &self.current_probe_values[0]; + let bv = arrow::util::display::array_value_to_string(bk.as_ref(), b) + .unwrap_or_else(|_| "".to_string()); + let pv = arrow::util::display::array_value_to_string(pk.as_ref(), p) + .unwrap_or_else(|_| "".to_string()); + pairs.push(format!("({},{})", bv, pv)); } - println!("[spill-join] Sample pairs (mem sizes) {} -> {}: {}", sample, build_indices.len(), pairs.join(", ")); + println!( + "[spill-join] Sample key pairs {} -> {}: {}", + sample, + build_indices.len(), + pairs.join(", ") + ); } // Apply residual join filter if present - let (build_indices, probe_indices) = if let Some(filter) = &self.filter { + let mut build_indices = build_indices; + let mut probe_indices = probe_indices; + if let Some(filter) = &self.filter { + let before_len = build_indices.len(); + let before_build_indices = build_indices.clone(); + let before_probe_indices = probe_indices.clone(); + + let (filtered_build_indices, filtered_probe_indices) = apply_join_filter_to_indices( build_batch, probe_batch, @@ -947,26 +1560,278 @@ impl PartitionedHashJoinStream { filter, JoinSide::Left, None, - )? + )?; + + if !self.filter_debug_once_per_part[partition_state.partition_id] { + println!( + "[spill-join][filter-debug] part={} filter_before={} filter_after={}", + partition_state.partition_id, + before_len, + filtered_build_indices.len() + ); + + let sample = filtered_build_indices.len().min(5); + for i in 0..sample { + let build_row = filtered_build_indices.value(i) as usize; + let probe_row = filtered_probe_indices.value(i) as usize; + + let build_schema = build_batch.schema(); + let build_vals = (0..build_batch.num_columns()) + .map(|col| { + let name = build_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + build_batch.column(col).as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + let probe_schema = probe_batch.schema(); + let probe_vals = (0..probe_batch.num_columns()) + .map(|col| { + let name = probe_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + probe_batch.column(col).as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + println!( + "[spill-join][filter-debug] sample {} build {{{}}} probe {{{}}}", + i, build_vals, probe_vals + ); + } + + if filtered_build_indices.len() == 0 { + let sample_removed = before_build_indices.len().min(5); + for i in 0..sample_removed { + let build_row = before_build_indices.value(i) as usize; + let probe_row = before_probe_indices.value(i) as usize; + + let build_schema = build_batch.schema(); + let build_vals = (0..build_batch.num_columns()) + .map(|col| { + let name = build_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + build_batch.column(col).as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + let probe_schema = probe_batch.schema(); + let probe_vals = (0..probe_batch.num_columns()) + .map(|col| { + let name = probe_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + probe_batch.column(col).as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + println!( + "[spill-join][filter-debug] removed sample {} build {{{}}} probe {{{}}}", + i, build_vals, probe_vals + ); + } + } + + self.filter_debug_once_per_part[partition_state.partition_id] = true; + } + + if before_len != filtered_build_indices.len() { + println!( + "[spill-join][filter-debug] part={} filter removed {} rows", + partition_state.partition_id, + before_len - filtered_build_indices.len() + ); + } + + build_indices = filtered_build_indices; + probe_indices = filtered_probe_indices; + } + + // Log sample matches even if no residual filter remains, to debug equality behavior + if !self.filter_debug_once_per_part[partition_state.partition_id] + || build_indices.len() != probe_indices.len() + { + let sample = build_indices.len().min(5); + for i in 0..sample { + let build_row = build_indices.value(i) as usize; + let probe_row = probe_indices.value(i) as usize; + + let build_schema = build_batch.schema(); + let build_vals = (0..build_batch.num_columns()) + .map(|col| { + let name = build_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + build_batch.column(col).as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + let probe_schema = probe_batch.schema(); + let probe_vals = (0..probe_batch.num_columns()) + .map(|col| { + let name = probe_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + probe_batch.column(col).as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + println!( + "[spill-join][match-debug] part={} pair {} build {{{}}} probe {{{}}}", + partition_state.partition_id, + i, + build_vals, + probe_vals + ); + } + + if build_indices.len() != probe_indices.len() { + println!( + "[spill-join][match-debug] part={} MISMATCH len build={} probe={}", + partition_state.partition_id, + build_indices.len(), + probe_indices.len() + ); + } + + self.filter_debug_once_per_part[partition_state.partition_id] = true; + } + + // Debug counter: post-equality (before any alignment) + println!( + "[spill-join] After equality{} (pre-align): {}", + if self.filter.is_some() { "+filter" } else { "" }, + build_indices.len() + ); + // Shadow verify for two-key joins (stringified) to catch type coercion issues + if matches!(self.join_type, JoinType::Inner) + && build_values.len() == 2 + && self.current_probe_values.len() == 2 + && !self.verify_once_per_part[partition_state.partition_id] + { + use std::collections::HashMap; + let mut map: HashMap = HashMap::new(); + let max_b = build_batch.num_rows().min(50_000); + for i in 0..max_b { + let k0 = arrow::util::display::array_value_to_string( + build_values[0].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let k1 = arrow::util::display::array_value_to_string( + build_values[1].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let key = format!("{}|{}", k0, k1); + *map.entry(key).or_insert(0) += 1; + } + let mut expect = 0usize; + let max_p = probe_batch.num_rows().min(50_000); + for i in 0..max_p { + let k0 = arrow::util::display::array_value_to_string( + self.current_probe_values[0].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let k1 = arrow::util::display::array_value_to_string( + self.current_probe_values[1].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let key = format!("{}|{}", k0, k1); + if let Some(&c) = map.get(&key) { + expect += c; + } + } + println!( + "[spill-join][verify2] part={} expect_pairs~{} vs actual_after_eq={}", + partition_state.partition_id, + expect, + build_indices.len() + ); + self.verify_once_per_part[partition_state.partition_id] = true; + } + // Accumulate matched rows per partition + self.matched_rows_per_part[partition_state.partition_id] = self + .matched_rows_per_part[partition_state.partition_id] + .saturating_add(build_indices.len()); + + // Only apply alignment to right-oriented joins (RightSemi/RightAnti/RightMark) + let needs_alignment = matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ); + + // Compute alignment window and adjust indices if needed + let (build_indices, probe_indices, last_joined_right_idx) = if needs_alignment + { + // Compute last joined probe idx from matched pairs BEFORE alignment adjustments + let last_joined_right_idx = match probe_indices.len() { + 0 => None, + n => Some(probe_indices.value(n - 1) as usize), + }; + let index_alignment_range_start = + self.joined_probe_idx.map_or(0, |v| v + 1); + let index_alignment_range_end = if next_offset.is_none() { + probe_batch.num_rows() + } else { + last_joined_right_idx.map_or(index_alignment_range_start, |v| v + 1) + }; + let (build_indices, probe_indices) = adjust_indices_by_join_type( + build_indices, + probe_indices, + index_alignment_range_start..index_alignment_range_end, + self.join_type, + self.right_side_ordered, + )?; + (build_indices, probe_indices, last_joined_right_idx) } else { - (build_indices, probe_indices) + // Skip alignment for INNER/LEFT (and others not listed above) + (build_indices, probe_indices, None) }; - let (build_indices, probe_indices) = adjust_indices_by_join_type( - build_indices, - probe_indices, - 0..probe_batch.num_rows(), - self.join_type, - self.right_side_ordered, - )?; + // Debug counter: after alignment (or skipped) println!( - "[spill-join] Matched after equality{}: {}", - if self.filter.is_some() { "+filter" } else { "" }, + "[spill-join] After alignment{}: {}", + if needs_alignment { "" } else { " (skipped)" }, build_indices.len() ); // Prepare ids for marking after we release borrows let build_ids_to_mark: Vec = build_indices.values().to_vec(); + // Track last joined probe row only for right-oriented joins; otherwise clear it + self.joined_probe_idx = if needs_alignment && next_offset.is_some() { + last_joined_right_idx + } else { + None + }; // Build output batch depending on join side semantics let result = if matches!(self.join_type, JoinType::RightMark) { @@ -980,8 +1845,12 @@ impl PartitionedHashJoinStream { &self.column_indices, JoinSide::Right, )? - } else if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { - println!("[spill-join] Building output with JoinSide::Right ({:?})", self.join_type); + } else if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) + { + println!( + "[spill-join] Building output with JoinSide::Right ({:?})", + self.join_type + ); build_batch_from_indices( &self.schema, probe_batch, @@ -1003,11 +1872,18 @@ impl PartitionedHashJoinStream { )? }; + let emitted_rows = result.num_rows(); + self.emitted_rows_per_part[partition_state.partition_id] = self + .emitted_rows_per_part[partition_state.partition_id] + .saturating_add(emitted_rows); (result, build_ids_to_mark, next_offset) }; // Mark matched build-side rows for outer joins (use current partition's bitmap) - if let Some(bitmap) = self.matched_build_rows_per_partition.get_mut(partition_state.partition_id) { + if let Some(bitmap) = self + .matched_build_rows_per_partition + .get_mut(partition_state.partition_id) + { for build_idx in build_ids_to_mark { bitmap.set_bit(build_idx as usize, true); } @@ -1022,6 +1898,7 @@ impl PartitionedHashJoinStream { self.current_probe_values.clear(); self.current_probe_hashes.clear(); self.current_offset = (0, None); + self.joined_probe_idx = None; // Alignment is batch-local for semi/anti/mark in spillable path; do not carry across batches } @@ -1043,7 +1920,10 @@ impl PartitionedHashJoinStream { } /// Handle unmatched rows for outer joins (poll-based, non-blocking spill reload) - fn handle_unmatched_rows(&mut self, cx: &mut Context<'_>) -> Poll>>> { + fn handle_unmatched_rows( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { if !need_produce_result_in_final(self.join_type) { self.state = PartitionedHashJoinState::Completed; return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); @@ -1064,26 +1944,41 @@ impl PartitionedHashJoinStream { let left_chunk = left_chunk_ref .as_any() .downcast_ref::() - .ok_or_else(|| internal_datafusion_err!("failed to downcast left indices chunk"))?; + .ok_or_else(|| { + internal_datafusion_err!("failed to downcast left indices chunk") + })?; let right_chunk = right_chunk_ref .as_any() .downcast_ref::() - .ok_or_else(|| internal_datafusion_err!("failed to downcast right indices chunk"))?; + .ok_or_else(|| { + internal_datafusion_err!("failed to downcast right indices chunk") + })?; // Use current partition's build batch let partition = self .build_partitions .get(self.unmatched_partition) - .ok_or_else(|| internal_datafusion_err!("missing build partition during unmatched cached emission"))?; + .ok_or_else(|| { + internal_datafusion_err!( + "missing build partition during unmatched cached emission" + ) + })?; let build_batch = match partition { BuildPartition::InMemory { batch, .. } => batch, BuildPartition::Spilled { .. } => { // Should not happen because we only cache after loading InMemory indices return Poll::Ready(Ok(StatefulStreamResult::Continue)); } + BuildPartition::Released { .. } => { + return Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + BuildPartition::Empty => { + return Poll::Ready(Ok(StatefulStreamResult::Continue)) + } }; - let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + let empty_right_batch = + RecordBatch::new_empty(Arc::clone(&self.probe_schema)); println!( "Emitting unmatched rows chunk: partition={}, offset={}, size={} (total={})", self.unmatched_partition, @@ -1126,23 +2021,29 @@ impl PartitionedHashJoinStream { // Process unmatched rows for the current partition if self.unmatched_partition < self.build_partitions.len() { - let partition = self.build_partitions.get_mut(self.unmatched_partition) - .ok_or_else(|| internal_datafusion_err!("missing build partition during unmatched processing"))?; + let partition = self + .build_partitions + .get_mut(self.unmatched_partition) + .ok_or_else(|| { + internal_datafusion_err!( + "missing build partition during unmatched processing" + ) + })?; match partition { BuildPartition::InMemory { batch: _batch, .. } => { // Get unmatched indices for this partition using its bitmap - let (left_indices, right_indices) = if let Some(bitmap) = self.matched_build_rows_per_partition.get(self.unmatched_partition) { - get_final_indices_from_bit_map( - bitmap, - self.join_type, - ) + let (left_indices, right_indices) = if let Some(bitmap) = self + .matched_build_rows_per_partition + .get(self.unmatched_partition) + { + get_final_indices_from_bit_map(bitmap, self.join_type) } else { // If no bitmap, skip this partition self.unmatched_partition += 1; return Poll::Ready(Ok(StatefulStreamResult::Continue)); }; - + println!( "Unmatched calculation for partition {} -> {} rows", self.unmatched_partition, @@ -1165,8 +2066,13 @@ impl PartitionedHashJoinStream { BuildPartition::Spilled { spill_file, .. } => { // Non-blocking reload of spilled partition for unmatched rows if self.pending_reload_partition.is_none() { - let taken = spill_file.take().ok_or_else(|| internal_datafusion_err!("spill file already consumed for unmatched"))?; - let stream = self.spill_manager.read_spill_as_stream(taken)?; + let taken = spill_file.take().ok_or_else(|| { + internal_datafusion_err!( + "spill file already consumed for unmatched" + ) + })?; + let stream = + self.build_spill_manager.read_spill_as_stream(taken)?; self.pending_reload_stream = Some(stream); self.pending_reload_batches.clear(); self.pending_reload_partition = Some(self.unmatched_partition); @@ -1175,8 +2081,8 @@ impl PartitionedHashJoinStream { if self.pending_reload_partition == Some(self.unmatched_partition) { if let Some(stream) = self.pending_reload_stream.as_mut() { match stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - println!( + Poll::Ready(Some(Ok(batch))) => { + println!( "Reload stream yielded batch for build partition {} (rows={})", self.unmatched_partition, batch.num_rows() @@ -1186,11 +2092,20 @@ impl PartitionedHashJoinStream { } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Ready(None) => { - let first_schema = self.pending_reload_batches.get(0) - .ok_or_else(|| internal_datafusion_err!("empty spilled partition for unmatched"))? + let first_schema = self + .pending_reload_batches + .get(0) + .ok_or_else(|| { + internal_datafusion_err!( + "empty spilled partition for unmatched" + ) + })? .schema(); - let concatenated = concat_batches(&first_schema, self.pending_reload_batches.as_slice()) - .map_err(DataFusionError::from)?; + let concatenated = concat_batches( + &first_schema, + self.pending_reload_batches.as_slice(), + ) + .map_err(DataFusionError::from)?; println!( "Reloaded spilled build partition {} for unmatched rows (rows={})", @@ -1198,22 +2113,28 @@ impl PartitionedHashJoinStream { concatenated.num_rows() ); - let new_reservation = MemoryConsumer::new("partition_reload_unmatched") - .with_can_spill(true) - .register(&self.runtime_env.memory_pool); - let mut values: Vec = Vec::with_capacity(self.on_left.len()); + let new_reservation = + MemoryConsumer::new("partition_reload_unmatched") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let mut values: Vec = + Vec::with_capacity(self.on_left.len()); for c in &self.on_left { - values.push(c.evaluate(&concatenated)?.into_array(concatenated.num_rows())?); + values.push( + c.evaluate(&concatenated)? + .into_array(concatenated.num_rows())?, + ); } let hash_map: Box = Box::new( crate::joins::join_hash_map::JoinHashMapU32::with_capacity(concatenated.num_rows()), ); - self.build_partitions[self.unmatched_partition] = BuildPartition::InMemory { - hash_map, - batch: concatenated, - values, - reservation: new_reservation, - }; + self.build_partitions[self.unmatched_partition] = + BuildPartition::InMemory { + hash_map, + batch: concatenated, + values, + reservation: new_reservation, + }; println!( "Prepared spilled partition {} as InMemory for unmatched emission", self.unmatched_partition @@ -1225,22 +2146,33 @@ impl PartitionedHashJoinStream { self.pending_reload_partition = None; // Continue; next iteration will handle InMemory branch - return Poll::Ready(Ok(StatefulStreamResult::Continue)); + return Poll::Ready(Ok( + StatefulStreamResult::Continue, + )); } - Poll::Pending => { - // Yield until more data is available from reload stream - println!( + Poll::Pending => { + // Yield until more data is available from reload stream + println!( "Reload stream pending for build partition {} (accumulated_batches={})", self.unmatched_partition, self.pending_reload_batches.len() ); - return Poll::Pending; - } + return Poll::Pending; + } } } } Poll::Pending } + BuildPartition::Released { .. } => { + // Nothing to emit; advance + self.unmatched_partition += 1; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + BuildPartition::Empty => { + self.unmatched_partition += 1; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } } } else { // All partitions processed @@ -1271,6 +2203,11 @@ impl Stream for PartitionedHashJoinStream { let fut = &mut self.left_fut; ready!(fut.get_shared(cx))? }; + match self.poll_bounds_update(cx, &left_data) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => return Poll::Pending, + } match self.partition_build_side(left_data) { Ok(StatefulStreamResult::Continue) => continue, Ok(StatefulStreamResult::Ready(Some(batch))) => { @@ -1280,7 +2217,9 @@ impl Stream for PartitionedHashJoinStream { ); return Poll::Ready(Some(Ok(batch))); } - Ok(StatefulStreamResult::Ready(None)) => return Poll::Ready(None), + Ok(StatefulStreamResult::Ready(None)) => { + return Poll::Ready(None) + } Err(e) => return Poll::Ready(Some(Err(e))), } } diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index ce5892849aa6c..33ec3347e445a 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -290,6 +290,50 @@ pub(super) fn lookup_join_hashmap( null_equality, )?; + // Shadow verify for two-key INNER joins to catch coercion issues in classic path + if build_side_values.len() == 2 && probe_side_values.len() == 2 { + use std::collections::HashMap; + let mut map: HashMap = HashMap::new(); + let max_b = build_side_values[0].len().min(50_000); + for i in 0..max_b { + let k0 = arrow::util::display::array_value_to_string( + build_side_values[0].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let k1 = arrow::util::display::array_value_to_string( + build_side_values[1].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let key = format!("{}|{}", k0, k1); + *map.entry(key).or_insert(0) += 1; + } + let mut expect = 0usize; + let max_p = probe_side_values[0].len().min(50_000); + for i in 0..max_p { + let k0 = arrow::util::display::array_value_to_string( + probe_side_values[0].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let k1 = arrow::util::display::array_value_to_string( + probe_side_values[1].as_ref(), + i, + ) + .unwrap_or_else(|_| "".to_string()); + let key = format!("{}|{}", k0, k1); + if let Some(&c) = map.get(&key) { + expect += c; + } + } + println!( + "[hash-join][verify2] expect_pairs~{} vs actual_after_eq={}", + expect, + build_indices.len() + ); + } + Ok((build_indices, probe_indices, next_offset)) } @@ -556,7 +600,10 @@ impl HashJoinStream { last_joined_right_idx.map_or(0, |v| v + 1) }; - if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark) { + if matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ) { println!( "[hash-join] Align {:?}: pre-adjust right_indices={}, range={}..{} (next_offset_present={})", self.join_type, @@ -575,7 +622,10 @@ impl HashJoinStream { self.right_side_ordered, )?; - if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark) { + if matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ) { println!( "[hash-join] Align {:?}: post-adjust unique_right_indices={} (range={}..{})", self.join_type, @@ -596,6 +646,125 @@ impl HashJoinStream { ); } + // Log some matched pairs for debugging + let build_schema = build_side.left_data.batch().schema(); + let probe_schema = state.batch.schema(); + + let sample = left_indices.len().min(5); + if sample > 0 { + for i in 0..sample { + let build_row = left_indices.value(i) as usize; + let probe_row = right_indices.value(i) as usize; + + let build_vals = (0..build_schema.fields().len()) + .map(|col| { + let name = build_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + build_side.left_data.batch().column(col).as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + let probe_vals = (0..probe_schema.fields().len()) + .map(|col| { + let name = probe_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + state.batch.column(col).as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + println!( + "[hash-join][match-debug] partition={} pair {} build {{{}}} probe {{{}}}", + self.partition, + i, + build_vals, + probe_vals + ); + } + } + + let build_supply_idx = build_schema + .fields() + .iter() + .enumerate() + .find_map(|(idx, f)| { + if f.name().to_ascii_lowercase().contains("ps_supplycost") { + Some(idx) + } else { + None + } + }); + + let probe_min_idx = probe_schema + .fields() + .iter() + .enumerate() + .find_map(|(idx, f)| { + if f.name().to_ascii_lowercase().contains("min(") + || f.name().to_ascii_lowercase().contains("min_") + { + Some(idx) + } else { + None + } + }); + + if let (Some(build_supply_idx), Some(probe_min_idx)) = (build_supply_idx, probe_min_idx) { + let build_array = build_side.left_data.batch().column(build_supply_idx); + let probe_array = state.batch.column(probe_min_idx); + + for j in 0..left_indices.len() { + let build_row = left_indices.value(j) as usize; + let probe_row = right_indices.value(j) as usize; + + let build_value = arrow::util::display::array_value_to_string( + build_array.as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + let probe_value = arrow::util::display::array_value_to_string( + probe_array.as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + + if build_value != probe_value { + println!( + "[hash-join][mismatch] partition={} build_row={} ps_supplycost={} min_cost={}", + self.partition, + build_row, + build_value, + probe_value + ); + break; + } + } + } else { + println!( + "[hash-join][mismatch-debug] partition={} build_fields={:?} probe_fields={:?}", + self.partition, + build_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect::>(), + probe_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect::>() + ); + } + let result = if self.join_type == JoinType::RightMark { build_batch_from_indices( &self.schema, From c5949e3c326b8307f413d4cde2cf81e80503a267 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Mon, 3 Nov 2025 19:00:33 +0200 Subject: [PATCH 09/36] Fix RightSemi Join --- datafusion/execution/src/disk_manager.rs | 2 +- .../src/joins/hash_join/partitioned.rs | 60 +++++++++++-------- .../src/joins/hash_join/stream.rs | 40 +++++++------ datafusion/physical-plan/src/joins/utils.rs | 16 +++++ 4 files changed, 73 insertions(+), 45 deletions(-) diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 82f2d75ac1b57..05fd0a9a3b27b 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -35,7 +35,7 @@ const DEFAULT_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB /// Builder pattern for the [DiskManager] structure #[derive(Clone, Debug)] pub struct DiskManagerBuilder { - /// The storage mode of the disk manager +/// The storage mode of the disk manager mode: DiskManagerMode, /// The maximum amount of data (in bytes) stored inside the temporary directories. /// Default to 100GB diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 879be939572f8..8d085d239c2d9 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -52,7 +52,8 @@ use crate::joins::join_hash_map::JoinHashMapType; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, equal_rows_arr, get_final_indices_from_bit_map, need_produce_result_in_final, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceFut, StatefulStreamResult, + uint32_to_uint64_indices, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceFut, + StatefulStreamResult, }; use crate::metrics::SpillMetrics; use crate::spill::in_progress_spill_file::InProgressSpillFile; @@ -1797,13 +1798,24 @@ impl PartitionedHashJoinStream { 0 => None, n => Some(probe_indices.value(n - 1) as usize), }; - let index_alignment_range_start = - self.joined_probe_idx.map_or(0, |v| v + 1); - let index_alignment_range_end = if next_offset.is_none() { - probe_batch.num_rows() - } else { - last_joined_right_idx.map_or(index_alignment_range_start, |v| v + 1) - }; + let probe_num_rows = probe_batch.num_rows(); + let mut index_alignment_range_start = + self.joined_probe_idx.map_or(0, |v| v + 1); + let mut index_alignment_range_end = if next_offset.is_none() { + probe_num_rows + } else { + last_joined_right_idx.map_or(index_alignment_range_start, |v| v + 1) + }; + + if index_alignment_range_start > probe_num_rows { + index_alignment_range_start = probe_num_rows; + } + if index_alignment_range_end > probe_num_rows { + index_alignment_range_end = probe_num_rows; + } + if index_alignment_range_end < index_alignment_range_start { + index_alignment_range_end = index_alignment_range_start; + } let (build_indices, probe_indices) = adjust_indices_by_join_type( build_indices, probe_indices, @@ -1834,28 +1846,24 @@ impl PartitionedHashJoinStream { }; // Build output batch depending on join side semantics - let result = if matches!(self.join_type, JoinType::RightMark) { - println!("[spill-join] Building output with JoinSide::Right (RightMark)"); - build_batch_from_indices( - &self.schema, - probe_batch, - build_batch, - &build_indices, - &probe_indices, - &self.column_indices, - JoinSide::Right, - )? - } else if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) - { - println!( - "[spill-join] Building output with JoinSide::Right ({:?})", - self.join_type - ); + let result = if matches!( + self.join_type, + JoinType::RightMark | JoinType::RightSemi | JoinType::RightAnti + ) { + if matches!(self.join_type, JoinType::RightMark) { + println!("[spill-join] Building output with JoinSide::Right (RightMark)"); + } else { + println!( + "[spill-join] Building output with JoinSide::Right ({:?})", + self.join_type + ); + } + let right_indices_u64 = uint32_to_uint64_indices(&probe_indices); build_batch_from_indices( &self.schema, probe_batch, build_batch, - &build_indices, + &right_indices_u64, &probe_indices, &self.column_indices, JoinSide::Right, diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 33ec3347e445a..68b28a61a09f1 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -35,8 +35,8 @@ use crate::{ joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_empty_build_side, build_batch_from_indices, - need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - JoinHashMapType, StatefulStreamResult, + need_produce_result_in_final, uint32_to_uint64_indices, BuildProbeJoinMetrics, + ColumnIndex, JoinFilter, JoinHashMapType, StatefulStreamResult, }, RecordBatchStream, SendableRecordBatchStream, }; @@ -593,13 +593,24 @@ impl HashJoinStream { // Calculate range and perform alignment. // In case probe batch has been processed -- align all remaining rows. - let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); - let index_alignment_range_end = if next_offset.is_none() { - state.batch.num_rows() + let batch_num_rows = state.batch.num_rows(); + let mut index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); + let mut index_alignment_range_end = if next_offset.is_none() { + batch_num_rows } else { last_joined_right_idx.map_or(0, |v| v + 1) }; + if index_alignment_range_start > batch_num_rows { + index_alignment_range_start = batch_num_rows; + } + if index_alignment_range_end > batch_num_rows { + index_alignment_range_end = batch_num_rows; + } + if index_alignment_range_end < index_alignment_range_start { + index_alignment_range_end = index_alignment_range_start; + } + if matches!( self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark @@ -765,23 +776,16 @@ impl HashJoinStream { ); } - let result = if self.join_type == JoinType::RightMark { - build_batch_from_indices( - &self.schema, - &state.batch, - build_side.left_data.batch(), - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Right, - )? - } else if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { - // Emit probe-side rows for right-oriented joins + let result = if matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ) { + let right_indices_u64 = uint32_to_uint64_indices(&right_indices); build_batch_from_indices( &self.schema, &state.batch, build_side.left_data.batch(), - &left_indices, + &right_indices_u64, &right_indices, &self.column_indices, JoinSide::Right, diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d392650f88dda..3a0847754a147 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1230,6 +1230,22 @@ where ) } +pub(crate) fn uint32_to_uint64_indices(indices: &UInt32Array) -> UInt64Array { + if indices.null_count() == 0 { + UInt64Array::from_iter_values(indices.values().iter().map(|v| *v as u64)) + } else { + let mut builder = UInt64Builder::with_capacity(indices.len()); + for i in 0..indices.len() { + if indices.is_null(i) { + builder.append_null(); + } else { + builder.append_value(indices.value(i) as u64); + } + } + builder.finish() + } +} + fn build_range_bitmap( range: &Range, input: &PrimitiveArray, From ca73155faf2a877b53f4cb1e4301044e119498e3 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Mon, 3 Nov 2025 20:39:45 +0200 Subject: [PATCH 10/36] Fix LeftAnti --- .../src/joins/hash_join/partitioned.rs | 68 ++++++++++--------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 8d085d239c2d9..0611117a40954 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -1666,6 +1666,14 @@ impl PartitionedHashJoinStream { probe_indices = filtered_probe_indices; } + // Capture matched build indices prior to alignment so we can mark bitmaps even if + // the join type drops them (e.g. LeftAnti emits matches only in the final phase). + let build_indices_for_marking = if need_produce_result_in_final(self.join_type) { + Some(build_indices.clone()) + } else { + None + }; + // Log sample matches even if no residual filter remains, to debug equality behavior if !self.filter_debug_once_per_part[partition_state.partition_id] || build_indices.len() != probe_indices.len() @@ -1784,20 +1792,11 @@ impl PartitionedHashJoinStream { .matched_rows_per_part[partition_state.partition_id] .saturating_add(build_indices.len()); - // Only apply alignment to right-oriented joins (RightSemi/RightAnti/RightMark) - let needs_alignment = matches!( - self.join_type, - JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark - ); - - // Compute alignment window and adjust indices if needed - let (build_indices, probe_indices, last_joined_right_idx) = if needs_alignment - { - // Compute last joined probe idx from matched pairs BEFORE alignment adjustments - let last_joined_right_idx = match probe_indices.len() { - 0 => None, - n => Some(probe_indices.value(n - 1) as usize), - }; + // Compute alignment window (used by adjust_indices for all join types) + let last_joined_right_idx = match probe_indices.len() { + 0 => None, + n => Some(probe_indices.value(n - 1) as usize), + }; let probe_num_rows = probe_batch.num_rows(); let mut index_alignment_range_start = self.joined_probe_idx.map_or(0, |v| v + 1); @@ -1816,28 +1815,31 @@ impl PartitionedHashJoinStream { if index_alignment_range_end < index_alignment_range_start { index_alignment_range_end = index_alignment_range_start; } - let (build_indices, probe_indices) = adjust_indices_by_join_type( - build_indices, - probe_indices, - index_alignment_range_start..index_alignment_range_end, - self.join_type, - self.right_side_ordered, - )?; - (build_indices, probe_indices, last_joined_right_idx) - } else { - // Skip alignment for INNER/LEFT (and others not listed above) - (build_indices, probe_indices, None) - }; - // Debug counter: after alignment (or skipped) - println!( - "[spill-join] After alignment{}: {}", - if needs_alignment { "" } else { " (skipped)" }, - build_indices.len() + let (build_indices, probe_indices) = adjust_indices_by_join_type( + build_indices, + probe_indices, + index_alignment_range_start..index_alignment_range_end, + self.join_type, + self.right_side_ordered, + )?; + + // Only right-oriented joins need to preserve alignment state across batches + let needs_alignment = matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark ); - // Prepare ids for marking after we release borrows - let build_ids_to_mark: Vec = build_indices.values().to_vec(); + // Debug counter: after alignment (or effective no-op for other join types) + println!("[spill-join] After alignment: {}", build_indices.len()); + + // Prepare ids for marking after we release borrows. Prefer the pre-alignment + // matches (for join types like LeftAnti) so bitmap tracking remains accurate. + let build_ids_to_mark: Vec = if let Some(indices) = build_indices_for_marking { + indices.values().to_vec() + } else { + build_indices.values().to_vec() + }; // Track last joined probe row only for right-oriented joins; otherwise clear it self.joined_probe_idx = if needs_alignment && next_offset.is_some() { last_joined_right_idx From 5712df8a018fcee23d3e46c52149594f34ac1016 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Tue, 4 Nov 2025 01:01:10 +0300 Subject: [PATCH 11/36] DRAFT GraceHashJoin with disk spilling --- datafusion/execution/src/disk_manager.rs | 13 +- .../src/joins/grace_hash_join/exec.rs | 1219 +++++++++++++++++ .../src/joins/grace_hash_join/mod.rs | 23 + .../src/joins/grace_hash_join/stream.rs | 406 ++++++ .../physical-plan/src/joins/hash_join/mod.rs | 2 +- datafusion/physical-plan/src/joins/mod.rs | 2 + datafusion/physical-plan/src/joins/utils.rs | 4 +- datafusion/physical-plan/src/spill/mod.rs | 2 +- .../physical-plan/src/spill/spill_manager.rs | 75 +- 9 files changed, 1713 insertions(+), 33 deletions(-) create mode 100644 datafusion/physical-plan/src/joins/grace_hash_join/exec.rs create mode 100644 datafusion/physical-plan/src/joins/grace_hash_join/mod.rs create mode 100644 datafusion/physical-plan/src/joins/grace_hash_join/stream.rs diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 82f2d75ac1b57..67251088af120 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -26,7 +26,7 @@ use rand::{rng, Rng}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use tempfile::{Builder, NamedTempFile, TempDir}; +use tempfile::{Builder, NamedTempFile, TempDir, TempPath}; use crate::memory_pool::human_readable_size; @@ -370,6 +370,17 @@ impl RefCountedTempFile { pub fn current_disk_usage(&self) -> u64 { self.current_file_disk_usage } + + pub fn clone_refcounted(&self) -> Result { + let reopened = std::fs::File::open(self.path())?; + let temp_path = TempPath::from_path(self.path()); + Ok(Self { + _parent_temp_dir: Arc::clone(&self._parent_temp_dir), + tempfile: NamedTempFile::from_parts(reopened, temp_path), + current_file_disk_usage: self.current_file_disk_usage, + disk_manager: Arc::clone(&self.disk_manager), + }) + } } /// When the temporary file is dropped, subtract its disk usage from the disk manager's total diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs new file mode 100644 index 0000000000000..e25407c084bbb --- /dev/null +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -0,0 +1,1219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt; +use std::mem::size_of; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, OnceLock}; +use std::{any::Any, vec}; +use std::fmt::{format, Formatter}; +use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::joins::utils::{ + asymmetric_join_output_partitioning, reorder_output_after_swap, swap_join_projection, + update_hash, OnceAsync, OnceFut, +}; +use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; +use crate::projection::{ + try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, + ProjectionExec, +}; +use crate::spill::get_record_batch_memory_size; +use crate::{displayable, ExecutionPlanProperties, SpillManager}; +use crate::{ + common::can_project, + joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + need_produce_result_in_final, symmetric_join_output_partitioning, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, + }, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PlanProperties, SendableRecordBatchStream, Statistics, +}; + +use arrow::array::{ArrayRef, BooleanBufferBuilder, UInt32Array}; +use arrow::compute::{concat, concat_batches, take}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use arrow::util::bit_util; +use arrow_schema::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::utils::memory::estimate_memory_size; +use datafusion_common::{internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, JoinSide, JoinType, NullEquality, Result}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::TaskContext; +use datafusion_expr::{Accumulator, UserDefinedLogicalNode}; +use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; +use datafusion_physical_expr::equivalence::{ + join_equivalence_properties, ProjectionMapping, +}; +use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; + +use ahash::RandomState; +use arrow_ord::partition::partition; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::{StreamExt, TryStreamExt}; +use futures::executor::block_on; +use parking_lot::Mutex; +use datafusion_common::hash_utils::create_hashes; +use datafusion_execution::runtime_env::RuntimeEnv; +use crate::empty::EmptyExec; +use crate::joins::grace_hash_join::stream::{GraceAccumulator, GraceHashJoinStream, SpillFut}; +use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; +use crate::metrics::SpillMetrics; +use crate::spill::spill_manager::{GetSlicedSize, SpillLocation}; + +/// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. +const HASH_JOIN_SEED: RandomState = + RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); + +pub struct GraceHashJoinExec { + /// left (build) side which gets hashed + pub left: Arc, + /// right (probe) side which are filtered by the hash table + pub right: Arc, + /// Set of equijoin columns from the relations: `(left_col, right_col)` + pub on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + /// Filters which are applied while finding matching rows + pub filter: Option, + /// How the join is performed (`OUTER`, `INNER`, etc) + pub join_type: JoinType, + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + join_schema: SchemaRef, + /// Shared the `RandomState` for the hashing algorithm + random_state: RandomState, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// The projection indices of the columns in the output schema of join + pub projection: Option>, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// The equality null-handling behavior of the join algorithm. + pub null_equality: NullEquality, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, + /// Dynamic filter for pushing down to the probe side + /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. + /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. + dynamic_filter: Option, + accumulator: Arc, + spill_left: Arc, + spill_right: Arc, +} + +#[derive(Clone)] +struct HashJoinExecDynamicFilter { + /// Dynamic filter that we'll update with the results of the build side once that is done. + filter: Arc, + /// Bounds accumulator to keep track of the min/max bounds on the join keys for each partition. + /// It is lazily initialized during execution to make sure we use the actual execution time partition counts. + bounds_accumulator: OnceLock>, +} + +impl fmt::Debug for GraceHashJoinExec { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("HashJoinExec") + .field("left", &self.left) + .field("right", &self.right) + .field("on", &self.on) + .field("filter", &self.filter) + .field("join_type", &self.join_type) + .field("join_schema", &self.join_schema) + .field("random_state", &self.random_state) + .field("metrics", &self.metrics) + .field("projection", &self.projection) + .field("column_indices", &self.column_indices) + .field("null_equality", &self.null_equality) + .field("cache", &self.cache) + // Explicitly exclude dynamic_filter to avoid runtime state differences in tests + .finish() + } +} + + +impl EmbeddedProjection for GraceHashJoinExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } +} + +impl GraceHashJoinExec { + /// Tries to create a new [GraceHashJoinExec]. + /// + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + #[allow(clippy::too_many_arguments)] + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + projection: Option>, + null_equality: NullEquality, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + if on.is_empty() { + return plan_err!("On constraints in HashJoinExec should be non-empty"); + } + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let (join_schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + + let random_state = HASH_JOIN_SEED; + + let join_schema = Arc::new(join_schema); + + // check if the projection is valid + can_project(&join_schema, projection.as_ref())?; + + let cache = Self::compute_properties( + &left, + &right, + Arc::clone(&join_schema), + *join_type, + &on, + projection.as_ref(), + )?; + let partitions = left.output_partitioning().partition_count(); + let accumulator = GraceAccumulator::new(partitions); + + + let metrics = ExecutionPlanMetricsSet::new(); + let runtime = Arc::new(RuntimeEnv::default()); + let spill_left = Arc::new(SpillManager::new( + Arc::clone(&runtime), + SpillMetrics::new(&metrics, 0), + Arc::clone(&left_schema), + )); + let spill_right = Arc::new(SpillManager::new( + Arc::clone(&runtime), + SpillMetrics::new(&metrics, 0), + Arc::clone(&right_schema), + )); + + // Initialize both dynamic filter and bounds accumulator to None + // They will be set later if dynamic filtering is enabled + Ok(GraceHashJoinExec { + left, + right, + on, + filter, + join_type: *join_type, + join_schema, + random_state, + metrics, + projection, + column_indices, + null_equality, + cache, + dynamic_filter: None, + accumulator, + spill_left, + spill_right, + }) + } + + fn create_dynamic_filter(on: &JoinOn) -> Arc { + // Extract the right-side keys (probe side keys) from the `on` clauses + // Dynamic filter will be created from build side values (left side) and applied to probe side (right side) + let right_keys: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); + // Initialize with a placeholder expression (true) that will be updated when the hash table is built + Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) + } + + /// left (build) side which gets hashed + pub fn left(&self) -> &Arc { + &self.left + } + + /// right (probe) side which are filtered by the hash table + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { + &self.on + } + + /// Filters applied before join output + pub fn filter(&self) -> Option<&JoinFilter> { + self.filter.as_ref() + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + pub fn join_schema(&self) -> &SchemaRef { + &self.join_schema + } + + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality + } + + /// Calculate order preservation flags for this hash join. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!( + join_type, + JoinType::Inner + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightSemi + | JoinType::RightMark + ), + ] + } + + /// Get probe side information for the hash join. + pub fn probe_side() -> JoinSide { + // In current implementation right side is always probe side. + JoinSide::Right + } + + /// Return whether the join contains a projection + pub fn contains_projection(&self) -> bool { + self.projection.is_some() + } + + /// Return new instance of [HashJoinExec] with the given projection. + pub fn with_projection(&self, projection: Option>) -> Result { + // check if the projection is valid + can_project(&self.schema(), projection.as_ref())?; + let projection = match projection { + Some(projection) => match &self.projection { + Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), + None => Some(projection), + }, + None => None, + }; + Self::try_new( + Arc::clone(&self.left), + Arc::clone(&self.right), + self.on.clone(), + self.filter.clone(), + &self.join_type, + projection, + self.null_equality, + ) + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + left: &Arc, + right: &Arc, + schema: SchemaRef, + join_type: JoinType, + on: JoinOnRef, + projection: Option<&Vec>, + ) -> Result { + // Calculate equivalence properties: + let mut eq_properties = join_equivalence_properties( + left.equivalence_properties().clone(), + right.equivalence_properties().clone(), + &join_type, + Arc::clone(&schema), + &Self::maintains_input_order(join_type), + Some(Self::probe_side()), + on, + )?; + + let mut output_partitioning = symmetric_join_output_partitioning(left, right, &join_type)?; + let emission_type = if left.boundedness().is_unbounded() { + EmissionType::Final + } else if right.pipeline_behavior() == EmissionType::Incremental { + match join_type { + // If we only need to generate matched rows from the probe side, + // we can emit rows incrementally. + JoinType::Inner + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, + // If we need to generate unmatched rows from the *build side*, + // we need to emit them at the end. + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftMark + | JoinType::Full => EmissionType::Both, + } + } else { + right.pipeline_behavior() + }; + + // If contains projection, update the PlanProperties. + if let Some(projection) = projection { + // construct a map from the input expressions to the output expression of the Projection + let projection_mapping = + ProjectionMapping::from_indices(projection, &schema)?; + let out_schema = project_schema(&schema, Some(projection))?; + output_partitioning = + output_partitioning.project(&projection_mapping, &eq_properties); + eq_properties = eq_properties.project(&projection_mapping, out_schema); + } + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + emission_type, + boundedness_from_children([left, right]), + )) + } + + /// Returns a new `ExecutionPlan` that computes the same join as this one, + /// with the left and right inputs swapped using the specified + /// `partition_mode`. + /// + /// # Notes: + /// + /// This function is public so other downstream projects can use it to + /// construct `HashJoinExec` with right side as the build side. + /// + /// For using this interface directly, please refer to below: + /// + /// Hash join execution may require specific input partitioning (for example, + /// the left child may have a single partition while the right child has multiple). + /// + /// Calling this function on join nodes whose children have already been repartitioned + /// (e.g., after a `RepartitionExec` has been inserted) may break the partitioning + /// requirements of the hash join. Therefore, ensure you call this function + /// before inserting any repartitioning operators on the join's children. + /// + /// In DataFusion's default SQL interface, this function is used by the `JoinSelection` + /// physical optimizer rule to determine a good join order, which is + /// executed before the `EnforceDistribution` rule (the rule that may + /// insert `RepartitionExec` operators). + pub fn swap_inputs( + &self, + partition_mode: PartitionMode, + ) -> Result> { + let left = self.left(); + let right = self.right(); + let new_join = GraceHashJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect(), + self.filter().map(JoinFilter::swap), + &self.join_type().swap(), + swap_join_projection( + left.schema().fields().len(), + right.schema().fields().len(), + self.projection.as_ref(), + self.join_type(), + ), + self.null_equality(), + )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again + if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) || self.projection.is_some() + { + Ok(Arc::new(new_join)) + } else { + reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) + } + } +} + +impl DisplayAs for GraceHashJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let display_filter = self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()), + ); + let display_projections = if self.contains_projection() { + format!( + ", projection=[{}]", + self.projection + .as_ref() + .unwrap() + .iter() + .map(|index| format!( + "{}@{}", + self.join_schema.fields().get(*index).unwrap().name(), + index + )) + .collect::>() + .join(", ") + ) + } else { + "".to_string() + }; + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({c1}, {c2})")) + .collect::>() + .join(", "); + write!( + f, + "GraceHashJoinExec: join_type={:?}, on=[{}]{}{}", + self.join_type, on, display_filter, display_projections, + ) + } + DisplayFormatType::TreeRender => { + let on = self + .on + .iter() + .map(|(c1, c2)| { + format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref())) + }) + .collect::>() + .join(", "); + + if *self.join_type() != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + + writeln!(f, "on={on}")?; + + if let Some(filter) = self.filter.as_ref() { + writeln!(f, "filter={filter}")?; + } + + Ok(()) + } + } + } +} + +impl ExecutionPlan for GraceHashJoinExec { + fn name(&self) -> &'static str { + "GraceHashJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + + // For [JoinType::Inner] and [JoinType::RightSemi] in hash joins, the probe phase initiates by + // applying the hash function to convert the join key(s) in each row into a hash value from the + // probe side table in the order they're arranged. The hash value is used to look up corresponding + // entries in the hash table that was constructed from the build side table during the build phase. + // + // Because of the immediate generation of result rows once a match is found, + // the output of the join tends to follow the order in which the rows were read from + // the probe side table. This is simply due to the sequence in which the rows were processed. + // Hence, it appears that the hash join is preserving the order of the probe side. + // + // Meanwhile, in the case of a [JoinType::RightAnti] hash join, + // the unmatched rows from the probe side are also kept in order. + // This is because the **`RightAnti`** join is designed to return rows from the right + // (probe side) table that have no match in the left (build side) table. Because the rows + // are processed sequentially in the probe phase, and unmatched rows are directly output + // as results, these results tend to retain the order of the probe side table. + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + /// Creates a new HashJoinExec with different children while preserving configuration. + /// + /// This method is called during query optimization when the optimizer creates new + /// plan nodes. Importantly, it creates a fresh bounds_accumulator via `try_new` + /// rather than cloning the existing one because partitioning may have changed. + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let partitions = children[0].output_partitioning().partition_count(); + Ok(Arc::new(GraceHashJoinExec { + left: Arc::clone(&children[0]), + right: Arc::clone(&children[1]), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + random_state: self.random_state.clone(), + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: Self::compute_properties( + &children[0], + &children[1], + Arc::clone(&self.join_schema), + self.join_type, + &self.on, + self.projection.as_ref(), + )?, + // Keep the dynamic filter, bounds accumulator will be reset + dynamic_filter: self.dynamic_filter.clone(), + accumulator: Arc::clone(&self.accumulator), + spill_left: Arc::clone(&self.spill_left), + spill_right: Arc::clone(&self.spill_right), + })) + } + + fn reset_state(self: Arc) -> Result> { + Ok(Arc::new(GraceHashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + random_state: self.random_state.clone(), + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + // Reset dynamic filter and bounds accumulator to initial state + dynamic_filter: None, + accumulator: Arc::clone(&self.accumulator), + spill_left: Arc::clone(&self.spill_left), + spill_right: Arc::clone(&self.spill_right), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let left_partitions = self.left.output_partitioning().partition_count(); + let right_partitions = self.right.output_partitioning().partition_count(); + + if left_partitions != right_partitions { + return internal_err!( + "Invalid GraceHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ + consider using RepartitionExec" + ); + } + + let enable_dynamic_filter_pushdown = self.dynamic_filter.is_some(); + + let join_metrics = Arc::new(BuildProbeJoinMetrics::new(partition, &self.metrics)); + + let left = self.left.execute(partition, Arc::clone(&context))?; + let left_schema = Arc::clone(&self.left.schema()); + let on_left = self + .on + .iter() + .map(|(left_expr, _)| Arc::clone(left_expr)) + .collect::>(); + + let right = self.right.execute(partition, Arc::clone(&context))?; + let right_schema = Arc::clone(&self.right.schema()); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + + let spill_left = Arc::new(SpillManager::new( + Arc::clone(&context.runtime_env()), + SpillMetrics::new(&self.metrics, partition), + Arc::clone(&left_schema), + )); + let spill_right = Arc::new(SpillManager::new( + Arc::clone(&context.runtime_env()), + SpillMetrics::new(&self.metrics, partition), + Arc::clone(&right_schema), + )); + + // update column indices to reflect the projection + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + + let random_state = self.random_state.clone(); + let on = self.on.clone(); + let spill_left_clone = Arc::clone(&spill_left); + let spill_right_clone = Arc::clone(&spill_right); + let accumulator_clone = Arc::clone(&self.accumulator); + let join_metrics_clone = Arc::clone(&join_metrics); + let spill_fut = OnceFut::new(async move { + let (left_idx, right_idx) = partition_and_spill( + random_state, + on, + left, + right, + join_metrics_clone, + enable_dynamic_filter_pushdown, + left_partitions, + spill_left_clone, + spill_right_clone, + partition, + ).await?; + accumulator_clone.report_partition(partition, left_idx.clone(), right_idx.clone()).await; + Ok(SpillFut::new(partition, left_idx, right_idx)) + }); + + Ok(Box::pin(GraceHashJoinStream::new( + self.schema(), + spill_fut, + spill_left, + spill_right, + on_left, + on_right, + self.random_state.clone(), + self.join_type, + column_indices_after_projection, + join_metrics, + context, + Arc::clone(&self.accumulator), + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + let stats = estimate_join_statistics( + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, + self.on.clone(), + &self.join_type, + &self.join_schema, + )?; + // Project statistics if there is a projection + Ok(stats.project(self.projection.as_ref())) + } + + /// Tries to push `projection` down through `hash_join`. If possible, performs the + /// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections + /// as its children. Otherwise, returns `None`. + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + // TODO: currently if there is projection in GraceHashJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later. + if self.contains_projection() { + return Ok(None); + } + + if let Some(JoinData { + projected_left_child, + projected_right_child, + join_filter, + join_on, + }) = try_pushdown_through_join( + projection, + self.left(), + self.right(), + self.on(), + self.schema(), + self.filter(), + )? { + Ok(Some(Arc::new(GraceHashJoinExec::try_new( + Arc::new(projected_left_child), + Arc::new(projected_right_child), + join_on, + join_filter, + self.join_type(), + // Returned early if projection is not None + None, + self.null_equality, + )?))) + } else { + try_embed_projection(projection, self) + } + } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + config: &ConfigOptions, + ) -> Result { + // Other types of joins can support *some* filters, but restrictions are complex and error prone. + // For now we don't support them. + // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs + // See https://github.com/apache/datafusion/issues/16973 for tracking. + if self.join_type != JoinType::Inner { + return Ok(FilterDescription::all_unsupported( + &parent_filters, + &self.children(), + )); + } + + // Get basic filter descriptions for both children + let left_child = crate::filter_pushdown::ChildFilterDescription::from_child( + &parent_filters, + self.left(), + )?; + let mut right_child = crate::filter_pushdown::ChildFilterDescription::from_child( + &parent_filters, + self.right(), + )?; + + // Add dynamic filters in Post phase if enabled + if matches!(phase, FilterPushdownPhase::Post) + && config.optimizer.enable_dynamic_filter_pushdown + { + // Add actual dynamic filter to right side (probe side) + let dynamic_filter = Self::create_dynamic_filter(&self.on); + right_child = right_child.with_self_filter(dynamic_filter); + } + + Ok(FilterDescription::new() + .with_child(left_child) + .with_child(right_child)) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // Note: this check shouldn't be necessary because we already marked all parent filters as unsupported for + // non-inner joins in `gather_filters_for_pushdown`. + // However it's a cheap check and serves to inform future devs touching this function that they need to be really + // careful pushing down filters through non-inner joins. + if self.join_type != JoinType::Inner { + // Other types of joins can support *some* filters, but restrictions are complex and error prone. + // For now we don't support them. + // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs + return Ok(FilterPushdownPropagation::all_unsupported( + child_pushdown_result, + )); + } + + let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); + assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children + let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child + // We expect 0 or 1 self filters + if let Some(filter) = right_child_self_filters.first() { + // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said + // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating + let predicate = Arc::clone(&filter.predicate); + if let Ok(dynamic_filter) = + Arc::downcast::(predicate) + { + // We successfully pushed down our self filter - we need to make a new node with the dynamic filter + let new_node = Arc::new(GraceHashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + random_state: self.random_state.clone(), + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + dynamic_filter: Some(HashJoinExecDynamicFilter { + filter: dynamic_filter, + bounds_accumulator: OnceLock::new(), + }), + spill_left: Arc::clone(&self.spill_left), + spill_right: Arc::clone(&self.spill_right), + accumulator: Arc::clone(&self.accumulator), + }); + result = result.with_updated_node(new_node as Arc); + } + } + Ok(result) + } +} + +/// Accumulator for collecting min/max bounds from build-side data during hash join. +/// +/// This struct encapsulates the logic for progressively computing column bounds +/// (minimum and maximum values) for a specific join key expression as batches +/// are processed during the build phase of a hash join. +/// +/// The bounds are used for dynamic filter pushdown optimization, where filters +/// based on the actual data ranges can be pushed down to the probe side to +/// eliminate unnecessary data early. +struct CollectLeftAccumulator { + /// The physical expression to evaluate for each batch + expr: Arc, + /// Accumulator for tracking the minimum value across all batches + min: MinAccumulator, + /// Accumulator for tracking the maximum value across all batches + max: MaxAccumulator, +} + +pub async fn partition_and_spill( + random_state: RandomState, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + mut left_stream: SendableRecordBatchStream, + mut right_stream: SendableRecordBatchStream, + join_metrics: Arc, + enable_dynamic_filter_pushdown: bool, + partition_count: usize, + spill_left: Arc, + spill_right: Arc, + partition: usize, +) -> Result<(Vec, Vec)> { + let on_left: Vec<_> = on.iter().map(|(l, _)| Arc::clone(l)).collect(); + let on_right: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); + + // === LEFT side partitioning === + let left_index = partition_and_spill_one_side( + &mut left_stream, + &on_left, + &random_state, + partition_count, + spill_left, + &join_metrics, + enable_dynamic_filter_pushdown, + &format!("left_{partition}"), + ) + .await?; + + // === RIGHT side partitioning === + let right_index = partition_and_spill_one_side( + &mut right_stream, + &on_right, + &random_state, + partition_count, + spill_right, + &join_metrics, + enable_dynamic_filter_pushdown, + &format!("right_{partition}"), + ) + .await?; + Ok((left_index, right_index)) +} + +async fn partition_and_spill_one_side( + input: &mut SendableRecordBatchStream, + on_exprs: &[PhysicalExprRef], + random_state: &RandomState, + partition_count: usize, + spill_manager: Arc, + join_metrics: &BuildProbeJoinMetrics, + enable_dynamic_filter_pushdown: bool, + file_request_msg: &str, +) -> Result> { + let mut partitions: Vec = (0..partition_count) + .map(|_| PartitionWriter::new(Arc::clone(&spill_manager))) + .collect(); + + let mut total_rows = 0usize; + + while let Some(batch) = input.next().await { + let batch = batch?; + let num_rows = batch.num_rows(); + if num_rows == 0 { + continue; + } + + total_rows += num_rows; + join_metrics.build_input_batches.add(1); + join_metrics.build_input_rows.add(num_rows); + + // Calculate hashes + let keys = on_exprs + .iter() + .map(|c| c.evaluate(&batch)?.into_array(num_rows)) + .collect::>>()?; + + let mut hashes = vec![0u64; num_rows]; + create_hashes(&keys, random_state, &mut hashes)?; + + // Spread to partitions + let mut indices: Vec> = vec![Vec::new(); partition_count]; + for (row, h) in hashes.iter().enumerate() { + let bucket = (*h as usize) % partition_count; + indices[bucket].push(row as u32); + } + + // Collect and spill + for (i, idxs) in indices.into_iter().enumerate() { + if idxs.is_empty() { + continue; + } + let idx_array = UInt32Array::from(idxs); + let taken = batch + .columns() + .iter() + .map(|c| take(c.as_ref(), &idx_array, None)) + .collect::>>()?; + let part_batch = RecordBatch::try_new(batch.schema(), taken)?; + let request_msg = format!("grace_partition_{file_request_msg}_{i}"); + partitions[i].spill_batch_auto(&part_batch, &request_msg)?; + } + } + + // Prepare indexes + let mut result = Vec::with_capacity(partitions.len()); + for (i, writer) in partitions.into_iter().enumerate() { + result.push(writer.finish(i)?); + } + + Ok(result) +} + +#[derive(Debug)] +pub struct PartitionWriter { + spill_manager: Arc, + total_rows: usize, + total_bytes: usize, + chunks: Vec, +} + +impl PartitionWriter { + pub fn new(spill_manager: Arc) -> Self { + Self { + spill_manager, + total_rows: 0, + total_bytes: 0, + chunks: vec![], + } + } + + pub fn spill_batch_auto(&mut self, batch: &RecordBatch, request_msg: &str) -> Result<()> { + let loc = self.spill_manager.spill_batch_auto(batch, request_msg)?; + self.total_rows += batch.num_rows(); + self.total_bytes += get_record_batch_memory_size(batch); + self.chunks.push(loc); + Ok(()) + } + + pub fn finish(self, part_id: usize) -> Result { + Ok(PartitionIndex { + part_id, + chunks: self.chunks, + total_rows: self.total_rows, + total_bytes: self.total_bytes, + }) + } +} + +/// Describes a single partition of spilled data (used in GraceHashJoin). +/// +/// Each partition can consist of one or multiple chunks (batches) +/// that were spilled either to memory or to disk. +/// These chunks are later reloaded during the join phase. +/// +/// Example: +/// Partition 3 -> [ spill_chunk_3_0.arrow, spill_chunk_3_1.arrow ] +#[derive(Debug, Clone)] +pub struct PartitionIndex { + /// Unique partition identifier (0..N-1) + pub part_id: usize, + + /// Total number of rows in this partition + pub total_rows: usize, + + /// Total size in bytes of all batches in this partition + pub total_bytes: usize, + + /// Collection of spill locations (each corresponds to one batch written + /// by [`PartitionWriter::spill_batch_auto`]) + pub chunks: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::coalesce_partitions::CoalescePartitionsExec; + use crate::test::{assert_join_metrics, TestMemoryExec}; + use crate::{ + common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, + test::exec::MockExec, + }; + + use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; + use arrow::buffer::NullBuffer; + use arrow::datatypes::{DataType, Field}; + use arrow::util::pretty::print_batches; + use arrow_schema::Schema; + use futures::future; + use datafusion_common::hash_utils::create_hashes; + use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, + ScalarValue, + }; + use datafusion_execution::config::SessionConfig; + use hashbrown::HashTable; + use insta::{allow_duplicates, assert_snapshot}; + use rstest::*; + use rstest_reuse::*; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + fn build_large_table( + a_name: &str, + b_name: &str, + c_name: &str, + n: usize, + ) -> Arc { + let a: ArrayRef = Arc::new(Int32Array::from_iter_values(1..=n as i32)); + let b: ArrayRef = Arc::new(Int32Array::from_iter_values(1..=n as i32)); + let c: ArrayRef = Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 10))); + + let schema = Arc::new(arrow::datatypes::Schema::new(vec![ + arrow::datatypes::Field::new(a_name, arrow::datatypes::DataType::Int32, false), + arrow::datatypes::Field::new(b_name, arrow::datatypes::DataType::Int32, false), + arrow::datatypes::Field::new(c_name, arrow::datatypes::DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap(); + + // MemoryExec требует список партиций: Vec> + Arc::new(TestMemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + #[tokio::test] + async fn single_partition_join_overallocation() -> Result<()> { + // let left = build_table( + // ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + // ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + // ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + // ); + // let right = build_table( + // ("a2", &vec![1, 2]), + // ("b2", &vec![1, 2]), + // ("c2", &vec![14, 15]), + // ); + let left = build_large_table("a1", "b1", "c1", 100_000); + let right = build_large_table("a2", "b2", "c2", 50_000); + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, + )]; + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + let left_repartitioned: Arc = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, 32), + )?); + let right_repartitioned: Arc = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, 32), + )?); + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50_000_000_000, 1.0) + .build_arc()?; + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); + + let join = GraceHashJoinExec::try_new( + Arc::clone(&left_repartitioned), + Arc::clone(&right_repartitioned), + on.clone(), + None, + &JoinType::Inner, + None, + NullEquality::NullEqualsNothing, + )?; + + let partition_count = right_repartitioned.output_partitioning().partition_count(); + println!("partition_count {partition_count}"); + + let tasks: Vec<_> = (0..partition_count) + .map(|i| { + let ctx = Arc::clone(&task_ctx); + let s = join.execute(i, ctx).unwrap(); + async move { common::collect(s).await } + }) + .collect(); + + let results = future::join_all(tasks).await; + let mut batches = Vec::new(); + for r in results { + let mut v = r?; + v.retain(|b| b.num_rows() > 0); + batches.extend(v); + } + + print_batches(&*batches).unwrap(); + // Asserting that operator-level reservation attempting to overallocate + // assert_contains!( + // err.to_string(), + // "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput" + // ); + // + // assert_contains!( + // err.to_string(), + // "Failed to allocate additional 120.0 B for HashJoinInput" + // ); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/mod.rs b/datafusion/physical-plan/src/joins/grace_hash_join/mod.rs new file mode 100644 index 0000000000000..55d7e2035e6c7 --- /dev/null +++ b/datafusion/physical-plan/src/joins/grace_hash_join/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GraceHashJoinExec`] Partitioned Hash Join Operator + +pub use exec::GraceHashJoinExec; + +mod exec; +mod stream; diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs new file mode 100644 index 0000000000000..417990da43f11 --- /dev/null +++ b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs @@ -0,0 +1,406 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream implementation for Hash Join +//! +//! This module implements [`HashJoinStream`], the streaming engine for +//! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details. + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::joins::utils::{ + equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, +}; +use crate::{handle_state, hash_utils::create_hashes, joins::join_hash_map::JoinHashMapOffset, joins::utils::{ + adjust_indices_by_join_type, apply_join_filter_to_indices, + build_batch_empty_build_side, build_batch_from_indices, + need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + JoinHashMapType, StatefulStreamResult, +}, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, SpillManager}; + +use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{ + internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result, +}; +use datafusion_physical_expr::PhysicalExprRef; +use ahash::RandomState; +use futures::{ready, FutureExt, Stream, StreamExt}; +use tokio::sync::Mutex; +use datafusion_execution::TaskContext; +use crate::empty::EmptyExec; +use crate::joins::grace_hash_join::exec::PartitionIndex; +use crate::joins::{HashJoinExec, PartitionMode}; +use crate::memory::MemoryStream; +use crate::stream::RecordBatchStreamAdapter; +use crate::test::TestMemoryExec; + +enum GraceJoinState { + /// Waiting for the partitioning phase (Phase 1) to finish + WaitPartitioning, + + WaitAllPartitions { + wait_all_fut: Option>>, + }, + + /// Currently joining partition `current` + JoinPartition { + current: usize, + all_parts: Arc>, + current_stream: Option, + left_fut: Option>>, + right_fut: Option>>, + }, + + Done, +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +#[derive(Debug, Clone)] +pub(super) struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, + /// Probe-side on expressions values + values: Vec, + /// Starting offset for JoinHashMap lookups + offset: JoinHashMapOffset, + /// Max joined probe-side index from current batch + joined_probe_idx: Option, +} + +pub struct GraceHashJoinStream { + schema: SchemaRef, + spill_fut: OnceFut, + spill_left: Arc, + spill_right: Arc, + on_left: Vec, + on_right: Vec, + random_state: RandomState, + join_type: JoinType, + column_indices: Vec, + join_metrics: Arc, + context: Arc, + accumulator: Arc, + state: GraceJoinState, +} + +#[derive(Debug, Clone)] +pub struct SpillFut { + partition: usize, + left: Vec, + right: Vec +} +impl SpillFut { + pub(crate) fn new(partition: usize, left: Vec, right: Vec) -> Self { + SpillFut { + partition, + left, + right, + } + } +} + +impl RecordBatchStream for GraceHashJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl GraceHashJoinStream { + pub fn new( + schema: SchemaRef, + spill_fut: OnceFut, + spill_left: Arc, + spill_right: Arc, + on_left: Vec, + on_right: Vec, + random_state: RandomState, + join_type: JoinType, + column_indices: Vec, + join_metrics: Arc, + context: Arc, + accumulator: Arc, + ) -> Self { + Self { + schema, + spill_fut, + spill_left, + spill_right, + on_left, + on_right, + random_state, + join_type, + column_indices, + join_metrics, + context, + accumulator, + state: GraceJoinState::WaitPartitioning, + } + } + + /// Core state machine logic (poll implementation) + fn poll_next_impl(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + match &mut self.state { + GraceJoinState::WaitPartitioning => { + let shared = ready!(self.spill_fut.get_shared(cx))?; + + let acc = Arc::clone(&self.accumulator); + let left = shared.left.clone(); + let right = shared.right.clone(); + let wait_all_fut = if shared.partition == 0 { + OnceFut::new(async move { + acc.report_partition(shared.partition, left, right).await; + let all = acc.wait_all().await; + Ok(all) + }) + } else { + OnceFut::new(async move { + acc.report_partition(shared.partition, left, right).await; + acc.wait_ready().await; + Ok(vec![]) + }) + }; + self.state = GraceJoinState::WaitAllPartitions { wait_all_fut: Some(wait_all_fut) }; + continue; + } + GraceJoinState::WaitAllPartitions { wait_all_fut } => { + if let Some(fut) = wait_all_fut { + let all_arc = ready!(fut.get_shared(cx))?; + let mut all = (*all_arc).clone(); + all.sort_by_key(|s| s.partition); + + self.state = GraceJoinState::JoinPartition { + current: 0, + all_parts: Arc::from(all), + current_stream: None, + left_fut: None, + right_fut: None, + }; + continue; + } else { + return Poll::Pending; + } + } + GraceJoinState::JoinPartition { + current, + all_parts, + current_stream, + left_fut, + right_fut, + } => { + if *current >= all_parts.len() { + self.state = GraceJoinState::Done; + continue; + } + + // If we don't have a stream yet, create one for the current partition pair + if current_stream.is_none() { + if left_fut.is_none() && right_fut.is_none() { + let spill_fut = &all_parts[*current]; + *left_fut = Some(load_partition_async(Arc::clone(&self.spill_left), spill_fut.left.clone())); + *right_fut = Some(load_partition_async(Arc::clone(&self.spill_right), spill_fut.right.clone())); + } + + let left_batches = (*ready!(left_fut.as_mut().unwrap().get_shared(cx))?).clone(); + let right_batches = (*ready!(right_fut.as_mut().unwrap().get_shared(cx))?).clone(); + + let stream = build_in_memory_join_stream( + Arc::clone(&self.schema), + left_batches, + right_batches, + &self.on_left, + &self.on_right, + self.random_state.clone(), + self.join_type, + &self.column_indices, + &self.join_metrics, + &self.context, + )?; + + *current_stream = Some(stream); + *left_fut = None; + *right_fut = None; + } + + // Drive current stream forward + if let Some(stream) = current_stream { + match ready!(stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => return Poll::Ready(Some(Ok(batch))), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => { + *current += 1; + *current_stream = None; + continue; + } + } + } + } + GraceJoinState::Done => return Poll::Ready(None), + } + } + } +} + +fn load_partition_async( + spill_manager: Arc, + partitions: Vec, +) -> OnceFut> { + OnceFut::new(async move { + let mut all_batches = Vec::new(); + for p in partitions { + for chunk in p.chunks { + let mut reader = spill_manager.load_spilled_batch(&chunk)?; + while let Some(batch_result) = reader.next().await { + let batch = batch_result?; + all_batches.push(batch); + } + } + } + Ok(all_batches) + }) +} + +/// Build an in-memory HashJoinExec for one pair of spilled partitions +fn build_in_memory_join_stream( + output_schema: SchemaRef, + left_batches: Vec, + right_batches: Vec, + on_left: &[PhysicalExprRef], + on_right: &[PhysicalExprRef], + random_state: RandomState, + join_type: JoinType, + column_indices: &[ColumnIndex], + join_metrics: &BuildProbeJoinMetrics, + context: &Arc, +) -> Result { + if left_batches.is_empty() && right_batches.is_empty() { + return EmptyExec::new(output_schema).execute(0, Arc::clone(context)); + } + + let left_schema = left_batches + .first() + .map(|b| b.schema()) + .unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty())); + + let right_schema = right_batches + .first() + .map(|b| b.schema()) + .unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty())); + + // Build memory execution nodes for each side + let left_plan: Arc = + Arc::new(TestMemoryExec::try_new(&[left_batches], left_schema, None)?); + let right_plan: Arc = + Arc::new(TestMemoryExec::try_new(&[right_batches], right_schema, None)?); + + // Combine join expressions into pairs + let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = on_left + .iter() + .cloned() + .zip(on_right.iter().cloned()) + .collect(); + + // For one partition pair: always CollectLeft (build left, stream right) + let join_exec = HashJoinExec::try_new( + left_plan, + right_plan, + on, + None::, + &join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + + // Each join executes locally with the same context + join_exec.execute(0, Arc::clone(context)) +} + +impl Stream for GraceHashJoinStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +#[derive(Debug)] +pub struct GraceAccumulator { + expected: usize, + collected: Mutex>, + notify: tokio::sync::Notify, +} + +impl GraceAccumulator { + pub fn new(expected: usize) -> Arc { + Arc::new(Self { + expected, + collected: Mutex::new(vec![]), + notify: tokio::sync::Notify::new(), + }) + } + + pub async fn report_partition( + &self, + part_id: usize, + left_idx: Vec, + right_idx: Vec, + ) { + let mut guard = self.collected.lock().await; + if let Some(pos) = guard.iter().position(|s| s.partition == part_id) { + guard[pos] = SpillFut::new(part_id, left_idx, right_idx); + } else { + guard.push(SpillFut::new(part_id, left_idx, right_idx)); + } + + if guard.len() == self.expected { + self.notify.notify_waiters(); + } + } + + pub async fn wait_all( + &self, + ) -> Vec { + loop { + { + let guard = self.collected.lock().await; + if guard.len() == self.expected { + return guard.clone(); + } + } + self.notify.notified().await; + } + } + pub async fn wait_ready(&self) { + loop { + { + let guard = self.collected.lock().await; + if guard.len() == self.expected { + return; + } + } + self.notify.notified().await; + } + } +} \ No newline at end of file diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 7f1e5cae13a3e..612134604c7b2 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -20,5 +20,5 @@ pub use exec::HashJoinExec; mod exec; -mod shared_bounds; +pub mod shared_bounds; mod stream; diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434e..14429ec55182a 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -28,6 +28,8 @@ pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; +mod grace_hash_join; + mod nested_loop_join; mod sort_merge_join; mod stream_join_utils; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d392650f88dda..1c49454f1a038 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1663,7 +1663,7 @@ pub fn update_hash( hashes_buffer: &mut Vec, deleted_offset: usize, fifo_hashmap: bool, -) -> Result<()> { +) -> Result> { // evaluate the keys let keys_values = on .iter() @@ -1688,7 +1688,7 @@ pub fn update_hash( hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset); } - Ok(()) + Ok(keys_values) } pub(super) fn equal_rows_arr( diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 270b3654b2bad..9d19e72cb379a 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -478,7 +478,7 @@ mod tests { assert_eq!(spilled_rows, num_rows); for spill in results { - let stream = spill_manager.load_spilled_batch(spill)?; + let stream = spill_manager.load_spilled_batch(&spill)?; let collected = collect(stream).await?; assert!(!collected.is_empty()); assert_eq!(collected[0].schema(), schema); diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index e8f5037f764bc..2c4d6b7844cd9 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -173,30 +173,37 @@ impl SpillManager { /// Automatically decides whether to spill the given RecordBatch to memory or disk, /// depending on available memory pool capacity. pub(crate) fn spill_batch_auto(&self, batch: &RecordBatch, request_msg: &str) -> Result { - let size = batch.get_sliced_size()?; - - // Check current memory usage and total limit from the runtime memory pool - let used = self.env.memory_pool.reserved(); - let limit = match self.env.memory_pool.memory_limit() { - datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, - _ => usize::MAX, + let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { + return Err(DataFusionError::Execution( + "failed to spill batch to disk".into(), + )); }; - - // If there's enough memory (with a small safety margin), keep it in memory - if used + size * 3 / 2 <= limit { - let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); - self.metrics.spilled_bytes.add(size); - self.metrics.spilled_rows.add(batch.num_rows()); - Ok(SpillLocation::Memory(buf)) - } else { - // Otherwise spill to disk using the existing SpillManager logic - let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { - return Err(DataFusionError::Execution( - "failed to spill batch to disk".into(), - )); - }; - Ok(SpillLocation::Disk(file)) - } + Ok(SpillLocation::Disk(Arc::new(file))) + // + // let size = batch.get_sliced_size()?; + // + // // Check current memory usage and total limit from the runtime memory pool + // let used = self.env.memory_pool.reserved(); + // let limit = match self.env.memory_pool.memory_limit() { + // datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, + // _ => usize::MAX, + // }; + // + // // If there's enough memory (with a small safety margin), keep it in memory + // if used + size * 3 / 2 <= limit { + // let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); + // self.metrics.spilled_bytes.add(size); + // self.metrics.spilled_rows.add(batch.num_rows()); + // Ok(SpillLocation::Memory(buf)) + // } else { + // // Otherwise spill to disk using the existing SpillManager logic + // let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { + // return Err(DataFusionError::Execution( + // "failed to spill batch to disk".into(), + // )); + // }; + // Ok(SpillLocation::Disk(Arc::new(file))) + // } } pub fn spill_batches_auto( @@ -226,21 +233,33 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + pub fn read_spill_as_stream_ref( + &self, + spill_file_path: &RefCountedTempFile, + ) -> Result { + let stream = Box::pin(cooperative(SpillReaderStream::new( + Arc::clone(&self.schema), + spill_file_path.clone_refcounted()?, + ))); + + Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + } + pub fn load_spilled_batch( &self, - spill: SpillLocation, + spill: &SpillLocation, ) -> Result { match spill { - SpillLocation::Memory(buf) => Ok(buf.as_stream(Arc::clone(&self.schema))?), - SpillLocation::Disk(file) => self.read_spill_as_stream(file), + SpillLocation::Memory(buf) => Ok(Arc::clone(&buf).as_stream(Arc::clone(&self.schema))?), + SpillLocation::Disk(file) => self.read_spill_as_stream_ref(file), } } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum SpillLocation { Memory(Arc), - Disk(RefCountedTempFile), + Disk(Arc), } From aa4b43901597fbb764598489b2ffa8d1b901c1d1 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Tue, 4 Nov 2025 13:33:37 +0200 Subject: [PATCH 12/36] Comment out printlns --- datafusion/execution/src/disk_manager.rs | 2 +- .../physical-plan/src/joins/hash_join/exec.rs | 31 +- .../src/joins/hash_join/partitioned.rs | 888 +++++++++--------- .../src/joins/hash_join/stream.rs | 193 ++-- 4 files changed, 561 insertions(+), 553 deletions(-) diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 05fd0a9a3b27b..82f2d75ac1b57 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -35,7 +35,7 @@ const DEFAULT_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB /// Builder pattern for the [DiskManager] structure #[derive(Clone, Debug)] pub struct DiskManagerBuilder { -/// The storage mode of the disk manager + /// The storage mode of the disk manager mode: DiskManagerMode, /// The maximum amount of data (in bytes) stored inside the temporary directories. /// Default to 100GB diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 97ffe3d2f0745..c05fb1cc14be7 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -905,7 +905,7 @@ impl ExecutionPlan for HashJoinExec { partition: usize, context: Arc, ) -> Result { - println!("Executing HashJoinExec"); + //println!("Executing HashJoinExec"); let on_left = self .on .iter() @@ -1023,7 +1023,9 @@ impl ExecutionPlan for HashJoinExec { )) })?; - let make_bounds_accumulator = |right_plan: &Arc| { + let make_bounds_accumulator = |right_plan: &Arc< + dyn ExecutionPlan, + >| { if enable_dynamic_filter_pushdown { self.dynamic_filter.as_ref().map(|df| { let filter = Arc::clone(&df.filter); @@ -1033,13 +1035,15 @@ impl ExecutionPlan for HashJoinExec { .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::>(); Arc::clone(df.bounds_accumulator.get_or_init(|| { - Arc::new(SharedBoundsAccumulator::new_from_partition_mode( - self.mode, - left_plan.as_ref(), - right_plan.as_ref(), - filter, - on_right, - )) + Arc::new( + SharedBoundsAccumulator::new_from_partition_mode( + self.mode, + left_plan.as_ref(), + right_plan.as_ref(), + filter, + on_right, + ), + ) })) }) } else { @@ -1061,7 +1065,8 @@ impl ExecutionPlan for HashJoinExec { ))) }; let right_stream = right_plan.execute(0, Arc::clone(&context))?; - let shared_bounds_accumulator = make_bounds_accumulator(&right_plan); + let shared_bounds_accumulator = + make_bounds_accumulator(&right_plan); let column_indices_after_projection = match &self.projection { Some(projection) => projection .iter() @@ -1138,10 +1143,8 @@ impl ExecutionPlan for HashJoinExec { let partitioned_reservation = MemoryConsumer::new("PartitionedHashJoin") .register(context.memory_pool()); - let probe_spill_metrics = - SpillMetrics::new(&self.metrics, partition); - let build_spill_metrics = - SpillMetrics::new(&self.metrics, partition); + let probe_spill_metrics = SpillMetrics::new(&self.metrics, partition); + let build_spill_metrics = SpillMetrics::new(&self.metrics, partition); let partitioned_stream = PartitionedHashJoinStream::new( partition, self.schema(), diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 0611117a40954..34a3bd962031c 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -290,36 +290,34 @@ impl PartitionedHashJoinStream { ) -> Poll> { if let Some(ref accumulator) = self.bounds_accumulator { if self.bounds_waiter.is_none() { - println!( - "[spill-join] partition={} reporting build bounds (rows={})", - self.partition, - build_data.batch().num_rows() - ); + // println!( + // "[spill-join] partition={} reporting build bounds (rows={})", + // self.partition, + // build_data.batch().num_rows() + // ); let accumulator = Arc::clone(accumulator); let partition = self.partition; let bounds = build_data.bounds.clone(); self.bounds_waiter = Some(OnceFut::new(async move { - accumulator - .report_partition_bounds(partition, bounds) - .await + accumulator.report_partition_bounds(partition, bounds).await })); } if let Some(waiter) = self.bounds_waiter.as_mut() { match waiter.get(cx) { Poll::Ready(Ok(_)) => { - println!( - "[spill-join] partition={} build bounds reported", - self.partition - ); + // println!( + // "[spill-join] partition={} build bounds reported", + // self.partition + // ); self.bounds_waiter = None; } Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => { - println!( - "[spill-join] partition={} waiting on shared bounds barrier", - self.partition - ); + // println!( + // "[spill-join] partition={} waiting on shared bounds barrier", + // self.partition + // ); return Poll::Pending; } } @@ -357,10 +355,10 @@ impl PartitionedHashJoinStream { self.pending_reload_stream = Some(stream); self.pending_reload_batches.clear(); self.pending_reload_partition = Some(part_id); - println!( - "[spill-join][reload] start partition {}", - part_id - ); + // println!( + // "[spill-join][reload] start partition {}", + // part_id + // ); } } @@ -369,14 +367,14 @@ impl PartitionedHashJoinStream { if let Some(stream) = self.pending_reload_stream.as_mut() { match stream.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { - println!( - "[spill-join][reload] partition {} batch rows={}", - part_id, - batch.num_rows() - ); + // println!( + // "[spill-join][reload] partition {} batch rows={}", + // part_id, + // batch.num_rows() + // ); self.pending_reload_batches.push(batch); return Poll::Pending; - }, + } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Ready(None) => { // Concatenate @@ -393,11 +391,11 @@ impl PartitionedHashJoinStream { ) .map_err(DataFusionError::from)?; - println!( - "Reloaded spilled build partition {} for probing (rows={})", - part_id, - concatenated.num_rows() - ); + // println!( + // "Reloaded spilled build partition {} for probing (rows={})", + // part_id, + // concatenated.num_rows() + // ); // Grow global reservation conservatively by concatenated batch size let concat_size = concatenated.get_array_memory_size(); @@ -442,17 +440,17 @@ impl PartitionedHashJoinStream { reservation: new_reservation, }; - if let Some(BuildPartition::InMemory { + /*if let Some(BuildPartition::InMemory { hash_map, batch, .. }) = self.build_partitions.get(part_id) { - println!( - "Reloaded partition {} hashmap empty? {} rows={}", - part_id, - hash_map.is_empty(), - batch.num_rows() - ); - } + // println!( + // "Reloaded partition {} hashmap empty? {} rows={}", + // part_id, + // hash_map.is_empty(), + // batch.num_rows() + // ); + }*/ self.pending_reload_stream = None; self.pending_reload_batches.clear(); @@ -462,11 +460,11 @@ impl PartitionedHashJoinStream { return Poll::Ready(Ok(())); } Poll::Pending => { - println!( - "[spill-join][reload] partition {} pending batches={}", - part_id, - self.pending_reload_batches.len() - ); + // println!( + // "[spill-join][reload] partition {} pending batches={}", + // part_id, + // self.pending_reload_batches.len() + // ); return Poll::Pending; } } @@ -592,11 +590,11 @@ impl PartitionedHashJoinStream { match self.right.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { // Compute ON values for the full batch (once) - println!( - "[spill-join] probe batch rows={} schema={:?}", - batch.num_rows(), - batch.schema().fields().len() - ); + // println!( + // "[spill-join] probe batch rows={} schema={:?}", + // batch.num_rows(), + // batch.schema().fields().len() + // ); let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); for c in &self.on_right { @@ -623,11 +621,11 @@ impl PartitionedHashJoinStream { } let indices_arr: UInt32Array = part_indices.clone().into(); if self.probe_partitions[part_id].batches.is_empty() { - println!( - "[spill-join] probe partition {} first rows {:?}", - part_id, - &part_indices[..part_indices.len().min(10)] - ); + // println!( + // "[spill-join] probe partition {} first rows {:?}", + // part_id, + // &part_indices[..part_indices.len().min(10)] + // ); } // Take data columns @@ -676,11 +674,11 @@ impl PartitionedHashJoinStream { self.probe_spill_in_progress[part_id] { ipf.append_batch(&filtered_batch)?; - println!( - "[spill-join][probe-spill] write partition={} rows={}", - part_id, - filtered_batch.num_rows() - ); + // println!( + // "[spill-join][probe-spill] write partition={} rows={}", + // part_id, + // filtered_batch.num_rows() + // ); } self.probe_spilled_rows_per_part[part_id] += filtered_batch.num_rows(); @@ -713,25 +711,25 @@ impl PartitionedHashJoinStream { // Finished buffering self.probes_buffered = true; self.probe_batch_positions = vec![0; self.num_partitions]; - println!( - "[spill-join] probe buffered rows per partition = {:?}", - self.probe_partitions - .iter() - .enumerate() - .map(|(i, p)| (i, p.batches.iter().map(|b| b.num_rows()).sum::())) - .collect::>() - ); + // println!( + // "[spill-join] probe buffered rows per partition = {:?}", + // self.probe_partitions + // .iter() + // .enumerate() + // .map(|(i, p)| (i, p.batches.iter().map(|b| b.num_rows()).sum::())) + // .collect::>() + // ); // Finalize any in-progress probe spill files for part_id in 0..self.num_partitions { if let Some(mut ipf) = self.probe_spill_in_progress[part_id].take() { if let Some(file) = ipf.finish()? { - println!( - "[spill-join][probe-spill] finalize partition={} rows_spilled={}", - part_id, - self.probe_spilled_rows_per_part[part_id] - ); + // println!( + // "[spill-join][probe-spill] finalize partition={} rows_spilled={}", + // part_id, + // self.probe_spilled_rows_per_part[part_id] + // ); self.probe_spill_files[part_id] = Some(file); } } @@ -739,13 +737,13 @@ impl PartitionedHashJoinStream { return Poll::Ready(Ok(())); } Poll::Pending => { - println!( - "[spill-join][probe-buffer] pending batches buffered={:?} spilled_rows={:?}", - self.probe_buffered_rows_per_part, - self.probe_spilled_rows_per_part - ); + // println!( + // "[spill-join][probe-buffer] pending batches buffered={:?} spilled_rows={:?}", + // self.probe_buffered_rows_per_part, + // self.probe_spilled_rows_per_part + // ); return Poll::Pending; - }, + } } } } @@ -755,10 +753,10 @@ impl PartitionedHashJoinStream { &mut self, build_data: Arc, ) -> Result>> { - println!( - "Partitioning build side data into {} partitions", - self.num_partitions - ); + // println!( + // "Partitioning build side data into {} partitions", + // self.num_partitions + // ); // Metrics: record build input self.join_metrics.build_input_batches.add(1); self.join_metrics @@ -784,10 +782,10 @@ impl PartitionedHashJoinStream { for (row_idx, &hash) in hashes.iter().enumerate() { let partition_id = self.partition_for_hash(hash); if row_idx < 10 { - println!( - "[spill-join] build row {} hash={} -> partition {}", - row_idx, hash, partition_id - ); + // println!( + // "[spill-join] build row {} hash={} -> partition {}", + // row_idx, hash, partition_id + // ); } partition_batches[partition_id].push(row_idx); } @@ -854,13 +852,13 @@ impl PartitionedHashJoinStream { } if will_spill && self.runtime_env.disk_manager.tmp_files_enabled() { - println!( - "Spilling build partition {} (rows={}) due to memory threshold (threshold={} bytes, current={})", - partition_id, - row_indices.len(), - self.memory_threshold, - self.memory_reservation.size() - ); + // println!( + // "Spilling build partition {} (rows={}) due to memory threshold (threshold={} bytes, current={})", + // partition_id, + // row_indices.len(), + // self.memory_threshold, + // self.memory_reservation.size() + // ); // Spill this partition to disk and do not keep it in memory let spill_file = self .build_spill_manager @@ -906,11 +904,11 @@ impl PartitionedHashJoinStream { let iter = self.hashes_buffer.iter().enumerate().map(|(i, h)| (i, h)); partition_hash_map.update_from_iter(Box::new(iter), 0); - println!( - "Built in-memory hash map for partition {} (rows={})", - partition_id, - row_indices.len() - ); + // println!( + // "Built in-memory hash map for partition {} (rows={})", + // partition_id, + // row_indices.len() + // ); // Metrics: approximate build memory used (batch + values) let approx = partition_batch.get_array_memory_size() + partition_values @@ -931,11 +929,11 @@ impl PartitionedHashJoinStream { .with_can_spill(true) .register(&self.runtime_env.memory_pool); - let is_empty_after = partition_hash_map.is_empty(); - println!( - "Partition {} hashmap empty after build? {}", - partition_id, is_empty_after - ); + //let is_empty_after = partition_hash_map.is_empty(); + // println!( + // "Partition {} hashmap empty after build? {}", + // partition_id, is_empty_after + // ); self.build_partitions.push(BuildPartition::InMemory { hash_map: partition_hash_map, @@ -949,10 +947,10 @@ impl PartitionedHashJoinStream { // This ensures a single stream can process all partitions when the // operator reports a single output partition. let start_partition = 0; - println!( - "Partitioning complete. Created {} partitions. Starting to process partition {}", - self.build_partitions.len(), start_partition - ); + // println!( + // "Partitioning complete. Created {} partitions. Starting to process partition {}", + // self.build_partitions.len(), start_partition + // ); self.state = PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { partition_id: start_partition, @@ -1078,12 +1076,12 @@ impl PartitionedHashJoinStream { self.state = PartitionedHashJoinState::HandleUnmatchedRows; return Poll::Ready(Ok(StatefulStreamResult::Continue)); } - println!( - "Processing partition {} (total_partitions={}), build_partitions.len()={}", - partition_state.partition_id, - partition_state.total_partitions, - self.build_partitions.len() - ); + // println!( + // "Processing partition {} (total_partitions={}), build_partitions.len()={}", + // partition_state.partition_id, + // partition_state.total_partitions, + // self.build_partitions.len() + // ); // Do not buffer probe side here; selection happens below depending on num_partitions @@ -1167,10 +1165,10 @@ impl PartitionedHashJoinStream { self.pending_probe_partition = Some(partition_state.partition_id); } else { // Spilled probe indicated but file not yet finalized: wait - println!( - "[spill-join] Waiting for spilled probe file for partition {}", - partition_state.partition_id - ); + // println!( + // "[spill-join] Waiting for spilled probe file for partition {}", + // partition_state.partition_id + // ); return Poll::Pending; } } @@ -1205,34 +1203,34 @@ impl PartitionedHashJoinStream { [partition_state.partition_id] .saturating_add(b.num_rows()); } - println!( - "[spill-join][probe-spill] partition={} batch rows={}", - partition_state.partition_id, - self.current_probe_batch - .as_ref() - .map(|b| b.num_rows()) - .unwrap_or(0) - ); + // println!( + // "[spill-join][probe-spill] partition={} batch rows={}", + // partition_state.partition_id, + // self.current_probe_batch + // .as_ref() + // .map(|b| b.num_rows()) + // .unwrap_or(0) + // ); } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Ready(None) => { // Finished probe for this partition; advance self.pending_probe_stream = None; self.pending_probe_partition = None; - println!( - "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - partition_state.partition_id, - self.probe_buffered_rows_per_part[partition_state.partition_id], - self.probe_spilled_rows_per_part[partition_state.partition_id], - self.probe_consumed_rows_per_part[partition_state.partition_id], - self.candidate_pairs_per_part[partition_state.partition_id], - self.matched_rows_per_part[partition_state.partition_id], - self.emitted_rows_per_part[partition_state.partition_id] - ); - println!( - "[spill-join][probe-spill] partition={} stream complete", - partition_state.partition_id - ); + // println!( + // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + // partition_state.partition_id, + // self.probe_buffered_rows_per_part[partition_state.partition_id], + // self.probe_spilled_rows_per_part[partition_state.partition_id], + // self.probe_consumed_rows_per_part[partition_state.partition_id], + // self.candidate_pairs_per_part[partition_state.partition_id], + // self.matched_rows_per_part[partition_state.partition_id], + // self.emitted_rows_per_part[partition_state.partition_id] + // ); + // println!( + // "[spill-join][probe-spill] partition={} stream complete", + // partition_state.partition_id + // ); self.release_partition_resources( partition_state.partition_id, ); @@ -1263,16 +1261,16 @@ impl PartitionedHashJoinStream { // No stream available; nothing to read, advance self.pending_probe_stream = None; self.pending_probe_partition = None; - println!( - "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - partition_state.partition_id, - self.probe_buffered_rows_per_part[partition_state.partition_id], - self.probe_spilled_rows_per_part[partition_state.partition_id], - self.probe_consumed_rows_per_part[partition_state.partition_id], - self.candidate_pairs_per_part[partition_state.partition_id], - self.matched_rows_per_part[partition_state.partition_id], - self.emitted_rows_per_part[partition_state.partition_id] - ); + // println!( + // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + // partition_state.partition_id, + // self.probe_buffered_rows_per_part[partition_state.partition_id], + // self.probe_spilled_rows_per_part[partition_state.partition_id], + // self.probe_consumed_rows_per_part[partition_state.partition_id], + // self.candidate_pairs_per_part[partition_state.partition_id], + // self.matched_rows_per_part[partition_state.partition_id], + // self.emitted_rows_per_part[partition_state.partition_id] + // ); self.release_partition_resources(partition_state.partition_id); if partition_state.is_last_partition { self.state = PartitionedHashJoinState::HandleUnmatchedRows; @@ -1291,16 +1289,16 @@ impl PartitionedHashJoinStream { } } else { // Neither spilled nor buffered probe for this partition: advance - println!( - "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - partition_state.partition_id, - self.probe_buffered_rows_per_part[partition_state.partition_id], - self.probe_spilled_rows_per_part[partition_state.partition_id], - self.probe_consumed_rows_per_part[partition_state.partition_id], - self.candidate_pairs_per_part[partition_state.partition_id], - self.matched_rows_per_part[partition_state.partition_id], - self.emitted_rows_per_part[partition_state.partition_id] - ); + // println!( + // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + // partition_state.partition_id, + // self.probe_buffered_rows_per_part[partition_state.partition_id], + // self.probe_spilled_rows_per_part[partition_state.partition_id], + // self.probe_consumed_rows_per_part[partition_state.partition_id], + // self.candidate_pairs_per_part[partition_state.partition_id], + // self.matched_rows_per_part[partition_state.partition_id], + // self.emitted_rows_per_part[partition_state.partition_id] + // ); self.release_partition_resources(partition_state.partition_id); if partition_state.is_last_partition { self.state = PartitionedHashJoinState::HandleUnmatchedRows; @@ -1320,16 +1318,16 @@ impl PartitionedHashJoinStream { // If no probe batch selected, advance to next partition (no probe rows here) if self.current_probe_batch.is_none() { - println!( - "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - partition_state.partition_id, - self.probe_buffered_rows_per_part[partition_state.partition_id], - self.probe_spilled_rows_per_part[partition_state.partition_id], - self.probe_consumed_rows_per_part[partition_state.partition_id], - self.candidate_pairs_per_part[partition_state.partition_id], - self.matched_rows_per_part[partition_state.partition_id], - self.emitted_rows_per_part[partition_state.partition_id] - ); + // println!( + // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", + // partition_state.partition_id, + // self.probe_buffered_rows_per_part[partition_state.partition_id], + // self.probe_spilled_rows_per_part[partition_state.partition_id], + // self.probe_consumed_rows_per_part[partition_state.partition_id], + // self.candidate_pairs_per_part[partition_state.partition_id], + // self.matched_rows_per_part[partition_state.partition_id], + // self.emitted_rows_per_part[partition_state.partition_id] + // ); self.release_partition_resources(partition_state.partition_id); if partition_state.is_last_partition { self.state = PartitionedHashJoinState::HandleUnmatchedRows; @@ -1372,7 +1370,7 @@ impl PartitionedHashJoinStream { } }; // Debug: log ON expressions and output mapping once we have both sides - let on_left_desc = self + /* let on_left_desc = self .on_left .iter() .map(|e| format!("{}", e)) @@ -1396,14 +1394,14 @@ impl PartitionedHashJoinStream { format!("{}@{}", side, ci.index) }) .collect::>() - .join(", "); - println!( - "[spill-join] ON build=[{}] | probe=[{}] | out=[{}]", - on_left_desc, on_right_desc, mapping_desc - ); + .join(", ");*/ + // println!( + // "[spill-join] ON build=[{}] | probe=[{}] | out=[{}]", + // on_left_desc, on_right_desc, mapping_desc + // ); // Log resolved output column names for the current mapping - let out_names = self + /*let out_names = self .column_indices .iter() .map(|ci| match ci.side { @@ -1417,13 +1415,13 @@ impl PartitionedHashJoinStream { }) .collect::>() .join(", "); - println!("[spill-join] OUT columns: {}", out_names); + // println!("[spill-join] OUT columns: {}", out_names); - println!( - "[spill-join] Partition {} build hashmap empty? {}", - partition_state.partition_id, - build_hashmap.is_empty() - ); + // println!( + // "[spill-join] Partition {} build hashmap empty? {}", + // partition_state.partition_id, + // build_hashmap.is_empty() + // );*/ // Lookup against hash map with limit let (probe_indices, build_indices, next_offset) = build_hashmap @@ -1440,13 +1438,13 @@ impl PartitionedHashJoinStream { self.candidate_pairs_per_part[partition_state.partition_id] = self .candidate_pairs_per_part[partition_state.partition_id] .saturating_add(build_indices.len()); - println!( - "[spill-join] Candidates before equality: build_ids={}, probe_ids={}, build_rows={}, probe_rows={}", - build_indices.len(), - probe_indices.len(), - build_batch.num_rows(), - probe_batch.num_rows() - ); + // println!( + // "[spill-join] Candidates before equality: build_ids={}, probe_ids={}, build_rows={}, probe_rows={}", + // build_indices.len(), + // probe_indices.len(), + // build_batch.num_rows(), + // probe_batch.num_rows() + // ); // Resolve hash collisions let (build_indices, probe_indices) = equal_rows_arr( @@ -1458,7 +1456,7 @@ impl PartitionedHashJoinStream { )?; // Shadow verify on INNER join with single Int64 key (first 50k rows) - if matches!(self.join_type, JoinType::Inner) + /*if matches!(self.join_type, JoinType::Inner) && build_values.len() == 1 && self.current_probe_values.len() == 1 && build_values[0].data_type() == &arrow::datatypes::DataType::Int64 @@ -1485,7 +1483,7 @@ impl PartitionedHashJoinStream { let k = bcol.value(i); *map.entry(k).or_insert(0) += 1; } - let mut expect = 0usize; + /*let mut expect = 0usize; let max_p = pcol.len().min(50_000); for i in 0..max_p { if pcol.is_null(i) { @@ -1496,18 +1494,18 @@ impl PartitionedHashJoinStream { expect += c; } } - println!( - "[spill-join][verify] part={} expect_pairs~{} vs actual_after_eq={}", - partition_state.partition_id, - expect, - build_indices.len() - ); + // println!( + // "[spill-join][verify] part={} expect_pairs~{} vs actual_after_eq={}", + // partition_state.partition_id, + // expect, + // build_indices.len() + // );*/ self.verify_once_per_part[partition_state.partition_id] = true; - } + }*/ // Debug: log key data types and sample matched pairs - if !build_indices.is_empty() { - let build_types = build_values + /*if !build_indices.is_empty() { + /*let build_types = build_values .iter() .map(|a| format!("{:?}", a.data_type())) .collect::>() @@ -1517,11 +1515,11 @@ impl PartitionedHashJoinStream { .iter() .map(|a| format!("{:?}", a.data_type())) .collect::>() - .join(", "); - println!( - "[spill-join] Key types: build=[{}], probe=[{}], null_equality={:?}", - build_types, probe_types, self.null_equality - ); + .join(", ");*/ + // println!( + // "[spill-join] Key types: build=[{}], probe=[{}], null_equality={:?}", + // build_types, probe_types, self.null_equality + // ); let sample = build_indices.len().min(5); let mut pairs = Vec::new(); for i in 0..sample { @@ -1536,85 +1534,46 @@ impl PartitionedHashJoinStream { .unwrap_or_else(|_| "".to_string()); pairs.push(format!("({},{})", bv, pv)); } - println!( - "[spill-join] Sample key pairs {} -> {}: {}", - sample, - build_indices.len(), - pairs.join(", ") - ); - } + // println!( + // "[spill-join] Sample key pairs {} -> {}: {}", + // sample, + // build_indices.len(), + // pairs.join(", ") + // ); + }*/ // Apply residual join filter if present - let mut build_indices = build_indices; - let mut probe_indices = probe_indices; - if let Some(filter) = &self.filter { - let before_len = build_indices.len(); - let before_build_indices = build_indices.clone(); - let before_probe_indices = probe_indices.clone(); - - let (filtered_build_indices, filtered_probe_indices) = - apply_join_filter_to_indices( - build_batch, - probe_batch, - build_indices, - probe_indices, - filter, - JoinSide::Left, - None, - )?; - - if !self.filter_debug_once_per_part[partition_state.partition_id] { - println!( - "[spill-join][filter-debug] part={} filter_before={} filter_after={}", - partition_state.partition_id, - before_len, - filtered_build_indices.len() - ); - - let sample = filtered_build_indices.len().min(5); - for i in 0..sample { - let build_row = filtered_build_indices.value(i) as usize; - let probe_row = filtered_probe_indices.value(i) as usize; - - let build_schema = build_batch.schema(); - let build_vals = (0..build_batch.num_columns()) - .map(|col| { - let name = build_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - build_batch.column(col).as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); - - let probe_schema = probe_batch.schema(); - let probe_vals = (0..probe_batch.num_columns()) - .map(|col| { - let name = probe_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - probe_batch.column(col).as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); - - println!( - "[spill-join][filter-debug] sample {} build {{{}}} probe {{{}}}", - i, build_vals, probe_vals - ); - } - - if filtered_build_indices.len() == 0 { - let sample_removed = before_build_indices.len().min(5); - for i in 0..sample_removed { - let build_row = before_build_indices.value(i) as usize; - let probe_row = before_probe_indices.value(i) as usize; + let mut build_indices = build_indices; + let mut probe_indices = probe_indices; + if let Some(filter) = &self.filter { + let before_len = build_indices.len(); + // let before_build_indices = build_indices.clone(); + //let before_probe_indices = probe_indices.clone(); + + let (filtered_build_indices, filtered_probe_indices) = + apply_join_filter_to_indices( + build_batch, + probe_batch, + build_indices, + probe_indices, + filter, + JoinSide::Left, + None, + )?; + + if !self.filter_debug_once_per_part[partition_state.partition_id] { + /* + // println!( + // "[spill-join][filter-debug] part={} filter_before={} filter_after={}", + // partition_state.partition_id, + // before_len, + // filtered_build_indices.len() + // ); + + let sample = filtered_build_indices.len().min(5); + for i in 0..sample { + let build_row = filtered_build_indices.value(i) as usize; + let probe_row = filtered_probe_indices.value(i) as usize; let build_schema = build_batch.schema(); let build_vals = (0..build_batch.num_columns()) @@ -1644,102 +1603,145 @@ impl PartitionedHashJoinStream { .collect::>() .join(", "); - println!( - "[spill-join][filter-debug] removed sample {} build {{{}}} probe {{{}}}", - i, build_vals, probe_vals - ); + // println!( + // "[spill-join][filter-debug] sample {} build {{{}}} probe {{{}}}", + // i, build_vals, probe_vals + // ); } + + if filtered_build_indices.len() == 0 { + let sample_removed = before_build_indices.len().min(5); + for i in 0..sample_removed { + let build_row = before_build_indices.value(i) as usize; + let probe_row = before_probe_indices.value(i) as usize; + + let build_schema = build_batch.schema(); + let build_vals = (0..build_batch.num_columns()) + .map(|col| { + let name = build_schema.field(col).name(); + let value = + arrow::util::display::array_value_to_string( + build_batch.column(col).as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + + let probe_schema = probe_batch.schema(); + /*let probe_vals = (0..probe_batch.num_columns()) + .map(|col| { + let name = probe_schema.field(col).name(); + let value = + arrow::util::display::array_value_to_string( + probe_batch.column(col).as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", ");*/ + + // println!( + // "[spill-join][filter-debug] removed sample {} build {{{}}} probe {{{}}}", + // i, build_vals, probe_vals + // ); + } + }*/ + + self.filter_debug_once_per_part[partition_state.partition_id] = true; } - self.filter_debug_once_per_part[partition_state.partition_id] = true; - } + if before_len != filtered_build_indices.len() { + // println!( + // "[spill-join][filter-debug] part={} filter removed {} rows", + // partition_state.partition_id, + // before_len - filtered_build_indices.len() + // ); + } - if before_len != filtered_build_indices.len() { - println!( - "[spill-join][filter-debug] part={} filter removed {} rows", - partition_state.partition_id, - before_len - filtered_build_indices.len() - ); + build_indices = filtered_build_indices; + probe_indices = filtered_probe_indices; } - build_indices = filtered_build_indices; - probe_indices = filtered_probe_indices; - } - - // Capture matched build indices prior to alignment so we can mark bitmaps even if - // the join type drops them (e.g. LeftAnti emits matches only in the final phase). - let build_indices_for_marking = if need_produce_result_in_final(self.join_type) { - Some(build_indices.clone()) - } else { - None - }; + // Capture matched build indices prior to alignment so we can mark bitmaps even if + // the join type drops them (e.g. LeftAnti emits matches only in the final phase). + let build_indices_for_marking = + if need_produce_result_in_final(self.join_type) { + Some(build_indices.clone()) + } else { + None + }; - // Log sample matches even if no residual filter remains, to debug equality behavior - if !self.filter_debug_once_per_part[partition_state.partition_id] - || build_indices.len() != probe_indices.len() - { - let sample = build_indices.len().min(5); - for i in 0..sample { - let build_row = build_indices.value(i) as usize; - let probe_row = probe_indices.value(i) as usize; - - let build_schema = build_batch.schema(); - let build_vals = (0..build_batch.num_columns()) - .map(|col| { - let name = build_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - build_batch.column(col).as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); + // Log sample matches even if no residual filter remains, to debug equality behavior + /*if !self.filter_debug_once_per_part[partition_state.partition_id] + || build_indices.len() != probe_indices.len() + { + let sample = build_indices.len().min(5); + for i in 0..sample { + let build_row = build_indices.value(i) as usize; + let probe_row = probe_indices.value(i) as usize; - let probe_schema = probe_batch.schema(); - let probe_vals = (0..probe_batch.num_columns()) - .map(|col| { - let name = probe_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - probe_batch.column(col).as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); + let build_schema = build_batch.schema(); + let build_vals = (0..build_batch.num_columns()) + .map(|col| { + let name = build_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + build_batch.column(col).as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); - println!( - "[spill-join][match-debug] part={} pair {} build {{{}}} probe {{{}}}", - partition_state.partition_id, - i, - build_vals, - probe_vals - ); - } + let probe_schema = probe_batch.schema(); + /* let probe_vals = (0..probe_batch.num_columns()) + .map(|col| { + let name = probe_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + probe_batch.column(col).as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", ");*/ + + // println!( + // "[spill-join][match-debug] part={} pair {} build {{{}}} probe {{{}}}", + // partition_state.partition_id, + // i, + // build_vals, + // probe_vals + // ); + } - if build_indices.len() != probe_indices.len() { - println!( - "[spill-join][match-debug] part={} MISMATCH len build={} probe={}", - partition_state.partition_id, - build_indices.len(), - probe_indices.len() - ); - } + if build_indices.len() != probe_indices.len() { + // println!( + // "[spill-join][match-debug] part={} MISMATCH len build={} probe={}", + // partition_state.partition_id, + // build_indices.len(), + // probe_indices.len() + // ); + } - self.filter_debug_once_per_part[partition_state.partition_id] = true; - } + self.filter_debug_once_per_part[partition_state.partition_id] = true; + }*/ // Debug counter: post-equality (before any alignment) - println!( - "[spill-join] After equality{} (pre-align): {}", - if self.filter.is_some() { "+filter" } else { "" }, - build_indices.len() - ); + // println!( + // "[spill-join] After equality{} (pre-align): {}", + // if self.filter.is_some() { "+filter" } else { "" }, + // build_indices.len() + // ); // Shadow verify for two-key joins (stringified) to catch type coercion issues - if matches!(self.join_type, JoinType::Inner) + /*if matches!(self.join_type, JoinType::Inner) && build_values.len() == 2 && self.current_probe_values.len() == 2 && !self.verify_once_per_part[partition_state.partition_id] @@ -1761,7 +1763,6 @@ impl PartitionedHashJoinStream { let key = format!("{}|{}", k0, k1); *map.entry(key).or_insert(0) += 1; } - let mut expect = 0usize; let max_p = probe_batch.num_rows().min(50_000); for i in 0..max_p { let k0 = arrow::util::display::array_value_to_string( @@ -1779,14 +1780,14 @@ impl PartitionedHashJoinStream { expect += c; } } - println!( - "[spill-join][verify2] part={} expect_pairs~{} vs actual_after_eq={}", - partition_state.partition_id, - expect, - build_indices.len() - ); + // println!( + // "[spill-join][verify2] part={} expect_pairs~{} vs actual_after_eq={}", + // partition_state.partition_id, + // expect, + // build_indices.len() + // ); self.verify_once_per_part[partition_state.partition_id] = true; - } + }*/ // Accumulate matched rows per partition self.matched_rows_per_part[partition_state.partition_id] = self .matched_rows_per_part[partition_state.partition_id] @@ -1831,15 +1832,16 @@ impl PartitionedHashJoinStream { ); // Debug counter: after alignment (or effective no-op for other join types) - println!("[spill-join] After alignment: {}", build_indices.len()); + // println!("[spill-join] After alignment: {}", build_indices.len()); // Prepare ids for marking after we release borrows. Prefer the pre-alignment // matches (for join types like LeftAnti) so bitmap tracking remains accurate. - let build_ids_to_mark: Vec = if let Some(indices) = build_indices_for_marking { - indices.values().to_vec() - } else { - build_indices.values().to_vec() - }; + let build_ids_to_mark: Vec = + if let Some(indices) = build_indices_for_marking { + indices.values().to_vec() + } else { + build_indices.values().to_vec() + }; // Track last joined probe row only for right-oriented joins; otherwise clear it self.joined_probe_idx = if needs_alignment && next_offset.is_some() { last_joined_right_idx @@ -1853,12 +1855,12 @@ impl PartitionedHashJoinStream { JoinType::RightMark | JoinType::RightSemi | JoinType::RightAnti ) { if matches!(self.join_type, JoinType::RightMark) { - println!("[spill-join] Building output with JoinSide::Right (RightMark)"); + // println!("[spill-join] Building output with JoinSide::Right (RightMark)"); } else { - println!( - "[spill-join] Building output with JoinSide::Right ({:?})", - self.join_type - ); + // println!( + // "[spill-join] Building output with JoinSide::Right ({:?})", + // self.join_type + // ); } let right_indices_u64 = uint32_to_uint64_indices(&probe_indices); build_batch_from_indices( @@ -1913,19 +1915,19 @@ impl PartitionedHashJoinStream { } if result.num_rows() == 0 { - println!( - "[spill-join] Skipping empty batch emission (partition={})", - partition_state.partition_id - ); + // println!( + // "[spill-join] Skipping empty batch emission (partition={})", + // partition_state.partition_id + // ); return Poll::Ready(Ok(StatefulStreamResult::Continue)); } self.join_metrics.output_batches.add(1); self.join_metrics.baseline.record_output(result.num_rows()); - println!( - "[spill-join] Emitting batch: rows={} (partition={})", - result.num_rows(), - partition_state.partition_id - ); + // println!( + // "[spill-join] Emitting batch: rows={} (partition={})", + // result.num_rows(), + // partition_state.partition_id + // ); Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))) } @@ -1989,13 +1991,13 @@ impl PartitionedHashJoinStream { let empty_right_batch = RecordBatch::new_empty(Arc::clone(&self.probe_schema)); - println!( - "Emitting unmatched rows chunk: partition={}, offset={}, size={} (total={})", - self.unmatched_partition, - self.unmatched_offset, - to_emit, - total - ); + // println!( + // "Emitting unmatched rows chunk: partition={}, offset={}, size={} (total={})", + // self.unmatched_partition, + // self.unmatched_offset, + // to_emit, + // total + // ); let result = build_batch_from_indices( &self.schema, @@ -2013,10 +2015,10 @@ impl PartitionedHashJoinStream { self.unmatched_left_indices_cache = None; self.unmatched_right_indices_cache = None; self.unmatched_offset = 0; - println!( - "Finished emitting unmatched rows for partition {}", - self.unmatched_partition - ); + // println!( + // "Finished emitting unmatched rows for partition {}", + // self.unmatched_partition + // ); self.unmatched_partition += 1; } @@ -2054,11 +2056,11 @@ impl PartitionedHashJoinStream { return Poll::Ready(Ok(StatefulStreamResult::Continue)); }; - println!( - "Unmatched calculation for partition {} -> {} rows", - self.unmatched_partition, - left_indices.len() - ); + // println!( + // "Unmatched calculation for partition {} -> {} rows", + // self.unmatched_partition, + // left_indices.len() + // ); if left_indices.len() > 0 { // Cache the full indices and emit first chunk via cached path next call @@ -2092,11 +2094,11 @@ impl PartitionedHashJoinStream { if let Some(stream) = self.pending_reload_stream.as_mut() { match stream.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { - println!( - "Reload stream yielded batch for build partition {} (rows={})", - self.unmatched_partition, - batch.num_rows() - ); + // println!( + // "Reload stream yielded batch for build partition {} (rows={})", + // self.unmatched_partition, + // batch.num_rows() + // ); self.pending_reload_batches.push(batch); return Poll::Pending; } @@ -2117,11 +2119,11 @@ impl PartitionedHashJoinStream { ) .map_err(DataFusionError::from)?; - println!( - "Reloaded spilled build partition {} for unmatched rows (rows={})", - self.unmatched_partition, - concatenated.num_rows() - ); + // println!( + // "Reloaded spilled build partition {} for unmatched rows (rows={})", + // self.unmatched_partition, + // concatenated.num_rows() + // ); let new_reservation = MemoryConsumer::new("partition_reload_unmatched") @@ -2145,10 +2147,10 @@ impl PartitionedHashJoinStream { values, reservation: new_reservation, }; - println!( - "Prepared spilled partition {} as InMemory for unmatched emission", - self.unmatched_partition - ); + // println!( + // "Prepared spilled partition {} as InMemory for unmatched emission", + // self.unmatched_partition + // ); // Clear pending self.pending_reload_stream = None; @@ -2162,11 +2164,11 @@ impl PartitionedHashJoinStream { } Poll::Pending => { // Yield until more data is available from reload stream - println!( - "Reload stream pending for build partition {} (accumulated_batches={})", - self.unmatched_partition, - self.pending_reload_batches.len() - ); + // println!( + // "Reload stream pending for build partition {} (accumulated_batches={})", + // self.unmatched_partition, + // self.pending_reload_batches.len() + // ); return Poll::Pending; } } @@ -2221,10 +2223,10 @@ impl Stream for PartitionedHashJoinStream { match self.partition_build_side(left_data) { Ok(StatefulStreamResult::Continue) => continue, Ok(StatefulStreamResult::Ready(Some(batch))) => { - println!( - "[spill-join] poll_next yielding initial batch: rows={}", - batch.num_rows() - ); + // println!( + // "[spill-join] poll_next yielding initial batch: rows={}", + // batch.num_rows() + // ); return Poll::Ready(Some(Ok(batch))); } Ok(StatefulStreamResult::Ready(None)) => { @@ -2238,18 +2240,18 @@ impl Stream for PartitionedHashJoinStream { if self.num_partitions > 1 && !self.placeholder_emitted { self.placeholder_emitted = true; let empty = RecordBatch::new_empty(self.schema.clone()); - println!( - "[spill-join] Emitting placeholder empty batch for partition {}", - partition_state.partition_id - ); + // println!( + // "[spill-join] Emitting placeholder empty batch for partition {}", + // partition_state.partition_id + // ); return Poll::Ready(Some(Ok(empty))); } match self.process_partition(cx, &partition_state) { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { - println!( - "[spill-join] poll_next yielding process batch: rows={} (state partition={})", - batch.num_rows(), partition_state.partition_id - ); + // println!( + // "[spill-join] poll_next yielding process batch: rows={} (state partition={})", + // batch.num_rows(), partition_state.partition_id + // ); return Poll::Ready(Some(Ok(batch))); } Poll::Ready(Ok(StatefulStreamResult::Ready(None))) => { @@ -2265,10 +2267,10 @@ impl Stream for PartitionedHashJoinStream { PartitionedHashJoinState::HandleUnmatchedRows => { match self.handle_unmatched_rows(cx) { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { - println!( - "[spill-join] poll_next yielding unmatched batch: rows={}", - batch.num_rows() - ); + // println!( + // "[spill-join] poll_next yielding unmatched batch: rows={}", + // batch.num_rows() + // ); return Poll::Ready(Some(Ok(batch))); } Poll::Ready(Ok(StatefulStreamResult::Ready(None))) => { diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 68b28a61a09f1..f2abe2b942c11 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -291,7 +291,7 @@ pub(super) fn lookup_join_hashmap( )?; // Shadow verify for two-key INNER joins to catch coercion issues in classic path - if build_side_values.len() == 2 && probe_side_values.len() == 2 { + /*if build_side_values.len() == 2 && probe_side_values.len() == 2 { use std::collections::HashMap; let mut map: HashMap = HashMap::new(); let max_b = build_side_values[0].len().min(50_000); @@ -309,7 +309,6 @@ pub(super) fn lookup_join_hashmap( let key = format!("{}|{}", k0, k1); *map.entry(key).or_insert(0) += 1; } - let mut expect = 0usize; let max_p = probe_side_values[0].len().min(50_000); for i in 0..max_p { let k0 = arrow::util::display::array_value_to_string( @@ -327,12 +326,12 @@ pub(super) fn lookup_join_hashmap( expect += c; } } - println!( - "[hash-join][verify2] expect_pairs~{} vs actual_after_eq={}", - expect, - build_indices.len() - ); - } + // println!( + // "[hash-join][verify2] expect_pairs~{} vs actual_after_eq={}", + // expect, + // build_indices.len() + // ); + }*/ Ok((build_indices, probe_indices, next_offset)) } @@ -611,19 +610,19 @@ impl HashJoinStream { index_alignment_range_end = index_alignment_range_start; } - if matches!( + /* if matches!( self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark ) { - println!( - "[hash-join] Align {:?}: pre-adjust right_indices={}, range={}..{} (next_offset_present={})", - self.join_type, - right_indices.len(), - index_alignment_range_start, - index_alignment_range_end, - next_offset.is_some() - ); - } + // println!( + // "[hash-join] Align {:?}: pre-adjust right_indices={}, range={}..{} (next_offset_present={})", + // self.join_type, + // right_indices.len(), + // index_alignment_range_start, + // index_alignment_range_end, + // next_offset.is_some() + // ); + }*/ let (left_indices, right_indices) = adjust_indices_by_join_type( left_indices, @@ -633,28 +632,28 @@ impl HashJoinStream { self.right_side_ordered, )?; - if matches!( + /* if matches!( self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark ) { - println!( - "[hash-join] Align {:?}: post-adjust unique_right_indices={} (range={}..{})", - self.join_type, - right_indices.len(), - index_alignment_range_start, - index_alignment_range_end - ); + // println!( + // "[hash-join] Align {:?}: post-adjust unique_right_indices={} (range={}..{})", + // self.join_type, + // right_indices.len(), + // index_alignment_range_start, + // index_alignment_range_end + // ); } if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { - println!( - "[hash-join] Right {:?}: probe_batch_rows={}, unique_matched_right_indices={} (range={}..{})", - self.join_type, - state.batch.num_rows(), - right_indices.len(), - index_alignment_range_start, - index_alignment_range_end - ); + // println!( + // "[hash-join] Right {:?}: probe_batch_rows={}, unique_matched_right_indices={} (range={}..{})", + // self.join_type, + // state.batch.num_rows(), + // right_indices.len(), + // index_alignment_range_start, + // index_alignment_range_end + // ); } // Log some matched pairs for debugging @@ -667,7 +666,7 @@ impl HashJoinStream { let build_row = left_indices.value(i) as usize; let probe_row = right_indices.value(i) as usize; - let build_vals = (0..build_schema.fields().len()) + /*let build_vals = (0..build_schema.fields().len()) .map(|col| { let name = build_schema.field(col).name(); let value = arrow::util::display::array_value_to_string( @@ -679,8 +678,8 @@ impl HashJoinStream { }) .collect::>() .join(", "); - - let probe_vals = (0..probe_schema.fields().len()) +*/ + /*let probe_vals = (0..probe_schema.fields().len()) .map(|col| { let name = probe_schema.field(col).name(); let value = arrow::util::display::array_value_to_string( @@ -691,45 +690,49 @@ impl HashJoinStream { format!("{}={}", name, value) }) .collect::>() - .join(", "); - - println!( - "[hash-join][match-debug] partition={} pair {} build {{{}}} probe {{{}}}", - self.partition, - i, - build_vals, - probe_vals - ); + .join(", ");*/ + + // println!( + // "[hash-join][match-debug] partition={} pair {} build {{{}}} probe {{{}}}", + // self.partition, + // i, + // build_vals, + // probe_vals + // ); } } - let build_supply_idx = build_schema - .fields() - .iter() - .enumerate() - .find_map(|(idx, f)| { - if f.name().to_ascii_lowercase().contains("ps_supplycost") { - Some(idx) - } else { - None - } - }); - - let probe_min_idx = probe_schema - .fields() - .iter() - .enumerate() - .find_map(|(idx, f)| { - if f.name().to_ascii_lowercase().contains("min(") - || f.name().to_ascii_lowercase().contains("min_") - { - Some(idx) - } else { - None - } - }); - - if let (Some(build_supply_idx), Some(probe_min_idx)) = (build_supply_idx, probe_min_idx) { + let build_supply_idx = + build_schema + .fields() + .iter() + .enumerate() + .find_map(|(idx, f)| { + if f.name().to_ascii_lowercase().contains("ps_supplycost") { + Some(idx) + } else { + None + } + }); + + let probe_min_idx = + probe_schema + .fields() + .iter() + .enumerate() + .find_map(|(idx, f)| { + if f.name().to_ascii_lowercase().contains("min(") + || f.name().to_ascii_lowercase().contains("min_") + { + Some(idx) + } else { + None + } + }); + + if let (Some(build_supply_idx), Some(probe_min_idx)) = + (build_supply_idx, probe_min_idx) + { let build_array = build_side.left_data.batch().column(build_supply_idx); let probe_array = state.batch.column(probe_min_idx); @@ -749,32 +752,32 @@ impl HashJoinStream { .unwrap_or_else(|_| "".to_string()); if build_value != probe_value { - println!( - "[hash-join][mismatch] partition={} build_row={} ps_supplycost={} min_cost={}", - self.partition, - build_row, - build_value, - probe_value - ); + // println!( + // "[hash-join][mismatch] partition={} build_row={} ps_supplycost={} min_cost={}", + // self.partition, + // build_row, + // build_value, + // probe_value + // ); break; } } } else { - println!( - "[hash-join][mismatch-debug] partition={} build_fields={:?} probe_fields={:?}", - self.partition, - build_schema - .fields() - .iter() - .map(|f| f.name().clone()) - .collect::>(), - probe_schema - .fields() - .iter() - .map(|f| f.name().clone()) - .collect::>() - ); - } + // println!( + // "[hash-join][mismatch-debug] partition={} build_fields={:?} probe_fields={:?}", + // self.partition, + // build_schema + // .fields() + // .iter() + // .map(|f| f.name().clone()) + // .collect::>(), + // probe_schema + // .fields() + // .iter() + // .map(|f| f.name().clone()) + // .collect::>() + // ); + }*/ let result = if matches!( self.join_type, From 07dc4d0d0a944220295c5c49395f235cbad7de8b Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Tue, 4 Nov 2025 16:44:30 +0200 Subject: [PATCH 13/36] Comment2 remove --- .../physical-plan/src/joins/hash_join/exec.rs | 37 ++- .../src/joins/hash_join/partitioned.rs | 4 +- .../src/joins/hash_join/stream.rs | 268 +++++++++--------- 3 files changed, 154 insertions(+), 155 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index c05fb1cc14be7..2e313ce251802 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -1023,19 +1023,18 @@ impl ExecutionPlan for HashJoinExec { )) })?; - let make_bounds_accumulator = |right_plan: &Arc< - dyn ExecutionPlan, - >| { - if enable_dynamic_filter_pushdown { - self.dynamic_filter.as_ref().map(|df| { - let filter = Arc::clone(&df.filter); - let on_right = self - .on - .iter() - .map(|(_, right_expr)| Arc::clone(right_expr)) - .collect::>(); - Arc::clone(df.bounds_accumulator.get_or_init(|| { - Arc::new( + let make_bounds_accumulator = + |right_plan: &Arc| { + if enable_dynamic_filter_pushdown { + self.dynamic_filter.as_ref().map(|df| { + let filter = Arc::clone(&df.filter); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + Arc::clone(df.bounds_accumulator.get_or_init(|| { + Arc::new( SharedBoundsAccumulator::new_from_partition_mode( self.mode, left_plan.as_ref(), @@ -1044,12 +1043,12 @@ impl ExecutionPlan for HashJoinExec { on_right, ), ) - })) - }) - } else { - None - } - }; + })) + }) + } else { + None + } + }; // For Right-side oriented joins, fall back to standard HashJoinStream for correctness if matches!( diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 34a3bd962031c..db2d425639c3e 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -1370,7 +1370,7 @@ impl PartitionedHashJoinStream { } }; // Debug: log ON expressions and output mapping once we have both sides - /* let on_left_desc = self + /* let on_left_desc = self .on_left .iter() .map(|e| format!("{}", e)) @@ -1547,7 +1547,7 @@ impl PartitionedHashJoinStream { let mut probe_indices = probe_indices; if let Some(filter) = &self.filter { let before_len = build_indices.len(); - // let before_build_indices = build_indices.clone(); + // let before_build_indices = build_indices.clone(); //let before_probe_indices = probe_indices.clone(); let (filtered_build_indices, filtered_probe_indices) = diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index f2abe2b942c11..affbe2495bacf 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -610,7 +610,7 @@ impl HashJoinStream { index_alignment_range_end = index_alignment_range_start; } - /* if matches!( + /* if matches!( self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark ) { @@ -632,152 +632,152 @@ impl HashJoinStream { self.right_side_ordered, )?; - /* if matches!( - self.join_type, - JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark - ) { - // println!( - // "[hash-join] Align {:?}: post-adjust unique_right_indices={} (range={}..{})", - // self.join_type, - // right_indices.len(), - // index_alignment_range_start, - // index_alignment_range_end - // ); - } + /* if matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ) { + // println!( + // "[hash-join] Align {:?}: post-adjust unique_right_indices={} (range={}..{})", + // self.join_type, + // right_indices.len(), + // index_alignment_range_start, + // index_alignment_range_end + // ); + } - if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { - // println!( - // "[hash-join] Right {:?}: probe_batch_rows={}, unique_matched_right_indices={} (range={}..{})", - // self.join_type, - // state.batch.num_rows(), - // right_indices.len(), - // index_alignment_range_start, - // index_alignment_range_end - // ); - } + if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { + // println!( + // "[hash-join] Right {:?}: probe_batch_rows={}, unique_matched_right_indices={} (range={}..{})", + // self.join_type, + // state.batch.num_rows(), + // right_indices.len(), + // index_alignment_range_start, + // index_alignment_range_end + // ); + } - // Log some matched pairs for debugging - let build_schema = build_side.left_data.batch().schema(); - let probe_schema = state.batch.schema(); - - let sample = left_indices.len().min(5); - if sample > 0 { - for i in 0..sample { - let build_row = left_indices.value(i) as usize; - let probe_row = right_indices.value(i) as usize; - - /*let build_vals = (0..build_schema.fields().len()) - .map(|col| { - let name = build_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - build_side.left_data.batch().column(col).as_ref(), + // Log some matched pairs for debugging + let build_schema = build_side.left_data.batch().schema(); + let probe_schema = state.batch.schema(); + + let sample = left_indices.len().min(5); + if sample > 0 { + for i in 0..sample { + let build_row = left_indices.value(i) as usize; + let probe_row = right_indices.value(i) as usize; + + /*let build_vals = (0..build_schema.fields().len()) + .map(|col| { + let name = build_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + build_side.left_data.batch().column(col).as_ref(), + build_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", "); + */ + /*let probe_vals = (0..probe_schema.fields().len()) + .map(|col| { + let name = probe_schema.field(col).name(); + let value = arrow::util::display::array_value_to_string( + state.batch.column(col).as_ref(), + probe_row, + ) + .unwrap_or_else(|_| "".to_string()); + format!("{}={}", name, value) + }) + .collect::>() + .join(", ");*/ + + // println!( + // "[hash-join][match-debug] partition={} pair {} build {{{}}} probe {{{}}}", + // self.partition, + // i, + // build_vals, + // probe_vals + // ); + } + } + + let build_supply_idx = + build_schema + .fields() + .iter() + .enumerate() + .find_map(|(idx, f)| { + if f.name().to_ascii_lowercase().contains("ps_supplycost") { + Some(idx) + } else { + None + } + }); + + let probe_min_idx = + probe_schema + .fields() + .iter() + .enumerate() + .find_map(|(idx, f)| { + if f.name().to_ascii_lowercase().contains("min(") + || f.name().to_ascii_lowercase().contains("min_") + { + Some(idx) + } else { + None + } + }); + + if let (Some(build_supply_idx), Some(probe_min_idx)) = + (build_supply_idx, probe_min_idx) + { + let build_array = build_side.left_data.batch().column(build_supply_idx); + let probe_array = state.batch.column(probe_min_idx); + + for j in 0..left_indices.len() { + let build_row = left_indices.value(j) as usize; + let probe_row = right_indices.value(j) as usize; + + let build_value = arrow::util::display::array_value_to_string( + build_array.as_ref(), build_row, ) .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); -*/ - /*let probe_vals = (0..probe_schema.fields().len()) - .map(|col| { - let name = probe_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - state.batch.column(col).as_ref(), + let probe_value = arrow::util::display::array_value_to_string( + probe_array.as_ref(), probe_row, ) .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", ");*/ - - // println!( - // "[hash-join][match-debug] partition={} pair {} build {{{}}} probe {{{}}}", - // self.partition, - // i, - // build_vals, - // probe_vals - // ); - } - } - let build_supply_idx = - build_schema - .fields() - .iter() - .enumerate() - .find_map(|(idx, f)| { - if f.name().to_ascii_lowercase().contains("ps_supplycost") { - Some(idx) - } else { - None - } - }); - - let probe_min_idx = - probe_schema - .fields() - .iter() - .enumerate() - .find_map(|(idx, f)| { - if f.name().to_ascii_lowercase().contains("min(") - || f.name().to_ascii_lowercase().contains("min_") - { - Some(idx) - } else { - None + if build_value != probe_value { + // println!( + // "[hash-join][mismatch] partition={} build_row={} ps_supplycost={} min_cost={}", + // self.partition, + // build_row, + // build_value, + // probe_value + // ); + break; + } } - }); - - if let (Some(build_supply_idx), Some(probe_min_idx)) = - (build_supply_idx, probe_min_idx) - { - let build_array = build_side.left_data.batch().column(build_supply_idx); - let probe_array = state.batch.column(probe_min_idx); - - for j in 0..left_indices.len() { - let build_row = left_indices.value(j) as usize; - let probe_row = right_indices.value(j) as usize; - - let build_value = arrow::util::display::array_value_to_string( - build_array.as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - let probe_value = arrow::util::display::array_value_to_string( - probe_array.as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - - if build_value != probe_value { + } else { // println!( - // "[hash-join][mismatch] partition={} build_row={} ps_supplycost={} min_cost={}", + // "[hash-join][mismatch-debug] partition={} build_fields={:?} probe_fields={:?}", // self.partition, - // build_row, - // build_value, - // probe_value + // build_schema + // .fields() + // .iter() + // .map(|f| f.name().clone()) + // .collect::>(), + // probe_schema + // .fields() + // .iter() + // .map(|f| f.name().clone()) + // .collect::>() // ); - break; - } - } - } else { - // println!( - // "[hash-join][mismatch-debug] partition={} build_fields={:?} probe_fields={:?}", - // self.partition, - // build_schema - // .fields() - // .iter() - // .map(|f| f.name().clone()) - // .collect::>(), - // probe_schema - // .fields() - // .iter() - // .map(|f| f.name().clone()) - // .collect::>() - // ); - }*/ + }*/ let result = if matches!( self.join_type, From c0832821934db4d28942070cb812d1fa907da77c Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 5 Nov 2025 13:54:09 +0200 Subject: [PATCH 14/36] Spill less in case we fit in memory --- .../physical-plan/src/joins/hash_join/exec.rs | 156 +++- .../src/joins/hash_join/partitioned.rs | 761 +++++++++++------- 2 files changed, 635 insertions(+), 282 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 2e313ce251802..c6eaa3489cbaa 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -85,12 +85,25 @@ use parking_lot::Mutex; const HASH_JOIN_SEED: RandomState = RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); +/// Maximum number of partitions allowed when recursively repartitioning during hybrid hash join. +const HYBRID_HASH_MAX_PARTITIONS: usize = 1 << 16; +/// Upper bound multiplier applied to the initial partition fanout when searching for additional partitions. +const HYBRID_HASH_PARTITION_GROWTH_FACTOR: usize = 16; +/// Approximate number of probe batches worth of rows we target per partition when statistics are available. +const HYBRID_HASH_ROWS_PER_PARTITION_BATCH_MULTIPLIER: usize = 8; +/// Minimum number of bytes we aim to keep in memory per partition when deriving the initial fanout. +const HYBRID_HASH_MIN_BYTES_PER_PARTITION: usize = 8 * 1024 * 1024; +/// Minimum number of rows per partition when statistics are available to avoid extreme fan-out. +const HYBRID_HASH_MIN_ROWS_PER_PARTITION: usize = 1_024; + /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` pub(super) hash_map: Box, /// The input rows for the build side batch: RecordBatch, + /// Original build-side batches before concatenation + original_batches: Arc>, /// The build side on expressions values values: Vec, /// Shared bitmap builder for visited left indices @@ -112,6 +125,7 @@ impl JoinLeftData { pub(super) fn new( hash_map: Box, batch: RecordBatch, + original_batches: Arc>, values: Vec, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, @@ -121,6 +135,7 @@ impl JoinLeftData { Self { hash_map, batch, + original_batches, values, visited_indices_bitmap, probe_threads_counter, @@ -139,6 +154,10 @@ impl JoinLeftData { &self.batch } + pub(super) fn original_batches(&self) -> &[RecordBatch] { + &self.original_batches + } + /// returns a reference to the build side expressions values pub(super) fn values(&self) -> &[ArrayRef] { &self.values @@ -1120,16 +1139,10 @@ impl ExecutionPlan for HashJoinExec { .iter() .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::>(); - let batch_size = context.session_config().batch_size(); - let mut num_partitions = context.session_config().target_partitions(); - if num_partitions == 0 { - num_partitions = 1; - } - let np2 = num_partitions.next_power_of_two(); - let num_partitions = np2.max(1); + let session_config = context.session_config(); + let batch_size = session_config.batch_size(); let memory_threshold = { - let bytes = context - .session_config() + let bytes = session_config .options() .execution .sort_spill_reservation_bytes; @@ -1139,6 +1152,126 @@ impl ExecutionPlan for HashJoinExec { bytes } }; + let existing_partitions = std::cmp::max( + 1, + right_plan.output_partitioning().partition_count(), + ); + let target_partitions = std::cmp::max( + existing_partitions, + session_config.target_partitions(), + ); + let mut num_partitions = target_partitions; + let mut bytes_per_partition = memory_threshold / 2; + bytes_per_partition = bytes_per_partition + .max(HYBRID_HASH_MIN_BYTES_PER_PARTITION) + .min(memory_threshold.max(HYBRID_HASH_MIN_BYTES_PER_PARTITION)); + + let initial_cap = target_partitions + .saturating_mul(HYBRID_HASH_PARTITION_GROWTH_FACTOR) + .min(HYBRID_HASH_MAX_PARTITIONS); + + let mut build_size_bytes: Option = None; + if let Ok(left_stats) = self.left.partition_statistics(None) { + if let Some(total_bytes) = left_stats.total_byte_size.get_value() + { + let total_bytes = *total_bytes; + build_size_bytes = Some(total_bytes); + if total_bytes <= memory_threshold { + num_partitions = 1; + } else if bytes_per_partition > 0 { + let required = total_bytes + .saturating_add(bytes_per_partition - 1) + / bytes_per_partition; + if required > 0 { + num_partitions = + std::cmp::max(num_partitions, required); + } + } + } + + if let Some(num_rows) = left_stats.num_rows.get_value() { + let num_rows = *num_rows; + let min_rows = session_config + .batch_size() + .saturating_mul( + HYBRID_HASH_ROWS_PER_PARTITION_BATCH_MULTIPLIER, + ) + .max(HYBRID_HASH_MIN_ROWS_PER_PARTITION); + if num_partitions > 1 && min_rows > 0 && num_rows > min_rows { + let required = + num_rows.saturating_add(min_rows - 1) / min_rows; + if required > 0 { + num_partitions = + std::cmp::max(num_partitions, required); + } + } + } + } + + if build_size_bytes + .map(|b| b <= memory_threshold) + .unwrap_or(false) + { + return Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), + batch_size, + vec![], + self.right.output_ordering().is_some(), + shared_bounds_accumulator, + ))); + } + + num_partitions = num_partitions + .min(initial_cap) + .clamp(1, HYBRID_HASH_MAX_PARTITIONS); + if num_partitions > 1 && !num_partitions.is_power_of_two() { + num_partitions = num_partitions + .checked_next_power_of_two() + .unwrap_or(num_partitions) + .max(1); + } + + if num_partitions == 1 { + return Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), + batch_size, + vec![], + self.right.output_ordering().is_some(), + shared_bounds_accumulator, + ))); + } + + let mut max_partition_count = if num_partitions == 1 { + 1 + } else { + num_partitions + .saturating_mul(HYBRID_HASH_PARTITION_GROWTH_FACTOR) + .min(HYBRID_HASH_MAX_PARTITIONS) + }; + max_partition_count = + max_partition_count.max(initial_cap).max(num_partitions); let partitioned_reservation = MemoryConsumer::new("PartitionedHashJoin") .register(context.memory_pool()); @@ -1161,6 +1294,7 @@ impl ExecutionPlan for HashJoinExec { self.null_equality, batch_size, num_partitions, + max_partition_count, memory_threshold, partitioned_reservation, context.runtime_env(), @@ -1607,6 +1741,7 @@ async fn collect_left_input( mut reservation, bounds_accumulators, } = state; + let batches_arc = Arc::new(batches); // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` @@ -1633,7 +1768,7 @@ async fn collect_left_input( let mut offset = 0; // Updating hashmap starting from the last batch - let batches_iter = batches.iter().rev(); + let batches_iter = batches_arc.iter().rev(); for batch in batches_iter.clone() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); @@ -1688,6 +1823,7 @@ async fn collect_left_input( let data = JoinLeftData::new( hashmap, single_batch, + Arc::clone(&batches_arc), left_values.clone(), Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index db2d425639c3e..6eda45d38b24e 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -43,12 +43,12 @@ //! - Generates join results and handles unmatched rows for outer joins //! - Tracks matched rows for proper outer join semantics -use std::mem; +use std::mem::{self, size_of}; use std::sync::Arc; use std::task::{Context, Poll}; use crate::joins::hash_join::exec::JoinLeftData; -use crate::joins::join_hash_map::JoinHashMapType; +use crate::joins::join_hash_map::{JoinHashMapType, JoinHashMapU32, JoinHashMapU64}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, equal_rows_arr, get_final_indices_from_bit_map, need_produce_result_in_final, @@ -64,6 +64,7 @@ use arrow::array::{Array, ArrayRef, BooleanBufferBuilder, UInt32Array, UInt64Arr use arrow::compute::{concat_batches, take}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ hash_utils::create_hashes, internal_datafusion_err, internal_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, @@ -140,6 +141,35 @@ pub(super) struct ProbePartition { pub hashes: Vec>, } +enum PartitionBuildStatus { + Ready(StatefulStreamResult>), + NeedMorePartitions { next_count: usize }, +} + +struct PartitionAccumulator { + buffered_batches: Vec, + buffered_bytes: usize, + total_rows: usize, + spill_writer: Option, +} + +impl PartitionAccumulator { + fn new() -> Self { + Self { + buffered_batches: Vec::new(), + buffered_bytes: 0, + total_rows: 0, + spill_writer: None, + } + } +} + +impl Default for PartitionAccumulator { + fn default() -> Self { + Self::new() + } +} + // Use RefCountedTempFile from datafusion_execution::disk_manager /// Partitioned Hash Join stream that can handle large datasets by partitioning @@ -178,6 +208,8 @@ pub(super) struct PartitionedHashJoinStream { pub batch_size: usize, /// Number of partitions to use pub num_partitions: usize, + /// Maximum partition fanout allowed when recursively repartitioning + pub max_partition_count: usize, /// Memory threshold for spilling (in bytes) pub memory_threshold: usize, @@ -199,6 +231,8 @@ pub(super) struct PartitionedHashJoinStream { pub build_spill_manager: SpillManager, /// Memory reservation for the entire operation pub memory_reservation: MemoryReservation, + /// Tracks how many repartition passes have been attempted + pub partition_pass: usize, /// Runtime environment pub runtime_env: Arc, /// Scratch space for computing hashes @@ -214,6 +248,8 @@ pub(super) struct PartitionedHashJoinStream { Option>, /// Future used to synchronize dynamic filter updates across partitions pub bounds_waiter: Option>, + /// Cached build-side schema + pub build_schema: SchemaRef, /// Cached probe-side schema pub probe_schema: SchemaRef, /// Current probe batch (filtered to the active partition), if any @@ -282,6 +318,126 @@ impl PartitionedHashJoinStream { } } + fn resize_partition_vectors(&mut self) { + let n = self.num_partitions; + self.probe_spill_in_progress = (0..n).map(|_| None).collect(); + self.probe_spill_files = (0..n).map(|_| None).collect(); + self.probe_batch_positions = vec![0; n]; + self.probe_buffered_rows_per_part = vec![0; n]; + self.probe_spilled_rows_per_part = vec![0; n]; + self.probe_consumed_rows_per_part = vec![0; n]; + self.matched_rows_per_part = vec![0; n]; + self.emitted_rows_per_part = vec![0; n]; + self.candidate_pairs_per_part = vec![0; n]; + self.verify_once_per_part = vec![false; n]; + self.filter_debug_once_per_part = vec![false; n]; + } + + fn ensure_build_spill_writer<'a>( + &self, + accum: &'a mut PartitionAccumulator, + ) -> Result<&'a mut InProgressSpillFile> { + if accum.spill_writer.is_none() { + accum.spill_writer = Some( + self.build_spill_manager + .create_in_progress_file("hash_join_build_partition")?, + ); + } + Ok(accum.spill_writer.as_mut().unwrap()) + } + + fn spill_partition( + &mut self, + partition_id: usize, + accum: &mut PartitionAccumulator, + ) -> Result<()> { + let buffered_batches = mem::take(&mut accum.buffered_batches); + if buffered_batches.is_empty() { + return Ok(()); + } + + let writer = self.ensure_build_spill_writer(accum)?; + for batch in buffered_batches { + writer.append_batch(&batch)?; + } + if accum.buffered_bytes > 0 { + let _ = self.memory_reservation.try_shrink(accum.buffered_bytes); + accum.buffered_bytes = 0; + } + Ok(()) + } + + fn append_spilled_batch( + &self, + accum: &mut PartitionAccumulator, + batch: RecordBatch, + ) -> Result<()> { + let writer = self.ensure_build_spill_writer(accum)?; + writer.append_batch(&batch)?; + Ok(()) + } + + fn reset_partition_state(&mut self) { + for writer in self.probe_spill_in_progress.iter_mut() { + if let Some(mut writer) = writer.take() { + let _ = writer.finish(); + } + } + + self.build_partitions.clear(); + self.matched_build_rows_per_partition.clear(); + self.current_partition = None; + self.current_probe_batch = None; + self.current_probe_values.clear(); + self.current_probe_hashes.clear(); + self.current_offset = (0, None); + self.joined_probe_idx = None; + self.placeholder_emitted = false; + self.right_alignment_start = 0; + self.unmatched_partition = 0; + self.unmatched_left_indices_cache = None; + self.unmatched_right_indices_cache = None; + self.unmatched_offset = 0; + self.probe_partitions.clear(); + self.probe_batch_positions.clear(); + self.probes_buffered = false; + self.pending_reload_stream = None; + self.pending_reload_batches.clear(); + self.pending_reload_partition = None; + self.pending_probe_stream = None; + self.pending_probe_partition = None; + self.probe_spill_files.clear(); + self.bounds_waiter = None; + + self.resize_partition_vectors(); + + let reserved = self.memory_reservation.size(); + if reserved > 0 { + let _ = self.memory_reservation.try_shrink(reserved); + } + + self.state = PartitionedHashJoinState::PartitionBuildSide; + } + + fn next_partition_count(&self) -> Option { + if self.num_partitions >= self.max_partition_count { + return None; + } + + let mut next = self.num_partitions.saturating_mul(2); + if next <= self.num_partitions { + next = self.num_partitions.saturating_add(1); + } + if next > self.max_partition_count { + next = self.max_partition_count; + } + if next > self.num_partitions { + Some(next) + } else { + None + } + } + /// Report build-side bounds to the shared accumulator when dynamic filtering is enabled fn poll_bounds_update( &mut self, @@ -355,117 +511,98 @@ impl PartitionedHashJoinStream { self.pending_reload_stream = Some(stream); self.pending_reload_batches.clear(); self.pending_reload_partition = Some(part_id); - // println!( - // "[spill-join][reload] start partition {}", - // part_id - // ); } } // Drive stream forward if self.pending_reload_partition == Some(part_id) { if let Some(stream) = self.pending_reload_stream.as_mut() { - match stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - // println!( - // "[spill-join][reload] partition {} batch rows={}", - // part_id, - // batch.num_rows() - // ); - self.pending_reload_batches.push(batch); - return Poll::Pending; - } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(None) => { - // Concatenate - let first_schema = self - .pending_reload_batches - .get(0) - .ok_or_else(|| { - internal_datafusion_err!("empty spilled partition") - })? - .schema(); - let concatenated = concat_batches( - &first_schema, - self.pending_reload_batches.as_slice(), - ) - .map_err(DataFusionError::from)?; - - // println!( - // "Reloaded spilled build partition {} for probing (rows={})", - // part_id, - // concatenated.num_rows() - // ); - - // Grow global reservation conservatively by concatenated batch size - let concat_size = concatenated.get_array_memory_size(); - let _ = self.memory_reservation.try_grow(concat_size); - - // Recompute values and hashmap - let mut values: Vec = - Vec::with_capacity(self.on_left.len()); - for c in &self.on_left { - values.push( - c.evaluate(&concatenated)? - .into_array(concatenated.num_rows())?, - ); + loop { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.pending_reload_batches.push(batch); + // Continue draining ready batches without yielding. + continue; } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + // Concatenate + let first_schema = self + .pending_reload_batches + .get(0) + .ok_or_else(|| { + internal_datafusion_err!("empty spilled partition") + })? + .schema(); + let concatenated = concat_batches( + &first_schema, + self.pending_reload_batches.as_slice(), + ) + .map_err(DataFusionError::from)?; - let mut hash_map: Box = Box::new( - crate::joins::join_hash_map::JoinHashMapU32::with_capacity( - concatenated.num_rows(), - ), - ); - self.hashes_buffer.clear(); - self.hashes_buffer.resize(concatenated.num_rows(), 0); - // Build HT for reloaded partition from precomputed key arrays (no re-eval) - create_hashes( - &values, - &self.random_state, - &mut self.hashes_buffer, - )?; - hash_map.extend_zero(concatenated.num_rows()); - let iter = - self.hashes_buffer.iter().enumerate().map(|(i, h)| (i, h)); - hash_map.update_from_iter(Box::new(iter), 0); - - let new_reservation = MemoryConsumer::new("partition_reload") - .with_can_spill(true) - .register(&self.runtime_env.memory_pool); - - self.build_partitions[part_id] = BuildPartition::InMemory { - hash_map, - batch: concatenated, - values, - reservation: new_reservation, - }; - - /*if let Some(BuildPartition::InMemory { - hash_map, batch, .. - }) = self.build_partitions.get(part_id) - { // println!( - // "Reloaded partition {} hashmap empty? {} rows={}", + // "Reloaded spilled build partition {} for probing (rows={})", // part_id, - // hash_map.is_empty(), - // batch.num_rows() + // concatenated.num_rows() // ); - }*/ - self.pending_reload_stream = None; - self.pending_reload_batches.clear(); - self.pending_reload_partition = None; - // Shrink global reservation now that partition is resident with per-partition reservation - let _ = self.memory_reservation.try_shrink(concat_size); - return Poll::Ready(Ok(())); - } - Poll::Pending => { - // println!( - // "[spill-join][reload] partition {} pending batches={}", - // part_id, - // self.pending_reload_batches.len() - // ); - return Poll::Pending; + // Grow global reservation conservatively by concatenated batch size + let concat_size = concatenated.get_array_memory_size(); + let _ = self.memory_reservation.try_grow(concat_size); + + // Recompute values and hashmap + let mut values: Vec = + Vec::with_capacity(self.on_left.len()); + for c in &self.on_left { + values.push( + c.evaluate(&concatenated)? + .into_array(concatenated.num_rows())?, + ); + } + + let mut hash_map: Box = Box::new( + JoinHashMapU32::with_capacity(concatenated.num_rows()), + ); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(concatenated.num_rows(), 0); + // Build HT for reloaded partition from precomputed key arrays (no re-eval) + create_hashes( + &values, + &self.random_state, + &mut self.hashes_buffer, + )?; + hash_map.extend_zero(concatenated.num_rows()); + let iter = self + .hashes_buffer + .iter() + .enumerate() + .map(|(i, h)| (i, h)); + hash_map.update_from_iter(Box::new(iter), 0); + + let new_reservation = MemoryConsumer::new("partition_reload") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + let concat_rows = concatenated.num_rows(); + + self.build_partitions[part_id] = BuildPartition::InMemory { + hash_map, + batch: concatenated, + values, + reservation: new_reservation, + }; + + self.pending_reload_stream = None; + self.pending_reload_batches.clear(); + self.pending_reload_partition = None; + // Shrink global reservation now that partition is resident with per-partition reservation + let _ = self.memory_reservation.try_shrink(concat_size); + return Poll::Ready(Ok(())); + } + Poll::Pending => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } } } } @@ -491,6 +628,7 @@ impl PartitionedHashJoinStream { null_equality: NullEquality, batch_size: usize, num_partitions: usize, + max_partition_count: usize, memory_threshold: usize, memory_reservation: MemoryReservation, runtime_env: Arc, @@ -528,6 +666,7 @@ impl PartitionedHashJoinStream { null_equality, batch_size, num_partitions, + max_partition_count, memory_threshold, state: PartitionedHashJoinState::PartitionBuildSide, build_partitions: Vec::new(), @@ -536,6 +675,7 @@ impl PartitionedHashJoinStream { probe_spill_manager, build_spill_manager, memory_reservation, + partition_pass: 0, runtime_env, hashes_buffer: Vec::new(), right_side_ordered, @@ -543,6 +683,7 @@ impl PartitionedHashJoinStream { right_alignment_start: 0, bounds_accumulator, bounds_waiter: None, + build_schema, probe_schema, current_probe_batch: None, current_probe_values: vec![], @@ -577,6 +718,14 @@ impl PartitionedHashJoinStream { /// Buffer the entire probe side stream into per-partition batches. /// Returns Pending until the right stream is fully consumed. fn buffer_probe_side(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.probe_spill_in_progress.len() != self.num_partitions + || self.probe_spill_files.len() != self.num_partitions + { + self.resize_partition_vectors(); + } + if self.probe_partitions.len() != self.num_partitions { + self.probe_partitions.clear(); + } if self.probe_partitions.is_empty() { self.probe_partitions = (0..self.num_partitions) .map(|_| ProbePartition { @@ -753,214 +902,280 @@ impl PartitionedHashJoinStream { &mut self, build_data: Arc, ) -> Result>> { - // println!( - // "Partitioning build side data into {} partitions", - // self.num_partitions - // ); - // Metrics: record build input - self.join_metrics.build_input_batches.add(1); - self.join_metrics - .build_input_rows - .add(build_data.batch().num_rows()); - // Initialize partitions - self.build_partitions = Vec::with_capacity(self.num_partitions); - // Initialize per-partition matched rows bitmaps - self.matched_build_rows_per_partition = Vec::with_capacity(self.num_partitions); - - // Extract build-side data - let batch = build_data.batch(); - let values = build_data.values(); + if self.partition_pass == 0 { + self.join_metrics.build_input_batches.add(1); + let total_rows: usize = build_data + .original_batches() + .iter() + .map(|b| b.num_rows()) + .sum(); + self.join_metrics.build_input_rows.add(total_rows); + } - // Compute hash values for all rows in the build-side batch - let mut hashes = vec![0u64; batch.num_rows()]; - create_hashes(values, &self.random_state, &mut hashes)?; + let build_total_size: usize = build_data + .original_batches() + .iter() + .map(|batch| batch.get_array_memory_size()) + .sum(); + if build_total_size <= self.memory_threshold { + self.num_partitions = 1; + self.max_partition_count = 1; + } - // Partition the data based on hash values - let mut partition_batches: Vec> = - vec![Vec::new(); self.num_partitions]; + let mut allow_repartition = true; + loop { + self.reset_partition_state(); + + match self.try_partition_build_side(&build_data, allow_repartition)? { + PartitionBuildStatus::Ready(result) => return Ok(result), + PartitionBuildStatus::NeedMorePartitions { next_count } => { + if next_count <= self.num_partitions + || next_count == 0 + || next_count > self.max_partition_count + { + allow_repartition = false; + continue; + } - for (row_idx, &hash) in hashes.iter().enumerate() { - let partition_id = self.partition_for_hash(hash); - if row_idx < 10 { - // println!( - // "[spill-join] build row {} hash={} -> partition {}", - // row_idx, hash, partition_id - // ); + self.num_partitions = next_count; + self.partition_pass += 1; + allow_repartition = true; + } } - partition_batches[partition_id].push(row_idx); } + } - // Create partitions; spill when memory_threshold is exceeded - for partition_id in 0..self.num_partitions { - let row_indices = &partition_batches[partition_id]; - if row_indices.is_empty() { - // Empty partition - create empty hash map - let empty_hash_map: Box = Box::new( - crate::joins::join_hash_map::JoinHashMapU32::with_capacity(0), - ); - let empty_batch = batch.slice(0, 0); - let empty_values: Vec = - values.iter().map(|arr| arr.slice(0, 0)).collect(); + fn try_partition_build_side( + &mut self, + build_data: &Arc, + allow_repartition: bool, + ) -> Result { + self.build_partitions = Vec::with_capacity(self.num_partitions); + self.matched_build_rows_per_partition = Vec::with_capacity(self.num_partitions); - // Initialize empty matched rows bitmap for this partition - let matched_bitmap = BooleanBufferBuilder::new(0); - self.matched_build_rows_per_partition.push(matched_bitmap); + let mut partition_accumulators = (0..self.num_partitions) + .map(|_| PartitionAccumulator::new()) + .collect::>(); + let mut repartition_request: Option = None; - self.build_partitions.push(BuildPartition::InMemory { - hash_map: empty_hash_map, - batch: empty_batch, - values: empty_values, - reservation: MemoryConsumer::new("empty_partition") - .with_can_spill(true) - .register(&self.runtime_env.memory_pool), - }); - continue; + for (batch_index, batch) in build_data.original_batches().iter().enumerate() { + let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); + for expr in &self.on_left { + keys_values.push(expr.evaluate(batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = + vec![Vec::new(); self.num_partitions]; + for (row_idx, hash) in hashes.iter().enumerate() { + let partition_id = self.partition_for_hash(*hash); + indices_per_part[partition_id].push(row_idx as u32); } - // Create batch slice for this partition - let partition_batch = self.take_rows(batch, row_indices)?; - let partition_values: Vec = values - .iter() - .map(|arr| self.take_rows_from_array(arr, row_indices)) - .collect::>>()?; + for (partition_id, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } - // Estimate memory for this partition - let estimated_size = partition_batch.get_array_memory_size() - + partition_values - .iter() - .map(|a| a.get_array_memory_size()) - .sum::(); + let idx_array = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &idx_array, None).map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + let batch_size = filtered_batch.get_array_memory_size(); + let accum = &mut partition_accumulators[partition_id]; + accum.total_rows += filtered_batch.num_rows(); + + if accum.spill_writer.is_some() { + self.append_spilled_batch(accum, filtered_batch)?; + continue; + } - // Decide spilling using global reservation (per DF best practice) - let mut will_spill = false; - match self.memory_reservation.try_grow(estimated_size) { - Ok(_) => { - if self.memory_reservation.size() > self.memory_threshold { - // Exceeds threshold: roll back and spill - let _ = self.memory_reservation.try_shrink(estimated_size); - will_spill = true; + match self.memory_reservation.try_grow(batch_size) { + Ok(_) => { + accum.buffered_bytes += batch_size; + accum.buffered_batches.push(filtered_batch); + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + if self.memory_reservation.size() > self.memory_threshold { + if !self.runtime_env.disk_manager.tmp_files_enabled() { + if allow_repartition { + if let Some(next_count) = self.next_partition_count() + { + repartition_request = Some(next_count); + break; + } + } + return Err(internal_datafusion_err!( + "Insufficient memory for build partitioning and spilling is disabled" + )); + } + self.spill_partition(partition_id, accum)?; + } + } + Err(_) => { + if !self.runtime_env.disk_manager.tmp_files_enabled() { + if allow_repartition { + if let Some(next_count) = self.next_partition_count() { + repartition_request = Some(next_count); + break; + } + } + return Err(internal_datafusion_err!( + "Unable to allocate memory for build partition" + )); + } + self.spill_partition(partition_id, accum)?; + self.append_spilled_batch(accum, filtered_batch)?; } } - Err(_) => { - will_spill = true; + + if repartition_request.is_some() { + break; } } - // Disable spilling in single-partition mode to avoid reload deadlocks and ensure progress - if self.num_partitions == 1 { - will_spill = false; + if repartition_request.is_some() { + break; + } + } + + if let Some(next_count) = repartition_request { + return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + } + + self.build_partitions.reserve(self.num_partitions); + self.matched_build_rows_per_partition + .reserve(self.num_partitions); + + let mut partitions_in_memory = 0usize; + let mut partitions_spilled = 0usize; + let mut partitions_empty = 0usize; + for part_id in 0..self.num_partitions { + let mut accum = mem::take(&mut partition_accumulators[part_id]); + if accum.spill_writer.is_some() { + if !accum.buffered_batches.is_empty() { + self.spill_partition(part_id, &mut accum)?; + } + if let Some(mut writer) = accum.spill_writer.take() { + let spill_file = writer + .finish()? + .ok_or_else(|| internal_datafusion_err!("expected spill file"))?; + let mut matched_bitmap = BooleanBufferBuilder::new(accum.total_rows); + matched_bitmap.append_n(accum.total_rows, false); + self.matched_build_rows_per_partition.push(matched_bitmap); + let reservation = MemoryConsumer::new("partition_spilled") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + partitions_spilled += 1; + self.build_partitions.push(BuildPartition::Spilled { + spill_file: Some(spill_file), + reservation, + }); + } + continue; } - if will_spill && self.runtime_env.disk_manager.tmp_files_enabled() { - // println!( - // "Spilling build partition {} (rows={}) due to memory threshold (threshold={} bytes, current={})", - // partition_id, - // row_indices.len(), - // self.memory_threshold, - // self.memory_reservation.size() - // ); - // Spill this partition to disk and do not keep it in memory - let spill_file = self - .build_spill_manager - .spill_record_batch_and_finish( - &[partition_batch.clone()], - "hash_join_build_partition", - )? - .ok_or_else(|| internal_datafusion_err!("expected spill file"))?; - - // Initialize matched rows bitmap for this partition - let mut matched_bitmap = BooleanBufferBuilder::new(row_indices.len()); - matched_bitmap.append_n(row_indices.len(), false); - self.matched_build_rows_per_partition.push(matched_bitmap); - - // Per-partition reservation kept as zero-sized placeholder - let reservation = MemoryConsumer::new("partition_spilled") - .with_can_spill(true) - .register(&self.runtime_env.memory_pool); - - self.build_partitions.push(BuildPartition::Spilled { - spill_file: Some(spill_file), - reservation, - }); + if accum.buffered_batches.is_empty() { + self.matched_build_rows_per_partition + .push(BooleanBufferBuilder::new(0)); + partitions_empty += 1; + self.build_partitions.push(BuildPartition::Empty); continue; } - // Create hash map for this partition - let partition_hash_map: Box = - Box::new(crate::joins::join_hash_map::JoinHashMapU32::with_capacity( - row_indices.len(), - )); + let mut buffered_batches = accum.buffered_batches; + let partition_batch = if buffered_batches.len() == 1 { + buffered_batches.pop().unwrap() + } else { + let batch_refs: Vec<_> = buffered_batches.iter().collect(); + concat_batches(&self.build_schema, batch_refs)? + }; + let num_rows = partition_batch.num_rows(); + let partition_values = self + .on_left + .iter() + .map(|expr| expr.evaluate(&partition_batch)?.into_array(num_rows)) + .collect::>>()?; + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + let mut hash_map: Box = if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; - // Build the hash map for this partition from pre-sliced key arrays - let mut partition_hash_map = partition_hash_map; self.hashes_buffer.clear(); - self.hashes_buffer.resize(partition_batch.num_rows(), 0); + self.hashes_buffer.resize(num_rows, 0); create_hashes( &partition_values, &self.random_state, &mut self.hashes_buffer, )?; - partition_hash_map.extend_zero(partition_batch.num_rows()); - let iter = self.hashes_buffer.iter().enumerate().map(|(i, h)| (i, h)); - partition_hash_map.update_from_iter(Box::new(iter), 0); - - // println!( - // "Built in-memory hash map for partition {} (rows={})", - // partition_id, - // row_indices.len() - // ); - // Metrics: approximate build memory used (batch + values) - let approx = partition_batch.get_array_memory_size() - + partition_values - .iter() - .map(|a| a.get_array_memory_size()) - .sum::(); - self.join_metrics - .build_mem_used - .set_max(self.memory_reservation.size().saturating_add(approx)); + hash_map.extend_zero(num_rows); + let iter = self + .hashes_buffer + .iter() + .enumerate() + .map(|(idx, hash)| (idx, hash)); + hash_map.update_from_iter(Box::new(iter), 0); - // Initialize matched rows bitmap for this partition - let mut matched_bitmap = BooleanBufferBuilder::new(row_indices.len()); - matched_bitmap.append_n(row_indices.len(), false); + let mut matched_bitmap = BooleanBufferBuilder::new(num_rows); + matched_bitmap.append_n(num_rows, false); self.matched_build_rows_per_partition.push(matched_bitmap); - // Per-partition reservation: zero-sized placeholder; global reservation tracks memory let reservation = MemoryConsumer::new("partition_memory") .with_can_spill(true) .register(&self.runtime_env.memory_pool); - //let is_empty_after = partition_hash_map.is_empty(); - // println!( - // "Partition {} hashmap empty after build? {}", - // partition_id, is_empty_after - // ); + let approx_partition_size = partition_batch.get_array_memory_size() + + partition_values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::(); + self.join_metrics.build_mem_used.set_max( + self.memory_reservation + .size() + .saturating_add(approx_partition_size), + ); + partitions_in_memory += 1; self.build_partitions.push(BuildPartition::InMemory { - hash_map: partition_hash_map, + hash_map, batch: partition_batch, values: partition_values, reservation, }); } - // Start processing from the first radix partition and iterate sequentially - // This ensures a single stream can process all partitions when the - // operator reports a single output partition. - let start_partition = 0; - // println!( - // "Partitioning complete. Created {} partitions. Starting to process partition {}", - // self.build_partitions.len(), start_partition - // ); - self.state = PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { - partition_id: start_partition, + partition_id: 0, total_partitions: self.num_partitions, - is_last_partition: start_partition + 1 == self.num_partitions, + is_last_partition: self.num_partitions == 1, }); - Ok(StatefulStreamResult::Continue) + Ok(PartitionBuildStatus::Ready(StatefulStreamResult::Continue)) } - /// Take specific rows from a RecordBatch fn take_rows(&self, batch: &RecordBatch, indices: &[usize]) -> Result { use arrow::array::UInt32Array; @@ -1039,9 +1254,8 @@ impl PartitionedHashJoinStream { .filter_map(|expr| expr.evaluate(&empty_batch).ok()) .filter_map(|v| v.into_array(empty_batch.num_rows()).ok()) .collect(); - let empty_hash_map: Box = Box::new( - crate::joins::join_hash_map::JoinHashMapU32::with_capacity(0), - ); + let empty_hash_map: Box = + Box::new(JoinHashMapU32::with_capacity(0)); self.build_partitions[partition_id] = BuildPartition::InMemory { hash_map: empty_hash_map, @@ -1076,12 +1290,10 @@ impl PartitionedHashJoinStream { self.state = PartitionedHashJoinState::HandleUnmatchedRows; return Poll::Ready(Ok(StatefulStreamResult::Continue)); } - // println!( - // "Processing partition {} (total_partitions={}), build_partitions.len()={}", - // partition_state.partition_id, - // partition_state.total_partitions, - // self.build_partitions.len() - // ); + + if self.current_partition != Some(partition_state.partition_id) { + self.current_partition = Some(partition_state.partition_id); + } // Do not buffer probe side here; selection happens below depending on num_partitions @@ -1235,9 +1447,11 @@ impl PartitionedHashJoinStream { partition_state.partition_id, ); if partition_state.is_last_partition { + self.current_partition = None; self.state = PartitionedHashJoinState::HandleUnmatchedRows; } else { + self.current_partition = None; self.state = PartitionedHashJoinState::ProcessPartition( ProcessPartitionState { @@ -1330,8 +1544,10 @@ impl PartitionedHashJoinStream { // ); self.release_partition_resources(partition_state.partition_id); if partition_state.is_last_partition { + self.current_partition = None; self.state = PartitionedHashJoinState::HandleUnmatchedRows; } else { + self.current_partition = None; self.state = PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { partition_id: partition_state.partition_id + 1, @@ -2137,9 +2353,10 @@ impl PartitionedHashJoinStream { .into_array(concatenated.num_rows())?, ); } - let hash_map: Box = Box::new( - crate::joins::join_hash_map::JoinHashMapU32::with_capacity(concatenated.num_rows()), - ); + let hash_map: Box = + Box::new(JoinHashMapU32::with_capacity( + concatenated.num_rows(), + )); self.build_partitions[self.unmatched_partition] = BuildPartition::InMemory { hash_map, From 072f56a07547a575de13f8419bff972bea32e1a7 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 5 Nov 2025 17:48:17 +0200 Subject: [PATCH 15/36] tracking spilled bytes --- .../src/joins/hash_join/partitioned.rs | 90 +++++++++++++------ datafusion/physical-plan/src/joins/utils.rs | 30 +++++++ 2 files changed, 92 insertions(+), 28 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 6eda45d38b24e..b8974bb4db79b 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -120,6 +120,10 @@ pub(super) enum BuildPartition { spill_file: Option, /// Memory reservation (released when spilled) reservation: MemoryReservation, + /// Total bytes written for this spill partition + spilled_bytes: usize, + /// Total rows written for this spill partition + spilled_rows: usize, }, /// Partition resources released and not available Released { @@ -151,6 +155,7 @@ struct PartitionAccumulator { buffered_bytes: usize, total_rows: usize, spill_writer: Option, + spilled_bytes: usize, } impl PartitionAccumulator { @@ -160,6 +165,7 @@ impl PartitionAccumulator { buffered_bytes: 0, total_rows: 0, spill_writer: None, + spilled_bytes: 0, } } } @@ -348,7 +354,7 @@ impl PartitionedHashJoinStream { fn spill_partition( &mut self, - partition_id: usize, + _partition_id: usize, accum: &mut PartitionAccumulator, ) -> Result<()> { let buffered_batches = mem::take(&mut accum.buffered_batches); @@ -356,10 +362,22 @@ impl PartitionedHashJoinStream { return Ok(()); } - let writer = self.ensure_build_spill_writer(accum)?; - for batch in buffered_batches { - writer.append_batch(&batch)?; + let created_writer = accum.spill_writer.is_none(); + let mut total_spilled_bytes = 0usize; + { + let writer = self.ensure_build_spill_writer(accum)?; + if created_writer { + self.join_metrics.build_spill_count.add(1); + } + for batch in buffered_batches { + let batch_size = batch.get_array_memory_size(); + total_spilled_bytes = total_spilled_bytes.saturating_add(batch_size); + self.join_metrics.build_spilled_rows.add(batch.num_rows()); + self.join_metrics.build_spilled_bytes.add(batch_size); + writer.append_batch(&batch)?; + } } + accum.spilled_bytes = accum.spilled_bytes.saturating_add(total_spilled_bytes); if accum.buffered_bytes > 0 { let _ = self.memory_reservation.try_shrink(accum.buffered_bytes); accum.buffered_bytes = 0; @@ -372,8 +390,14 @@ impl PartitionedHashJoinStream { accum: &mut PartitionAccumulator, batch: RecordBatch, ) -> Result<()> { - let writer = self.ensure_build_spill_writer(accum)?; - writer.append_batch(&batch)?; + let batch_size = batch.get_array_memory_size(); + self.join_metrics.build_spilled_rows.add(batch.num_rows()); + self.join_metrics.build_spilled_bytes.add(batch_size); + { + let writer = self.ensure_build_spill_writer(accum)?; + writer.append_batch(&batch)?; + } + accum.spilled_bytes = accum.spilled_bytes.saturating_add(batch_size); Ok(()) } @@ -583,8 +607,6 @@ impl PartitionedHashJoinStream { .with_can_spill(true) .register(&self.runtime_env.memory_pool); - let concat_rows = concatenated.num_rows(); - self.build_partitions[part_id] = BuildPartition::InMemory { hash_map, batch: concatenated, @@ -818,11 +840,18 @@ impl PartitionedHashJoinStream { "hash_join_probe_partition", )?; self.probe_spill_in_progress[part_id] = Some(ipf); + self.join_metrics.probe_spill_count.add(1); } if let Some(ref mut ipf) = self.probe_spill_in_progress[part_id] { ipf.append_batch(&filtered_batch)?; + self.join_metrics + .probe_spilled_rows + .add(filtered_batch.num_rows()); + self.join_metrics + .probe_spilled_bytes + .add(filtered_batch.get_array_memory_size()); // println!( // "[spill-join][probe-spill] write partition={} rows={}", // part_id, @@ -957,8 +986,10 @@ impl PartitionedHashJoinStream { .map(|_| PartitionAccumulator::new()) .collect::>(); let mut repartition_request: Option = None; + let mut max_spilled_bytes: usize = 0; + let mut any_spilled = false; - for (batch_index, batch) in build_data.original_batches().iter().enumerate() { + for batch in build_data.original_batches() { let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); for expr in &self.on_left { keys_values.push(expr.evaluate(batch)?.into_array(batch.num_rows())?); @@ -1006,14 +1037,13 @@ impl PartitionedHashJoinStream { .build_mem_used .set_max(self.memory_reservation.size()); if self.memory_reservation.size() > self.memory_threshold { - if !self.runtime_env.disk_manager.tmp_files_enabled() { - if allow_repartition { - if let Some(next_count) = self.next_partition_count() - { - repartition_request = Some(next_count); - break; - } + if allow_repartition { + if let Some(next_count) = self.next_partition_count() { + repartition_request = Some(next_count); + break; } + } + if !self.runtime_env.disk_manager.tmp_files_enabled() { return Err(internal_datafusion_err!( "Insufficient memory for build partitioning and spilling is disabled" )); @@ -1022,13 +1052,13 @@ impl PartitionedHashJoinStream { } } Err(_) => { - if !self.runtime_env.disk_manager.tmp_files_enabled() { - if allow_repartition { - if let Some(next_count) = self.next_partition_count() { - repartition_request = Some(next_count); - break; - } + if allow_repartition { + if let Some(next_count) = self.next_partition_count() { + repartition_request = Some(next_count); + break; } + } + if !self.runtime_env.disk_manager.tmp_files_enabled() { return Err(internal_datafusion_err!( "Unable to allocate memory for build partition" )); @@ -1056,11 +1086,9 @@ impl PartitionedHashJoinStream { self.matched_build_rows_per_partition .reserve(self.num_partitions); - let mut partitions_in_memory = 0usize; - let mut partitions_spilled = 0usize; - let mut partitions_empty = 0usize; for part_id in 0..self.num_partitions { let mut accum = mem::take(&mut partition_accumulators[part_id]); + max_spilled_bytes = max_spilled_bytes.max(accum.spilled_bytes); if accum.spill_writer.is_some() { if !accum.buffered_batches.is_empty() { self.spill_partition(part_id, &mut accum)?; @@ -1075,10 +1103,12 @@ impl PartitionedHashJoinStream { let reservation = MemoryConsumer::new("partition_spilled") .with_can_spill(true) .register(&self.runtime_env.memory_pool); - partitions_spilled += 1; + any_spilled = true; self.build_partitions.push(BuildPartition::Spilled { spill_file: Some(spill_file), reservation, + spilled_bytes: accum.spilled_bytes, + spilled_rows: accum.total_rows, }); } continue; @@ -1087,7 +1117,6 @@ impl PartitionedHashJoinStream { if accum.buffered_batches.is_empty() { self.matched_build_rows_per_partition .push(BooleanBufferBuilder::new(0)); - partitions_empty += 1; self.build_partitions.push(BuildPartition::Empty); continue; } @@ -1159,7 +1188,6 @@ impl PartitionedHashJoinStream { .saturating_add(approx_partition_size), ); - partitions_in_memory += 1; self.build_partitions.push(BuildPartition::InMemory { hash_map, batch: partition_batch, @@ -1168,6 +1196,12 @@ impl PartitionedHashJoinStream { }); } + if (max_spilled_bytes > self.memory_threshold || any_spilled) && allow_repartition { + if let Some(next_count) = self.next_partition_count() { + return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + } + } + self.state = PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { partition_id: 0, total_partitions: self.num_partitions, diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 3a0847754a147..1b4d8474fed0b 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1329,6 +1329,18 @@ pub(crate) struct BuildProbeJoinMetrics { pub(crate) build_input_rows: metrics::Count, /// Memory used by build-side in bytes pub(crate) build_mem_used: metrics::Gauge, + /// Number of spill files produced for the build side + pub(crate) build_spill_count: metrics::Count, + /// Total build-side bytes written to spill + pub(crate) build_spilled_bytes: metrics::Count, + /// Total build-side rows written to spill + pub(crate) build_spilled_rows: metrics::Count, + /// Number of spill files produced for the probe side + pub(crate) probe_spill_count: metrics::Count, + /// Total probe-side bytes written to spill + pub(crate) probe_spilled_bytes: metrics::Count, + /// Total probe-side rows written to spill + pub(crate) probe_spilled_rows: metrics::Count, /// Total time for joining probe-side batches to the build-side batches pub(crate) join_time: metrics::Time, /// Number of batches consumed by probe-side of this operator @@ -1374,6 +1386,18 @@ impl BuildProbeJoinMetrics { let build_mem_used = MetricBuilder::new(metrics).gauge("build_mem_used", partition); + let build_spill_count = + MetricBuilder::new(metrics).counter("build_spill_count", partition); + let build_spilled_bytes = + MetricBuilder::new(metrics).counter("build_spilled_bytes", partition); + let build_spilled_rows = + MetricBuilder::new(metrics).counter("build_spilled_rows", partition); + let probe_spill_count = + MetricBuilder::new(metrics).counter("probe_spill_count", partition); + let probe_spilled_bytes = + MetricBuilder::new(metrics).counter("probe_spilled_bytes", partition); + let probe_spilled_rows = + MetricBuilder::new(metrics).counter("probe_spilled_rows", partition); let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); @@ -1388,6 +1412,12 @@ impl BuildProbeJoinMetrics { build_input_batches, build_input_rows, build_mem_used, + build_spill_count, + build_spilled_bytes, + build_spilled_rows, + probe_spill_count, + probe_spilled_bytes, + probe_spilled_rows, join_time, input_batches, input_rows, From 0932d547c7f739265b38d9eae3ff7307b5d79367 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Thu, 6 Nov 2025 20:15:19 +0300 Subject: [PATCH 16/36] Upd --- .../src/joins/grace_hash_join/exec.rs | 165 +++++++++++------- .../src/joins/grace_hash_join/stream.rs | 1 + 2 files changed, 103 insertions(+), 63 deletions(-) diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index e25407c084bbb..0fca133f79d79 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -117,8 +117,6 @@ pub struct GraceHashJoinExec { /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. dynamic_filter: Option, accumulator: Arc, - spill_left: Arc, - spill_right: Arc, } #[derive(Clone)] @@ -200,20 +198,7 @@ impl GraceHashJoinExec { let partitions = left.output_partitioning().partition_count(); let accumulator = GraceAccumulator::new(partitions); - let metrics = ExecutionPlanMetricsSet::new(); - let runtime = Arc::new(RuntimeEnv::default()); - let spill_left = Arc::new(SpillManager::new( - Arc::clone(&runtime), - SpillMetrics::new(&metrics, 0), - Arc::clone(&left_schema), - )); - let spill_right = Arc::new(SpillManager::new( - Arc::clone(&runtime), - SpillMetrics::new(&metrics, 0), - Arc::clone(&right_schema), - )); - // Initialize both dynamic filter and bounds accumulator to None // They will be set later if dynamic filtering is enabled Ok(GraceHashJoinExec { @@ -231,8 +216,6 @@ impl GraceHashJoinExec { cache, dynamic_filter: None, accumulator, - spill_left, - spill_right, }) } @@ -575,7 +558,6 @@ impl ExecutionPlan for GraceHashJoinExec { self: Arc, children: Vec>, ) -> Result> { - let partitions = children[0].output_partitioning().partition_count(); Ok(Arc::new(GraceHashJoinExec { left: Arc::clone(&children[0]), right: Arc::clone(&children[1]), @@ -599,8 +581,6 @@ impl ExecutionPlan for GraceHashJoinExec { // Keep the dynamic filter, bounds accumulator will be reset dynamic_filter: self.dynamic_filter.clone(), accumulator: Arc::clone(&self.accumulator), - spill_left: Arc::clone(&self.spill_left), - spill_right: Arc::clone(&self.spill_right), })) } @@ -621,8 +601,6 @@ impl ExecutionPlan for GraceHashJoinExec { // Reset dynamic filter and bounds accumulator to initial state dynamic_filter: None, accumulator: Arc::clone(&self.accumulator), - spill_left: Arc::clone(&self.spill_left), - spill_right: Arc::clone(&self.spill_right), })) } @@ -875,8 +853,6 @@ impl ExecutionPlan for GraceHashJoinExec { filter: dynamic_filter, bounds_accumulator: OnceLock::new(), }), - spill_left: Arc::clone(&self.spill_left), - spill_right: Arc::clone(&self.spill_right), accumulator: Arc::clone(&self.accumulator), }); result = result.with_updated_node(new_node as Arc); @@ -961,50 +937,61 @@ async fn partition_and_spill_one_side( .map(|_| PartitionWriter::new(Arc::clone(&spill_manager))) .collect(); + let mut buffered_batches = Vec::new(); let mut total_rows = 0usize; - + let schema = input.schema(); while let Some(batch) = input.next().await { let batch = batch?; - let num_rows = batch.num_rows(); - if num_rows == 0 { + if batch.num_rows() == 0 { continue; } - - total_rows += num_rows; + total_rows += batch.num_rows(); join_metrics.build_input_batches.add(1); - join_metrics.build_input_rows.add(num_rows); + join_metrics.build_input_rows.add(batch.num_rows()); + buffered_batches.push(batch); + } + if buffered_batches.is_empty() { + return Ok(Vec::new()); + } + // Create single batch to reduce number of spilled files + let single_batch = concat_batches(&schema, &buffered_batches)?; + let num_rows = single_batch.num_rows(); + if num_rows == 0 { + return Ok(Vec::new()); + } - // Calculate hashes - let keys = on_exprs - .iter() - .map(|c| c.evaluate(&batch)?.into_array(num_rows)) - .collect::>>()?; + // Calculate hashes + let keys = on_exprs + .iter() + .map(|c| c.evaluate(&single_batch)?.into_array(num_rows)) + .collect::>>()?; - let mut hashes = vec![0u64; num_rows]; - create_hashes(&keys, random_state, &mut hashes)?; + let mut hashes = vec![0u64; num_rows]; + create_hashes(&keys, random_state, &mut hashes)?; - // Spread to partitions - let mut indices: Vec> = vec![Vec::new(); partition_count]; - for (row, h) in hashes.iter().enumerate() { - let bucket = (*h as usize) % partition_count; - indices[bucket].push(row as u32); - } + // Spread to partitions + let mut indices: Vec> = vec![Vec::new(); partition_count]; + for (row, h) in hashes.iter().enumerate() { + let bucket = (*h as usize) % partition_count; + indices[bucket].push(row as u32); + } - // Collect and spill - for (i, idxs) in indices.into_iter().enumerate() { - if idxs.is_empty() { - continue; - } - let idx_array = UInt32Array::from(idxs); - let taken = batch - .columns() - .iter() - .map(|c| take(c.as_ref(), &idx_array, None)) - .collect::>>()?; - let part_batch = RecordBatch::try_new(batch.schema(), taken)?; - let request_msg = format!("grace_partition_{file_request_msg}_{i}"); - partitions[i].spill_batch_auto(&part_batch, &request_msg)?; + // Collect and spill + for (i, idxs) in indices.into_iter().enumerate() { + if idxs.is_empty() { + continue; } + + let idx_array = UInt32Array::from(idxs); + let taken = single_batch + .columns() + .iter() + .map(|c| take(c.as_ref(), &idx_array, None)) + .collect::>>()?; + + let part_batch = RecordBatch::try_new(single_batch.schema(), taken)?; + let request_msg = format!("grace_partition_{file_request_msg}_{i}"); + partitions[i].spill_batch_auto(&part_batch, &request_msg)?; } // Prepare indexes @@ -1012,7 +999,7 @@ async fn partition_and_spill_one_side( for (i, writer) in partitions.into_iter().enumerate() { result.push(writer.finish(i)?); } - + println!("spill_manager {:?}", spill_manager.metrics); Ok(result) } @@ -1104,6 +1091,7 @@ mod tests { use rstest::*; use rstest_reuse::*; use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use crate::joins::HashJoinExec; fn build_large_table( a_name: &str, @@ -1112,7 +1100,7 @@ mod tests { n: usize, ) -> Arc { let a: ArrayRef = Arc::new(Int32Array::from_iter_values(1..=n as i32)); - let b: ArrayRef = Arc::new(Int32Array::from_iter_values(1..=n as i32)); + let b: ArrayRef = Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 2))); let c: ArrayRef = Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 10))); let schema = Arc::new(arrow::datatypes::Schema::new(vec![ @@ -1138,7 +1126,8 @@ mod tests { } #[tokio::test] - async fn single_partition_join_overallocation() -> Result<()> { + async fn single_partition_join_overallocation() -> Result<()> + { // let left = build_table( // ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), // ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), @@ -1149,13 +1138,13 @@ mod tests { // ("b2", &vec![1, 2]), // ("c2", &vec![14, 15]), // ); - let left = build_large_table("a1", "b1", "c1", 100_000); - let right = build_large_table("a2", "b2", "c2", 50_000); + let left = build_large_table("a1", "b1", "c1", 200); + let right = build_large_table("a2", "b2", "c2", 500); let on = vec![( Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let (left_expr, right_expr) = on + let (left_expr, right_expr): (Vec>, Vec>) = on .iter() .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) .unzip(); @@ -1216,4 +1205,54 @@ mod tests { // ); Ok(()) } + + #[tokio::test] + async fn single_partition_join_overallocation_f() -> Result<()> { + + let left = build_large_table("a1", "b1", "c1", 200); + let right = build_large_table("a2", "b2", "c2", 500); + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, + )]; + let (left_expr, right_expr): (Vec>, Vec>) = on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + let left_repartitioned: Arc = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, 32), + )?); + let right_repartitioned: Arc = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, 32), + )?); + + let join = HashJoinExec::try_new( + left_repartitioned, + right_repartitioned, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + )?; + + let task_ctx = Arc::new(TaskContext::default()); + let mut batches = vec![]; + for i in 0..32 { + let stream = join.execute(i, Arc::clone(&task_ctx))?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + print_batches(&*batches).unwrap(); + Ok(()) + } + } diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs index 417990da43f11..97e48fc81a362 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs @@ -265,6 +265,7 @@ fn load_partition_async( ) -> OnceFut> { OnceFut::new(async move { let mut all_batches = Vec::new(); + println!("partitions {:?}", partitions); for p in partitions { for chunk in p.chunks { let mut reader = spill_manager.load_spilled_batch(&chunk)?; From c87377dffea935657a639d108f8a78b040b77ce7 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Thu, 6 Nov 2025 22:09:27 +0200 Subject: [PATCH 17/36] Spilled probe path --- .../src/joins/hash_join/partitioned.rs | 1539 ++++++++++++----- datafusion/physical-plan/src/joins/utils.rs | 20 + 2 files changed, 1100 insertions(+), 459 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index b8974bb4db79b..ccfb72689517f 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -43,6 +43,7 @@ //! - Generates join results and handles unmatched rows for outer joins //! - Tracks matched rows for proper outer join semantics +use std::collections::VecDeque; use std::mem::{self, size_of}; use std::sync::Arc; use std::task::{Context, Poll}; @@ -75,7 +76,25 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; -use futures::{ready, Stream, StreamExt}; +use futures::{executor::block_on, ready, Stream, StreamExt}; + +const HYBRID_HASH_MAX_REPARTITION_DEPTH: usize = 6; +const HYBRID_HASH_MIN_FANOUT: usize = 2; +const HYBRID_HASH_MIN_PARTITION_BYTES: usize = 8 * 1024 * 1024; +const HYBRID_HASH_ROWS_PER_PARTITION_TARGET_MULTIPLIER: usize = 8; +const HYBRID_HASH_ROWS_PER_PARTITION_MIN: usize = 32 * 1024; + +fn highest_power_of_two_leq(n: usize) -> usize { + if n <= 1 { + 1 + } else { + let mut power = 1usize; + while (power << 1) <= n { + power <<= 1; + } + power + } +} /// State of the partitioned hash join stream #[derive(Debug, Clone)] @@ -93,12 +112,8 @@ pub(super) enum PartitionedHashJoinState { /// State for processing a specific partition #[derive(Debug, Clone)] pub(super) struct ProcessPartitionState { - /// Current partition being processed - pub partition_id: usize, - /// Total number of partitions - pub total_partitions: usize, - /// Whether we're processing the last partition - pub is_last_partition: bool, + /// Descriptor for the partition currently being processed + descriptor: PartitionDescriptor, } /// Represents a partition of build-side data @@ -145,6 +160,16 @@ pub(super) struct ProbePartition { pub hashes: Vec>, } +impl ProbePartition { + fn new() -> Self { + Self { + batches: Vec::new(), + values: Vec::new(), + hashes: Vec::new(), + } + } +} + enum PartitionBuildStatus { Ready(StatefulStreamResult>), NeedMorePartitions { next_count: usize }, @@ -176,6 +201,24 @@ impl Default for PartitionAccumulator { } } +#[derive(Debug, Clone)] +pub(super) struct PartitionDescriptor { + /// Index into build/probe storage vectors + build_index: usize, + /// Index of the original (generation 0) partition + root_index: usize, + /// Number of refinement passes applied so far + generation: usize, + /// Total number of radix bits used to identify this partition + radix_bits: usize, + /// Hash prefix (lower `radix_bits`) identifying this partition + hash_prefix: u64, + /// Latest spilled byte estimate for this partition + spilled_bytes: usize, + /// Latest spilled row estimate for this partition + spilled_rows: usize, +} + // Use RefCountedTempFile from datafusion_execution::disk_manager /// Partitioned Hash Join stream that can handle large datasets by partitioning @@ -231,6 +274,8 @@ pub(super) struct PartitionedHashJoinStream { pub probe_partitions: Vec, /// Current partition being processed pub current_partition: Option, + /// Queue of pending partitions to process (supports recursive fan-out) + pub pending_partitions: VecDeque, /// Spill manager for probe-side (right) batches pub probe_spill_manager: SpillManager, /// Spill manager for build-side (left) batches @@ -276,8 +321,8 @@ pub(super) struct PartitionedHashJoinStream { pub unmatched_left_indices_cache: Option, pub unmatched_right_indices_cache: Option, pub unmatched_offset: usize, - /// Whether we've buffered the entire probe side into per-partition batches - pub probes_buffered: bool, + /// Whether the probe stream has reached EOF + pub probe_stream_finished: bool, /// Current read position per partition within buffered probe batches pub probe_batch_positions: Vec, /// Metrics: total probe rows buffered per partition (RAM) @@ -304,12 +349,16 @@ pub(super) struct PartitionedHashJoinStream { pub pending_reload_partition: Option, /// In-progress probe spill writers, one per partition (used when corresponding build is spilled) pub probe_spill_in_progress: Vec>, - /// Finalized probe spill files per partition (set after buffering probe side) - pub probe_spill_files: Vec>, + /// Finalized probe spill files per partition (queue of ready-to-read files) + pub probe_spill_files: Vec>, /// Pending probe stream for the current partition's probe spill file pub pending_probe_stream: Option, /// Target partition id for pending probe stream pub pending_probe_partition: Option, + /// Whether a partition is currently queued for processing + pub partition_pending: Vec, + /// Latest descriptor metadata per partition + pub partition_descriptors: Vec>, } impl PartitionedHashJoinStream { @@ -324,19 +373,778 @@ impl PartitionedHashJoinStream { } } - fn resize_partition_vectors(&mut self) { - let n = self.num_partitions; - self.probe_spill_in_progress = (0..n).map(|_| None).collect(); - self.probe_spill_files = (0..n).map(|_| None).collect(); - self.probe_batch_positions = vec![0; n]; - self.probe_buffered_rows_per_part = vec![0; n]; - self.probe_spilled_rows_per_part = vec![0; n]; - self.probe_consumed_rows_per_part = vec![0; n]; - self.matched_rows_per_part = vec![0; n]; - self.emitted_rows_per_part = vec![0; n]; - self.candidate_pairs_per_part = vec![0; n]; - self.verify_once_per_part = vec![false; n]; - self.filter_debug_once_per_part = vec![false; n]; + fn resize_partition_vectors(&mut self) { + let n = self.num_partitions; + self.probe_spill_in_progress = (0..n).map(|_| None).collect(); + self.probe_spill_files = (0..n).map(|_| VecDeque::new()).collect(); + self.probe_batch_positions = vec![0; n]; + self.probe_buffered_rows_per_part = vec![0; n]; + self.probe_spilled_rows_per_part = vec![0; n]; + self.probe_consumed_rows_per_part = vec![0; n]; + self.matched_rows_per_part = vec![0; n]; + self.emitted_rows_per_part = vec![0; n]; + self.candidate_pairs_per_part = vec![0; n]; + self.verify_once_per_part = vec![false; n]; + self.filter_debug_once_per_part = vec![false; n]; + self.partition_pending = vec![false; n]; + self.partition_descriptors = (0..n).map(|_| None).collect(); + } + + fn allocate_partition_slot(&mut self) -> usize { + let idx = self.build_partitions.len(); + self.build_partitions.push(BuildPartition::Empty); + self.matched_build_rows_per_partition + .push(BooleanBufferBuilder::new(0)); + self.probe_partitions.push(ProbePartition::new()); + self.probe_batch_positions.push(0); + self.probe_spill_in_progress.push(None); + self.probe_spill_files.push(VecDeque::new()); + self.probe_buffered_rows_per_part.push(0); + self.probe_spilled_rows_per_part.push(0); + self.probe_consumed_rows_per_part.push(0); + self.matched_rows_per_part.push(0); + self.emitted_rows_per_part.push(0); + self.candidate_pairs_per_part.push(0); + self.verify_once_per_part.push(false); + self.filter_debug_once_per_part.push(false); + self.partition_pending.push(false); + self.partition_descriptors.push(None); + idx + } + + fn schedule_partition(&mut self, part_id: usize) -> Result<()> { + if part_id >= self.partition_pending.len() { + let new_len = part_id + 1; + self.partition_pending.resize(new_len, false); + self.partition_descriptors.resize_with(new_len, || None); + } + + if self.current_partition == Some(part_id) { + return Ok(()); + } + + if self.partition_pending[part_id] { + return Ok(()); + } + + if let Some(desc) = self + .partition_descriptors + .get(part_id) + .and_then(|d| d.clone()) + { + self.pending_partitions.push_back(desc); + self.partition_pending[part_id] = true; + } + + Ok(()) + } + + fn flush_probe_writer( + &mut self, + part_id: usize, + ) -> Result> { + if part_id >= self.probe_spill_in_progress.len() { + return Ok(None); + } + if let Some(mut writer) = self.probe_spill_in_progress[part_id].take() { + let file = writer.finish()?; + return Ok(file); + } + Ok(None) + } + + fn finalize_spilled_partition(&mut self, part_id: usize) -> Result { + if part_id >= self.probe_spill_in_progress.len() { + return Ok(false); + } + if let Some(file) = self.flush_probe_writer(part_id)? { + if part_id >= self.probe_spill_files.len() { + self.probe_spill_files + .resize_with(part_id + 1, VecDeque::new); + } + self.probe_spill_files[part_id].push_back(file); + self.schedule_partition(part_id)?; + return Ok(true); + } + Ok(false) + } + + fn compute_recursive_fanout( + &self, + descriptor: &PartitionDescriptor, + ) -> Option<(usize, usize)> { + if descriptor.generation >= HYBRID_HASH_MAX_REPARTITION_DEPTH { + return None; + } + if self.max_partition_count == 0 { + return None; + } + let current_total = self.build_partitions.len(); + if current_total == 0 { + return None; + } + + let max_fanout_allowed = self + .max_partition_count + .saturating_sub(current_total.saturating_sub(1)); + if max_fanout_allowed < HYBRID_HASH_MIN_FANOUT { + return None; + } + + let mut per_partition_budget = self + .memory_threshold + .checked_div(self.max_partition_count.max(1)) + .unwrap_or(self.memory_threshold); + if per_partition_budget == 0 { + per_partition_budget = HYBRID_HASH_MIN_PARTITION_BYTES; + } + per_partition_budget = per_partition_budget.max(HYBRID_HASH_MIN_PARTITION_BYTES); + + let rows_budget = self + .batch_size + .saturating_mul(HYBRID_HASH_ROWS_PER_PARTITION_TARGET_MULTIPLIER) + .max(HYBRID_HASH_ROWS_PER_PARTITION_MIN); + + let should_repartition_bytes = descriptor.spilled_bytes > per_partition_budget; + let should_repartition_rows = descriptor.spilled_rows > rows_budget; + + if !should_repartition_bytes && !should_repartition_rows { + return None; + } + + let mut required = HYBRID_HASH_MIN_FANOUT; + + if should_repartition_bytes { + let budget = per_partition_budget.max(1); + let needed = descriptor.spilled_bytes.saturating_add(budget - 1) / budget; + required = required.max(needed); + } + + if should_repartition_rows { + let budget = rows_budget.max(1); + let needed = descriptor.spilled_rows.saturating_add(budget - 1) / budget; + required = required.max(needed); + } + + let mut fanout = required.next_power_of_two(); + if fanout == 0 { + fanout = HYBRID_HASH_MIN_FANOUT; + } + if fanout > max_fanout_allowed { + fanout = highest_power_of_two_leq(max_fanout_allowed); + } + if fanout < HYBRID_HASH_MIN_FANOUT { + return None; + } + + let additional_bits = fanout.trailing_zeros() as usize; + if additional_bits == 0 { + return None; + } + Some((additional_bits, fanout)) + } + + fn repartition_spilled_partition( + &mut self, + descriptor: &PartitionDescriptor, + additional_bits: usize, + fanout: usize, + ) -> Result> { + let build_index = descriptor.build_index; + if build_index >= self.build_partitions.len() { + return Ok(vec![]); + } + + let placeholder_reservation = + MemoryConsumer::new("partition_repartition_placeholder") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + let old_partition = mem::replace( + &mut self.build_partitions[build_index], + BuildPartition::Released { + reservation: placeholder_reservation, + }, + ); + + let (spill_file, _spilled_bytes, _spilled_rows) = match old_partition { + BuildPartition::Spilled { + spill_file, + spilled_bytes, + spilled_rows, + .. + } => ( + spill_file.ok_or_else(|| { + internal_datafusion_err!( + "spill file already consumed for partition {}", + build_index + ) + })?, + spilled_bytes, + spilled_rows, + ), + other => { + self.build_partitions[build_index] = other; + return Ok(vec![]); + } + }; + + // Collect spilled build batches + let mut build_batches = block_on(async { + let mut stream = self.build_spill_manager.read_spill_as_stream(spill_file)?; + let mut batches = Vec::new(); + while let Some(batch) = stream.next().await { + batches.push(batch?); + } + Result::>::Ok(batches) + })?; + + if build_batches.is_empty() { + // Nothing to repartition; keep placeholder as empty partition + let mut new_descriptor = descriptor.clone(); + new_descriptor.spilled_bytes = 0; + new_descriptor.spilled_rows = 0; + self.matched_build_rows_per_partition[build_index] = + BooleanBufferBuilder::new(0); + self.build_partitions[build_index] = BuildPartition::Empty; + return Ok(vec![new_descriptor]); + } + + let shift_bits = descriptor.radix_bits; + let mask = (fanout - 1) as u64; + let mut sub_accumulators = (0..fanout) + .map(|_| PartitionAccumulator::new()) + .collect::>(); + + self.join_metrics.recursive_repartition_events.add(1); + self.join_metrics.recursive_partitions_created.add(fanout); + self.join_metrics + .recursive_partition_depth + .set_max(descriptor.generation.saturating_add(1)); + self.join_metrics + .recursive_repartition_fanout + .set_max(fanout); + + for batch in build_batches.drain(..) { + let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); + for expr in &self.on_left { + keys_values.push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = vec![Vec::new(); fanout]; + for (row_idx, hash) in hashes.iter().enumerate() { + let sub_idx = (((*hash >> shift_bits) as usize) & mask as usize) % fanout; + indices_per_part[sub_idx].push(row_idx as u32); + } + + for (sub_idx, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } + let idx_array = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &idx_array, None).map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + let batch_size = filtered_batch.get_array_memory_size(); + + let accum = &mut sub_accumulators[sub_idx]; + accum.total_rows += filtered_batch.num_rows(); + + match self.memory_reservation.try_grow(batch_size) { + Ok(_) => { + accum.buffered_bytes += batch_size; + accum.buffered_batches.push(filtered_batch); + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + if self.memory_reservation.size() > self.memory_threshold { + self.spill_partition(sub_idx, accum)?; + } + } + Err(_) => { + self.spill_partition(sub_idx, accum)?; + self.append_spilled_batch(accum, filtered_batch)?; + } + } + } + } + + // Finalize sub partitions + let new_radix_bits = descriptor.radix_bits + additional_bits; + let mut new_descriptors = Vec::with_capacity(fanout); + let mut partition_indices = Vec::with_capacity(fanout); + + for sub_idx in 0..fanout { + let accum = &mut sub_accumulators[sub_idx]; + let mut matched_bitmap = BooleanBufferBuilder::new(accum.total_rows); + matched_bitmap.append_n(accum.total_rows, false); + + let new_index = if sub_idx == 0 { + build_index + } else { + self.allocate_partition_slot() + }; + partition_indices.push(new_index); + + self.matched_build_rows_per_partition[new_index] = matched_bitmap; + + if accum.spill_writer.is_some() || !accum.buffered_batches.is_empty() { + if accum.spill_writer.is_some() { + if !accum.buffered_batches.is_empty() { + self.spill_partition(sub_idx, accum)?; + } + let mut writer = accum.spill_writer.take().ok_or_else(|| { + internal_datafusion_err!("missing spill writer") + })?; + let spill_file = writer.finish()?.ok_or_else(|| { + internal_datafusion_err!("expected spill file after repartition") + })?; + let reservation = MemoryConsumer::new("partition_spilled") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + self.build_partitions[new_index] = BuildPartition::Spilled { + spill_file: Some(spill_file), + reservation, + spilled_bytes: accum.spilled_bytes, + spilled_rows: accum.total_rows, + }; + } else { + let mut buffered_batches = mem::take(&mut accum.buffered_batches); + let partition_batch = if buffered_batches.len() == 1 { + buffered_batches.pop().unwrap() + } else { + let batch_refs: Vec<_> = buffered_batches.iter().collect(); + concat_batches(&self.build_schema, batch_refs)? + }; + let num_rows = partition_batch.num_rows(); + let partition_values = self + .on_left + .iter() + .map(|expr| expr.evaluate(&partition_batch)?.into_array(num_rows)) + .collect::>>()?; + + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + let mut hash_map: Box = if num_rows + > u32::MAX as usize + { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; + + self.hashes_buffer.clear(); + self.hashes_buffer.resize(num_rows, 0); + create_hashes( + &partition_values, + &self.random_state, + &mut self.hashes_buffer, + )?; + hash_map.extend_zero(num_rows); + let iter = self + .hashes_buffer + .iter() + .enumerate() + .map(|(idx, hash)| (idx, hash)); + hash_map.update_from_iter(Box::new(iter), 0); + + let reservation = MemoryConsumer::new("partition_memory") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + self.build_partitions[new_index] = BuildPartition::InMemory { + hash_map, + batch: partition_batch, + values: partition_values, + reservation, + }; + accum.spilled_bytes = 0; + } + } else { + self.build_partitions[new_index] = BuildPartition::Empty; + } + + let hash_prefix = + (descriptor.hash_prefix << additional_bits) | (sub_idx as u64); + new_descriptors.push(PartitionDescriptor { + build_index: new_index, + root_index: descriptor.root_index, + generation: descriptor.generation + 1, + radix_bits: new_radix_bits, + hash_prefix, + spilled_bytes: accum.spilled_bytes, + spilled_rows: accum.total_rows, + }); + } + + self.repartition_probe_partition(descriptor, fanout, &partition_indices)?; + + Ok(new_descriptors) + } + + fn repartition_probe_partition( + &mut self, + descriptor: &PartitionDescriptor, + fanout: usize, + partition_indices: &[usize], + ) -> Result<()> { + let parent_index = descriptor.build_index; + if parent_index >= self.probe_partitions.len() { + return Ok(()); + } + + // Reset parent metrics + self.probe_buffered_rows_per_part + .get_mut(parent_index) + .map(|v| *v = 0); + self.probe_spilled_rows_per_part + .get_mut(parent_index) + .map(|v| *v = 0); + self.probe_consumed_rows_per_part + .get_mut(parent_index) + .map(|v| *v = 0); + if parent_index < self.probe_batch_positions.len() { + self.probe_batch_positions[parent_index] = 0; + } + if parent_index < self.probe_spill_in_progress.len() { + self.probe_spill_in_progress[parent_index] = None; + } + + let shift_bits = descriptor.radix_bits; + let mask = (fanout - 1) as u64; + + if let Some(file) = self + .probe_spill_files + .get_mut(parent_index) + .and_then(|queue| queue.pop_front()) + { + let mut writers = Vec::with_capacity(fanout); + for _ in 0..fanout { + let writer = self + .probe_spill_manager + .create_in_progress_file("hash_join_probe_repartition")?; + writers.push(writer); + } + + let mut file_opt = Some(file); + block_on(async { + let mut stream = self + .probe_spill_manager + .read_spill_as_stream(file_opt.take().unwrap())?; + while let Some(batch) = stream.next().await { + let batch = batch?; + let mut key_arrays: Vec = + Vec::with_capacity(self.on_right.len()); + for expr in &self.on_right { + key_arrays + .push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&key_arrays, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = vec![Vec::new(); fanout]; + for (row_idx, hash) in hashes.iter().enumerate() { + let sub_idx = + (((*hash >> shift_bits) as usize) & mask as usize) % fanout; + indices_per_part[sub_idx].push(row_idx as u32); + } + + for (sub_idx, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } + let indices_arr = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + let writer = writers + .get_mut(sub_idx) + .ok_or_else(|| internal_datafusion_err!("missing writer"))?; + writer.append_batch(&filtered_batch)?; + self.join_metrics + .probe_spilled_rows + .add(filtered_batch.num_rows()); + self.join_metrics + .probe_spilled_bytes + .add(filtered_batch.get_array_memory_size()); + } + } + Result::<()>::Ok(()) + })?; + + for (sub_idx, mut writer) in writers.into_iter().enumerate() { + let file = writer.finish()?.ok_or_else(|| { + internal_datafusion_err!("expected probe spill file") + })?; + let partitions_idx = partition_indices[sub_idx]; + self.probe_spill_files[partitions_idx].push_back(file); + self.probe_spilled_rows_per_part[partitions_idx] = 0; + self.probe_buffered_rows_per_part[partitions_idx] = 0; + self.probe_consumed_rows_per_part[partitions_idx] = 0; + } + return Ok(()); + } + + // In-memory probe data + let parent_partition = mem::replace( + &mut self.probe_partitions[parent_index], + ProbePartition::new(), + ); + for idx in 0..parent_partition.batches.len() { + let batch = &parent_partition.batches[idx]; + let values = &parent_partition.values[idx]; + let hashes = &parent_partition.hashes[idx]; + let mut indices_per_part: Vec> = vec![Vec::new(); fanout]; + for (row_idx, hash) in hashes.iter().enumerate() { + let sub_idx = (((*hash >> shift_bits) as usize) & mask as usize) % fanout; + indices_per_part[sub_idx].push(row_idx as u32); + } + + for (sub_idx, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } + let indices_arr = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &indices_arr, None).map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + + let mut filtered_values: Vec = Vec::with_capacity(values.len()); + for arr in values.iter() { + filtered_values.push( + take(arr, &indices_arr, None).map_err(DataFusionError::from)?, + ); + } + + let mut filtered_hashes: Vec = Vec::with_capacity(indices_arr.len()); + for i in indices_arr.values().iter() { + filtered_hashes.push(hashes[*i as usize]); + } + + let idx = partition_indices[sub_idx]; + let part = self + .probe_partitions + .get_mut(idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + part.batches.push(filtered_batch); + part.values.push(filtered_values); + part.hashes.push(filtered_hashes); + let buffered = part + .batches + .last() + .map(|b| b.num_rows()) + .unwrap_or_default(); + self.probe_buffered_rows_per_part[idx] = + self.probe_buffered_rows_per_part[idx].saturating_add(buffered); + } + } + + Ok(()) + } + + fn buffer_probe_side(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.probe_spill_in_progress.len() != self.num_partitions + || self.probe_spill_files.len() != self.num_partitions + { + self.resize_partition_vectors(); + } + if self.probe_partitions.len() != self.num_partitions { + self.probe_partitions.clear(); + } + if self.probe_partitions.is_empty() { + self.probe_partitions = (0..self.num_partitions) + .map(|_| ProbePartition::new()) + .collect(); + } + loop { + match self.right.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + let mut keys_values: Vec = + Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; + keys_values.push(v); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = + vec![Vec::new(); self.num_partitions]; + for (row_idx, &hash) in hashes.iter().enumerate() { + let pid = self.partition_for_hash(hash) as usize; + indices_per_part[pid].push(row_idx as u32); + } + + for part_id in 0..self.num_partitions { + let part_indices = &indices_per_part[part_id]; + if part_indices.is_empty() { + continue; + } + + let indices_arr: UInt32Array = part_indices.clone().into(); + + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + + let mut filtered_on_values: Vec = + Vec::with_capacity(self.on_right.len()); + for arr in &keys_values { + filtered_on_values.push( + take(arr, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); + } + + let mut filtered_hashes: Vec = + Vec::with_capacity(part_indices.len()); + for &i in part_indices.iter() { + filtered_hashes.push(hashes[i as usize]); + } + + match self.build_partitions.get_mut(part_id) { + Some(BuildPartition::Spilled { .. }) => { + if self.probe_spill_in_progress[part_id].is_none() { + let ipf = self + .probe_spill_manager + .create_in_progress_file( + "hash_join_probe_partition", + )?; + self.probe_spill_in_progress[part_id] = Some(ipf); + self.join_metrics.probe_spill_count.add(1); + } + if let Some(ref mut ipf) = + self.probe_spill_in_progress[part_id] + { + ipf.append_batch(&filtered_batch)?; + self.join_metrics + .probe_spilled_rows + .add(filtered_batch.num_rows()); + self.join_metrics + .probe_spilled_bytes + .add(filtered_batch.get_array_memory_size()); + } + self.probe_spilled_rows_per_part[part_id] += + filtered_batch.num_rows(); + let queue_ready = self + .probe_spill_files + .get(part_id) + .map(|q| !q.is_empty()) + .unwrap_or(false); + let stream_active = self + .pending_probe_partition + .is_some_and(|p| p == part_id); + if !queue_ready && !stream_active { + self.finalize_spilled_partition(part_id)?; + } + } + _ => { + self.probe_partitions[part_id] + .batches + .push(filtered_batch); + self.probe_partitions[part_id] + .values + .push(filtered_on_values); + self.probe_partitions[part_id] + .hashes + .push(filtered_hashes); + let last = self.probe_partitions[part_id] + .batches + .last() + .unwrap(); + self.probe_buffered_rows_per_part[part_id] += + last.num_rows(); + } + } + } + + return Poll::Ready(Ok(())); + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + self.probe_stream_finished = true; + self.probe_batch_positions = vec![0; self.num_partitions]; + for part_id in 0..self.num_partitions { + self.finalize_spilled_partition(part_id)?; + } + return Poll::Ready(Ok(())); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } + + fn maybe_recursive_repartition( + &mut self, + descriptor: &PartitionDescriptor, + ) -> Result { + if descriptor.build_index >= self.build_partitions.len() { + return Ok(false); + } + match self.build_partitions.get(descriptor.build_index) { + Some(BuildPartition::Spilled { .. }) => {} + _ => return Ok(false), + } + let Some((additional_bits, fanout)) = self.compute_recursive_fanout(descriptor) + else { + return Ok(false); + }; + let new_descriptors = + self.repartition_spilled_partition(descriptor, additional_bits, fanout)?; + if new_descriptors.is_empty() { + return Ok(false); + } + // Enqueue new descriptors in order + for desc in new_descriptors.into_iter().rev() { + self.pending_partitions.push_front(desc); + } + Ok(true) } fn ensure_build_spill_writer<'a>( @@ -354,7 +1162,7 @@ impl PartitionedHashJoinStream { fn spill_partition( &mut self, - _partition_id: usize, + _build_index: usize, accum: &mut PartitionAccumulator, ) -> Result<()> { let buffered_batches = mem::take(&mut accum.buffered_batches); @@ -411,6 +1219,7 @@ impl PartitionedHashJoinStream { self.build_partitions.clear(); self.matched_build_rows_per_partition.clear(); self.current_partition = None; + self.pending_partitions.clear(); self.current_probe_batch = None; self.current_probe_values.clear(); self.current_probe_hashes.clear(); @@ -424,13 +1233,17 @@ impl PartitionedHashJoinStream { self.unmatched_offset = 0; self.probe_partitions.clear(); self.probe_batch_positions.clear(); - self.probes_buffered = false; + self.probe_stream_finished = false; self.pending_reload_stream = None; self.pending_reload_batches.clear(); self.pending_reload_partition = None; self.pending_probe_stream = None; self.pending_probe_partition = None; - self.probe_spill_files.clear(); + for queue in self.probe_spill_files.iter_mut() { + queue.clear(); + } + self.partition_pending.clear(); + self.partition_descriptors.clear(); self.bounds_waiter = None; self.resize_partition_vectors(); @@ -462,6 +1275,69 @@ impl PartitionedHashJoinStream { } } + fn prepare_partition_queue(&mut self) { + self.pending_partitions.clear(); + let radix_bits = + self.num_partitions.next_power_of_two().trailing_zeros() as usize; + for part_id in 0..self.build_partitions.len() { + let (spilled_bytes, spilled_rows) = match &self.build_partitions[part_id] { + BuildPartition::Spilled { + spilled_bytes, + spilled_rows, + .. + } => (*spilled_bytes, *spilled_rows), + _ => (0, 0), + }; + if self.partition_descriptors.len() <= part_id { + self.partition_descriptors.resize_with(part_id + 1, || None); + } + if self.partition_pending.len() <= part_id { + self.partition_pending.resize(part_id + 1, false); + } + self.pending_partitions.push_back(PartitionDescriptor { + build_index: part_id, + root_index: part_id, + generation: self.partition_pass, + radix_bits, + hash_prefix: part_id as u64, + spilled_bytes, + spilled_rows, + }); + if let Some(desc) = self.pending_partitions.back() { + self.partition_descriptors[part_id] = Some(desc.clone()); + self.partition_pending[part_id] = true; + } + } + } + + fn transition_to_next_partition(&mut self) { + if let Some(descriptor) = self.pending_partitions.pop_front() { + let build_index = descriptor.build_index; + if self.partition_descriptors.len() <= build_index { + self.partition_descriptors + .resize_with(build_index + 1, || None); + } + if self.partition_pending.len() <= build_index { + self.partition_pending.resize(build_index + 1, false); + } + self.partition_descriptors[build_index] = Some(descriptor.clone()); + self.partition_pending[build_index] = false; + self.current_partition = Some(build_index); + self.state = + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor, + }); + } else { + self.current_partition = None; + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + } + } + + fn advance_to_next_partition(&mut self) { + self.current_partition = None; + self.transition_to_next_partition(); + } + /// Report build-side bounds to the shared accumulator when dynamic filtering is enabled fn poll_bounds_update( &mut self, @@ -694,6 +1570,7 @@ impl PartitionedHashJoinStream { build_partitions: Vec::new(), probe_partitions: Vec::new(), current_partition: None, + pending_partitions: VecDeque::new(), probe_spill_manager, build_spill_manager, memory_reservation, @@ -717,13 +1594,13 @@ impl PartitionedHashJoinStream { unmatched_left_indices_cache: None, unmatched_right_indices_cache: None, unmatched_offset: 0, - probes_buffered: false, + probe_stream_finished: false, probe_batch_positions: vec![], pending_reload_stream: None, pending_reload_batches: Vec::new(), pending_reload_partition: None, probe_spill_in_progress: (0..num_partitions).map(|_| None).collect(), - probe_spill_files: (0..num_partitions).map(|_| None).collect(), + probe_spill_files: (0..num_partitions).map(|_| VecDeque::new()).collect(), pending_probe_stream: None, pending_probe_partition: None, probe_buffered_rows_per_part: vec![0; num_partitions], @@ -734,198 +1611,11 @@ impl PartitionedHashJoinStream { candidate_pairs_per_part: vec![0; num_partitions], verify_once_per_part: vec![false; num_partitions], filter_debug_once_per_part: vec![false; num_partitions], + partition_pending: vec![false; num_partitions], + partition_descriptors: (0..num_partitions).map(|_| None).collect(), }) } - /// Buffer the entire probe side stream into per-partition batches. - /// Returns Pending until the right stream is fully consumed. - fn buffer_probe_side(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.probe_spill_in_progress.len() != self.num_partitions - || self.probe_spill_files.len() != self.num_partitions - { - self.resize_partition_vectors(); - } - if self.probe_partitions.len() != self.num_partitions { - self.probe_partitions.clear(); - } - if self.probe_partitions.is_empty() { - self.probe_partitions = (0..self.num_partitions) - .map(|_| ProbePartition { - batches: Vec::new(), - values: Vec::new(), - hashes: Vec::new(), - }) - .collect(); - } - loop { - match self.right.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - // Compute ON values for the full batch (once) - // println!( - // "[spill-join] probe batch rows={} schema={:?}", - // batch.num_rows(), - // batch.schema().fields().len() - // ); - let mut keys_values: Vec = - Vec::with_capacity(self.on_right.len()); - for c in &self.on_right { - let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; - keys_values.push(v); - } - // Compute hashes (once) - let mut hashes = vec![0u64; batch.num_rows()]; - create_hashes(&keys_values, &self.random_state, &mut hashes)?; - - // Build per-partition row indices in one pass - let mut indices_per_part: Vec> = - vec![Vec::new(); self.num_partitions]; - for (row_idx, &hash) in hashes.iter().enumerate() { - let pid = self.partition_for_hash(hash) as usize; - indices_per_part[pid].push(row_idx as u32); - } - - // For each non-empty partition, slice both data columns and already computed key values - for part_id in 0..self.num_partitions { - let part_indices = &indices_per_part[part_id]; - if part_indices.is_empty() { - continue; - } - let indices_arr: UInt32Array = part_indices.clone().into(); - if self.probe_partitions[part_id].batches.is_empty() { - // println!( - // "[spill-join] probe partition {} first rows {:?}", - // part_id, - // &part_indices[..part_indices.len().min(10)] - // ); - } - - // Take data columns - let mut filtered_columns: Vec = - Vec::with_capacity(batch.num_columns()); - for col in batch.columns() { - filtered_columns.push( - take(col, &indices_arr, None) - .map_err(DataFusionError::from)?, - ); - } - let filtered_batch = - RecordBatch::try_new(batch.schema(), filtered_columns) - .map_err(DataFusionError::from)?; - - // Take ON key values using precomputed arrays (no re-eval) - let mut filtered_on_values: Vec = - Vec::with_capacity(self.on_right.len()); - for arr in &keys_values { - filtered_on_values.push( - take(arr, &indices_arr, None) - .map_err(DataFusionError::from)?, - ); - } - - // Slice hashes - let mut filtered_hashes: Vec = - Vec::with_capacity(part_indices.len()); - for &i in part_indices.iter() { - filtered_hashes.push(hashes[i as usize]); - } - - // If corresponding build partition is spilled, stream this partition's probe to disk - match self.build_partitions.get_mut(part_id) { - Some(BuildPartition::Spilled { .. }) => { - // Lazily create in-progress file - if self.probe_spill_in_progress[part_id].is_none() { - let ipf = self - .probe_spill_manager - .create_in_progress_file( - "hash_join_probe_partition", - )?; - self.probe_spill_in_progress[part_id] = Some(ipf); - self.join_metrics.probe_spill_count.add(1); - } - if let Some(ref mut ipf) = - self.probe_spill_in_progress[part_id] - { - ipf.append_batch(&filtered_batch)?; - self.join_metrics - .probe_spilled_rows - .add(filtered_batch.num_rows()); - self.join_metrics - .probe_spilled_bytes - .add(filtered_batch.get_array_memory_size()); - // println!( - // "[spill-join][probe-spill] write partition={} rows={}", - // part_id, - // filtered_batch.num_rows() - // ); - } - self.probe_spilled_rows_per_part[part_id] += - filtered_batch.num_rows(); - // Do not RAM-buffer spilled probe partitions - } - _ => { - // Keep in memory for in-memory build partitions - self.probe_partitions[part_id] - .batches - .push(filtered_batch); - self.probe_partitions[part_id] - .values - .push(filtered_on_values); - self.probe_partitions[part_id] - .hashes - .push(filtered_hashes); - // Track buffered rows - let last = self.probe_partitions[part_id] - .batches - .last() - .unwrap(); - self.probe_buffered_rows_per_part[part_id] += - last.num_rows(); - } - } - } - } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(None) => { - // Finished buffering - self.probes_buffered = true; - self.probe_batch_positions = vec![0; self.num_partitions]; - // println!( - // "[spill-join] probe buffered rows per partition = {:?}", - // self.probe_partitions - // .iter() - // .enumerate() - // .map(|(i, p)| (i, p.batches.iter().map(|b| b.num_rows()).sum::())) - // .collect::>() - // ); - // Finalize any in-progress probe spill files - for part_id in 0..self.num_partitions { - if let Some(mut ipf) = - self.probe_spill_in_progress[part_id].take() - { - if let Some(file) = ipf.finish()? { - // println!( - // "[spill-join][probe-spill] finalize partition={} rows_spilled={}", - // part_id, - // self.probe_spilled_rows_per_part[part_id] - // ); - self.probe_spill_files[part_id] = Some(file); - } - } - } - return Poll::Ready(Ok(())); - } - Poll::Pending => { - // println!( - // "[spill-join][probe-buffer] pending batches buffered={:?} spilled_rows={:?}", - // self.probe_buffered_rows_per_part, - // self.probe_spilled_rows_per_part - // ); - return Poll::Pending; - } - } - } - } - /// Partition build-side data into multiple partitions fn partition_build_side( &mut self, @@ -1000,11 +1690,11 @@ impl PartitionedHashJoinStream { let mut indices_per_part: Vec> = vec![Vec::new(); self.num_partitions]; for (row_idx, hash) in hashes.iter().enumerate() { - let partition_id = self.partition_for_hash(*hash); - indices_per_part[partition_id].push(row_idx as u32); + let build_index = self.partition_for_hash(*hash); + indices_per_part[build_index].push(row_idx as u32); } - for (partition_id, indices) in indices_per_part.into_iter().enumerate() { + for (build_index, indices) in indices_per_part.into_iter().enumerate() { if indices.is_empty() { continue; } @@ -1021,7 +1711,7 @@ impl PartitionedHashJoinStream { RecordBatch::try_new(batch.schema(), filtered_columns) .map_err(DataFusionError::from)?; let batch_size = filtered_batch.get_array_memory_size(); - let accum = &mut partition_accumulators[partition_id]; + let accum = &mut partition_accumulators[build_index]; accum.total_rows += filtered_batch.num_rows(); if accum.spill_writer.is_some() { @@ -1048,7 +1738,7 @@ impl PartitionedHashJoinStream { "Insufficient memory for build partitioning and spilling is disabled" )); } - self.spill_partition(partition_id, accum)?; + self.spill_partition(build_index, accum)?; } } Err(_) => { @@ -1063,7 +1753,7 @@ impl PartitionedHashJoinStream { "Unable to allocate memory for build partition" )); } - self.spill_partition(partition_id, accum)?; + self.spill_partition(build_index, accum)?; self.append_spilled_batch(accum, filtered_batch)?; } } @@ -1196,60 +1886,26 @@ impl PartitionedHashJoinStream { }); } - if (max_spilled_bytes > self.memory_threshold || any_spilled) && allow_repartition { + if (max_spilled_bytes > self.memory_threshold || any_spilled) && allow_repartition + { if let Some(next_count) = self.next_partition_count() { return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); } } - self.state = PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { - partition_id: 0, - total_partitions: self.num_partitions, - is_last_partition: self.num_partitions == 1, - }); + self.prepare_partition_queue(); + self.transition_to_next_partition(); Ok(PartitionBuildStatus::Ready(StatefulStreamResult::Continue)) } - /// Take specific rows from a RecordBatch - fn take_rows(&self, batch: &RecordBatch, indices: &[usize]) -> Result { - use arrow::array::UInt32Array; - use arrow::compute::take; - - let indices_array = - UInt32Array::from(indices.iter().map(|&i| i as u32).collect::>()); - - let columns: Result, DataFusionError> = batch - .columns() - .iter() - .map(|col| take(col, &indices_array, None).map_err(|e| e.into())) - .collect(); - - Ok(RecordBatch::try_new(batch.schema(), columns?)?) - } - - /// Take specific rows from an ArrayRef - fn take_rows_from_array( - &self, - array: &ArrayRef, - indices: &[usize], - ) -> Result { - use arrow::array::UInt32Array; - use arrow::compute::take; - - let indices_array = - UInt32Array::from(indices.iter().map(|&i| i as u32).collect::>()); - - Ok(take(array, &indices_array, None).map_err(DataFusionError::from)?) - } - /// Release resources associated with a finished partition when safe to do so. /// Only releases memory eagerly when we don't need unmatched rows in the final phase. - fn release_partition_resources(&mut self, partition_id: usize) { + fn release_partition_resources(&mut self, build_index: usize) { if need_produce_result_in_final(self.join_type) { return; } - if partition_id >= self.build_partitions.len() { + if build_index >= self.build_partitions.len() { return; } @@ -1259,7 +1915,7 @@ impl PartitionedHashJoinStream { .with_can_spill(true) .register(&self.runtime_env.memory_pool); let old_partition = mem::replace( - &mut self.build_partitions[partition_id], + &mut self.build_partitions[build_index], BuildPartition::Released { reservation: placeholder_reservation, }, @@ -1291,7 +1947,7 @@ impl PartitionedHashJoinStream { let empty_hash_map: Box = Box::new(JoinHashMapU32::with_capacity(0)); - self.build_partitions[partition_id] = BuildPartition::InMemory { + self.build_partitions[build_index] = BuildPartition::InMemory { hash_map: empty_hash_map, batch: empty_batch, values: empty_values, @@ -1300,11 +1956,11 @@ impl PartitionedHashJoinStream { } BuildPartition::Spilled { reservation, .. } => { // Transition to Released; no files remain - self.build_partitions[partition_id] = + self.build_partitions[build_index] = BuildPartition::Released { reservation }; } BuildPartition::Released { reservation } => { - self.build_partitions[partition_id] = + self.build_partitions[build_index] = BuildPartition::Released { reservation }; } BuildPartition::Empty => { @@ -1313,20 +1969,65 @@ impl PartitionedHashJoinStream { } } + fn partition_has_pending_probe(&self, part_id: usize) -> bool { + if part_id < self.probe_partitions.len() + && part_id < self.probe_batch_positions.len() + && self.probe_batch_positions[part_id] + < self.probe_partitions[part_id].batches.len() + { + return true; + } + + if self.current_partition.is_some_and(|idx| idx == part_id) + && self.current_probe_batch.is_some() + { + return true; + } + + if part_id < self.probe_spill_files.len() + && !self.probe_spill_files[part_id].is_empty() + { + return true; + } + + if self + .pending_probe_partition + .is_some_and(|idx| idx == part_id) + { + return true; + } + + if part_id < self.probe_spill_in_progress.len() + && self.probe_spill_in_progress[part_id].is_some() + { + return true; + } + + false + } + /// Process a specific partition fn process_partition( &mut self, cx: &mut Context<'_>, partition_state: &ProcessPartitionState, ) -> Poll>>> { + let build_index = partition_state.descriptor.build_index; + // Guard against invalid partition ids (off-by-one protection) - if partition_state.partition_id >= partition_state.total_partitions { + if build_index >= self.build_partitions.len() { self.state = PartitionedHashJoinState::HandleUnmatchedRows; return Poll::Ready(Ok(StatefulStreamResult::Continue)); } - if self.current_partition != Some(partition_state.partition_id) { - self.current_partition = Some(partition_state.partition_id); + if self.maybe_recursive_repartition(&partition_state.descriptor)? { + self.current_partition = None; + self.transition_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + if self.current_partition != Some(build_index) { + self.current_partition = Some(build_index); } // Do not buffer probe side here; selection happens below depending on num_partitions @@ -1336,18 +2037,24 @@ impl PartitionedHashJoinStream { // (Build partition will be immutably borrowed later within a narrower scope) // Ensure the build partition is ready (reload if spilled) BEFORE any immutable borrows - match self.ensure_build_partition_loaded(cx, partition_state.partition_id) { + match self.ensure_build_partition_loaded(cx, build_index) { Poll::Ready(Ok(())) => {} Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } // Ensure probe side is fully buffered into per-partition containers - if !self.probes_buffered { + if !self.probe_stream_finished { match self.buffer_probe_side(cx) { Poll::Ready(Ok(())) => {} Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, + Poll::Pending => { + let no_current_data = !self.partition_has_pending_probe(build_index); + let no_other_pending = self.pending_partitions.is_empty(); + if no_current_data && no_other_pending { + return Poll::Pending; + } + } } } @@ -1356,73 +2063,88 @@ impl PartitionedHashJoinStream { // Decide probe source based on whether we spilled probe for this partition let has_spilled_probe = self .probe_spill_in_progress - .get(partition_state.partition_id) + .get(build_index) .and_then(|o| o.as_ref()) .is_some() || self .probe_spill_files - .get(partition_state.partition_id) - .and_then(|o| o.as_ref()) - .is_some() + .get(build_index) + .map(|queue| !queue.is_empty()) + .unwrap_or(false) || self .pending_probe_partition - .is_some_and(|p| p == partition_state.partition_id); + .is_some_and(|p| p == build_index); let has_buffered_probe = self .probe_partitions - .get(partition_state.partition_id) + .get(build_index) .map(|p| !p.batches.is_empty()) .unwrap_or(false); // Prefer buffered probe batches first; when exhausted, consume spilled probe stream - let pos = self.probe_batch_positions[partition_state.partition_id]; + let pos = self.probe_batch_positions[build_index]; let buffered_len = self .probe_partitions - .get(partition_state.partition_id) + .get(build_index) .map(|p| p.batches.len()) .unwrap_or(0); if has_buffered_probe && pos < buffered_len { - let part = &self.probe_partitions[partition_state.partition_id]; + let part = &self.probe_partitions[build_index]; // Take buffered batch/values/hashes let batch = part.batches[pos].clone(); let values = part.values[pos].clone(); let hashes = part.hashes[pos].clone(); - self.probe_batch_positions[partition_state.partition_id] = pos + 1; + self.probe_batch_positions[build_index] = pos + 1; self.current_probe_batch = Some(batch); self.current_probe_values = values; self.current_probe_hashes = hashes; self.current_offset = (0, None); if let Some(b) = &self.current_probe_batch { - self.probe_consumed_rows_per_part[partition_state.partition_id] = - self.probe_consumed_rows_per_part[partition_state.partition_id] - .saturating_add(b.num_rows()); + self.probe_consumed_rows_per_part[build_index] = self + .probe_consumed_rows_per_part[build_index] + .saturating_add(b.num_rows()); } } else if has_spilled_probe { - // Stream from probe spill file for this partition - if self.pending_probe_partition.is_none() { - let file = self - .probe_spill_files - .get_mut(partition_state.partition_id) - .and_then(|o| o.take()); - if let Some(file) = file { - let stream = - self.probe_spill_manager.read_spill_as_stream(file)?; - self.pending_probe_stream = Some(stream); - self.pending_probe_partition = Some(partition_state.partition_id); - } else { - // Spilled probe indicated but file not yet finalized: wait - // println!( - // "[spill-join] Waiting for spilled probe file for partition {}", - // partition_state.partition_id - // ); - return Poll::Pending; + loop { + if self.pending_probe_partition != Some(build_index) { + let mut next_file = self + .probe_spill_files + .get_mut(build_index) + .and_then(|queue| queue.pop_front()); + if next_file.is_none() + && self.finalize_spilled_partition(build_index)? + { + next_file = self + .probe_spill_files + .get_mut(build_index) + .and_then(|queue| queue.pop_front()); + } + if let Some(file) = next_file { + let stream = + self.probe_spill_manager.read_spill_as_stream(file)?; + self.pending_probe_stream = Some(stream); + self.pending_probe_partition = Some(build_index); + } else { + let writer_open = self + .probe_spill_in_progress + .get(build_index) + .and_then(|o| o.as_ref()) + .is_some(); + if self.probe_stream_finished && !writer_open { + self.pending_probe_stream = None; + self.pending_probe_partition = None; + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } else { + return Poll::Pending; + } + } } - } - if self.pending_probe_partition == Some(partition_state.partition_id) { + if let Some(stream) = self.pending_probe_stream.as_mut() { match stream.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { - // Compute ON values and hashes for this filtered batch let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); for c in &self.on_right { @@ -1443,123 +2165,38 @@ impl PartitionedHashJoinStream { self.current_probe_hashes = hashes; self.current_offset = (0, None); if let Some(b) = &self.current_probe_batch { - self.probe_consumed_rows_per_part - [partition_state.partition_id] = self - .probe_consumed_rows_per_part - [partition_state.partition_id] + self.probe_consumed_rows_per_part[build_index] = self + .probe_consumed_rows_per_part[build_index] .saturating_add(b.num_rows()); } - // println!( - // "[spill-join][probe-spill] partition={} batch rows={}", - // partition_state.partition_id, - // self.current_probe_batch - // .as_ref() - // .map(|b| b.num_rows()) - // .unwrap_or(0) - // ); + break; } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Ready(None) => { - // Finished probe for this partition; advance self.pending_probe_stream = None; self.pending_probe_partition = None; - // println!( - // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - // partition_state.partition_id, - // self.probe_buffered_rows_per_part[partition_state.partition_id], - // self.probe_spilled_rows_per_part[partition_state.partition_id], - // self.probe_consumed_rows_per_part[partition_state.partition_id], - // self.candidate_pairs_per_part[partition_state.partition_id], - // self.matched_rows_per_part[partition_state.partition_id], - // self.emitted_rows_per_part[partition_state.partition_id] - // ); - // println!( - // "[spill-join][probe-spill] partition={} stream complete", - // partition_state.partition_id - // ); - self.release_partition_resources( - partition_state.partition_id, - ); - if partition_state.is_last_partition { - self.current_partition = None; - self.state = - PartitionedHashJoinState::HandleUnmatchedRows; - } else { - self.current_partition = None; - self.state = - PartitionedHashJoinState::ProcessPartition( - ProcessPartitionState { - partition_id: partition_state - .partition_id - + 1, - total_partitions: partition_state - .total_partitions, - is_last_partition: partition_state - .partition_id - + 1 - == partition_state.total_partitions, - }, - ); - } - return Poll::Ready(Ok(StatefulStreamResult::Continue)); + continue; } Poll::Pending => return Poll::Pending, } } else { - // No stream available; nothing to read, advance - self.pending_probe_stream = None; - self.pending_probe_partition = None; - // println!( - // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - // partition_state.partition_id, - // self.probe_buffered_rows_per_part[partition_state.partition_id], - // self.probe_spilled_rows_per_part[partition_state.partition_id], - // self.probe_consumed_rows_per_part[partition_state.partition_id], - // self.candidate_pairs_per_part[partition_state.partition_id], - // self.matched_rows_per_part[partition_state.partition_id], - // self.emitted_rows_per_part[partition_state.partition_id] - // ); - self.release_partition_resources(partition_state.partition_id); - if partition_state.is_last_partition { - self.state = PartitionedHashJoinState::HandleUnmatchedRows; - } else { - self.state = PartitionedHashJoinState::ProcessPartition( - ProcessPartitionState { - partition_id: partition_state.partition_id + 1, - total_partitions: partition_state.total_partitions, - is_last_partition: partition_state.partition_id + 1 - == partition_state.total_partitions, - }, - ); - } - return Poll::Ready(Ok(StatefulStreamResult::Continue)); + return Poll::Pending; } } } else { // Neither spilled nor buffered probe for this partition: advance // println!( // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - // partition_state.partition_id, - // self.probe_buffered_rows_per_part[partition_state.partition_id], - // self.probe_spilled_rows_per_part[partition_state.partition_id], - // self.probe_consumed_rows_per_part[partition_state.partition_id], - // self.candidate_pairs_per_part[partition_state.partition_id], - // self.matched_rows_per_part[partition_state.partition_id], - // self.emitted_rows_per_part[partition_state.partition_id] + // build_index, + // self.probe_buffered_rows_per_part[build_index], + // self.probe_spilled_rows_per_part[build_index], + // self.probe_consumed_rows_per_part[build_index], + // self.candidate_pairs_per_part[build_index], + // self.matched_rows_per_part[build_index], + // self.emitted_rows_per_part[build_index] // ); - self.release_partition_resources(partition_state.partition_id); - if partition_state.is_last_partition { - self.state = PartitionedHashJoinState::HandleUnmatchedRows; - } else { - self.state = PartitionedHashJoinState::ProcessPartition( - ProcessPartitionState { - partition_id: partition_state.partition_id + 1, - total_partitions: partition_state.total_partitions, - is_last_partition: partition_state.partition_id + 1 - == partition_state.total_partitions, - }, - ); - } + self.release_partition_resources(build_index); + self.advance_to_next_partition(); return Poll::Ready(Ok(StatefulStreamResult::Continue)); } } @@ -1568,28 +2205,16 @@ impl PartitionedHashJoinStream { if self.current_probe_batch.is_none() { // println!( // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - // partition_state.partition_id, - // self.probe_buffered_rows_per_part[partition_state.partition_id], - // self.probe_spilled_rows_per_part[partition_state.partition_id], - // self.probe_consumed_rows_per_part[partition_state.partition_id], - // self.candidate_pairs_per_part[partition_state.partition_id], - // self.matched_rows_per_part[partition_state.partition_id], - // self.emitted_rows_per_part[partition_state.partition_id] + // build_index, + // self.probe_buffered_rows_per_part[build_index], + // self.probe_spilled_rows_per_part[build_index], + // self.probe_consumed_rows_per_part[build_index], + // self.candidate_pairs_per_part[build_index], + // self.matched_rows_per_part[build_index], + // self.emitted_rows_per_part[build_index] // ); - self.release_partition_resources(partition_state.partition_id); - if partition_state.is_last_partition { - self.current_partition = None; - self.state = PartitionedHashJoinState::HandleUnmatchedRows; - } else { - self.current_partition = None; - self.state = - PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { - partition_id: partition_state.partition_id + 1, - total_partitions: partition_state.total_partitions, - is_last_partition: partition_state.partition_id + 1 - == partition_state.total_partitions, - }); - } + self.release_partition_resources(build_index); + self.advance_to_next_partition(); return Poll::Ready(Ok(StatefulStreamResult::Continue)); } @@ -1601,7 +2226,7 @@ impl PartitionedHashJoinStream { .ok_or_else(|| internal_datafusion_err!("expected probe batch"))?; let (build_hashmap, build_batch, build_values) = - match self.build_partitions.get(partition_state.partition_id) { + match self.build_partitions.get(build_index) { Some(BuildPartition::InMemory { hash_map, batch, @@ -1669,7 +2294,7 @@ impl PartitionedHashJoinStream { // println!( // "[spill-join] Partition {} build hashmap empty? {}", - // partition_state.partition_id, + // build_index, // build_hashmap.is_empty() // );*/ @@ -1685,8 +2310,8 @@ impl PartitionedHashJoinStream { let probe_indices: UInt32Array = probe_indices.into(); // Track candidate pairs before equality - self.candidate_pairs_per_part[partition_state.partition_id] = self - .candidate_pairs_per_part[partition_state.partition_id] + self.candidate_pairs_per_part[build_index] = self.candidate_pairs_per_part + [build_index] .saturating_add(build_indices.len()); // println!( // "[spill-join] Candidates before equality: build_ids={}, probe_ids={}, build_rows={}, probe_rows={}", @@ -1712,7 +2337,7 @@ impl PartitionedHashJoinStream { && build_values[0].data_type() == &arrow::datatypes::DataType::Int64 && self.current_probe_values[0].data_type() == &arrow::datatypes::DataType::Int64 - && !self.verify_once_per_part[partition_state.partition_id] + && !self.verify_once_per_part[build_index] { use arrow::array::Int64Array; use std::collections::HashMap; @@ -1746,11 +2371,11 @@ impl PartitionedHashJoinStream { } // println!( // "[spill-join][verify] part={} expect_pairs~{} vs actual_after_eq={}", - // partition_state.partition_id, + // build_index, // expect, // build_indices.len() // );*/ - self.verify_once_per_part[partition_state.partition_id] = true; + self.verify_once_per_part[build_index] = true; }*/ // Debug: log key data types and sample matched pairs @@ -1811,11 +2436,11 @@ impl PartitionedHashJoinStream { None, )?; - if !self.filter_debug_once_per_part[partition_state.partition_id] { + if !self.filter_debug_once_per_part[build_index] { /* // println!( // "[spill-join][filter-debug] part={} filter_before={} filter_after={}", - // partition_state.partition_id, + // build_index, // before_len, // filtered_build_indices.len() // ); @@ -1902,13 +2527,13 @@ impl PartitionedHashJoinStream { } }*/ - self.filter_debug_once_per_part[partition_state.partition_id] = true; + self.filter_debug_once_per_part[build_index] = true; } if before_len != filtered_build_indices.len() { // println!( // "[spill-join][filter-debug] part={} filter removed {} rows", - // partition_state.partition_id, + // build_index, // before_len - filtered_build_indices.len() // ); } @@ -1927,7 +2552,7 @@ impl PartitionedHashJoinStream { }; // Log sample matches even if no residual filter remains, to debug equality behavior - /*if !self.filter_debug_once_per_part[partition_state.partition_id] + /*if !self.filter_debug_once_per_part[build_index] || build_indices.len() != probe_indices.len() { let sample = build_indices.len().min(5); @@ -1965,7 +2590,7 @@ impl PartitionedHashJoinStream { // println!( // "[spill-join][match-debug] part={} pair {} build {{{}}} probe {{{}}}", - // partition_state.partition_id, + // build_index, // i, // build_vals, // probe_vals @@ -1975,13 +2600,13 @@ impl PartitionedHashJoinStream { if build_indices.len() != probe_indices.len() { // println!( // "[spill-join][match-debug] part={} MISMATCH len build={} probe={}", - // partition_state.partition_id, + // build_index, // build_indices.len(), // probe_indices.len() // ); } - self.filter_debug_once_per_part[partition_state.partition_id] = true; + self.filter_debug_once_per_part[build_index] = true; }*/ // Debug counter: post-equality (before any alignment) @@ -1994,7 +2619,7 @@ impl PartitionedHashJoinStream { /*if matches!(self.join_type, JoinType::Inner) && build_values.len() == 2 && self.current_probe_values.len() == 2 - && !self.verify_once_per_part[partition_state.partition_id] + && !self.verify_once_per_part[build_index] { use std::collections::HashMap; let mut map: HashMap = HashMap::new(); @@ -2032,15 +2657,15 @@ impl PartitionedHashJoinStream { } // println!( // "[spill-join][verify2] part={} expect_pairs~{} vs actual_after_eq={}", - // partition_state.partition_id, + // build_index, // expect, // build_indices.len() // ); - self.verify_once_per_part[partition_state.partition_id] = true; + self.verify_once_per_part[build_index] = true; }*/ // Accumulate matched rows per partition - self.matched_rows_per_part[partition_state.partition_id] = self - .matched_rows_per_part[partition_state.partition_id] + self.matched_rows_per_part[build_index] = self.matched_rows_per_part + [build_index] .saturating_add(build_indices.len()); // Compute alignment window (used by adjust_indices for all join types) @@ -2135,17 +2760,13 @@ impl PartitionedHashJoinStream { }; let emitted_rows = result.num_rows(); - self.emitted_rows_per_part[partition_state.partition_id] = self - .emitted_rows_per_part[partition_state.partition_id] - .saturating_add(emitted_rows); + self.emitted_rows_per_part[build_index] = + self.emitted_rows_per_part[build_index].saturating_add(emitted_rows); (result, build_ids_to_mark, next_offset) }; // Mark matched build-side rows for outer joins (use current partition's bitmap) - if let Some(bitmap) = self - .matched_build_rows_per_partition - .get_mut(partition_state.partition_id) - { + if let Some(bitmap) = self.matched_build_rows_per_partition.get_mut(build_index) { for build_idx in build_ids_to_mark { bitmap.set_bit(build_idx as usize, true); } @@ -2167,7 +2788,7 @@ impl PartitionedHashJoinStream { if result.num_rows() == 0 { // println!( // "[spill-join] Skipping empty batch emission (partition={})", - // partition_state.partition_id + // build_index // ); return Poll::Ready(Ok(StatefulStreamResult::Continue)); } @@ -2176,7 +2797,7 @@ impl PartitionedHashJoinStream { // println!( // "[spill-join] Emitting batch: rows={} (partition={})", // result.num_rows(), - // partition_state.partition_id + // build_index // ); Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))) } @@ -2493,7 +3114,7 @@ impl Stream for PartitionedHashJoinStream { let empty = RecordBatch::new_empty(self.schema.clone()); // println!( // "[spill-join] Emitting placeholder empty batch for partition {}", - // partition_state.partition_id + // build_index // ); return Poll::Ready(Some(Ok(empty))); } @@ -2501,7 +3122,7 @@ impl Stream for PartitionedHashJoinStream { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { // println!( // "[spill-join] poll_next yielding process batch: rows={} (state partition={})", - // batch.num_rows(), partition_state.partition_id + // batch.num_rows(), build_index // ); return Poll::Ready(Some(Ok(batch))); } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 1b4d8474fed0b..25bbcdbcbecac 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1341,6 +1341,14 @@ pub(crate) struct BuildProbeJoinMetrics { pub(crate) probe_spilled_bytes: metrics::Count, /// Total probe-side rows written to spill pub(crate) probe_spilled_rows: metrics::Count, + /// Number of times recursive repartitioning was triggered + pub(crate) recursive_repartition_events: metrics::Count, + /// Total number of child partitions materialized by recursion + pub(crate) recursive_partitions_created: metrics::Count, + /// Maximum recursion depth reached + pub(crate) recursive_partition_depth: metrics::Gauge, + /// Maximum fan-out applied during recursive repartitioning + pub(crate) recursive_repartition_fanout: metrics::Gauge, /// Total time for joining probe-side batches to the build-side batches pub(crate) join_time: metrics::Time, /// Number of batches consumed by probe-side of this operator @@ -1398,6 +1406,14 @@ impl BuildProbeJoinMetrics { MetricBuilder::new(metrics).counter("probe_spilled_bytes", partition); let probe_spilled_rows = MetricBuilder::new(metrics).counter("probe_spilled_rows", partition); + let recursive_repartition_events = MetricBuilder::new(metrics) + .counter("recursive_repartition_events", partition); + let recursive_partitions_created = MetricBuilder::new(metrics) + .counter("recursive_partitions_created", partition); + let recursive_partition_depth = + MetricBuilder::new(metrics).gauge("recursive_partition_depth", partition); + let recursive_repartition_fanout = + MetricBuilder::new(metrics).gauge("recursive_repartition_fanout", partition); let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); @@ -1418,6 +1434,10 @@ impl BuildProbeJoinMetrics { probe_spill_count, probe_spilled_bytes, probe_spilled_rows, + recursive_repartition_events, + recursive_partitions_created, + recursive_partition_depth, + recursive_repartition_fanout, join_time, input_batches, input_rows, From ddc6a34526aad8eb6258ae3e46ce86d4ea1e1d28 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Fri, 7 Nov 2025 16:21:52 +0200 Subject: [PATCH 18/36] scheduler init --- datafusion/physical-plan/Cargo.toml | 1 + .../physical-plan/src/joins/hash_join/mod.rs | 2 + .../src/joins/hash_join/partitioned.rs | 22 ++ .../src/joins/hash_join/scheduler.rs | 274 ++++++++++++++++++ 4 files changed, 299 insertions(+) create mode 100644 datafusion/physical-plan/src/joins/hash_join/scheduler.rs diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index a21d91c219aa1..110f02bf22de3 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -38,6 +38,7 @@ workspace = true force_hash_collisions = [] tokio_coop = [] tokio_coop_fallback = [] +hybrid_hash_join_scheduler = [] [lib] name = "datafusion_physical_plan" diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 7c4a297414f3c..1a803684852ff 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -21,5 +21,7 @@ pub use exec::HashJoinExec; mod exec; mod partitioned; +#[cfg(feature = "hybrid_hash_join_scheduler")] +mod scheduler; mod shared_bounds; mod stream; diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index ccfb72689517f..e0cd5b97cbb6d 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -48,6 +48,8 @@ use std::mem::{self, size_of}; use std::sync::Arc; use std::task::{Context, Poll}; +#[cfg(feature = "hybrid_hash_join_scheduler")] +use super::scheduler::{HybridTaskScheduler, SchedulerConfig}; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::join_hash_map::{JoinHashMapType, JoinHashMapU32, JoinHashMapU64}; use crate::joins::utils::{ @@ -1617,9 +1619,29 @@ impl PartitionedHashJoinStream { } /// Partition build-side data into multiple partitions + #[cfg(feature = "hybrid_hash_join_scheduler")] fn partition_build_side( &mut self, build_data: Arc, + ) -> Result>> { + let config = SchedulerConfig::from_stream(self); + HybridTaskScheduler::with_build_task(config, build_data) + .run_until_build_finished(self) + } + + /// Partition build-side data into multiple partitions (legacy serial path) + #[cfg(not(feature = "hybrid_hash_join_scheduler"))] + fn partition_build_side( + &mut self, + build_data: Arc, + ) -> Result>> { + self.partition_build_side_serial(build_data) + } + + /// Legacy build partitioning logic shared with the experimental scheduler. + pub(super) fn partition_build_side_serial( + &mut self, + build_data: Arc, ) -> Result>> { if self.partition_pass == 0 { self.join_metrics.build_input_batches.add(1); diff --git a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs new file mode 100644 index 0000000000000..29d9614447cfd --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs @@ -0,0 +1,274 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Experimental Hybrid Hash Join scheduler abstractions. + +#![cfg(feature = "hybrid_hash_join_scheduler")] + +use std::collections::VecDeque; +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; + +use crate::joins::hash_join::exec::JoinLeftData; +use crate::joins::hash_join::partitioned::{ + PartitionDescriptor, PartitionedHashJoinStream, +}; +use crate::joins::utils::StatefulStreamResult; + +use datafusion_common::{internal_datafusion_err, Result}; + +/// Configuration shared across scheduler components. +#[derive(Clone, Debug)] +pub(super) struct SchedulerConfig { + pub memory_threshold: usize, + pub batch_size: usize, + pub max_partition_count: usize, +} + +impl SchedulerConfig { + pub fn from_stream(stream: &PartitionedHashJoinStream) -> Self { + Self { + memory_threshold: stream.memory_threshold, + batch_size: stream.batch_size, + max_partition_count: stream.max_partition_count, + } + } +} + +/// Minimal scheduler capable of running build / probe / finalize tasks. +pub(super) struct HybridTaskScheduler { + config: SchedulerConfig, + ready_queue: VecDeque, +} + +impl HybridTaskScheduler { + pub fn new(config: SchedulerConfig) -> Self { + Self { + config, + ready_queue: VecDeque::new(), + } + } + + pub fn with_build_task( + config: SchedulerConfig, + build_data: Arc, + ) -> Self { + let mut scheduler = Self::new(config.clone()); + scheduler + .ready_queue + .push_back(SchedulerTask::Build(BuildStageTask::new( + config, build_data, + ))); + scheduler + } + + pub fn enqueue_probe_task(&mut self, descriptor: PartitionDescriptor) { + self.ready_queue + .push_back(SchedulerTask::Probe(ProbeStageTask::new( + self.config.clone(), + descriptor, + ))); + } + + pub fn enqueue_finalize_task(&mut self, descriptor: PartitionDescriptor) { + self.ready_queue + .push_back(SchedulerTask::Finalize(FinalizeStageTask::new( + self.config.clone(), + descriptor, + ))); + } + + pub fn run_until_build_finished( + &mut self, + stream: &mut PartitionedHashJoinStream, + ) -> Result>> { + while let Some(task) = self.ready_queue.pop_front() { + match task.poll(stream)? { + TaskPoll::Pending(task) => self.ready_queue.push_back(task), + TaskPoll::BuildFinished(result) => return Ok(result), + TaskPoll::YieldProbe(task) => self.ready_queue.push_back(task), + TaskPoll::YieldFinalize(task) => self.ready_queue.push_back(task), + TaskPoll::ProbeFinished | TaskPoll::FinalizeFinished => continue, + } + } + Err(internal_datafusion_err!( + "scheduler queue exhausted without producing build output" + )) + } +} + +enum SchedulerTask { + Build(BuildStageTask), + Probe(ProbeStageTask), + Finalize(FinalizeStageTask), +} + +enum TaskPoll { + Pending(SchedulerTask), + BuildFinished(StatefulStreamResult>), + /// Probe task yielded without producing output (to be expanded later). + YieldProbe(SchedulerTask), + /// Finalize task yielded without producing output. + YieldFinalize(SchedulerTask), + ProbeFinished, + FinalizeFinished, +} + +impl SchedulerTask { + fn poll(self, stream: &mut PartitionedHashJoinStream) -> Result { + match self { + SchedulerTask::Build(task) => match task.poll(stream)? { + BuildTaskEvent::Pending(next_state) => { + Ok(TaskPoll::Pending(SchedulerTask::Build(next_state))) + } + BuildTaskEvent::Finished(result) => Ok(TaskPoll::BuildFinished(result)), + }, + SchedulerTask::Probe(task) => match task.poll(stream)? { + ProbeTaskEvent::Pending(next_task) => { + Ok(TaskPoll::YieldProbe(SchedulerTask::Probe(next_task))) + } + ProbeTaskEvent::Finished => Ok(TaskPoll::ProbeFinished), + }, + SchedulerTask::Finalize(task) => match task.poll(stream)? { + FinalizeTaskEvent::Pending(next_task) => { + Ok(TaskPoll::YieldFinalize(SchedulerTask::Finalize(next_task))) + } + FinalizeTaskEvent::Finished => Ok(TaskPoll::FinalizeFinished), + }, + } + } +} + +/// Build stage broken into multiple cooperative steps so the scheduler can interleave it. +struct BuildStageTask { + config: SchedulerConfig, + build_data: Option>, + step: BuildTaskStep, + warmup_remaining: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BuildTaskStep { + Init, + Partitioning, + Finished, +} + +impl BuildStageTask { + fn new(config: SchedulerConfig, build_data: Arc) -> Self { + Self { + config, + build_data: Some(build_data), + step: BuildTaskStep::Init, + warmup_remaining: 2, // allow a couple of yields before heavy work + } + } + + fn poll(mut self, stream: &mut PartitionedHashJoinStream) -> Result { + match self.step { + BuildTaskStep::Init => { + if self.warmup_remaining > 0 { + self.warmup_remaining -= 1; + return Ok(BuildTaskEvent::Pending(self)); + } + self.step = BuildTaskStep::Partitioning; + Ok(BuildTaskEvent::Pending(self)) + } + BuildTaskStep::Partitioning => { + let build_data = self.build_data.take().ok_or_else(|| { + internal_datafusion_err!("build task missing input data") + })?; + let result = stream.partition_build_side_serial(build_data)?; + self.step = BuildTaskStep::Finished; + Ok(BuildTaskEvent::Finished(result)) + } + BuildTaskStep::Finished => { + Err(internal_datafusion_err!("build task already finished")) + } + } + } +} + +enum BuildTaskEvent { + Pending(BuildStageTask), + Finished(StatefulStreamResult>), +} + +struct ProbeStageTask { + config: SchedulerConfig, + descriptor: PartitionDescriptor, + yielded_once: bool, +} + +impl ProbeStageTask { + fn new(config: SchedulerConfig, descriptor: PartitionDescriptor) -> Self { + Self { + config, + descriptor, + yielded_once: false, + } + } + + fn poll(self, _stream: &mut PartitionedHashJoinStream) -> Result { + if self.yielded_once { + Ok(ProbeTaskEvent::Finished) + } else { + Ok(ProbeTaskEvent::Pending(Self { + yielded_once: true, + ..self + })) + } + } +} + +enum ProbeTaskEvent { + Pending(ProbeStageTask), + Finished, +} + +struct FinalizeStageTask { + config: SchedulerConfig, + descriptor: PartitionDescriptor, + yielded_once: bool, +} + +impl FinalizeStageTask { + fn new(config: SchedulerConfig, descriptor: PartitionDescriptor) -> Self { + Self { + config, + descriptor, + yielded_once: false, + } + } + + fn poll(self, _stream: &mut PartitionedHashJoinStream) -> Result { + if self.yielded_once { + Ok(FinalizeTaskEvent::Finished) + } else { + Ok(FinalizeTaskEvent::Pending(Self { + yielded_once: true, + ..self + })) + } + } +} + +enum FinalizeTaskEvent { + Pending(FinalizeStageTask), + Finished, +} From 2e85309b75659088f2989f6213164aedfe0df97b Mon Sep 17 00:00:00 2001 From: osipovartem Date: Fri, 7 Nov 2025 21:01:59 +0300 Subject: [PATCH 19/36] upd --- .../src/joins/grace_hash_join/exec.rs | 231 +++++++++--------- .../src/joins/grace_hash_join/stream.rs | 102 ++++---- datafusion/physical-plan/src/joins/mod.rs | 2 +- datafusion/physical-plan/src/spill/mod.rs | 18 +- .../physical-plan/src/spill/spill_manager.rs | 6 +- 5 files changed, 179 insertions(+), 180 deletions(-) diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index 0fca133f79d79..0782d9c84fc1f 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -15,52 +15,47 @@ // specific language governing permissions and limitations // under the License. -use std::fmt; -use std::mem::size_of; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, OnceLock}; -use std::{any::Any, vec}; -use std::fmt::{format, Formatter}; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; use crate::joins::utils::{ - asymmetric_join_output_partitioning, reorder_output_after_swap, swap_join_projection, - update_hash, OnceAsync, OnceFut, + reorder_output_after_swap, swap_join_projection, OnceFut, }; -use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; +use crate::joins::{JoinOn, JoinOnRef, PartitionMode}; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; use crate::spill::get_record_batch_memory_size; -use crate::{displayable, ExecutionPlanProperties, SpillManager}; use crate::{ common::can_project, joins::utils::{ build_join_schema, check_join_is_valid, estimate_join_statistics, - need_produce_result_in_final, symmetric_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, + symmetric_join_output_partitioning, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties, SendableRecordBatchStream, Statistics, }; +use crate::{ExecutionPlanProperties, SpillManager}; +use std::fmt; +use std::fmt::Formatter; +use std::sync::{Arc, OnceLock}; +use std::{any::Any, vec}; -use arrow::array::{ArrayRef, BooleanBufferBuilder, UInt32Array}; -use arrow::compute::{concat, concat_batches, take}; +use arrow::array::UInt32Array; +use arrow::compute::{concat_batches, take}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use arrow::util::bit_util; -use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; -use datafusion_common::utils::memory::estimate_memory_size; -use datafusion_common::{internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, JoinSide, JoinType, NullEquality, Result}; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_common::{ + internal_err, plan_err, project_schema, JoinSide, JoinType, + NullEquality, Result, +}; use datafusion_execution::TaskContext; -use datafusion_expr::{Accumulator, UserDefinedLogicalNode}; use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, @@ -68,19 +63,16 @@ use datafusion_physical_expr::equivalence::{ use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use ahash::RandomState; -use arrow_ord::partition::partition; -use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{StreamExt, TryStreamExt}; -use futures::executor::block_on; -use parking_lot::Mutex; -use datafusion_common::hash_utils::create_hashes; -use datafusion_execution::runtime_env::RuntimeEnv; -use crate::empty::EmptyExec; -use crate::joins::grace_hash_join::stream::{GraceAccumulator, GraceHashJoinStream, SpillFut}; +use crate::joins::grace_hash_join::stream::{ + GraceAccumulator, GraceHashJoinStream, SpillFut, +}; use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; use crate::metrics::SpillMetrics; -use crate::spill::spill_manager::{GetSlicedSize, SpillLocation}; +use crate::spill::spill_manager::SpillLocation; +use ahash::RandomState; +use datafusion_common::hash_utils::create_hashes; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::StreamExt; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. const HASH_JOIN_SEED: RandomState = @@ -148,7 +140,6 @@ impl fmt::Debug for GraceHashJoinExec { } } - impl EmbeddedProjection for GraceHashJoinExec { fn with_projection(&self, projection: Option>) -> Result { self.with_projection(projection) @@ -331,7 +322,8 @@ impl GraceHashJoinExec { on, )?; - let mut output_partitioning = symmetric_join_output_partitioning(left, right, &join_type)?; + let mut output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type)?; let emission_type = if left.boundedness().is_unbounded() { EmissionType::Final } else if right.pipeline_behavior() == EmissionType::Incremental { @@ -399,7 +391,7 @@ impl GraceHashJoinExec { /// insert `RepartitionExec` operators). pub fn swap_inputs( &self, - partition_mode: PartitionMode, + _partition_mode: PartitionMode, ) -> Result> { let left = self.left(); let right = self.right(); @@ -677,8 +669,11 @@ impl ExecutionPlan for GraceHashJoinExec { spill_left_clone, spill_right_clone, partition, - ).await?; - accumulator_clone.report_partition(partition, left_idx.clone(), right_idx.clone()).await; + ) + .await?; + accumulator_clone + .report_partition(partition, left_idx.clone(), right_idx.clone()) + .await; Ok(SpillFut::new(partition, left_idx, right_idx)) }); @@ -710,9 +705,6 @@ impl ExecutionPlan for GraceHashJoinExec { if partition.is_some() { return Ok(Statistics::new_unknown(&self.schema())); } - // TODO stats: it is not possible in general to know the output size of joins - // There are some special cases though, for example: - // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` let stats = estimate_join_statistics( self.left.partition_statistics(None)?, self.right.partition_statistics(None)?, @@ -895,31 +887,31 @@ pub async fn partition_and_spill( let on_left: Vec<_> = on.iter().map(|(l, _)| Arc::clone(l)).collect(); let on_right: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); - // === LEFT side partitioning === + // LEFT side partitioning let left_index = partition_and_spill_one_side( &mut left_stream, &on_left, &random_state, partition_count, spill_left, + &format!("left_{partition}"), &join_metrics, enable_dynamic_filter_pushdown, - &format!("left_{partition}"), ) - .await?; + .await?; - // === RIGHT side partitioning === + // RIGHT side partitioning let right_index = partition_and_spill_one_side( &mut right_stream, &on_right, &random_state, partition_count, spill_right, + &format!("right_{partition}"), &join_metrics, enable_dynamic_filter_pushdown, - &format!("right_{partition}"), ) - .await?; + .await?; Ok((left_index, right_index)) } @@ -929,23 +921,22 @@ async fn partition_and_spill_one_side( random_state: &RandomState, partition_count: usize, spill_manager: Arc, + spilling_request_msg: &str, join_metrics: &BuildProbeJoinMetrics, - enable_dynamic_filter_pushdown: bool, - file_request_msg: &str, + _enable_dynamic_filter_pushdown: bool, ) -> Result> { let mut partitions: Vec = (0..partition_count) .map(|_| PartitionWriter::new(Arc::clone(&spill_manager))) .collect(); let mut buffered_batches = Vec::new(); - let mut total_rows = 0usize; + let schema = input.schema(); while let Some(batch) = input.next().await { let batch = batch?; if batch.num_rows() == 0 { continue; } - total_rows += batch.num_rows(); join_metrics.build_input_batches.add(1); join_metrics.build_input_rows.add(batch.num_rows()); buffered_batches.push(batch); @@ -990,7 +981,8 @@ async fn partition_and_spill_one_side( .collect::>>()?; let part_batch = RecordBatch::try_new(single_batch.schema(), taken)?; - let request_msg = format!("grace_partition_{file_request_msg}_{i}"); + // We need unique name for spilling + let request_msg = format!("grace_partition_{spilling_request_msg}_{i}"); partitions[i].spill_batch_auto(&part_batch, &request_msg)?; } @@ -999,7 +991,7 @@ async fn partition_and_spill_one_side( for (i, writer) in partitions.into_iter().enumerate() { result.push(writer.finish(i)?); } - println!("spill_manager {:?}", spill_manager.metrics); + // println!("spill_manager {:?}", spill_manager.metrics); Ok(result) } @@ -1021,7 +1013,11 @@ impl PartitionWriter { } } - pub fn spill_batch_auto(&mut self, batch: &RecordBatch, request_msg: &str) -> Result<()> { + pub fn spill_batch_auto( + &mut self, + batch: &RecordBatch, + request_msg: &str, + ) -> Result<()> { let loc = self.spill_manager.spill_batch_auto(batch, request_msg)?; self.total_rows += batch.num_rows(); self.total_bytes += get_record_batch_memory_size(batch); @@ -1066,32 +1062,16 @@ pub struct PartitionIndex { #[cfg(test)] mod tests { use super::*; - use crate::coalesce_partitions::CoalescePartitionsExec; - use crate::test::{assert_join_metrics, TestMemoryExec}; + use crate::test::TestMemoryExec; use crate::{ common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, - test::exec::MockExec, }; - use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; - use arrow::buffer::NullBuffer; - use arrow::datatypes::{DataType, Field}; - use arrow::util::pretty::print_batches; - use arrow_schema::Schema; - use futures::future; - use datafusion_common::hash_utils::create_hashes; - use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; - use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, - ScalarValue, - }; - use datafusion_execution::config::SessionConfig; - use hashbrown::HashTable; - use insta::{allow_duplicates, assert_snapshot}; - use rstest::*; - use rstest_reuse::*; - use datafusion_execution::runtime_env::RuntimeEnvBuilder; use crate::joins::HashJoinExec; + use arrow::array::{ArrayRef, Int32Array}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr::Partitioning; + use futures::future; fn build_large_table( a_name: &str, @@ -1100,18 +1080,30 @@ mod tests { n: usize, ) -> Arc { let a: ArrayRef = Arc::new(Int32Array::from_iter_values(1..=n as i32)); - let b: ArrayRef = Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 2))); - let c: ArrayRef = Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 10))); + let b: ArrayRef = + Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 2))); + let c: ArrayRef = + Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 10))); let schema = Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new(a_name, arrow::datatypes::DataType::Int32, false), - arrow::datatypes::Field::new(b_name, arrow::datatypes::DataType::Int32, false), - arrow::datatypes::Field::new(c_name, arrow::datatypes::DataType::Int32, false), + arrow::datatypes::Field::new( + a_name, + arrow::datatypes::DataType::Int32, + false, + ), + arrow::datatypes::Field::new( + b_name, + arrow::datatypes::DataType::Int32, + false, + ), + arrow::datatypes::Field::new( + c_name, + arrow::datatypes::DataType::Int32, + false, + ), ])); let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap(); - - // MemoryExec требует список партиций: Vec> Arc::new(TestMemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } @@ -1126,8 +1118,7 @@ mod tests { } #[tokio::test] - async fn single_partition_join_overallocation() -> Result<()> - { + async fn simple_grace_hash_join() -> Result<()> { // let left = build_table( // ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), // ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), @@ -1138,27 +1129,28 @@ mod tests { // ("b2", &vec![1, 2]), // ("c2", &vec![14, 15]), // ); - let left = build_large_table("a1", "b1", "c1", 200); - let right = build_large_table("a2", "b2", "c2", 500); + let left = build_large_table("a1", "b1", "c1", 2000000); + let right = build_large_table("a2", "b2", "c2", 5000000); let on = vec![( - Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (left_expr, right_expr): (Vec>, Vec>) = on + let (left_expr, right_expr): ( + Vec>, + Vec>, + ) = on .iter() .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) .unzip(); - let left_repartitioned: Arc = Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(left_expr, 32), - )?); - let right_repartitioned: Arc = Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(right_expr, 32), - )?); + let left_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(left, Partitioning::Hash(left_expr, 32))?, + ); + let right_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(right, Partitioning::Hash(right_expr, 32))?, + ); let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(50_000_000_000, 1.0) + .with_memory_limit(500_000_000, 1.0) .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); @@ -1174,8 +1166,6 @@ mod tests { )?; let partition_count = right_repartitioned.output_partitioning().partition_count(); - println!("partition_count {partition_count}"); - let tasks: Vec<_> = (0..partition_count) .map(|i| { let ctx = Arc::clone(&task_ctx); @@ -1191,8 +1181,10 @@ mod tests { v.retain(|b| b.num_rows() > 0); batches.extend(v); } + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + println!("TOTAL ROWS = {}", total_rows); - print_batches(&*batches).unwrap(); + // print_batches(&*batches).unwrap(); // Asserting that operator-level reservation attempting to overallocate // assert_contains!( // err.to_string(), @@ -1207,26 +1199,27 @@ mod tests { } #[tokio::test] - async fn single_partition_join_overallocation_f() -> Result<()> { - - let left = build_large_table("a1", "b1", "c1", 200); - let right = build_large_table("a2", "b2", "c2", 500); + async fn simple_hash_join() -> Result<()> { + let left = build_large_table("a1", "b1", "c1", 2000000); + let right = build_large_table("a2", "b2", "c2", 5000000); let on = vec![( - Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (left_expr, right_expr): (Vec>, Vec>) = on + let (left_expr, right_expr): ( + Vec>, + Vec>, + ) = on .iter() .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) .unzip(); - let left_repartitioned: Arc = Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(left_expr, 32), - )?); - let right_repartitioned: Arc = Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(right_expr, 32), - )?); + let left_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(left, Partitioning::Hash(left_expr, 32))?, + ); + let right_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(right, Partitioning::Hash(right_expr, 32))?, + ); + let partition_count = left_repartitioned.output_partitioning().partition_count(); let join = HashJoinExec::try_new( left_repartitioned, @@ -1241,7 +1234,7 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let mut batches = vec![]; - for i in 0..32 { + for i in 0..partition_count { let stream = join.execute(i, Arc::clone(&task_ctx))?; let more_batches = common::collect(stream).await?; batches.extend( @@ -1251,8 +1244,10 @@ mod tests { .collect::>(), ); } - print_batches(&*batches).unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + println!("TOTAL ROWS = {}", total_rows); + + // print_batches(&*batches).unwrap(); Ok(()) } - } diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs index 97e48fc81a362..a6afdef1029d1 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs @@ -24,33 +24,24 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::joins::utils::{ - equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, +use crate::joins::utils::OnceFut; +use crate::{ + joins::utils::{BuildProbeJoinMetrics, ColumnIndex, JoinFilter}, + ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, SpillManager, }; -use crate::{handle_state, hash_utils::create_hashes, joins::join_hash_map::JoinHashMapOffset, joins::utils::{ - adjust_indices_by_join_type, apply_join_filter_to_indices, - build_batch_empty_build_side, build_batch_from_indices, - need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - JoinHashMapType, StatefulStreamResult, -}, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, SpillManager}; - -use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use datafusion_common::{ - internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result, -}; -use datafusion_physical_expr::PhysicalExprRef; -use ahash::RandomState; -use futures::{ready, FutureExt, Stream, StreamExt}; -use tokio::sync::Mutex; -use datafusion_execution::TaskContext; + use crate::empty::EmptyExec; use crate::joins::grace_hash_join::exec::PartitionIndex; use crate::joins::{HashJoinExec, PartitionMode}; -use crate::memory::MemoryStream; -use crate::stream::RecordBatchStreamAdapter; use crate::test::TestMemoryExec; +use ahash::RandomState; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{JoinType, NullEquality, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExprRef; +use futures::{ready, Stream, StreamExt}; +use tokio::sync::Mutex; enum GraceJoinState { /// Waiting for the partitioning phase (Phase 1) to finish @@ -72,19 +63,6 @@ enum GraceJoinState { Done, } -/// Container for HashJoinStreamState::ProcessProbeBatch related data -#[derive(Debug, Clone)] -pub(super) struct ProcessProbeBatchState { - /// Current probe-side batch - batch: RecordBatch, - /// Probe-side on expressions values - values: Vec, - /// Starting offset for JoinHashMap lookups - offset: JoinHashMapOffset, - /// Max joined probe-side index from current batch - joined_probe_idx: Option, -} - pub struct GraceHashJoinStream { schema: SchemaRef, spill_fut: OnceFut, @@ -105,10 +83,14 @@ pub struct GraceHashJoinStream { pub struct SpillFut { partition: usize, left: Vec, - right: Vec + right: Vec, } impl SpillFut { - pub(crate) fn new(partition: usize, left: Vec, right: Vec) -> Self { + pub(crate) fn new( + partition: usize, + left: Vec, + right: Vec, + ) -> Self { SpillFut { partition, left, @@ -156,7 +138,10 @@ impl GraceHashJoinStream { } /// Core state machine logic (poll implementation) - fn poll_next_impl(&mut self, cx: &mut Context<'_>) -> Poll>> { + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { loop { match &mut self.state { GraceJoinState::WaitPartitioning => { @@ -165,6 +150,7 @@ impl GraceHashJoinStream { let acc = Arc::clone(&self.accumulator); let left = shared.left.clone(); let right = shared.right.clone(); + // Use 0 partition as the main let wait_all_fut = if shared.partition == 0 { OnceFut::new(async move { acc.report_partition(shared.partition, left, right).await; @@ -178,7 +164,9 @@ impl GraceHashJoinStream { Ok(vec![]) }) }; - self.state = GraceJoinState::WaitAllPartitions { wait_all_fut: Some(wait_all_fut) }; + self.state = GraceJoinState::WaitAllPartitions { + wait_all_fut: Some(wait_all_fut), + }; continue; } GraceJoinState::WaitAllPartitions { wait_all_fut } => { @@ -215,12 +203,21 @@ impl GraceHashJoinStream { if current_stream.is_none() { if left_fut.is_none() && right_fut.is_none() { let spill_fut = &all_parts[*current]; - *left_fut = Some(load_partition_async(Arc::clone(&self.spill_left), spill_fut.left.clone())); - *right_fut = Some(load_partition_async(Arc::clone(&self.spill_right), spill_fut.right.clone())); + *left_fut = Some(load_partition_async( + Arc::clone(&self.spill_left), + spill_fut.left.clone(), + )); + *right_fut = Some(load_partition_async( + Arc::clone(&self.spill_right), + spill_fut.right.clone(), + )); } - let left_batches = (*ready!(left_fut.as_mut().unwrap().get_shared(cx))?).clone(); - let right_batches = (*ready!(right_fut.as_mut().unwrap().get_shared(cx))?).clone(); + let left_batches = + (*ready!(left_fut.as_mut().unwrap().get_shared(cx))?).clone(); + let right_batches = + (*ready!(right_fut.as_mut().unwrap().get_shared(cx))?) + .clone(); let stream = build_in_memory_join_stream( Arc::clone(&self.schema), @@ -265,7 +262,7 @@ fn load_partition_async( ) -> OnceFut> { OnceFut::new(async move { let mut all_batches = Vec::new(); - println!("partitions {:?}", partitions); + for p in partitions { for chunk in p.chunks { let mut reader = spill_manager.load_spilled_batch(&chunk)?; @@ -299,18 +296,21 @@ fn build_in_memory_join_stream( let left_schema = left_batches .first() .map(|b| b.schema()) - .unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty())); + .unwrap_or_else(|| Arc::new(Schema::empty())); let right_schema = right_batches .first() .map(|b| b.schema()) - .unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty())); + .unwrap_or_else(|| Arc::new(Schema::empty())); // Build memory execution nodes for each side let left_plan: Arc = Arc::new(TestMemoryExec::try_new(&[left_batches], left_schema, None)?); - let right_plan: Arc = - Arc::new(TestMemoryExec::try_new(&[right_batches], right_schema, None)?); + let right_plan: Arc = Arc::new(TestMemoryExec::try_new( + &[right_batches], + right_schema, + None, + )?); // Combine join expressions into pairs let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = on_left @@ -380,9 +380,7 @@ impl GraceAccumulator { } } - pub async fn wait_all( - &self, - ) -> Vec { + pub async fn wait_all(&self) -> Vec { loop { { let guard = self.collected.lock().await; @@ -404,4 +402,4 @@ impl GraceAccumulator { self.notify.notified().await; } } -} \ No newline at end of file +} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 14429ec55182a..8ee4c3de430a3 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -27,8 +27,8 @@ use parking_lot::Mutex; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; -mod hash_join; mod grace_hash_join; +mod hash_join; mod nested_loop_join; mod sort_merge_join; diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 9d19e72cb379a..782100e6d4cf1 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -17,9 +17,9 @@ //! Defines the spilling functions +pub(crate) mod in_memory_spill_buffer; pub(crate) mod in_progress_spill_file; pub(crate) mod spill_manager; -pub(crate) mod in_memory_spill_buffer; use std::fs::File; use std::io::BufReader; @@ -387,8 +387,8 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; use futures::StreamExt as _; - use std::sync::Arc; use datafusion_execution::memory_pool::{FairSpillPool, MemoryPool}; + use std::sync::Arc; #[tokio::test] async fn test_batch_spill_and_read() -> Result<()> { @@ -456,8 +456,8 @@ mod tests { // --- create small memory pool (simulate memory pressure) --- let memory_limit_bytes = 20 * 1024; // 20KB - let memory_pool: Arc = Arc::new(FairSpillPool::new(memory_limit_bytes)); - + let memory_pool: Arc = + Arc::new(FairSpillPool::new(memory_limit_bytes)); // Construct SpillManager let env = RuntimeEnvBuilder::new() @@ -469,8 +469,14 @@ mod tests { let results = spill_manager.spill_batches_auto(&batches, "TestAutoSpill")?; assert_eq!(results.len(), 2); - let mem_count = results.iter().filter(|r| matches!(r, SpillLocation::Memory(_))).count(); - let disk_count = results.iter().filter(|r| matches!(r, SpillLocation::Disk(_))).count(); + let mem_count = results + .iter() + .filter(|r| matches!(r, SpillLocation::Memory(_))) + .count(); + let disk_count = results + .iter() + .filter(|r| matches!(r, SpillLocation::Disk(_))) + .count(); assert!(mem_count >= 1); assert!(disk_count >= 1); diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 2c4d6b7844cd9..10f90c6cb3492 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -179,7 +179,7 @@ impl SpillManager { )); }; Ok(SpillLocation::Disk(Arc::new(file))) - // + // // // let size = batch.get_sliced_size()?; // // // Check current memory usage and total limit from the runtime memory pool @@ -188,9 +188,9 @@ impl SpillManager { // datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, // _ => usize::MAX, // }; - // + // println!("size {size} used {used}"); // // If there's enough memory (with a small safety margin), keep it in memory - // if used + size * 3 / 2 <= limit { + // if used + size * 3 * 64 / 2 <= limit { // let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); // self.metrics.spilled_bytes.add(size); // self.metrics.spilled_rows.add(batch.num_rows()); From c78d2e9091b96f48e5406601e88525850ee16457 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Sat, 8 Nov 2025 18:13:20 +0200 Subject: [PATCH 20/36] Partly working scheduler (incorrect results) --- .../src/joins/hash_join/partitioned.rs | 1724 +++++++++++++---- .../src/joins/hash_join/scheduler.rs | 184 +- 2 files changed, 1523 insertions(+), 385 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index e0cd5b97cbb6d..7651a67128fcf 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -47,9 +47,13 @@ use std::collections::VecDeque; use std::mem::{self, size_of}; use std::sync::Arc; use std::task::{Context, Poll}; +use std::time::SystemTime; #[cfg(feature = "hybrid_hash_join_scheduler")] -use super::scheduler::{HybridTaskScheduler, SchedulerConfig}; +use super::scheduler::{ + HybridTaskScheduler, ProbeDataPoll, ProbePartitionState, ProbeStageTask, + SchedulerConfig, SchedulerTask, TaskPoll, +}; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::join_hash_map::{JoinHashMapType, JoinHashMapU32, JoinHashMapU64}; use crate::joins::utils::{ @@ -98,6 +102,27 @@ fn highest_power_of_two_leq(n: usize) -> usize { } } +fn max_partitions_allowed_for_memory(memory_threshold: usize) -> usize { + let mut slots = memory_threshold + .checked_div(HYBRID_HASH_MIN_PARTITION_BYTES) + .unwrap_or(usize::MAX); + if slots == 0 { + slots = 1; + } + highest_power_of_two_leq(slots) +} + +#[inline] +fn hhj_debug String>(builder: F) { + if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() { + let ts = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map(|d| d.as_millis()) + .unwrap_or(0); + println!("[hhj-debug {ts}] {}", builder()); + } +} + /// State of the partitioned hash join stream #[derive(Debug, Clone)] pub(super) enum PartitionedHashJoinState { @@ -105,6 +130,9 @@ pub(super) enum PartitionedHashJoinState { PartitionBuildSide, /// Processing a specific partition ProcessPartition(ProcessPartitionState), + /// Waiting for partitions that are throttled on probe IO to resume + #[cfg(feature = "hybrid_hash_join_scheduler")] + WaitingForProbe, /// All partitions processed, handling unmatched rows for outer joins HandleUnmatchedRows, /// Join completed @@ -163,7 +191,7 @@ pub(super) struct ProbePartition { } impl ProbePartition { - fn new() -> Self { + pub(super) fn new() -> Self { Self { batches: Vec::new(), values: Vec::new(), @@ -172,6 +200,63 @@ impl ProbePartition { } } +/// Runtime state tracked per probe partition. +#[cfg(not(feature = "hybrid_hash_join_scheduler"))] +pub(super) struct ProbePartitionState { + buffered: ProbePartition, + batch_position: usize, + buffered_rows: usize, + spilled_rows: usize, + consumed_rows: usize, + spill_in_progress: Option, + spill_files: VecDeque, + pending_stream: Option, + active_batch: Option, + active_values: Vec, + active_hashes: Vec, + active_offset: crate::joins::join_hash_map::JoinHashMapOffset, + joined_probe_idx: Option, +} + +#[cfg(not(feature = "hybrid_hash_join_scheduler"))] +impl ProbePartitionState { + fn new() -> Self { + Self { + buffered: ProbePartition::new(), + batch_position: 0, + buffered_rows: 0, + spilled_rows: 0, + consumed_rows: 0, + spill_in_progress: None, + spill_files: VecDeque::new(), + pending_stream: None, + active_batch: None, + active_values: Vec::new(), + active_hashes: Vec::new(), + active_offset: (0, None), + joined_probe_idx: None, + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn prepare_probe_values( + &self, + batch: &RecordBatch, + ) -> Result<(Vec, Vec)> { + let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + keys_values.push(c.evaluate(batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + Ok((keys_values, hashes)) + } + + fn reset(&mut self) { + *self = Self::new(); + } +} + enum PartitionBuildStatus { Ready(StatefulStreamResult>), NeedMorePartitions { next_count: usize }, @@ -206,7 +291,7 @@ impl Default for PartitionAccumulator { #[derive(Debug, Clone)] pub(super) struct PartitionDescriptor { /// Index into build/probe storage vectors - build_index: usize, + pub(super) build_index: usize, /// Index of the original (generation 0) partition root_index: usize, /// Number of refinement passes applied so far @@ -273,7 +358,19 @@ pub(super) struct PartitionedHashJoinStream { /// Build-side partitions pub build_partitions: Vec, /// Probe-side partitions - pub probe_partitions: Vec, + pub probe_states: Vec, + /// Scheduler used to coordinate probe tasks + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_task_scheduler: HybridTaskScheduler, + /// Whether a scheduler task is currently in-flight per partition + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_inflight: Vec, + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_waiting_for_stream: VecDeque, + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_active_streams: usize, + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_max_streams: usize, /// Current partition being processed pub current_partition: Option, /// Queue of pending partitions to process (supports recursive fan-out) @@ -305,16 +402,6 @@ pub(super) struct PartitionedHashJoinStream { pub build_schema: SchemaRef, /// Cached probe-side schema pub probe_schema: SchemaRef, - /// Current probe batch (filtered to the active partition), if any - pub current_probe_batch: Option, - /// Current probe values for ON expressions - pub current_probe_values: Vec, - /// Current probe hashes (filtered to the active partition) - pub current_probe_hashes: Vec, - /// Current lookup offset within the join hash map - pub current_offset: crate::joins::join_hash_map::JoinHashMapOffset, - /// Max joined probe-side index from current batch (for Right/Semi/Anti alignment) - pub joined_probe_idx: Option, /// Bitmaps to track matched build-side rows for outer joins (one per partition) pub matched_build_rows_per_partition: Vec, /// Current partition being processed for unmatched rows @@ -325,14 +412,6 @@ pub(super) struct PartitionedHashJoinStream { pub unmatched_offset: usize, /// Whether the probe stream has reached EOF pub probe_stream_finished: bool, - /// Current read position per partition within buffered probe batches - pub probe_batch_positions: Vec, - /// Metrics: total probe rows buffered per partition (RAM) - pub probe_buffered_rows_per_part: Vec, - /// Metrics: total probe rows spilled per partition (disk) - pub probe_spilled_rows_per_part: Vec, - /// Metrics: total probe rows consumed during probing per partition - pub probe_consumed_rows_per_part: Vec, /// Metrics: total matches after equality per partition pub matched_rows_per_part: Vec, /// Metrics: total rows emitted per partition @@ -349,20 +428,21 @@ pub(super) struct PartitionedHashJoinStream { pub pending_reload_batches: Vec, /// Target partition id for pending reload pub pending_reload_partition: Option, - /// In-progress probe spill writers, one per partition (used when corresponding build is spilled) - pub probe_spill_in_progress: Vec>, - /// Finalized probe spill files per partition (queue of ready-to-read files) - pub probe_spill_files: Vec>, - /// Pending probe stream for the current partition's probe spill file - pub pending_probe_stream: Option, - /// Target partition id for pending probe stream - pub pending_probe_partition: Option, /// Whether a partition is currently queued for processing pub partition_pending: Vec, /// Latest descriptor metadata per partition pub partition_descriptors: Vec>, } +#[cfg(feature = "hybrid_hash_join_scheduler")] +#[derive(Debug)] +enum ProbeTaskStatus { + Ready, + Pending, + WaitingForStream, + Finished, +} + impl PartitionedHashJoinStream { /// Compute partition id for a given hash using radix mask when possible #[inline] @@ -377,12 +457,16 @@ impl PartitionedHashJoinStream { fn resize_partition_vectors(&mut self) { let n = self.num_partitions; - self.probe_spill_in_progress = (0..n).map(|_| None).collect(); - self.probe_spill_files = (0..n).map(|_| VecDeque::new()).collect(); - self.probe_batch_positions = vec![0; n]; - self.probe_buffered_rows_per_part = vec![0; n]; - self.probe_spilled_rows_per_part = vec![0; n]; - self.probe_consumed_rows_per_part = vec![0; n]; + self.probe_states = (0..n).map(|_| ProbePartitionState::new()).collect(); + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + self.probe_scheduler_inflight = vec![false; n]; + self.probe_scheduler_waiting_for_stream = VecDeque::new(); + self.probe_scheduler_active_streams = 0; + self.probe_scheduler_max_streams = std::cmp::max(1, std::cmp::min(4, n)); + self.probe_task_scheduler = + HybridTaskScheduler::new(SchedulerConfig::from_stream(self)); + } self.matched_rows_per_part = vec![0; n]; self.emitted_rows_per_part = vec![0; n]; self.candidate_pairs_per_part = vec![0; n]; @@ -392,18 +476,28 @@ impl PartitionedHashJoinStream { self.partition_descriptors = (0..n).map(|_| None).collect(); } + fn probe_state(&self, idx: usize) -> Result<&ProbePartitionState> { + self.probe_states + .get(idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition")) + } + + fn probe_state_mut(&mut self, idx: usize) -> Result<&mut ProbePartitionState> { + self.probe_states + .get_mut(idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition")) + } + fn allocate_partition_slot(&mut self) -> usize { let idx = self.build_partitions.len(); self.build_partitions.push(BuildPartition::Empty); self.matched_build_rows_per_partition .push(BooleanBufferBuilder::new(0)); - self.probe_partitions.push(ProbePartition::new()); - self.probe_batch_positions.push(0); - self.probe_spill_in_progress.push(None); - self.probe_spill_files.push(VecDeque::new()); - self.probe_buffered_rows_per_part.push(0); - self.probe_spilled_rows_per_part.push(0); - self.probe_consumed_rows_per_part.push(0); + self.probe_states.push(ProbePartitionState::new()); + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + self.probe_scheduler_inflight.push(false); + } self.matched_rows_per_part.push(0); self.emitted_rows_per_part.push(0); self.candidate_pairs_per_part.push(0); @@ -434,8 +528,10 @@ impl PartitionedHashJoinStream { .get(part_id) .and_then(|d| d.clone()) { - self.pending_partitions.push_back(desc); + self.pending_partitions.push_back(desc.clone()); self.partition_pending[part_id] = true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&desc); } Ok(()) @@ -445,26 +541,46 @@ impl PartitionedHashJoinStream { &mut self, part_id: usize, ) -> Result> { - if part_id >= self.probe_spill_in_progress.len() { - return Ok(None); - } - if let Some(mut writer) = self.probe_spill_in_progress[part_id].take() { - let file = writer.finish()?; - return Ok(file); + if let Some(state) = self.probe_states.get_mut(part_id) { + if let Some(mut writer) = state.spill_in_progress.take() { + return writer.finish(); + } } Ok(None) } + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn ensure_probe_scheduler_capacity(&mut self, part_id: usize) { + if self.probe_scheduler_inflight.len() <= part_id { + self.probe_scheduler_inflight.resize(part_id + 1, false); + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn schedule_probe_task(&mut self, descriptor: &PartitionDescriptor) { + let part_id = descriptor.build_index; + self.ensure_probe_scheduler_capacity(part_id); + if self.probe_scheduler_inflight[part_id] { + hhj_debug(|| format!("schedule_probe_task skip part {part_id} (inflight)")); + return; + } + let task = SchedulerTask::Probe(ProbeStageTask::new( + SchedulerConfig::from_stream(self), + descriptor.clone(), + )); + self.probe_task_scheduler.push_task(task); + self.probe_scheduler_inflight[part_id] = true; + hhj_debug(|| format!("schedule_probe_task queued part {part_id}")); + } + fn finalize_spilled_partition(&mut self, part_id: usize) -> Result { - if part_id >= self.probe_spill_in_progress.len() { + if part_id >= self.probe_states.len() { return Ok(false); } if let Some(file) = self.flush_probe_writer(part_id)? { - if part_id >= self.probe_spill_files.len() { - self.probe_spill_files - .resize_with(part_id + 1, VecDeque::new); + if let Some(state) = self.probe_states.get_mut(part_id) { + state.spill_files.push_back(file); } - self.probe_spill_files[part_id].push_back(file); self.schedule_partition(part_id)?; return Ok(true); } @@ -812,35 +928,32 @@ impl PartitionedHashJoinStream { partition_indices: &[usize], ) -> Result<()> { let parent_index = descriptor.build_index; - if parent_index >= self.probe_partitions.len() { + if parent_index >= self.probe_states.len() { return Ok(()); } - // Reset parent metrics - self.probe_buffered_rows_per_part - .get_mut(parent_index) - .map(|v| *v = 0); - self.probe_spilled_rows_per_part - .get_mut(parent_index) - .map(|v| *v = 0); - self.probe_consumed_rows_per_part - .get_mut(parent_index) - .map(|v| *v = 0); - if parent_index < self.probe_batch_positions.len() { - self.probe_batch_positions[parent_index] = 0; - } - if parent_index < self.probe_spill_in_progress.len() { - self.probe_spill_in_progress[parent_index] = None; - } - let shift_bits = descriptor.radix_bits; let mask = (fanout - 1) as u64; - if let Some(file) = self - .probe_spill_files - .get_mut(parent_index) - .and_then(|queue| queue.pop_front()) - { + let spill_file = { + let state = self + .probe_states + .get_mut(parent_index) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.batch_position = 0; + state.buffered_rows = 0; + state.spilled_rows = 0; + state.consumed_rows = 0; + state.active_batch = None; + state.active_values.clear(); + state.active_hashes.clear(); + state.active_offset = (0, None); + state.joined_probe_idx = None; + state.pending_stream = None; + state.spill_files.pop_front() + }; + + if let Some(file) = spill_file { let mut writers = Vec::with_capacity(fanout); for _ in 0..fanout { let writer = self @@ -908,19 +1021,33 @@ impl PartitionedHashJoinStream { internal_datafusion_err!("expected probe spill file") })?; let partitions_idx = partition_indices[sub_idx]; - self.probe_spill_files[partitions_idx].push_back(file); - self.probe_spilled_rows_per_part[partitions_idx] = 0; - self.probe_buffered_rows_per_part[partitions_idx] = 0; - self.probe_consumed_rows_per_part[partitions_idx] = 0; + let state = self + .probe_states + .get_mut(partitions_idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.spill_files.push_back(file); + state.spilled_rows = 0; + state.buffered_rows = 0; + state.consumed_rows = 0; + state.batch_position = 0; + state.pending_stream = None; + state.active_batch = None; + state.active_values.clear(); + state.active_hashes.clear(); + state.active_offset = (0, None); + state.joined_probe_idx = None; } return Ok(()); } // In-memory probe data - let parent_partition = mem::replace( - &mut self.probe_partitions[parent_index], - ProbePartition::new(), - ); + let parent_partition = { + let state = self + .probe_states + .get_mut(parent_index) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + mem::replace(&mut state.buffered, ProbePartition::new()) + }; for idx in 0..parent_partition.batches.len() { let batch = &parent_partition.batches[idx]; let values = &parent_partition.values[idx]; @@ -960,20 +1087,20 @@ impl PartitionedHashJoinStream { } let idx = partition_indices[sub_idx]; - let part = self - .probe_partitions + let state = self + .probe_states .get_mut(idx) .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - part.batches.push(filtered_batch); - part.values.push(filtered_values); - part.hashes.push(filtered_hashes); - let buffered = part + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_values); + state.buffered.hashes.push(filtered_hashes); + let buffered = state + .buffered .batches .last() .map(|b| b.num_rows()) .unwrap_or_default(); - self.probe_buffered_rows_per_part[idx] = - self.probe_buffered_rows_per_part[idx].saturating_add(buffered); + state.buffered_rows = state.buffered_rows.saturating_add(buffered); } } @@ -981,19 +1108,10 @@ impl PartitionedHashJoinStream { } fn buffer_probe_side(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.probe_spill_in_progress.len() != self.num_partitions - || self.probe_spill_files.len() != self.num_partitions - { + if self.probe_states.len() != self.num_partitions { self.resize_partition_vectors(); } - if self.probe_partitions.len() != self.num_partitions { - self.probe_partitions.clear(); - } - if self.probe_partitions.is_empty() { - self.probe_partitions = (0..self.num_partitions) - .map(|_| ProbePartition::new()) - .collect(); - } + loop { match self.right.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { @@ -1048,20 +1166,29 @@ impl PartitionedHashJoinStream { filtered_hashes.push(hashes[i as usize]); } - match self.build_partitions.get_mut(part_id) { - Some(BuildPartition::Spilled { .. }) => { - if self.probe_spill_in_progress[part_id].is_none() { + if matches!( + self.build_partitions.get(part_id), + Some(BuildPartition::Spilled { .. }) + ) { + let (queue_ready, stream_active) = { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| { + internal_datafusion_err!( + "missing probe partition" + ) + })?; + if state.spill_in_progress.is_none() { let ipf = self .probe_spill_manager .create_in_progress_file( "hash_join_probe_partition", )?; - self.probe_spill_in_progress[part_id] = Some(ipf); + state.spill_in_progress = Some(ipf); self.join_metrics.probe_spill_count.add(1); } - if let Some(ref mut ipf) = - self.probe_spill_in_progress[part_id] - { + if let Some(ref mut ipf) = state.spill_in_progress { ipf.append_batch(&filtered_batch)?; self.join_metrics .probe_spilled_rows @@ -1070,36 +1197,28 @@ impl PartitionedHashJoinStream { .probe_spilled_bytes .add(filtered_batch.get_array_memory_size()); } - self.probe_spilled_rows_per_part[part_id] += - filtered_batch.num_rows(); - let queue_ready = self - .probe_spill_files - .get(part_id) - .map(|q| !q.is_empty()) - .unwrap_or(false); - let stream_active = self - .pending_probe_partition - .is_some_and(|p| p == part_id); - if !queue_ready && !stream_active { - self.finalize_spilled_partition(part_id)?; - } + state.spilled_rows = state + .spilled_rows + .saturating_add(filtered_batch.num_rows()); + ( + !state.spill_files.is_empty(), + state.pending_stream.is_some(), + ) + }; + if !queue_ready && !stream_active { + self.finalize_spilled_partition(part_id)?; } - _ => { - self.probe_partitions[part_id] - .batches - .push(filtered_batch); - self.probe_partitions[part_id] - .values - .push(filtered_on_values); - self.probe_partitions[part_id] - .hashes - .push(filtered_hashes); - let last = self.probe_partitions[part_id] - .batches - .last() - .unwrap(); - self.probe_buffered_rows_per_part[part_id] += - last.num_rows(); + } else { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing probe partition") + })?; + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_on_values); + state.buffered.hashes.push(filtered_hashes); + if let Some(last) = state.buffered.batches.last() { + state.buffered_rows = + state.buffered_rows.saturating_add(last.num_rows()); } } } @@ -1109,8 +1228,10 @@ impl PartitionedHashJoinStream { Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Ready(None) => { self.probe_stream_finished = true; - self.probe_batch_positions = vec![0; self.num_partitions]; for part_id in 0..self.num_partitions { + if let Some(state) = self.probe_states.get_mut(part_id) { + state.batch_position = 0; + } self.finalize_spilled_partition(part_id)?; } return Poll::Ready(Ok(())); @@ -1144,6 +1265,8 @@ impl PartitionedHashJoinStream { } // Enqueue new descriptors in order for desc in new_descriptors.into_iter().rev() { + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&desc); self.pending_partitions.push_front(desc); } Ok(true) @@ -1212,38 +1335,53 @@ impl PartitionedHashJoinStream { } fn reset_partition_state(&mut self) { - for writer in self.probe_spill_in_progress.iter_mut() { - if let Some(mut writer) = writer.take() { + for state in self.probe_states.iter_mut() { + if let Some(mut writer) = state.spill_in_progress.take() { let _ = writer.finish(); } + state.reset(); + } + self.probe_states.clear(); + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + self.probe_task_scheduler = + HybridTaskScheduler::new(SchedulerConfig::from_stream(self)); + self.probe_scheduler_inflight.clear(); + self.probe_scheduler_waiting_for_stream.clear(); + self.probe_scheduler_active_streams = 0; + } + + for partition in self.build_partitions.iter_mut() { + if let BuildPartition::Spilled { + spill_file, + reservation, + .. + } = partition + { + if let Some(file) = spill_file.take() { + drop(file); + } + let placeholder = MemoryConsumer::new("released_build_partition") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let _ = mem::replace(reservation, placeholder); + } } self.build_partitions.clear(); self.matched_build_rows_per_partition.clear(); self.current_partition = None; self.pending_partitions.clear(); - self.current_probe_batch = None; - self.current_probe_values.clear(); - self.current_probe_hashes.clear(); - self.current_offset = (0, None); - self.joined_probe_idx = None; self.placeholder_emitted = false; self.right_alignment_start = 0; self.unmatched_partition = 0; self.unmatched_left_indices_cache = None; self.unmatched_right_indices_cache = None; self.unmatched_offset = 0; - self.probe_partitions.clear(); - self.probe_batch_positions.clear(); self.probe_stream_finished = false; self.pending_reload_stream = None; self.pending_reload_batches.clear(); self.pending_reload_partition = None; - self.pending_probe_stream = None; - self.pending_probe_partition = None; - for queue in self.probe_spill_files.iter_mut() { - queue.clear(); - } self.partition_pending.clear(); self.partition_descriptors.clear(); self.bounds_waiter = None; @@ -1277,6 +1415,17 @@ impl PartitionedHashJoinStream { } } + fn repartition_worthwhile(&self, max_spilled_bytes: usize) -> bool { + let partitions = self.num_partitions.max(1); + let per_partition_budget = self.memory_threshold / partitions; + if per_partition_budget == 0 { + return false; + } + let cutoff = + std::cmp::max(per_partition_budget / 2, HYBRID_HASH_MIN_PARTITION_BYTES); + max_spilled_bytes > cutoff + } + fn prepare_partition_queue(&mut self) { self.pending_partitions.clear(); let radix_bits = @@ -1305,9 +1454,11 @@ impl PartitionedHashJoinStream { spilled_bytes, spilled_rows, }); - if let Some(desc) = self.pending_partitions.back() { + if let Some(desc) = self.pending_partitions.back().cloned() { self.partition_descriptors[part_id] = Some(desc.clone()); self.partition_pending[part_id] = true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&desc); } } } @@ -1331,6 +1482,16 @@ impl PartitionedHashJoinStream { }); } else { self.current_partition = None; + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + if !self.probe_scheduler_waiting_for_stream.is_empty() { + hhj_debug(|| { + "transition_to_next_partition -> WaitingForProbe".to_string() + }); + self.state = PartitionedHashJoinState::WaitingForProbe; + return; + } + } self.state = PartitionedHashJoinState::HandleUnmatchedRows; } } @@ -1527,8 +1688,8 @@ impl PartitionedHashJoinStream { column_indices: Vec, null_equality: NullEquality, batch_size: usize, - num_partitions: usize, - max_partition_count: usize, + mut num_partitions: usize, + mut max_partition_count: usize, memory_threshold: usize, memory_reservation: MemoryReservation, runtime_env: Arc, @@ -1551,6 +1712,23 @@ impl PartitionedHashJoinStream { Arc::clone(&build_schema), ); + let mem_limit = max_partitions_allowed_for_memory(memory_threshold) + .max(HYBRID_HASH_MIN_FANOUT); + max_partition_count = max_partition_count + .max(HYBRID_HASH_MIN_FANOUT) + .min(mem_limit); + num_partitions = num_partitions + .max(HYBRID_HASH_MIN_FANOUT) + .min(max_partition_count); + + #[cfg(feature = "hybrid_hash_join_scheduler")] + let scheduler_config = SchedulerConfig { + memory_threshold, + batch_size, + max_partition_count, + max_probe_streams: std::cmp::max(1, std::cmp::min(4, num_partitions)), + }; + Ok(Self { partition, schema, @@ -1570,7 +1748,19 @@ impl PartitionedHashJoinStream { memory_threshold, state: PartitionedHashJoinState::PartitionBuildSide, build_partitions: Vec::new(), - probe_partitions: Vec::new(), + probe_states: (0..num_partitions) + .map(|_| ProbePartitionState::new()) + .collect(), + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_task_scheduler: HybridTaskScheduler::new(scheduler_config.clone()), + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_inflight: vec![false; num_partitions], + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_waiting_for_stream: VecDeque::new(), + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_active_streams: 0, + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_max_streams: scheduler_config.max_probe_streams, current_partition: None, pending_partitions: VecDeque::new(), probe_spill_manager, @@ -1586,28 +1776,15 @@ impl PartitionedHashJoinStream { bounds_waiter: None, build_schema, probe_schema, - current_probe_batch: None, - current_probe_values: vec![], - current_probe_hashes: vec![], - current_offset: (0, None), - joined_probe_idx: None, matched_build_rows_per_partition: Vec::new(), unmatched_partition: 0, unmatched_left_indices_cache: None, unmatched_right_indices_cache: None, unmatched_offset: 0, probe_stream_finished: false, - probe_batch_positions: vec![], pending_reload_stream: None, pending_reload_batches: Vec::new(), pending_reload_partition: None, - probe_spill_in_progress: (0..num_partitions).map(|_| None).collect(), - probe_spill_files: (0..num_partitions).map(|_| VecDeque::new()).collect(), - pending_probe_stream: None, - pending_probe_partition: None, - probe_buffered_rows_per_part: vec![0; num_partitions], - probe_spilled_rows_per_part: vec![0; num_partitions], - probe_consumed_rows_per_part: vec![0; num_partitions], matched_rows_per_part: vec![0; num_partitions], emitted_rows_per_part: vec![0; num_partitions], candidate_pairs_per_part: vec![0; num_partitions], @@ -1665,15 +1842,41 @@ impl PartitionedHashJoinStream { let mut allow_repartition = true; loop { + hhj_debug(|| { + format!( + "partition_build_side pass={} num_partitions={} allow_repartition={}", + self.partition_pass, self.num_partitions, allow_repartition + ) + }); self.reset_partition_state(); match self.try_partition_build_side(&build_data, allow_repartition)? { - PartitionBuildStatus::Ready(result) => return Ok(result), + PartitionBuildStatus::Ready(result) => { + hhj_debug(|| { + format!( + "partition_build_side pass {} completed (num_partitions={})", + self.partition_pass, self.num_partitions + ) + }); + return Ok(result); + } PartitionBuildStatus::NeedMorePartitions { next_count } => { + hhj_debug(|| { + format!( + "partition_build_side requesting repartition to {} (current={})", + next_count, self.num_partitions + ) + }); if next_count <= self.num_partitions || next_count == 0 || next_count > self.max_partition_count { + hhj_debug(|| { + format!( + "repartition request invalid (max={} current={}); forcing spill", + self.max_partition_count, self.num_partitions + ) + }); allow_repartition = false; continue; } @@ -1750,9 +1953,26 @@ impl PartitionedHashJoinStream { .set_max(self.memory_reservation.size()); if self.memory_reservation.size() > self.memory_threshold { if allow_repartition { - if let Some(next_count) = self.next_partition_count() { - repartition_request = Some(next_count); - break; + let partition_estimate = accum.buffered_bytes; + if self.repartition_worthwhile(partition_estimate) { + if let Some(next_count) = self.next_partition_count() + { + hhj_debug(|| { + format!( + "partition {} exceeded budget (bytes={}) -> requesting repartition to {}", + build_index, partition_estimate, next_count + ) + }); + repartition_request = Some(next_count); + break; + } + } else { + hhj_debug(|| { + format!( + "partition {} exceeded global mem but bytes={} under per-part budget; skipping repartition", + build_index, partition_estimate + ) + }); } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -1765,9 +1985,26 @@ impl PartitionedHashJoinStream { } Err(_) => { if allow_repartition { - if let Some(next_count) = self.next_partition_count() { - repartition_request = Some(next_count); - break; + let partition_estimate = + accum.buffered_bytes.saturating_add(batch_size); + if self.repartition_worthwhile(partition_estimate) { + if let Some(next_count) = self.next_partition_count() { + hhj_debug(|| { + format!( + "allocation failure for partition {} (bytes={}) -> requesting repartition to {}", + build_index, partition_estimate, next_count + ) + }); + repartition_request = Some(next_count); + break; + } + } else { + hhj_debug(|| { + format!( + "allocation failure for partition {} but bytes={} under per-part budget; spilling without repartition", + build_index, partition_estimate + ) + }); } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -1791,6 +2028,11 @@ impl PartitionedHashJoinStream { } if let Some(next_count) = repartition_request { + hhj_debug(|| { + format!( + "try_partition_build_side early repartition request next_count={next_count}" + ) + }); return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); } @@ -1910,8 +2152,31 @@ impl PartitionedHashJoinStream { if (max_spilled_bytes > self.memory_threshold || any_spilled) && allow_repartition { - if let Some(next_count) = self.next_partition_count() { + if !self.repartition_worthwhile(max_spilled_bytes) { + hhj_debug(|| { + format!( + "spilled partitions already near budget (max_spilled_bytes={} bytes, memory_threshold={} partitions={}, budget≈{})", + max_spilled_bytes, + self.memory_threshold, + self.num_partitions, + (self.memory_threshold / self.num_partitions.max(1)).max(1) + ) + }); + } else if let Some(next_count) = self.next_partition_count() { + hhj_debug(|| { + format!( + "try_partition_build_side repartition due to spill (max_spilled_bytes={} threshold={} any_spilled={}) next_count={}", + max_spilled_bytes, + self.memory_threshold, + any_spilled, + next_count + ) + }); return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + } else { + hhj_debug(|| { + "spill detected but no further repartition possible".to_string() + }); } } @@ -1992,40 +2257,371 @@ impl PartitionedHashJoinStream { } fn partition_has_pending_probe(&self, part_id: usize) -> bool { - if part_id < self.probe_partitions.len() - && part_id < self.probe_batch_positions.len() - && self.probe_batch_positions[part_id] - < self.probe_partitions[part_id].batches.len() - { - return true; + if let Some(state) = self.probe_states.get(part_id) { + if state.batch_position < state.buffered.batches.len() { + return true; + } + if state.active_batch.is_some() { + return true; + } + if !state.spill_files.is_empty() { + return true; + } + if state.pending_stream.is_some() { + return true; + } + if state.spill_in_progress.is_some() { + return true; + } } + false + } - if self.current_partition.is_some_and(|idx| idx == part_id) - && self.current_probe_batch.is_some() - { - return true; + /// Attempts to load the next buffered probe batch for `part_id`. + pub(super) fn take_buffered_probe_batch( + &mut self, + part_id: usize, + ) -> Result> { + if let Some(state) = self.probe_states.get_mut(part_id) { + if state.batch_position < state.buffered.batches.len() { + let pos = state.batch_position; + let batch = state.buffered.batches[pos].clone(); + let values = state.buffered.values[pos].clone(); + let hashes = state.buffered.hashes[pos].clone(); + state.batch_position = state.batch_position.saturating_add(1); + state.active_batch = Some(batch.clone()); + state.active_values = values; + state.active_hashes = hashes; + state.active_offset = (0, None); + if let Some(b) = state.active_batch.as_ref() { + state.consumed_rows = + state.consumed_rows.saturating_add(b.num_rows()); + } + return Ok(Some(batch)); + } } + Ok(None) + } - if part_id < self.probe_spill_files.len() - && !self.probe_spill_files[part_id].is_empty() - { - return true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn try_acquire_probe_stream_slot(&mut self) -> bool { + if self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { + self.probe_scheduler_active_streams += 1; + true + } else { + false + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn release_probe_stream_slot(&mut self) { + if self.probe_scheduler_active_streams > 0 { + self.probe_scheduler_active_streams -= 1; } + self.wake_stream_waiter(); + } + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn enqueue_stream_waiter(&mut self, part_id: usize) { + if part_id >= self.partition_pending.len() { + return; + } if self - .pending_probe_partition - .is_some_and(|idx| idx == part_id) + .probe_scheduler_waiting_for_stream + .iter() + .any(|&v| v == part_id) { - return true; + return; } + self.probe_scheduler_waiting_for_stream.push_back(part_id); + } - if part_id < self.probe_spill_in_progress.len() - && self.probe_spill_in_progress[part_id].is_some() - { - return true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn wake_stream_waiter(&mut self) { + while self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { + if let Some(next_part) = self.probe_scheduler_waiting_for_stream.pop_front() { + hhj_debug(|| format!("wake_stream_waiter considering part {next_part}")); + if next_part >= self.partition_pending.len() { + continue; + } + if self.partition_pending[next_part] { + hhj_debug(|| { + format!("wake_stream_waiter skipping part {next_part} (already pending)") + }); + continue; + } + if let Some(Some(desc)) = + self.partition_descriptors.get(next_part).map(|d| d.clone()) + { + self.partition_pending[next_part] = true; + let waiting_for_probe = + matches!(self.state, PartitionedHashJoinState::WaitingForProbe); + self.pending_partitions.push_back(desc); + hhj_debug(|| { + format!( + "wake_stream_waiter scheduled part {next_part}, waiting_for_probe={waiting_for_probe}" + ) + }); + if waiting_for_probe { + self.transition_to_next_partition(); + } + break; + } + } else { + hhj_debug(|| "wake_stream_waiter nothing to wake".to_string()); + break; + } } + } - false + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn poll_probe_stage_task( + &mut self, + cx: &mut Context<'_>, + descriptor: &PartitionDescriptor, + ) -> Result { + let part_id = descriptor.build_index; + self.schedule_probe_task(descriptor); + hhj_debug(|| { + format!( + "poll_probe_stage_task part {part_id} start, queue_len={}", + self.probe_task_scheduler.len() + ) + }); + + let mut iterations = self.probe_task_scheduler.len(); + while iterations > 0 { + iterations -= 1; + let Some(task) = self.probe_task_scheduler.pop_task() else { + break; + }; + match task { + SchedulerTask::Probe(probe_task) => { + match SchedulerTask::Probe(probe_task).poll(self, Some(cx))? { + TaskPoll::ProbeReady(desc) => { + let ready_part = desc.build_index; + hhj_debug(|| { + format!("probe task ready for part {ready_part}") + }); + if ready_part >= self.probe_scheduler_inflight.len() { + self.probe_scheduler_inflight + .resize(ready_part + 1, false); + } + self.probe_scheduler_inflight[ready_part] = false; + if ready_part == part_id { + return Ok(ProbeTaskStatus::Ready); + } else { + if ready_part >= self.partition_pending.len() { + self.partition_pending.resize(ready_part + 1, false); + } + if !self.partition_pending[ready_part] { + self.pending_partitions.push_back(desc.clone()); + self.partition_pending[ready_part] = true; + } + } + } + TaskPoll::Pending(next_task) => { + hhj_debug(|| "probe task pending, requeue".to_string()); + self.probe_task_scheduler.push_task(next_task); + } + TaskPoll::YieldProbe { + task: next_task, + descriptor: desc, + } => { + let wait_part = desc.build_index; + if wait_part == part_id { + self.probe_task_scheduler.push_task(next_task); + return Ok(ProbeTaskStatus::WaitingForStream); + } else { + self.probe_task_scheduler.push_task(next_task); + self.enqueue_stream_waiter(wait_part); + } + } + TaskPoll::ProbeFinished(desc) => { + let finished_part = desc.build_index; + hhj_debug(|| { + format!("probe task finished for part {finished_part}") + }); + if finished_part >= self.probe_scheduler_inflight.len() { + self.probe_scheduler_inflight + .resize(finished_part + 1, false); + } + self.probe_scheduler_inflight[finished_part] = false; + if finished_part == part_id { + return Ok(ProbeTaskStatus::Finished); + } else { + if finished_part >= self.partition_pending.len() { + self.partition_pending + .resize(finished_part + 1, false); + } + if !self.partition_pending[finished_part] { + self.pending_partitions.push_back(desc.clone()); + self.partition_pending[finished_part] = true; + } + } + } + TaskPoll::YieldFinalize(task) => { + hhj_debug(|| "finalize task yielded".to_string()); + self.probe_task_scheduler.push_task(task); + } + TaskPoll::Ready(_) => { + // Build/finalize ready events are ignored in probe context. + } + TaskPoll::BuildFinished(_) => {} + TaskPoll::FinalizeFinished => {} + } + } + other_task => { + hhj_debug(|| { + "non-probe task encountered in probe scheduler".to_string() + }); + // Unexpected task type for probe scheduling; push back to preserve semantics. + self.probe_task_scheduler.push_task(other_task); + } + } + } + + let queue_len = self.probe_task_scheduler.len(); + hhj_debug(|| { + format!( + "poll_probe_stage_task part {part_id} returning Pending (queue_len={})", + queue_len + ) + }); + if queue_len > 0 { + cx.waker().wake_by_ref(); + } + Ok(ProbeTaskStatus::Pending) + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub(super) fn poll_probe_data_for_partition( + &mut self, + part_id: usize, + cx: &mut Context<'_>, + ) -> Result { + if self.take_buffered_probe_batch(part_id)?.is_some() { + return Ok(ProbeDataPoll::Ready); + } + + let has_spilled_probe = { + let state = self.probe_state(part_id)?; + state.spill_in_progress.is_some() + || !state.spill_files.is_empty() + || state.pending_stream.is_some() + }; + + if !has_spilled_probe { + return Ok(ProbeDataPoll::Finished); + } + + loop { + let needs_stream = { + let state = self.probe_state(part_id)?; + state.pending_stream.is_none() + }; + if needs_stream { + let mut next_file = { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state.spill_files.pop_front() + }; + if next_file.is_none() && self.finalize_spilled_partition(part_id)? { + next_file = { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing partition") + })?; + state.spill_files.pop_front() + }; + } + if let Some(file) = next_file { + if !self.try_acquire_probe_stream_slot() { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing partition") + })?; + state.spill_files.push_front(file); + return Ok(ProbeDataPoll::NeedStream); + } + let stream = self.probe_spill_manager.read_spill_as_stream(file)?; + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state.pending_stream = Some(stream); + } else { + let writer_open = { + let state = self.probe_state(part_id)?; + state.spill_in_progress.is_some() + }; + if self.probe_stream_finished && !writer_open { + return Ok(ProbeDataPoll::Finished); + } else { + return Ok(ProbeDataPoll::Pending); + } + } + } + + let poll_result = { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state + .pending_stream + .as_mut() + .map(|stream| stream.poll_next_unpin(cx)) + }; + + match poll_result { + Some(Poll::Ready(Some(Ok(batch)))) => { + let (values, hashes) = self.prepare_probe_values(&batch)?; + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state.active_batch = Some(batch); + state.active_values = values; + state.active_hashes = hashes; + state.active_offset = (0, None); + if let Some(b) = state.active_batch.as_ref() { + state.consumed_rows = + state.consumed_rows.saturating_add(b.num_rows()); + } + return Ok(ProbeDataPoll::Ready); + } + Some(Poll::Ready(Some(Err(e)))) => return Err(e), + Some(Poll::Ready(None)) => { + { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing partition") + })?; + state.pending_stream = None; + } + self.release_probe_stream_slot(); + continue; + } + Some(Poll::Pending) | None => return Ok(ProbeDataPoll::Pending), + } + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn prepare_probe_values( + &self, + batch: &RecordBatch, + ) -> Result<(Vec, Vec)> { + let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + keys_values.push(c.evaluate(batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + Ok((keys_values, hashes)) } /// Process a specific partition @@ -2035,6 +2631,7 @@ impl PartitionedHashJoinStream { partition_state: &ProcessPartitionState, ) -> Poll>>> { let build_index = partition_state.descriptor.build_index; + hhj_debug(|| format!("process_partition enter part {build_index}")); // Guard against invalid partition ids (off-by-one protection) if build_index >= self.build_partitions.len() { @@ -2081,91 +2678,132 @@ impl PartitionedHashJoinStream { } // Select next probe batch for current partition - if self.current_probe_batch.is_none() { - // Decide probe source based on whether we spilled probe for this partition - let has_spilled_probe = self - .probe_spill_in_progress - .get(build_index) - .and_then(|o| o.as_ref()) - .is_some() - || self - .probe_spill_files - .get(build_index) - .map(|queue| !queue.is_empty()) - .unwrap_or(false) - || self - .pending_probe_partition - .is_some_and(|p| p == build_index); - let has_buffered_probe = self - .probe_partitions - .get(build_index) - .map(|p| !p.batches.is_empty()) - .unwrap_or(false); - - // Prefer buffered probe batches first; when exhausted, consume spilled probe stream - let pos = self.probe_batch_positions[build_index]; - let buffered_len = self - .probe_partitions - .get(build_index) - .map(|p| p.batches.len()) - .unwrap_or(0); - if has_buffered_probe && pos < buffered_len { - let part = &self.probe_partitions[build_index]; - // Take buffered batch/values/hashes - let batch = part.batches[pos].clone(); - let values = part.values[pos].clone(); - let hashes = part.hashes[pos].clone(); - self.probe_batch_positions[build_index] = pos + 1; - - self.current_probe_batch = Some(batch); - self.current_probe_values = values; - self.current_probe_hashes = hashes; - self.current_offset = (0, None); - if let Some(b) = &self.current_probe_batch { - self.probe_consumed_rows_per_part[build_index] = self - .probe_consumed_rows_per_part[build_index] - .saturating_add(b.num_rows()); + let mut has_active_batch = match self.probe_state(build_index) { + Ok(state) => state.active_batch.is_some(), + Err(e) => return Poll::Ready(Err(e)), + }; + + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + if !has_active_batch { + match self.poll_probe_stage_task(cx, &partition_state.descriptor)? { + ProbeTaskStatus::Ready => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Ready") + }); + has_active_batch = true; + } + ProbeTaskStatus::Pending => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Pending") + }); + return Poll::Pending; + } + ProbeTaskStatus::WaitingForStream => { + hhj_debug(|| { + format!("process_partition part {build_index} -> WaitingForStream") + }); + self.enqueue_stream_waiter(build_index); + self.current_partition = None; + self.transition_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + ProbeTaskStatus::Finished => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Finished") + }); + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } } - } else if has_spilled_probe { - loop { - if self.pending_probe_partition != Some(build_index) { - let mut next_file = self - .probe_spill_files - .get_mut(build_index) - .and_then(|queue| queue.pop_front()); - if next_file.is_none() - && self.finalize_spilled_partition(build_index)? - { - next_file = self - .probe_spill_files - .get_mut(build_index) - .and_then(|queue| queue.pop_front()); + } + } + + #[cfg(not(feature = "hybrid_hash_join_scheduler"))] + { + if !has_active_batch { + if self.take_buffered_probe_batch(build_index)?.is_some() { + has_active_batch = true; + } + } + + if !has_active_batch { + let has_spilled_probe = match self.probe_state(build_index) { + Ok(state) => { + state.spill_in_progress.is_some() + || !state.spill_files.is_empty() + || state.pending_stream.is_some() + } + Err(e) => return Poll::Ready(Err(e)), + }; + + if has_spilled_probe { + loop { + let needs_stream = match self.probe_state(build_index) { + Ok(state) => state.pending_stream.is_none(), + Err(e) => return Poll::Ready(Err(e)), + }; + + if needs_stream { + let mut next_file = match self.probe_state_mut(build_index) { + Ok(state) => state.spill_files.pop_front(), + Err(e) => return Poll::Ready(Err(e)), + }; + if next_file.is_none() + && self.finalize_spilled_partition(build_index)? + { + next_file = match self.probe_state_mut(build_index) { + Ok(state) => state.spill_files.pop_front(), + Err(e) => return Poll::Ready(Err(e)), + }; + } + if let Some(file) = next_file { + let stream = self + .probe_spill_manager + .read_spill_as_stream(file)?; + match self.probe_state_mut(build_index) { + Ok(state) => state.pending_stream = Some(stream), + Err(e) => return Poll::Ready(Err(e)), + } + } else { + let should_release = match self.probe_state(build_index) { + Ok(state) => { + self.probe_stream_finished + && state.spill_in_progress.is_none() + && state.pending_stream.is_none() + } + Err(e) => return Poll::Ready(Err(e)), + }; + if should_release { + match self.probe_state_mut(build_index) { + Ok(state) => state.pending_stream = None, + Err(e) => return Poll::Ready(Err(e)), + } + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok( + StatefulStreamResult::Continue, + )); + } else { + return Poll::Pending; + } + } } - if let Some(file) = next_file { - let stream = - self.probe_spill_manager.read_spill_as_stream(file)?; - self.pending_probe_stream = Some(stream); - self.pending_probe_partition = Some(build_index); - } else { - let writer_open = self - .probe_spill_in_progress - .get(build_index) - .and_then(|o| o.as_ref()) - .is_some(); - if self.probe_stream_finished && !writer_open { - self.pending_probe_stream = None; - self.pending_probe_partition = None; - self.release_partition_resources(build_index); - self.advance_to_next_partition(); - return Poll::Ready(Ok(StatefulStreamResult::Continue)); + + let poll_result = { + let state = match self.probe_state_mut(build_index) { + Ok(state) => state, + Err(e) => return Poll::Ready(Err(e)), + }; + if let Some(stream) = state.pending_stream.as_mut() { + stream.poll_next_unpin(cx) } else { return Poll::Pending; } - } - } + }; - if let Some(stream) = self.pending_probe_stream.as_mut() { - match stream.poll_next_unpin(cx) { + match poll_result { Poll::Ready(Some(Ok(batch))) => { let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); @@ -2182,70 +2820,74 @@ impl PartitionedHashJoinStream { &mut hashes, )?; - self.current_probe_batch = Some(batch); - self.current_probe_values = keys_values; - self.current_probe_hashes = hashes; - self.current_offset = (0, None); - if let Some(b) = &self.current_probe_batch { - self.probe_consumed_rows_per_part[build_index] = self - .probe_consumed_rows_per_part[build_index] - .saturating_add(b.num_rows()); + let state = match self.probe_state_mut(build_index) { + Ok(state) => state, + Err(e) => return Poll::Ready(Err(e)), + }; + state.active_batch = Some(batch); + state.active_values = keys_values; + state.active_hashes = hashes; + state.active_offset = (0, None); + if let Some(b) = state.active_batch.as_ref() { + state.consumed_rows = + state.consumed_rows.saturating_add(b.num_rows()); } + has_active_batch = true; break; } Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), Poll::Ready(None) => { - self.pending_probe_stream = None; - self.pending_probe_partition = None; + match self.probe_state_mut(build_index) { + Ok(state) => state.pending_stream = None, + Err(e) => return Poll::Ready(Err(e)), + } continue; } Poll::Pending => return Poll::Pending, } - } else { - return Poll::Pending; } + } else { + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } - } else { - // Neither spilled nor buffered probe for this partition: advance - // println!( - // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - // build_index, - // self.probe_buffered_rows_per_part[build_index], - // self.probe_spilled_rows_per_part[build_index], - // self.probe_consumed_rows_per_part[build_index], - // self.candidate_pairs_per_part[build_index], - // self.matched_rows_per_part[build_index], - // self.emitted_rows_per_part[build_index] - // ); - self.release_partition_resources(build_index); - self.advance_to_next_partition(); - return Poll::Ready(Ok(StatefulStreamResult::Continue)); } } - // If no probe batch selected, advance to next partition (no probe rows here) - if self.current_probe_batch.is_none() { - // println!( - // "[spill-join][summary] part={} buffered={} spilled={} consumed={} candidates={} matched={} emitted={}", - // build_index, - // self.probe_buffered_rows_per_part[build_index], - // self.probe_spilled_rows_per_part[build_index], - // self.probe_consumed_rows_per_part[build_index], - // self.candidate_pairs_per_part[build_index], - // self.matched_rows_per_part[build_index], - // self.emitted_rows_per_part[build_index] - // ); + if !has_active_batch { self.release_partition_resources(build_index); self.advance_to_next_partition(); return Poll::Ready(Ok(StatefulStreamResult::Continue)); } // At this point we have a current probe batch for this partition - let (result, build_ids_to_mark, next_offset) = { - let probe_batch = self - .current_probe_batch - .as_ref() - .ok_or_else(|| internal_datafusion_err!("expected probe batch"))?; + let (result, build_ids_to_mark, next_offset, next_joined_idx) = { + let ( + probe_batch, + probe_values, + probe_hashes, + current_offset, + prev_joined_idx, + ) = { + let state = match self.probe_state(build_index) { + Ok(state) => state, + Err(e) => return Poll::Ready(Err(e)), + }; + let batch = state + .active_batch + .as_ref() + .ok_or_else(|| internal_datafusion_err!("expected probe batch"))? + .clone(); + let values = state.active_values.clone(); + let hashes = state.active_hashes.clone(); + ( + batch, + values, + hashes, + state.active_offset, + state.joined_probe_idx, + ) + }; let (build_hashmap, build_batch, build_values) = match self.build_partitions.get(build_index) { @@ -2323,9 +2965,9 @@ impl PartitionedHashJoinStream { // Lookup against hash map with limit let (probe_indices, build_indices, next_offset) = build_hashmap .get_matched_indices_with_limit_offset( - &self.current_probe_hashes, + &probe_hashes, self.batch_size, - self.current_offset, + current_offset, ); let build_indices: UInt64Array = build_indices.into(); @@ -2348,16 +2990,16 @@ impl PartitionedHashJoinStream { &build_indices, &probe_indices, build_values, - &self.current_probe_values, + &probe_values, self.null_equality, )?; // Shadow verify on INNER join with single Int64 key (first 50k rows) /*if matches!(self.join_type, JoinType::Inner) && build_values.len() == 1 - && self.current_probe_values.len() == 1 + && probe_values.len() == 1 && build_values[0].data_type() == &arrow::datatypes::DataType::Int64 - && self.current_probe_values[0].data_type() + && probe_values[0].data_type() == &arrow::datatypes::DataType::Int64 && !self.verify_once_per_part[build_index] { @@ -2367,7 +3009,7 @@ impl PartitionedHashJoinStream { .as_any() .downcast_ref::() .unwrap(); - let pcol = self.current_probe_values[0] + let pcol = probe_values[0] .as_any() .downcast_ref::() .unwrap(); @@ -2407,8 +3049,7 @@ impl PartitionedHashJoinStream { .map(|a| format!("{:?}", a.data_type())) .collect::>() .join(", "); - let probe_types = self - .current_probe_values + let probe_types = probe_values .iter() .map(|a| format!("{:?}", a.data_type())) .collect::>() @@ -2424,7 +3065,7 @@ impl PartitionedHashJoinStream { let p = probe_indices.value(i) as usize; // Include actual first-key values for sanity checks let bk = &build_values[0]; - let pk = &self.current_probe_values[0]; + let pk = &probe_values[0]; let bv = arrow::util::display::array_value_to_string(bk.as_ref(), b) .unwrap_or_else(|_| "".to_string()); let pv = arrow::util::display::array_value_to_string(pk.as_ref(), p) @@ -2450,7 +3091,7 @@ impl PartitionedHashJoinStream { let (filtered_build_indices, filtered_probe_indices) = apply_join_filter_to_indices( build_batch, - probe_batch, + &probe_batch, build_indices, probe_indices, filter, @@ -2640,7 +3281,7 @@ impl PartitionedHashJoinStream { // Shadow verify for two-key joins (stringified) to catch type coercion issues /*if matches!(self.join_type, JoinType::Inner) && build_values.len() == 2 - && self.current_probe_values.len() == 2 + && probe_values.len() == 2 && !self.verify_once_per_part[build_index] { use std::collections::HashMap; @@ -2663,12 +3304,12 @@ impl PartitionedHashJoinStream { let max_p = probe_batch.num_rows().min(50_000); for i in 0..max_p { let k0 = arrow::util::display::array_value_to_string( - self.current_probe_values[0].as_ref(), + probe_values[0].as_ref(), i, ) .unwrap_or_else(|_| "".to_string()); let k1 = arrow::util::display::array_value_to_string( - self.current_probe_values[1].as_ref(), + probe_values[1].as_ref(), i, ) .unwrap_or_else(|_| "".to_string()); @@ -2696,8 +3337,7 @@ impl PartitionedHashJoinStream { n => Some(probe_indices.value(n - 1) as usize), }; let probe_num_rows = probe_batch.num_rows(); - let mut index_alignment_range_start = - self.joined_probe_idx.map_or(0, |v| v + 1); + let mut index_alignment_range_start = prev_joined_idx.map_or(0, |v| v + 1); let mut index_alignment_range_end = if next_offset.is_none() { probe_num_rows } else { @@ -2740,7 +3380,7 @@ impl PartitionedHashJoinStream { build_indices.values().to_vec() }; // Track last joined probe row only for right-oriented joins; otherwise clear it - self.joined_probe_idx = if needs_alignment && next_offset.is_some() { + let next_joined_idx = if needs_alignment && next_offset.is_some() { last_joined_right_idx } else { None @@ -2762,7 +3402,7 @@ impl PartitionedHashJoinStream { let right_indices_u64 = uint32_to_uint64_indices(&probe_indices); build_batch_from_indices( &self.schema, - probe_batch, + &probe_batch, build_batch, &right_indices_u64, &probe_indices, @@ -2773,7 +3413,7 @@ impl PartitionedHashJoinStream { build_batch_from_indices( &self.schema, build_batch, - probe_batch, + &probe_batch, &build_indices, &probe_indices, &self.column_indices, @@ -2784,7 +3424,7 @@ impl PartitionedHashJoinStream { let emitted_rows = result.num_rows(); self.emitted_rows_per_part[build_index] = self.emitted_rows_per_part[build_index].saturating_add(emitted_rows); - (result, build_ids_to_mark, next_offset) + (result, build_ids_to_mark, next_offset, next_joined_idx) }; // Mark matched build-side rows for outer joins (use current partition's bitmap) @@ -2795,16 +3435,22 @@ impl PartitionedHashJoinStream { } // Update offset or fetch a new probe batch - if let Some(offset) = next_offset { - self.current_offset = offset; - } else { - // Finished this probe batch - self.current_probe_batch = None; - self.current_probe_values.clear(); - self.current_probe_hashes.clear(); - self.current_offset = (0, None); - self.joined_probe_idx = None; - // Alignment is batch-local for semi/anti/mark in spillable path; do not carry across batches + match self.probe_state_mut(build_index) { + Ok(state) => { + if let Some(offset) = next_offset { + state.active_offset = offset; + state.joined_probe_idx = next_joined_idx; + } else { + state.active_batch = None; + state.active_values.clear(); + state.active_hashes.clear(); + state.active_offset = (0, None); + state.joined_probe_idx = None; + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&partition_state.descriptor); + } + } + Err(e) => return Poll::Ready(Err(e)), } if result.num_rows() == 0 { @@ -3102,6 +3748,7 @@ impl Stream for PartitionedHashJoinStream { cx: &mut Context<'_>, ) -> Poll> { loop { + hhj_debug(|| format!("poll_next state {:?}", self.state)); match self.state.clone() { PartitionedHashJoinState::PartitionBuildSide => { // Collect build side and partition it @@ -3114,6 +3761,7 @@ impl Stream for PartitionedHashJoinStream { Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), Poll::Pending => return Poll::Pending, } + hhj_debug(|| format!("restarting build pass state={:?}", self.state)); match self.partition_build_side(left_data) { Ok(StatefulStreamResult::Continue) => continue, Ok(StatefulStreamResult::Ready(Some(batch))) => { @@ -3158,6 +3806,32 @@ impl Stream for PartitionedHashJoinStream { Poll::Pending => return Poll::Pending, } } + #[cfg(feature = "hybrid_hash_join_scheduler")] + PartitionedHashJoinState::WaitingForProbe => { + if self.pending_partitions.is_empty() { + if self.probe_scheduler_waiting_for_stream.is_empty() { + hhj_debug(|| { + "WaitingForProbe -> HandleUnmatchedRows (no waiters)" + .to_string() + }); + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + continue; + } + hhj_debug(|| { + "WaitingForProbe pending=0 waiters>0, parking".to_string() + }); + return Poll::Pending; + } else { + hhj_debug(|| { + format!( + "WaitingForProbe woke with {} pending partitions", + self.pending_partitions.len() + ) + }); + self.transition_to_next_partition(); + continue; + } + } PartitionedHashJoinState::HandleUnmatchedRows => { match self.handle_unmatched_rows(cx) { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { @@ -3182,3 +3856,343 @@ impl Stream for PartitionedHashJoinStream { } } } + +#[cfg(all(test, feature = "hybrid_hash_join_scheduler"))] +mod scheduler_tests { + use super::*; + use crate::metrics::ExecutionPlanMetricsSet; + use crate::stream::RecordBatchStreamAdapter; + use arrow::array::{ArrayRef, Int32Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnv; + use futures::{stream, task::noop_waker}; + use parking_lot::Mutex; + use std::sync::atomic::AtomicUsize; + use std::sync::Arc; + use std::task::Context as StdContext; + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])) + } + + fn test_batch(schema: &SchemaRef, values: &[i32]) -> RecordBatch { + let array: ArrayRef = Arc::new(Int32Array::from(values.to_vec())); + RecordBatch::try_new(schema.clone(), vec![array]).unwrap() + } + + fn build_join_left_data( + batch: RecordBatch, + runtime_env: &Arc, + ) -> JoinLeftData { + let hash_map: Box = + Box::new(JoinHashMapU32::with_capacity(0)); + let reservation = MemoryConsumer::new("left") + .with_can_spill(true) + .register(&runtime_env.memory_pool); + JoinLeftData::new( + hash_map, + batch.clone(), + Arc::new(vec![batch]), + vec![], + Mutex::new(BooleanBufferBuilder::new(0)), + AtomicUsize::new(0), + reservation, + None, + ) + } + + fn make_test_stream( + num_partitions: usize, + max_streams: usize, + ) -> PartitionedHashJoinStream { + let runtime_env = Arc::new(RuntimeEnv::default()); + let schema = test_schema(); + let build_batch = RecordBatch::new_empty(schema.clone()); + let left_data = build_join_left_data(build_batch, &runtime_env); + let left_fut = OnceFut::new(async move { Ok(left_data) }); + let metrics = ExecutionPlanMetricsSet::new(); + let join_metrics = BuildProbeJoinMetrics::new(0, &metrics); + let probe_spill_metrics = SpillMetrics::new(&metrics, 0); + let build_spill_metrics = SpillMetrics::new(&metrics, 0); + let right_stream: SendableRecordBatchStream = Box::pin( + RecordBatchStreamAdapter::new(schema.clone(), stream::empty()), + ); + let memory_reservation = MemoryConsumer::new("top") + .with_can_spill(true) + .register(&runtime_env.memory_pool); + + let mut stream = PartitionedHashJoinStream::new( + 0, + schema.clone(), + vec![], + vec![], + None, + JoinType::Inner, + right_stream, + left_fut, + RandomState::with_seeds(0, 0, 0, 0), + join_metrics, + probe_spill_metrics, + build_spill_metrics, + vec![], + NullEquality::NullEqualsNothing, + 1024, + num_partitions, + num_partitions, + 1024, + memory_reservation, + runtime_env, + schema.clone(), + schema, + false, + None, + ) + .unwrap(); + stream.probe_scheduler_max_streams = max_streams; + stream.pending_partitions.clear(); + for pending in stream.partition_pending.iter_mut() { + *pending = false; + } + stream + } + + fn add_spill_file( + stream: &mut PartitionedHashJoinStream, + part_id: usize, + batch: &RecordBatch, + ) -> Result<()> { + let mut writer = stream + .probe_spill_manager + .create_in_progress_file("test_spill")?; + writer.append_batch(batch)?; + let file = writer.finish()?.expect("spill file"); + stream.probe_states[part_id].spill_files.push_back(file); + Ok(()) + } + + fn descriptor_for(partition: usize) -> PartitionDescriptor { + PartitionDescriptor { + build_index: partition, + root_index: partition, + generation: 0, + radix_bits: 0, + hash_prefix: partition as u64, + spilled_bytes: 0, + spilled_rows: 0, + } + } + + async fn poll_task_status( + stream: &mut PartitionedHashJoinStream, + desc: &PartitionDescriptor, + ) -> ProbeTaskStatus { + let waker = noop_waker(); + for _ in 0..4096 { + let mut cx = StdContext::from_waker(&waker); + let status = stream + .poll_probe_stage_task(&mut cx, desc) + .expect("poll should succeed"); + if matches!(status, ProbeTaskStatus::Pending) { + tokio::task::yield_now().await; + continue; + } + return status; + } + panic!("probe task stuck in pending state"); + } + + async fn poll_probe_data_until_ready( + stream: &mut PartitionedHashJoinStream, + part_id: usize, + ) -> ProbeDataPoll { + let waker = noop_waker(); + for _ in 0..4096 { + let mut cx = StdContext::from_waker(&waker); + let status = stream + .poll_probe_data_for_partition(part_id, &mut cx) + .expect("poll probe data"); + if matches!(status, ProbeDataPoll::Pending) { + tokio::task::yield_now().await; + continue; + } + return status; + } + panic!("probe data did not become ready"); + } + + #[tokio::test] + async fn probe_tasks_wait_for_stream_slots() -> Result<()> { + let mut stream = make_test_stream(2, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[1]); + add_spill_file(&mut stream, 0, &batch)?; + add_spill_file(&mut stream, 1, &batch)?; + + let desc1 = descriptor_for(1); + stream.partition_descriptors[0] = Some(descriptor_for(0)); + stream.partition_descriptors[1] = Some(desc1.clone()); + stream.partition_pending[0] = false; + stream.partition_pending[1] = false; + stream.current_partition = Some(1); + + // Simulate another partition already holding the single stream slot. + stream.probe_scheduler_active_streams = stream.probe_scheduler_max_streams; + + let status = poll_task_status(&mut stream, &desc1).await; + assert!(matches!(status, ProbeTaskStatus::WaitingForStream)); + stream.enqueue_stream_waiter(desc1.build_index); + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 1); + + stream.probe_states[0].pending_stream = None; + stream.release_probe_stream_slot(); + assert_eq!(stream.probe_scheduler_active_streams, 0); + assert!(stream.probe_scheduler_waiting_for_stream.is_empty()); + let desc = stream.pending_partitions.pop_front().unwrap(); + assert_eq!(desc.build_index, 1); + stream.partition_pending[desc.build_index] = false; + Ok(()) + } + + #[tokio::test] + async fn probe_task_resumes_after_slot_available() -> Result<()> { + let mut stream = make_test_stream(2, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[10, 20]); + add_spill_file(&mut stream, 1, &batch)?; + + let desc1 = descriptor_for(1); + stream.partition_descriptors[1] = Some(desc1.clone()); + stream.partition_pending[1] = false; + stream.current_partition = Some(1); + + // Ensure there's no active stream yet. + assert_eq!(stream.probe_scheduler_active_streams, 0); + + let status = poll_task_status(&mut stream, &desc1).await; + assert!(matches!(status, ProbeTaskStatus::Ready)); + assert!(stream.probe_states[1].active_batch.is_some()); + assert_eq!(stream.probe_scheduler_active_streams, 1); + + // Mark the active batch as consumed and continue polling to drain the spill stream. + stream.probe_states[1].active_batch = None; + let mut status = poll_probe_data_until_ready(&mut stream, 1).await; + if matches!(status, ProbeDataPoll::Ready) { + stream.probe_states[1].active_batch = None; + status = poll_probe_data_until_ready(&mut stream, 1).await; + } + assert!(matches!(status, ProbeDataPoll::Finished)); + assert_eq!(stream.probe_scheduler_active_streams, 0); + Ok(()) + } + + #[tokio::test] + async fn probe_tasks_wait_queue_multiple() -> Result<()> { + let mut stream = make_test_stream(3, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[5]); + for part in 0..3 { + add_spill_file(&mut stream, part, &batch)?; + let desc = descriptor_for(part); + stream.partition_descriptors[part] = Some(desc); + stream.partition_pending[part] = false; + } + + // Partition 0 currently holds the only stream slot. + stream.probe_scheduler_active_streams = stream.probe_scheduler_max_streams; + + // Partitions 1 and 2 must wait for a stream slot. + for part in [1, 2] { + stream.enqueue_stream_waiter(part); + } + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 2); + + // Releasing the stream should enqueue partition 1 for processing. + stream.release_probe_stream_slot(); + assert_eq!(stream.probe_scheduler_active_streams, 0); + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 1); + let desc = stream.pending_partitions.pop_front().unwrap(); + assert_eq!(desc.build_index, 1); + stream.partition_pending[desc.build_index] = false; + + // Simulate partition 1 holding the stream slot and then finishing. + stream.probe_scheduler_active_streams = stream.probe_scheduler_max_streams; + stream.probe_states[1].pending_stream = None; + stream.release_probe_stream_slot(); + let desc = stream.pending_partitions.pop_front().unwrap(); + assert_eq!(desc.build_index, 2); + stream.partition_pending[desc.build_index] = false; + Ok(()) + } + + #[tokio::test] + async fn wait_queue_blocks_state_progression() -> Result<()> { + let mut stream = make_test_stream(2, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[7]); + for part in 0..2 { + add_spill_file(&mut stream, part, &batch)?; + let desc = descriptor_for(part); + stream.partition_descriptors[part] = Some(desc.clone()); + stream.pending_partitions.push_back(desc); + stream.partition_pending[part] = true; + } + + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 0 + )); + + // Both partitions end up waiting on a limited stream slot. + stream.enqueue_stream_waiter(0); + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 1 + )); + + stream.enqueue_stream_waiter(1); + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::WaitingForProbe + )); + assert!(stream.pending_partitions.is_empty()); + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 2); + + // Releasing a stream slot wakes the earliest waiter and resumes partition 0. + stream.probe_scheduler_active_streams = 0; + stream.wake_stream_waiter(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 0 + )); + + // Simulate finishing partition 0, which should put the stream back into waiting mode + // because partition 1 is still throttled. + stream.current_partition = None; + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::WaitingForProbe + )); + + // Another wake picks up the remaining partition. + stream.wake_stream_waiter(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 1 + )); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs index 29d9614447cfd..8a979713adee2 100644 --- a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs +++ b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs @@ -21,16 +21,22 @@ use std::collections::VecDeque; use std::sync::Arc; +use std::task::Context; -use arrow::record_batch::RecordBatch; +use arrow::{array::ArrayRef, record_batch::RecordBatch}; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::hash_join::partitioned::{ - PartitionDescriptor, PartitionedHashJoinStream, + PartitionDescriptor, PartitionedHashJoinStream, ProbePartition, }; use crate::joins::utils::StatefulStreamResult; +use crate::SendableRecordBatchStream; use datafusion_common::{internal_datafusion_err, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; + +use crate::joins::join_hash_map::JoinHashMapOffset; +use crate::spill::in_progress_spill_file::InProgressSpillFile; /// Configuration shared across scheduler components. #[derive(Clone, Debug)] @@ -38,6 +44,7 @@ pub(super) struct SchedulerConfig { pub memory_threshold: usize, pub batch_size: usize, pub max_partition_count: usize, + pub max_probe_streams: usize, } impl SchedulerConfig { @@ -46,6 +53,7 @@ impl SchedulerConfig { memory_threshold: stream.memory_threshold, batch_size: stream.batch_size, max_partition_count: stream.max_partition_count, + max_probe_streams: std::cmp::max(1, std::cmp::min(4, stream.num_partitions)), } } } @@ -64,6 +72,22 @@ impl HybridTaskScheduler { } } + pub fn push_task(&mut self, task: SchedulerTask) { + self.ready_queue.push_back(task); + } + + pub fn pop_task(&mut self) -> Option { + self.ready_queue.pop_front() + } + + pub fn len(&self) -> usize { + self.ready_queue.len() + } + + pub fn is_empty(&self) -> bool { + self.ready_queue.is_empty() + } + pub fn with_build_task( config: SchedulerConfig, build_data: Arc, @@ -98,12 +122,14 @@ impl HybridTaskScheduler { stream: &mut PartitionedHashJoinStream, ) -> Result>> { while let Some(task) = self.ready_queue.pop_front() { - match task.poll(stream)? { + match task.poll(stream, None)? { + TaskPoll::Ready(_) => continue, + TaskPoll::ProbeReady(_) => continue, TaskPoll::Pending(task) => self.ready_queue.push_back(task), TaskPoll::BuildFinished(result) => return Ok(result), - TaskPoll::YieldProbe(task) => self.ready_queue.push_back(task), + TaskPoll::YieldProbe { task, .. } => self.ready_queue.push_back(task), TaskPoll::YieldFinalize(task) => self.ready_queue.push_back(task), - TaskPoll::ProbeFinished | TaskPoll::FinalizeFinished => continue, + TaskPoll::ProbeFinished(_) | TaskPoll::FinalizeFinished => continue, } } Err(internal_datafusion_err!( @@ -112,25 +138,34 @@ impl HybridTaskScheduler { } } -enum SchedulerTask { +pub(super) enum SchedulerTask { Build(BuildStageTask), Probe(ProbeStageTask), Finalize(FinalizeStageTask), } -enum TaskPoll { +pub(super) enum TaskPoll { + Ready(Option), + ProbeReady(PartitionDescriptor), Pending(SchedulerTask), BuildFinished(StatefulStreamResult>), - /// Probe task yielded without producing output (to be expanded later). - YieldProbe(SchedulerTask), + /// Probe task yielded without producing output (e.g. waiting on IO). + YieldProbe { + task: SchedulerTask, + descriptor: PartitionDescriptor, + }, /// Finalize task yielded without producing output. YieldFinalize(SchedulerTask), - ProbeFinished, + ProbeFinished(PartitionDescriptor), FinalizeFinished, } impl SchedulerTask { - fn poll(self, stream: &mut PartitionedHashJoinStream) -> Result { + pub(super) fn poll( + self, + stream: &mut PartitionedHashJoinStream, + cx: Option<&mut Context<'_>>, + ) -> Result { match self { SchedulerTask::Build(task) => match task.poll(stream)? { BuildTaskEvent::Pending(next_state) => { @@ -138,12 +173,24 @@ impl SchedulerTask { } BuildTaskEvent::Finished(result) => Ok(TaskPoll::BuildFinished(result)), }, - SchedulerTask::Probe(task) => match task.poll(stream)? { - ProbeTaskEvent::Pending(next_task) => { - Ok(TaskPoll::YieldProbe(SchedulerTask::Probe(next_task))) + SchedulerTask::Probe(task) => { + let cx = cx.expect("probe task requires runtime context"); + let descriptor = task.descriptor().clone(); + match task.poll(stream, cx)? { + ProbeTaskEvent::Pending(next_task) => { + Ok(TaskPoll::Pending(SchedulerTask::Probe(next_task))) + } + ProbeTaskEvent::Ready => Ok(TaskPoll::ProbeReady(descriptor)), + ProbeTaskEvent::NeedStream(next_task) => { + let wait_descriptor = next_task.descriptor().clone(); + Ok(TaskPoll::YieldProbe { + task: SchedulerTask::Probe(next_task), + descriptor: wait_descriptor, + }) + } + ProbeTaskEvent::Finished => Ok(TaskPoll::ProbeFinished(descriptor)), } - ProbeTaskEvent::Finished => Ok(TaskPoll::ProbeFinished), - }, + } SchedulerTask::Finalize(task) => match task.poll(stream)? { FinalizeTaskEvent::Pending(next_task) => { Ok(TaskPoll::YieldFinalize(SchedulerTask::Finalize(next_task))) @@ -209,35 +256,112 @@ enum BuildTaskEvent { Finished(StatefulStreamResult>), } -struct ProbeStageTask { - config: SchedulerConfig, +pub(super) struct ProbePartitionState { + pub buffered: ProbePartition, + pub batch_position: usize, + pub buffered_rows: usize, + pub spilled_rows: usize, + pub consumed_rows: usize, + pub spill_in_progress: Option, + pub spill_files: VecDeque, + pub pending_stream: Option, + pub active_batch: Option, + pub active_values: Vec, + pub active_hashes: Vec, + pub active_offset: JoinHashMapOffset, + pub joined_probe_idx: Option, +} + +impl ProbePartitionState { + pub fn new() -> Self { + Self { + buffered: ProbePartition::new(), + batch_position: 0, + buffered_rows: 0, + spilled_rows: 0, + consumed_rows: 0, + spill_in_progress: None, + spill_files: VecDeque::new(), + pending_stream: None, + active_batch: None, + active_values: Vec::new(), + active_hashes: Vec::new(), + active_offset: (0, None), + joined_probe_idx: None, + } + } + + pub fn reset(&mut self) { + *self = Self::new(); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum ProbeDataPoll { + Ready, + Pending, + NeedStream, + Finished, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ProbeTaskState { + Init, + Ready, + Finished, +} + +pub(super) struct ProbeStageTask { + _config: SchedulerConfig, descriptor: PartitionDescriptor, - yielded_once: bool, + state: ProbeTaskState, } impl ProbeStageTask { - fn new(config: SchedulerConfig, descriptor: PartitionDescriptor) -> Self { + pub fn new(config: SchedulerConfig, descriptor: PartitionDescriptor) -> Self { Self { - config, + _config: config, descriptor, - yielded_once: false, + state: ProbeTaskState::Init, } } - fn poll(self, _stream: &mut PartitionedHashJoinStream) -> Result { - if self.yielded_once { - Ok(ProbeTaskEvent::Finished) - } else { - Ok(ProbeTaskEvent::Pending(Self { - yielded_once: true, - ..self - })) + pub fn descriptor(&self) -> &PartitionDescriptor { + &self.descriptor + } + + fn poll( + mut self, + stream: &mut PartitionedHashJoinStream, + cx: &mut Context<'_>, + ) -> Result { + match self.state { + ProbeTaskState::Init => { + self.state = ProbeTaskState::Ready; + Ok(ProbeTaskEvent::Pending(self)) + } + ProbeTaskState::Ready => { + match stream + .poll_probe_data_for_partition(self.descriptor.build_index, cx)? + { + ProbeDataPoll::Ready => Ok(ProbeTaskEvent::Ready), + ProbeDataPoll::Pending => Ok(ProbeTaskEvent::Pending(self)), + ProbeDataPoll::NeedStream => Ok(ProbeTaskEvent::NeedStream(self)), + ProbeDataPoll::Finished => { + self.state = ProbeTaskState::Finished; + Ok(ProbeTaskEvent::Finished) + } + } + } + ProbeTaskState::Finished => Ok(ProbeTaskEvent::Finished), } } } enum ProbeTaskEvent { Pending(ProbeStageTask), + Ready, + NeedStream(ProbeStageTask), Finished, } From 6d14d4bbc0cc28f3bdda088b1b9788df05723a61 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Sun, 9 Nov 2025 21:45:35 +0200 Subject: [PATCH 21/36] Fix duplicating rows with reseting a state --- .../src/joins/hash_join/partitioned.rs | 71 ++++++++----------- 1 file changed, 29 insertions(+), 42 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 7651a67128fcf..f6b3a3c2c1fd1 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -112,6 +112,17 @@ fn max_partitions_allowed_for_memory(memory_threshold: usize) -> usize { highest_power_of_two_leq(slots) } +fn per_partition_budget_bytes(memory_threshold: usize, partitions: usize) -> usize { + let partitions = partitions.max(1); + let mut budget = memory_threshold + .checked_div(partitions) + .unwrap_or(memory_threshold); + if budget == 0 { + budget = HYBRID_HASH_MIN_PARTITION_BYTES; + } + budget.max(HYBRID_HASH_MIN_PARTITION_BYTES) +} + #[inline] fn hhj_debug String>(builder: F) { if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() { @@ -202,6 +213,7 @@ impl ProbePartition { /// Runtime state tracked per probe partition. #[cfg(not(feature = "hybrid_hash_join_scheduler"))] +#[cfg(not(feature = "hybrid_hash_join_scheduler"))] pub(super) struct ProbePartitionState { buffered: ProbePartition, batch_position: usize, @@ -383,6 +395,8 @@ pub(super) struct PartitionedHashJoinStream { pub memory_reservation: MemoryReservation, /// Tracks how many repartition passes have been attempted pub partition_pass: usize, + /// Indicates whether the current pass has already prepared partitions for output + pub partition_pass_output_started: bool, /// Runtime environment pub runtime_env: Arc, /// Scratch space for computing hashes @@ -609,14 +623,8 @@ impl PartitionedHashJoinStream { return None; } - let mut per_partition_budget = self - .memory_threshold - .checked_div(self.max_partition_count.max(1)) - .unwrap_or(self.memory_threshold); - if per_partition_budget == 0 { - per_partition_budget = HYBRID_HASH_MIN_PARTITION_BYTES; - } - per_partition_budget = per_partition_budget.max(HYBRID_HASH_MIN_PARTITION_BYTES); + let mut per_partition_budget = + per_partition_budget_bytes(self.memory_threshold, self.num_partitions); let rows_budget = self .batch_size @@ -1229,9 +1237,6 @@ impl PartitionedHashJoinStream { Poll::Ready(None) => { self.probe_stream_finished = true; for part_id in 0..self.num_partitions { - if let Some(state) = self.probe_states.get_mut(part_id) { - state.batch_position = 0; - } self.finalize_spilled_partition(part_id)?; } return Poll::Ready(Ok(())); @@ -1767,6 +1772,7 @@ impl PartitionedHashJoinStream { build_spill_manager, memory_reservation, partition_pass: 0, + partition_pass_output_started: false, runtime_env, hashes_buffer: Vec::new(), right_side_ordered, @@ -1840,7 +1846,7 @@ impl PartitionedHashJoinStream { self.max_partition_count = 1; } - let mut allow_repartition = true; + let mut allow_repartition = !self.partition_pass_output_started; loop { hhj_debug(|| { format!( @@ -1883,6 +1889,7 @@ impl PartitionedHashJoinStream { self.num_partitions = next_count; self.partition_pass += 1; + self.partition_pass_output_started = false; allow_repartition = true; } } @@ -1966,13 +1973,6 @@ impl PartitionedHashJoinStream { repartition_request = Some(next_count); break; } - } else { - hhj_debug(|| { - format!( - "partition {} exceeded global mem but bytes={} under per-part budget; skipping repartition", - build_index, partition_estimate - ) - }); } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -1998,13 +1998,6 @@ impl PartitionedHashJoinStream { repartition_request = Some(next_count); break; } - } else { - hhj_debug(|| { - format!( - "allocation failure for partition {} but bytes={} under per-part budget; spilling without repartition", - build_index, partition_estimate - ) - }); } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -2150,19 +2143,11 @@ impl PartitionedHashJoinStream { }); } - if (max_spilled_bytes > self.memory_threshold || any_spilled) && allow_repartition + if allow_repartition + && (max_spilled_bytes > self.memory_threshold || any_spilled) + && self.repartition_worthwhile(max_spilled_bytes) { - if !self.repartition_worthwhile(max_spilled_bytes) { - hhj_debug(|| { - format!( - "spilled partitions already near budget (max_spilled_bytes={} bytes, memory_threshold={} partitions={}, budget≈{})", - max_spilled_bytes, - self.memory_threshold, - self.num_partitions, - (self.memory_threshold / self.num_partitions.max(1)).max(1) - ) - }); - } else if let Some(next_count) = self.next_partition_count() { + if let Some(next_count) = self.next_partition_count() { hhj_debug(|| { format!( "try_partition_build_side repartition due to spill (max_spilled_bytes={} threshold={} any_spilled={}) next_count={}", @@ -2173,14 +2158,11 @@ impl PartitionedHashJoinStream { ) }); return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); - } else { - hhj_debug(|| { - "spill detected but no further repartition possible".to_string() - }); } } self.prepare_partition_queue(); + self.partition_pass_output_started = true; self.transition_to_next_partition(); Ok(PartitionBuildStatus::Ready(StatefulStreamResult::Continue)) @@ -2293,6 +2275,11 @@ impl PartitionedHashJoinStream { state.active_values = values; state.active_hashes = hashes; state.active_offset = (0, None); + if state.batch_position >= state.buffered.batches.len() { + state.buffered = ProbePartition::new(); + state.batch_position = 0; + state.buffered_rows = 0; + } if let Some(b) = state.active_batch.as_ref() { state.consumed_rows = state.consumed_rows.saturating_add(b.num_rows()); From 4b0a66ba9f39c3307dd0d4adc7b6eb617fab104b Mon Sep 17 00:00:00 2001 From: osipovartem Date: Mon, 10 Nov 2025 12:58:02 +0300 Subject: [PATCH 22/36] Pass projection and filter --- .../src/joins/grace_hash_join/exec.rs | 22 ++----- .../src/joins/grace_hash_join/stream.rs | 24 ++++---- .../physical-plan/src/spill/spill_manager.rs | 60 +++++++++---------- 3 files changed, 48 insertions(+), 58 deletions(-) diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index 0782d9c84fc1f..2c9482f93f892 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -684,7 +684,8 @@ impl ExecutionPlan for GraceHashJoinExec { spill_right, on_left, on_right, - self.random_state.clone(), + self.projection.clone(), + self.filter.clone(), self.join_type, column_indices_after_projection, join_metrics, @@ -854,24 +855,8 @@ impl ExecutionPlan for GraceHashJoinExec { } } -/// Accumulator for collecting min/max bounds from build-side data during hash join. -/// -/// This struct encapsulates the logic for progressively computing column bounds -/// (minimum and maximum values) for a specific join key expression as batches -/// are processed during the build phase of a hash join. -/// -/// The bounds are used for dynamic filter pushdown optimization, where filters -/// based on the actual data ranges can be pushed down to the probe side to -/// eliminate unnecessary data early. -struct CollectLeftAccumulator { - /// The physical expression to evaluate for each batch - expr: Arc, - /// Accumulator for tracking the minimum value across all batches - min: MinAccumulator, - /// Accumulator for tracking the maximum value across all batches - max: MaxAccumulator, -} +#[allow(clippy::too_many_arguments)] pub async fn partition_and_spill( random_state: RandomState, on: Vec<(PhysicalExprRef, PhysicalExprRef)>, @@ -915,6 +900,7 @@ pub async fn partition_and_spill( Ok((left_index, right_index)) } +#[allow(clippy::too_many_arguments)] async fn partition_and_spill_one_side( input: &mut SendableRecordBatchStream, on_exprs: &[PhysicalExprRef], diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs index a6afdef1029d1..d028b0e8bcf87 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs @@ -34,7 +34,6 @@ use crate::empty::EmptyExec; use crate::joins::grace_hash_join::exec::PartitionIndex; use crate::joins::{HashJoinExec, PartitionMode}; use crate::test::TestMemoryExec; -use ahash::RandomState; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{JoinType, NullEquality, Result}; @@ -70,7 +69,8 @@ pub struct GraceHashJoinStream { spill_right: Arc, on_left: Vec, on_right: Vec, - random_state: RandomState, + projection: Option>, + filter: Option, join_type: JoinType, column_indices: Vec, join_metrics: Arc, @@ -113,7 +113,8 @@ impl GraceHashJoinStream { spill_right: Arc, on_left: Vec, on_right: Vec, - random_state: RandomState, + projection: Option>, + filter: Option, join_type: JoinType, column_indices: Vec, join_metrics: Arc, @@ -127,7 +128,8 @@ impl GraceHashJoinStream { spill_right, on_left, on_right, - random_state, + projection, + filter, join_type, column_indices, join_metrics, @@ -225,7 +227,8 @@ impl GraceHashJoinStream { right_batches, &self.on_left, &self.on_right, - self.random_state.clone(), + self.projection.clone(), + self.filter.clone(), self.join_type, &self.column_indices, &self.join_metrics, @@ -283,10 +286,11 @@ fn build_in_memory_join_stream( right_batches: Vec, on_left: &[PhysicalExprRef], on_right: &[PhysicalExprRef], - random_state: RandomState, + projection: Option>, + filter: Option, join_type: JoinType, - column_indices: &[ColumnIndex], - join_metrics: &BuildProbeJoinMetrics, + _column_indices: &[ColumnIndex], + _join_metrics: &BuildProbeJoinMetrics, context: &Arc, ) -> Result { if left_batches.is_empty() && right_batches.is_empty() { @@ -324,9 +328,9 @@ fn build_in_memory_join_stream( left_plan, right_plan, on, - None::, + filter, &join_type, - None, + projection, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, )?; diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 10f90c6cb3492..d3c38aba45989 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -173,37 +173,37 @@ impl SpillManager { /// Automatically decides whether to spill the given RecordBatch to memory or disk, /// depending on available memory pool capacity. pub(crate) fn spill_batch_auto(&self, batch: &RecordBatch, request_msg: &str) -> Result { - let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { - return Err(DataFusionError::Execution( - "failed to spill batch to disk".into(), - )); - }; - Ok(SpillLocation::Disk(Arc::new(file))) - // // - // let size = batch.get_sliced_size()?; - // - // // Check current memory usage and total limit from the runtime memory pool - // let used = self.env.memory_pool.reserved(); - // let limit = match self.env.memory_pool.memory_limit() { - // datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, - // _ => usize::MAX, + // let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { + // return Err(DataFusionError::Execution( + // "failed to spill batch to disk".into(), + // )); // }; - // println!("size {size} used {used}"); - // // If there's enough memory (with a small safety margin), keep it in memory - // if used + size * 3 * 64 / 2 <= limit { - // let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); - // self.metrics.spilled_bytes.add(size); - // self.metrics.spilled_rows.add(batch.num_rows()); - // Ok(SpillLocation::Memory(buf)) - // } else { - // // Otherwise spill to disk using the existing SpillManager logic - // let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { - // return Err(DataFusionError::Execution( - // "failed to spill batch to disk".into(), - // )); - // }; - // Ok(SpillLocation::Disk(Arc::new(file))) - // } + // Ok(SpillLocation::Disk(Arc::new(file))) + // // + let size = batch.get_sliced_size()?; + + // Check current memory usage and total limit from the runtime memory pool + let used = self.env.memory_pool.reserved(); + let limit = match self.env.memory_pool.memory_limit() { + datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, + _ => usize::MAX, + }; + + // If there's enough memory (with a safety margin), keep it in memory + if used + size * 3 / 2 <= limit { + let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); + self.metrics.spilled_bytes.add(size); + self.metrics.spilled_rows.add(batch.num_rows()); + Ok(SpillLocation::Memory(buf)) + } else { + // Otherwise spill to disk using the existing SpillManager logic + let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { + return Err(DataFusionError::Execution( + "failed to spill batch to disk".into(), + )); + }; + Ok(SpillLocation::Disk(Arc::new(file))) + } } pub fn spill_batches_auto( From 80f2c34459ec0327d2ad2336a132994679c93f65 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Mon, 10 Nov 2025 16:10:18 +0200 Subject: [PATCH 23/36] Spill manager support for sharing spill file handle --- .../physical-plan/src/joins/hash_join/exec.rs | 74 ++++++++++++++++--- .../src/joins/hash_join/partitioned.rs | 28 +++---- datafusion/physical-plan/src/spill/mod.rs | 10 ++- .../physical-plan/src/spill/spill_manager.rs | 12 +++ 4 files changed, 95 insertions(+), 29 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index c6eaa3489cbaa..7f5083ea354a9 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -41,7 +41,7 @@ use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; -use crate::spill::get_record_batch_memory_size; +use crate::spill::{get_record_batch_memory_size, spill_manager::SpillManager}; use crate::ExecutionPlanProperties; use crate::{ common::can_project, @@ -66,6 +66,7 @@ use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, }; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; @@ -78,7 +79,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::TryStreamExt; +use futures::{executor::block_on, StreamExt, TryStreamExt}; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. @@ -100,10 +101,14 @@ const HYBRID_HASH_MIN_ROWS_PER_PARTITION: usize = 1_024; pub(super) struct JoinLeftData { /// The hash table with indices into `batch` pub(super) hash_map: Box, - /// The input rows for the build side + /// The input rows for the build side (may be empty when spilled) batch: RecordBatch, - /// Original build-side batches before concatenation - original_batches: Arc>, + /// Build-side input backing storage + original_input: OriginalBuildInput, + /// Total rows collected from the build side + total_rows: usize, + /// Estimated bytes for the original build input + total_input_size: usize, /// The build side on expressions values values: Vec, /// Shared bitmap builder for visited left indices @@ -120,6 +125,14 @@ pub(super) struct JoinLeftData { pub(super) bounds: Option>, } +enum OriginalBuildInput { + InMemory(Arc>), + Spilled { + spill_manager: Arc, + spill_file: Arc, + }, +} + impl JoinLeftData { /// Create a new `JoinLeftData` from its parts pub(super) fn new( @@ -132,10 +145,17 @@ impl JoinLeftData { reservation: MemoryReservation, bounds: Option>, ) -> Self { + let total_rows = original_batches.iter().map(|b| b.num_rows()).sum(); + let total_input_size = original_batches + .iter() + .map(|batch| batch.get_array_memory_size()) + .sum(); Self { hash_map, batch, - original_batches, + original_input: OriginalBuildInput::InMemory(original_batches), + total_rows, + total_input_size, values, visited_indices_bitmap, probe_threads_counter, @@ -154,8 +174,12 @@ impl JoinLeftData { &self.batch } - pub(super) fn original_batches(&self) -> &[RecordBatch] { - &self.original_batches + pub(super) fn total_rows(&self) -> usize { + self.total_rows + } + + pub(super) fn total_input_size(&self) -> usize { + self.total_input_size } /// returns a reference to the build side expressions values @@ -173,6 +197,37 @@ impl JoinLeftData { pub(super) fn report_probe_completed(&self) -> bool { self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 } + + pub(super) fn for_each_original_batch(&self, mut f: F) -> Result<()> + where + F: FnMut(RecordBatch) -> Result, + { + match &self.original_input { + OriginalBuildInput::InMemory(batches) => { + for batch in batches.iter() { + if !f(batch.clone())? { + break; + } + } + Ok(()) + } + OriginalBuildInput::Spilled { + spill_manager, + spill_file, + } => { + let mut stream = + spill_manager.read_spill_as_stream_shared(Arc::clone(spill_file))?; + block_on(async { + while let Some(batch) = stream.next().await.transpose()? { + if !f(batch)? { + break; + } + } + Ok(()) + }) + } + } + } } #[allow(rustdoc::private_intra_doc_links)] @@ -1742,7 +1797,6 @@ async fn collect_left_input( bounds_accumulators, } = state; let batches_arc = Arc::new(batches); - // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` let fixed_size_u32 = size_of::(); @@ -1824,7 +1878,7 @@ async fn collect_left_input( hashmap, single_batch, Arc::clone(&batches_arc), - left_values.clone(), + left_values, Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), reservation, diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index f6b3a3c2c1fd1..b0df425fa4d65 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -1828,19 +1828,11 @@ impl PartitionedHashJoinStream { ) -> Result>> { if self.partition_pass == 0 { self.join_metrics.build_input_batches.add(1); - let total_rows: usize = build_data - .original_batches() - .iter() - .map(|b| b.num_rows()) - .sum(); + let total_rows = build_data.total_rows(); self.join_metrics.build_input_rows.add(total_rows); } - let build_total_size: usize = build_data - .original_batches() - .iter() - .map(|batch| batch.get_array_memory_size()) - .sum(); + let build_total_size = build_data.total_input_size(); if build_total_size <= self.memory_threshold { self.num_partitions = 1; self.max_partition_count = 1; @@ -1911,10 +1903,10 @@ impl PartitionedHashJoinStream { let mut max_spilled_bytes: usize = 0; let mut any_spilled = false; - for batch in build_data.original_batches() { + build_data.for_each_original_batch(|batch| { let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); for expr in &self.on_left { - keys_values.push(expr.evaluate(batch)?.into_array(batch.num_rows())?); + keys_values.push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); } let mut hashes = vec![0u64; batch.num_rows()]; create_hashes(&keys_values, &self.random_state, &mut hashes)?; @@ -1971,7 +1963,7 @@ impl PartitionedHashJoinStream { ) }); repartition_request = Some(next_count); - break; + return Ok(false); } } } @@ -1996,7 +1988,7 @@ impl PartitionedHashJoinStream { ) }); repartition_request = Some(next_count); - break; + return Ok(false); } } } @@ -2011,14 +2003,16 @@ impl PartitionedHashJoinStream { } if repartition_request.is_some() { - break; + return Ok(false); } } if repartition_request.is_some() { - break; + Ok(false) + } else { + Ok(true) } - } + })?; if let Some(next_count) = repartition_request { hhj_debug(|| { diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index fab62bff840f6..5c220f9c0340c 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -51,6 +51,7 @@ use futures::{FutureExt as _, Stream}; /// file read (instead of each batch). This approach does not work because when /// the number of concurrent reads exceeds the Tokio thread pool limit, /// deadlocks can occur and block progress. + struct SpillReaderStream { schema: SchemaRef, state: SpillReaderStreamState, @@ -63,7 +64,7 @@ type NextRecordBatchResult = Result<(StreamReader>, Option), /// A read is in progress in a spawned blocking task for which we hold the handle. ReadInProgress(SpawnedTask), @@ -77,6 +78,10 @@ enum SpillReaderStreamState { impl SpillReaderStream { fn new(schema: SchemaRef, spill_file: RefCountedTempFile) -> Self { + Self::new_from_shared(schema, Arc::new(spill_file)) + } + + fn new_from_shared(schema: SchemaRef, spill_file: Arc) -> Self { Self { schema, state: SpillReaderStreamState::Uninitialized(spill_file), @@ -96,8 +101,9 @@ impl SpillReaderStream { unreachable!() }; + let file_ref = spill_file.clone(); let task = SpawnedTask::spawn_blocking(move || { - let file = BufReader::new(File::open(spill_file.path())?); + let file = BufReader::new(File::open(file_ref.path())?); // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications // with validated schemas and buffers. Skip redundant validation during read // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written. diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index ad23bd66a021a..f605544c82318 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -182,6 +182,18 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + + pub fn read_spill_as_stream_shared( + &self, + spill_file_path: Arc, + ) -> Result { + let stream = Box::pin(cooperative(SpillReaderStream::new_from_shared( + Arc::clone(&self.schema), + spill_file_path, + ))); + + Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + } } pub(crate) trait GetSlicedSize { From d36af9dd4f6530bcf797ca8d95f3ac76e2bc96bd Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Mon, 10 Nov 2025 16:42:31 +0200 Subject: [PATCH 24/36] Add spilltodisk collectMode --- .../physical-plan/src/joins/hash_join/exec.rs | 271 +++++++++++------- .../src/joins/hash_join/partitioned.rs | 12 +- 2 files changed, 180 insertions(+), 103 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 7f5083ea354a9..4bdd8e41e1f85 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -41,7 +41,10 @@ use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; -use crate::spill::{get_record_batch_memory_size, spill_manager::SpillManager}; +use crate::spill::{ + get_record_batch_memory_size, in_progress_spill_file::InProgressSpillFile, + spill_manager::SpillManager, +}; use crate::ExecutionPlanProperties; use crate::{ common::can_project, @@ -64,7 +67,8 @@ use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, + internal_datafusion_err, internal_err, plan_err, project_schema, JoinSide, JoinType, + NullEquality, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -79,7 +83,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{executor::block_on, StreamExt, TryStreamExt}; +use futures::{executor::block_on, pin_mut, StreamExt}; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. @@ -125,7 +129,7 @@ pub(super) struct JoinLeftData { pub(super) bounds: Option>, } -enum OriginalBuildInput { +pub(super) enum OriginalBuildInput { InMemory(Arc>), Spilled { spill_manager: Arc, @@ -133,27 +137,29 @@ enum OriginalBuildInput { }, } +enum CollectLeftMode { + InMemory, + SpillToDisk { spill_manager: Arc }, +} + impl JoinLeftData { /// Create a new `JoinLeftData` from its parts pub(super) fn new( hash_map: Box, batch: RecordBatch, - original_batches: Arc>, + original_input: OriginalBuildInput, + total_rows: usize, + total_input_size: usize, values: Vec, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, reservation: MemoryReservation, bounds: Option>, ) -> Self { - let total_rows = original_batches.iter().map(|b| b.num_rows()).sum(); - let total_input_size = original_batches - .iter() - .map(|batch| batch.get_array_memory_size()) - .sum(); Self { hash_map, batch, - original_input: OriginalBuildInput::InMemory(original_batches), + original_input, total_rows, total_input_size, values, @@ -1022,6 +1028,7 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), enable_dynamic_filter_pushdown, + CollectLeftMode::InMemory, )) })?, PartitionMode::Partitioned => { @@ -1040,6 +1047,7 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), 1, enable_dynamic_filter_pushdown, + CollectLeftMode::InMemory, )) } PartitionMode::Auto => { @@ -1071,6 +1079,7 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), 1, enable_dynamic_filter_pushdown, + CollectLeftMode::InMemory, )) } else { // Spillable enabled: coalesce left to a single stream @@ -1084,6 +1093,11 @@ impl ExecutionPlan for HashJoinExec { let left_stream = left_plan.execute(0, Arc::clone(&context))?; let reservation = MemoryConsumer::new("HashJoinInput") .register(context.memory_pool()); + let build_input_spill_manager = Arc::new(SpillManager::new( + Arc::clone(&context.runtime_env()), + SpillMetrics::new(&self.metrics, partition), + build_schema.clone(), + )); let left_fut = self.left_fut.try_once(|| { Ok(collect_left_input( self.random_state.clone(), @@ -1094,6 +1108,9 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), enable_dynamic_filter_pushdown, + CollectLeftMode::SpillToDisk { + spill_manager: Arc::clone(&build_input_spill_manager), + }, )) })?; @@ -1749,44 +1766,53 @@ async fn collect_left_input( with_visited_indices_bitmap: bool, probe_threads_count: usize, should_compute_bounds: bool, + collect_mode: CollectLeftMode, ) -> Result { let schema = left_stream.schema(); - // This operation performs 2 steps at once: - // 1. creates a [JoinHashMap] of all batches from the stream - // 2. stores the batches in a vector. - let initial = BuildSideState::try_new( + let mut state = BuildSideState::try_new( metrics, reservation, on_left.clone(), &schema, should_compute_bounds, )?; + let mut total_input_size = 0usize; + let mut spill_writer: Option = None; + + pin_mut!(left_stream); + while let Some(batch) = left_stream.next().await { + let batch = batch?; + if let Some(ref mut accumulators) = state.bounds_accumulators { + for accumulator in accumulators { + accumulator.update_batch(&batch)?; + } + } - let state = left_stream - .try_fold(initial, |mut state, batch| async move { - // Update accumulators if computing bounds - if let Some(ref mut accumulators) = state.bounds_accumulators { - for accumulator in accumulators { - accumulator.update_batch(&batch)?; + let batch_size = get_record_batch_memory_size(&batch); + state.reservation.try_grow(batch_size)?; + state.metrics.build_mem_used.add(batch_size); + state.metrics.build_input_batches.add(1); + state.metrics.build_input_rows.add(batch.num_rows()); + state.num_rows += batch.num_rows(); + total_input_size = total_input_size.saturating_add(batch_size); + + match &collect_mode { + CollectLeftMode::InMemory => { + state.batches.push(batch); + } + CollectLeftMode::SpillToDisk { spill_manager } => { + if spill_writer.is_none() { + spill_writer = Some( + spill_manager.create_in_progress_file("hash_join_build_input")?, + ); + } + if let Some(writer) = spill_writer.as_mut() { + writer.append_batch(&batch)?; } } - - // Decide if we spill or not - let batch_size = get_record_batch_memory_size(&batch); - // Reserve memory for incoming batch - state.reservation.try_grow(batch_size)?; - // Update metrics - state.metrics.build_mem_used.add(batch_size); - state.metrics.build_input_batches.add(1); - state.metrics.build_input_rows.add(batch.num_rows()); - // Update row count - state.num_rows += batch.num_rows(); - // Push batch to output - state.batches.push(batch); - Ok(state) - }) - .await?; + } + } // Extract fields from state let BuildSideState { @@ -1797,70 +1823,113 @@ async fn collect_left_input( bounds_accumulators, } = state; let batches_arc = Arc::new(batches); - // Estimation of memory size, required for hashtable, prior to allocation. - // Final result can be verified using `RawTable.allocation_info()` - let fixed_size_u32 = size_of::(); - let fixed_size_u64 = size_of::(); - - // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the - // `u64` indice variant - let mut hashmap: Box = if num_rows > u32::MAX as usize { - let estimated_hashtable_size = - estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU64::with_capacity(num_rows)) - } else { - let estimated_hashtable_size = - estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU32::with_capacity(num_rows)) - }; - - let mut hashes_buffer = Vec::new(); - let mut offset = 0; + let (hashmap, single_batch, left_values, visited_indices_bitmap, original_input) = + match collect_mode { + CollectLeftMode::InMemory => { + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + let mut hashmap: Box = + if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; - // Updating hashmap starting from the last batch - let batches_iter = batches_arc.iter().rev(); - for batch in batches_iter.clone() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( - &on_left, - batch, - &mut *hashmap, - offset, - &random_state, - &mut hashes_buffer, - 0, - true, - )?; - offset += batch.num_rows(); - } - // Merge all batches into a single batch, so we can directly index into the arrays - let single_batch = concat_batches(&schema, batches_iter)?; - - // Reserve additional memory for visited indices bitmap and create shared builder - let visited_indices_bitmap = if with_visited_indices_bitmap { - let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); - reservation.try_grow(bitmap_size)?; - metrics.build_mem_used.add(bitmap_size); - - let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); - bitmap_buffer.append_n(num_rows, false); - bitmap_buffer - } else { - BooleanBufferBuilder::new(0) - }; + let mut hashes_buffer = Vec::new(); + let mut offset = 0; + let batches_iter = batches_arc.iter().rev(); + for batch in batches_iter.clone() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + &on_left, + batch, + &mut *hashmap, + offset, + &random_state, + &mut hashes_buffer, + 0, + true, + )?; + offset += batch.num_rows(); + } + let single_batch = concat_batches(&schema, batches_iter)?; + let visited_indices_bitmap = if with_visited_indices_bitmap { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = + BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; - let left_values = on_left - .iter() - .map(|c| { - c.evaluate(&single_batch)? - .into_array(single_batch.num_rows()) - }) - .collect::>>()?; + let left_values = on_left + .iter() + .map(|c| { + c.evaluate(&single_batch)? + .into_array(single_batch.num_rows()) + }) + .collect::>>()?; + + ( + hashmap, + single_batch, + left_values, + visited_indices_bitmap, + OriginalBuildInput::InMemory(Arc::clone(&batches_arc)), + ) + } + CollectLeftMode::SpillToDisk { spill_manager } => { + if num_rows == 0 { + ( + Box::new(JoinHashMapU32::with_capacity(0)) + as Box, + RecordBatch::new_empty(schema.clone()), + Vec::new(), + BooleanBufferBuilder::new(0), + OriginalBuildInput::InMemory(Arc::new(vec![])), + ) + } else { + let mut writer = spill_writer.ok_or_else(|| { + internal_datafusion_err!("missing build spill writer") + })?; + let spill_file = writer.finish()?.ok_or_else(|| { + internal_datafusion_err!( + "expected spill file when spilling build input" + ) + })?; + metrics.build_spill_count.add(1); + metrics.build_spilled_rows.add(num_rows); + metrics + .build_spilled_bytes + .add(spill_file.current_disk_usage() as usize); + let _ = reservation.try_shrink(reservation.size()); + ( + Box::new(JoinHashMapU32::with_capacity(0)) + as Box, + RecordBatch::new_empty(schema.clone()), + Vec::new(), + BooleanBufferBuilder::new(0), + OriginalBuildInput::Spilled { + spill_manager, + spill_file: Arc::new(spill_file), + }, + ) + } + } + }; // Compute bounds for dynamic filter if enabled let bounds = match bounds_accumulators { @@ -1877,7 +1946,9 @@ async fn collect_left_input( let data = JoinLeftData::new( hashmap, single_batch, - Arc::clone(&batches_arc), + original_input, + num_rows, + total_input_size, left_values, Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index b0df425fa4d65..3cf813d0cc656 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -623,7 +623,7 @@ impl PartitionedHashJoinStream { return None; } - let mut per_partition_budget = + let per_partition_budget = per_partition_budget_bytes(self.memory_threshold, self.num_partitions); let rows_budget = self @@ -3872,10 +3872,16 @@ mod scheduler_tests { let reservation = MemoryConsumer::new("left") .with_can_spill(true) .register(&runtime_env.memory_pool); + let arc_batches = Arc::new(vec![batch.clone()]); + let total_rows = arc_batches.iter().map(|b| b.num_rows()).sum(); + let total_input_size = + arc_batches.iter().map(|b| b.get_array_memory_size()).sum(); JoinLeftData::new( hash_map, - batch.clone(), - Arc::new(vec![batch]), + batch, + OriginalBuildInput::InMemory(Arc::clone(&arc_batches)), + total_rows, + total_input_size, vec![], Mutex::new(BooleanBufferBuilder::new(0)), AtomicUsize::new(0), From d0a48e0cfc873ccf3e1dc27bb428ca3adda9f518 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Mon, 10 Nov 2025 17:26:02 +0200 Subject: [PATCH 25/36] Try fixing freezzed state --- .../src/joins/hash_join/partitioned.rs | 167 +++++++++++++++--- .../src/joins/hash_join/scheduler.rs | 2 + 2 files changed, 146 insertions(+), 23 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 3cf813d0cc656..dcdf79c44c3f9 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -123,6 +123,21 @@ fn per_partition_budget_bytes(memory_threshold: usize, partitions: usize) -> usi budget.max(HYBRID_HASH_MIN_PARTITION_BYTES) } +fn estimate_probe_buffer_bytes( + batch: &RecordBatch, + values: &[ArrayRef], + hashes: &[u64], +) -> usize { + let batch_bytes = batch.get_array_memory_size(); + let values_bytes = values.iter().fold(0usize, |acc, arr| { + acc.saturating_add(arr.get_array_memory_size()) + }); + let hashes_bytes = hashes.len().saturating_mul(size_of::()); + batch_bytes + .saturating_add(values_bytes) + .saturating_add(hashes_bytes) +} + #[inline] fn hhj_debug String>(builder: F) { if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() { @@ -218,6 +233,7 @@ pub(super) struct ProbePartitionState { buffered: ProbePartition, batch_position: usize, buffered_rows: usize, + buffered_bytes: usize, spilled_rows: usize, consumed_rows: usize, spill_in_progress: Option, @@ -237,6 +253,7 @@ impl ProbePartitionState { buffered: ProbePartition::new(), batch_position: 0, buffered_rows: 0, + buffered_bytes: 0, spilled_rows: 0, consumed_rows: 0, spill_in_progress: None, @@ -601,6 +618,71 @@ impl PartitionedHashJoinStream { Ok(false) } + fn flush_probe_buffer_to_spill(&mut self, part_id: usize) -> Result<()> { + if !self.runtime_env.disk_manager.tmp_files_enabled() { + return Err(internal_datafusion_err!( + "Insufficient memory for buffering probe partitions and spilling is disabled" + )); + } + if part_id >= self.probe_states.len() { + return Ok(()); + } + + let (buffered, queue_ready, stream_active, mut writer_opt) = { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + if state.buffered.batches.is_empty() { + state.buffered_bytes = 0; + return Ok(()); + } + let queue_ready = !state.spill_files.is_empty(); + let stream_active = state.pending_stream.is_some(); + let buffered = mem::replace(&mut state.buffered, ProbePartition::new()); + state.buffered_rows = 0; + state.buffered_bytes = 0; + state.batch_position = 0; + let writer_opt = state.spill_in_progress.take(); + (buffered, queue_ready, stream_active, writer_opt) + }; + + if writer_opt.is_none() { + writer_opt = Some( + self.probe_spill_manager + .create_in_progress_file("hash_join_probe_partition")?, + ); + self.join_metrics.probe_spill_count.add(1); + } + + let mut writer = writer_opt.ok_or_else(|| { + internal_datafusion_err!("expected probe spill writer for partition") + })?; + + let mut spilled_rows = 0usize; + for batch in buffered.batches { + let batch_size = batch.get_array_memory_size(); + writer.append_batch(&batch)?; + self.join_metrics.probe_spilled_rows.add(batch.num_rows()); + self.join_metrics.probe_spilled_bytes.add(batch_size); + spilled_rows = spilled_rows.saturating_add(batch.num_rows()); + } + + { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.spill_in_progress = Some(writer); + state.spilled_rows = state.spilled_rows.saturating_add(spilled_rows); + } + + if !queue_ready && !stream_active { + self.finalize_spilled_partition(part_id)?; + } + Ok(()) + } + fn compute_recursive_fanout( &self, descriptor: &PartitionDescriptor, @@ -942,6 +1024,8 @@ impl PartitionedHashJoinStream { let shift_bits = descriptor.radix_bits; let mask = (fanout - 1) as u64; + let probe_partition_budget = + per_partition_budget_bytes(self.memory_threshold, self.num_partitions); let spill_file = { let state = self @@ -950,6 +1034,7 @@ impl PartitionedHashJoinStream { .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; state.batch_position = 0; state.buffered_rows = 0; + state.buffered_bytes = 0; state.spilled_rows = 0; state.consumed_rows = 0; state.active_batch = None; @@ -1036,6 +1121,7 @@ impl PartitionedHashJoinStream { state.spill_files.push_back(file); state.spilled_rows = 0; state.buffered_rows = 0; + state.buffered_bytes = 0; state.consumed_rows = 0; state.batch_position = 0; state.pending_stream = None; @@ -1054,6 +1140,8 @@ impl PartitionedHashJoinStream { .probe_states .get_mut(parent_index) .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.buffered_bytes = 0; + state.batch_position = 0; mem::replace(&mut state.buffered, ProbePartition::new()) }; for idx in 0..parent_partition.batches.len() { @@ -1095,20 +1183,30 @@ impl PartitionedHashJoinStream { } let idx = partition_indices[sub_idx]; - let state = self - .probe_states - .get_mut(idx) - .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - state.buffered.batches.push(filtered_batch); - state.buffered.values.push(filtered_values); - state.buffered.hashes.push(filtered_hashes); - let buffered = state - .buffered - .batches - .last() - .map(|b| b.num_rows()) - .unwrap_or_default(); - state.buffered_rows = state.buffered_rows.saturating_add(buffered); + let row_count = filtered_batch.num_rows(); + let buffered_bytes_delta = estimate_probe_buffer_bytes( + &filtered_batch, + &filtered_values, + &filtered_hashes, + ); + let mut should_flush = false; + { + let state = self.probe_states.get_mut(idx).ok_or_else(|| { + internal_datafusion_err!("missing probe partition") + })?; + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_values); + state.buffered.hashes.push(filtered_hashes); + state.buffered_rows = state.buffered_rows.saturating_add(row_count); + state.buffered_bytes = + state.buffered_bytes.saturating_add(buffered_bytes_delta); + if state.buffered_bytes >= probe_partition_budget { + should_flush = true; + } + } + if should_flush { + self.flush_probe_buffer_to_spill(idx)?; + } } } @@ -1119,6 +1217,8 @@ impl PartitionedHashJoinStream { if self.probe_states.len() != self.num_partitions { self.resize_partition_vectors(); } + let probe_partition_budget = + per_partition_budget_bytes(self.memory_threshold, self.num_partitions); loop { match self.right.poll_next_unpin(cx) { @@ -1217,16 +1317,36 @@ impl PartitionedHashJoinStream { self.finalize_spilled_partition(part_id)?; } } else { - let state = - self.probe_states.get_mut(part_id).ok_or_else(|| { - internal_datafusion_err!("missing probe partition") - })?; - state.buffered.batches.push(filtered_batch); - state.buffered.values.push(filtered_on_values); - state.buffered.hashes.push(filtered_hashes); - if let Some(last) = state.buffered.batches.last() { + let row_count = filtered_batch.num_rows(); + let buffered_bytes_delta = estimate_probe_buffer_bytes( + &filtered_batch, + &filtered_on_values, + &filtered_hashes, + ); + let mut should_flush = false; + { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| { + internal_datafusion_err!( + "missing probe partition" + ) + })?; + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_on_values); + state.buffered.hashes.push(filtered_hashes); state.buffered_rows = - state.buffered_rows.saturating_add(last.num_rows()); + state.buffered_rows.saturating_add(row_count); + state.buffered_bytes = state + .buffered_bytes + .saturating_add(buffered_bytes_delta); + if state.buffered_bytes >= probe_partition_budget { + should_flush = true; + } + } + if should_flush { + self.flush_probe_buffer_to_spill(part_id)?; } } } @@ -2273,6 +2393,7 @@ impl PartitionedHashJoinStream { state.buffered = ProbePartition::new(); state.batch_position = 0; state.buffered_rows = 0; + state.buffered_bytes = 0; } if let Some(b) = state.active_batch.as_ref() { state.consumed_rows = diff --git a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs index 8a979713adee2..b2a04d7736242 100644 --- a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs +++ b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs @@ -260,6 +260,7 @@ pub(super) struct ProbePartitionState { pub buffered: ProbePartition, pub batch_position: usize, pub buffered_rows: usize, + pub buffered_bytes: usize, pub spilled_rows: usize, pub consumed_rows: usize, pub spill_in_progress: Option, @@ -278,6 +279,7 @@ impl ProbePartitionState { buffered: ProbePartition::new(), batch_position: 0, buffered_rows: 0, + buffered_bytes: 0, spilled_rows: 0, consumed_rows: 0, spill_in_progress: None, From 505b1cba12b98f12ef508e4d884b26c34a4a35b3 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Tue, 11 Nov 2025 20:18:50 +0200 Subject: [PATCH 26/36] Add flag for scheduler --- .../physical-plan/src/joins/hash_join/exec.rs | 367 +++++++--- .../src/joins/hash_join/partitioned.rs | 656 +++++++++++++++--- 2 files changed, 824 insertions(+), 199 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 4bdd8e41e1f85..56b7ae899a949 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -16,6 +16,8 @@ // under the License. use std::fmt; +use std::fs::File; +use std::io::BufReader; use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; @@ -33,8 +35,7 @@ use crate::joins::hash_join::stream::{ }; use crate::joins::join_hash_map::{JoinHashMapU32, JoinHashMapU64}; use crate::joins::utils::{ - asymmetric_join_output_partitioning, reorder_output_after_swap, swap_join_projection, - update_hash, OnceAsync, OnceFut, + reorder_output_after_swap, swap_join_projection, update_hash, OnceAsync, OnceFut, }; use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; use crate::projection::{ @@ -61,14 +62,15 @@ use crate::{ use arrow::array::{ArrayRef, BooleanBufferBuilder}; use arrow::compute::concat_batches; use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::StreamReader; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - internal_datafusion_err, internal_err, plan_err, project_schema, JoinSide, JoinType, - NullEquality, Result, + internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, + JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -83,7 +85,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{executor::block_on, pin_mut, StreamExt}; +use futures::{pin_mut, StreamExt}; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. @@ -91,7 +93,7 @@ const HASH_JOIN_SEED: RandomState = RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); /// Maximum number of partitions allowed when recursively repartitioning during hybrid hash join. -const HYBRID_HASH_MAX_PARTITIONS: usize = 1 << 16; +pub(crate) const HYBRID_HASH_MAX_PARTITIONS: usize = 1 << 16; /// Upper bound multiplier applied to the initial partition fanout when searching for additional partitions. const HYBRID_HASH_PARTITION_GROWTH_FACTOR: usize = 16; /// Approximate number of probe batches worth of rows we target per partition when statistics are available. @@ -101,6 +103,8 @@ const HYBRID_HASH_MIN_BYTES_PER_PARTITION: usize = 8 * 1024 * 1024; /// Minimum number of rows per partition when statistics are available to avoid extreme fan-out. const HYBRID_HASH_MIN_ROWS_PER_PARTITION: usize = 1_024; +static NEXT_HHJ_STREAM_ID: AtomicUsize = AtomicUsize::new(0); + /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` @@ -132,7 +136,7 @@ pub(super) struct JoinLeftData { pub(super) enum OriginalBuildInput { InMemory(Arc>), Spilled { - spill_manager: Arc, + _spill_manager: Arc, spill_file: Arc, }, } @@ -217,20 +221,20 @@ impl JoinLeftData { } Ok(()) } - OriginalBuildInput::Spilled { - spill_manager, - spill_file, - } => { - let mut stream = - spill_manager.read_spill_as_stream_shared(Arc::clone(spill_file))?; - block_on(async { - while let Some(batch) = stream.next().await.transpose()? { - if !f(batch)? { - break; - } + OriginalBuildInput::Spilled { spill_file, .. } => { + let file = File::open(spill_file.path())?; + let reader = BufReader::new(file); + // SAFETY: spill files are generated by DataFusion with validated schema/buffers. + let mut reader = unsafe { + StreamReader::try_new(reader, None)?.with_skip_validation(true) + }; + while let Some(batch) = reader.next() { + let batch = batch.map_err(DataFusionError::from)?; + if !f(batch)? { + break; } - Ok(()) - }) + } + Ok(()) } } } @@ -670,9 +674,7 @@ impl HashJoinExec { )?; let mut output_partitioning = match mode { - PartitionMode::CollectLeft => { - asymmetric_join_output_partitioning(left, right, &join_type)? - } + PartitionMode::CollectLeft => Partitioning::UnknownPartitioning(1), PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), @@ -1005,7 +1007,23 @@ impl ExecutionPlan for HashJoinExec { if self.mode == PartitionMode::CollectLeft && left_partitions != 1 { return internal_err!( "Invalid HashJoinExec, the output partition count of the left child must be 1 in CollectLeft mode,\ - consider using CoalescePartitionsExec or the EnforceDistribution rule" + consider using CoalescePartitionsExec or the EnforceDistribution rule" + ); + } + + let enable_spillable_mode = context + .session_config() + .options() + .optimizer + .enable_spillable_hash_join; + let single_stream_mode = matches!(self.mode, PartitionMode::CollectLeft) + || (self.mode == PartitionMode::PartitionedSpillable && enable_spillable_mode); + + if single_stream_mode && partition > 0 { + return plan_err!( + "HashJoinExec in {:?} mode produces a single output partition but partition {partition} was requested. \ + Insert a RepartitionExec above the join when additional parallelism is required.", + self.mode ); } @@ -1057,13 +1075,7 @@ impl ExecutionPlan for HashJoinExec { ); } PartitionMode::PartitionedSpillable => { - let enable_spillable = context - .session_config() - .options() - .optimizer - .enable_spillable_hash_join; - - if !enable_spillable { + if !enable_spillable_mode { // Legacy fallback: behave like Partitioned let left_stream = self.left.execute(partition, Arc::clone(&context))?; @@ -1349,6 +1361,19 @@ impl ExecutionPlan for HashJoinExec { .register(context.memory_pool()); let probe_spill_metrics = SpillMetrics::new(&self.metrics, partition); let build_spill_metrics = SpillMetrics::new(&self.metrics, partition); + let stream_id = + NEXT_HHJ_STREAM_ID.fetch_add(1, Ordering::Relaxed); + #[cfg(feature = "hybrid_hash_join_scheduler")] + let probe_scheduler_enabled = std::env::var( + "DATAFUSION_HHJ_ENABLE_SCHEDULER", + ) + .map(|v| { + matches!( + v.to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) + .unwrap_or(false); let partitioned_stream = PartitionedHashJoinStream::new( partition, self.schema(), @@ -1374,6 +1399,9 @@ impl ExecutionPlan for HashJoinExec { probe_schema, self.right.output_ordering().is_some(), shared_bounds_accumulator, + stream_id, + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_enabled, )?; return Ok(Box::pin(partitioned_stream)); } @@ -1822,67 +1850,21 @@ async fn collect_left_input( mut reservation, bounds_accumulators, } = state; - let batches_arc = Arc::new(batches); let (hashmap, single_batch, left_values, visited_indices_bitmap, original_input) = match collect_mode { CollectLeftMode::InMemory => { - let fixed_size_u32 = size_of::(); - let fixed_size_u64 = size_of::(); - let mut hashmap: Box = - if num_rows > u32::MAX as usize { - let estimated_hashtable_size = - estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU64::with_capacity(num_rows)) - } else { - let estimated_hashtable_size = - estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU32::with_capacity(num_rows)) - }; - - let mut hashes_buffer = Vec::new(); - let mut offset = 0; - let batches_iter = batches_arc.iter().rev(); - for batch in batches_iter.clone() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( + let batches_arc = Arc::new(batches); + let (hashmap, single_batch, left_values, visited_indices_bitmap) = + build_join_data_from_batches( + batches_arc.as_slice(), + num_rows, + &schema, &on_left, - batch, - &mut *hashmap, - offset, &random_state, - &mut hashes_buffer, - 0, - true, + &mut reservation, + &metrics, + with_visited_indices_bitmap, )?; - offset += batch.num_rows(); - } - let single_batch = concat_batches(&schema, batches_iter)?; - let visited_indices_bitmap = if with_visited_indices_bitmap { - let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); - reservation.try_grow(bitmap_size)?; - metrics.build_mem_used.add(bitmap_size); - - let mut bitmap_buffer = - BooleanBufferBuilder::new(single_batch.num_rows()); - bitmap_buffer.append_n(num_rows, false); - bitmap_buffer - } else { - BooleanBufferBuilder::new(0) - }; - - let left_values = on_left - .iter() - .map(|c| { - c.evaluate(&single_batch)? - .into_array(single_batch.num_rows()) - }) - .collect::>>()?; - ( hashmap, single_batch, @@ -1905,26 +1887,40 @@ async fn collect_left_input( let mut writer = spill_writer.ok_or_else(|| { internal_datafusion_err!("missing build spill writer") })?; - let spill_file = writer.finish()?.ok_or_else(|| { + let spill_file = Arc::new(writer.finish()?.ok_or_else(|| { internal_datafusion_err!( "expected spill file when spilling build input" ) - })?; + })?); metrics.build_spill_count.add(1); metrics.build_spilled_rows.add(num_rows); metrics .build_spilled_bytes .add(spill_file.current_disk_usage() as usize); let _ = reservation.try_shrink(reservation.size()); + + let reloaded_batches = reload_spilled_batches(&spill_file)?; + let (hashmap, single_batch, left_values, visited_indices_bitmap) = + build_join_data_from_batches( + &reloaded_batches, + num_rows, + &schema, + &on_left, + &random_state, + &mut reservation, + &metrics, + with_visited_indices_bitmap, + )?; + drop(reloaded_batches); + ( - Box::new(JoinHashMapU32::with_capacity(0)) - as Box, - RecordBatch::new_empty(schema.clone()), - Vec::new(), - BooleanBufferBuilder::new(0), + hashmap, + single_batch, + left_values, + visited_indices_bitmap, OriginalBuildInput::Spilled { - spill_manager, - spill_file: Arc::new(spill_file), + _spill_manager: spill_manager, + spill_file, }, ) } @@ -1959,6 +1955,94 @@ async fn collect_left_input( Ok(data) } +fn build_join_data_from_batches( + batches: &[RecordBatch], + num_rows: usize, + schema: &SchemaRef, + on_left: &[PhysicalExprRef], + random_state: &RandomState, + reservation: &mut MemoryReservation, + metrics: &BuildProbeJoinMetrics, + with_visited_indices_bitmap: bool, +) -> Result<( + Box, + RecordBatch, + Vec, + BooleanBufferBuilder, +)> { + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + let mut hashmap: Box = if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; + + let mut hashes_buffer = Vec::new(); + let mut offset = 0usize; + let batches_iter = batches.iter().rev(); + for batch in batches_iter.clone() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + on_left, + batch, + &mut *hashmap, + offset, + random_state, + &mut hashes_buffer, + 0, + true, + )?; + offset += batch.num_rows(); + } + let single_batch = concat_batches(schema, batches_iter)?; + let visited_indices_bitmap = if with_visited_indices_bitmap { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let left_values = on_left + .iter() + .map(|c| { + c.evaluate(&single_batch)? + .into_array(single_batch.num_rows()) + }) + .collect::>>()?; + + Ok((hashmap, single_batch, left_values, visited_indices_bitmap)) +} + +fn reload_spilled_batches( + spill_file: &Arc, +) -> Result> { + let file = File::open(spill_file.path())?; + let reader = BufReader::new(file); + // SAFETY: spill files are generated by DataFusion with validated schema/buffers. + let mut reader = + unsafe { StreamReader::try_new(reader, None)?.with_skip_validation(true) }; + let mut batches = Vec::new(); + while let Some(batch) = reader.next() { + batches.push(batch.map_err(DataFusionError::from)?); + } + Ok(batches) +} + #[cfg(test)] mod tests { use super::*; @@ -1966,8 +2050,9 @@ mod tests { use crate::joins::hash_join::stream::lookup_join_hashmap; use crate::test::{assert_join_metrics, TestMemoryExec}; use crate::{ - common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, - test::exec::MockExec, + common, expressions::Column, metrics::ExecutionPlanMetricsSet, + repartition::RepartitionExec, stream::RecordBatchStreamAdapter, + test::build_table_i32, test::exec::MockExec, }; use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; @@ -1978,13 +2063,15 @@ mod tests { use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, - ScalarValue, + DataFusionError, ScalarValue, }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; + use datafusion_execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::PhysicalExpr; + use futures::stream; use hashbrown::HashTable; use insta::{allow_duplicates, assert_snapshot}; use rstest::*; @@ -2013,6 +2100,74 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + #[tokio::test] + async fn collect_left_input_spill_path_rebuilds_state() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![4, 5])), + Arc::new(Int32Array::from(vec![40, 50])), + ], + )?; + + let stream = RecordBatchStreamAdapter::new( + Arc::clone(&schema), + stream::iter(vec![ + Ok::<_, DataFusionError>(batch1.clone()), + Ok::<_, DataFusionError>(batch2.clone()), + ]), + ); + let left_stream: SendableRecordBatchStream = Box::pin(stream); + + let metrics_set = ExecutionPlanMetricsSet::new(); + let join_metrics = BuildProbeJoinMetrics::new(0, &metrics_set); + let pool: Arc = Arc::new(UnboundedMemoryPool::default()); + let reservation = MemoryConsumer::new("collect_left_input_spill").register(&pool); + + let spill_manager = Arc::new(SpillManager::new( + Arc::new(RuntimeEnv::default()), + SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0), + Arc::clone(&schema), + )); + + let join_data = collect_left_input( + RandomState::new(), + left_stream, + vec![Arc::new(Column::new("a", 0))], + join_metrics, + reservation, + true, + 1, + false, + CollectLeftMode::SpillToDisk { spill_manager }, + ) + .await?; + + let expected_rows = batch1.num_rows() + batch2.num_rows(); + assert_eq!(join_data.total_rows(), expected_rows); + assert_eq!(join_data.batch().num_rows(), expected_rows); + assert!(!join_data.hash_map().is_empty()); + + let mut counted_rows = 0usize; + join_data.for_each_original_batch(|batch| { + counted_rows += batch.num_rows(); + Ok(true) + })?; + assert_eq!(counted_rows, expected_rows); + Ok(()) + } + fn join( left: Arc, right: Arc, @@ -2157,10 +2312,22 @@ mod tests { null_equality, )?; + let spillable_enabled = context + .session_config() + .options() + .optimizer + .enable_spillable_hash_join; + let single_stream_output = matches!(partition_mode, PartitionMode::CollectLeft) + || (partition_mode == PartitionMode::PartitionedSpillable && spillable_enabled); + let requested_partitions = if single_stream_output { + 1 + } else { + partition_count + }; let columns = columns(&join.schema()); let mut batches = vec![]; - for i in 0..partition_count { + for i in 0..requested_partitions { let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index dcdf79c44c3f9..4082ec2d87c7d 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -83,6 +83,7 @@ use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; use futures::{executor::block_on, ready, Stream, StreamExt}; +use crate::joins::hash_join::exec::HYBRID_HASH_MAX_PARTITIONS; const HYBRID_HASH_MAX_REPARTITION_DEPTH: usize = 6; const HYBRID_HASH_MIN_FANOUT: usize = 2; @@ -375,6 +376,7 @@ pub(super) struct PartitionedHashJoinStream { pub num_partitions: usize, /// Maximum partition fanout allowed when recursively repartitioning pub max_partition_count: usize, + pub mem_limit_partitions: usize, /// Memory threshold for spilling (in bytes) pub memory_threshold: usize, @@ -463,6 +465,13 @@ pub(super) struct PartitionedHashJoinStream { pub partition_pending: Vec, /// Latest descriptor metadata per partition pub partition_descriptors: Vec>, + /// Whether the async probe scheduler is enabled + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_enabled: bool, + /// Whether this stream has already completed (prevents restart) + pub stream_completed: bool, + /// Unique identifier for debugging/logging + pub stream_id: usize, } #[cfg(feature = "hybrid_hash_join_scheduler")] @@ -475,6 +484,13 @@ enum ProbeTaskStatus { } impl PartitionedHashJoinStream { + #[inline] + fn log_stream String>(&self, builder: F) { + if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() { + let stream_id = self.stream_id; + hhj_debug(|| format!("[stream={stream_id}] {}", builder())); + } + } /// Compute partition id for a given hash using radix mask when possible #[inline] fn partition_for_hash(&self, hash: u64) -> usize { @@ -562,7 +578,9 @@ impl PartitionedHashJoinStream { self.pending_partitions.push_back(desc.clone()); self.partition_pending[part_id] = true; #[cfg(feature = "hybrid_hash_join_scheduler")] - self.schedule_probe_task(&desc); + if self.probe_scheduler_enabled { + self.schedule_probe_task(&desc); + } } Ok(()) @@ -582,6 +600,9 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn ensure_probe_scheduler_capacity(&mut self, part_id: usize) { + if !self.probe_scheduler_enabled { + return; + } if self.probe_scheduler_inflight.len() <= part_id { self.probe_scheduler_inflight.resize(part_id + 1, false); } @@ -589,6 +610,9 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn schedule_probe_task(&mut self, descriptor: &PartitionDescriptor) { + if !self.probe_scheduler_enabled { + return; + } let part_id = descriptor.build_index; self.ensure_probe_scheduler_capacity(part_id); if self.probe_scheduler_inflight[part_id] { @@ -613,6 +637,16 @@ impl PartitionedHashJoinStream { state.spill_files.push_back(file); } self.schedule_partition(part_id)?; + hhj_debug(|| { + format!( + "finalize_spilled_partition enqueued partition {} (spill_files={})", + part_id, + self.probe_states + .get(part_id) + .map(|s| s.spill_files.len()) + .unwrap_or(0) + ) + }); return Ok(true); } Ok(false) @@ -628,7 +662,7 @@ impl PartitionedHashJoinStream { return Ok(()); } - let (buffered, queue_ready, stream_active, mut writer_opt) = { + let (mut buffered, queue_ready, stream_active, mut writer_opt) = { let state = self .probe_states .get_mut(part_id) @@ -653,14 +687,16 @@ impl PartitionedHashJoinStream { .create_in_progress_file("hash_join_probe_partition")?, ); self.join_metrics.probe_spill_count.add(1); + hhj_debug(|| { + format!("flush_probe_buffer_to_spill opening writer for part {part_id}") + }); } - let mut writer = writer_opt.ok_or_else(|| { - internal_datafusion_err!("expected probe spill writer for partition") - })?; + let mut writer = writer_opt + .ok_or_else(|| internal_datafusion_err!("missing probe spill writer"))?; let mut spilled_rows = 0usize; - for batch in buffered.batches { + for batch in buffered.batches.drain(..) { let batch_size = batch.get_array_memory_size(); writer.append_batch(&batch)?; self.join_metrics.probe_spilled_rows.add(batch.num_rows()); @@ -668,17 +704,32 @@ impl PartitionedHashJoinStream { spilled_rows = spilled_rows.saturating_add(batch.num_rows()); } - { - let state = self - .probe_states - .get_mut(part_id) - .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - state.spill_in_progress = Some(writer); - state.spilled_rows = state.spilled_rows.saturating_add(spilled_rows); - } + let spill_file = writer.finish()?.ok_or_else(|| { + internal_datafusion_err!("expected probe spill file after flush") + })?; + hhj_debug(|| { + format!( + "flush_probe_buffer_to_spill part {} flushed_rows={} queue_ready={} stream_active={}", + part_id, spilled_rows, queue_ready, stream_active + ) + }); + + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.spill_files.push_back(spill_file); + state.spill_in_progress = None; + state.spilled_rows = state.spilled_rows.saturating_add(spilled_rows); if !queue_ready && !stream_active { - self.finalize_spilled_partition(part_id)?; + hhj_debug(|| { + format!( + "flush_probe_buffer_to_spill scheduling partition {} after flush (queue empty)", + part_id + ) + }); + self.schedule_partition(part_id)?; } Ok(()) } @@ -713,7 +764,9 @@ impl PartitionedHashJoinStream { .saturating_mul(HYBRID_HASH_ROWS_PER_PARTITION_TARGET_MULTIPLIER) .max(HYBRID_HASH_ROWS_PER_PARTITION_MIN); - let should_repartition_bytes = descriptor.spilled_bytes > per_partition_budget; + let spilled_any = descriptor.spilled_bytes > 0; + let should_repartition_bytes = + spilled_any || descriptor.spilled_bytes > per_partition_budget; let should_repartition_rows = descriptor.spilled_rows > rows_budget; if !should_repartition_bytes && !should_repartition_rows { @@ -721,6 +774,12 @@ impl PartitionedHashJoinStream { } let mut required = HYBRID_HASH_MIN_FANOUT; + hhj_debug(|| { + format!( + "compute_recursive_fanout part={} spilled_bytes={} spilled_rows={} budget_bytes={} rows_budget={}", + descriptor.build_index, descriptor.spilled_bytes, descriptor.spilled_rows, per_partition_budget, rows_budget + ) + }); if should_repartition_bytes { let budget = per_partition_budget.max(1); @@ -749,6 +808,12 @@ impl PartitionedHashJoinStream { if additional_bits == 0 { return None; } + hhj_debug(|| { + format!( + "compute_recursive_fanout part={} -> additional_bits={} fanout={}", + descriptor.build_index, additional_bits, fanout + ) + }); Some((additional_bits, fanout)) } @@ -758,6 +823,12 @@ impl PartitionedHashJoinStream { additional_bits: usize, fanout: usize, ) -> Result> { + hhj_debug(|| { + format!( + "repartition_spilled_partition start part={} gen={} additional_bits={} fanout={}", + descriptor.build_index, descriptor.generation, additional_bits, fanout + ) + }); let build_index = descriptor.build_index; if build_index >= self.build_partitions.len() { return Ok(vec![]); @@ -812,6 +883,12 @@ impl PartitionedHashJoinStream { let mut new_descriptor = descriptor.clone(); new_descriptor.spilled_bytes = 0; new_descriptor.spilled_rows = 0; + hhj_debug(|| { + format!( + "repartition_spilled_partition part={} had empty spill batches; marking empty descriptor", + build_index + ) + }); self.matched_build_rows_per_partition[build_index] = BooleanBufferBuilder::new(0); self.build_partitions[build_index] = BuildPartition::Empty; @@ -1027,11 +1104,17 @@ impl PartitionedHashJoinStream { let probe_partition_budget = per_partition_budget_bytes(self.memory_threshold, self.num_partitions); - let spill_file = { + let mut spill_files = { let state = self .probe_states .get_mut(parent_index) .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + // Ensure any in-progress writer contributes its rows + if let Some(mut writer) = state.spill_in_progress.take() { + if let Some(file) = writer.finish()? { + state.spill_files.push_back(file); + } + } state.batch_position = 0; state.buffered_rows = 0; state.buffered_bytes = 0; @@ -1043,10 +1126,10 @@ impl PartitionedHashJoinStream { state.active_offset = (0, None); state.joined_probe_idx = None; state.pending_stream = None; - state.spill_files.pop_front() + mem::take(&mut state.spill_files) }; - if let Some(file) = spill_file { + while let Some(file) = spill_files.pop_front() { let mut writers = Vec::with_capacity(fanout); for _ in 0..fanout { let writer = self @@ -1131,7 +1214,6 @@ impl PartitionedHashJoinStream { state.active_offset = (0, None); state.joined_probe_idx = None; } - return Ok(()); } // In-memory probe data @@ -1373,27 +1455,67 @@ impl PartitionedHashJoinStream { descriptor: &PartitionDescriptor, ) -> Result { if descriptor.build_index >= self.build_partitions.len() { + hhj_debug(|| { + format!( + "maybe_recursive_repartition skipping part {}: build index out of range", + descriptor.build_index + ) + }); return Ok(false); } match self.build_partitions.get(descriptor.build_index) { Some(BuildPartition::Spilled { .. }) => {} - _ => return Ok(false), + _ => { + hhj_debug(|| { + format!( + "maybe_recursive_repartition skipping part {}: not spilled", + descriptor.build_index + ) + }); + return Ok(false); + } } let Some((additional_bits, fanout)) = self.compute_recursive_fanout(descriptor) else { + hhj_debug(|| { + format!( + "maybe_recursive_repartition skipping part {}: compute_recursive_fanout returned None", + descriptor.build_index + ) + }); return Ok(false); }; let new_descriptors = self.repartition_spilled_partition(descriptor, additional_bits, fanout)?; if new_descriptors.is_empty() { + hhj_debug(|| { + format!( + "maybe_recursive_repartition part {} produced no new descriptors", + descriptor.build_index + ) + }); return Ok(false); } // Enqueue new descriptors in order for desc in new_descriptors.into_iter().rev() { #[cfg(feature = "hybrid_hash_join_scheduler")] - self.schedule_probe_task(&desc); + if self.probe_scheduler_enabled { + self.schedule_probe_task(&desc); + } + hhj_debug(|| { + format!( + "maybe_recursive_repartition enqueued new part {} (root={} gen={})", + desc.build_index, desc.root_index, desc.generation + ) + }); self.pending_partitions.push_front(desc); } + hhj_debug(|| { + format!( + "maybe_recursive_repartition succeeded for original part {}", + descriptor.build_index + ) + }); Ok(true) } @@ -1521,7 +1643,8 @@ impl PartitionedHashJoinStream { self.state = PartitionedHashJoinState::PartitionBuildSide; } - fn next_partition_count(&self) -> Option { + fn next_partition_count(&mut self) -> Option { + self.ensure_partition_headroom(); if self.num_partitions >= self.max_partition_count { return None; } @@ -1540,6 +1663,32 @@ impl PartitionedHashJoinStream { } } + fn ensure_partition_headroom(&mut self) { + if self.num_partitions < self.max_partition_count { + return; + } + if self.max_partition_count >= self.mem_limit_partitions { + return; + } + let mut new_max = self.max_partition_count.saturating_mul(2); + if new_max <= self.max_partition_count { + new_max = self.max_partition_count.saturating_add(1); + } + new_max = new_max + .max(HYBRID_HASH_MIN_FANOUT) + .min(self.mem_limit_partitions); + if new_max > self.max_partition_count { + let old = self.max_partition_count; + self.max_partition_count = new_max; + self.log_stream(|| { + format!( + "expanded max_partition_count from {} to {} (mem_limit={})", + old, new_max, self.mem_limit_partitions + ) + }); + } + } + fn repartition_worthwhile(&self, max_spilled_bytes: usize) -> bool { let partitions = self.num_partitions.max(1); let per_partition_budget = self.memory_threshold / partitions; @@ -1551,10 +1700,12 @@ impl PartitionedHashJoinStream { max_spilled_bytes > cutoff } - fn prepare_partition_queue(&mut self) { + fn prepare_partition_queue(&mut self) -> Result<()> { self.pending_partitions.clear(); let radix_bits = self.num_partitions.next_power_of_two().trailing_zeros() as usize; + + let mut descriptors: VecDeque = VecDeque::new(); for part_id in 0..self.build_partitions.len() { let (spilled_bytes, spilled_rows) = match &self.build_partitions[part_id] { BuildPartition::Spilled { @@ -1564,13 +1715,7 @@ impl PartitionedHashJoinStream { } => (*spilled_bytes, *spilled_rows), _ => (0, 0), }; - if self.partition_descriptors.len() <= part_id { - self.partition_descriptors.resize_with(part_id + 1, || None); - } - if self.partition_pending.len() <= part_id { - self.partition_pending.resize(part_id + 1, false); - } - self.pending_partitions.push_back(PartitionDescriptor { + descriptors.push_back(PartitionDescriptor { build_index: part_id, root_index: part_id, generation: self.partition_pass, @@ -1579,13 +1724,55 @@ impl PartitionedHashJoinStream { spilled_bytes, spilled_rows, }); - if let Some(desc) = self.pending_partitions.back().cloned() { - self.partition_descriptors[part_id] = Some(desc.clone()); - self.partition_pending[part_id] = true; - #[cfg(feature = "hybrid_hash_join_scheduler")] - self.schedule_probe_task(&desc); + } + + while let Some(descriptor) = descriptors.pop_front() { + let build_index = descriptor.build_index; + let needs_recursive = matches!( + self.build_partitions.get(build_index), + Some(BuildPartition::Spilled { .. }) + ); + + if needs_recursive { + if let Some((additional_bits, fanout)) = + self.compute_recursive_fanout(&descriptor) + { + hhj_debug(|| { + format!( + "prepare_partition_queue repartitioning part {} fanout={}", + build_index, fanout + ) + }); + let new_descriptors = self.repartition_spilled_partition( + &descriptor, + additional_bits, + fanout, + )?; + for desc in new_descriptors.into_iter().rev() { + descriptors.push_front(desc); + } + continue; + } + } + + let build_index = descriptor.build_index; + if self.partition_descriptors.len() <= build_index { + self.partition_descriptors + .resize_with(build_index + 1, || None); + } + if self.partition_pending.len() <= build_index { + self.partition_pending.resize(build_index + 1, false); + } + + self.pending_partitions.push_back(descriptor.clone()); + self.partition_descriptors[build_index] = Some(descriptor.clone()); + self.partition_pending[build_index] = true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + if self.probe_scheduler_enabled { + self.schedule_probe_task(&descriptor); } } + Ok(()) } fn transition_to_next_partition(&mut self) { @@ -1609,7 +1796,9 @@ impl PartitionedHashJoinStream { self.current_partition = None; #[cfg(feature = "hybrid_hash_join_scheduler")] { - if !self.probe_scheduler_waiting_for_stream.is_empty() { + if self.probe_scheduler_enabled + && !self.probe_scheduler_waiting_for_stream.is_empty() + { hhj_debug(|| { "transition_to_next_partition -> WaitingForProbe".to_string() }); @@ -1824,6 +2013,9 @@ impl PartitionedHashJoinStream { bounds_accumulator: Option< Arc, >, + stream_id: usize, + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_enabled: bool, ) -> Result { let probe_spill_manager = SpillManager::new( runtime_env.clone(), @@ -1870,6 +2062,7 @@ impl PartitionedHashJoinStream { batch_size, num_partitions, max_partition_count, + mem_limit_partitions: mem_limit, memory_threshold, state: PartitionedHashJoinState::PartitionBuildSide, build_partitions: Vec::new(), @@ -1918,6 +2111,10 @@ impl PartitionedHashJoinStream { filter_debug_once_per_part: vec![false; num_partitions], partition_pending: vec![false; num_partitions], partition_descriptors: (0..num_partitions).map(|_| None).collect(), + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_enabled, + stream_completed: false, + stream_id, }) } @@ -1955,12 +2152,15 @@ impl PartitionedHashJoinStream { let build_total_size = build_data.total_input_size(); if build_total_size <= self.memory_threshold { self.num_partitions = 1; - self.max_partition_count = 1; + self.max_partition_count = self + .mem_limit_partitions + .max(HYBRID_HASH_MIN_FANOUT) + .min(HYBRID_HASH_MAX_PARTITIONS); } let mut allow_repartition = !self.partition_pass_output_started; loop { - hhj_debug(|| { + self.log_stream(|| { format!( "partition_build_side pass={} num_partitions={} allow_repartition={}", self.partition_pass, self.num_partitions, allow_repartition @@ -1970,7 +2170,7 @@ impl PartitionedHashJoinStream { match self.try_partition_build_side(&build_data, allow_repartition)? { PartitionBuildStatus::Ready(result) => { - hhj_debug(|| { + self.log_stream(|| { format!( "partition_build_side pass {} completed (num_partitions={})", self.partition_pass, self.num_partitions @@ -1979,7 +2179,7 @@ impl PartitionedHashJoinStream { return Ok(result); } PartitionBuildStatus::NeedMorePartitions { next_count } => { - hhj_debug(|| { + self.log_stream(|| { format!( "partition_build_side requesting repartition to {} (current={})", next_count, self.num_partitions @@ -1989,7 +2189,7 @@ impl PartitionedHashJoinStream { || next_count == 0 || next_count > self.max_partition_count { - hhj_debug(|| { + self.log_stream(|| { format!( "repartition request invalid (max={} current={}); forcing spill", self.max_partition_count, self.num_partitions @@ -2013,6 +2213,12 @@ impl PartitionedHashJoinStream { build_data: &Arc, allow_repartition: bool, ) -> Result { + hhj_debug(|| { + format!( + "try_partition_build_side start allow_repartition={} num_partitions={}", + allow_repartition, self.num_partitions + ) + }); self.build_partitions = Vec::with_capacity(self.num_partitions); self.matched_build_rows_per_partition = Vec::with_capacity(self.num_partitions); @@ -2024,6 +2230,13 @@ impl PartitionedHashJoinStream { let mut any_spilled = false; build_data.for_each_original_batch(|batch| { + hhj_debug(|| { + format!( + "partition_build_side processing source batch rows={} num_partitions={}", + batch.num_rows(), + self.num_partitions + ) + }); let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); for expr in &self.on_left { keys_values.push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); @@ -2085,6 +2298,19 @@ impl PartitionedHashJoinStream { repartition_request = Some(next_count); return Ok(false); } + } else if repartition_request.is_none() { + if let Some(next_count) = + self.next_partition_count() + { + hhj_debug(|| { + format!( + "partition {} spilled during partitioning -> requesting repartition to {}", + build_index, next_count + ) + }); + repartition_request = Some(next_count); + return Ok(false); + } } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -2092,6 +2318,7 @@ impl PartitionedHashJoinStream { "Insufficient memory for build partitioning and spilling is disabled" )); } + any_spilled = true; self.spill_partition(build_index, accum)?; } } @@ -2110,6 +2337,17 @@ impl PartitionedHashJoinStream { repartition_request = Some(next_count); return Ok(false); } + } else if repartition_request.is_none() { + if let Some(next_count) = self.next_partition_count() { + hhj_debug(|| { + format!( + "partition {} spilled during partitioning -> requesting repartition to {}", + build_index, next_count + ) + }); + repartition_request = Some(next_count); + return Ok(false); + } } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -2118,6 +2356,7 @@ impl PartitionedHashJoinStream { )); } self.spill_partition(build_index, accum)?; + any_spilled = true; self.append_spilled_batch(accum, filtered_batch)?; } } @@ -2134,14 +2373,12 @@ impl PartitionedHashJoinStream { } })?; - if let Some(next_count) = repartition_request { - hhj_debug(|| { - format!( - "try_partition_build_side early repartition request next_count={next_count}" - ) - }); - return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); - } + self.log_stream(|| { + format!( + "partition_build_side processed all batches repartition_request={:?} any_spilled={} max_spilled_bytes={}", + repartition_request, any_spilled, max_spilled_bytes + ) + }); self.build_partitions.reserve(self.num_partitions); self.matched_build_rows_per_partition @@ -2152,6 +2389,7 @@ impl PartitionedHashJoinStream { max_spilled_bytes = max_spilled_bytes.max(accum.spilled_bytes); if accum.spill_writer.is_some() { if !accum.buffered_batches.is_empty() { + any_spilled = true; self.spill_partition(part_id, &mut accum)?; } if let Some(mut writer) = accum.spill_writer.take() { @@ -2165,6 +2403,23 @@ impl PartitionedHashJoinStream { .with_can_spill(true) .register(&self.runtime_env.memory_pool); any_spilled = true; + if allow_repartition && repartition_request.is_none() { + if let Some(next_count) = self.next_partition_count() { + self.log_stream(|| { + format!( + "partition {} finalized spill -> requesting repartition to {}", + part_id, next_count + ) + }); + repartition_request = Some(next_count); + } + } + self.log_stream(|| { + format!( + "partition_build_side finalized spill for part {} bytes={} rows={}", + part_id, accum.spilled_bytes, accum.total_rows + ) + }); self.build_partitions.push(BuildPartition::Spilled { spill_file: Some(spill_file), reservation, @@ -2255,27 +2510,26 @@ impl PartitionedHashJoinStream { values: partition_values, reservation, }); + hhj_debug(|| { + format!( + "partition_build_side kept partition {} in-memory (rows={}, approx_bytes={})", + part_id, + num_rows, + approx_partition_size + ) + }); } - if allow_repartition - && (max_spilled_bytes > self.memory_threshold || any_spilled) - && self.repartition_worthwhile(max_spilled_bytes) - { - if let Some(next_count) = self.next_partition_count() { - hhj_debug(|| { - format!( - "try_partition_build_side repartition due to spill (max_spilled_bytes={} threshold={} any_spilled={}) next_count={}", - max_spilled_bytes, - self.memory_threshold, - any_spilled, - next_count - ) - }); - return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); - } + if let Some(next_count) = repartition_request { + hhj_debug(|| { + format!( + "try_partition_build_side early repartition request next_count={next_count}" + ) + }); + return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); } - self.prepare_partition_queue(); + self.prepare_partition_queue()?; self.partition_pass_output_started = true; self.transition_to_next_partition(); @@ -2350,6 +2604,13 @@ impl PartitionedHashJoinStream { // no-op } } + + if build_index < self.partition_descriptors.len() { + self.partition_descriptors[build_index] = None; + } + if build_index < self.partition_pending.len() { + self.partition_pending[build_index] = false; + } } fn partition_has_pending_probe(&self, part_id: usize) -> bool { @@ -2407,6 +2668,9 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn try_acquire_probe_stream_slot(&mut self) -> bool { + if !self.probe_scheduler_enabled { + return true; + } if self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { self.probe_scheduler_active_streams += 1; true @@ -2417,6 +2681,9 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn release_probe_stream_slot(&mut self) { + if !self.probe_scheduler_enabled { + return; + } if self.probe_scheduler_active_streams > 0 { self.probe_scheduler_active_streams -= 1; } @@ -2425,6 +2692,9 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn enqueue_stream_waiter(&mut self, part_id: usize) { + if !self.probe_scheduler_enabled { + return; + } if part_id >= self.partition_pending.len() { return; } @@ -2440,6 +2710,9 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn wake_stream_waiter(&mut self) { + if !self.probe_scheduler_enabled { + return; + } while self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { if let Some(next_part) = self.probe_scheduler_waiting_for_stream.pop_front() { hhj_debug(|| format!("wake_stream_waiter considering part {next_part}")); @@ -2482,6 +2755,9 @@ impl PartitionedHashJoinStream { cx: &mut Context<'_>, descriptor: &PartitionDescriptor, ) -> Result { + if !self.probe_scheduler_enabled { + return Ok(ProbeTaskStatus::Finished); + } let part_id = descriptor.build_index; self.schedule_probe_task(descriptor); hhj_debug(|| { @@ -2614,6 +2890,12 @@ impl PartitionedHashJoinStream { }; if !has_spilled_probe { + hhj_debug(|| { + format!( + "poll_probe_data_for_partition part {} -> Finished (no spill state)", + part_id + ) + }); return Ok(ProbeDataPoll::Finished); } @@ -2646,6 +2928,14 @@ impl PartitionedHashJoinStream { internal_datafusion_err!("missing partition") })?; state.spill_files.push_front(file); + hhj_debug(|| { + format!( + "poll_probe_data_for_partition part {} NeedStream (active_streams={} files_pending={})", + part_id, + self.probe_scheduler_active_streams, + state.spill_files.len() + ) + }); return Ok(ProbeDataPoll::NeedStream); } let stream = self.probe_spill_manager.read_spill_as_stream(file)?; @@ -2660,8 +2950,17 @@ impl PartitionedHashJoinStream { state.spill_in_progress.is_some() }; if self.probe_stream_finished && !writer_open { + hhj_debug(|| { + format!("poll_probe_data_for_partition part {} -> Finished (all files drained)", part_id) + }); return Ok(ProbeDataPoll::Finished); } else { + hhj_debug(|| { + format!( + "poll_probe_data_for_partition part {} pending writer_open={} finished={}", + part_id, writer_open, self.probe_stream_finished + ) + }); return Ok(ProbeDataPoll::Pending); } } @@ -2693,6 +2992,15 @@ impl PartitionedHashJoinStream { state.consumed_rows = state.consumed_rows.saturating_add(b.num_rows()); } + hhj_debug(|| { + format!( + "poll_probe_data_for_partition part {} -> Ready (batch_rows={}, files_pending={}, pending_stream={})", + part_id, + state.active_batch.as_ref().map(|b| b.num_rows()).unwrap_or(0), + state.spill_files.len(), + state.pending_stream.is_some() + ) + }); return Ok(ProbeDataPoll::Ready); } Some(Poll::Ready(Some(Err(e)))) => return Err(e), @@ -2705,9 +3013,31 @@ impl PartitionedHashJoinStream { state.pending_stream = None; } self.release_probe_stream_slot(); + hhj_debug(|| { + format!( + "poll_probe_data_for_partition part {} drained stream (files_left={})", + part_id, + self.probe_states + .get(part_id) + .map(|s| s.spill_files.len()) + .unwrap_or(0) + ) + }); continue; } - Some(Poll::Pending) | None => return Ok(ProbeDataPoll::Pending), + Some(Poll::Pending) | None => { + hhj_debug(|| { + format!( + "poll_probe_data_for_partition part {} pending on stream (pending_stream={})", + part_id, + self.probe_states + .get(part_id) + .map(|s| s.pending_stream.is_some()) + .unwrap_or(false) + ) + }); + return Ok(ProbeDataPoll::Pending); + } } } } @@ -2787,36 +3117,138 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] { - if !has_active_batch { - match self.poll_probe_stage_task(cx, &partition_state.descriptor)? { - ProbeTaskStatus::Ready => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Ready") - }); - has_active_batch = true; - } - ProbeTaskStatus::Pending => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Pending") - }); - return Poll::Pending; + if self.probe_scheduler_enabled { + if !has_active_batch { + match self.poll_probe_stage_task(cx, &partition_state.descriptor)? { + ProbeTaskStatus::Ready => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Ready") + }); + has_active_batch = true; + } + ProbeTaskStatus::Pending => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Pending") + }); + return Poll::Pending; + } + ProbeTaskStatus::WaitingForStream => { + hhj_debug(|| { + format!( + "process_partition part {build_index} -> WaitingForStream" + ) + }); + self.enqueue_stream_waiter(build_index); + self.current_partition = None; + self.transition_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + ProbeTaskStatus::Finished => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Finished") + }); + if self.probe_scheduler_inflight.len() > build_index { + self.probe_scheduler_inflight[build_index] = false; + } + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } } - ProbeTaskStatus::WaitingForStream => { - hhj_debug(|| { - format!("process_partition part {build_index} -> WaitingForStream") - }); - self.enqueue_stream_waiter(build_index); - self.current_partition = None; - self.transition_to_next_partition(); - return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + } else { + if !has_active_batch { + if self.take_buffered_probe_batch(build_index)?.is_some() { + has_active_batch = true; } - ProbeTaskStatus::Finished => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Finished") - }); + } + + if !has_active_batch { + let has_spilled_probe = match self.probe_state(build_index) { + Ok(state) => { + state.spill_in_progress.is_some() + || !state.spill_files.is_empty() + || state.pending_stream.is_some() + } + Err(e) => return Poll::Ready(Err(e)), + }; + + if has_spilled_probe { + loop { + let needs_stream = match self.probe_state(build_index) { + Ok(state) => state.pending_stream.is_none(), + Err(e) => return Poll::Ready(Err(e)), + }; + + if needs_stream { + let mut next_file = match self.probe_state_mut(build_index) { + Ok(state) => state.spill_files.pop_front(), + Err(e) => return Poll::Ready(Err(e)), + }; + if next_file.is_none() + && self.finalize_spilled_partition(build_index)? + { + next_file = match self.probe_state_mut(build_index) { + Ok(state) => state.spill_files.pop_front(), + Err(e) => return Poll::Ready(Err(e)), + }; + } + if let Some(file) = next_file { + let stream = self + .probe_spill_manager + .read_spill_as_stream(file)?; + match self.probe_state_mut(build_index) { + Ok(state) => state.pending_stream = Some(stream), + Err(e) => return Poll::Ready(Err(e)), + } + } else { + let should_release = match self.probe_state(build_index) { + Ok(state) => { + state.buffered.batches.is_empty() + && state.pending_stream.is_none() + } + Err(e) => return Poll::Ready(Err(e)), + }; + if should_release { + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok( + StatefulStreamResult::Continue + )); + } + break; + } + } else { + match self.poll_probe_data_for_partition( + build_index, + cx, + )? { + ProbeDataPoll::Ready => { + has_active_batch = true; + break; + } + ProbeDataPoll::Pending => return Poll::Pending, + ProbeDataPoll::NeedStream + | ProbeDataPoll::Finished => { + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok( + StatefulStreamResult::Continue + )); + } + } + } + } + } else if self.probe_stream_finished { self.release_partition_resources(build_index); self.advance_to_next_partition(); return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } else { + match self.buffer_probe_side(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } } } } @@ -3523,12 +3955,13 @@ impl PartitionedHashJoinStream { )? }; - let emitted_rows = result.num_rows(); - self.emitted_rows_per_part[build_index] = - self.emitted_rows_per_part[build_index].saturating_add(emitted_rows); (result, build_ids_to_mark, next_offset, next_joined_idx) }; + let emitted_rows = result.num_rows(); + self.emitted_rows_per_part[build_index] = + self.emitted_rows_per_part[build_index].saturating_add(emitted_rows); + // Mark matched build-side rows for outer joins (use current partition's bitmap) if let Some(bitmap) = self.matched_build_rows_per_partition.get_mut(build_index) { for build_idx in build_ids_to_mark { @@ -3549,7 +3982,9 @@ impl PartitionedHashJoinStream { state.active_offset = (0, None); state.joined_probe_idx = None; #[cfg(feature = "hybrid_hash_join_scheduler")] - self.schedule_probe_task(&partition_state.descriptor); + if self.probe_scheduler_enabled { + self.schedule_probe_task(&partition_state.descriptor); + } } } Err(e) => return Poll::Ready(Err(e)), @@ -3563,6 +3998,14 @@ impl PartitionedHashJoinStream { return Poll::Ready(Ok(StatefulStreamResult::Continue)); } self.join_metrics.output_batches.add(1); + self.log_stream(|| { + format!( + "process_partition part {} emitted {} rows (cumulative={})", + build_index, + emitted_rows, + self.emitted_rows_per_part[build_index] + ) + }); self.join_metrics.baseline.record_output(result.num_rows()); // println!( // "[spill-join] Emitting batch: rows={} (partition={})", @@ -3579,6 +4022,7 @@ impl PartitionedHashJoinStream { ) -> Poll>>> { if !need_produce_result_in_final(self.join_type) { self.state = PartitionedHashJoinState::Completed; + self.stream_completed = true; return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); } @@ -3831,6 +4275,7 @@ impl PartitionedHashJoinStream { } else { // All partitions processed self.state = PartitionedHashJoinState::Completed; + self.stream_completed = true; return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); } } @@ -3849,8 +4294,11 @@ impl Stream for PartitionedHashJoinStream { mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + if self.stream_completed { + return Poll::Ready(None); + } loop { - hhj_debug(|| format!("poll_next state {:?}", self.state)); + self.log_stream(|| format!("poll_next state {:?}", self.state)); match self.state.clone() { PartitionedHashJoinState::PartitionBuildSide => { // Collect build side and partition it @@ -3863,7 +4311,9 @@ impl Stream for PartitionedHashJoinStream { Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), Poll::Pending => return Poll::Pending, } - hhj_debug(|| format!("restarting build pass state={:?}", self.state)); + self.log_stream(|| { + format!("restarting build pass state={:?}", self.state) + }); match self.partition_build_side(left_data) { Ok(StatefulStreamResult::Continue) => continue, Ok(StatefulStreamResult::Ready(Some(batch))) => { @@ -3874,6 +4324,7 @@ impl Stream for PartitionedHashJoinStream { return Poll::Ready(Some(Ok(batch))); } Ok(StatefulStreamResult::Ready(None)) => { + self.stream_completed = true; return Poll::Ready(None) } Err(e) => return Poll::Ready(Some(Err(e))), @@ -3910,6 +4361,10 @@ impl Stream for PartitionedHashJoinStream { } #[cfg(feature = "hybrid_hash_join_scheduler")] PartitionedHashJoinState::WaitingForProbe => { + if !self.probe_scheduler_enabled { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + continue; + } if self.pending_partitions.is_empty() { if self.probe_scheduler_waiting_for_stream.is_empty() { hhj_debug(|| { @@ -4056,6 +4511,9 @@ mod scheduler_tests { schema, false, None, + 0, + #[cfg(feature = "hybrid_hash_join_scheduler")] + true, ) .unwrap(); stream.probe_scheduler_max_streams = max_streams; From fe8e49473305a5582986e44cfcae876aabeae9ab Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 12 Nov 2025 21:06:36 +0200 Subject: [PATCH 27/36] Add instrumentation to support grace hash join --- datafusion/common/src/config.rs | 5 + .../physical_optimizer/join_selection.rs | 74 +++++- datafusion/execution/src/config.rs | 11 + .../src/coalesce_batches.rs | 8 +- .../src/enforce_distribution.rs | 49 +++- .../src/enforce_sorting/sort_pushdown.rs | 81 ++++++- .../physical-optimizer/src/join_selection.rs | 155 ++++++++----- .../src/joins/grace_hash_join/exec.rs | 22 +- .../physical-plan/src/joins/hash_join/mod.rs | 1 - datafusion/physical-plan/src/joins/mod.rs | 1 + datafusion/proto/proto/datafusion.proto | 11 + datafusion/proto/src/generated/pbjson.rs | 216 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 21 +- datafusion/proto/src/physical_plan/mod.rs | 197 +++++++++++++++- .../tests/cases/roundtrip_physical_plan.rs | 39 +++- 15 files changed, 798 insertions(+), 93 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 54895bc56e049..7397bf374a84e 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -766,6 +766,11 @@ config_namespace! { /// when memory pressure occurs during join execution. pub enable_spillable_hash_join: bool, default = true + /// When set to true, spillable partitioned hash joins will be replaced with the experimental + /// Grace hash join operator which repartitions both inputs to disk before performing the join. + /// This trades additional IO for predictable memory usage on very large joins. + pub enable_grace_hash_join: bool, default = false + /// Should DataFusion allow symmetric hash joins for unbounded data sources even when /// its inputs do not have any ordering or filtering If the flag is not enabled, /// the SymmetricHashJoin operator will be unable to prune its internal buffers, diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index 7ae1d6e50dc3f..a68e465c85688 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -40,7 +40,9 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::displayable; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::utils::JoinFilter; -use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; +use datafusion_physical_plan::joins::{ + GraceHashJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, +}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::{ @@ -266,6 +268,76 @@ async fn test_join_with_swap() { ); } +#[tokio::test] +async fn test_grace_hash_join_enabled() { + let (big, small) = create_big_and_small(); + let join = Arc::new( + HashJoinExec::try_new( + Arc::clone(&small), + Arc::clone(&big), + vec![( + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + )], + None, + &JoinType::Inner, + None, + PartitionMode::Auto, + NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_grace_hash_join = true; + config.optimizer.enable_spillable_hash_join = true; + config.optimizer.hash_join_single_partition_threshold = 1; + config.optimizer.hash_join_single_partition_threshold_rows = 1; + + let optimized = JoinSelection::new().optimize(join, &config).unwrap(); + assert!( + optimized.as_any().is::(), + "expected GraceHashJoinExec when grace hash join is enabled" + ); +} + +#[tokio::test] +async fn test_grace_hash_join_disabled() { + let (big, small) = create_big_and_small(); + let join = Arc::new( + HashJoinExec::try_new( + Arc::clone(&small), + Arc::clone(&big), + vec![( + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + )], + None, + &JoinType::Inner, + None, + PartitionMode::Auto, + NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_grace_hash_join = false; + config.optimizer.enable_spillable_hash_join = true; + config.optimizer.hash_join_single_partition_threshold = 1; + config.optimizer.hash_join_single_partition_threshold_rows = 1; + + let optimized = JoinSelection::new().optimize(join, &config).unwrap(); + let hash_join = optimized + .as_any() + .downcast_ref::() + .expect("Grace disabled should keep HashJoinExec"); + assert_eq!( + hash_join.partition_mode(), + &PartitionMode::PartitionedSpillable + ); +} + #[tokio::test] async fn test_left_join_no_swap() { let (big, small) = create_big_and_small(); diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index e959b5684f813..47f2fc43236e1 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -240,6 +240,11 @@ impl SessionConfig { self.options.optimizer.enable_spillable_hash_join } + /// Should spillable hash joins be executed via the Grace hash join operator? + pub fn enable_grace_hash_join(&self) -> bool { + self.options.optimizer.enable_grace_hash_join + } + /// Are aggregates repartitioned during execution? pub fn repartition_aggregations(&self) -> bool { self.options.optimizer.repartition_aggregations @@ -309,6 +314,12 @@ impl SessionConfig { self } + /// Enables or disables the Grace hash join operator for spillable hash joins + pub fn with_enable_grace_hash_join(mut self, enabled: bool) -> Self { + self.options_mut().optimizer.enable_grace_hash_join = enabled; + self + } + /// Enables or disables the use of repartitioning for aggregations to improve parallelism pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self { self.options_mut().optimizer.repartition_aggregations = enabled; diff --git a/datafusion/physical-optimizer/src/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs index 5cf2c877c61a4..481b63b7a134e 100644 --- a/datafusion/physical-optimizer/src/coalesce_batches.rs +++ b/datafusion/physical-optimizer/src/coalesce_batches.rs @@ -26,8 +26,11 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_physical_expr::Partitioning; use datafusion_physical_plan::{ - coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, ExecutionPlan, + coalesce_batches::CoalesceBatchesExec, + filter::FilterExec, + joins::{GraceHashJoinExec, HashJoinExec}, + repartition::RepartitionExec, + ExecutionPlan, }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -62,6 +65,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { // See https://github.com/apache/datafusion/issues/139 let wrap_in_coalesce = plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() + || plan_any.downcast_ref::().is_some() // Don't need to add CoalesceBatchesExec after a round robin RepartitionExec || plan_any .downcast_ref::() diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index af7d7d0ce414b..4173afbdc3383 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -48,7 +48,7 @@ use datafusion_physical_plan::aggregates::{ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, + CrossJoinExec, GraceHashJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, }; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -376,6 +376,29 @@ pub fn adjust_input_keys_ordering( .map(Transformed::yes); } } + } else if let Some(grace_join) = plan.as_any().downcast_ref::() { + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + GraceHashJoinExec::try_new( + Arc::clone(grace_join.left()), + Arc::clone(grace_join.right()), + new_conditions.0, + grace_join.filter().cloned(), + grace_join.join_type(), + grace_join.projection.clone(), + grace_join.null_equality(), + ) + .map(|e| Arc::new(e) as _) + }; + return reorder_partitioned_join_keys( + requirements, + grace_join.on(), + &[], + &join_constructor, + ) + .map(Transformed::yes); } else if let Some(CrossJoinExec { left, .. }) = plan.as_any().downcast_ref::() { @@ -683,6 +706,30 @@ pub fn reorder_join_keys_to_inputs( )?)); } } + } else if let Some(grace_join) = plan_any.downcast_ref::() { + let (join_keys, positions) = reorder_current_join_keys( + extract_join_keys(grace_join.on()), + Some(grace_join.left().output_partitioning()), + Some(grace_join.right().output_partitioning()), + grace_join.left().equivalence_properties(), + grace_join.right().equivalence_properties(), + ); + if positions.is_some_and(|idxs| !idxs.is_empty()) { + let JoinKeyPairs { + left_keys, + right_keys, + } = join_keys; + let new_join_on = new_join_conditions(&left_keys, &right_keys); + return Ok(Arc::new(GraceHashJoinExec::try_new( + Arc::clone(grace_join.left()), + Arc::clone(grace_join.right()), + new_join_on, + grace_join.filter().cloned(), + grace_join.join_type(), + grace_join.projection.clone(), + grace_join.null_equality(), + )?)); + } } else if let Some(SortMergeJoinExec { left, right, diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 6e4e784866129..ce1137d296f15 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -40,7 +40,9 @@ use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ calculate_join_output_ordering, ColumnIndex, }; -use datafusion_physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; +use datafusion_physical_plan::joins::{ + GraceHashJoinExec, HashJoinExec, SortMergeJoinExec, +}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; @@ -381,7 +383,9 @@ fn pushdown_requirement_to_children( Ok(None) } } else if let Some(hash_join) = plan.as_any().downcast_ref::() { - handle_hash_join(hash_join, parent_required) + handle_hash_like_join(hash_join, parent_required) + } else if let Some(grace_join) = plan.as_any().downcast_ref::() { + handle_hash_like_join(grace_join, parent_required) } else { handle_custom_pushdown(plan, parent_required, maintains_input_order) } @@ -698,10 +702,71 @@ fn handle_custom_pushdown( } } -// For hash join we only maintain the input order for the right child +trait HashJoinLike { + fn maintains_input_order(&self) -> Vec; + fn projection(&self) -> &Option>; + fn children(&self) -> Vec<&Arc>; + fn join_type(&self) -> &JoinType; + fn left(&self) -> &Arc; + fn right(&self) -> &Arc; +} + +impl HashJoinLike for HashJoinExec { + fn maintains_input_order(&self) -> Vec { + ExecutionPlan::maintains_input_order(self) + } + + fn projection(&self) -> &Option> { + &self.projection + } + + fn children(&self) -> Vec<&Arc> { + ExecutionPlan::children(self) + } + + fn join_type(&self) -> &JoinType { + self.join_type() + } + + fn left(&self) -> &Arc { + self.left() + } + + fn right(&self) -> &Arc { + self.right() + } +} + +impl HashJoinLike for GraceHashJoinExec { + fn maintains_input_order(&self) -> Vec { + ExecutionPlan::maintains_input_order(self) + } + + fn projection(&self) -> &Option> { + &self.projection + } + + fn children(&self) -> Vec<&Arc> { + ExecutionPlan::children(self) + } + + fn join_type(&self) -> &JoinType { + self.join_type() + } + + fn left(&self) -> &Arc { + self.left() + } + + fn right(&self) -> &Arc { + self.right() + } +} + +// For hash-based joins we only maintain the input order for the right child // for join type: Inner, Right, RightSemi, RightAnti -fn handle_hash_join( - plan: &HashJoinExec, +fn handle_hash_like_join( + plan: &J, parent_required: OrderingRequirements, ) -> Result>>> { // If the plan has no children or does not maintain the right side ordering, @@ -723,7 +788,7 @@ fn handle_hash_join( .collect(); let column_indices = build_join_column_index(plan); - let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + let projected_indices: Vec<_> = if let Some(projection) = plan.projection() { projection.iter().map(|&i| &column_indices[i]).collect() } else { column_indices.iter().collect() @@ -770,9 +835,9 @@ fn handle_hash_join( } } -// this function is used to build the column index for the hash join +// this function is used to build the column index for hash-based joins so we can // push down sort requirements to the right child -fn build_join_column_index(plan: &HashJoinExec) -> Vec { +fn build_join_column_index(plan: &J) -> Vec { let map_fields = |schema: SchemaRef, side: JoinSide| { schema .fields() diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 976096e7761fc..111e9d1837df7 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -34,7 +34,7 @@ use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + CrossJoinExec, GraceHashJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -135,6 +135,7 @@ impl PhysicalOptimizerRule for JoinSelection { let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; let enable_spillable = config.enable_spillable_hash_join; + let enable_grace = config.enable_grace_hash_join; new_plan .transform_up(|plan| { statistical_join_selection_subrule( @@ -142,6 +143,7 @@ impl PhysicalOptimizerRule for JoinSelection { collect_threshold_byte_size, collect_threshold_num_rows, enable_spillable, + enable_grace, ) }) .data() @@ -232,6 +234,7 @@ pub(crate) fn try_collect_left( pub(crate) fn partitioned_hash_join( hash_join: &HashJoinExec, enable_spillable: bool, + enable_grace: bool, ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); @@ -241,8 +244,26 @@ pub(crate) fn partitioned_hash_join( PartitionMode::Partitioned }; - if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? - { + let should_swap = hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)?; + if enable_grace && matches!(partition_mode, PartitionMode::PartitionedSpillable) { + let grace = Arc::new(GraceHashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.projection.clone(), + hash_join.null_equality(), + )?); + return if should_swap { + grace.swap_inputs(partition_mode) + } else { + Ok(grace) + }; + } + + if should_swap { hash_join.swap_inputs(partition_mode) } else { Ok(Arc::new(HashJoinExec::try_new( @@ -265,75 +286,83 @@ fn statistical_join_selection_subrule( collect_threshold_byte_size: usize, collect_threshold_num_rows: usize, enable_spillable: bool, + enable_grace: bool, ) -> Result>> { - let transformed = - if let Some(hash_join) = plan.as_any().downcast_ref::() { - match hash_join.partition_mode() { - PartitionMode::Auto => try_collect_left( - hash_join, - false, - collect_threshold_byte_size, - collect_threshold_num_rows, - )? + let transformed = if let Some(hash_join) = + plan.as_any().downcast_ref::() + { + match hash_join.partition_mode() { + PartitionMode::Auto => try_collect_left( + hash_join, + false, + collect_threshold_byte_size, + collect_threshold_num_rows, + )? + .map_or_else( + || { + partitioned_hash_join(hash_join, enable_spillable, enable_grace) + .map(Some) + }, + |v| Ok(Some(v)), + )?, + PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? .map_or_else( - || partitioned_hash_join(hash_join, enable_spillable).map(Some), + || { + partitioned_hash_join(hash_join, enable_spillable, enable_grace) + .map(Some) + }, |v| Ok(Some(v)), )?, - PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? - .map_or_else( - || partitioned_hash_join(hash_join, enable_spillable).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if hash_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - hash_join - .swap_inputs(PartitionMode::Partitioned) - .map(Some)? - } else { - None - } - } - PartitionMode::PartitionedSpillable => { - println!("Using PartitionMode::PartitionedSpillable"); - // For partitioned spillable, use the same logic as regular partitioned - let left = hash_join.left(); - let right = hash_join.right(); - if hash_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - hash_join - .swap_inputs(PartitionMode::PartitionedSpillable) - .map(Some)? - } else { - None - } + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + hash_join + .swap_inputs(PartitionMode::Partitioned) + .map(Some)? + } else { + None } } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right)? { - cross_join.swap_inputs().map(Some)? - } else { - None - } - } else if let Some(nl_join) = plan.as_any().downcast_ref::() { - let left = nl_join.left(); - let right = nl_join.right(); - if nl_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - nl_join.swap_inputs().map(Some)? - } else { - None + PartitionMode::PartitionedSpillable => { + println!("Using PartitionMode::PartitionedSpillable"); + // For partitioned spillable, use the same logic as regular partitioned + let left = hash_join.left(); + let right = hash_join.right(); + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + hash_join + .swap_inputs(PartitionMode::PartitionedSpillable) + .map(Some)? + } else { + None + } } + } + } else if let Some(cross_join) = plan.as_any().downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right)? { + cross_join.swap_inputs().map(Some)? } else { None - }; + } + } else if let Some(nl_join) = plan.as_any().downcast_ref::() { + let left = nl_join.left(); + let right = nl_join.right(); + if nl_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + nl_join.swap_inputs().map(Some)? + } else { + None + } + } else { + None + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index 2c9482f93f892..1e85ae5c58775 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -20,9 +20,7 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; -use crate::joins::utils::{ - reorder_output_after_swap, swap_join_projection, OnceFut, -}; +use crate::joins::utils::{reorder_output_after_swap, swap_join_projection, OnceFut}; use crate::joins::{JoinOn, JoinOnRef, PartitionMode}; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, @@ -33,12 +31,12 @@ use crate::{ common::can_project, joins::utils::{ build_join_schema, check_join_is_valid, estimate_join_statistics, - symmetric_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, + JoinFilter, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - PlanProperties, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties, + SendableRecordBatchStream, Statistics, }; use crate::{ExecutionPlanProperties, SpillManager}; use std::fmt; @@ -52,8 +50,7 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - internal_err, plan_err, project_schema, JoinSide, JoinType, - NullEquality, Result, + internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::TaskContext; use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; @@ -550,6 +547,7 @@ impl ExecutionPlan for GraceHashJoinExec { self: Arc, children: Vec>, ) -> Result> { + let new_partition_count = children[0].output_partitioning().partition_count(); Ok(Arc::new(GraceHashJoinExec { left: Arc::clone(&children[0]), right: Arc::clone(&children[1]), @@ -572,11 +570,12 @@ impl ExecutionPlan for GraceHashJoinExec { )?, // Keep the dynamic filter, bounds accumulator will be reset dynamic_filter: self.dynamic_filter.clone(), - accumulator: Arc::clone(&self.accumulator), + accumulator: GraceAccumulator::new(new_partition_count), })) } fn reset_state(self: Arc) -> Result> { + let partition_count = self.left.output_partitioning().partition_count(); Ok(Arc::new(GraceHashJoinExec { left: Arc::clone(&self.left), right: Arc::clone(&self.right), @@ -592,7 +591,7 @@ impl ExecutionPlan for GraceHashJoinExec { cache: self.cache.clone(), // Reset dynamic filter and bounds accumulator to initial state dynamic_filter: None, - accumulator: Arc::clone(&self.accumulator), + accumulator: GraceAccumulator::new(partition_count), })) } @@ -855,7 +854,6 @@ impl ExecutionPlan for GraceHashJoinExec { } } - #[allow(clippy::too_many_arguments)] pub async fn partition_and_spill( random_state: RandomState, diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 9177954564046..9d70ca3e1ac63 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -23,6 +23,5 @@ mod exec; mod partitioned; #[cfg(feature = "hybrid_hash_join_scheduler")] mod scheduler; -mod shared_bounds; pub mod shared_bounds; mod stream; diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 914fdc75537bb..52d084bad4d26 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,6 +20,7 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; +pub use grace_hash_join::GraceHashJoinExec; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 70d6caf7642bc..07b71ff159ca1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -732,6 +732,7 @@ message PhysicalPlanNode { GenerateSeriesNode generate_series = 33; SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; + GraceHashJoinExecNode grace_hash_join = 36; } } @@ -1074,6 +1075,16 @@ message HashJoinExecNode { repeated uint32 projection = 9; } +message GraceHashJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + datafusion_common.JoinType join_type = 4; + datafusion_common.NullEquality null_equality = 5; + JoinFilter filter = 6; + repeated uint32 projection = 7; +} + enum StreamPartitionMode { SINGLE_PARTITION = 0; PARTITIONED_EXEC = 1; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 83f662e611120..4d7f8241c7100 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7717,6 +7717,208 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for GraceHashJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.null_equality != 0 { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + if !self.projection.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GraceHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GraceHashJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "null_equality", + "nullEquality", + "filter", + "projection", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + NullEquality, + Filter, + Projection, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), + "filter" => Ok(GeneratedField::Filter), + "projection" => Ok(GeneratedField::Projection), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GraceHashJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GraceHashJoinExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut null_equality__ = None; + let mut filter__ = None; + let mut projection__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); + } + null_equality__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(GraceHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), + filter: filter__, + projection: projection__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.GraceHashJoinExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for GroupingSetNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17022,6 +17224,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::MemoryScan(v) => { struct_ser.serialize_field("memoryScan", v)?; } + physical_plan_node::PhysicalPlanType::GraceHashJoin(v) => { + struct_ser.serialize_field("graceHashJoin", v)?; + } } } struct_ser.end() @@ -17087,6 +17292,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin", "memory_scan", "memoryScan", + "grace_hash_join", + "graceHashJoin", ]; #[allow(clippy::enum_variant_names)] @@ -17125,6 +17332,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { GenerateSeries, SortMergeJoin, MemoryScan, + GraceHashJoin, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17180,6 +17388,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "generateSeries" | "generate_series" => Ok(GeneratedField::GenerateSeries), "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), + "graceHashJoin" | "grace_hash_join" => Ok(GeneratedField::GraceHashJoin), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17438,6 +17647,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("memoryScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::MemoryScan) +; + } + GeneratedField::GraceHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("graceHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GraceHashJoin) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cc19add6fbe9e..d5520c6843b69 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1055,7 +1055,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub physical_plan_type: ::core::option::Option, } @@ -1133,6 +1133,8 @@ pub mod physical_plan_node { SortMergeJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "35")] MemoryScan(super::MemoryScanExecNode), + #[prost(message, tag = "36")] + GraceHashJoin(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1636,6 +1638,23 @@ pub struct HashJoinExecNode { pub projection: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct GraceHashJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")] + pub join_type: i32, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "5")] + pub null_equality: i32, + #[prost(message, optional, tag = "6")] + pub filter: ::core::option::Option, + #[prost(uint32, repeated, tag = "7")] + pub projection: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index bfa9cf0425887..adcc2d2dffff3 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -77,8 +77,8 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion::physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, - SymmetricHashJoinExec, + CrossJoinExec, GraceHashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, + StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; @@ -214,6 +214,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, ), + PhysicalPlanType::GraceHashJoin(grace_hash_join) => self + .try_into_grace_hash_join_physical_plan( + grace_hash_join, + ctx, + runtime, + extension_codec, + ), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, @@ -365,6 +372,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_grace_hash_join_exec( + exec, + extension_codec, + ); + } + if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, @@ -1266,6 +1280,116 @@ impl protobuf::PhysicalPlanNode { )?)) } + fn try_into_grace_hash_join_physical_plan( + &self, + grace_join: &protobuf::GraceHashJoinExecNode, + ctx: &SessionContext, + runtime: &RuntimeEnv, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left = into_physical_plan(&grace_join.left, ctx, runtime, extension_codec)?; + let right = into_physical_plan(&grace_join.right, ctx, runtime, extension_codec)?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let on = grace_join + .on + .iter() + .map(|col| { + let left_expr = parse_physical_expr( + col.left.as_ref().ok_or_else(|| { + proto_error("GraceHashJoinExecNode missing left expr") + })?, + ctx, + left_schema.as_ref(), + extension_codec, + )?; + let right_expr = parse_physical_expr( + col.right.as_ref().ok_or_else(|| { + proto_error("GraceHashJoinExecNode missing right expr") + })?, + ctx, + right_schema.as_ref(), + extension_codec, + )?; + Ok((left_expr, right_expr)) + }) + .collect::>()?; + let join_type = + protobuf::JoinType::try_from(grace_join.join_type).map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode with unknown JoinType {}", + grace_join.join_type + )) + })?; + let null_equality = protobuf::NullEquality::try_from(grace_join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode with unknown NullEquality {}", + grace_join.null_equality + )) + })?; + let filter = grace_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + ctx, + &schema, + extension_codec, + )?; + let column_indices = f + .column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side).map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode message with JoinSide in Filter {}", + i.side + )) + })?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>>()?; + + Ok(JoinFilter::new(expression, column_indices, Arc::new(schema))) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + let projection = if !grace_join.projection.is_empty() { + Some( + grace_join + .projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + + Ok(Arc::new(GraceHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + projection, + null_equality.into(), + )?)) + } + fn try_into_symmetric_hash_join_physical_plan( &self, sym_join: &protobuf::SymmetricHashJoinExecNode, @@ -2223,6 +2347,75 @@ impl protobuf::PhysicalPlanNode { }) } + fn try_from_grace_hash_join_exec( + exec: &GraceHashJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on: Vec = exec + .on() + .iter() + .map(|tuple| { + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), + }) + }) + .collect::>()?; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().as_ref().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GraceHashJoin(Box::new( + protobuf::GraceHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + null_equality: null_equality.into(), + filter, + projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }), + }, + ))), + }) + } + fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, extension_codec: &dyn PhysicalExtensionCodec, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a5357a132eef2..c4fdaa78dba28 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -75,8 +75,8 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::{ - HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, - StreamJoinPartitionMode, SymmetricHashJoinExec, + GraceHashJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; @@ -284,6 +284,41 @@ fn roundtrip_hash_join() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_grace_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Arc::new(Column::new("col", schema_left.index_of("col")?)) as _, + Arc::new(Column::new("col", schema_right.index_of("col")?)) as _, + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + roundtrip_test(Arc::new(GraceHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + None, + NullEquality::NullEqualsNothing, + )?))?; + } + Ok(()) +} + #[test] fn roundtrip_nested_loop_join() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false); From 51654a4ab813c8a161d89474012ba93ad5aaffda Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 12 Nov 2025 21:07:02 +0200 Subject: [PATCH 28/36] Revert "Add flag for scheduler" This reverts commit 505b1cba12b98f12ef508e4d884b26c34a4a35b3. --- .../physical-plan/src/joins/hash_join/exec.rs | 367 +++------- .../src/joins/hash_join/partitioned.rs | 656 +++--------------- 2 files changed, 199 insertions(+), 824 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 56b7ae899a949..4bdd8e41e1f85 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -16,8 +16,6 @@ // under the License. use std::fmt; -use std::fs::File; -use std::io::BufReader; use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; @@ -35,7 +33,8 @@ use crate::joins::hash_join::stream::{ }; use crate::joins::join_hash_map::{JoinHashMapU32, JoinHashMapU64}; use crate::joins::utils::{ - reorder_output_after_swap, swap_join_projection, update_hash, OnceAsync, OnceFut, + asymmetric_join_output_partitioning, reorder_output_after_swap, swap_join_projection, + update_hash, OnceAsync, OnceFut, }; use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; use crate::projection::{ @@ -62,15 +61,14 @@ use crate::{ use arrow::array::{ArrayRef, BooleanBufferBuilder}; use arrow::compute::concat_batches; use arrow::datatypes::SchemaRef; -use arrow::ipc::reader::StreamReader; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, - JoinSide, JoinType, NullEquality, Result, + internal_datafusion_err, internal_err, plan_err, project_schema, JoinSide, JoinType, + NullEquality, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -85,7 +83,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{pin_mut, StreamExt}; +use futures::{executor::block_on, pin_mut, StreamExt}; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. @@ -93,7 +91,7 @@ const HASH_JOIN_SEED: RandomState = RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); /// Maximum number of partitions allowed when recursively repartitioning during hybrid hash join. -pub(crate) const HYBRID_HASH_MAX_PARTITIONS: usize = 1 << 16; +const HYBRID_HASH_MAX_PARTITIONS: usize = 1 << 16; /// Upper bound multiplier applied to the initial partition fanout when searching for additional partitions. const HYBRID_HASH_PARTITION_GROWTH_FACTOR: usize = 16; /// Approximate number of probe batches worth of rows we target per partition when statistics are available. @@ -103,8 +101,6 @@ const HYBRID_HASH_MIN_BYTES_PER_PARTITION: usize = 8 * 1024 * 1024; /// Minimum number of rows per partition when statistics are available to avoid extreme fan-out. const HYBRID_HASH_MIN_ROWS_PER_PARTITION: usize = 1_024; -static NEXT_HHJ_STREAM_ID: AtomicUsize = AtomicUsize::new(0); - /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` @@ -136,7 +132,7 @@ pub(super) struct JoinLeftData { pub(super) enum OriginalBuildInput { InMemory(Arc>), Spilled { - _spill_manager: Arc, + spill_manager: Arc, spill_file: Arc, }, } @@ -221,20 +217,20 @@ impl JoinLeftData { } Ok(()) } - OriginalBuildInput::Spilled { spill_file, .. } => { - let file = File::open(spill_file.path())?; - let reader = BufReader::new(file); - // SAFETY: spill files are generated by DataFusion with validated schema/buffers. - let mut reader = unsafe { - StreamReader::try_new(reader, None)?.with_skip_validation(true) - }; - while let Some(batch) = reader.next() { - let batch = batch.map_err(DataFusionError::from)?; - if !f(batch)? { - break; + OriginalBuildInput::Spilled { + spill_manager, + spill_file, + } => { + let mut stream = + spill_manager.read_spill_as_stream_shared(Arc::clone(spill_file))?; + block_on(async { + while let Some(batch) = stream.next().await.transpose()? { + if !f(batch)? { + break; + } } - } - Ok(()) + Ok(()) + }) } } } @@ -674,7 +670,9 @@ impl HashJoinExec { )?; let mut output_partitioning = match mode { - PartitionMode::CollectLeft => Partitioning::UnknownPartitioning(1), + PartitionMode::CollectLeft => { + asymmetric_join_output_partitioning(left, right, &join_type)? + } PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), @@ -1007,23 +1005,7 @@ impl ExecutionPlan for HashJoinExec { if self.mode == PartitionMode::CollectLeft && left_partitions != 1 { return internal_err!( "Invalid HashJoinExec, the output partition count of the left child must be 1 in CollectLeft mode,\ - consider using CoalescePartitionsExec or the EnforceDistribution rule" - ); - } - - let enable_spillable_mode = context - .session_config() - .options() - .optimizer - .enable_spillable_hash_join; - let single_stream_mode = matches!(self.mode, PartitionMode::CollectLeft) - || (self.mode == PartitionMode::PartitionedSpillable && enable_spillable_mode); - - if single_stream_mode && partition > 0 { - return plan_err!( - "HashJoinExec in {:?} mode produces a single output partition but partition {partition} was requested. \ - Insert a RepartitionExec above the join when additional parallelism is required.", - self.mode + consider using CoalescePartitionsExec or the EnforceDistribution rule" ); } @@ -1075,7 +1057,13 @@ impl ExecutionPlan for HashJoinExec { ); } PartitionMode::PartitionedSpillable => { - if !enable_spillable_mode { + let enable_spillable = context + .session_config() + .options() + .optimizer + .enable_spillable_hash_join; + + if !enable_spillable { // Legacy fallback: behave like Partitioned let left_stream = self.left.execute(partition, Arc::clone(&context))?; @@ -1361,19 +1349,6 @@ impl ExecutionPlan for HashJoinExec { .register(context.memory_pool()); let probe_spill_metrics = SpillMetrics::new(&self.metrics, partition); let build_spill_metrics = SpillMetrics::new(&self.metrics, partition); - let stream_id = - NEXT_HHJ_STREAM_ID.fetch_add(1, Ordering::Relaxed); - #[cfg(feature = "hybrid_hash_join_scheduler")] - let probe_scheduler_enabled = std::env::var( - "DATAFUSION_HHJ_ENABLE_SCHEDULER", - ) - .map(|v| { - matches!( - v.to_ascii_lowercase().as_str(), - "1" | "true" | "yes" | "on" - ) - }) - .unwrap_or(false); let partitioned_stream = PartitionedHashJoinStream::new( partition, self.schema(), @@ -1399,9 +1374,6 @@ impl ExecutionPlan for HashJoinExec { probe_schema, self.right.output_ordering().is_some(), shared_bounds_accumulator, - stream_id, - #[cfg(feature = "hybrid_hash_join_scheduler")] - probe_scheduler_enabled, )?; return Ok(Box::pin(partitioned_stream)); } @@ -1850,21 +1822,67 @@ async fn collect_left_input( mut reservation, bounds_accumulators, } = state; + let batches_arc = Arc::new(batches); let (hashmap, single_batch, left_values, visited_indices_bitmap, original_input) = match collect_mode { CollectLeftMode::InMemory => { - let batches_arc = Arc::new(batches); - let (hashmap, single_batch, left_values, visited_indices_bitmap) = - build_join_data_from_batches( - batches_arc.as_slice(), - num_rows, - &schema, + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + let mut hashmap: Box = + if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; + + let mut hashes_buffer = Vec::new(); + let mut offset = 0; + let batches_iter = batches_arc.iter().rev(); + for batch in batches_iter.clone() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( &on_left, + batch, + &mut *hashmap, + offset, &random_state, - &mut reservation, - &metrics, - with_visited_indices_bitmap, + &mut hashes_buffer, + 0, + true, )?; + offset += batch.num_rows(); + } + let single_batch = concat_batches(&schema, batches_iter)?; + let visited_indices_bitmap = if with_visited_indices_bitmap { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = + BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let left_values = on_left + .iter() + .map(|c| { + c.evaluate(&single_batch)? + .into_array(single_batch.num_rows()) + }) + .collect::>>()?; + ( hashmap, single_batch, @@ -1887,40 +1905,26 @@ async fn collect_left_input( let mut writer = spill_writer.ok_or_else(|| { internal_datafusion_err!("missing build spill writer") })?; - let spill_file = Arc::new(writer.finish()?.ok_or_else(|| { + let spill_file = writer.finish()?.ok_or_else(|| { internal_datafusion_err!( "expected spill file when spilling build input" ) - })?); + })?; metrics.build_spill_count.add(1); metrics.build_spilled_rows.add(num_rows); metrics .build_spilled_bytes .add(spill_file.current_disk_usage() as usize); let _ = reservation.try_shrink(reservation.size()); - - let reloaded_batches = reload_spilled_batches(&spill_file)?; - let (hashmap, single_batch, left_values, visited_indices_bitmap) = - build_join_data_from_batches( - &reloaded_batches, - num_rows, - &schema, - &on_left, - &random_state, - &mut reservation, - &metrics, - with_visited_indices_bitmap, - )?; - drop(reloaded_batches); - ( - hashmap, - single_batch, - left_values, - visited_indices_bitmap, + Box::new(JoinHashMapU32::with_capacity(0)) + as Box, + RecordBatch::new_empty(schema.clone()), + Vec::new(), + BooleanBufferBuilder::new(0), OriginalBuildInput::Spilled { - _spill_manager: spill_manager, - spill_file, + spill_manager, + spill_file: Arc::new(spill_file), }, ) } @@ -1955,94 +1959,6 @@ async fn collect_left_input( Ok(data) } -fn build_join_data_from_batches( - batches: &[RecordBatch], - num_rows: usize, - schema: &SchemaRef, - on_left: &[PhysicalExprRef], - random_state: &RandomState, - reservation: &mut MemoryReservation, - metrics: &BuildProbeJoinMetrics, - with_visited_indices_bitmap: bool, -) -> Result<( - Box, - RecordBatch, - Vec, - BooleanBufferBuilder, -)> { - let fixed_size_u32 = size_of::(); - let fixed_size_u64 = size_of::(); - let mut hashmap: Box = if num_rows > u32::MAX as usize { - let estimated_hashtable_size = - estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU64::with_capacity(num_rows)) - } else { - let estimated_hashtable_size = - estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU32::with_capacity(num_rows)) - }; - - let mut hashes_buffer = Vec::new(); - let mut offset = 0usize; - let batches_iter = batches.iter().rev(); - for batch in batches_iter.clone() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( - on_left, - batch, - &mut *hashmap, - offset, - random_state, - &mut hashes_buffer, - 0, - true, - )?; - offset += batch.num_rows(); - } - let single_batch = concat_batches(schema, batches_iter)?; - let visited_indices_bitmap = if with_visited_indices_bitmap { - let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); - reservation.try_grow(bitmap_size)?; - metrics.build_mem_used.add(bitmap_size); - - let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); - bitmap_buffer.append_n(num_rows, false); - bitmap_buffer - } else { - BooleanBufferBuilder::new(0) - }; - - let left_values = on_left - .iter() - .map(|c| { - c.evaluate(&single_batch)? - .into_array(single_batch.num_rows()) - }) - .collect::>>()?; - - Ok((hashmap, single_batch, left_values, visited_indices_bitmap)) -} - -fn reload_spilled_batches( - spill_file: &Arc, -) -> Result> { - let file = File::open(spill_file.path())?; - let reader = BufReader::new(file); - // SAFETY: spill files are generated by DataFusion with validated schema/buffers. - let mut reader = - unsafe { StreamReader::try_new(reader, None)?.with_skip_validation(true) }; - let mut batches = Vec::new(); - while let Some(batch) = reader.next() { - batches.push(batch.map_err(DataFusionError::from)?); - } - Ok(batches) -} - #[cfg(test)] mod tests { use super::*; @@ -2050,9 +1966,8 @@ mod tests { use crate::joins::hash_join::stream::lookup_join_hashmap; use crate::test::{assert_join_metrics, TestMemoryExec}; use crate::{ - common, expressions::Column, metrics::ExecutionPlanMetricsSet, - repartition::RepartitionExec, stream::RecordBatchStreamAdapter, - test::build_table_i32, test::exec::MockExec, + common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, + test::exec::MockExec, }; use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; @@ -2063,15 +1978,13 @@ mod tests { use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, - DataFusionError, ScalarValue, + ScalarValue, }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; - use datafusion_execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::PhysicalExpr; - use futures::stream; use hashbrown::HashTable; use insta::{allow_duplicates, assert_snapshot}; use rstest::*; @@ -2100,74 +2013,6 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } - #[tokio::test] - async fn collect_left_input_spill_path_rebuilds_state() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ])); - let batch1 = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(Int32Array::from(vec![10, 20, 30])), - ], - )?; - let batch2 = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![4, 5])), - Arc::new(Int32Array::from(vec![40, 50])), - ], - )?; - - let stream = RecordBatchStreamAdapter::new( - Arc::clone(&schema), - stream::iter(vec![ - Ok::<_, DataFusionError>(batch1.clone()), - Ok::<_, DataFusionError>(batch2.clone()), - ]), - ); - let left_stream: SendableRecordBatchStream = Box::pin(stream); - - let metrics_set = ExecutionPlanMetricsSet::new(); - let join_metrics = BuildProbeJoinMetrics::new(0, &metrics_set); - let pool: Arc = Arc::new(UnboundedMemoryPool::default()); - let reservation = MemoryConsumer::new("collect_left_input_spill").register(&pool); - - let spill_manager = Arc::new(SpillManager::new( - Arc::new(RuntimeEnv::default()), - SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0), - Arc::clone(&schema), - )); - - let join_data = collect_left_input( - RandomState::new(), - left_stream, - vec![Arc::new(Column::new("a", 0))], - join_metrics, - reservation, - true, - 1, - false, - CollectLeftMode::SpillToDisk { spill_manager }, - ) - .await?; - - let expected_rows = batch1.num_rows() + batch2.num_rows(); - assert_eq!(join_data.total_rows(), expected_rows); - assert_eq!(join_data.batch().num_rows(), expected_rows); - assert!(!join_data.hash_map().is_empty()); - - let mut counted_rows = 0usize; - join_data.for_each_original_batch(|batch| { - counted_rows += batch.num_rows(); - Ok(true) - })?; - assert_eq!(counted_rows, expected_rows); - Ok(()) - } - fn join( left: Arc, right: Arc, @@ -2312,22 +2157,10 @@ mod tests { null_equality, )?; - let spillable_enabled = context - .session_config() - .options() - .optimizer - .enable_spillable_hash_join; - let single_stream_output = matches!(partition_mode, PartitionMode::CollectLeft) - || (partition_mode == PartitionMode::PartitionedSpillable && spillable_enabled); - let requested_partitions = if single_stream_output { - 1 - } else { - partition_count - }; let columns = columns(&join.schema()); let mut batches = vec![]; - for i in 0..requested_partitions { + for i in 0..partition_count { let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 4082ec2d87c7d..dcdf79c44c3f9 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -83,7 +83,6 @@ use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; use futures::{executor::block_on, ready, Stream, StreamExt}; -use crate::joins::hash_join::exec::HYBRID_HASH_MAX_PARTITIONS; const HYBRID_HASH_MAX_REPARTITION_DEPTH: usize = 6; const HYBRID_HASH_MIN_FANOUT: usize = 2; @@ -376,7 +375,6 @@ pub(super) struct PartitionedHashJoinStream { pub num_partitions: usize, /// Maximum partition fanout allowed when recursively repartitioning pub max_partition_count: usize, - pub mem_limit_partitions: usize, /// Memory threshold for spilling (in bytes) pub memory_threshold: usize, @@ -465,13 +463,6 @@ pub(super) struct PartitionedHashJoinStream { pub partition_pending: Vec, /// Latest descriptor metadata per partition pub partition_descriptors: Vec>, - /// Whether the async probe scheduler is enabled - #[cfg(feature = "hybrid_hash_join_scheduler")] - pub probe_scheduler_enabled: bool, - /// Whether this stream has already completed (prevents restart) - pub stream_completed: bool, - /// Unique identifier for debugging/logging - pub stream_id: usize, } #[cfg(feature = "hybrid_hash_join_scheduler")] @@ -484,13 +475,6 @@ enum ProbeTaskStatus { } impl PartitionedHashJoinStream { - #[inline] - fn log_stream String>(&self, builder: F) { - if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() { - let stream_id = self.stream_id; - hhj_debug(|| format!("[stream={stream_id}] {}", builder())); - } - } /// Compute partition id for a given hash using radix mask when possible #[inline] fn partition_for_hash(&self, hash: u64) -> usize { @@ -578,9 +562,7 @@ impl PartitionedHashJoinStream { self.pending_partitions.push_back(desc.clone()); self.partition_pending[part_id] = true; #[cfg(feature = "hybrid_hash_join_scheduler")] - if self.probe_scheduler_enabled { - self.schedule_probe_task(&desc); - } + self.schedule_probe_task(&desc); } Ok(()) @@ -600,9 +582,6 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn ensure_probe_scheduler_capacity(&mut self, part_id: usize) { - if !self.probe_scheduler_enabled { - return; - } if self.probe_scheduler_inflight.len() <= part_id { self.probe_scheduler_inflight.resize(part_id + 1, false); } @@ -610,9 +589,6 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn schedule_probe_task(&mut self, descriptor: &PartitionDescriptor) { - if !self.probe_scheduler_enabled { - return; - } let part_id = descriptor.build_index; self.ensure_probe_scheduler_capacity(part_id); if self.probe_scheduler_inflight[part_id] { @@ -637,16 +613,6 @@ impl PartitionedHashJoinStream { state.spill_files.push_back(file); } self.schedule_partition(part_id)?; - hhj_debug(|| { - format!( - "finalize_spilled_partition enqueued partition {} (spill_files={})", - part_id, - self.probe_states - .get(part_id) - .map(|s| s.spill_files.len()) - .unwrap_or(0) - ) - }); return Ok(true); } Ok(false) @@ -662,7 +628,7 @@ impl PartitionedHashJoinStream { return Ok(()); } - let (mut buffered, queue_ready, stream_active, mut writer_opt) = { + let (buffered, queue_ready, stream_active, mut writer_opt) = { let state = self .probe_states .get_mut(part_id) @@ -687,16 +653,14 @@ impl PartitionedHashJoinStream { .create_in_progress_file("hash_join_probe_partition")?, ); self.join_metrics.probe_spill_count.add(1); - hhj_debug(|| { - format!("flush_probe_buffer_to_spill opening writer for part {part_id}") - }); } - let mut writer = writer_opt - .ok_or_else(|| internal_datafusion_err!("missing probe spill writer"))?; + let mut writer = writer_opt.ok_or_else(|| { + internal_datafusion_err!("expected probe spill writer for partition") + })?; let mut spilled_rows = 0usize; - for batch in buffered.batches.drain(..) { + for batch in buffered.batches { let batch_size = batch.get_array_memory_size(); writer.append_batch(&batch)?; self.join_metrics.probe_spilled_rows.add(batch.num_rows()); @@ -704,32 +668,17 @@ impl PartitionedHashJoinStream { spilled_rows = spilled_rows.saturating_add(batch.num_rows()); } - let spill_file = writer.finish()?.ok_or_else(|| { - internal_datafusion_err!("expected probe spill file after flush") - })?; - hhj_debug(|| { - format!( - "flush_probe_buffer_to_spill part {} flushed_rows={} queue_ready={} stream_active={}", - part_id, spilled_rows, queue_ready, stream_active - ) - }); - - let state = self - .probe_states - .get_mut(part_id) - .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - state.spill_files.push_back(spill_file); - state.spill_in_progress = None; - state.spilled_rows = state.spilled_rows.saturating_add(spilled_rows); + { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.spill_in_progress = Some(writer); + state.spilled_rows = state.spilled_rows.saturating_add(spilled_rows); + } if !queue_ready && !stream_active { - hhj_debug(|| { - format!( - "flush_probe_buffer_to_spill scheduling partition {} after flush (queue empty)", - part_id - ) - }); - self.schedule_partition(part_id)?; + self.finalize_spilled_partition(part_id)?; } Ok(()) } @@ -764,9 +713,7 @@ impl PartitionedHashJoinStream { .saturating_mul(HYBRID_HASH_ROWS_PER_PARTITION_TARGET_MULTIPLIER) .max(HYBRID_HASH_ROWS_PER_PARTITION_MIN); - let spilled_any = descriptor.spilled_bytes > 0; - let should_repartition_bytes = - spilled_any || descriptor.spilled_bytes > per_partition_budget; + let should_repartition_bytes = descriptor.spilled_bytes > per_partition_budget; let should_repartition_rows = descriptor.spilled_rows > rows_budget; if !should_repartition_bytes && !should_repartition_rows { @@ -774,12 +721,6 @@ impl PartitionedHashJoinStream { } let mut required = HYBRID_HASH_MIN_FANOUT; - hhj_debug(|| { - format!( - "compute_recursive_fanout part={} spilled_bytes={} spilled_rows={} budget_bytes={} rows_budget={}", - descriptor.build_index, descriptor.spilled_bytes, descriptor.spilled_rows, per_partition_budget, rows_budget - ) - }); if should_repartition_bytes { let budget = per_partition_budget.max(1); @@ -808,12 +749,6 @@ impl PartitionedHashJoinStream { if additional_bits == 0 { return None; } - hhj_debug(|| { - format!( - "compute_recursive_fanout part={} -> additional_bits={} fanout={}", - descriptor.build_index, additional_bits, fanout - ) - }); Some((additional_bits, fanout)) } @@ -823,12 +758,6 @@ impl PartitionedHashJoinStream { additional_bits: usize, fanout: usize, ) -> Result> { - hhj_debug(|| { - format!( - "repartition_spilled_partition start part={} gen={} additional_bits={} fanout={}", - descriptor.build_index, descriptor.generation, additional_bits, fanout - ) - }); let build_index = descriptor.build_index; if build_index >= self.build_partitions.len() { return Ok(vec![]); @@ -883,12 +812,6 @@ impl PartitionedHashJoinStream { let mut new_descriptor = descriptor.clone(); new_descriptor.spilled_bytes = 0; new_descriptor.spilled_rows = 0; - hhj_debug(|| { - format!( - "repartition_spilled_partition part={} had empty spill batches; marking empty descriptor", - build_index - ) - }); self.matched_build_rows_per_partition[build_index] = BooleanBufferBuilder::new(0); self.build_partitions[build_index] = BuildPartition::Empty; @@ -1104,17 +1027,11 @@ impl PartitionedHashJoinStream { let probe_partition_budget = per_partition_budget_bytes(self.memory_threshold, self.num_partitions); - let mut spill_files = { + let spill_file = { let state = self .probe_states .get_mut(parent_index) .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - // Ensure any in-progress writer contributes its rows - if let Some(mut writer) = state.spill_in_progress.take() { - if let Some(file) = writer.finish()? { - state.spill_files.push_back(file); - } - } state.batch_position = 0; state.buffered_rows = 0; state.buffered_bytes = 0; @@ -1126,10 +1043,10 @@ impl PartitionedHashJoinStream { state.active_offset = (0, None); state.joined_probe_idx = None; state.pending_stream = None; - mem::take(&mut state.spill_files) + state.spill_files.pop_front() }; - while let Some(file) = spill_files.pop_front() { + if let Some(file) = spill_file { let mut writers = Vec::with_capacity(fanout); for _ in 0..fanout { let writer = self @@ -1214,6 +1131,7 @@ impl PartitionedHashJoinStream { state.active_offset = (0, None); state.joined_probe_idx = None; } + return Ok(()); } // In-memory probe data @@ -1455,67 +1373,27 @@ impl PartitionedHashJoinStream { descriptor: &PartitionDescriptor, ) -> Result { if descriptor.build_index >= self.build_partitions.len() { - hhj_debug(|| { - format!( - "maybe_recursive_repartition skipping part {}: build index out of range", - descriptor.build_index - ) - }); return Ok(false); } match self.build_partitions.get(descriptor.build_index) { Some(BuildPartition::Spilled { .. }) => {} - _ => { - hhj_debug(|| { - format!( - "maybe_recursive_repartition skipping part {}: not spilled", - descriptor.build_index - ) - }); - return Ok(false); - } + _ => return Ok(false), } let Some((additional_bits, fanout)) = self.compute_recursive_fanout(descriptor) else { - hhj_debug(|| { - format!( - "maybe_recursive_repartition skipping part {}: compute_recursive_fanout returned None", - descriptor.build_index - ) - }); return Ok(false); }; let new_descriptors = self.repartition_spilled_partition(descriptor, additional_bits, fanout)?; if new_descriptors.is_empty() { - hhj_debug(|| { - format!( - "maybe_recursive_repartition part {} produced no new descriptors", - descriptor.build_index - ) - }); return Ok(false); } // Enqueue new descriptors in order for desc in new_descriptors.into_iter().rev() { #[cfg(feature = "hybrid_hash_join_scheduler")] - if self.probe_scheduler_enabled { - self.schedule_probe_task(&desc); - } - hhj_debug(|| { - format!( - "maybe_recursive_repartition enqueued new part {} (root={} gen={})", - desc.build_index, desc.root_index, desc.generation - ) - }); + self.schedule_probe_task(&desc); self.pending_partitions.push_front(desc); } - hhj_debug(|| { - format!( - "maybe_recursive_repartition succeeded for original part {}", - descriptor.build_index - ) - }); Ok(true) } @@ -1643,8 +1521,7 @@ impl PartitionedHashJoinStream { self.state = PartitionedHashJoinState::PartitionBuildSide; } - fn next_partition_count(&mut self) -> Option { - self.ensure_partition_headroom(); + fn next_partition_count(&self) -> Option { if self.num_partitions >= self.max_partition_count { return None; } @@ -1663,32 +1540,6 @@ impl PartitionedHashJoinStream { } } - fn ensure_partition_headroom(&mut self) { - if self.num_partitions < self.max_partition_count { - return; - } - if self.max_partition_count >= self.mem_limit_partitions { - return; - } - let mut new_max = self.max_partition_count.saturating_mul(2); - if new_max <= self.max_partition_count { - new_max = self.max_partition_count.saturating_add(1); - } - new_max = new_max - .max(HYBRID_HASH_MIN_FANOUT) - .min(self.mem_limit_partitions); - if new_max > self.max_partition_count { - let old = self.max_partition_count; - self.max_partition_count = new_max; - self.log_stream(|| { - format!( - "expanded max_partition_count from {} to {} (mem_limit={})", - old, new_max, self.mem_limit_partitions - ) - }); - } - } - fn repartition_worthwhile(&self, max_spilled_bytes: usize) -> bool { let partitions = self.num_partitions.max(1); let per_partition_budget = self.memory_threshold / partitions; @@ -1700,12 +1551,10 @@ impl PartitionedHashJoinStream { max_spilled_bytes > cutoff } - fn prepare_partition_queue(&mut self) -> Result<()> { + fn prepare_partition_queue(&mut self) { self.pending_partitions.clear(); let radix_bits = self.num_partitions.next_power_of_two().trailing_zeros() as usize; - - let mut descriptors: VecDeque = VecDeque::new(); for part_id in 0..self.build_partitions.len() { let (spilled_bytes, spilled_rows) = match &self.build_partitions[part_id] { BuildPartition::Spilled { @@ -1715,7 +1564,13 @@ impl PartitionedHashJoinStream { } => (*spilled_bytes, *spilled_rows), _ => (0, 0), }; - descriptors.push_back(PartitionDescriptor { + if self.partition_descriptors.len() <= part_id { + self.partition_descriptors.resize_with(part_id + 1, || None); + } + if self.partition_pending.len() <= part_id { + self.partition_pending.resize(part_id + 1, false); + } + self.pending_partitions.push_back(PartitionDescriptor { build_index: part_id, root_index: part_id, generation: self.partition_pass, @@ -1724,55 +1579,13 @@ impl PartitionedHashJoinStream { spilled_bytes, spilled_rows, }); - } - - while let Some(descriptor) = descriptors.pop_front() { - let build_index = descriptor.build_index; - let needs_recursive = matches!( - self.build_partitions.get(build_index), - Some(BuildPartition::Spilled { .. }) - ); - - if needs_recursive { - if let Some((additional_bits, fanout)) = - self.compute_recursive_fanout(&descriptor) - { - hhj_debug(|| { - format!( - "prepare_partition_queue repartitioning part {} fanout={}", - build_index, fanout - ) - }); - let new_descriptors = self.repartition_spilled_partition( - &descriptor, - additional_bits, - fanout, - )?; - for desc in new_descriptors.into_iter().rev() { - descriptors.push_front(desc); - } - continue; - } - } - - let build_index = descriptor.build_index; - if self.partition_descriptors.len() <= build_index { - self.partition_descriptors - .resize_with(build_index + 1, || None); - } - if self.partition_pending.len() <= build_index { - self.partition_pending.resize(build_index + 1, false); - } - - self.pending_partitions.push_back(descriptor.clone()); - self.partition_descriptors[build_index] = Some(descriptor.clone()); - self.partition_pending[build_index] = true; - #[cfg(feature = "hybrid_hash_join_scheduler")] - if self.probe_scheduler_enabled { - self.schedule_probe_task(&descriptor); + if let Some(desc) = self.pending_partitions.back().cloned() { + self.partition_descriptors[part_id] = Some(desc.clone()); + self.partition_pending[part_id] = true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&desc); } } - Ok(()) } fn transition_to_next_partition(&mut self) { @@ -1796,9 +1609,7 @@ impl PartitionedHashJoinStream { self.current_partition = None; #[cfg(feature = "hybrid_hash_join_scheduler")] { - if self.probe_scheduler_enabled - && !self.probe_scheduler_waiting_for_stream.is_empty() - { + if !self.probe_scheduler_waiting_for_stream.is_empty() { hhj_debug(|| { "transition_to_next_partition -> WaitingForProbe".to_string() }); @@ -2013,9 +1824,6 @@ impl PartitionedHashJoinStream { bounds_accumulator: Option< Arc, >, - stream_id: usize, - #[cfg(feature = "hybrid_hash_join_scheduler")] - probe_scheduler_enabled: bool, ) -> Result { let probe_spill_manager = SpillManager::new( runtime_env.clone(), @@ -2062,7 +1870,6 @@ impl PartitionedHashJoinStream { batch_size, num_partitions, max_partition_count, - mem_limit_partitions: mem_limit, memory_threshold, state: PartitionedHashJoinState::PartitionBuildSide, build_partitions: Vec::new(), @@ -2111,10 +1918,6 @@ impl PartitionedHashJoinStream { filter_debug_once_per_part: vec![false; num_partitions], partition_pending: vec![false; num_partitions], partition_descriptors: (0..num_partitions).map(|_| None).collect(), - #[cfg(feature = "hybrid_hash_join_scheduler")] - probe_scheduler_enabled, - stream_completed: false, - stream_id, }) } @@ -2152,15 +1955,12 @@ impl PartitionedHashJoinStream { let build_total_size = build_data.total_input_size(); if build_total_size <= self.memory_threshold { self.num_partitions = 1; - self.max_partition_count = self - .mem_limit_partitions - .max(HYBRID_HASH_MIN_FANOUT) - .min(HYBRID_HASH_MAX_PARTITIONS); + self.max_partition_count = 1; } let mut allow_repartition = !self.partition_pass_output_started; loop { - self.log_stream(|| { + hhj_debug(|| { format!( "partition_build_side pass={} num_partitions={} allow_repartition={}", self.partition_pass, self.num_partitions, allow_repartition @@ -2170,7 +1970,7 @@ impl PartitionedHashJoinStream { match self.try_partition_build_side(&build_data, allow_repartition)? { PartitionBuildStatus::Ready(result) => { - self.log_stream(|| { + hhj_debug(|| { format!( "partition_build_side pass {} completed (num_partitions={})", self.partition_pass, self.num_partitions @@ -2179,7 +1979,7 @@ impl PartitionedHashJoinStream { return Ok(result); } PartitionBuildStatus::NeedMorePartitions { next_count } => { - self.log_stream(|| { + hhj_debug(|| { format!( "partition_build_side requesting repartition to {} (current={})", next_count, self.num_partitions @@ -2189,7 +1989,7 @@ impl PartitionedHashJoinStream { || next_count == 0 || next_count > self.max_partition_count { - self.log_stream(|| { + hhj_debug(|| { format!( "repartition request invalid (max={} current={}); forcing spill", self.max_partition_count, self.num_partitions @@ -2213,12 +2013,6 @@ impl PartitionedHashJoinStream { build_data: &Arc, allow_repartition: bool, ) -> Result { - hhj_debug(|| { - format!( - "try_partition_build_side start allow_repartition={} num_partitions={}", - allow_repartition, self.num_partitions - ) - }); self.build_partitions = Vec::with_capacity(self.num_partitions); self.matched_build_rows_per_partition = Vec::with_capacity(self.num_partitions); @@ -2230,13 +2024,6 @@ impl PartitionedHashJoinStream { let mut any_spilled = false; build_data.for_each_original_batch(|batch| { - hhj_debug(|| { - format!( - "partition_build_side processing source batch rows={} num_partitions={}", - batch.num_rows(), - self.num_partitions - ) - }); let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); for expr in &self.on_left { keys_values.push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); @@ -2298,19 +2085,6 @@ impl PartitionedHashJoinStream { repartition_request = Some(next_count); return Ok(false); } - } else if repartition_request.is_none() { - if let Some(next_count) = - self.next_partition_count() - { - hhj_debug(|| { - format!( - "partition {} spilled during partitioning -> requesting repartition to {}", - build_index, next_count - ) - }); - repartition_request = Some(next_count); - return Ok(false); - } } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -2318,7 +2092,6 @@ impl PartitionedHashJoinStream { "Insufficient memory for build partitioning and spilling is disabled" )); } - any_spilled = true; self.spill_partition(build_index, accum)?; } } @@ -2337,17 +2110,6 @@ impl PartitionedHashJoinStream { repartition_request = Some(next_count); return Ok(false); } - } else if repartition_request.is_none() { - if let Some(next_count) = self.next_partition_count() { - hhj_debug(|| { - format!( - "partition {} spilled during partitioning -> requesting repartition to {}", - build_index, next_count - ) - }); - repartition_request = Some(next_count); - return Ok(false); - } } } if !self.runtime_env.disk_manager.tmp_files_enabled() { @@ -2356,7 +2118,6 @@ impl PartitionedHashJoinStream { )); } self.spill_partition(build_index, accum)?; - any_spilled = true; self.append_spilled_batch(accum, filtered_batch)?; } } @@ -2373,12 +2134,14 @@ impl PartitionedHashJoinStream { } })?; - self.log_stream(|| { - format!( - "partition_build_side processed all batches repartition_request={:?} any_spilled={} max_spilled_bytes={}", - repartition_request, any_spilled, max_spilled_bytes - ) - }); + if let Some(next_count) = repartition_request { + hhj_debug(|| { + format!( + "try_partition_build_side early repartition request next_count={next_count}" + ) + }); + return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + } self.build_partitions.reserve(self.num_partitions); self.matched_build_rows_per_partition @@ -2389,7 +2152,6 @@ impl PartitionedHashJoinStream { max_spilled_bytes = max_spilled_bytes.max(accum.spilled_bytes); if accum.spill_writer.is_some() { if !accum.buffered_batches.is_empty() { - any_spilled = true; self.spill_partition(part_id, &mut accum)?; } if let Some(mut writer) = accum.spill_writer.take() { @@ -2403,23 +2165,6 @@ impl PartitionedHashJoinStream { .with_can_spill(true) .register(&self.runtime_env.memory_pool); any_spilled = true; - if allow_repartition && repartition_request.is_none() { - if let Some(next_count) = self.next_partition_count() { - self.log_stream(|| { - format!( - "partition {} finalized spill -> requesting repartition to {}", - part_id, next_count - ) - }); - repartition_request = Some(next_count); - } - } - self.log_stream(|| { - format!( - "partition_build_side finalized spill for part {} bytes={} rows={}", - part_id, accum.spilled_bytes, accum.total_rows - ) - }); self.build_partitions.push(BuildPartition::Spilled { spill_file: Some(spill_file), reservation, @@ -2510,26 +2255,27 @@ impl PartitionedHashJoinStream { values: partition_values, reservation, }); - hhj_debug(|| { - format!( - "partition_build_side kept partition {} in-memory (rows={}, approx_bytes={})", - part_id, - num_rows, - approx_partition_size - ) - }); } - if let Some(next_count) = repartition_request { - hhj_debug(|| { - format!( - "try_partition_build_side early repartition request next_count={next_count}" - ) - }); - return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + if allow_repartition + && (max_spilled_bytes > self.memory_threshold || any_spilled) + && self.repartition_worthwhile(max_spilled_bytes) + { + if let Some(next_count) = self.next_partition_count() { + hhj_debug(|| { + format!( + "try_partition_build_side repartition due to spill (max_spilled_bytes={} threshold={} any_spilled={}) next_count={}", + max_spilled_bytes, + self.memory_threshold, + any_spilled, + next_count + ) + }); + return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + } } - self.prepare_partition_queue()?; + self.prepare_partition_queue(); self.partition_pass_output_started = true; self.transition_to_next_partition(); @@ -2604,13 +2350,6 @@ impl PartitionedHashJoinStream { // no-op } } - - if build_index < self.partition_descriptors.len() { - self.partition_descriptors[build_index] = None; - } - if build_index < self.partition_pending.len() { - self.partition_pending[build_index] = false; - } } fn partition_has_pending_probe(&self, part_id: usize) -> bool { @@ -2668,9 +2407,6 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn try_acquire_probe_stream_slot(&mut self) -> bool { - if !self.probe_scheduler_enabled { - return true; - } if self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { self.probe_scheduler_active_streams += 1; true @@ -2681,9 +2417,6 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn release_probe_stream_slot(&mut self) { - if !self.probe_scheduler_enabled { - return; - } if self.probe_scheduler_active_streams > 0 { self.probe_scheduler_active_streams -= 1; } @@ -2692,9 +2425,6 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn enqueue_stream_waiter(&mut self, part_id: usize) { - if !self.probe_scheduler_enabled { - return; - } if part_id >= self.partition_pending.len() { return; } @@ -2710,9 +2440,6 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] fn wake_stream_waiter(&mut self) { - if !self.probe_scheduler_enabled { - return; - } while self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { if let Some(next_part) = self.probe_scheduler_waiting_for_stream.pop_front() { hhj_debug(|| format!("wake_stream_waiter considering part {next_part}")); @@ -2755,9 +2482,6 @@ impl PartitionedHashJoinStream { cx: &mut Context<'_>, descriptor: &PartitionDescriptor, ) -> Result { - if !self.probe_scheduler_enabled { - return Ok(ProbeTaskStatus::Finished); - } let part_id = descriptor.build_index; self.schedule_probe_task(descriptor); hhj_debug(|| { @@ -2890,12 +2614,6 @@ impl PartitionedHashJoinStream { }; if !has_spilled_probe { - hhj_debug(|| { - format!( - "poll_probe_data_for_partition part {} -> Finished (no spill state)", - part_id - ) - }); return Ok(ProbeDataPoll::Finished); } @@ -2928,14 +2646,6 @@ impl PartitionedHashJoinStream { internal_datafusion_err!("missing partition") })?; state.spill_files.push_front(file); - hhj_debug(|| { - format!( - "poll_probe_data_for_partition part {} NeedStream (active_streams={} files_pending={})", - part_id, - self.probe_scheduler_active_streams, - state.spill_files.len() - ) - }); return Ok(ProbeDataPoll::NeedStream); } let stream = self.probe_spill_manager.read_spill_as_stream(file)?; @@ -2950,17 +2660,8 @@ impl PartitionedHashJoinStream { state.spill_in_progress.is_some() }; if self.probe_stream_finished && !writer_open { - hhj_debug(|| { - format!("poll_probe_data_for_partition part {} -> Finished (all files drained)", part_id) - }); return Ok(ProbeDataPoll::Finished); } else { - hhj_debug(|| { - format!( - "poll_probe_data_for_partition part {} pending writer_open={} finished={}", - part_id, writer_open, self.probe_stream_finished - ) - }); return Ok(ProbeDataPoll::Pending); } } @@ -2992,15 +2693,6 @@ impl PartitionedHashJoinStream { state.consumed_rows = state.consumed_rows.saturating_add(b.num_rows()); } - hhj_debug(|| { - format!( - "poll_probe_data_for_partition part {} -> Ready (batch_rows={}, files_pending={}, pending_stream={})", - part_id, - state.active_batch.as_ref().map(|b| b.num_rows()).unwrap_or(0), - state.spill_files.len(), - state.pending_stream.is_some() - ) - }); return Ok(ProbeDataPoll::Ready); } Some(Poll::Ready(Some(Err(e)))) => return Err(e), @@ -3013,31 +2705,9 @@ impl PartitionedHashJoinStream { state.pending_stream = None; } self.release_probe_stream_slot(); - hhj_debug(|| { - format!( - "poll_probe_data_for_partition part {} drained stream (files_left={})", - part_id, - self.probe_states - .get(part_id) - .map(|s| s.spill_files.len()) - .unwrap_or(0) - ) - }); continue; } - Some(Poll::Pending) | None => { - hhj_debug(|| { - format!( - "poll_probe_data_for_partition part {} pending on stream (pending_stream={})", - part_id, - self.probe_states - .get(part_id) - .map(|s| s.pending_stream.is_some()) - .unwrap_or(false) - ) - }); - return Ok(ProbeDataPoll::Pending); - } + Some(Poll::Pending) | None => return Ok(ProbeDataPoll::Pending), } } } @@ -3117,138 +2787,36 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] { - if self.probe_scheduler_enabled { - if !has_active_batch { - match self.poll_probe_stage_task(cx, &partition_state.descriptor)? { - ProbeTaskStatus::Ready => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Ready") - }); - has_active_batch = true; - } - ProbeTaskStatus::Pending => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Pending") - }); - return Poll::Pending; - } - ProbeTaskStatus::WaitingForStream => { - hhj_debug(|| { - format!( - "process_partition part {build_index} -> WaitingForStream" - ) - }); - self.enqueue_stream_waiter(build_index); - self.current_partition = None; - self.transition_to_next_partition(); - return Poll::Ready(Ok(StatefulStreamResult::Continue)); - } - ProbeTaskStatus::Finished => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Finished") - }); - if self.probe_scheduler_inflight.len() > build_index { - self.probe_scheduler_inflight[build_index] = false; - } - self.release_partition_resources(build_index); - self.advance_to_next_partition(); - return Poll::Ready(Ok(StatefulStreamResult::Continue)); - } - } - } - } else { - if !has_active_batch { - if self.take_buffered_probe_batch(build_index)?.is_some() { + if !has_active_batch { + match self.poll_probe_stage_task(cx, &partition_state.descriptor)? { + ProbeTaskStatus::Ready => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Ready") + }); has_active_batch = true; } - } - - if !has_active_batch { - let has_spilled_probe = match self.probe_state(build_index) { - Ok(state) => { - state.spill_in_progress.is_some() - || !state.spill_files.is_empty() - || state.pending_stream.is_some() - } - Err(e) => return Poll::Ready(Err(e)), - }; - - if has_spilled_probe { - loop { - let needs_stream = match self.probe_state(build_index) { - Ok(state) => state.pending_stream.is_none(), - Err(e) => return Poll::Ready(Err(e)), - }; - - if needs_stream { - let mut next_file = match self.probe_state_mut(build_index) { - Ok(state) => state.spill_files.pop_front(), - Err(e) => return Poll::Ready(Err(e)), - }; - if next_file.is_none() - && self.finalize_spilled_partition(build_index)? - { - next_file = match self.probe_state_mut(build_index) { - Ok(state) => state.spill_files.pop_front(), - Err(e) => return Poll::Ready(Err(e)), - }; - } - if let Some(file) = next_file { - let stream = self - .probe_spill_manager - .read_spill_as_stream(file)?; - match self.probe_state_mut(build_index) { - Ok(state) => state.pending_stream = Some(stream), - Err(e) => return Poll::Ready(Err(e)), - } - } else { - let should_release = match self.probe_state(build_index) { - Ok(state) => { - state.buffered.batches.is_empty() - && state.pending_stream.is_none() - } - Err(e) => return Poll::Ready(Err(e)), - }; - if should_release { - self.release_partition_resources(build_index); - self.advance_to_next_partition(); - return Poll::Ready(Ok( - StatefulStreamResult::Continue - )); - } - break; - } - } else { - match self.poll_probe_data_for_partition( - build_index, - cx, - )? { - ProbeDataPoll::Ready => { - has_active_batch = true; - break; - } - ProbeDataPoll::Pending => return Poll::Pending, - ProbeDataPoll::NeedStream - | ProbeDataPoll::Finished => { - self.release_partition_resources(build_index); - self.advance_to_next_partition(); - return Poll::Ready(Ok( - StatefulStreamResult::Continue - )); - } - } - } - } - } else if self.probe_stream_finished { + ProbeTaskStatus::Pending => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Pending") + }); + return Poll::Pending; + } + ProbeTaskStatus::WaitingForStream => { + hhj_debug(|| { + format!("process_partition part {build_index} -> WaitingForStream") + }); + self.enqueue_stream_waiter(build_index); + self.current_partition = None; + self.transition_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + ProbeTaskStatus::Finished => { + hhj_debug(|| { + format!("process_partition part {build_index} -> Finished") + }); self.release_partition_resources(build_index); self.advance_to_next_partition(); return Poll::Ready(Ok(StatefulStreamResult::Continue)); - } else { - match self.buffer_probe_side(cx) { - Poll::Ready(Ok(())) => {} - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - } } } } @@ -3955,13 +3523,12 @@ impl PartitionedHashJoinStream { )? }; + let emitted_rows = result.num_rows(); + self.emitted_rows_per_part[build_index] = + self.emitted_rows_per_part[build_index].saturating_add(emitted_rows); (result, build_ids_to_mark, next_offset, next_joined_idx) }; - let emitted_rows = result.num_rows(); - self.emitted_rows_per_part[build_index] = - self.emitted_rows_per_part[build_index].saturating_add(emitted_rows); - // Mark matched build-side rows for outer joins (use current partition's bitmap) if let Some(bitmap) = self.matched_build_rows_per_partition.get_mut(build_index) { for build_idx in build_ids_to_mark { @@ -3982,9 +3549,7 @@ impl PartitionedHashJoinStream { state.active_offset = (0, None); state.joined_probe_idx = None; #[cfg(feature = "hybrid_hash_join_scheduler")] - if self.probe_scheduler_enabled { - self.schedule_probe_task(&partition_state.descriptor); - } + self.schedule_probe_task(&partition_state.descriptor); } } Err(e) => return Poll::Ready(Err(e)), @@ -3998,14 +3563,6 @@ impl PartitionedHashJoinStream { return Poll::Ready(Ok(StatefulStreamResult::Continue)); } self.join_metrics.output_batches.add(1); - self.log_stream(|| { - format!( - "process_partition part {} emitted {} rows (cumulative={})", - build_index, - emitted_rows, - self.emitted_rows_per_part[build_index] - ) - }); self.join_metrics.baseline.record_output(result.num_rows()); // println!( // "[spill-join] Emitting batch: rows={} (partition={})", @@ -4022,7 +3579,6 @@ impl PartitionedHashJoinStream { ) -> Poll>>> { if !need_produce_result_in_final(self.join_type) { self.state = PartitionedHashJoinState::Completed; - self.stream_completed = true; return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); } @@ -4275,7 +3831,6 @@ impl PartitionedHashJoinStream { } else { // All partitions processed self.state = PartitionedHashJoinState::Completed; - self.stream_completed = true; return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); } } @@ -4294,11 +3849,8 @@ impl Stream for PartitionedHashJoinStream { mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - if self.stream_completed { - return Poll::Ready(None); - } loop { - self.log_stream(|| format!("poll_next state {:?}", self.state)); + hhj_debug(|| format!("poll_next state {:?}", self.state)); match self.state.clone() { PartitionedHashJoinState::PartitionBuildSide => { // Collect build side and partition it @@ -4311,9 +3863,7 @@ impl Stream for PartitionedHashJoinStream { Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), Poll::Pending => return Poll::Pending, } - self.log_stream(|| { - format!("restarting build pass state={:?}", self.state) - }); + hhj_debug(|| format!("restarting build pass state={:?}", self.state)); match self.partition_build_side(left_data) { Ok(StatefulStreamResult::Continue) => continue, Ok(StatefulStreamResult::Ready(Some(batch))) => { @@ -4324,7 +3874,6 @@ impl Stream for PartitionedHashJoinStream { return Poll::Ready(Some(Ok(batch))); } Ok(StatefulStreamResult::Ready(None)) => { - self.stream_completed = true; return Poll::Ready(None) } Err(e) => return Poll::Ready(Some(Err(e))), @@ -4361,10 +3910,6 @@ impl Stream for PartitionedHashJoinStream { } #[cfg(feature = "hybrid_hash_join_scheduler")] PartitionedHashJoinState::WaitingForProbe => { - if !self.probe_scheduler_enabled { - self.state = PartitionedHashJoinState::HandleUnmatchedRows; - continue; - } if self.pending_partitions.is_empty() { if self.probe_scheduler_waiting_for_stream.is_empty() { hhj_debug(|| { @@ -4511,9 +4056,6 @@ mod scheduler_tests { schema, false, None, - 0, - #[cfg(feature = "hybrid_hash_join_scheduler")] - true, ) .unwrap(); stream.probe_scheduler_max_streams = max_streams; From dc461cad60750213928035d4790c8a5377bef303 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 12 Nov 2025 21:07:13 +0200 Subject: [PATCH 29/36] Revert "Try fixing freezzed state" This reverts commit d0a48e0cfc873ccf3e1dc27bb428ca3adda9f518. --- .../src/joins/hash_join/partitioned.rs | 167 +++--------------- .../src/joins/hash_join/scheduler.rs | 2 - 2 files changed, 23 insertions(+), 146 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index dcdf79c44c3f9..3cf813d0cc656 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -123,21 +123,6 @@ fn per_partition_budget_bytes(memory_threshold: usize, partitions: usize) -> usi budget.max(HYBRID_HASH_MIN_PARTITION_BYTES) } -fn estimate_probe_buffer_bytes( - batch: &RecordBatch, - values: &[ArrayRef], - hashes: &[u64], -) -> usize { - let batch_bytes = batch.get_array_memory_size(); - let values_bytes = values.iter().fold(0usize, |acc, arr| { - acc.saturating_add(arr.get_array_memory_size()) - }); - let hashes_bytes = hashes.len().saturating_mul(size_of::()); - batch_bytes - .saturating_add(values_bytes) - .saturating_add(hashes_bytes) -} - #[inline] fn hhj_debug String>(builder: F) { if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() { @@ -233,7 +218,6 @@ pub(super) struct ProbePartitionState { buffered: ProbePartition, batch_position: usize, buffered_rows: usize, - buffered_bytes: usize, spilled_rows: usize, consumed_rows: usize, spill_in_progress: Option, @@ -253,7 +237,6 @@ impl ProbePartitionState { buffered: ProbePartition::new(), batch_position: 0, buffered_rows: 0, - buffered_bytes: 0, spilled_rows: 0, consumed_rows: 0, spill_in_progress: None, @@ -618,71 +601,6 @@ impl PartitionedHashJoinStream { Ok(false) } - fn flush_probe_buffer_to_spill(&mut self, part_id: usize) -> Result<()> { - if !self.runtime_env.disk_manager.tmp_files_enabled() { - return Err(internal_datafusion_err!( - "Insufficient memory for buffering probe partitions and spilling is disabled" - )); - } - if part_id >= self.probe_states.len() { - return Ok(()); - } - - let (buffered, queue_ready, stream_active, mut writer_opt) = { - let state = self - .probe_states - .get_mut(part_id) - .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - if state.buffered.batches.is_empty() { - state.buffered_bytes = 0; - return Ok(()); - } - let queue_ready = !state.spill_files.is_empty(); - let stream_active = state.pending_stream.is_some(); - let buffered = mem::replace(&mut state.buffered, ProbePartition::new()); - state.buffered_rows = 0; - state.buffered_bytes = 0; - state.batch_position = 0; - let writer_opt = state.spill_in_progress.take(); - (buffered, queue_ready, stream_active, writer_opt) - }; - - if writer_opt.is_none() { - writer_opt = Some( - self.probe_spill_manager - .create_in_progress_file("hash_join_probe_partition")?, - ); - self.join_metrics.probe_spill_count.add(1); - } - - let mut writer = writer_opt.ok_or_else(|| { - internal_datafusion_err!("expected probe spill writer for partition") - })?; - - let mut spilled_rows = 0usize; - for batch in buffered.batches { - let batch_size = batch.get_array_memory_size(); - writer.append_batch(&batch)?; - self.join_metrics.probe_spilled_rows.add(batch.num_rows()); - self.join_metrics.probe_spilled_bytes.add(batch_size); - spilled_rows = spilled_rows.saturating_add(batch.num_rows()); - } - - { - let state = self - .probe_states - .get_mut(part_id) - .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - state.spill_in_progress = Some(writer); - state.spilled_rows = state.spilled_rows.saturating_add(spilled_rows); - } - - if !queue_ready && !stream_active { - self.finalize_spilled_partition(part_id)?; - } - Ok(()) - } - fn compute_recursive_fanout( &self, descriptor: &PartitionDescriptor, @@ -1024,8 +942,6 @@ impl PartitionedHashJoinStream { let shift_bits = descriptor.radix_bits; let mask = (fanout - 1) as u64; - let probe_partition_budget = - per_partition_budget_bytes(self.memory_threshold, self.num_partitions); let spill_file = { let state = self @@ -1034,7 +950,6 @@ impl PartitionedHashJoinStream { .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; state.batch_position = 0; state.buffered_rows = 0; - state.buffered_bytes = 0; state.spilled_rows = 0; state.consumed_rows = 0; state.active_batch = None; @@ -1121,7 +1036,6 @@ impl PartitionedHashJoinStream { state.spill_files.push_back(file); state.spilled_rows = 0; state.buffered_rows = 0; - state.buffered_bytes = 0; state.consumed_rows = 0; state.batch_position = 0; state.pending_stream = None; @@ -1140,8 +1054,6 @@ impl PartitionedHashJoinStream { .probe_states .get_mut(parent_index) .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; - state.buffered_bytes = 0; - state.batch_position = 0; mem::replace(&mut state.buffered, ProbePartition::new()) }; for idx in 0..parent_partition.batches.len() { @@ -1183,30 +1095,20 @@ impl PartitionedHashJoinStream { } let idx = partition_indices[sub_idx]; - let row_count = filtered_batch.num_rows(); - let buffered_bytes_delta = estimate_probe_buffer_bytes( - &filtered_batch, - &filtered_values, - &filtered_hashes, - ); - let mut should_flush = false; - { - let state = self.probe_states.get_mut(idx).ok_or_else(|| { - internal_datafusion_err!("missing probe partition") - })?; - state.buffered.batches.push(filtered_batch); - state.buffered.values.push(filtered_values); - state.buffered.hashes.push(filtered_hashes); - state.buffered_rows = state.buffered_rows.saturating_add(row_count); - state.buffered_bytes = - state.buffered_bytes.saturating_add(buffered_bytes_delta); - if state.buffered_bytes >= probe_partition_budget { - should_flush = true; - } - } - if should_flush { - self.flush_probe_buffer_to_spill(idx)?; - } + let state = self + .probe_states + .get_mut(idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_values); + state.buffered.hashes.push(filtered_hashes); + let buffered = state + .buffered + .batches + .last() + .map(|b| b.num_rows()) + .unwrap_or_default(); + state.buffered_rows = state.buffered_rows.saturating_add(buffered); } } @@ -1217,8 +1119,6 @@ impl PartitionedHashJoinStream { if self.probe_states.len() != self.num_partitions { self.resize_partition_vectors(); } - let probe_partition_budget = - per_partition_budget_bytes(self.memory_threshold, self.num_partitions); loop { match self.right.poll_next_unpin(cx) { @@ -1317,36 +1217,16 @@ impl PartitionedHashJoinStream { self.finalize_spilled_partition(part_id)?; } } else { - let row_count = filtered_batch.num_rows(); - let buffered_bytes_delta = estimate_probe_buffer_bytes( - &filtered_batch, - &filtered_on_values, - &filtered_hashes, - ); - let mut should_flush = false; - { - let state = self - .probe_states - .get_mut(part_id) - .ok_or_else(|| { - internal_datafusion_err!( - "missing probe partition" - ) - })?; - state.buffered.batches.push(filtered_batch); - state.buffered.values.push(filtered_on_values); - state.buffered.hashes.push(filtered_hashes); + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing probe partition") + })?; + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_on_values); + state.buffered.hashes.push(filtered_hashes); + if let Some(last) = state.buffered.batches.last() { state.buffered_rows = - state.buffered_rows.saturating_add(row_count); - state.buffered_bytes = state - .buffered_bytes - .saturating_add(buffered_bytes_delta); - if state.buffered_bytes >= probe_partition_budget { - should_flush = true; - } - } - if should_flush { - self.flush_probe_buffer_to_spill(part_id)?; + state.buffered_rows.saturating_add(last.num_rows()); } } } @@ -2393,7 +2273,6 @@ impl PartitionedHashJoinStream { state.buffered = ProbePartition::new(); state.batch_position = 0; state.buffered_rows = 0; - state.buffered_bytes = 0; } if let Some(b) = state.active_batch.as_ref() { state.consumed_rows = diff --git a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs index b2a04d7736242..8a979713adee2 100644 --- a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs +++ b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs @@ -260,7 +260,6 @@ pub(super) struct ProbePartitionState { pub buffered: ProbePartition, pub batch_position: usize, pub buffered_rows: usize, - pub buffered_bytes: usize, pub spilled_rows: usize, pub consumed_rows: usize, pub spill_in_progress: Option, @@ -279,7 +278,6 @@ impl ProbePartitionState { buffered: ProbePartition::new(), batch_position: 0, buffered_rows: 0, - buffered_bytes: 0, spilled_rows: 0, consumed_rows: 0, spill_in_progress: None, From 3bc28f0317b0540397b95f299ec1add2967ecba0 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 12 Nov 2025 21:07:26 +0200 Subject: [PATCH 30/36] Revert "Add spilltodisk collectMode" This reverts commit d36af9dd4f6530bcf797ca8d95f3ac76e2bc96bd. --- .../physical-plan/src/joins/hash_join/exec.rs | 271 +++++++----------- .../src/joins/hash_join/partitioned.rs | 12 +- 2 files changed, 103 insertions(+), 180 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 4bdd8e41e1f85..7f5083ea354a9 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -41,10 +41,7 @@ use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; -use crate::spill::{ - get_record_batch_memory_size, in_progress_spill_file::InProgressSpillFile, - spill_manager::SpillManager, -}; +use crate::spill::{get_record_batch_memory_size, spill_manager::SpillManager}; use crate::ExecutionPlanProperties; use crate::{ common::can_project, @@ -67,8 +64,7 @@ use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - internal_datafusion_err, internal_err, plan_err, project_schema, JoinSide, JoinType, - NullEquality, Result, + internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -83,7 +79,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{executor::block_on, pin_mut, StreamExt}; +use futures::{executor::block_on, StreamExt, TryStreamExt}; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. @@ -129,7 +125,7 @@ pub(super) struct JoinLeftData { pub(super) bounds: Option>, } -pub(super) enum OriginalBuildInput { +enum OriginalBuildInput { InMemory(Arc>), Spilled { spill_manager: Arc, @@ -137,29 +133,27 @@ pub(super) enum OriginalBuildInput { }, } -enum CollectLeftMode { - InMemory, - SpillToDisk { spill_manager: Arc }, -} - impl JoinLeftData { /// Create a new `JoinLeftData` from its parts pub(super) fn new( hash_map: Box, batch: RecordBatch, - original_input: OriginalBuildInput, - total_rows: usize, - total_input_size: usize, + original_batches: Arc>, values: Vec, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, reservation: MemoryReservation, bounds: Option>, ) -> Self { + let total_rows = original_batches.iter().map(|b| b.num_rows()).sum(); + let total_input_size = original_batches + .iter() + .map(|batch| batch.get_array_memory_size()) + .sum(); Self { hash_map, batch, - original_input, + original_input: OriginalBuildInput::InMemory(original_batches), total_rows, total_input_size, values, @@ -1028,7 +1022,6 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), enable_dynamic_filter_pushdown, - CollectLeftMode::InMemory, )) })?, PartitionMode::Partitioned => { @@ -1047,7 +1040,6 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), 1, enable_dynamic_filter_pushdown, - CollectLeftMode::InMemory, )) } PartitionMode::Auto => { @@ -1079,7 +1071,6 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), 1, enable_dynamic_filter_pushdown, - CollectLeftMode::InMemory, )) } else { // Spillable enabled: coalesce left to a single stream @@ -1093,11 +1084,6 @@ impl ExecutionPlan for HashJoinExec { let left_stream = left_plan.execute(0, Arc::clone(&context))?; let reservation = MemoryConsumer::new("HashJoinInput") .register(context.memory_pool()); - let build_input_spill_manager = Arc::new(SpillManager::new( - Arc::clone(&context.runtime_env()), - SpillMetrics::new(&self.metrics, partition), - build_schema.clone(), - )); let left_fut = self.left_fut.try_once(|| { Ok(collect_left_input( self.random_state.clone(), @@ -1108,9 +1094,6 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), enable_dynamic_filter_pushdown, - CollectLeftMode::SpillToDisk { - spill_manager: Arc::clone(&build_input_spill_manager), - }, )) })?; @@ -1766,53 +1749,44 @@ async fn collect_left_input( with_visited_indices_bitmap: bool, probe_threads_count: usize, should_compute_bounds: bool, - collect_mode: CollectLeftMode, ) -> Result { let schema = left_stream.schema(); - let mut state = BuildSideState::try_new( + // This operation performs 2 steps at once: + // 1. creates a [JoinHashMap] of all batches from the stream + // 2. stores the batches in a vector. + let initial = BuildSideState::try_new( metrics, reservation, on_left.clone(), &schema, should_compute_bounds, )?; - let mut total_input_size = 0usize; - let mut spill_writer: Option = None; - - pin_mut!(left_stream); - while let Some(batch) = left_stream.next().await { - let batch = batch?; - if let Some(ref mut accumulators) = state.bounds_accumulators { - for accumulator in accumulators { - accumulator.update_batch(&batch)?; - } - } - let batch_size = get_record_batch_memory_size(&batch); - state.reservation.try_grow(batch_size)?; - state.metrics.build_mem_used.add(batch_size); - state.metrics.build_input_batches.add(1); - state.metrics.build_input_rows.add(batch.num_rows()); - state.num_rows += batch.num_rows(); - total_input_size = total_input_size.saturating_add(batch_size); - - match &collect_mode { - CollectLeftMode::InMemory => { - state.batches.push(batch); - } - CollectLeftMode::SpillToDisk { spill_manager } => { - if spill_writer.is_none() { - spill_writer = Some( - spill_manager.create_in_progress_file("hash_join_build_input")?, - ); - } - if let Some(writer) = spill_writer.as_mut() { - writer.append_batch(&batch)?; + let state = left_stream + .try_fold(initial, |mut state, batch| async move { + // Update accumulators if computing bounds + if let Some(ref mut accumulators) = state.bounds_accumulators { + for accumulator in accumulators { + accumulator.update_batch(&batch)?; } } - } - } + + // Decide if we spill or not + let batch_size = get_record_batch_memory_size(&batch); + // Reserve memory for incoming batch + state.reservation.try_grow(batch_size)?; + // Update metrics + state.metrics.build_mem_used.add(batch_size); + state.metrics.build_input_batches.add(1); + state.metrics.build_input_rows.add(batch.num_rows()); + // Update row count + state.num_rows += batch.num_rows(); + // Push batch to output + state.batches.push(batch); + Ok(state) + }) + .await?; // Extract fields from state let BuildSideState { @@ -1823,113 +1797,70 @@ async fn collect_left_input( bounds_accumulators, } = state; let batches_arc = Arc::new(batches); - let (hashmap, single_batch, left_values, visited_indices_bitmap, original_input) = - match collect_mode { - CollectLeftMode::InMemory => { - let fixed_size_u32 = size_of::(); - let fixed_size_u64 = size_of::(); - let mut hashmap: Box = - if num_rows > u32::MAX as usize { - let estimated_hashtable_size = - estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU64::with_capacity(num_rows)) - } else { - let estimated_hashtable_size = - estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU32::with_capacity(num_rows)) - }; + // Estimation of memory size, required for hashtable, prior to allocation. + // Final result can be verified using `RawTable.allocation_info()` + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + + // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the + // `u64` indice variant + let mut hashmap: Box = if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; - let mut hashes_buffer = Vec::new(); - let mut offset = 0; - let batches_iter = batches_arc.iter().rev(); - for batch in batches_iter.clone() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( - &on_left, - batch, - &mut *hashmap, - offset, - &random_state, - &mut hashes_buffer, - 0, - true, - )?; - offset += batch.num_rows(); - } - let single_batch = concat_batches(&schema, batches_iter)?; - let visited_indices_bitmap = if with_visited_indices_bitmap { - let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); - reservation.try_grow(bitmap_size)?; - metrics.build_mem_used.add(bitmap_size); - - let mut bitmap_buffer = - BooleanBufferBuilder::new(single_batch.num_rows()); - bitmap_buffer.append_n(num_rows, false); - bitmap_buffer - } else { - BooleanBufferBuilder::new(0) - }; + let mut hashes_buffer = Vec::new(); + let mut offset = 0; - let left_values = on_left - .iter() - .map(|c| { - c.evaluate(&single_batch)? - .into_array(single_batch.num_rows()) - }) - .collect::>>()?; - - ( - hashmap, - single_batch, - left_values, - visited_indices_bitmap, - OriginalBuildInput::InMemory(Arc::clone(&batches_arc)), - ) - } - CollectLeftMode::SpillToDisk { spill_manager } => { - if num_rows == 0 { - ( - Box::new(JoinHashMapU32::with_capacity(0)) - as Box, - RecordBatch::new_empty(schema.clone()), - Vec::new(), - BooleanBufferBuilder::new(0), - OriginalBuildInput::InMemory(Arc::new(vec![])), - ) - } else { - let mut writer = spill_writer.ok_or_else(|| { - internal_datafusion_err!("missing build spill writer") - })?; - let spill_file = writer.finish()?.ok_or_else(|| { - internal_datafusion_err!( - "expected spill file when spilling build input" - ) - })?; - metrics.build_spill_count.add(1); - metrics.build_spilled_rows.add(num_rows); - metrics - .build_spilled_bytes - .add(spill_file.current_disk_usage() as usize); - let _ = reservation.try_shrink(reservation.size()); - ( - Box::new(JoinHashMapU32::with_capacity(0)) - as Box, - RecordBatch::new_empty(schema.clone()), - Vec::new(), - BooleanBufferBuilder::new(0), - OriginalBuildInput::Spilled { - spill_manager, - spill_file: Arc::new(spill_file), - }, - ) - } - } - }; + // Updating hashmap starting from the last batch + let batches_iter = batches_arc.iter().rev(); + for batch in batches_iter.clone() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + &on_left, + batch, + &mut *hashmap, + offset, + &random_state, + &mut hashes_buffer, + 0, + true, + )?; + offset += batch.num_rows(); + } + // Merge all batches into a single batch, so we can directly index into the arrays + let single_batch = concat_batches(&schema, batches_iter)?; + + // Reserve additional memory for visited indices bitmap and create shared builder + let visited_indices_bitmap = if with_visited_indices_bitmap { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let left_values = on_left + .iter() + .map(|c| { + c.evaluate(&single_batch)? + .into_array(single_batch.num_rows()) + }) + .collect::>>()?; // Compute bounds for dynamic filter if enabled let bounds = match bounds_accumulators { @@ -1946,9 +1877,7 @@ async fn collect_left_input( let data = JoinLeftData::new( hashmap, single_batch, - original_input, - num_rows, - total_input_size, + Arc::clone(&batches_arc), left_values, Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 3cf813d0cc656..b0df425fa4d65 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -623,7 +623,7 @@ impl PartitionedHashJoinStream { return None; } - let per_partition_budget = + let mut per_partition_budget = per_partition_budget_bytes(self.memory_threshold, self.num_partitions); let rows_budget = self @@ -3872,16 +3872,10 @@ mod scheduler_tests { let reservation = MemoryConsumer::new("left") .with_can_spill(true) .register(&runtime_env.memory_pool); - let arc_batches = Arc::new(vec![batch.clone()]); - let total_rows = arc_batches.iter().map(|b| b.num_rows()).sum(); - let total_input_size = - arc_batches.iter().map(|b| b.get_array_memory_size()).sum(); JoinLeftData::new( hash_map, - batch, - OriginalBuildInput::InMemory(Arc::clone(&arc_batches)), - total_rows, - total_input_size, + batch.clone(), + Arc::new(vec![batch]), vec![], Mutex::new(BooleanBufferBuilder::new(0)), AtomicUsize::new(0), From c333d2dccc12ec88c92bc26ff003dd245f6dd1ef Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Wed, 12 Nov 2025 21:08:21 +0200 Subject: [PATCH 31/36] Revert "Spill manager support for sharing spill file handle" This reverts commit 80f2c34459ec0327d2ad2336a132994679c93f65. --- .../physical-plan/src/joins/hash_join/exec.rs | 74 +++---------------- .../src/joins/hash_join/partitioned.rs | 28 ++++--- datafusion/physical-plan/src/spill/mod.rs | 10 +-- .../physical-plan/src/spill/spill_manager.rs | 22 ------ 4 files changed, 29 insertions(+), 105 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 7f5083ea354a9..c6eaa3489cbaa 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -41,7 +41,7 @@ use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; -use crate::spill::{get_record_batch_memory_size, spill_manager::SpillManager}; +use crate::spill::get_record_batch_memory_size; use crate::ExecutionPlanProperties; use crate::{ common::can_project, @@ -66,7 +66,6 @@ use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, }; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; @@ -79,7 +78,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{executor::block_on, StreamExt, TryStreamExt}; +use futures::TryStreamExt; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. @@ -101,14 +100,10 @@ const HYBRID_HASH_MIN_ROWS_PER_PARTITION: usize = 1_024; pub(super) struct JoinLeftData { /// The hash table with indices into `batch` pub(super) hash_map: Box, - /// The input rows for the build side (may be empty when spilled) + /// The input rows for the build side batch: RecordBatch, - /// Build-side input backing storage - original_input: OriginalBuildInput, - /// Total rows collected from the build side - total_rows: usize, - /// Estimated bytes for the original build input - total_input_size: usize, + /// Original build-side batches before concatenation + original_batches: Arc>, /// The build side on expressions values values: Vec, /// Shared bitmap builder for visited left indices @@ -125,14 +120,6 @@ pub(super) struct JoinLeftData { pub(super) bounds: Option>, } -enum OriginalBuildInput { - InMemory(Arc>), - Spilled { - spill_manager: Arc, - spill_file: Arc, - }, -} - impl JoinLeftData { /// Create a new `JoinLeftData` from its parts pub(super) fn new( @@ -145,17 +132,10 @@ impl JoinLeftData { reservation: MemoryReservation, bounds: Option>, ) -> Self { - let total_rows = original_batches.iter().map(|b| b.num_rows()).sum(); - let total_input_size = original_batches - .iter() - .map(|batch| batch.get_array_memory_size()) - .sum(); Self { hash_map, batch, - original_input: OriginalBuildInput::InMemory(original_batches), - total_rows, - total_input_size, + original_batches, values, visited_indices_bitmap, probe_threads_counter, @@ -174,12 +154,8 @@ impl JoinLeftData { &self.batch } - pub(super) fn total_rows(&self) -> usize { - self.total_rows - } - - pub(super) fn total_input_size(&self) -> usize { - self.total_input_size + pub(super) fn original_batches(&self) -> &[RecordBatch] { + &self.original_batches } /// returns a reference to the build side expressions values @@ -197,37 +173,6 @@ impl JoinLeftData { pub(super) fn report_probe_completed(&self) -> bool { self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 } - - pub(super) fn for_each_original_batch(&self, mut f: F) -> Result<()> - where - F: FnMut(RecordBatch) -> Result, - { - match &self.original_input { - OriginalBuildInput::InMemory(batches) => { - for batch in batches.iter() { - if !f(batch.clone())? { - break; - } - } - Ok(()) - } - OriginalBuildInput::Spilled { - spill_manager, - spill_file, - } => { - let mut stream = - spill_manager.read_spill_as_stream_shared(Arc::clone(spill_file))?; - block_on(async { - while let Some(batch) = stream.next().await.transpose()? { - if !f(batch)? { - break; - } - } - Ok(()) - }) - } - } - } } #[allow(rustdoc::private_intra_doc_links)] @@ -1797,6 +1742,7 @@ async fn collect_left_input( bounds_accumulators, } = state; let batches_arc = Arc::new(batches); + // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` let fixed_size_u32 = size_of::(); @@ -1878,7 +1824,7 @@ async fn collect_left_input( hashmap, single_batch, Arc::clone(&batches_arc), - left_values, + left_values.clone(), Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), reservation, diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index b0df425fa4d65..f6b3a3c2c1fd1 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -1828,11 +1828,19 @@ impl PartitionedHashJoinStream { ) -> Result>> { if self.partition_pass == 0 { self.join_metrics.build_input_batches.add(1); - let total_rows = build_data.total_rows(); + let total_rows: usize = build_data + .original_batches() + .iter() + .map(|b| b.num_rows()) + .sum(); self.join_metrics.build_input_rows.add(total_rows); } - let build_total_size = build_data.total_input_size(); + let build_total_size: usize = build_data + .original_batches() + .iter() + .map(|batch| batch.get_array_memory_size()) + .sum(); if build_total_size <= self.memory_threshold { self.num_partitions = 1; self.max_partition_count = 1; @@ -1903,10 +1911,10 @@ impl PartitionedHashJoinStream { let mut max_spilled_bytes: usize = 0; let mut any_spilled = false; - build_data.for_each_original_batch(|batch| { + for batch in build_data.original_batches() { let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); for expr in &self.on_left { - keys_values.push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); + keys_values.push(expr.evaluate(batch)?.into_array(batch.num_rows())?); } let mut hashes = vec![0u64; batch.num_rows()]; create_hashes(&keys_values, &self.random_state, &mut hashes)?; @@ -1963,7 +1971,7 @@ impl PartitionedHashJoinStream { ) }); repartition_request = Some(next_count); - return Ok(false); + break; } } } @@ -1988,7 +1996,7 @@ impl PartitionedHashJoinStream { ) }); repartition_request = Some(next_count); - return Ok(false); + break; } } } @@ -2003,16 +2011,14 @@ impl PartitionedHashJoinStream { } if repartition_request.is_some() { - return Ok(false); + break; } } if repartition_request.is_some() { - Ok(false) - } else { - Ok(true) + break; } - })?; + } if let Some(next_count) = repartition_request { hhj_debug(|| { diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 08056f9e122ba..782100e6d4cf1 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -52,7 +52,6 @@ use futures::{FutureExt as _, Stream}; /// file read (instead of each batch). This approach does not work because when /// the number of concurrent reads exceeds the Tokio thread pool limit, /// deadlocks can occur and block progress. - struct SpillReaderStream { schema: SchemaRef, state: SpillReaderStreamState, @@ -65,7 +64,7 @@ type NextRecordBatchResult = Result<(StreamReader>, Option), + Uninitialized(RefCountedTempFile), /// A read is in progress in a spawned blocking task for which we hold the handle. ReadInProgress(SpawnedTask), @@ -79,10 +78,6 @@ enum SpillReaderStreamState { impl SpillReaderStream { fn new(schema: SchemaRef, spill_file: RefCountedTempFile) -> Self { - Self::new_from_shared(schema, Arc::new(spill_file)) - } - - fn new_from_shared(schema: SchemaRef, spill_file: Arc) -> Self { Self { schema, state: SpillReaderStreamState::Uninitialized(spill_file), @@ -102,9 +97,8 @@ impl SpillReaderStream { unreachable!() }; - let file_ref = spill_file.clone(); let task = SpawnedTask::spawn_blocking(move || { - let file = BufReader::new(File::open(file_ref.path())?); + let file = BufReader::new(File::open(spill_file.path())?); // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications // with validated schemas and buffers. Skip redundant validation during read // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written. diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index d3c38aba45989..de11fb2905fe1 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -232,28 +232,6 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } - - pub fn read_spill_as_stream_ref( - &self, - spill_file_path: &RefCountedTempFile, - ) -> Result { - let stream = Box::pin(cooperative(SpillReaderStream::new( - Arc::clone(&self.schema), - spill_file_path.clone_refcounted()?, - ))); - - Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) - } - - pub fn load_spilled_batch( - &self, - spill: &SpillLocation, - ) -> Result { - match spill { - SpillLocation::Memory(buf) => Ok(Arc::clone(&buf).as_stream(Arc::clone(&self.schema))?), - SpillLocation::Disk(file) => self.read_spill_as_stream_ref(file), - } - } } #[derive(Debug, Clone)] From 8a46c67450b22df118f3b318d82a8ee757a55db0 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Thu, 13 Nov 2025 15:26:42 +0200 Subject: [PATCH 32/36] Fix RightSemi joins --- .../physical-optimizer/src/join_selection.rs | 74 ++++++--- .../src/joins/grace_hash_join/exec.rs | 25 +-- .../src/joins/grace_hash_join/stream.rs | 151 +++--------------- .../physical-plan/src/spill/spill_manager.rs | 39 ++++- 4 files changed, 116 insertions(+), 173 deletions(-) diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 111e9d1837df7..a587ec027d89b 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -27,12 +27,12 @@ use crate::PhysicalOptimizerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, JoinSide, JoinType}; +use datafusion_common::{internal_err, DataFusionError, JoinSide, JoinType}; use datafusion_expr_common::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::execution_plan::EmissionType; -use datafusion_physical_plan::joins::utils::ColumnIndex; +use datafusion_physical_plan::joins::utils::{check_join_is_valid, ColumnIndex}; use datafusion_physical_plan::joins::{ CrossJoinExec, GraceHashJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -173,6 +173,13 @@ pub(crate) fn try_collect_left( let left = hash_join.left(); let right = hash_join.right(); + // Skip collect-left rewrite if the join currently has inconsistent schemas (e.g. required + // columns were projected away temporarily). This mirrors the legacy hash join behavior where + // collect-left is only attempted once the join inputs are fully valid. + if check_join_is_valid(&left.schema(), &right.schema(), hash_join.on()).is_err() { + return Ok(None); + } + let left_can_collect = ignore_threshold || supports_collect_by_thresholds( &**left, @@ -191,33 +198,23 @@ pub(crate) fn try_collect_left( if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) + match hash_join.swap_inputs(PartitionMode::CollectLeft) { + Ok(plan) => Ok(Some(plan)), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } } else { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))) + build_collect_left_exec(hash_join, left, right) } } - (true, false) => Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))), + (true, false) => build_collect_left_exec(hash_join, left, right), (false, true) => { if hash_join.join_type().supports_swap() { - hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) + match hash_join.swap_inputs(PartitionMode::CollectLeft) { + Ok(plan) => Ok(Some(plan)), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } } else { Ok(None) } @@ -226,6 +223,35 @@ pub(crate) fn try_collect_left( } } +fn is_missing_join_columns(err: &DataFusionError) -> bool { + matches!( + err, + DataFusionError::Plan(msg) + if msg.contains("The left or right side of the join does not have all columns") + ) +} + +fn build_collect_left_exec( + hash_join: &HashJoinExec, + left: &Arc, + right: &Arc, +) -> Result>> { + match HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.projection.clone(), + PartitionMode::CollectLeft, + hash_join.null_equality(), + ) { + Ok(exec) => Ok(Some(Arc::new(exec))), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } +} + /// Creates a partitioned hash join execution plan, swapping inputs if beneficial. /// /// Checks if the join order should be swapped based on the join type and input statistics. diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index 1e85ae5c58775..deae45ef51e06 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -53,16 +53,13 @@ use datafusion_common::{ internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::TaskContext; -use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use crate::joins::grace_hash_join::stream::{ - GraceAccumulator, GraceHashJoinStream, SpillFut, -}; +use crate::joins::grace_hash_join::stream::{GraceHashJoinStream, SpillFut}; use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; use crate::metrics::SpillMetrics; use crate::spill::spill_manager::SpillLocation; @@ -105,7 +102,6 @@ pub struct GraceHashJoinExec { /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. dynamic_filter: Option, - accumulator: Arc, } #[derive(Clone)] @@ -183,9 +179,6 @@ impl GraceHashJoinExec { &on, projection.as_ref(), )?; - let partitions = left.output_partitioning().partition_count(); - let accumulator = GraceAccumulator::new(partitions); - let metrics = ExecutionPlanMetricsSet::new(); // Initialize both dynamic filter and bounds accumulator to None // They will be set later if dynamic filtering is enabled @@ -203,7 +196,6 @@ impl GraceHashJoinExec { null_equality, cache, dynamic_filter: None, - accumulator, }) } @@ -547,7 +539,6 @@ impl ExecutionPlan for GraceHashJoinExec { self: Arc, children: Vec>, ) -> Result> { - let new_partition_count = children[0].output_partitioning().partition_count(); Ok(Arc::new(GraceHashJoinExec { left: Arc::clone(&children[0]), right: Arc::clone(&children[1]), @@ -570,12 +561,10 @@ impl ExecutionPlan for GraceHashJoinExec { )?, // Keep the dynamic filter, bounds accumulator will be reset dynamic_filter: self.dynamic_filter.clone(), - accumulator: GraceAccumulator::new(new_partition_count), })) } fn reset_state(self: Arc) -> Result> { - let partition_count = self.left.output_partitioning().partition_count(); Ok(Arc::new(GraceHashJoinExec { left: Arc::clone(&self.left), right: Arc::clone(&self.right), @@ -591,7 +580,6 @@ impl ExecutionPlan for GraceHashJoinExec { cache: self.cache.clone(), // Reset dynamic filter and bounds accumulator to initial state dynamic_filter: None, - accumulator: GraceAccumulator::new(partition_count), })) } @@ -654,7 +642,6 @@ impl ExecutionPlan for GraceHashJoinExec { let on = self.on.clone(); let spill_left_clone = Arc::clone(&spill_left); let spill_right_clone = Arc::clone(&spill_right); - let accumulator_clone = Arc::clone(&self.accumulator); let join_metrics_clone = Arc::clone(&join_metrics); let spill_fut = OnceFut::new(async move { let (left_idx, right_idx) = partition_and_spill( @@ -670,14 +657,16 @@ impl ExecutionPlan for GraceHashJoinExec { partition, ) .await?; - accumulator_clone - .report_partition(partition, left_idx.clone(), right_idx.clone()) - .await; Ok(SpillFut::new(partition, left_idx, right_idx)) }); + let left_input_schema = self.left.schema(); + let right_input_schema = self.right.schema(); + Ok(Box::pin(GraceHashJoinStream::new( self.schema(), + left_input_schema, + right_input_schema, spill_fut, spill_left, spill_right, @@ -689,7 +678,6 @@ impl ExecutionPlan for GraceHashJoinExec { column_indices_after_projection, join_metrics, context, - Arc::clone(&self.accumulator), ))) } @@ -845,7 +833,6 @@ impl ExecutionPlan for GraceHashJoinExec { filter: dynamic_filter, bounds_accumulator: OnceLock::new(), }), - accumulator: Arc::clone(&self.accumulator), }); result = result.with_updated_node(new_node as Arc); } diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs index d028b0e8bcf87..da0de10a51889 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs @@ -34,22 +34,17 @@ use crate::empty::EmptyExec; use crate::joins::grace_hash_join::exec::PartitionIndex; use crate::joins::{HashJoinExec, PartitionMode}; use crate::test::TestMemoryExec; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{JoinType, NullEquality, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExprRef; use futures::{ready, Stream, StreamExt}; -use tokio::sync::Mutex; enum GraceJoinState { /// Waiting for the partitioning phase (Phase 1) to finish WaitPartitioning, - WaitAllPartitions { - wait_all_fut: Option>>, - }, - /// Currently joining partition `current` JoinPartition { current: usize, @@ -64,6 +59,8 @@ enum GraceJoinState { pub struct GraceHashJoinStream { schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, spill_fut: OnceFut, spill_left: Arc, spill_right: Arc, @@ -75,27 +72,21 @@ pub struct GraceHashJoinStream { column_indices: Vec, join_metrics: Arc, context: Arc, - accumulator: Arc, state: GraceJoinState, } #[derive(Debug, Clone)] pub struct SpillFut { - partition: usize, left: Vec, right: Vec, } impl SpillFut { pub(crate) fn new( - partition: usize, + _partition: usize, left: Vec, right: Vec, ) -> Self { - SpillFut { - partition, - left, - right, - } + SpillFut { left, right } } } @@ -108,6 +99,8 @@ impl RecordBatchStream for GraceHashJoinStream { impl GraceHashJoinStream { pub fn new( schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, spill_fut: OnceFut, spill_left: Arc, spill_right: Arc, @@ -119,10 +112,11 @@ impl GraceHashJoinStream { column_indices: Vec, join_metrics: Arc, context: Arc, - accumulator: Arc, ) -> Self { Self { schema, + left_input_schema, + right_input_schema, spill_fut, spill_left, spill_right, @@ -134,7 +128,6 @@ impl GraceHashJoinStream { column_indices, join_metrics, context, - accumulator, state: GraceJoinState::WaitPartitioning, } } @@ -148,47 +141,16 @@ impl GraceHashJoinStream { match &mut self.state { GraceJoinState::WaitPartitioning => { let shared = ready!(self.spill_fut.get_shared(cx))?; - - let acc = Arc::clone(&self.accumulator); - let left = shared.left.clone(); - let right = shared.right.clone(); - // Use 0 partition as the main - let wait_all_fut = if shared.partition == 0 { - OnceFut::new(async move { - acc.report_partition(shared.partition, left, right).await; - let all = acc.wait_all().await; - Ok(all) - }) - } else { - OnceFut::new(async move { - acc.report_partition(shared.partition, left, right).await; - acc.wait_ready().await; - Ok(vec![]) - }) - }; - self.state = GraceJoinState::WaitAllPartitions { - wait_all_fut: Some(wait_all_fut), + let parts = Arc::new(vec![(*shared).clone()]); + self.state = GraceJoinState::JoinPartition { + current: 0, + all_parts: parts, + current_stream: None, + left_fut: None, + right_fut: None, }; continue; } - GraceJoinState::WaitAllPartitions { wait_all_fut } => { - if let Some(fut) = wait_all_fut { - let all_arc = ready!(fut.get_shared(cx))?; - let mut all = (*all_arc).clone(); - all.sort_by_key(|s| s.partition); - - self.state = GraceJoinState::JoinPartition { - current: 0, - all_parts: Arc::from(all), - current_stream: None, - left_fut: None, - right_fut: None, - }; - continue; - } else { - return Poll::Pending; - } - } GraceJoinState::JoinPartition { current, all_parts, @@ -223,6 +185,8 @@ impl GraceHashJoinStream { let stream = build_in_memory_join_stream( Arc::clone(&self.schema), + Arc::clone(&self.left_input_schema), + Arc::clone(&self.right_input_schema), left_batches, right_batches, &self.on_left, @@ -282,6 +246,8 @@ fn load_partition_async( /// Build an in-memory HashJoinExec for one pair of spilled partitions fn build_in_memory_join_stream( output_schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, left_batches: Vec, right_batches: Vec, on_left: &[PhysicalExprRef], @@ -297,22 +263,15 @@ fn build_in_memory_join_stream( return EmptyExec::new(output_schema).execute(0, Arc::clone(context)); } - let left_schema = left_batches - .first() - .map(|b| b.schema()) - .unwrap_or_else(|| Arc::new(Schema::empty())); - - let right_schema = right_batches - .first() - .map(|b| b.schema()) - .unwrap_or_else(|| Arc::new(Schema::empty())); - // Build memory execution nodes for each side - let left_plan: Arc = - Arc::new(TestMemoryExec::try_new(&[left_batches], left_schema, None)?); + let left_plan: Arc = Arc::new(TestMemoryExec::try_new( + &[left_batches], + left_input_schema, + None, + )?); let right_plan: Arc = Arc::new(TestMemoryExec::try_new( &[right_batches], - right_schema, + right_input_schema, None, )?); @@ -349,61 +308,3 @@ impl Stream for GraceHashJoinStream { self.poll_next_impl(cx) } } - -#[derive(Debug)] -pub struct GraceAccumulator { - expected: usize, - collected: Mutex>, - notify: tokio::sync::Notify, -} - -impl GraceAccumulator { - pub fn new(expected: usize) -> Arc { - Arc::new(Self { - expected, - collected: Mutex::new(vec![]), - notify: tokio::sync::Notify::new(), - }) - } - - pub async fn report_partition( - &self, - part_id: usize, - left_idx: Vec, - right_idx: Vec, - ) { - let mut guard = self.collected.lock().await; - if let Some(pos) = guard.iter().position(|s| s.partition == part_id) { - guard[pos] = SpillFut::new(part_id, left_idx, right_idx); - } else { - guard.push(SpillFut::new(part_id, left_idx, right_idx)); - } - - if guard.len() == self.expected { - self.notify.notify_waiters(); - } - } - - pub async fn wait_all(&self) -> Vec { - loop { - { - let guard = self.collected.lock().await; - if guard.len() == self.expected { - return guard.clone(); - } - } - self.notify.notified().await; - } - } - pub async fn wait_ready(&self) { - loop { - { - let guard = self.collected.lock().await; - if guard.len() == self.expected { - return; - } - } - self.notify.notified().await; - } - } -} diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index de11fb2905fe1..27176ff02527e 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -30,8 +30,8 @@ use datafusion_execution::SendableRecordBatchStream; use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; use crate::coop::cooperative; -use crate::{common::spawn_buffered, metrics::SpillMetrics}; use crate::spill::in_memory_spill_buffer::InMemorySpillBuffer; +use crate::{common::spawn_buffered, metrics::SpillMetrics}; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -172,7 +172,11 @@ impl SpillManager { /// Automatically decides whether to spill the given RecordBatch to memory or disk, /// depending on available memory pool capacity. - pub(crate) fn spill_batch_auto(&self, batch: &RecordBatch, request_msg: &str) -> Result { + pub(crate) fn spill_batch_auto( + &self, + batch: &RecordBatch, + request_msg: &str, + ) -> Result { // let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { // return Err(DataFusionError::Execution( // "failed to spill batch to disk".into(), @@ -190,14 +194,16 @@ impl SpillManager { }; // If there's enough memory (with a safety margin), keep it in memory - if used + size * 3 / 2 <= limit { + if used + size * 3 / 2 <= limit { let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); self.metrics.spilled_bytes.add(size); self.metrics.spilled_rows.add(batch.num_rows()); Ok(SpillLocation::Memory(buf)) } else { // Otherwise spill to disk using the existing SpillManager logic - let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { + let Some(file) = + self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? + else { return Err(DataFusionError::Execution( "failed to spill batch to disk".into(), )); @@ -232,6 +238,30 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + + pub fn read_spill_as_stream_ref( + &self, + spill_file_path: &RefCountedTempFile, + ) -> Result { + let stream = Box::pin(cooperative(SpillReaderStream::new( + Arc::clone(&self.schema), + spill_file_path.clone_refcounted()?, + ))); + + Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + } + + pub fn load_spilled_batch( + &self, + spill: &SpillLocation, + ) -> Result { + match spill { + SpillLocation::Memory(buf) => { + Ok(Arc::clone(&buf).as_stream(Arc::clone(&self.schema))?) + } + SpillLocation::Disk(file) => self.read_spill_as_stream_ref(file), + } + } } #[derive(Debug, Clone)] @@ -240,7 +270,6 @@ pub enum SpillLocation { Disk(Arc), } - pub(crate) trait GetSlicedSize { /// Returns the size of the `RecordBatch` when sliced. /// Note: if multiple arrays or even a single array share the same data buffers, we may double count each buffer. From 293a93300d30643069b97bfe4cb973f9d8032bcb Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Thu, 13 Nov 2025 17:43:28 +0200 Subject: [PATCH 33/36] Cleanup --- .../src/joins/grace_hash_join/exec.rs | 93 ++++--------------- .../src/joins/hash_join/partitioned.rs | 27 +++--- 2 files changed, 32 insertions(+), 88 deletions(-) diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index deae45ef51e06..3e81004dadc28 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -26,7 +26,6 @@ use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; -use crate::spill::get_record_batch_memory_size; use crate::{ common::can_project, joins::utils::{ @@ -41,7 +40,7 @@ use crate::{ use crate::{ExecutionPlanProperties, SpillManager}; use std::fmt; use std::fmt::Formatter; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use std::{any::Any, vec}; use arrow::array::UInt32Array; @@ -98,19 +97,8 @@ pub struct GraceHashJoinExec { pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, - /// Dynamic filter for pushing down to the probe side - /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. - /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. - dynamic_filter: Option, -} - -#[derive(Clone)] -struct HashJoinExecDynamicFilter { - /// Dynamic filter that we'll update with the results of the build side once that is done. - filter: Arc, - /// Bounds accumulator to keep track of the min/max bounds on the join keys for each partition. - /// It is lazily initialized during execution to make sure we use the actual execution time partition counts. - bounds_accumulator: OnceLock>, + /// Indicates whether dynamic filter pushdown is enabled for this join. + dynamic_filter_enabled: bool, } impl fmt::Debug for GraceHashJoinExec { @@ -128,7 +116,7 @@ impl fmt::Debug for GraceHashJoinExec { .field("column_indices", &self.column_indices) .field("null_equality", &self.null_equality) .field("cache", &self.cache) - // Explicitly exclude dynamic_filter to avoid runtime state differences in tests + // Intentionally omit dynamic_filter_enabled to keep debug output stable .finish() } } @@ -195,7 +183,7 @@ impl GraceHashJoinExec { column_indices, null_equality, cache, - dynamic_filter: None, + dynamic_filter_enabled: false, }) } @@ -559,8 +547,8 @@ impl ExecutionPlan for GraceHashJoinExec { &self.on, self.projection.as_ref(), )?, - // Keep the dynamic filter, bounds accumulator will be reset - dynamic_filter: self.dynamic_filter.clone(), + // Preserve dynamic filter enablement; state will be refreshed as needed + dynamic_filter_enabled: self.dynamic_filter_enabled, })) } @@ -578,8 +566,8 @@ impl ExecutionPlan for GraceHashJoinExec { column_indices: self.column_indices.clone(), null_equality: self.null_equality, cache: self.cache.clone(), - // Reset dynamic filter and bounds accumulator to initial state - dynamic_filter: None, + // Reset dynamic filter state to initial configuration + dynamic_filter_enabled: false, })) } @@ -598,7 +586,7 @@ impl ExecutionPlan for GraceHashJoinExec { ); } - let enable_dynamic_filter_pushdown = self.dynamic_filter.is_some(); + let enable_dynamic_filter_pushdown = self.dynamic_filter_enabled; let join_metrics = Arc::new(BuildProbeJoinMetrics::new(partition, &self.metrics)); @@ -812,9 +800,7 @@ impl ExecutionPlan for GraceHashJoinExec { // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating let predicate = Arc::clone(&filter.predicate); - if let Ok(dynamic_filter) = - Arc::downcast::(predicate) - { + if Arc::downcast::(predicate).is_ok() { // We successfully pushed down our self filter - we need to make a new node with the dynamic filter let new_node = Arc::new(GraceHashJoinExec { left: Arc::clone(&self.left), @@ -829,10 +815,7 @@ impl ExecutionPlan for GraceHashJoinExec { column_indices: self.column_indices.clone(), null_equality: self.null_equality, cache: self.cache.clone(), - dynamic_filter: Some(HashJoinExecDynamicFilter { - filter: dynamic_filter, - bounds_accumulator: OnceLock::new(), - }), + dynamic_filter_enabled: true, }); result = result.with_updated_node(new_node as Arc); } @@ -959,8 +942,8 @@ async fn partition_and_spill_one_side( // Prepare indexes let mut result = Vec::with_capacity(partitions.len()); - for (i, writer) in partitions.into_iter().enumerate() { - result.push(writer.finish(i)?); + for writer in partitions.into_iter() { + result.push(writer.finish()?); } // println!("spill_manager {:?}", spill_manager.metrics); Ok(result) @@ -969,8 +952,6 @@ async fn partition_and_spill_one_side( #[derive(Debug)] pub struct PartitionWriter { spill_manager: Arc, - total_rows: usize, - total_bytes: usize, chunks: Vec, } @@ -978,8 +959,6 @@ impl PartitionWriter { pub fn new(spill_manager: Arc) -> Self { Self { spill_manager, - total_rows: 0, - total_bytes: 0, chunks: vec![], } } @@ -990,18 +969,13 @@ impl PartitionWriter { request_msg: &str, ) -> Result<()> { let loc = self.spill_manager.spill_batch_auto(batch, request_msg)?; - self.total_rows += batch.num_rows(); - self.total_bytes += get_record_batch_memory_size(batch); self.chunks.push(loc); Ok(()) } - pub fn finish(self, part_id: usize) -> Result { + pub fn finish(self) -> Result { Ok(PartitionIndex { - part_id, chunks: self.chunks, - total_rows: self.total_rows, - total_bytes: self.total_bytes, }) } } @@ -1016,15 +990,6 @@ impl PartitionWriter { /// Partition 3 -> [ spill_chunk_3_0.arrow, spill_chunk_3_1.arrow ] #[derive(Debug, Clone)] pub struct PartitionIndex { - /// Unique partition identifier (0..N-1) - pub part_id: usize, - - /// Total number of rows in this partition - pub total_rows: usize, - - /// Total size in bytes of all batches in this partition - pub total_bytes: usize, - /// Collection of spill locations (each corresponds to one batch written /// by [`PartitionWriter::spill_batch_auto`]) pub chunks: Vec, @@ -1034,9 +999,7 @@ pub struct PartitionIndex { mod tests { use super::*; use crate::test::TestMemoryExec; - use crate::{ - common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, - }; + use crate::{common, expressions::Column, repartition::RepartitionExec}; use crate::joins::HashJoinExec; use arrow::array::{ArrayRef, Int32Array}; @@ -1078,28 +1041,8 @@ mod tests { Arc::new(TestMemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } - fn build_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let batch = build_table_i32(a, b, c); - let schema = batch.schema(); - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() - } - #[tokio::test] async fn simple_grace_hash_join() -> Result<()> { - // let left = build_table( - // ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), - // ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), - // ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), - // ); - // let right = build_table( - // ("a2", &vec![1, 2]), - // ("b2", &vec![1, 2]), - // ("c2", &vec![14, 15]), - // ); let left = build_large_table("a1", "b1", "c1", 2000000); let right = build_large_table("a2", "b2", "c2", 5000000); let on = vec![( @@ -1153,7 +1096,7 @@ mod tests { batches.extend(v); } let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - println!("TOTAL ROWS = {}", total_rows); + assert_eq!(total_rows, 1_000_000); // print_batches(&*batches).unwrap(); // Asserting that operator-level reservation attempting to overallocate @@ -1216,7 +1159,7 @@ mod tests { ); } let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - println!("TOTAL ROWS = {}", total_rows); + assert_eq!(total_rows, 1_000_000); // print_batches(&*batches).unwrap(); Ok(()) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index f6b3a3c2c1fd1..83e82c9c29a77 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -43,12 +43,6 @@ //! - Generates join results and handles unmatched rows for outer joins //! - Tracks matched rows for proper outer join semantics -use std::collections::VecDeque; -use std::mem::{self, size_of}; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::time::SystemTime; - #[cfg(feature = "hybrid_hash_join_scheduler")] use super::scheduler::{ HybridTaskScheduler, ProbeDataPoll, ProbePartitionState, ProbeStageTask, @@ -66,6 +60,10 @@ use crate::metrics::SpillMetrics; use crate::spill::in_progress_spill_file::InProgressSpillFile; use crate::spill::spill_manager::SpillManager; use crate::{RecordBatchStream, SendableRecordBatchStream}; +use std::collections::VecDeque; +use std::mem::{self, size_of}; +use std::sync::Arc; +use std::task::{Context, Poll}; use arrow::array::{Array, ArrayRef, BooleanBufferBuilder, UInt32Array, UInt64Array}; use arrow::compute::{concat_batches, take}; @@ -125,12 +123,15 @@ fn per_partition_budget_bytes(memory_threshold: usize, partitions: usize) -> usi #[inline] fn hhj_debug String>(builder: F) { - if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() { - let ts = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_millis()) - .unwrap_or(0); - println!("[hhj-debug {ts}] {}", builder()); + if log::log_enabled!( + target: "datafusion::physical_plan::hash_join::partitioned", + log::Level::Debug + ) { + log::debug!( + target: "datafusion::physical_plan::hash_join::partitioned", + "{}", + builder() + ); } } @@ -623,7 +624,7 @@ impl PartitionedHashJoinStream { return None; } - let mut per_partition_budget = + let per_partition_budget = per_partition_budget_bytes(self.memory_threshold, self.num_partitions); let rows_budget = self From 50d36b3c126d6e007e2edce9df39ad6121b00d8e Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Thu, 13 Nov 2025 19:53:15 +0200 Subject: [PATCH 34/36] Remove HHJ debug --- .../src/joins/hash_join/partitioned.rs | 130 ------------------ 1 file changed, 130 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 83e82c9c29a77..40c8571df2810 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -121,20 +121,6 @@ fn per_partition_budget_bytes(memory_threshold: usize, partitions: usize) -> usi budget.max(HYBRID_HASH_MIN_PARTITION_BYTES) } -#[inline] -fn hhj_debug String>(builder: F) { - if log::log_enabled!( - target: "datafusion::physical_plan::hash_join::partitioned", - log::Level::Debug - ) { - log::debug!( - target: "datafusion::physical_plan::hash_join::partitioned", - "{}", - builder() - ); - } -} - /// State of the partitioned hash join stream #[derive(Debug, Clone)] pub(super) enum PartitionedHashJoinState { @@ -576,7 +562,6 @@ impl PartitionedHashJoinStream { let part_id = descriptor.build_index; self.ensure_probe_scheduler_capacity(part_id); if self.probe_scheduler_inflight[part_id] { - hhj_debug(|| format!("schedule_probe_task skip part {part_id} (inflight)")); return; } let task = SchedulerTask::Probe(ProbeStageTask::new( @@ -585,7 +570,6 @@ impl PartitionedHashJoinStream { )); self.probe_task_scheduler.push_task(task); self.probe_scheduler_inflight[part_id] = true; - hhj_debug(|| format!("schedule_probe_task queued part {part_id}")); } fn finalize_spilled_partition(&mut self, part_id: usize) -> Result { @@ -1491,9 +1475,6 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] { if !self.probe_scheduler_waiting_for_stream.is_empty() { - hhj_debug(|| { - "transition_to_next_partition -> WaitingForProbe".to_string() - }); self.state = PartitionedHashJoinState::WaitingForProbe; return; } @@ -1849,41 +1830,17 @@ impl PartitionedHashJoinStream { let mut allow_repartition = !self.partition_pass_output_started; loop { - hhj_debug(|| { - format!( - "partition_build_side pass={} num_partitions={} allow_repartition={}", - self.partition_pass, self.num_partitions, allow_repartition - ) - }); self.reset_partition_state(); match self.try_partition_build_side(&build_data, allow_repartition)? { PartitionBuildStatus::Ready(result) => { - hhj_debug(|| { - format!( - "partition_build_side pass {} completed (num_partitions={})", - self.partition_pass, self.num_partitions - ) - }); return Ok(result); } PartitionBuildStatus::NeedMorePartitions { next_count } => { - hhj_debug(|| { - format!( - "partition_build_side requesting repartition to {} (current={})", - next_count, self.num_partitions - ) - }); if next_count <= self.num_partitions || next_count == 0 || next_count > self.max_partition_count { - hhj_debug(|| { - format!( - "repartition request invalid (max={} current={}); forcing spill", - self.max_partition_count, self.num_partitions - ) - }); allow_repartition = false; continue; } @@ -1965,12 +1922,6 @@ impl PartitionedHashJoinStream { if self.repartition_worthwhile(partition_estimate) { if let Some(next_count) = self.next_partition_count() { - hhj_debug(|| { - format!( - "partition {} exceeded budget (bytes={}) -> requesting repartition to {}", - build_index, partition_estimate, next_count - ) - }); repartition_request = Some(next_count); break; } @@ -1990,12 +1941,6 @@ impl PartitionedHashJoinStream { accum.buffered_bytes.saturating_add(batch_size); if self.repartition_worthwhile(partition_estimate) { if let Some(next_count) = self.next_partition_count() { - hhj_debug(|| { - format!( - "allocation failure for partition {} (bytes={}) -> requesting repartition to {}", - build_index, partition_estimate, next_count - ) - }); repartition_request = Some(next_count); break; } @@ -2022,11 +1967,6 @@ impl PartitionedHashJoinStream { } if let Some(next_count) = repartition_request { - hhj_debug(|| { - format!( - "try_partition_build_side early repartition request next_count={next_count}" - ) - }); return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); } @@ -2149,15 +2089,6 @@ impl PartitionedHashJoinStream { && self.repartition_worthwhile(max_spilled_bytes) { if let Some(next_count) = self.next_partition_count() { - hhj_debug(|| { - format!( - "try_partition_build_side repartition due to spill (max_spilled_bytes={} threshold={} any_spilled={}) next_count={}", - max_spilled_bytes, - self.memory_threshold, - any_spilled, - next_count - ) - }); return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); } } @@ -2328,14 +2259,10 @@ impl PartitionedHashJoinStream { fn wake_stream_waiter(&mut self) { while self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { if let Some(next_part) = self.probe_scheduler_waiting_for_stream.pop_front() { - hhj_debug(|| format!("wake_stream_waiter considering part {next_part}")); if next_part >= self.partition_pending.len() { continue; } if self.partition_pending[next_part] { - hhj_debug(|| { - format!("wake_stream_waiter skipping part {next_part} (already pending)") - }); continue; } if let Some(Some(desc)) = @@ -2345,18 +2272,12 @@ impl PartitionedHashJoinStream { let waiting_for_probe = matches!(self.state, PartitionedHashJoinState::WaitingForProbe); self.pending_partitions.push_back(desc); - hhj_debug(|| { - format!( - "wake_stream_waiter scheduled part {next_part}, waiting_for_probe={waiting_for_probe}" - ) - }); if waiting_for_probe { self.transition_to_next_partition(); } break; } } else { - hhj_debug(|| "wake_stream_waiter nothing to wake".to_string()); break; } } @@ -2370,12 +2291,6 @@ impl PartitionedHashJoinStream { ) -> Result { let part_id = descriptor.build_index; self.schedule_probe_task(descriptor); - hhj_debug(|| { - format!( - "poll_probe_stage_task part {part_id} start, queue_len={}", - self.probe_task_scheduler.len() - ) - }); let mut iterations = self.probe_task_scheduler.len(); while iterations > 0 { @@ -2388,9 +2303,6 @@ impl PartitionedHashJoinStream { match SchedulerTask::Probe(probe_task).poll(self, Some(cx))? { TaskPoll::ProbeReady(desc) => { let ready_part = desc.build_index; - hhj_debug(|| { - format!("probe task ready for part {ready_part}") - }); if ready_part >= self.probe_scheduler_inflight.len() { self.probe_scheduler_inflight .resize(ready_part + 1, false); @@ -2409,7 +2321,6 @@ impl PartitionedHashJoinStream { } } TaskPoll::Pending(next_task) => { - hhj_debug(|| "probe task pending, requeue".to_string()); self.probe_task_scheduler.push_task(next_task); } TaskPoll::YieldProbe { @@ -2427,9 +2338,6 @@ impl PartitionedHashJoinStream { } TaskPoll::ProbeFinished(desc) => { let finished_part = desc.build_index; - hhj_debug(|| { - format!("probe task finished for part {finished_part}") - }); if finished_part >= self.probe_scheduler_inflight.len() { self.probe_scheduler_inflight .resize(finished_part + 1, false); @@ -2449,7 +2357,6 @@ impl PartitionedHashJoinStream { } } TaskPoll::YieldFinalize(task) => { - hhj_debug(|| "finalize task yielded".to_string()); self.probe_task_scheduler.push_task(task); } TaskPoll::Ready(_) => { @@ -2460,9 +2367,6 @@ impl PartitionedHashJoinStream { } } other_task => { - hhj_debug(|| { - "non-probe task encountered in probe scheduler".to_string() - }); // Unexpected task type for probe scheduling; push back to preserve semantics. self.probe_task_scheduler.push_task(other_task); } @@ -2470,12 +2374,6 @@ impl PartitionedHashJoinStream { } let queue_len = self.probe_task_scheduler.len(); - hhj_debug(|| { - format!( - "poll_probe_stage_task part {part_id} returning Pending (queue_len={})", - queue_len - ) - }); if queue_len > 0 { cx.waker().wake_by_ref(); } @@ -2619,7 +2517,6 @@ impl PartitionedHashJoinStream { partition_state: &ProcessPartitionState, ) -> Poll>>> { let build_index = partition_state.descriptor.build_index; - hhj_debug(|| format!("process_partition enter part {build_index}")); // Guard against invalid partition ids (off-by-one protection) if build_index >= self.build_partitions.len() { @@ -2676,30 +2573,18 @@ impl PartitionedHashJoinStream { if !has_active_batch { match self.poll_probe_stage_task(cx, &partition_state.descriptor)? { ProbeTaskStatus::Ready => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Ready") - }); has_active_batch = true; } ProbeTaskStatus::Pending => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Pending") - }); return Poll::Pending; } ProbeTaskStatus::WaitingForStream => { - hhj_debug(|| { - format!("process_partition part {build_index} -> WaitingForStream") - }); self.enqueue_stream_waiter(build_index); self.current_partition = None; self.transition_to_next_partition(); return Poll::Ready(Ok(StatefulStreamResult::Continue)); } ProbeTaskStatus::Finished => { - hhj_debug(|| { - format!("process_partition part {build_index} -> Finished") - }); self.release_partition_resources(build_index); self.advance_to_next_partition(); return Poll::Ready(Ok(StatefulStreamResult::Continue)); @@ -3736,7 +3621,6 @@ impl Stream for PartitionedHashJoinStream { cx: &mut Context<'_>, ) -> Poll> { loop { - hhj_debug(|| format!("poll_next state {:?}", self.state)); match self.state.clone() { PartitionedHashJoinState::PartitionBuildSide => { // Collect build side and partition it @@ -3749,7 +3633,6 @@ impl Stream for PartitionedHashJoinStream { Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), Poll::Pending => return Poll::Pending, } - hhj_debug(|| format!("restarting build pass state={:?}", self.state)); match self.partition_build_side(left_data) { Ok(StatefulStreamResult::Continue) => continue, Ok(StatefulStreamResult::Ready(Some(batch))) => { @@ -3798,24 +3681,11 @@ impl Stream for PartitionedHashJoinStream { PartitionedHashJoinState::WaitingForProbe => { if self.pending_partitions.is_empty() { if self.probe_scheduler_waiting_for_stream.is_empty() { - hhj_debug(|| { - "WaitingForProbe -> HandleUnmatchedRows (no waiters)" - .to_string() - }); self.state = PartitionedHashJoinState::HandleUnmatchedRows; continue; } - hhj_debug(|| { - "WaitingForProbe pending=0 waiters>0, parking".to_string() - }); return Poll::Pending; } else { - hhj_debug(|| { - format!( - "WaitingForProbe woke with {} pending partitions", - self.pending_partitions.len() - ) - }); self.transition_to_next_partition(); continue; } From 29ecfa2de3536d35a0b893201d540982750b34d6 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Thu, 13 Nov 2025 20:26:26 +0200 Subject: [PATCH 35/36] Cleanup helpers and commented code --- .../src/joins/hash_join/partitioned.rs | 391 ------------------ .../src/joins/hash_join/stream.rs | 204 --------- 2 files changed, 595 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 40c8571df2810..9cdb010358713 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -419,10 +419,6 @@ pub(super) struct PartitionedHashJoinStream { pub emitted_rows_per_part: Vec, /// Metrics: total candidate pairs before equality per partition pub candidate_pairs_per_part: Vec, - /// One-time flag to run shadow verification per partition - pub verify_once_per_part: Vec, - /// One-time flag for filter debug logging per partition - pub filter_debug_once_per_part: Vec, /// Pending async spill reload stream for build partitions pub pending_reload_stream: Option, /// Accumulated batches for pending reload @@ -471,8 +467,6 @@ impl PartitionedHashJoinStream { self.matched_rows_per_part = vec![0; n]; self.emitted_rows_per_part = vec![0; n]; self.candidate_pairs_per_part = vec![0; n]; - self.verify_once_per_part = vec![false; n]; - self.filter_debug_once_per_part = vec![false; n]; self.partition_pending = vec![false; n]; self.partition_descriptors = (0..n).map(|_| None).collect(); } @@ -502,8 +496,6 @@ impl PartitionedHashJoinStream { self.matched_rows_per_part.push(0); self.emitted_rows_per_part.push(0); self.candidate_pairs_per_part.push(0); - self.verify_once_per_part.push(false); - self.filter_debug_once_per_part.push(false); self.partition_pending.push(false); self.partition_descriptors.push(None); idx @@ -1496,7 +1488,6 @@ impl PartitionedHashJoinStream { ) -> Poll> { if let Some(ref accumulator) = self.bounds_accumulator { if self.bounds_waiter.is_none() { - // println!( // "[spill-join] partition={} reporting build bounds (rows={})", // self.partition, // build_data.batch().num_rows() @@ -1512,7 +1503,6 @@ impl PartitionedHashJoinStream { if let Some(waiter) = self.bounds_waiter.as_mut() { match waiter.get(cx) { Poll::Ready(Ok(_)) => { - // println!( // "[spill-join] partition={} build bounds reported", // self.partition // ); @@ -1520,7 +1510,6 @@ impl PartitionedHashJoinStream { } Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => { - // println!( // "[spill-join] partition={} waiting on shared bounds barrier", // self.partition // ); @@ -1589,8 +1578,6 @@ impl PartitionedHashJoinStream { self.pending_reload_batches.as_slice(), ) .map_err(DataFusionError::from)?; - - // println!( // "Reloaded spilled build partition {} for probing (rows={})", // part_id, // concatenated.num_rows() @@ -1776,8 +1763,6 @@ impl PartitionedHashJoinStream { matched_rows_per_part: vec![0; num_partitions], emitted_rows_per_part: vec![0; num_partitions], candidate_pairs_per_part: vec![0; num_partitions], - verify_once_per_part: vec![false; num_partitions], - filter_debug_once_per_part: vec![false; num_partitions], partition_pending: vec![false; num_partitions], partition_descriptors: (0..num_partitions).map(|_| None).collect(), }) @@ -2782,59 +2767,6 @@ impl PartitionedHashJoinStream { } }; // Debug: log ON expressions and output mapping once we have both sides - /* let on_left_desc = self - .on_left - .iter() - .map(|e| format!("{}", e)) - .collect::>() - .join(", "); - let on_right_desc = self - .on_right - .iter() - .map(|e| format!("{}", e)) - .collect::>() - .join(", "); - let mapping_desc = self - .column_indices - .iter() - .map(|ci| { - let side = match ci.side { - JoinSide::Left => "L", - JoinSide::Right => "R", - JoinSide::None => "M", - }; - format!("{}@{}", side, ci.index) - }) - .collect::>() - .join(", ");*/ - // println!( - // "[spill-join] ON build=[{}] | probe=[{}] | out=[{}]", - // on_left_desc, on_right_desc, mapping_desc - // ); - - // Log resolved output column names for the current mapping - /*let out_names = self - .column_indices - .iter() - .map(|ci| match ci.side { - JoinSide::Left => { - format!("L:{}", build_batch.schema().field(ci.index).name()) - } - JoinSide::Right => { - format!("R:{}", probe_batch.schema().field(ci.index).name()) - } - JoinSide::None => "M:mark".to_string(), - }) - .collect::>() - .join(", "); - // println!("[spill-join] OUT columns: {}", out_names); - - // println!( - // "[spill-join] Partition {} build hashmap empty? {}", - // build_index, - // build_hashmap.is_empty() - // );*/ - // Lookup against hash map with limit let (probe_indices, build_indices, next_offset) = build_hashmap .get_matched_indices_with_limit_offset( @@ -2850,7 +2782,6 @@ impl PartitionedHashJoinStream { self.candidate_pairs_per_part[build_index] = self.candidate_pairs_per_part [build_index] .saturating_add(build_indices.len()); - // println!( // "[spill-join] Candidates before equality: build_ids={}, probe_ids={}, build_rows={}, probe_rows={}", // build_indices.len(), // probe_indices.len(), @@ -2867,100 +2798,10 @@ impl PartitionedHashJoinStream { self.null_equality, )?; - // Shadow verify on INNER join with single Int64 key (first 50k rows) - /*if matches!(self.join_type, JoinType::Inner) - && build_values.len() == 1 - && probe_values.len() == 1 - && build_values[0].data_type() == &arrow::datatypes::DataType::Int64 - && probe_values[0].data_type() - == &arrow::datatypes::DataType::Int64 - && !self.verify_once_per_part[build_index] - { - use arrow::array::Int64Array; - use std::collections::HashMap; - let bcol = build_values[0] - .as_any() - .downcast_ref::() - .unwrap(); - let pcol = probe_values[0] - .as_any() - .downcast_ref::() - .unwrap(); - let mut map: HashMap = HashMap::new(); - let max_b = bcol.len().min(50_000); - for i in 0..max_b { - if bcol.is_null(i) { - continue; - } - let k = bcol.value(i); - *map.entry(k).or_insert(0) += 1; - } - /*let mut expect = 0usize; - let max_p = pcol.len().min(50_000); - for i in 0..max_p { - if pcol.is_null(i) { - continue; - } - let k = pcol.value(i); - if let Some(&c) = map.get(&k) { - expect += c; - } - } - // println!( - // "[spill-join][verify] part={} expect_pairs~{} vs actual_after_eq={}", - // build_index, - // expect, - // build_indices.len() - // );*/ - self.verify_once_per_part[build_index] = true; - }*/ - - // Debug: log key data types and sample matched pairs - /*if !build_indices.is_empty() { - /*let build_types = build_values - .iter() - .map(|a| format!("{:?}", a.data_type())) - .collect::>() - .join(", "); - let probe_types = probe_values - .iter() - .map(|a| format!("{:?}", a.data_type())) - .collect::>() - .join(", ");*/ - // println!( - // "[spill-join] Key types: build=[{}], probe=[{}], null_equality={:?}", - // build_types, probe_types, self.null_equality - // ); - let sample = build_indices.len().min(5); - let mut pairs = Vec::new(); - for i in 0..sample { - let b = build_indices.value(i) as usize; - let p = probe_indices.value(i) as usize; - // Include actual first-key values for sanity checks - let bk = &build_values[0]; - let pk = &probe_values[0]; - let bv = arrow::util::display::array_value_to_string(bk.as_ref(), b) - .unwrap_or_else(|_| "".to_string()); - let pv = arrow::util::display::array_value_to_string(pk.as_ref(), p) - .unwrap_or_else(|_| "".to_string()); - pairs.push(format!("({},{})", bv, pv)); - } - // println!( - // "[spill-join] Sample key pairs {} -> {}: {}", - // sample, - // build_indices.len(), - // pairs.join(", ") - // ); - }*/ - // Apply residual join filter if present let mut build_indices = build_indices; let mut probe_indices = probe_indices; if let Some(filter) = &self.filter { - let before_len = build_indices.len(); - // let before_build_indices = build_indices.clone(); - //let before_probe_indices = probe_indices.clone(); - let (filtered_build_indices, filtered_probe_indices) = apply_join_filter_to_indices( build_batch, @@ -2972,108 +2813,6 @@ impl PartitionedHashJoinStream { None, )?; - if !self.filter_debug_once_per_part[build_index] { - /* - // println!( - // "[spill-join][filter-debug] part={} filter_before={} filter_after={}", - // build_index, - // before_len, - // filtered_build_indices.len() - // ); - - let sample = filtered_build_indices.len().min(5); - for i in 0..sample { - let build_row = filtered_build_indices.value(i) as usize; - let probe_row = filtered_probe_indices.value(i) as usize; - - let build_schema = build_batch.schema(); - let build_vals = (0..build_batch.num_columns()) - .map(|col| { - let name = build_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - build_batch.column(col).as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); - - let probe_schema = probe_batch.schema(); - let probe_vals = (0..probe_batch.num_columns()) - .map(|col| { - let name = probe_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - probe_batch.column(col).as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); - - // println!( - // "[spill-join][filter-debug] sample {} build {{{}}} probe {{{}}}", - // i, build_vals, probe_vals - // ); - } - - if filtered_build_indices.len() == 0 { - let sample_removed = before_build_indices.len().min(5); - for i in 0..sample_removed { - let build_row = before_build_indices.value(i) as usize; - let probe_row = before_probe_indices.value(i) as usize; - - let build_schema = build_batch.schema(); - let build_vals = (0..build_batch.num_columns()) - .map(|col| { - let name = build_schema.field(col).name(); - let value = - arrow::util::display::array_value_to_string( - build_batch.column(col).as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); - - let probe_schema = probe_batch.schema(); - /*let probe_vals = (0..probe_batch.num_columns()) - .map(|col| { - let name = probe_schema.field(col).name(); - let value = - arrow::util::display::array_value_to_string( - probe_batch.column(col).as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", ");*/ - - // println!( - // "[spill-join][filter-debug] removed sample {} build {{{}}} probe {{{}}}", - // i, build_vals, probe_vals - // ); - } - }*/ - - self.filter_debug_once_per_part[build_index] = true; - } - - if before_len != filtered_build_indices.len() { - // println!( - // "[spill-join][filter-debug] part={} filter removed {} rows", - // build_index, - // before_len - filtered_build_indices.len() - // ); - } - build_indices = filtered_build_indices; probe_indices = filtered_probe_indices; } @@ -3087,118 +2826,6 @@ impl PartitionedHashJoinStream { None }; - // Log sample matches even if no residual filter remains, to debug equality behavior - /*if !self.filter_debug_once_per_part[build_index] - || build_indices.len() != probe_indices.len() - { - let sample = build_indices.len().min(5); - for i in 0..sample { - let build_row = build_indices.value(i) as usize; - let probe_row = probe_indices.value(i) as usize; - - let build_schema = build_batch.schema(); - let build_vals = (0..build_batch.num_columns()) - .map(|col| { - let name = build_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - build_batch.column(col).as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); - - let probe_schema = probe_batch.schema(); - /* let probe_vals = (0..probe_batch.num_columns()) - .map(|col| { - let name = probe_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - probe_batch.column(col).as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", ");*/ - - // println!( - // "[spill-join][match-debug] part={} pair {} build {{{}}} probe {{{}}}", - // build_index, - // i, - // build_vals, - // probe_vals - // ); - } - - if build_indices.len() != probe_indices.len() { - // println!( - // "[spill-join][match-debug] part={} MISMATCH len build={} probe={}", - // build_index, - // build_indices.len(), - // probe_indices.len() - // ); - } - - self.filter_debug_once_per_part[build_index] = true; - }*/ - - // Debug counter: post-equality (before any alignment) - // println!( - // "[spill-join] After equality{} (pre-align): {}", - // if self.filter.is_some() { "+filter" } else { "" }, - // build_indices.len() - // ); - // Shadow verify for two-key joins (stringified) to catch type coercion issues - /*if matches!(self.join_type, JoinType::Inner) - && build_values.len() == 2 - && probe_values.len() == 2 - && !self.verify_once_per_part[build_index] - { - use std::collections::HashMap; - let mut map: HashMap = HashMap::new(); - let max_b = build_batch.num_rows().min(50_000); - for i in 0..max_b { - let k0 = arrow::util::display::array_value_to_string( - build_values[0].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let k1 = arrow::util::display::array_value_to_string( - build_values[1].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let key = format!("{}|{}", k0, k1); - *map.entry(key).or_insert(0) += 1; - } - let max_p = probe_batch.num_rows().min(50_000); - for i in 0..max_p { - let k0 = arrow::util::display::array_value_to_string( - probe_values[0].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let k1 = arrow::util::display::array_value_to_string( - probe_values[1].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let key = format!("{}|{}", k0, k1); - if let Some(&c) = map.get(&key) { - expect += c; - } - } - // println!( - // "[spill-join][verify2] part={} expect_pairs~{} vs actual_after_eq={}", - // build_index, - // expect, - // build_indices.len() - // ); - self.verify_once_per_part[build_index] = true; - }*/ // Accumulate matched rows per partition self.matched_rows_per_part[build_index] = self.matched_rows_per_part [build_index] @@ -3242,7 +2869,6 @@ impl PartitionedHashJoinStream { ); // Debug counter: after alignment (or effective no-op for other join types) - // println!("[spill-join] After alignment: {}", build_indices.len()); // Prepare ids for marking after we release borrows. Prefer the pre-alignment // matches (for join types like LeftAnti) so bitmap tracking remains accurate. @@ -3265,9 +2891,7 @@ impl PartitionedHashJoinStream { JoinType::RightMark | JoinType::RightSemi | JoinType::RightAnti ) { if matches!(self.join_type, JoinType::RightMark) { - // println!("[spill-join] Building output with JoinSide::Right (RightMark)"); } else { - // println!( // "[spill-join] Building output with JoinSide::Right ({:?})", // self.join_type // ); @@ -3327,7 +2951,6 @@ impl PartitionedHashJoinStream { } if result.num_rows() == 0 { - // println!( // "[spill-join] Skipping empty batch emission (partition={})", // build_index // ); @@ -3335,7 +2958,6 @@ impl PartitionedHashJoinStream { } self.join_metrics.output_batches.add(1); self.join_metrics.baseline.record_output(result.num_rows()); - // println!( // "[spill-join] Emitting batch: rows={} (partition={})", // result.num_rows(), // build_index @@ -3403,7 +3025,6 @@ impl PartitionedHashJoinStream { let empty_right_batch = RecordBatch::new_empty(Arc::clone(&self.probe_schema)); - // println!( // "Emitting unmatched rows chunk: partition={}, offset={}, size={} (total={})", // self.unmatched_partition, // self.unmatched_offset, @@ -3427,7 +3048,6 @@ impl PartitionedHashJoinStream { self.unmatched_left_indices_cache = None; self.unmatched_right_indices_cache = None; self.unmatched_offset = 0; - // println!( // "Finished emitting unmatched rows for partition {}", // self.unmatched_partition // ); @@ -3467,8 +3087,6 @@ impl PartitionedHashJoinStream { self.unmatched_partition += 1; return Poll::Ready(Ok(StatefulStreamResult::Continue)); }; - - // println!( // "Unmatched calculation for partition {} -> {} rows", // self.unmatched_partition, // left_indices.len() @@ -3506,7 +3124,6 @@ impl PartitionedHashJoinStream { if let Some(stream) = self.pending_reload_stream.as_mut() { match stream.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { - // println!( // "Reload stream yielded batch for build partition {} (rows={})", // self.unmatched_partition, // batch.num_rows() @@ -3530,8 +3147,6 @@ impl PartitionedHashJoinStream { self.pending_reload_batches.as_slice(), ) .map_err(DataFusionError::from)?; - - // println!( // "Reloaded spilled build partition {} for unmatched rows (rows={})", // self.unmatched_partition, // concatenated.num_rows() @@ -3560,7 +3175,6 @@ impl PartitionedHashJoinStream { values, reservation: new_reservation, }; - // println!( // "Prepared spilled partition {} as InMemory for unmatched emission", // self.unmatched_partition // ); @@ -3577,7 +3191,6 @@ impl PartitionedHashJoinStream { } Poll::Pending => { // Yield until more data is available from reload stream - // println!( // "Reload stream pending for build partition {} (accumulated_batches={})", // self.unmatched_partition, // self.pending_reload_batches.len() @@ -3636,7 +3249,6 @@ impl Stream for PartitionedHashJoinStream { match self.partition_build_side(left_data) { Ok(StatefulStreamResult::Continue) => continue, Ok(StatefulStreamResult::Ready(Some(batch))) => { - // println!( // "[spill-join] poll_next yielding initial batch: rows={}", // batch.num_rows() // ); @@ -3653,7 +3265,6 @@ impl Stream for PartitionedHashJoinStream { if self.num_partitions > 1 && !self.placeholder_emitted { self.placeholder_emitted = true; let empty = RecordBatch::new_empty(self.schema.clone()); - // println!( // "[spill-join] Emitting placeholder empty batch for partition {}", // build_index // ); @@ -3661,7 +3272,6 @@ impl Stream for PartitionedHashJoinStream { } match self.process_partition(cx, &partition_state) { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { - // println!( // "[spill-join] poll_next yielding process batch: rows={} (state partition={})", // batch.num_rows(), build_index // ); @@ -3693,7 +3303,6 @@ impl Stream for PartitionedHashJoinStream { PartitionedHashJoinState::HandleUnmatchedRows => { match self.handle_unmatched_rows(cx) { Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { - // println!( // "[spill-join] poll_next yielding unmatched batch: rows={}", // batch.num_rows() // ); diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index affbe2495bacf..d840f31fcb44a 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -290,49 +290,6 @@ pub(super) fn lookup_join_hashmap( null_equality, )?; - // Shadow verify for two-key INNER joins to catch coercion issues in classic path - /*if build_side_values.len() == 2 && probe_side_values.len() == 2 { - use std::collections::HashMap; - let mut map: HashMap = HashMap::new(); - let max_b = build_side_values[0].len().min(50_000); - for i in 0..max_b { - let k0 = arrow::util::display::array_value_to_string( - build_side_values[0].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let k1 = arrow::util::display::array_value_to_string( - build_side_values[1].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let key = format!("{}|{}", k0, k1); - *map.entry(key).or_insert(0) += 1; - } - let max_p = probe_side_values[0].len().min(50_000); - for i in 0..max_p { - let k0 = arrow::util::display::array_value_to_string( - probe_side_values[0].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let k1 = arrow::util::display::array_value_to_string( - probe_side_values[1].as_ref(), - i, - ) - .unwrap_or_else(|_| "".to_string()); - let key = format!("{}|{}", k0, k1); - if let Some(&c) = map.get(&key) { - expect += c; - } - } - // println!( - // "[hash-join][verify2] expect_pairs~{} vs actual_after_eq={}", - // expect, - // build_indices.len() - // ); - }*/ - Ok((build_indices, probe_indices, next_offset)) } @@ -610,20 +567,6 @@ impl HashJoinStream { index_alignment_range_end = index_alignment_range_start; } - /* if matches!( - self.join_type, - JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark - ) { - // println!( - // "[hash-join] Align {:?}: pre-adjust right_indices={}, range={}..{} (next_offset_present={})", - // self.join_type, - // right_indices.len(), - // index_alignment_range_start, - // index_alignment_range_end, - // next_offset.is_some() - // ); - }*/ - let (left_indices, right_indices) = adjust_indices_by_join_type( left_indices, right_indices, @@ -632,153 +575,6 @@ impl HashJoinStream { self.right_side_ordered, )?; - /* if matches!( - self.join_type, - JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark - ) { - // println!( - // "[hash-join] Align {:?}: post-adjust unique_right_indices={} (range={}..{})", - // self.join_type, - // right_indices.len(), - // index_alignment_range_start, - // index_alignment_range_end - // ); - } - - if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) { - // println!( - // "[hash-join] Right {:?}: probe_batch_rows={}, unique_matched_right_indices={} (range={}..{})", - // self.join_type, - // state.batch.num_rows(), - // right_indices.len(), - // index_alignment_range_start, - // index_alignment_range_end - // ); - } - - // Log some matched pairs for debugging - let build_schema = build_side.left_data.batch().schema(); - let probe_schema = state.batch.schema(); - - let sample = left_indices.len().min(5); - if sample > 0 { - for i in 0..sample { - let build_row = left_indices.value(i) as usize; - let probe_row = right_indices.value(i) as usize; - - /*let build_vals = (0..build_schema.fields().len()) - .map(|col| { - let name = build_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - build_side.left_data.batch().column(col).as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", "); - */ - /*let probe_vals = (0..probe_schema.fields().len()) - .map(|col| { - let name = probe_schema.field(col).name(); - let value = arrow::util::display::array_value_to_string( - state.batch.column(col).as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - format!("{}={}", name, value) - }) - .collect::>() - .join(", ");*/ - - // println!( - // "[hash-join][match-debug] partition={} pair {} build {{{}}} probe {{{}}}", - // self.partition, - // i, - // build_vals, - // probe_vals - // ); - } - } - - let build_supply_idx = - build_schema - .fields() - .iter() - .enumerate() - .find_map(|(idx, f)| { - if f.name().to_ascii_lowercase().contains("ps_supplycost") { - Some(idx) - } else { - None - } - }); - - let probe_min_idx = - probe_schema - .fields() - .iter() - .enumerate() - .find_map(|(idx, f)| { - if f.name().to_ascii_lowercase().contains("min(") - || f.name().to_ascii_lowercase().contains("min_") - { - Some(idx) - } else { - None - } - }); - - if let (Some(build_supply_idx), Some(probe_min_idx)) = - (build_supply_idx, probe_min_idx) - { - let build_array = build_side.left_data.batch().column(build_supply_idx); - let probe_array = state.batch.column(probe_min_idx); - - for j in 0..left_indices.len() { - let build_row = left_indices.value(j) as usize; - let probe_row = right_indices.value(j) as usize; - - let build_value = arrow::util::display::array_value_to_string( - build_array.as_ref(), - build_row, - ) - .unwrap_or_else(|_| "".to_string()); - let probe_value = arrow::util::display::array_value_to_string( - probe_array.as_ref(), - probe_row, - ) - .unwrap_or_else(|_| "".to_string()); - - if build_value != probe_value { - // println!( - // "[hash-join][mismatch] partition={} build_row={} ps_supplycost={} min_cost={}", - // self.partition, - // build_row, - // build_value, - // probe_value - // ); - break; - } - } - } else { - // println!( - // "[hash-join][mismatch-debug] partition={} build_fields={:?} probe_fields={:?}", - // self.partition, - // build_schema - // .fields() - // .iter() - // .map(|f| f.name().clone()) - // .collect::>(), - // probe_schema - // .fields() - // .iter() - // .map(|f| f.name().clone()) - // .collect::>() - // ); - }*/ - let result = if matches!( self.join_type, JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark From 71a51cbaeb59c8f9e05e72019c10efb52b6e3b43 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Thu, 13 Nov 2025 22:08:07 +0200 Subject: [PATCH 36/36] Redundant structures cleanup --- .../src/joins/grace_hash_join/exec.rs | 1 - .../src/joins/hash_join/partitioned.rs | 36 ++---- .../src/joins/hash_join/scheduler.rs | 110 ++---------------- 3 files changed, 16 insertions(+), 131 deletions(-) diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index 3e81004dadc28..47530688c0b98 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -59,7 +59,6 @@ use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use crate::joins::grace_hash_join::stream::{GraceHashJoinStream, SpillFut}; -use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; use crate::metrics::SpillMetrics; use crate::spill::spill_manager::SpillLocation; use ahash::RandomState; diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs index 9cdb010358713..a4187c9d355b0 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -46,7 +46,7 @@ #[cfg(feature = "hybrid_hash_join_scheduler")] use super::scheduler::{ HybridTaskScheduler, ProbeDataPoll, ProbePartitionState, ProbeStageTask, - SchedulerConfig, SchedulerTask, TaskPoll, + SchedulerTask, TaskPoll, }; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::join_hash_map::{JoinHashMapType, JoinHashMapU32, JoinHashMapU64}; @@ -461,8 +461,7 @@ impl PartitionedHashJoinStream { self.probe_scheduler_waiting_for_stream = VecDeque::new(); self.probe_scheduler_active_streams = 0; self.probe_scheduler_max_streams = std::cmp::max(1, std::cmp::min(4, n)); - self.probe_task_scheduler = - HybridTaskScheduler::new(SchedulerConfig::from_stream(self)); + self.probe_task_scheduler = HybridTaskScheduler::new(); } self.matched_rows_per_part = vec![0; n]; self.emitted_rows_per_part = vec![0; n]; @@ -556,10 +555,7 @@ impl PartitionedHashJoinStream { if self.probe_scheduler_inflight[part_id] { return; } - let task = SchedulerTask::Probe(ProbeStageTask::new( - SchedulerConfig::from_stream(self), - descriptor.clone(), - )); + let task = SchedulerTask::Probe(ProbeStageTask::new(descriptor.clone())); self.probe_task_scheduler.push_task(task); self.probe_scheduler_inflight[part_id] = true; } @@ -1326,8 +1322,7 @@ impl PartitionedHashJoinStream { self.probe_states.clear(); #[cfg(feature = "hybrid_hash_join_scheduler")] { - self.probe_task_scheduler = - HybridTaskScheduler::new(SchedulerConfig::from_stream(self)); + self.probe_task_scheduler = HybridTaskScheduler::new(); self.probe_scheduler_inflight.clear(); self.probe_scheduler_waiting_for_stream.clear(); self.probe_scheduler_active_streams = 0; @@ -1696,12 +1691,8 @@ impl PartitionedHashJoinStream { .min(max_partition_count); #[cfg(feature = "hybrid_hash_join_scheduler")] - let scheduler_config = SchedulerConfig { - memory_threshold, - batch_size, - max_partition_count, - max_probe_streams: std::cmp::max(1, std::cmp::min(4, num_partitions)), - }; + let scheduler_max_probe_streams = + std::cmp::max(1, std::cmp::min(4, num_partitions)); Ok(Self { partition, @@ -1726,7 +1717,7 @@ impl PartitionedHashJoinStream { .map(|_| ProbePartitionState::new()) .collect(), #[cfg(feature = "hybrid_hash_join_scheduler")] - probe_task_scheduler: HybridTaskScheduler::new(scheduler_config.clone()), + probe_task_scheduler: HybridTaskScheduler::new(), #[cfg(feature = "hybrid_hash_join_scheduler")] probe_scheduler_inflight: vec![false; num_partitions], #[cfg(feature = "hybrid_hash_join_scheduler")] @@ -1734,7 +1725,7 @@ impl PartitionedHashJoinStream { #[cfg(feature = "hybrid_hash_join_scheduler")] probe_scheduler_active_streams: 0, #[cfg(feature = "hybrid_hash_join_scheduler")] - probe_scheduler_max_streams: scheduler_config.max_probe_streams, + probe_scheduler_max_streams: scheduler_max_probe_streams, current_partition: None, pending_partitions: VecDeque::new(), probe_spill_manager, @@ -1774,9 +1765,7 @@ impl PartitionedHashJoinStream { &mut self, build_data: Arc, ) -> Result>> { - let config = SchedulerConfig::from_stream(self); - HybridTaskScheduler::with_build_task(config, build_data) - .run_until_build_finished(self) + HybridTaskScheduler::with_build_task(build_data).run_until_build_finished(self) } /// Partition build-side data into multiple partitions (legacy serial path) @@ -2341,14 +2330,7 @@ impl PartitionedHashJoinStream { } } } - TaskPoll::YieldFinalize(task) => { - self.probe_task_scheduler.push_task(task); - } - TaskPoll::Ready(_) => { - // Build/finalize ready events are ignored in probe context. - } TaskPoll::BuildFinished(_) => {} - TaskPoll::FinalizeFinished => {} } } other_task => { diff --git a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs index 8a979713adee2..79217e759f751 100644 --- a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs +++ b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs @@ -38,36 +38,14 @@ use datafusion_execution::disk_manager::RefCountedTempFile; use crate::joins::join_hash_map::JoinHashMapOffset; use crate::spill::in_progress_spill_file::InProgressSpillFile; -/// Configuration shared across scheduler components. -#[derive(Clone, Debug)] -pub(super) struct SchedulerConfig { - pub memory_threshold: usize, - pub batch_size: usize, - pub max_partition_count: usize, - pub max_probe_streams: usize, -} - -impl SchedulerConfig { - pub fn from_stream(stream: &PartitionedHashJoinStream) -> Self { - Self { - memory_threshold: stream.memory_threshold, - batch_size: stream.batch_size, - max_partition_count: stream.max_partition_count, - max_probe_streams: std::cmp::max(1, std::cmp::min(4, stream.num_partitions)), - } - } -} - /// Minimal scheduler capable of running build / probe / finalize tasks. pub(super) struct HybridTaskScheduler { - config: SchedulerConfig, ready_queue: VecDeque, } impl HybridTaskScheduler { - pub fn new(config: SchedulerConfig) -> Self { + pub fn new() -> Self { Self { - config, ready_queue: VecDeque::new(), } } @@ -84,52 +62,25 @@ impl HybridTaskScheduler { self.ready_queue.len() } - pub fn is_empty(&self) -> bool { - self.ready_queue.is_empty() - } - - pub fn with_build_task( - config: SchedulerConfig, - build_data: Arc, - ) -> Self { - let mut scheduler = Self::new(config.clone()); + pub fn with_build_task(build_data: Arc) -> Self { + let mut scheduler = Self::new(); scheduler .ready_queue - .push_back(SchedulerTask::Build(BuildStageTask::new( - config, build_data, - ))); + .push_back(SchedulerTask::Build(BuildStageTask::new(build_data))); scheduler } - pub fn enqueue_probe_task(&mut self, descriptor: PartitionDescriptor) { - self.ready_queue - .push_back(SchedulerTask::Probe(ProbeStageTask::new( - self.config.clone(), - descriptor, - ))); - } - - pub fn enqueue_finalize_task(&mut self, descriptor: PartitionDescriptor) { - self.ready_queue - .push_back(SchedulerTask::Finalize(FinalizeStageTask::new( - self.config.clone(), - descriptor, - ))); - } - pub fn run_until_build_finished( &mut self, stream: &mut PartitionedHashJoinStream, ) -> Result>> { while let Some(task) = self.ready_queue.pop_front() { match task.poll(stream, None)? { - TaskPoll::Ready(_) => continue, TaskPoll::ProbeReady(_) => continue, TaskPoll::Pending(task) => self.ready_queue.push_back(task), TaskPoll::BuildFinished(result) => return Ok(result), TaskPoll::YieldProbe { task, .. } => self.ready_queue.push_back(task), - TaskPoll::YieldFinalize(task) => self.ready_queue.push_back(task), - TaskPoll::ProbeFinished(_) | TaskPoll::FinalizeFinished => continue, + TaskPoll::ProbeFinished(_) => continue, } } Err(internal_datafusion_err!( @@ -141,11 +92,9 @@ impl HybridTaskScheduler { pub(super) enum SchedulerTask { Build(BuildStageTask), Probe(ProbeStageTask), - Finalize(FinalizeStageTask), } pub(super) enum TaskPoll { - Ready(Option), ProbeReady(PartitionDescriptor), Pending(SchedulerTask), BuildFinished(StatefulStreamResult>), @@ -154,10 +103,7 @@ pub(super) enum TaskPoll { task: SchedulerTask, descriptor: PartitionDescriptor, }, - /// Finalize task yielded without producing output. - YieldFinalize(SchedulerTask), ProbeFinished(PartitionDescriptor), - FinalizeFinished, } impl SchedulerTask { @@ -191,19 +137,12 @@ impl SchedulerTask { ProbeTaskEvent::Finished => Ok(TaskPoll::ProbeFinished(descriptor)), } } - SchedulerTask::Finalize(task) => match task.poll(stream)? { - FinalizeTaskEvent::Pending(next_task) => { - Ok(TaskPoll::YieldFinalize(SchedulerTask::Finalize(next_task))) - } - FinalizeTaskEvent::Finished => Ok(TaskPoll::FinalizeFinished), - }, } } } /// Build stage broken into multiple cooperative steps so the scheduler can interleave it. struct BuildStageTask { - config: SchedulerConfig, build_data: Option>, step: BuildTaskStep, warmup_remaining: usize, @@ -217,9 +156,8 @@ enum BuildTaskStep { } impl BuildStageTask { - fn new(config: SchedulerConfig, build_data: Arc) -> Self { + fn new(build_data: Arc) -> Self { Self { - config, build_data: Some(build_data), step: BuildTaskStep::Init, warmup_remaining: 2, // allow a couple of yields before heavy work @@ -312,15 +250,13 @@ enum ProbeTaskState { } pub(super) struct ProbeStageTask { - _config: SchedulerConfig, descriptor: PartitionDescriptor, state: ProbeTaskState, } impl ProbeStageTask { - pub fn new(config: SchedulerConfig, descriptor: PartitionDescriptor) -> Self { + pub fn new(descriptor: PartitionDescriptor) -> Self { Self { - _config: config, descriptor, state: ProbeTaskState::Init, } @@ -364,35 +300,3 @@ enum ProbeTaskEvent { NeedStream(ProbeStageTask), Finished, } - -struct FinalizeStageTask { - config: SchedulerConfig, - descriptor: PartitionDescriptor, - yielded_once: bool, -} - -impl FinalizeStageTask { - fn new(config: SchedulerConfig, descriptor: PartitionDescriptor) -> Self { - Self { - config, - descriptor, - yielded_once: false, - } - } - - fn poll(self, _stream: &mut PartitionedHashJoinStream) -> Result { - if self.yielded_once { - Ok(FinalizeTaskEvent::Finished) - } else { - Ok(FinalizeTaskEvent::Pending(Self { - yielded_once: true, - ..self - })) - } - } -} - -enum FinalizeTaskEvent { - Pending(FinalizeStageTask), - Finished, -}