论如何优雅地在 WeNet 中支持 ONNX 导出
前言
千呼万唤始出来,犹抱琵琶半遮面。WeNet 社区期待良久的完全体 ONNX 终于和大家见面了。ONNX 成立的初衷就是解决神经网络模型在不同训练、推理框架之间的转换问题,而模型转换又是模型部署过程中一个重要的环节。因此,ONNX 在模型部署中扮演着重要的角色。
对于第三方部署框架来说(如 OnnxRuntime,TensorRT,MACE,MNN,NCNN 等),只要提供一个 ONNX 转换工具便可使其真正做到与训练框架的解耦;对于第三方芯片厂商来说(如 Nvidia,Intel,地平线,寒武纪,昆仑芯等),支持 ONNX 这种标准定义的模型表示格式可以大大减轻其开发工作量。对于要部署模型的普通用户来说,可以不受模型训练框架的约束,通过 ONNX 转换自由选择合适的部署框架和硬件平台。由此可见,支持导出 ONNX 模型(流式 + 非流式)是 WeNet 生态里不可或缺的重要一环。
在上一期中(虎牙在 WeNet 中开源 ONNX 推理支持)我们介绍了如何支持 onnxruntime 的 CPU 推理,然而大家似乎对导出 ONNX 模型的过程有一种云里雾里的朦胧感。在本期,我们将详细介绍 WeNet 开发者们针对早期流式模型导出方案所提出的优化策略,以期实现对 wenet/transformer/xx.py 代码的完全零注入修改。
往期回顾
在早先的第一版流式模型导出方案中(作业帮:基于 WeNet + ONNX 的端到端语音识别方案),面临的几个主要问题是:
一:ONNX 不支持 torch.tensor 转 index 的切片操作二:不支持传入 NoneType 类型参数三:不支持 List[tensor] 形式的输入和输出四:支持 If-else 的动态变化
针对以上问题,第一版的解决方案是:
1. 引入 torch.jit.script 和 slice_helper2. 引入音频长度为 1、值为 0 的 cache 张量替代 None3. 通过 cat 将 list 中的 tensor 合并为一个输出4. 引入 torch.jit.script,将 if-else 部分独立成函数 get_next_cache_start
显然第一版的方案是存在诸多不便的,比如 1、2、4 需要对 wenet/transformer/xx.py 的代码实施注入性修改,这些注入性修改包括:
- 在 4 中使用的 torch.jit.Script 需要将要修饰的代码抽离出本体;此外,被torch.jit.Script 修饰的代码会导致模型在 Pytorch 上无法进行计算,即无法进行训练。为了保证训练的正常进行,原方案不得不引入大量的 if onnx_mode 分支以区分导出和训练。
- 在 2 中引入的长度为 1 的张量使得推理过程和训练过程不匹配。这是由于推理过程对输入和输出的 slice 需要根据 cache 长度计算,因此这里的长度 1 会导致推理和训练的 slice 不一致,需要做特殊处理。在原方案中,该特殊处理即引入大量的 onnx 分支以区分导出和训练。
if onnx_mode: # 原方案onnx_mode分支示例
导出分支
else:
训练分支优化策略
为了克服上述的诸多不便,WeNet 开发者给出了新版的解决方案:
1. opset >= 13 时 ONNX 已经可以支持上述的切片操作(链接 https://github.com/onnx/onnx/blob/main/docs/Operators.md#Slice 由 @Mddct 提供)2. 引入音频长度为 0、值为 0 的 dummy 张量替代 None3. 在我们最新版的 forward_chunk 接口中(https://github.com/wenet-e2e/wenet/pull/1002),在设计时就考虑到了该情况,输入输出已经不再是 List 形式4. 通过巧妙地构造输入的 cache 形状,保证了整个 inference 流程中遇到的所有 if-else 在给定的解码策略下永远都只走同一个分支,也即 if-else 从动态变为静态,这时 ONNX 导出时报的类似如下的 warning 便可以愉快忽略!
/home/xcsong/workspace/wenet/wenet/transformer/encoder.py:226: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if required_cache_size < 0:
对于第一点,感谢 ONNX 的持续更新。
对于第二点, Torch 和 ONNX 均可以接受维度中存在 0 的 tensor,且可以对这种 tensor 做常规的 slice 和 concat 操作(举例如下)。这样做的好处是由于 cache 长度为 0 ,训练和推理的逻辑不需要做任何修改,两者完全相同。
a = torch.ones((1, 2, 0, 4))
b = torch.ones((1, 2, 3, 4))
c = torch.cat((a, b), dim=2)
torch.equal(b, c) # True
d = torch.split(a, 2, dim=-1)
torch.equal(d[0], d[1]) # True
对于第四点,其服务的目标代码(https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L226-L231)是
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
可见,对于 required_cache_size < 0(16 chunksize / -1 leftchunks) 和 required_cache_size == 0(16 chunksize / 0 leftchunks) 两种情况, if-else 本就不起作用,这两种情况下代码永远选择走固定的 branch。if-else 的动态变化主要出现在(16 chunksize / 4 leftchunks)的解码配置下(此时动态变化主要出现在 max 操作,该操作相当于一种变相的 if-else)。
在 16 / 4 配置下,若我们对第一个 chunk 送入长度为 0 的 cache,那么前四个 chunk 的 next_cache_start 值均为 0,对于第五个 chunk,由于 attention_key_size - required_cache_size > 0,计算得到的 next_cache_start 将不再是 0 ,此即所谓的“动态变化”。
为了保证在 16 / 4 配置下 next_cache_start 的值在推理的全过程中维持不变,我们选择对第一个 chunk 送入长度为 required_cache_size 而非长度为 0 的 cache(换句话说,从第一个 chunk 开始就送入“真实”的 cache),此时 next_cache_start == attention_key_size - required_cache_size 永远成立。
从第一个 chunk 开始就送入拥有真实长度的 cache 还有一个额外的好处:对于很多不支持动态 shape 的框架,我们永远无法实现 libtorch 中第一个 chunk 搭配长度为 0 的 cache,后面的 chunk 搭配长度为 required_cache_size 的 cache 的操作,此时上述修改便是必须的。
至此,经过我们对第一版方案的逐点优化,我们以完全零注入修改的方式实现了 ONNX 模型的导出,详细代码见 PR (https://github.com/wenet-e2e/wenet/pull/1023)。
补充介绍
除了第一版方案遇到的四个问题外,我们在实践过程中还遇到了如下两个问题:
我们的解决方案:五:最初的方案只涉及 U2 模型,对于 U2++ 模型中双向 decoder 的导出,反向 decoder 需要构造反向输入,其中涉及到 pad_sequence 这个 op,ONNX 是不支持导出的。六:超参数的存取问题,不希望单独写到一个文件中,最好可以和 ONNX 模型耦合
5. 针对该问题,由@Mddct重新设计了一个与 pad_sequence 等价且能被 ONNX 感知到 shape 变化的函数 https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/asr_model.py#L683-L721
# NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used
# in `reverse_pad_list` thus we have to refine the below code.
# Issue: https://github.com/wenet-e2e/wenet/issues/1113
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
max_len = torch.max(r_hyps_lens)
index_range = torch.arange(0, max_len, 1).to(encoder_out.device)
seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range # (beam, max_len)
# >>> seq_mask
# >>> tensor([[ True, True, True],
# >>> [ True, True, True],
# >>> [ True, False, False]])
index = (seq_len_expand - 1) - index_range # (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index = index * seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
r_hyps = torch.gather(r_hyps, 1, index)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps = torch.where(seq_mask, r_hyps, self.eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
6. 通过 onnx 的 metadata 接口将超参数全部存入 onnx 模型# write
onnx_encoder = onnx.load(encoder_outpath)
for (k, v) in args.items():
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.save(onnx_encoder, encoder_outpath)
# read
ort_session = onnxruntime.InferenceSession(encoder_outpath)
meta = ort_session.get_modelmeta()
print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map))
后记
作为一种几乎公认的中间格式模型定义,ONNX 之于 WeNet 社区小伙伴们的意义已经不止于“加速”:有了这样一个万能的“桥梁”, OpenVINO / MNN / NCNN 推理还会远吗?正所谓:
蒹葭苍苍,白露为霜。所谓伊人,在水一方。溯洄从之,道阻且长。“渡桥”从之,宛在手中央。
?
长按添加语音小管家
备注:姓名-学校/公司-研究方向-城市(如:张三-清华-语音合成-北京)即可申请加入语音之家语音合成/语音增强/声纹识别/说话人日志等技术交流群每月大咖直播分享、求职内推、干货资讯汇总、与 5000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企语音开发者互动交流~
永久福利 直投简历
(简历投递):[email protected]
语音杂谈内推助力,leader直收简历
企业招聘旺季,推荐机会不容错过
觉得本篇文章不错?
扫码关注我们
语音人的技术客栈
专注于语音技术分享与干货推送
这篇关于论如何优雅地在 WeNet 中支持 ONNX 导出的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!