Skip to content

Commit

Permalink
docs(substrait): document SubstraitProducer
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarua committed Dec 27, 2024
1 parent 8a14160 commit 4a464cb
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 13 deletions.
3 changes: 3 additions & 0 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ use substrait::proto::{
/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans.
/// It can be implemented by users to allow for custom handling of relations, expressions, etc.
///
/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully
/// customizable Substrait serde.
///
/// # Example Usage
///
/// ```
Expand Down
105 changes: 93 additions & 12 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,89 @@ use substrait::{
version,
};

/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans.
/// It can be implemented by users to allow for custom handling of relations, expressions, etc.
///
/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully
/// customizable Substrait serde.
///
/// # Example Usage
///
/// ```
/// # use std::sync::Arc;
/// # use substrait::proto::{Expression, Rel};
/// # use substrait::proto::rel::RelType;
/// # use datafusion::common::DFSchemaRef;
/// # use datafusion::error::Result;
/// # use datafusion::execution::SessionState;
/// # use datafusion::logical_expr::{Between, Extension, Projection};
/// # use datafusion_substrait::extensions::Extensions;
/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer};
///
/// struct CustomSubstraitProducer {
/// extensions: Extensions,
/// state: Arc<SessionState>,
/// }
///
/// impl SubstraitProducer for CustomSubstraitProducer {
///
/// fn register_function(&mut self, signature: String) -> u32 {
/// self.extensions.register_function(signature)
/// }
///
/// fn get_extensions(self) -> Extensions {
/// self.extensions
/// }
///
/// // You can set additional metadata on the Rels you produce
/// fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> {
/// let mut rel = from_projection(self, plan)?;
/// match rel.rel_type {
/// Some(RelType::Project(mut project)) => {
/// let mut project = project.clone();
/// // set common metadata or advanced extension
/// project.common = None;
/// project.advanced_extension = None;
/// Ok(Box::new(Rel {
/// rel_type: Some(RelType::Project(project)),
/// }))
/// }
/// rel_type => Ok(Box::new(Rel { rel_type })),
/// }
/// }
///
/// // You can tweak how you convert expressions for your target system
/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result<Expression> {
/// // add your own encoding for Between
/// todo!()
/// }
///
/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait
/// fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> {
/// // implement your own serializer into Substrait
/// todo!()
/// }
/// }
/// ```
pub trait SubstraitProducer: Send + Sync + Sized {
/// Within a Substrait plan, functions are referenced using function anchors that are stored at
/// the top level of the [Plan] within
/// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction)
/// messages.
///
/// When given a function signature, this method should return the existing anchor for it if
/// there is one. Otherwise, it should generate a new anchor.
fn register_function(&mut self, signature: String) -> u32;

/// Consume the producer to generate the [Extensions] for the Substrait plan based on the
/// functions that have been registered
fn get_extensions(self) -> Extensions;

fn register_function(&mut self, signature: String) -> u32;
// Logical Plan Methods
// There is one method per LogicalPlan to allow for easy overriding of producer behaviour.
// These methods have default implementations calling the common handler code, to allow for users
// to re-use common handling logic.

// Logical Plans
fn consume_plan(&mut self, plan: &LogicalPlan) -> Result<Box<Rel>> {
to_substrait_rel(self, plan)
}
Expand Down Expand Up @@ -175,7 +252,11 @@ pub trait SubstraitProducer: Send + Sync + Sized {
substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait")
}

// Expressions
// Expression Methods
// There is one method per DataFusion Expr to allow for easy overriding of producer behaviour
// These methods have default implementations calling the common handler code, to allow for users
// to re-use common handling logic.

fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> {
to_substrait_rex(self, expr, schema)
}
Expand Down Expand Up @@ -212,7 +293,7 @@ pub trait SubstraitProducer: Send + Sync + Sized {
from_like(self, like, schema)
}

/// Handles: Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative
/// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative
fn consume_unary_expr(
&mut self,
expr: &Expr,
Expand Down Expand Up @@ -253,7 +334,7 @@ pub trait SubstraitProducer: Send + Sync + Sized {
from_scalar_function(self, scalar_fn, schema)
}

fn consume_agg_function(
fn consume_aggregate_function(
&mut self,
agg_fn: &expr::AggregateFunction,
schema: &DFSchemaRef,
Expand Down Expand Up @@ -301,14 +382,14 @@ impl<'a> DefaultSubstraitProducer<'a> {
}

impl SubstraitProducer for DefaultSubstraitProducer<'_> {
fn get_extensions(self) -> Extensions {
self.extensions
}

fn register_function(&mut self, fn_name: String) -> u32 {
self.extensions.register_function(fn_name)
}

fn get_extensions(self) -> Extensions {
self.extensions
}

fn consume_extension(&mut self, plan: &Extension) -> Result<Box<Rel>> {
let extension_bytes = self
.state
Expand Down Expand Up @@ -1164,7 +1245,7 @@ pub fn to_substrait_agg_measure(
) -> Result<Measure> {
match expr {
Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema),
Expr::Alias(Alias{expr,..}) => {
Expr::Alias(Alias { expr, .. }) => {
to_substrait_agg_measure(producer, expr, schema)
}
_ => internal_err!(
Expand Down Expand Up @@ -2631,7 +2712,7 @@ mod test {
],
false,
)
.into(),
.into(),
false,
))?;

Expand All @@ -2640,7 +2721,7 @@ mod test {
Field::new("c0", DataType::Int32, true),
Field::new("c1", DataType::Utf8, true),
]
.into(),
.into(),
))?;

round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ async fn self_join_introduces_aliases() -> Result<()> {
\n TableScan: data projection=[b, c]",
false,
)
.await
.await
}

#[tokio::test]
Expand Down

0 comments on commit 4a464cb

Please sign in to comment.