@@ -399,15 +399,18 @@ def pct_change(block: blocks.Block, periods: int = 1) -> blocks.Block:
399399 window_spec = windows.unbound()
400400
401401 original_columns = block.value_columns
402- block, shift_columns = block.multi_apply_window_op(
403- original_columns, agg_ops.ShiftOp(periods), window_spec=window_spec
404- )
405402 exprs = []
406- for original_col, shifted_col in zip(original_columns, shift_columns):
407- change_expr = ops.sub_op.as_expr(original_col, shifted_col)
408- pct_change_expr = ops.div_op.as_expr(change_expr, shifted_col)
403+ for original_col in original_columns:
404+ shift_expr = agg_expressions.WindowExpression(
405+ agg_expressions.UnaryAggregation(
406+ agg_ops.ShiftOp(periods), ex.deref(original_col)
407+ ),
408+ window_spec,
409+ )
410+ change_expr = ops.sub_op.as_expr(original_col, shift_expr)
411+ pct_change_expr = ops.div_op.as_expr(change_expr, shift_expr)
409412 exprs.append(pct_change_expr)
410- return block.project_exprs (exprs, labels=column_labels, drop=True)
413+ return block.project_block_exprs (exprs, labels=column_labels, drop=True)
411414
412415
413416def rank(
@@ -428,16 +431,11 @@ def rank(
428431
429432 columns = columns or tuple(col for col in block.value_columns)
430433 labels = [block.col_id_to_label[id] for id in columns]
431- # Step 1: Calculate row numbers for each row
432- # Identify null values to be treated according to na_option param
433- rownum_col_ids = []
434- nullity_col_ids = []
434+
435+ result_exprs = []
435436 for col in columns:
436- block, nullity_col_id = block.apply_unary_op(
437- col,
438- ops.isnull_op,
439- )
440- nullity_col_ids.append(nullity_col_id)
437+ # Step 1: Calculate row numbers for each row
438+ # Identify null values to be treated according to na_option param
441439 window_ordering = (
442440 ordering.OrderingExpression(
443441 ex.deref(col),
@@ -448,87 +446,66 @@ def rank(
448446 ),
449447 )
450448 # Count_op ignores nulls, so if na_option is "top" or "bottom", we instead count the nullity columns, where nulls have been mapped to bools
451- block, rownum_id = block.apply_window_op (
452- col if na_option == "keep" else nullity_col_id,
453- agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op,
454- window_spec=windows.unbound(
455- grouping_keys=grouping_cols, ordering=window_ordering
456- )
449+ target_expr = (
450+ ex.deref( col) if na_option == "keep" else ops.isnull_op.as_expr(col)
451+ )
452+ window_op = agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op
453+ window_spec = (
454+ windows.unbound(grouping_keys=grouping_cols, ordering=window_ordering )
457455 if method == "dense"
458456 else windows.rows(
459457 end=0, ordering=window_ordering, grouping_keys=grouping_cols
460- ),
461- skip_reproject_unsafe=(col != columns[-1]),
458+ )
459+ )
460+ result_expr: ex.Expression = agg_expressions.WindowExpression(
461+ agg_expressions.UnaryAggregation(window_op, target_expr), window_spec
462462 )
463463 if pct:
464- block, max_id = block.apply_window_op(
465- rownum_id, agg_ops.max_op, windows.unbound(grouping_keys=grouping_cols)
464+ result_expr = ops.div_op.as_expr(
465+ result_expr,
466+ agg_expressions.WindowExpression(
467+ agg_expressions.UnaryAggregation(agg_ops.max_op, result_expr),
468+ windows.unbound(grouping_keys=grouping_cols),
469+ ),
466470 )
467- block, rownum_id = block.project_expr(ops.div_op.as_expr(rownum_id, max_id))
468-
469- rownum_col_ids.append(rownum_id)
470-
471- # Step 2: Apply aggregate to groups of like input values.
472- # This step is skipped for method=='first' or 'dense'
473- if method in ["average", "min", "max"]:
474- agg_op = {
475- "average": agg_ops.mean_op,
476- "min": agg_ops.min_op,
477- "max": agg_ops.max_op,
478- }[method]
479- post_agg_rownum_col_ids = []
480- for i in range(len(columns)):
481- block, result_id = block.apply_window_op(
482- rownum_col_ids[i],
483- agg_op,
484- window_spec=windows.unbound(grouping_keys=(columns[i], *grouping_cols)),
485- skip_reproject_unsafe=(i < (len(columns) - 1)),
471+ # Step 2: Apply aggregate to groups of like input values.
472+ # This step is skipped for method=='first' or 'dense'
473+ if method in ["average", "min", "max"]:
474+ agg_op = {
475+ "average": agg_ops.mean_op,
476+ "min": agg_ops.min_op,
477+ "max": agg_ops.max_op,
478+ }[method]
479+ result_expr = agg_expressions.WindowExpression(
480+ agg_expressions.UnaryAggregation(agg_op, result_expr),
481+ windows.unbound(grouping_keys=(col, *grouping_cols)),
486482 )
487- post_agg_rownum_col_ids.append(result_id)
488- rownum_col_ids = post_agg_rownum_col_ids
489-
490- # Pandas masks all values where any grouping column is null
491- # Note: we use pd.NA instead of float('nan')
492- if grouping_cols:
493- predicate = functools.reduce(
494- ops.and_op.as_expr,
495- [ops.notnull_op.as_expr(column_id) for column_id in grouping_cols],
496- )
497- block = block.project_exprs(
498- [
499- ops.where_op.as_expr(
500- ex.deref(col),
501- predicate,
502- ex.const(None),
503- )
504- for col in rownum_col_ids
505- ],
506- labels=labels,
507- )
508- rownum_col_ids = list(block.value_columns[-len(rownum_col_ids) :])
509-
510- # Step 3: post processing: mask null values and cast to float
511- if method in ["min", "max", "first", "dense"]:
512- # Pandas rank always produces Float64, so must cast for aggregation types that produce ints
513- return (
514- block.select_columns(rownum_col_ids)
515- .multi_apply_unary_op(ops.AsTypeOp(pd.Float64Dtype()))
516- .with_column_labels(labels)
517- )
518- if na_option == "keep":
519- # For na_option "keep", null inputs must produce null outputs
520- exprs = []
521- for i in range(len(columns)):
522- exprs.append(
523- ops.where_op.as_expr(
524- ex.const(pd.NA, dtype=pd.Float64Dtype()),
525- nullity_col_ids[i],
526- rownum_col_ids[i],
527- )
483+ # Pandas masks all values where any grouping column is null
484+ # Note: we use pd.NA instead of float('nan')
485+ if grouping_cols:
486+ predicate = functools.reduce(
487+ ops.and_op.as_expr,
488+ [ops.notnull_op.as_expr(column_id) for column_id in grouping_cols],
489+ )
490+ result_expr = ops.where_op.as_expr(
491+ result_expr,
492+ predicate,
493+ ex.const(None),
528494 )
529- return block.project_exprs(exprs, labels=labels, drop=True)
530495
531- return block.select_columns(rownum_col_ids).with_column_labels(labels)
496+ # Step 3: post processing: mask null values and cast to float
497+ if method in ["min", "max", "first", "dense"]:
498+ # Pandas rank always produces Float64, so must cast for aggregation types that produce ints
499+ result_expr = ops.AsTypeOp(pd.Float64Dtype()).as_expr(result_expr)
500+ elif na_option == "keep":
501+ # For na_option "keep", null inputs must produce null outputs
502+ result_expr = ops.where_op.as_expr(
503+ ex.const(pd.NA, dtype=pd.Float64Dtype()),
504+ ops.isnull_op.as_expr(col),
505+ result_expr,
506+ )
507+ result_exprs.append(result_expr)
508+ return block.project_block_exprs(result_exprs, labels=labels, drop=True)
532509
533510
534511def dropna(
0 commit comments