-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wenet] nn context biasing #1982
base: main
Are you sure you want to change the base?
Conversation
…tion problem due to context mismatch.
wenet/transformer/context_module.py
Outdated
_, last_state = self.sen_rnn(pack_seq) | ||
laste_h = last_state[0] | ||
laste_c = last_state[1] | ||
state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi,这里的实现是最后一层BLSTM的reverse last_h_state和第一层的forward last_h_state?
torch.nn.LSTM
**h_n**: tensor of shape :math:
(D * \text{num_layers}, H_{out}) for unbatched input or :math:
(D * \text{num_layers}, N, H_{out})containing the final hidden state for each element in the sequence. When ``bidirectional=True``,
h_n will contain a concatenation of the final forward and reverse hidden states, respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是我写错了,0应该改成-2,感谢指正
…hunk during bias module training
for utt_label in batch_label: | ||
st_index_list = [] | ||
for i in range(len(utt_label)): | ||
if '▁' not in symbol_table: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我想请问下,这里如果我的建模单元是中文汉字+英文bpe,这里是不是不太适用,需要改下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我自己训练的时候都是纯中文或者纯英文,英文在热词采样的时候对下划线特殊处理了下保证不会采样出半个词的情况,如果同时有中文和英文这部分最好是改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢~
可以提供一些模型训练时候的conf.yaml参数设置吗?谢谢 |
上面的模型链接中有我用的yaml文件,可以直接下载 |
我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀? |
漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象 |
很严重,就是一段一段的漏,原始设置的热词表大小是187,modelscope上开源的热词测试集,然后是设置了context_filtering参数进行过滤,如果过滤后热词表只有【0】的话,基本上就整句话漏了,如果是有热词的情况,也会出现成片漏掉的情况,设置的deep_score=1,filter_threshold=-4。目前训练迭代了17个epoch,loss_bias在10左右 |
那确实很奇怪,总体loss的情况正常吗,正常情况下收敛到差不多的时候,bias loss应该是和ctc loss差不多,总体的loss应该会比没有训练热词模块之前更低一些,在aishell上大概是3.4左右。你用的热词相关的yaml配置是否都和我上面给出的一致 |
还有就是我在做aishell1实验的时候发现对于aishell1这种句子大部分都很短的数据集,热词采样的代码需要去掉那个判断采样热词不能交叉的逻辑,不然很容易一句话只能采样出一个热词,这样训出来热词增强的效果会差一些,不过这个问题并不会导致漏字的情况。 |
目前训练出来整体的loss还算是正常,从3.1下降到了2.5,bias loss会比ctc loss高一些。我现在的热词配置就是您给的这个哈 |
会不会是你修改的热词采样部分的代码有点问题,我这边确实没遇到过你描述的状况,也想不出是什么原因,漏字而且还和传入的热词数量有关,理论上来说热词列表只剩个0应该对于正常解码的影响是最小的 |
您好,我尝试复现您在librispeech的结果,但是在训练热词增强模型时,出现cv loss值不下降的情况(保持在160多),并且train loss也是下降到四五十就不太下降了。 另外,我发现每次训练几个batch时,都会花五六分钟去训练下一个batch,正常情况我的显卡每训练一个batch的时间是30s左右,下面是一小段训练日志。。。 我没修改任何代码,训练conf文件也是您提供那个train_bias, 能大概分析下出现问题的原因吗? 谢谢! 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 77.121582 loss_att 68.588936 loss_ctc 90.912209 loss_bias 61.188702 lr 0.00001204 rank 3 |
你是不是直接从头开始训练了,为了减少对原本asr性能的影响,我写的是从一个预训练好的asr模型开始训,除了热词模块之外的参数都给冻结了。从头开始训应该也能够收敛,但是至少得把冻结的参数先解冻。 |
没有,也是用的之前在librispeech上预训练好的asr模型,做了参数冻结 |
我之前在aishell1上测试的时候没有遇到你说的情况,cv loss不正常、解码时U-WER明显增强和过滤丢字都挺奇怪的,感觉训的不太对。你的asr模型有没有啥特殊的地方,训热词是直接用这里的代码吗还是有加什么修改 |
你好,我核对了wenet公布的librispeech的en模型(解压后前缀20210610那个),和你发布的nn_bias模型。发nn_bias/bias_model/units.txt和20210610模型里的uinit.txt无法匹配(不匹配的点在eos/eos)。初始化参数的基础模型是不是搞错了呢? |
确实不一样,我看了下貌似是wenet上个月更新了一版预训练模型,把字典给改了。新下下来的那个模型最后编辑时间都是23年12月了,我这边之前下载的librispeech模型最后编辑时间是22年9月,之前初始化模型用的参数是22年9月的那个预训练模型 |
方便上传一下wenet历史的基础模型嘛,谢谢呢~我在复现librispeech结果时,也同样遇到了loss下不去的问题。排查后才发现是字典不匹配 |
传到上面那个huggingface链接里了 |
请问下有测试中文的U-WER,B-WER的脚本和性能吗 |
非常感谢,代码没有改过,只是换了模型与发音词典。 |
测试脚本传huggingface上了,脚本代码后来改过一点,所以测出来可能和表里会有点区别 |
至少只用我这边代码的话其他人训过应该是没问题的,你要不再检查下,我确实没见过这种情况,也判断不出来是哪里有问题 |
大佬,试了下你的cppn,长尾词召回效果挺好的,赞一个! 大佬有补充不 |
你的思路很对,我基本上是在triton上改的,我遇到一个问题,因为deepbiasing 是对每个词做context_emb ,所以好像没法做batch(一个batch里的音频不一样,两阶段算法后的热词数和热词字数可能不一样,context_list 的size不固定),也就是没法做onnx导出,于是我就在triton里面加了个pytorch推理的代码。。感觉有点丑陋,这周我把code发到仓库里,大佬有空一起看下,看看怎么改进一下。。 |
我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:
|
导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。 |
torch导出不是支持dynamic_axes吗,我这里是给了动态size的,后面的测试目前看是过了的 |
大佬,你尝试过改runtime代码吗,很难对原始代码不侵入,需要调整ForwardEncoderFunc, 加一个热词输入处理以及热词模块前向,主要ForwardEncoderFunc这个函数把ctc_prob计算也加进去了,其实每个模块抽出来会比较好改 |
你可以看下我仓库里的nn_bias 分支,我已经把代码合进来了,我做了如下改动:
|
好的,感谢大佬分享哈,我正在尝试改runtime,改完了也分享一下 |
@dahu1 大佬这个模型导出确实是有维度问题的,不过torch导出是支持dynamic_axes,这个确实是有效的;我目前定位到cppn模型内部使用了torch.nn.utils.rnn.pack_padded_sequence,把数据处理成变长的了,这个代码删了,动态维度导出是正常的,但是会对效果有明显影响,我理解可能是引入了padding的数据做embedding计算了: wenet/wenet/transformer/context_module.py Line 48 in 762e199
|
找到一个讨论的帖子,按帖子里这么做确实可以正常导出,onnxruntime可以跑通,就是会报warning;python代码最后导出的检查部分精度差距会比较大:17%,我看了下cer也会损失绝对0.1%左右;我在尝试用mask的形式去处理数据 |
@dahu1 大佬,我这尝试支持了下cppn的onnxruntime和模型导出,可以做个参考:https://github.com/fclearner/wenet/tree/nn_bias |
这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态 |
您好,您的邮件我已经收到,会尽快给您回复!祝您生活愉快,工作顺利!
|
好的,感谢大佬分享,我学习一下。确实单个推理会比较慢,如果能做到转onnx进行batch推理,会快很多。另外你说的把lstm改成单向的,那是不是性能会有损啊? |
我已经把新改好的代码提交了,
我这边刚做完测试,lstm单向相较双向确实热词召回率变差了六七个点,目前来看解决变长的onnx导出问题有两种策略:1、参考funasr,使用单向的lstm,状态索引用热词列表长度;2、使用torch.jit.script先转成静态图,但是运行的时候会报warning,需要调整onnx的日志屏蔽级别;我刚更新了第一种的代码,稍后我把第二种也commit一下;如果有大佬有更好的方法也可以回复下 |
更改为单向lstm后,是不是context 模型需要重新训练?毕竟训练和推理要保持一致。 |
是的,需要重新训练,主要是因为lstm的输出需要接一个context_encoder映射成embedding_size,lstm的单向状态仍然是可以通过热词列表长度索引的,我已经更新了双向的模型导出代码了,你先试试吧,因为我是直接copy到github的,没有做测试,可能会有问题:https://github.com/fclearner/wenet/blob/nn_bias/wenet/bin/export_onnx_cpu.py |
记录下问题, 问题二: 问题三: |
请问您跑这个huggingface上的bwer计算脚本跑通了吗,我跑出来一直不对,wer非常高 |
您好,您的邮件我已经收到,会尽快给您回复!祝您生活愉快,工作顺利!
|
The Deep biasing method comes from: https://arxiv.org/abs/2305.12493
The pre-trained ASR model is fine-tuned to achieve biasing. During the training process, the original parameters of the ASR model are frozen, and only the parameters related to deep biasing are trained. use_dynamic_chunk cannot be enabled during fine-tuning (the biasing effect will decrease), but the biasing effects of streaming and non-streaming inference are basically the same.
RESULT:
Model link: https://huggingface.co/kxhuang/Wenet_Librispeech_deep_biasing/tree/main
(I used the BLSTM forward state incorrectly when training this model, so to test this model you need to change the -2 to 0 in the forward function of the BLSTM class in wenet/transformer/context_module.py)
Using the Wenet Librispeech pre-trained AED model, after fine-tuning for 30 epochs, the final model was obtained with an average of 3 epochs. The following are the test results of the Librispeech test other.
The context list for the test set is sourced from: https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias
Non-streaming inference:
+ deep biasing
+ deep biasing
Streaming inference (chunk 16):
+ deep biasing