UDAF
UDAF是输入多行返回一组聚合值的函数,适用于典型的多模态数据处理场景,包括视频/音频/图像的统计特征聚合,多模态标签总结等。UDAF具体的定义和使用约束请参考UDAF。
UDAF注册
UDAF注册传入的必须是Python类,对于__init__、aggregate_state、accumulate、merge、finish、 __del__ 6个实例方法有特殊的含义认定,详情请参见下表;其他的Python类和实例方法不做限制,用户可以任意添加。
|
Python Class方法 |
是否必须 |
参数 |
含义 |
适用场景 |
|---|---|---|---|---|
|
__init__(self, *args, **kwargs) |
否 |
通过with_arguments方法传入参数,只允许传入标量Scalar。 |
UDAF的构造方法。 |
初始化UDAF属性(如保存参数、打开文件、建立连接等)。 |
|
aggregate_state(self) |
是 |
无。 |
可序列化的聚合中间态,用于跨分区传递。 |
accumulate与merge间共享状态。 |
|
accumulate(self, *args, **kwargs) |
是 |
通过UDAF算子传入参数,可以传入标量Scalar和列名Column。 |
增量更新聚合中间态。 |
正常数据扫过阶段。 |
|
merge(self, other_state) |
是 |
来自其它分区的中间态。 |
合并其它分区状态到当前中间态。 |
分布式reduce阶段,多分区聚合结果归并。 |
|
finish(self) |
是 |
无。 |
输出最终聚合结果。 |
聚合收尾。 |
|
__del__(self) |
否 |
不支持传入参数。 |
UDAF的析构方法。 |
UDAF资源清理(如关闭文件、断开网络连接等)。 |
示例
import ibis
import fabric_data as fabric
import datetime as dt
# create session
con = ai_lake.connect(...)
# UDAF definition
class PythonVWAP:
"""A generic UDAF that computes VWAP (volume-weighted average price) and total quantity."""
def __init__(
self,
*,
symbol: Optional[str] = None,
start_date: Optional[dt.date] = None,
end_date: Optional[dt.date] = None,
min_qty: float = 0.0,
round_digits: Optional[int] = None,
json_indent: Optional[int] = None,
**kwargs: Any,
):
self._symbol = symbol
self._start = start_date
self._end = end_date
self._min_qty = float(min_qty)
self._round = round_digits
self._json_indent = json_indent
self._agg_state = _AggState()
@property
def aggregate_state(self) -> _AggState:
return self._agg_state
def _date_in_range(self, d: Optional[dt.date]) -> bool:
if d is None:
return not (self._start or self._end)
if self._start and d < self._start:
return False
if self._end and d > self._end:
return False
return True
def accumulate(
self,
quantity: Optional[float],
price: Optional[float],
date: Optional[dt.date] = None,
symbol: Optional[str] = None,
) -> None:
st = self._agg_state
st.rows_seen += 1
if quantity is None or price is None:
return
try:
q = float(quantity)
p = float(price)
except (TypeError, ValueError):
return
if q <= 0 or p <= 0 or q < self._min_qty:
return
if self._symbol is not None and symbol is not None and symbol != self._symbol:
return
if not self._date_in_range(date):
return
st.sum_px_qty += p * q
st.sum_qty += q
st.rows_used += 1
def merge(self, other_agg_state: _AggState) -> None:
st = self._agg_state
st.sum_px_qty += other_agg_state.sum_px_qty
st.sum_qty += other_agg_state.sum_qty
st.rows_seen += other_agg_state.rows_seen
st.rows_used += other_agg_state.rows_used
def finish(self) -> str:
st = self._agg_state
if st.sum_qty <= 0:
payload: Dict[str, Optional[float | int]] = {
"vwap": None,
"total_quantity": 0.0,
"rows_seen": st.rows_seen,
"rows_used": st.rows_used,
}
else:
v = st.sum_px_qty / st.sum_qty
if self._round is not None:
v = round(v, self._round)
payload = {
"vwap": v,
"total_quantity": st.sum_qty,
"rows_seen": st.rows_seen,
"rows_used": st.rows_used,
}
# Return JSON *string*
return json.dumps(
payload,
indent=self._json_indent,
ensure_ascii=False,
sort_keys=True,
separators=None if self._json_indent is not None else (",", ":"),
)
# register udaf
con.create_agg_function(
PythonVWAP,
database="test",
signature=fabric.Signature(
parameters=[
fabric.Parameter(name="quantity", annotation=int),
fabric.Parameter(name="price", annotation=float),
fabric.Parameter(name="date", annotation=dt.date),
fabric.Parameter(name="symbol", annotation=str),
],
return_annotation=str,
),
volatility=fabric.udf.VolatilityType.IMMUTABLE,
strict=False,
)
# use udaf
ds = con.load_dataset("stock_table", database="test")
VMAP_handler = con.get_function("PythonVWAP", database="test")
from fabric_data.multimodal.function import AggregateFnBuilder
udaf_builder = AggregateFnBuilder(
fn=VMAP_handler,
on=[ds.quantity, ds.price, ds.data, ds.symbol],
as_col="VWAP_column",
num_dpus=0.5,
)
ds = ds.aggregate(udaf_builder, by=[ds.symbol])
# trigger executing
res = ds.execute()
df = res.copy()
# VWAP_column is a JSON string; expand to columns
parsed = df["VWAP_column"].apply(json.loads).apply(pd.Series)
df = pd.concat([df.drop(columns=["VWAP_column"]), parsed], axis=1)
print(df)
| symbol | vwap | total_quantity | rows_seen | rows_used |
|--------|--------|----------------|-----------|-----------|
| string | float64| float64 | int64 | int64 |
|--------|--------|----------------|-----------|-----------|
| AAPL | 188.42 | 10000.0 | 120 | 90 |
| MSFT | 321.15 | 8000.0 | 110 | 85 |