Skip to content

Commit 84df67e

Browse files
committed
FIXUP 2 revert DistributedPlan.unmarshal
1 parent 4dbbfcf commit 84df67e

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

src/dataframe.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,20 @@ use datafusion::execution::SendableRecordBatchStream;
4141
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
4242
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
4343
use datafusion::prelude::*;
44-
44+
use datafusion_proto::physical_plan::AsExecutionPlan;
45+
use datafusion_proto::protobuf::PhysicalPlanNode;
46+
use prost::Message;
4547
use pyo3::exceptions::PyValueError;
4648
use pyo3::prelude::*;
4749
use pyo3::pybacked::PyBackedStr;
48-
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
50+
use pyo3::types::{PyBytes, PyCapsule, PyDict, PyTuple, PyTupleMethods};
4951
use tokio::task::JoinHandle;
5052

5153
use crate::catalog::PyTable;
5254
use crate::common::df_schema::PyDFSchema;
5355
use crate::errors::{py_datafusion_err, PyDataFusionError};
5456
use crate::expr::sort_expr::to_sort_expressions;
55-
use crate::physical_plan::PyExecutionPlan;
57+
use crate::physical_plan::{codec, PyExecutionPlan};
5658
use crate::record_batch::PyRecordBatchStream;
5759
use crate::sql::logical::PyLogicalPlan;
5860
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
@@ -723,10 +725,20 @@ pub struct DistributedPlan {
723725
#[pymethods]
724726
impl DistributedPlan {
725727
#[new]
726-
fn new(physical_plan: PyExecutionPlan, min_size: usize) -> PyResult<Self> {
728+
fn unmarshal(state: Bound<PyDict>) -> PyResult<Self> {
729+
let ctx = SessionContext::new();
730+
let serialized_plan = state
731+
.get_item("plan")?
732+
.expect("missing key `plan` from state");
733+
let serialized_plan = serialized_plan.downcast::<PyBytes>()?.as_bytes();
734+
let min_size = state
735+
.get_item("min_size")?
736+
.expect("missing key `min_size` from state")
737+
.extract::<usize>()?;
738+
let plan = deserialize_plan(serialized_plan, &ctx)?;
727739
Ok(Self {
728740
min_size,
729-
physical_plan,
741+
physical_plan: PyExecutionPlan::new(plan),
730742
})
731743
}
732744

@@ -835,6 +847,20 @@ impl DistributedPlan {
835847
}
836848
}
837849

850+
fn deserialize_plan(
851+
serialized_plan: &[u8],
852+
ctx: &SessionContext,
853+
) -> PyResult<Arc<dyn ExecutionPlan>> {
854+
deltalake::ensure_initialized();
855+
let node = PhysicalPlanNode::decode(serialized_plan)
856+
.map_err(|e| DataFusionError::External(Box::new(e)))
857+
.map_err(py_datafusion_err)?;
858+
let plan = node
859+
.try_into_physical_plan(ctx, ctx.runtime_env().as_ref(), codec())
860+
.map_err(py_datafusion_err)?;
861+
Ok(plan)
862+
}
863+
838864
/// Print DataFrame
839865
fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
840866
// Get string representation of record batches

0 commit comments

Comments
 (0)