Skip to content

Commit 1c5aeb9

Browse files
committed
feat: add support for pyspark equivalent explode_outer, posexplode/posexplode_outer (via explode_with_index)
1 parent e79ec74 commit 1c5aeb9

File tree

13 files changed

+873
-56
lines changed

13 files changed

+873
-56
lines changed

protos/logical_plan/v1/plans.proto

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ message LogicalPlan {
4242
Union union = 24;
4343
Limit limit = 25;
4444
Explode explode = 26;
45+
ExplodeWithIndex explode_with_index = 34;
4546
DropDuplicates drop_duplicates = 27;
4647
Sort sort = 28;
4748
Unnest unnest = 29;
@@ -143,6 +144,15 @@ message Limit {
143144
message Explode {
144145
LogicalPlan input = 1;
145146
LogicalExpr expr = 2;
147+
bool keep_null_and_empty = 3;
148+
}
149+
150+
message ExplodeWithIndex {
151+
LogicalPlan input = 1;
152+
LogicalExpr expr = 2;
153+
string index_name = 3;
154+
optional string value_name = 4;
155+
bool keep_null_and_empty = 5;
146156
}
147157

148158
message DropDuplicates {

src/fenic/_backends/local/physical_plan/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
DropDuplicatesExec as DropDuplicatesExec,
3838
)
3939
from fenic._backends.local.physical_plan.transform import ExplodeExec as ExplodeExec
40+
from fenic._backends.local.physical_plan.transform import (
41+
ExplodeWithIndexExec as ExplodeWithIndexExec,
42+
)
4043
from fenic._backends.local.physical_plan.transform import FilterExec as FilterExec
4144
from fenic._backends.local.physical_plan.transform import LimitExec as LimitExec
4245
from fenic._backends.local.physical_plan.transform import (
@@ -65,6 +68,7 @@
6568
"InMemorySourceExec",
6669
"DropDuplicatesExec",
6770
"ExplodeExec",
71+
"ExplodeWithIndexExec",
6872
"FilterExec",
6973
"LimitExec",
7074
"ProjectionExec",

src/fenic/_backends/local/physical_plan/transform.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,25 @@ def __init__(
174174
child: PhysicalPlan,
175175
physical_expr: pl.Expr,
176176
col_name: str,
177+
keep_null_and_empty: bool,
177178
cache_info: Optional[CacheInfo],
178179
session_state: LocalSessionState,
179180
):
180181
super().__init__([child], cache_info=cache_info, session_state=session_state)
181182
self.physical_expr = physical_expr
182183
self.col_name = col_name
184+
self.keep_null_and_empty = keep_null_and_empty
183185

184186
def execute_node(self, child_dfs: List[pl.DataFrame]) -> pl.DataFrame:
185187
if len(child_dfs) != 1:
186188
raise ValueError("Unreachable: ExplodeExec expects 1 child")
187189
child_df = child_dfs[0]
188190
child_df = child_df.with_columns(self.physical_expr)
189191
exploded_df = child_df.explode(self.col_name)
190-
# Optionally filter out rows where the exploded column is null.
191-
return exploded_df.filter(pl.col(self.col_name).is_not_null())
192+
# Filter out nulls unless keep_null_and_empty is True
193+
if not self.keep_null_and_empty:
194+
exploded_df = exploded_df.filter(pl.col(self.col_name).is_not_null())
195+
return exploded_df
192196

193197
def with_children(self, children: List[PhysicalPlan]) -> PhysicalPlan:
194198
if len(children) != 1:
@@ -197,6 +201,7 @@ def with_children(self, children: List[PhysicalPlan]) -> PhysicalPlan:
197201
child=children[0],
198202
physical_expr=self.physical_expr,
199203
col_name=self.col_name,
204+
keep_null_and_empty=self.keep_null_and_empty,
200205
cache_info=self.cache_info,
201206
session_state=self.session_state,
202207
)
@@ -207,6 +212,109 @@ def build_node_lineage(
207212
) -> Tuple[OperatorLineage, pl.DataFrame]:
208213
child_operator, child_df = self.children[0].build_node_lineage(leaf_nodes)
209214
exploded_df = child_df.explode(self.col_name)
215+
# Filter out nulls unless keep_null_and_empty is True
216+
if not self.keep_null_and_empty:
217+
exploded_df = exploded_df.filter(pl.col(self.col_name).is_not_null())
218+
exploded_df = exploded_df.with_columns(
219+
pl.col("_uuid").alias("_backwards_uuid"),
220+
)
221+
exploded_df = _with_lineage_uuid(exploded_df)
222+
backwards_df = exploded_df.select(["_uuid", "_backwards_uuid"])
223+
224+
materialize_df = exploded_df.drop("_backwards_uuid")
225+
226+
operator = self._build_unary_operator_lineage(
227+
materialize_df=materialize_df,
228+
child=(child_operator, backwards_df),
229+
)
230+
231+
return operator, materialize_df
232+
233+
234+
class ExplodeWithIndexExec(PhysicalPlan):
235+
def __init__(
236+
self,
237+
child: PhysicalPlan,
238+
physical_expr: pl.Expr,
239+
col_name: str,
240+
index_name: str,
241+
value_name: str,
242+
keep_null_and_empty: bool,
243+
cache_info: Optional[CacheInfo],
244+
session_state: LocalSessionState,
245+
):
246+
super().__init__([child], cache_info=cache_info, session_state=session_state)
247+
self.physical_expr = physical_expr
248+
self.col_name = col_name
249+
self.index_name = index_name
250+
self.value_name = value_name
251+
self.keep_null_and_empty = keep_null_and_empty
252+
253+
def execute_node(self, child_dfs: List[pl.DataFrame]) -> pl.DataFrame:
254+
if len(child_dfs) != 1:
255+
raise ValueError("Unreachable: ExplodeWithIndexExec expects 1 child")
256+
child_df = child_dfs[0]
257+
258+
# Add the array column if it's an expression
259+
child_df = child_df.with_columns(self.physical_expr)
260+
261+
# Add a temporary row index to track original rows
262+
child_df = child_df.with_row_index("__explode_row_id")
263+
264+
# Explode the array column
265+
exploded_df = child_df.explode(self.col_name)
266+
267+
if self.keep_null_and_empty:
268+
# For outer explode, we need to handle null/empty arrays specially
269+
# Add the position column, but set it to null for null/empty arrays
270+
exploded_df = exploded_df.with_columns(
271+
pl.when(pl.col(self.col_name).is_not_null())
272+
.then(pl.int_range(pl.len(), dtype=pl.Int64).over("__explode_row_id", mapping_strategy="group_to_rows"))
273+
.otherwise(None)
274+
.alias(self.index_name)
275+
)
276+
else:
277+
# Filter out nulls in the exploded column for regular explode
278+
exploded_df = exploded_df.filter(pl.col(self.col_name).is_not_null())
279+
# Add the position column (0-based index within each original row)
280+
exploded_df = exploded_df.with_columns(
281+
pl.int_range(pl.len(), dtype=pl.Int64)
282+
.over("__explode_row_id", mapping_strategy="group_to_rows")
283+
.alias(self.index_name)
284+
)
285+
286+
# Drop the temporary row index
287+
exploded_df = exploded_df.drop("__explode_row_id")
288+
289+
# Rename the exploded column if needed
290+
if self.value_name != self.col_name:
291+
exploded_df = exploded_df.rename({self.col_name: self.value_name})
292+
293+
return exploded_df
294+
295+
def with_children(self, children: List[PhysicalPlan]) -> PhysicalPlan:
296+
if len(children) != 1:
297+
raise InternalError("Unreachable: ExplodeWithIndexExec expects 1 child")
298+
return ExplodeWithIndexExec(
299+
child=children[0],
300+
physical_expr=self.physical_expr,
301+
col_name=self.col_name,
302+
index_name=self.index_name,
303+
value_name=self.value_name,
304+
keep_null_and_empty=self.keep_null_and_empty,
305+
cache_info=self.cache_info,
306+
session_state=self.session_state,
307+
)
308+
309+
def build_node_lineage(
310+
self,
311+
leaf_nodes: List[OperatorLineage],
312+
) -> Tuple[OperatorLineage, pl.DataFrame]:
313+
child_operator, child_df = self.children[0].build_node_lineage(leaf_nodes)
314+
exploded_df = child_df.explode(self.col_name)
315+
# Filter out nulls unless keep_null_and_empty is True
316+
if not self.keep_null_and_empty:
317+
exploded_df = exploded_df.filter(pl.col(self.col_name).is_not_null())
210318
exploded_df = exploded_df.with_columns(
211319
pl.col("_uuid").alias("_backwards_uuid"),
212320
)

src/fenic/_backends/local/transpiler/plan_converter.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DuckDBTableSinkExec,
1111
DuckDBTableSourceExec,
1212
ExplodeExec,
13+
ExplodeWithIndexExec,
1314
FileSinkExec,
1415
FileSourceExec,
1516
FilterExec,
@@ -41,6 +42,7 @@
4142
DocSource,
4243
DropDuplicates,
4344
Explode,
45+
ExplodeWithIndex,
4446
FileSink,
4547
FileSource,
4648
Filter,
@@ -335,6 +337,30 @@ def _convert_to_physical_plan(self, logical: LogicalPlan, cache_keys: set[str])
335337
child_physical,
336338
physical_expr,
337339
target_field.name,
340+
logical.keep_null_and_empty,
341+
cache_info=logical.cache_info,
342+
session_state=self.session_state,
343+
)
344+
345+
elif isinstance(logical, ExplodeWithIndex):
346+
child_logical = logical.children()[0]
347+
physical_expr = self.expr_converter.convert(
348+
logical._expr
349+
)
350+
child_physical = self._convert_to_physical_plan(
351+
child_logical,
352+
cache_keys,
353+
)
354+
target_field = logical._expr.to_column_field(child_logical, self.session_state)
355+
# Determine the actual value name
356+
actual_value_name = logical.value_name if logical.value_name is not None else target_field.name
357+
return ExplodeWithIndexExec(
358+
child_physical,
359+
physical_expr,
360+
target_field.name,
361+
logical.index_name,
362+
actual_value_name,
363+
logical.keep_null_and_empty,
338364
cache_info=logical.cache_info,
339365
session_state=self.session_state,
340366
)

src/fenic/_gen/protos/logical_plan/v1/plans_pb2.py

Lines changed: 46 additions & 44 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/fenic/_gen/protos/logical_plan/v1/plans_pb2.pyi

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class CacheInfo(_message.Message):
3131
def __init__(self, cache_key: _Optional[str] = ...) -> None: ...
3232

3333
class LogicalPlan(_message.Message):
34-
__slots__ = ("schema", "cache_info", "in_memory_source", "file_source", "table_source", "doc_source", "projection", "filter", "join", "aggregate", "union", "limit", "explode", "drop_duplicates", "sort", "unnest", "sql", "semantic_cluster", "semantic_join", "semantic_similarity_join", "file_sink", "table_sink")
34+
__slots__ = ("schema", "cache_info", "in_memory_source", "file_source", "table_source", "doc_source", "projection", "filter", "join", "aggregate", "union", "limit", "explode", "explode_with_index", "drop_duplicates", "sort", "unnest", "sql", "semantic_cluster", "semantic_join", "semantic_similarity_join", "file_sink", "table_sink")
3535
SCHEMA_FIELD_NUMBER: _ClassVar[int]
3636
CACHE_INFO_FIELD_NUMBER: _ClassVar[int]
3737
IN_MEMORY_SOURCE_FIELD_NUMBER: _ClassVar[int]
@@ -45,6 +45,7 @@ class LogicalPlan(_message.Message):
4545
UNION_FIELD_NUMBER: _ClassVar[int]
4646
LIMIT_FIELD_NUMBER: _ClassVar[int]
4747
EXPLODE_FIELD_NUMBER: _ClassVar[int]
48+
EXPLODE_WITH_INDEX_FIELD_NUMBER: _ClassVar[int]
4849
DROP_DUPLICATES_FIELD_NUMBER: _ClassVar[int]
4950
SORT_FIELD_NUMBER: _ClassVar[int]
5051
UNNEST_FIELD_NUMBER: _ClassVar[int]
@@ -67,6 +68,7 @@ class LogicalPlan(_message.Message):
6768
union: Union
6869
limit: Limit
6970
explode: Explode
71+
explode_with_index: ExplodeWithIndex
7072
drop_duplicates: DropDuplicates
7173
sort: Sort
7274
unnest: Unnest
@@ -76,7 +78,7 @@ class LogicalPlan(_message.Message):
7678
semantic_similarity_join: SemanticSimilarityJoin
7779
file_sink: FileSink
7880
table_sink: TableSink
79-
def __init__(self, schema: _Optional[_Union[FenicSchema, _Mapping]] = ..., cache_info: _Optional[_Union[CacheInfo, _Mapping]] = ..., in_memory_source: _Optional[_Union[InMemorySource, _Mapping]] = ..., file_source: _Optional[_Union[FileSource, _Mapping]] = ..., table_source: _Optional[_Union[TableSource, _Mapping]] = ..., doc_source: _Optional[_Union[DocSource, _Mapping]] = ..., projection: _Optional[_Union[Projection, _Mapping]] = ..., filter: _Optional[_Union[Filter, _Mapping]] = ..., join: _Optional[_Union[Join, _Mapping]] = ..., aggregate: _Optional[_Union[Aggregate, _Mapping]] = ..., union: _Optional[_Union[Union, _Mapping]] = ..., limit: _Optional[_Union[Limit, _Mapping]] = ..., explode: _Optional[_Union[Explode, _Mapping]] = ..., drop_duplicates: _Optional[_Union[DropDuplicates, _Mapping]] = ..., sort: _Optional[_Union[Sort, _Mapping]] = ..., unnest: _Optional[_Union[Unnest, _Mapping]] = ..., sql: _Optional[_Union[SQL, _Mapping]] = ..., semantic_cluster: _Optional[_Union[SemanticCluster, _Mapping]] = ..., semantic_join: _Optional[_Union[SemanticJoin, _Mapping]] = ..., semantic_similarity_join: _Optional[_Union[SemanticSimilarityJoin, _Mapping]] = ..., file_sink: _Optional[_Union[FileSink, _Mapping]] = ..., table_sink: _Optional[_Union[TableSink, _Mapping]] = ...) -> None: ...
81+
def __init__(self, schema: _Optional[_Union[FenicSchema, _Mapping]] = ..., cache_info: _Optional[_Union[CacheInfo, _Mapping]] = ..., in_memory_source: _Optional[_Union[InMemorySource, _Mapping]] = ..., file_source: _Optional[_Union[FileSource, _Mapping]] = ..., table_source: _Optional[_Union[TableSource, _Mapping]] = ..., doc_source: _Optional[_Union[DocSource, _Mapping]] = ..., projection: _Optional[_Union[Projection, _Mapping]] = ..., filter: _Optional[_Union[Filter, _Mapping]] = ..., join: _Optional[_Union[Join, _Mapping]] = ..., aggregate: _Optional[_Union[Aggregate, _Mapping]] = ..., union: _Optional[_Union[Union, _Mapping]] = ..., limit: _Optional[_Union[Limit, _Mapping]] = ..., explode: _Optional[_Union[Explode, _Mapping]] = ..., explode_with_index: _Optional[_Union[ExplodeWithIndex, _Mapping]] = ..., drop_duplicates: _Optional[_Union[DropDuplicates, _Mapping]] = ..., sort: _Optional[_Union[Sort, _Mapping]] = ..., unnest: _Optional[_Union[Unnest, _Mapping]] = ..., sql: _Optional[_Union[SQL, _Mapping]] = ..., semantic_cluster: _Optional[_Union[SemanticCluster, _Mapping]] = ..., semantic_join: _Optional[_Union[SemanticJoin, _Mapping]] = ..., semantic_similarity_join: _Optional[_Union[SemanticSimilarityJoin, _Mapping]] = ..., file_sink: _Optional[_Union[FileSink, _Mapping]] = ..., table_sink: _Optional[_Union[TableSink, _Mapping]] = ...) -> None: ...
8082

8183
class InMemorySource(_message.Message):
8284
__slots__ = ("source",)
@@ -217,12 +219,28 @@ class Limit(_message.Message):
217219
def __init__(self, input: _Optional[_Union[LogicalPlan, _Mapping]] = ..., n: _Optional[int] = ...) -> None: ...
218220

219221
class Explode(_message.Message):
220-
__slots__ = ("input", "expr")
222+
__slots__ = ("input", "expr", "keep_null_and_empty")
221223
INPUT_FIELD_NUMBER: _ClassVar[int]
222224
EXPR_FIELD_NUMBER: _ClassVar[int]
225+
KEEP_NULL_AND_EMPTY_FIELD_NUMBER: _ClassVar[int]
223226
input: LogicalPlan
224227
expr: _expressions_pb2.LogicalExpr
225-
def __init__(self, input: _Optional[_Union[LogicalPlan, _Mapping]] = ..., expr: _Optional[_Union[_expressions_pb2.LogicalExpr, _Mapping]] = ...) -> None: ...
228+
keep_null_and_empty: bool
229+
def __init__(self, input: _Optional[_Union[LogicalPlan, _Mapping]] = ..., expr: _Optional[_Union[_expressions_pb2.LogicalExpr, _Mapping]] = ..., keep_null_and_empty: bool = ...) -> None: ...
230+
231+
class ExplodeWithIndex(_message.Message):
232+
__slots__ = ("input", "expr", "index_name", "value_name", "keep_null_and_empty")
233+
INPUT_FIELD_NUMBER: _ClassVar[int]
234+
EXPR_FIELD_NUMBER: _ClassVar[int]
235+
INDEX_NAME_FIELD_NUMBER: _ClassVar[int]
236+
VALUE_NAME_FIELD_NUMBER: _ClassVar[int]
237+
KEEP_NULL_AND_EMPTY_FIELD_NUMBER: _ClassVar[int]
238+
input: LogicalPlan
239+
expr: _expressions_pb2.LogicalExpr
240+
index_name: str
241+
value_name: str
242+
keep_null_and_empty: bool
243+
def __init__(self, input: _Optional[_Union[LogicalPlan, _Mapping]] = ..., expr: _Optional[_Union[_expressions_pb2.LogicalExpr, _Mapping]] = ..., index_name: _Optional[str] = ..., value_name: _Optional[str] = ..., keep_null_and_empty: bool = ...) -> None: ...
226244

227245
class DropDuplicates(_message.Message):
228246
__slots__ = ("input", "subset")

0 commit comments

Comments
 (0)