Updated on 2025-12-19 GMT+08:00

UDAF

A user-defined aggregate function (UDAF) aggregates multiple input rows into a single set of values. It is particularly useful for summarizing statistical features from multimodal data, such as video, audio, or images, or consolidating multimodal labels. For detailed definitions and usage constraints, refer to UDAF.

UDAF Registration

To register a UDAF, you must provide a Python class that includes six specific instance methods: __init__, aggregate_state, accumulate, merge, finish, and __del__. These methods have predefined meanings as outlined in the table below. Beyond these requirements, you are free to add other classes or instance methods without restrictions.

Table 1 Specific instance methods when registering a UDAF

Python Class Method

Mandatory

Parameter

Description

Use Case

__init__(self, *args, **kwargs)

No

Parameters are passed through the with_arguments method. Only scalar values are allowed.

Constructor of the UDAF.

Initializes UDAF attributes (such as saving parameters, opening files, and establishing connections).

aggregate_state(self)

Yes

None.

Serializable intermediate aggregation state for cross-partition transfer.

Sharing state between accumulate and merge.

accumulate(self, *args, **kwargs)

Yes

Parameters are passed through the UDAF operator. Both scalar values and column names are allowed.

Incrementally updates the intermediate aggregation state.

Processes data during normal scanning phase.

merge(self, other_state)

Yes

Intermediate state from another partition.

Merges states from other partitions into the current intermediate state.

Combines results across multiple partitions during distributed reduction.

finish(self)

Yes

None.

Outputs the final aggregated result.

Finalizes the aggregation process.

__del__(self)

No

Does not support passing parameters.

Destructor of the UDAF.

Clears UDAF resources (for example, closing files and disconnecting networks).

Example

import ibis
import fabric_data as fabric
import datetime as dt

# create session
con = ai_lake.connect(...)

# UDAF definition
class PythonVWAP:
    """A generic UDAF that computes VWAP (volume-weighted average price) and total quantity."""

    def __init__(
        self,
        *,
        symbol: Optional[str] = None,
        start_date: Optional[dt.date] = None,
        end_date: Optional[dt.date] = None,
        min_qty: float = 0.0,
        round_digits: Optional[int] = None,
        json_indent: Optional[int] = None,
        **kwargs: Any,
    ):
        self._symbol = symbol
        self._start = start_date
        self._end = end_date
        self._min_qty = float(min_qty)
        self._round = round_digits
        self._json_indent = json_indent
        self._agg_state = _AggState()

    @property
    def aggregate_state(self) -> _AggState:
        return self._agg_state

    def _date_in_range(self, d: Optional[dt.date]) -> bool:
        if d is None:
            return not (self._start or self._end)
        if self._start and d < self._start:
            return False
        if self._end and d > self._end:
            return False
        return True

    def accumulate(
        self,
        quantity: Optional[float],
        price: Optional[float],
        date: Optional[dt.date] = None,
        symbol: Optional[str] = None,
    ) -> None:
        st = self._agg_state
        st.rows_seen += 1

        if quantity is None or price is None:
            return
        try:
            q = float(quantity)
            p = float(price)
        except (TypeError, ValueError):
            return
        if q <= 0 or p <= 0 or q < self._min_qty:
            return
        if self._symbol is not None and symbol is not None and symbol != self._symbol:
            return
        if not self._date_in_range(date):
            return

        st.sum_px_qty += p * q
        st.sum_qty += q
        st.rows_used += 1

    def merge(self, other_agg_state: _AggState) -> None:
        st = self._agg_state
        st.sum_px_qty += other_agg_state.sum_px_qty
        st.sum_qty += other_agg_state.sum_qty
        st.rows_seen += other_agg_state.rows_seen
        st.rows_used += other_agg_state.rows_used

    def finish(self) -> str:
        st = self._agg_state
        if st.sum_qty <= 0:
            payload: Dict[str, Optional[float | int]] = {
                "vwap": None,
                "total_quantity": 0.0,
                "rows_seen": st.rows_seen,
                "rows_used": st.rows_used,
            }
        else:
            v = st.sum_px_qty / st.sum_qty
            if self._round is not None:
                v = round(v, self._round)
            payload = {
                "vwap": v,
                "total_quantity": st.sum_qty,
                "rows_seen": st.rows_seen,
                "rows_used": st.rows_used,
            }

        # Return JSON *string*
        return json.dumps(
            payload,
            indent=self._json_indent,
            ensure_ascii=False,
            sort_keys=True,
            separators=None if self._json_indent is not None else (",", ":"),
        )

# register udaf
con.create_agg_function(
    PythonVWAP, 
    database="test", 
    signature=fabric.Signature(
        parameters=[
            fabric.Parameter(name="quantity", annotation=int),
            fabric.Parameter(name="price", annotation=float),
            fabric.Parameter(name="date", annotation=dt.date),
            fabric.Parameter(name="symbol", annotation=str),
        ],
        return_annotation=str,
    ),
    volatility=fabric.udf.VolatilityType.IMMUTABLE,
    strict=False,
)

# use udaf
ds = con.load_dataset("stock_table", database="test")
VMAP_handler = con.get_function("PythonVWAP", database="test")

from fabric_data.multimodal.function import AggregateFnBuilder
udaf_builder = AggregateFnBuilder(
    fn=VMAP_handler,
    on=[ds.quantity, ds.price, ds.data, ds.symbol],
    as_col="VWAP_column",
    num_dpus=0.5,
)

ds = ds.aggregate(udaf_builder, by=[ds.symbol])

# trigger executing
res = ds.execute()

df = res.copy()
# VWAP_column is a JSON string; expand to columns
parsed = df["VWAP_column"].apply(json.loads).apply(pd.Series)
df = pd.concat([df.drop(columns=["VWAP_column"]), parsed], axis=1)

print(df)
| symbol | vwap   | total_quantity | rows_seen | rows_used |
|--------|--------|----------------|-----------|-----------|
| string | float64| float64        | int64     | int64     |
|--------|--------|----------------|-----------|-----------|
| AAPL   | 188.42 | 10000.0        | 120       | 90        |
| MSFT   | 321.15 |  8000.0        | 110       | 85        |