Skip to content

Commit d605d59

Browse files
committed
Building blocks for Ray DataFusionDatasource
1 parent 6badcbd commit d605d59

File tree

3 files changed

+167
-4
lines changed

3 files changed

+167
-4
lines changed

python/datafusion/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,9 @@ def count(self) -> int:
708708
"""
709709
return self.df.count()
710710

711+
def distributed_plan(self, num_shards: int):
712+
return self.df.distributed_plan(num_shards)
713+
711714
@deprecated("Use :py:func:`unnest_columns` instead.")
712715
def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame:
713716
"""See :py:func:`unnest_columns`."""

src/dataframe.rs

Lines changed: 161 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,27 @@ use arrow::util::display::{ArrayFormatter, FormatOptions};
2727
use datafusion::arrow::datatypes::Schema;
2828
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2929
use datafusion::arrow::util::pretty;
30-
use datafusion::common::UnnestOptions;
31-
use datafusion::config::{CsvOptions, TableParquetOptions};
30+
use datafusion::common::stats::Precision;
31+
use datafusion::common::{DFSchema, UnnestOptions};
32+
use datafusion::config::{ConfigOptions, CsvOptions, TableParquetOptions};
3233
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
33-
use datafusion::execution::SendableRecordBatchStream;
34+
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
35+
use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
36+
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
37+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
3438
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
39+
use datafusion::physical_plan::{displayable, execute_stream, ExecutionPlan};
3540
use datafusion::prelude::*;
41+
use datafusion_expr::registry::MemoryFunctionRegistry;
42+
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
43+
use datafusion_proto::protobuf::PhysicalPlanNode;
44+
use deltalake::delta_datafusion::DeltaPhysicalCodec;
45+
use prost::Message;
3646
use pyo3::exceptions::{PyTypeError, PyValueError};
3747
use pyo3::prelude::*;
3848
use pyo3::pybacked::PyBackedStr;
3949
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
4050
use tokio::task::JoinHandle;
41-
4251
use crate::errors::py_datafusion_err;
4352
use crate::expr::sort_expr::to_sort_expressions;
4453
use crate::physical_plan::PyExecutionPlan;
@@ -49,6 +58,7 @@ use crate::{
4958
errors::DataFusionError,
5059
expr::{sort_expr::PySortExpr, PyExpr},
5160
};
61+
use crate::common::df_schema::PyDFSchema;
5262

5363
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
5464
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -650,6 +660,153 @@ impl PyDataFrame {
650660
fn count(&self, py: Python) -> PyResult<usize> {
651661
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
652662
}
663+
664+
fn distributed_plan(&self, num_shards: usize, py: Python<'_>) -> PyResult<DistributedPlan> {
665+
let distributed_plan = wait_for_future(py, split_physical_plan(&self.df, num_shards))
666+
.map_err(PyErr::from)?;
667+
Ok(distributed_plan)
668+
}
669+
670+
}
671+
672+
#[pyclass(get_all)]
673+
#[derive(Debug, Clone)]
674+
pub struct Statistics {
675+
num_bytes: Option<usize>,
676+
num_rows: Option<usize>,
677+
}
678+
679+
impl Statistics {
680+
fn new(plan: &dyn ExecutionPlan) -> Self {
681+
fn extract(prec: Precision<usize>) -> Option<usize> {
682+
match prec {
683+
Precision::Exact(n) | Precision::Inexact(n) => Some(n),
684+
Precision::Absent => None,
685+
}
686+
}
687+
if let Ok(stats) = plan.statistics() {
688+
let num_bytes = extract(stats.total_byte_size);
689+
let num_rows = extract(stats.num_rows);
690+
Statistics { num_bytes, num_rows}
691+
} else {
692+
Statistics { num_bytes: None, num_rows: None }
693+
}
694+
}
695+
}
696+
697+
#[pyclass(get_all)]
698+
#[derive(Debug, Clone)]
699+
pub struct Shard {
700+
stats: Statistics,
701+
serialized_plan: Vec<u8>,
702+
}
703+
704+
impl Shard {
705+
pub fn try_new(plan: &Arc<dyn ExecutionPlan>) -> Result<Self, DataFusionError> {
706+
let stats = Statistics::new(plan.as_ref());
707+
let serialized_plan = PhysicalPlanNode::try_from_physical_plan(plan.clone(), Self::codec())?
708+
.encode_to_vec();
709+
Ok(Self { stats, serialized_plan })
710+
}
711+
712+
fn codec() -> &'static dyn PhysicalExtensionCodec {
713+
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
714+
&CODEC
715+
}
716+
}
717+
#[pymethods]
718+
impl Shard {
719+
pub fn stream(&self) -> PyResult<PyRecordBatchStream> {
720+
shard_stream(self.serialized_plan.as_ref())
721+
}
722+
}
723+
724+
#[pyclass(get_all)]
725+
#[derive(Debug, Clone)]
726+
pub struct DistributedPlan {
727+
shards: Vec<Shard>,
728+
schema: PyDFSchema,
729+
stats: Statistics,
730+
}
731+
732+
async fn split_physical_plan(df: &DataFrame, num_shards: usize) -> Result<DistributedPlan, DataFusionError> {
733+
fn split(plan: &Arc<dyn ExecutionPlan>, num_shards: usize) -> Vec<Arc<dyn ExecutionPlan>> {
734+
if let Some(parquet) = plan.as_any().downcast_ref::<ParquetExec>() {
735+
let parquet = if let Ok(Some(repartitioned)) = parquet.repartitioned(num_shards, &ConfigOptions::default()) {
736+
repartitioned.as_any().downcast_ref::<ParquetExec>()
737+
.expect("repartitioned parquet is no longer parquet")
738+
.clone()
739+
} else { // repartition failed
740+
parquet.clone()
741+
};
742+
let config = parquet.base_config();
743+
config
744+
.file_groups
745+
.iter()
746+
.map(|shard| {
747+
FileScanConfig {
748+
object_store_url: config.object_store_url.clone(),
749+
file_schema: config.file_schema.clone(),
750+
file_groups: shard.iter().map(|file| vec![file.to_owned()]).collect(), // one partition per file
751+
statistics: config.statistics.clone(),
752+
projection: config.projection.clone(),
753+
projection_deep: config.projection_deep.clone(),
754+
limit: config.limit,
755+
table_partition_cols: config.table_partition_cols.clone(),
756+
output_ordering: config.output_ordering.clone(),
757+
}
758+
})
759+
.map(|config| {
760+
let mut builder = ParquetExecBuilder::new(config)
761+
.with_table_parquet_options(parquet.table_parquet_options().clone());
762+
if let Some(predicate) = parquet.predicate() {
763+
builder = builder.with_predicate(predicate.clone());
764+
}
765+
builder.build_arc()
766+
})
767+
.map(|shard| shard as Arc<dyn ExecutionPlan>)
768+
.collect()
769+
} else if plan.children().len() == 0 { // TODO: split leaf nodes other than parquet?
770+
vec![plan.clone()]
771+
} else if plan.children().len() == 1 {
772+
plan.children().into_iter()
773+
.flat_map(|child| {
774+
split(child, num_shards)
775+
.into_iter()
776+
.map(|shard| plan.clone().with_new_children(vec![shard]))
777+
})
778+
.collect::<Result<Vec<_>, _>>()
779+
.expect("Unable to split plan")
780+
} else {
781+
panic!(
782+
"Only leaf or single-child plans are supported, found {}",
783+
displayable(plan.as_ref()).one_line()
784+
)
785+
}
786+
}
787+
let plan = df.clone().create_physical_plan().await?;
788+
let shards = split(&plan, num_shards)
789+
.iter()
790+
.map(Shard::try_new)
791+
.collect::<Result<Vec<_>, _>>()?;
792+
let schema = DFSchema::try_from(plan.schema().as_ref().to_owned())?.into();
793+
let stats = Statistics::new(plan.as_ref());
794+
Ok(DistributedPlan { shards, schema, stats })
795+
}
796+
797+
#[pyfunction]
798+
pub fn shard_stream(serialized_shard_plan: &[u8]) -> PyResult<PyRecordBatchStream> {
799+
deltalake::ensure_initialized();
800+
let registry = MemoryFunctionRegistry::default();
801+
let runtime = RuntimeEnvBuilder::new().build()?;
802+
let codec = DeltaPhysicalCodec {};
803+
let node = PhysicalPlanNode::decode(serialized_shard_plan)
804+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))
805+
.map_err(PyErr::from)?;
806+
let plan = node.try_into_physical_plan(&registry, &runtime, &codec)?;
807+
println!("Shard plan: {}", displayable(plan.as_ref()).one_line());
808+
let ctx = TaskContext::default();
809+
execute_stream(plan, Arc::new(ctx)).map(PyRecordBatchStream::new).map_err(PyErr::from)
653810
}
654811

655812
/// Print DataFrame

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
115115
#[cfg(feature = "substrait")]
116116
setup_substrait_module(py, &m)?;
117117

118+
m.add_class::<dataframe::Shard>()?;
119+
m.add_class::<dataframe::DistributedPlan>()?;
120+
m.add_wrapped(wrap_pyfunction!(dataframe::shard_stream))?;
118121
Ok(())
119122
}
120123

0 commit comments

Comments
 (0)