论如何优雅地在 WeNet 中支持 ONNX 导出

本文介绍了论如何优雅地在 WeNet 中支持 ONNX 导出,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

前言

千呼万唤始出来,犹抱琵琶半遮面。WeNet 社区期待良久的完全体 ONNX 终于和大家见面了。ONNX 成立的初衷就是解决神经网络模型在不同训练、推理框架之间的转换问题,而模型转换又是模型部署过程中一个重要的环节。因此,ONNX 在模型部署中扮演着重要的角色。

对于第三方部署框架来说(如 OnnxRuntime,TensorRT,MACE,MNN,NCNN 等),只要提供一个  ONNX  转换工具便可使其真正做到与训练框架的解耦;对于第三方芯片厂商来说(如 Nvidia,Intel,地平线,寒武纪,昆仑芯等),支持 ONNX 这种标准定义的模型表示格式可以大大减轻其开发工作量。对于要部署模型的普通用户来说,可以不受模型训练框架的约束,通过 ONNX 转换自由选择合适的部署框架和硬件平台。由此可见,支持导出  ONNX 模型(流式 + 非流式)是 WeNet 生态里不可或缺的重要一环。

在上一期中(虎牙在 WeNet 中开源 ONNX 推理支持)我们介绍了如何支持 onnxruntime 的 CPU 推理,然而大家似乎对导出 ONNX 模型的过程有一种云里雾里的朦胧感。在本期,我们将详细介绍 WeNet 开发者们针对早期流式模型导出方案所提出的优化策略,以期实现对 wenet/transformer/xx.py 代码的完全零注入修改。

往期回顾

在早先的第一版流式模型导出方案中(作业帮:基于 WeNet + ONNX 的端到端语音识别方案),面临的几个主要问题是:

一:ONNX 不支持 torch.tensor 转 index 的切片操作二:不支持传入 NoneType 类型参数三:不支持 List[tensor] 形式的输入和输出四:支持 If-else 的动态变化

针对以上问题,第一版的解决方案是:

1. 引入 torch.jit.script 和 slice_helper2. 引入音频长度为 1、值为 0 的 cache 张量替代 None3. 通过 cat 将 list 中的 tensor 合并为一个输出4. 引入 torch.jit.script,将 if-else 部分独立成函数 get_next_cache_start

显然第一版的方案是存在诸多不便的,比如 1、2、4 需要对 wenet/transformer/xx.py 的代码实施注入性修改,这些注入性修改包括:

  1. 在 4 中使用的 torch.jit.Script 需要将要修饰的代码抽离出本体;此外,被torch.jit.Script 修饰的代码会导致模型在 Pytorch 上无法进行计算,即无法进行训练。为了保证训练的正常进行,原方案不得不引入大量的 if  onnx_mode 分支以区分导出和训练。
  2. 在 2 中引入的长度为 1 的张量使得推理过程和训练过程不匹配。这是由于推理过程对输入和输出的 slice 需要根据 cache 长度计算,因此这里的长度 1 会导致推理和训练的 slice 不一致,需要做特殊处理。在原方案中,该特殊处理即引入大量的 onnx 分支以区分导出和训练。
if onnx_mode:  # 原方案onnx_mode分支示例
    导出分支
else:
    训练分支

优化策略

为了克服上述的诸多不便,WeNet 开发者给出了新版的解决方案:

1. opset >= 13 时 ONNX 已经可以支持上述的切片操作(链接 https://github.com/onnx/onnx/blob/main/docs/Operators.md#Slice 由 @Mddct 提供)2. 引入音频长度为 0、值为 0 的 dummy 张量替代 None3. 在我们最新版的 forward_chunk 接口中(https://github.com/wenet-e2e/wenet/pull/1002),在设计时就考虑到了该情况,输入输出已经不再是 List 形式4. 通过巧妙地构造输入的 cache 形状,保证了整个 inference 流程中遇到的所有 if-else 在给定的解码策略下永远都只走同一个分支,也即 if-else 从动态变为静态,这时 ONNX 导出时报的类似如下的 warning 便可以愉快忽略!
/home/xcsong/workspace/wenet/wenet/transformer/encoder.py:226: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if required_cache_size < 0:

对于第一点,感谢 ONNX 的持续更新。

对于第二点, Torch 和 ONNX 均可以接受维度中存在 0 的 tensor,且可以对这种 tensor 做常规的 slice 和 concat 操作(举例如下)。这样做的好处是由于 cache 长度为 0 ,训练和推理的逻辑不需要做任何修改,两者完全相同。

a = torch.ones((1, 2, 0, 4))
b = torch.ones((1, 2, 3, 4))
c = torch.cat((a, b), dim=2)
torch.equal(b, c)        # True
d = torch.split(a, 2, dim=-1)
torch.equal(d[0], d[1])  # True

对于第四点,其服务的目标代码(https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L226-L231)是

if required_cache_size < 0:
    next_cache_start = 0
elif required_cache_size == 0:
    next_cache_start = attention_key_size
else:
    next_cache_start = max(attention_key_size - required_cache_size, 0)

可见,对于 required_cache_size < 0(16 chunksize / -1 leftchunks) 和 required_cache_size == 0(16 chunksize / 0 leftchunks) 两种情况, if-else 本就不起作用,这两种情况下代码永远选择走固定的 branch。if-else 的动态变化主要出现在(16 chunksize / 4 leftchunks)的解码配置下(此时动态变化主要出现在 max 操作,该操作相当于一种变相的 if-else)。

在 16 / 4 配置下,若我们对第一个 chunk 送入长度为 0 的 cache,那么前四个 chunk 的 next_cache_start 值均为 0,对于第五个 chunk,由于 attention_key_size - required_cache_size > 0,计算得到的 next_cache_start 将不再是 0 ,此即所谓的“动态变化”。

为了保证在 16 / 4 配置下 next_cache_start 的值在推理的全过程中维持不变,我们选择对第一个 chunk 送入长度为 required_cache_size 而非长度为 0 的 cache(换句话说,从第一个 chunk 开始就送入“真实”的 cache),此时  next_cache_start == attention_key_size - required_cache_size 永远成立。

从第一个 chunk 开始就送入拥有真实长度的 cache 还有一个额外的好处:对于很多不支持动态 shape 的框架,我们永远无法实现 libtorch 中第一个 chunk 搭配长度为 0 的 cache,后面的 chunk 搭配长度为  required_cache_size 的 cache 的操作,此时上述修改便是必须的。

至此,经过我们对第一版方案的逐点优化,我们以完全零注入修改的方式实现了 ONNX 模型的导出,详细代码见 PR (https://github.com/wenet-e2e/wenet/pull/1023)。

补充介绍

除了第一版方案遇到的四个问题外,我们在实践过程中还遇到了如下两个问题:

五:最初的方案只涉及 U2 模型,对于 U2++ 模型中双向 decoder 的导出,反向 decoder 需要构造反向输入,其中涉及到 pad_sequence 这个 op,ONNX 是不支持导出的。 六:超参数的存取问题,不希望单独写到一个文件中,最好可以和 ONNX 模型耦合
我们的解决方案:
5. 针对该问题,由@Mddct重新设计了一个与 pad_sequence 等价且能被 ONNX 感知到 shape 变化的函数 https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/asr_model.py#L683-L721
# NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used
#   in `reverse_pad_list` thus we have to refine the below code.
#   Issue: https://github.com/wenet-e2e/wenet/issues/1113
# Equal to:
#   >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
#   >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
max_len = torch.max(r_hyps_lens)
index_range = torch.arange(0, max_len, 1).to(encoder_out.device)
seq_len_expand = r_hyps_lens.unsqueeze(1)
seq_mask = seq_len_expand > index_range  # (beam, max_len)
#   >>> seq_mask
#   >>> tensor([[ True,  True,  True],
#   >>>         [ True,  True,  True],
#   >>>         [ True, False, False]])
index = (seq_len_expand - 1) - index_range  # (beam, max_len)
#   >>> index
#   >>> tensor([[ 2,  1,  0],
#   >>>         [ 2,  1,  0],
#   >>>         [ 0, -1, -2]])
index = index * seq_mask
#   >>> index
#   >>> tensor([[2, 1, 0],
#   >>>         [2, 1, 0],
#   >>>         [0, 0, 0]])
r_hyps = torch.gather(r_hyps, 1, index)
#   >>> r_hyps
#   >>> tensor([[3, 2, 1],
#   >>>         [4, 8, 9],
#   >>>         [2, 2, 2]])
r_hyps = torch.where(seq_mask, r_hyps, self.eos)
#   >>> r_hyps
#   >>> tensor([[3, 2, 1],
#   >>>         [4, 8, 9],
#   >>>         [2, eos, eos]])
r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1)
#   >>> r_hyps
#   >>> tensor([[sos, 3, 2, 1],
#   >>>         [sos, 4, 8, 9],
#   >>>         [sos, 2, eos, eos]])
6. 通过 onnx 的 metadata 接口将超参数全部存入 onnx 模型
# write
onnx_encoder = onnx.load(encoder_outpath)
for (k, v) in args.items():
    meta = onnx_encoder.metadata_props.add()
    meta.key, meta.value = str(k), str(v)
onnx.save(onnx_encoder, encoder_outpath)
# read
ort_session = onnxruntime.InferenceSession(encoder_outpath)
meta = ort_session.get_modelmeta()
print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map))

后记

作为一种几乎公认的中间格式模型定义,ONNX 之于 WeNet 社区小伙伴们的意义已经不止于“加速”:有了这样一个万能的“桥梁”, OpenVINO / MNN / NCNN 推理还会远吗?正所谓:

蒹葭苍苍,白露为霜。所谓伊人,在水一方。溯洄从之,道阻且长。“渡桥”从之,宛在手中央。

语音技术社群邀请函 

?

长按添加语音小管家

备注:姓名-学校/公司-研究方向-城市(如:张三-清华-语音合成-北京)即可申请加入语音之家语音合成/语音增强/声纹识别/说话人日志等技术交流群

每月大咖直播分享、求职内推、干货资讯汇总、与 5000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企语音开发者互动交流~


永久福利 直投简历

(简历投递):[email protected]

语音杂谈内推助力,leader直收简历

企业招聘旺季,推荐机会不容错过

觉得本篇文章不错?


① 点击右下角“在看”,让更多的人看到这篇文章;② 分享给你的朋友圈;③ 关注语音杂谈公众号。

扫码关注我们


语音人的技术客栈

专注于语音技术分享与干货推送

这篇关于论如何优雅地在 WeNet 中支持 ONNX 导出的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!