更新时间:2025-12-10 GMT+08:00
分享

UDAF

UDAF是输入多行返回一组聚合值的函数,适用于典型的多模态数据处理场景,包括视频/音频/图像的统计特征聚合,多模态标签总结等。UDAF具体的定义和使用约束请参考UDAF

UDAF注册

UDAF注册传入的必须是Python类,对于__init__、aggregate_state、accumulate、merge、finish、 __del__ 6个实例方法有特殊的含义认定,详情请参见下表;其他的Python类和实例方法不做限制,用户可以任意添加。

表1 注册UDAF时的特殊实例方法

Python Class方法

是否必须

参数

含义

适用场景

__init__(self, *args, **kwargs)

通过with_arguments方法传入参数,只允许传入标量Scalar。

UDAF的构造方法。

初始化UDAF属性(如保存参数、打开文件、建立连接等)。

aggregate_state(self)

无。

可序列化的聚合中间态,用于跨分区传递。

accumulate与merge间共享状态。

accumulate(self, *args, **kwargs)

通过UDAF算子传入参数,可以传入标量Scalar和列名Column。

增量更新聚合中间态。

正常数据扫过阶段。

merge(self, other_state)

来自其它分区的中间态。

合并其它分区状态到当前中间态。

分布式reduce阶段,多分区聚合结果归并。

finish(self)

无。

输出最终聚合结果。

聚合收尾。

__del__(self)

不支持传入参数。

UDAF的析构方法。

UDAF资源清理(如关闭文件、断开网络连接等)。

示例

import ibis
import fabric_data as fabric
import datetime as dt

# create session
con = ai_lake.connect(...)

# UDAF definition
class PythonVWAP:
    """A generic UDAF that computes VWAP (volume-weighted average price) and total quantity."""

    def __init__(
        self,
        *,
        symbol: Optional[str] = None,
        start_date: Optional[dt.date] = None,
        end_date: Optional[dt.date] = None,
        min_qty: float = 0.0,
        round_digits: Optional[int] = None,
        json_indent: Optional[int] = None,
        **kwargs: Any,
    ):
        self._symbol = symbol
        self._start = start_date
        self._end = end_date
        self._min_qty = float(min_qty)
        self._round = round_digits
        self._json_indent = json_indent
        self._agg_state = _AggState()

    @property
    def aggregate_state(self) -> _AggState:
        return self._agg_state

    def _date_in_range(self, d: Optional[dt.date]) -> bool:
        if d is None:
            return not (self._start or self._end)
        if self._start and d < self._start:
            return False
        if self._end and d > self._end:
            return False
        return True

    def accumulate(
        self,
        quantity: Optional[float],
        price: Optional[float],
        date: Optional[dt.date] = None,
        symbol: Optional[str] = None,
    ) -> None:
        st = self._agg_state
        st.rows_seen += 1

        if quantity is None or price is None:
            return
        try:
            q = float(quantity)
            p = float(price)
        except (TypeError, ValueError):
            return
        if q <= 0 or p <= 0 or q < self._min_qty:
            return
        if self._symbol is not None and symbol is not None and symbol != self._symbol:
            return
        if not self._date_in_range(date):
            return

        st.sum_px_qty += p * q
        st.sum_qty += q
        st.rows_used += 1

    def merge(self, other_agg_state: _AggState) -> None:
        st = self._agg_state
        st.sum_px_qty += other_agg_state.sum_px_qty
        st.sum_qty += other_agg_state.sum_qty
        st.rows_seen += other_agg_state.rows_seen
        st.rows_used += other_agg_state.rows_used

    def finish(self) -> str:
        st = self._agg_state
        if st.sum_qty <= 0:
            payload: Dict[str, Optional[float | int]] = {
                "vwap": None,
                "total_quantity": 0.0,
                "rows_seen": st.rows_seen,
                "rows_used": st.rows_used,
            }
        else:
            v = st.sum_px_qty / st.sum_qty
            if self._round is not None:
                v = round(v, self._round)
            payload = {
                "vwap": v,
                "total_quantity": st.sum_qty,
                "rows_seen": st.rows_seen,
                "rows_used": st.rows_used,
            }

        # Return JSON *string*
        return json.dumps(
            payload,
            indent=self._json_indent,
            ensure_ascii=False,
            sort_keys=True,
            separators=None if self._json_indent is not None else (",", ":"),
        )

# register udaf
con.create_agg_function(
    PythonVWAP, 
    database="test", 
    signature=fabric.Signature(
        parameters=[
            fabric.Parameter(name="quantity", annotation=int),
            fabric.Parameter(name="price", annotation=float),
            fabric.Parameter(name="date", annotation=dt.date),
            fabric.Parameter(name="symbol", annotation=str),
        ],
        return_annotation=str,
    ),
    volatility=fabric.udf.VolatilityType.IMMUTABLE,
    strict=False,
)

# use udaf
ds = con.load_dataset("stock_table", database="test")
VMAP_handler = con.get_function("PythonVWAP", database="test")

from fabric_data.multimodal.function import AggregateFnBuilder
udaf_builder = AggregateFnBuilder(
    fn=VMAP_handler,
    on=[ds.quantity, ds.price, ds.data, ds.symbol],
    as_col="VWAP_column",
    num_dpus=0.5,
)

ds = ds.aggregate(udaf_builder, by=[ds.symbol])

# trigger executing
res = ds.execute()

df = res.copy()
# VWAP_column is a JSON string; expand to columns
parsed = df["VWAP_column"].apply(json.loads).apply(pd.Series)
df = pd.concat([df.drop(columns=["VWAP_column"]), parsed], axis=1)

print(df)
| symbol | vwap   | total_quantity | rows_seen | rows_used |
|--------|--------|----------------|-----------|-----------|
| string | float64| float64        | int64     | int64     |
|--------|--------|----------------|-----------|-----------|
| AAPL   | 188.42 | 10000.0        | 120       | 90        |
| MSFT   | 321.15 |  8000.0        | 110       | 85        |

相关文档