Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wenet] nn context biasing #1982

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Conversation

kaixunhuang0
Copy link
Collaborator

@kaixunhuang0 kaixunhuang0 commented Aug 31, 2023

The Deep biasing method comes from: https://arxiv.org/abs/2305.12493

The pre-trained ASR model is fine-tuned to achieve biasing. During the training process, the original parameters of the ASR model are frozen, and only the parameters related to deep biasing are trained. use_dynamic_chunk cannot be enabled during fine-tuning (the biasing effect will decrease), but the biasing effects of streaming and non-streaming inference are basically the same.

RESULT:
Model link: https://huggingface.co/kxhuang/Wenet_Librispeech_deep_biasing/tree/main
(I used the BLSTM forward state incorrectly when training this model, so to test this model you need to change the -2 to 0 in the forward function of the BLSTM class in wenet/transformer/context_module.py)

Using the Wenet Librispeech pre-trained AED model, after fine-tuning for 30 epochs, the final model was obtained with an average of 3 epochs. The following are the test results of the Librispeech test other.
The context list for the test set is sourced from: https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias

Non-streaming inference:

Method List size Graph score Biasing score WER U-WER B-WER
baseline / / / 8.77 5.58 36.84
context graph 3838 3.0 / 7.75 5.83 24.62
deep biasing 3838 / 1.5 7.93 5.92 25.64
context graph
+ deep biasing
3838 2.0 1.0 7.66 6.08 21.48
context graph 100 3.0 / 7.32 5.45 23.70
deep biasing 100 / 2.0 7.08 5.33 22.41
context graph
+ deep biasing
100 2.5 1.5 6.55 5.33 17.27

Streaming inference (chunk 16):

Method List size Graph score Biasing score WER U-WER B-WER
baseline / / / 10.47 7.07 40.30
context graph 100 3.0 / 9.06 6.99 27.21
deep biasing 100 / 2.0 8.86 6.87 26.28
context graph
+ deep biasing
100 2.5 1.5 8.17 6.85 19.72

@kaixunhuang0 kaixunhuang0 marked this pull request as ready for review September 6, 2023 00:49
_, last_state = self.sen_rnn(pack_seq)
laste_h = last_state[0]
laste_c = last_state[1]
state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi,这里的实现是最后一层BLSTM的reverse last_h_state和第一层的forward last_h_state?
torch.nn.LSTM
**h_n**: tensor of shape :math:(D * \text{num_layers}, H_{out}) for unbatched input or :math:(D * \text{num_layers}, N, H_{out})containing the final hidden state for each element in the sequence. When ``bidirectional=True``,h_n will contain a concatenation of the final forward and reverse hidden states, respectively.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是我写错了,0应该改成-2,感谢指正

for utt_label in batch_label:
st_index_list = []
for i in range(len(utt_label)):
if '▁' not in symbol_table:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想请问下,这里如果我的建模单元是中文汉字+英文bpe,这里是不是不太适用,需要改下?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,我自己训练的时候都是纯中文或者纯英文,英文在热词采样的时候对下划线特殊处理了下保证不会采样出半个词的情况,如果同时有中文和英文这部分最好是改下

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢~

@zyjcsf
Copy link

zyjcsf commented Sep 28, 2023

可以提供一些模型训练时候的conf.yaml参数设置吗?谢谢

@kaixunhuang0
Copy link
Collaborator Author

可以提供一些模型训练时的conf.yaml参数设置吗?谢谢

上面的模型链接中有我用的yaml文件,可以直接下载

@zyjcsf
Copy link

zyjcsf commented Oct 9, 2023

我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀?

@kaixunhuang0
Copy link
Collaborator Author

我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀?

漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象

@zyjcsf
Copy link

zyjcsf commented Oct 9, 2023

我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀?

漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象

很严重,就是一段一段的漏,原始设置的热词表大小是187,modelscope上开源的热词测试集,然后是设置了context_filtering参数进行过滤,如果过滤后热词表只有【0】的话,基本上就整句话漏了,如果是有热词的情况,也会出现成片漏掉的情况,设置的deep_score=1,filter_threshold=-4。目前训练迭代了17个epoch,loss_bias在10左右

@kaixunhuang0
Copy link
Collaborator Author

我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀?

漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象

很严重,就是一段一段的漏,原始设置的热词表大小是187,modelscope上开源的热词测试集,然后是设置了context_filtering参数进行过滤,如果过滤后热词表只有【0】的话,基本上就整句话漏了,如果是有热词的情况,也会出现成片漏掉的情况,设置的deep_score=1,filter_threshold=-4

那确实很奇怪,总体loss的情况正常吗,正常情况下收敛到差不多的时候,bias loss应该是和ctc loss差不多,总体的loss应该会比没有训练热词模块之前更低一些,在aishell上大概是3.4左右。你用的热词相关的yaml配置是否都和我上面给出的一致

@kaixunhuang0
Copy link
Collaborator Author

我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀?

还有就是我在做aishell1实验的时候发现对于aishell1这种句子大部分都很短的数据集,热词采样的代码需要去掉那个判断采样热词不能交叉的逻辑,不然很容易一句话只能采样出一个热词,这样训出来热词增强的效果会差一些,不过这个问题并不会导致漏字的情况。

@zyjcsf
Copy link

zyjcsf commented Oct 9, 2023

我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀?

漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象

很严重,就是一段一段的漏,原始设置的热词表大小是187,modelscope上开源的热词测试集,然后是设置了context_filtering参数进行过滤,如果过滤后热词表只有【0】的话,基本上就整句话漏了,如果是有热词的情况,也会出现成片漏掉的情况,设置的deep_score=1,filter_threshold=-4

那确实很奇怪,总体loss的情况正常吗,正常情况下收敛到差不多的时候,bias loss应该是和ctc loss差不多,总体的loss应该会比没有训练热词模块之前更低一些,在aishell上大概是3.4左右。你用的热词相关的yaml配置是否都和我上面给出的一致

目前训练出来整体的loss还算是正常,从3.1下降到了2.5,bias loss会比ctc loss高一些。我现在的热词配置就是您给的这个哈

@kaixunhuang0
Copy link
Collaborator Author

我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀?

漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象

很严重,就是一段一段的漏,原始设置的热词表大小是187,modelscope上开源的热词测试集,然后是设置了context_filtering参数进行过滤,如果过滤后热词表只有【0】的话,基本上就整句话漏了,如果是有热词的情况,也会出现成片漏掉的情况,设置的deep_score=1,filter_threshold=-4

那确实很奇怪,总体loss的情况正常吗,正常情况下收敛到差不多的时候,bias loss应该是和ctc loss差不多,总体的loss应该会比没有训练热词模块之前更低一些,在aishell上大概是3.4左右。你用的热词相关的yaml配置是否都和我上面给出的一致

目前训练出来整体的loss还算是正常,从3.1下降到了2.5,bias loss会比ctc loss高一些。我现在的热词配置就是您给的这个哈

会不会是你修改的热词采样部分的代码有点问题,我这边确实没遇到过你描述的状况,也想不出是什么原因,漏字而且还和传入的热词数量有关,理论上来说热词列表只剩个0应该对于正常解码的影响是最小的

@wpupup
Copy link

wpupup commented Oct 18, 2023

您好,我尝试复现您在librispeech的结果,但是在训练热词增强模型时,出现cv loss值不下降的情况(保持在160多),并且train loss也是下降到四五十就不太下降了。 另外,我发现每次训练几个batch时,都会花五六分钟去训练下一个batch,正常情况我的显卡每训练一个batch的时间是30s左右,下面是一小段训练日志。。。

我没修改任何代码,训练conf文件也是您提供那个train_bias, 能大概分析下出现问题的原因吗? 谢谢!

2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 77.121582 loss_att 68.588936 loss_ctc 90.912209 loss_bias 61.188702 lr 0.00001204 rank 3
2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 60.563221 loss_att 53.186646 loss_ctc 73.329880 loss_bias 44.453613 lr 0.00001204 rank 7
2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 66.905380 loss_att 60.219139 loss_ctc 76.915077 loss_bias 55.915321 lr 0.00001204 rank 1
2023-10-17 18:48:14,599 DEBUG TRAIN Batch 0/300 loss 58.367058 loss_att 54.565548 loss_ctc 63.268948 loss_bias 39.683086 lr 0.00001204 rank 0
2023-10-17 18:48:54,507 DEBUG TRAIN Batch 0/400 loss 69.295921 loss_att 62.990799 loss_ctc 78.056396 loss_bias 59.514668 lr 0.00001604 rank 7
2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 60.892227 loss_att 55.707409 loss_ctc 68.627617 loss_bias 43.625130 lr 0.00001604 rank 6
2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 70.570961 loss_att 63.955940 loss_ctc 81.632156 loss_bias 43.738525 lr 0.00001604 rank 2
2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 56.387531 loss_att 51.965221 loss_ctc 61.897652 loss_bias 48.085854 lr 0.00001604 rank 5
2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 57.394482 loss_att 53.534023 loss_ctc 62.557728 loss_bias 38.444881 lr 0.00001604 rank 1
2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 61.427593 loss_att 57.190033 loss_ctc 66.434952 loss_bias 48.802876 lr 0.00001604 rank 4
2023-10-17 18:48:54,513 DEBUG TRAIN Batch 0/400 loss 66.382660 loss_att 61.784157 loss_ctc 71.916908 loss_bias 51.955982 lr 0.00001604 rank 3
2023-10-17 18:48:54,517 DEBUG TRAIN Batch 0/400 loss 69.309433 loss_att 61.884018 loss_ctc 81.042137 loss_bias 55.932556 lr 0.00001604 rank 0
2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 60.114948 loss_att 58.303940 loss_ctc 60.731007 loss_bias 36.096294 lr 0.00002004 rank 7
2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 56.977654 loss_att 53.650196 loss_ctc 61.347378 loss_bias 33.943447 lr 0.00002004 rank 1
2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 56.869381 loss_att 54.544899 loss_ctc 58.243603 loss_bias 40.495705 lr 0.00002004 rank 2
2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 58.940693 loss_att 57.577057 loss_ctc 57.989662 loss_bias 41.328430 lr 0.00002004 rank 4
2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 63.078079 loss_att 60.879333 loss_ctc 64.494652 loss_bias 37.138424 lr 0.00002004 rank 3
2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 62.410076 loss_att 58.739368 loss_ctc 67.138695 loss_bias 38.363663 lr 0.00002004 rank 6
2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 61.162239 loss_att 57.996552 loss_ctc 63.624905 loss_bias 49.239365 lr 0.00002004 rank 5
2023-10-17 18:55:16,909 DEBUG TRAIN Batch 0/500 loss 62.478779 loss_att 60.307823 loss_ctc 63.295692 loss_bias 42.486469 lr 0.00002004 rank 0
2023-10-17 18:55:57,183 DEBUG TRAIN Batch 0/600 loss 62.084000 loss_att 62.485199 loss_ctc 56.836884 loss_bias 43.109840 lr 0.00002404 rank 7
2023-10-17 18:55:57,186 DEBUG TRAIN Batch 0/600 loss 63.226624 loss_att 62.583645 loss_ctc 60.321804 loss_bias 44.050949 lr 0.00002404 rank 3

@kaixunhuang0
Copy link
Collaborator Author

您好,我尝试复现您在librispeech的结果,但是在训练热词增强模型时,出现cv loss值不下降的情况(保持在160多),并且train loss也是下降到四五十就不太下降了。 另外,我发现每次训练几个batch时,都会花五六分钟去训练下一个batch,正常情况我的显卡每训练一个batch的时间是30s左右,下面是一小段训练日志。。。

我没修改任何代码,训练conf文件也是您提供那个train_bias, 能大概分析下出现问题的原因吗? 谢谢!

2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 77.121582 loss_att 68.588936 loss_ctc 90.912209 loss_bias 61.188702 lr 0.00001204 rank 3 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 60.563221 loss_att 53.186646 loss_ctc 73.329880 loss_bias 44.453613 lr 0.00001204 rank 7 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 66.905380 loss_att 60.219139 loss_ctc 76.915077 loss_bias 55.915321 lr 0.00001204 rank 1 2023-10-17 18:48:14,599 DEBUG TRAIN Batch 0/300 loss 58.367058 loss_att 54.565548 loss_ctc 63.268948 loss_bias 39.683086 lr 0.00001204 rank 0 2023-10-17 18:48:54,507 DEBUG TRAIN Batch 0/400 loss 69.295921 loss_att 62.990799 loss_ctc 78.056396 loss_bias 59.514668 lr 0.00001604 rank 7 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 60.892227 loss_att 55.707409 loss_ctc 68.627617 loss_bias 43.625130 lr 0.00001604 rank 6 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 70.570961 loss_att 63.955940 loss_ctc 81.632156 loss_bias 43.738525 lr 0.00001604 rank 2 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 56.387531 loss_att 51.965221 loss_ctc 61.897652 loss_bias 48.085854 lr 0.00001604 rank 5 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 57.394482 loss_att 53.534023 loss_ctc 62.557728 loss_bias 38.444881 lr 0.00001604 rank 1 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 61.427593 loss_att 57.190033 loss_ctc 66.434952 loss_bias 48.802876 lr 0.00001604 rank 4 2023-10-17 18:48:54,513 DEBUG TRAIN Batch 0/400 loss 66.382660 loss_att 61.784157 loss_ctc 71.916908 loss_bias 51.955982 lr 0.00001604 rank 3 2023-10-17 18:48:54,517 DEBUG TRAIN Batch 0/400 loss 69.309433 loss_att 61.884018 loss_ctc 81.042137 loss_bias 55.932556 lr 0.00001604 rank 0 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 60.114948 loss_att 58.303940 loss_ctc 60.731007 loss_bias 36.096294 lr 0.00002004 rank 7 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 56.977654 loss_att 53.650196 loss_ctc 61.347378 loss_bias 33.943447 lr 0.00002004 rank 1 2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 56.869381 loss_att 54.544899 loss_ctc 58.243603 loss_bias 40.495705 lr 0.00002004 rank 2 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 58.940693 loss_att 57.577057 loss_ctc 57.989662 loss_bias 41.328430 lr 0.00002004 rank 4 2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 63.078079 loss_att 60.879333 loss_ctc 64.494652 loss_bias 37.138424 lr 0.00002004 rank 3 2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 62.410076 loss_att 58.739368 loss_ctc 67.138695 loss_bias 38.363663 lr 0.00002004 rank 6 2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 61.162239 loss_att 57.996552 loss_ctc 63.624905 loss_bias 49.239365 lr 0.00002004 rank 5 2023-10-17 18:55:16,909 DEBUG TRAIN Batch 0/500 loss 62.478779 loss_att 60.307823 loss_ctc 63.295692 loss_bias 42.486469 lr 0.00002004 rank 0 2023-10-17 18:55:57,183 DEBUG TRAIN Batch 0/600 loss 62.084000 loss_att 62.485199 loss_ctc 56.836884 loss_bias 43.109840 lr 0.00002404 rank 7 2023-10-17 18:55:57,186 DEBUG TRAIN Batch 0/600 loss 63.226624 loss_att 62.583645 loss_ctc 60.321804 loss_bias 44.050949 lr 0.00002404 rank 3

你是不是直接从头开始训练了,为了减少对原本asr性能的影响,我写的是从一个预训练好的asr模型开始训,除了热词模块之外的参数都给冻结了。从头开始训应该也能够收敛,但是至少得把冻结的参数先解冻。

@wpupup
Copy link

wpupup commented Oct 18, 2023

您好,我尝试复现您在librispeech的结果,但是在训练热词增强模型时,出现cv loss值不下降的情况(保持在160多),并且train loss也是下降到四五十就不太下降了。 另外,我发现每次训练几个batch时,都会花五六分钟去训练下一个batch,正常情况我的显卡每训练一个batch的时间是30s左右,下面是一小段训练日志。。。
我没修改任何代码,训练conf文件也是您提供那个train_bias, 能大概分析下出现问题的原因吗? 谢谢!
2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 77.121582 loss_att 68.588936 loss_ctc 90.912209 loss_bias 61.188702 lr 0.00001204 rank 3 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 60.563221 loss_att 53.186646 loss_ctc 73.329880 loss_bias 44.453613 lr 0.00001204 rank 7 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 66.905380 loss_att 60.219139 loss_ctc 76.915077 loss_bias 55.915321 lr 0.00001204 rank 1 2023-10-17 18:48:14,599 DEBUG TRAIN Batch 0/300 loss 58.367058 loss_att 54.565548 loss_ctc 63.268948 loss_bias 39.683086 lr 0.00001204 rank 0 2023-10-17 18:48:54,507 DEBUG TRAIN Batch 0/400 loss 69.295921 loss_att 62.990799 loss_ctc 78.056396 loss_bias 59.514668 lr 0.00001604 rank 7 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 60.892227 loss_att 55.707409 loss_ctc 68.627617 loss_bias 43.625130 lr 0.00001604 rank 6 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 70.570961 loss_att 63.955940 loss_ctc 81.632156 loss_bias 43.738525 lr 0.00001604 rank 2 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 56.387531 loss_att 51.965221 loss_ctc 61.897652 loss_bias 48.085854 lr 0.00001604 rank 5 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 57.394482 loss_att 53.534023 loss_ctc 62.557728 loss_bias 38.444881 lr 0.00001604 rank 1 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 61.427593 loss_att 57.190033 loss_ctc 66.434952 loss_bias 48.802876 lr 0.00001604 rank 4 2023-10-17 18:48:54,513 DEBUG TRAIN Batch 0/400 loss 66.382660 loss_att 61.784157 loss_ctc 71.916908 loss_bias 51.955982 lr 0.00001604 rank 3 2023-10-17 18:48:54,517 DEBUG TRAIN Batch 0/400 loss 69.309433 loss_att 61.884018 loss_ctc 81.042137 loss_bias 55.932556 lr 0.00001604 rank 0 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 60.114948 loss_att 58.303940 loss_ctc 60.731007 loss_bias 36.096294 lr 0.00002004 rank 7 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 56.977654 loss_att 53.650196 loss_ctc 61.347378 loss_bias 33.943447 lr 0.00002004 rank 1 2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 56.869381 loss_att 54.544899 loss_ctc 58.243603 loss_bias 40.495705 lr 0.00002004 rank 2 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 58.940693 loss_att 57.577057 loss_ctc 57.989662 loss_bias 41.328430 lr 0.00002004 rank 4 2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 63.078079 loss_att 60.879333 loss_ctc 64.494652 loss_bias 37.138424 lr 0.00002004 rank 3 2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 62.410076 loss_att 58.739368 loss_ctc 67.138695 loss_bias 38.363663 lr 0.00002004 rank 6 2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 61.162239 loss_att 57.996552 loss_ctc 63.624905 loss_bias 49.239365 lr 0.00002004 rank 5 2023-10-17 18:55:16,909 DEBUG TRAIN Batch 0/500 loss 62.478779 loss_att 60.307823 loss_ctc 63.295692 loss_bias 42.486469 lr 0.00002004 rank 0 2023-10-17 18:55:57,183 DEBUG TRAIN Batch 0/600 loss 62.084000 loss_att 62.485199 loss_ctc 56.836884 loss_bias 43.109840 lr 0.00002404 rank 7 2023-10-17 18:55:57,186 DEBUG TRAIN Batch 0/600 loss 63.226624 loss_att 62.583645 loss_ctc 60.321804 loss_bias 44.050949 lr 0.00002404 rank 3

你是不是直接从头开始训练了,为了减少对原本asr性能的影响,我写的是从一个预训练好的asr模型开始训,除了热词模块之外的参数都给冻结了。从头开始训应该也能够收敛,但是至少得把冻结的参数先解冻。

没有,也是用的之前在librispeech上预训练好的asr模型,做了参数冻结

@kaixunhuang0
Copy link
Collaborator Author

您好,我再自己的asr模型上,用aishell训练集进行热词模块的微调训练,现在从训练loss看,train的ctc loss与bias loss接近,但是在cv loss上,bias loss会比ctc loss高很多(ctc loss是1.5,bias loss是8.7). 然后用模型解码的时候,不用热词模块的时候,aishell_hotwords测试集(wer=14.06, U-wer=5.9, B-Wer=43.37),然后我用deep bias,不用context_filtering,热词是187个,测试集wer=13.88 U-wer=8.82, B-WER=32.04; 最后当我用context_filtering ,过滤阈值设为-4的时候,我的模型就存在大量漏字,完全解不出来的情况,请教一下这会是什么问题?我看加了过滤后,每句解码的热词词表里是包含该句的热词的。非常感谢。

我之前在aishell1上测试的时候没有遇到你说的情况,cv loss不正常、解码时U-WER明显增强和过滤丢字都挺奇怪的,感觉训的不太对。你的asr模型有没有啥特殊的地方,训热词是直接用这里的代码吗还是有加什么修改

@SwingSoulF
Copy link

您好,我尝试复现您在librispeech的结果,但是在训练热词增强模型时,出现cv loss值不下降的情况(保持在160多),并且train loss也是下降到四五十就不太下降了。 另外,我发现每次训练几个batch时,都会花五六分钟去训练下一个batch,正常情况我的显卡每训练一个batch的时间是30s左右,下面是一小段训练日志。。。
我没修改任何代码,训练conf文件也是您提供那个train_bias, 能大概分析下出现问题的原因吗? 谢谢!
2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 77.121582 loss_att 68.588936 loss_ctc 90.912209 loss_bias 61.188702 lr 0.00001204 rank 3 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 60.563221 loss_att 53.186646 loss_ctc 73.329880 loss_bias 44.453613 lr 0.00001204 rank 7 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 66.905380 loss_att 60.219139 loss_ctc 76.915077 loss_bias 55.915321 lr 0.00001204 rank 1 2023-10-17 18:48:14,599 DEBUG TRAIN Batch 0/300 loss 58.367058 loss_att 54.565548 loss_ctc 63.268948 loss_bias 39.683086 lr 0.00001204 rank 0 2023-10-17 18:48:54,507 DEBUG TRAIN Batch 0/400 loss 69.295921 loss_att 62.990799 loss_ctc 78.056396 loss_bias 59.514668 lr 0.00001604 rank 7 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 60.892227 loss_att 55.707409 loss_ctc 68.627617 loss_bias 43.625130 lr 0.00001604 rank 6 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 70.570961 loss_att 63.955940 loss_ctc 81.632156 loss_bias 43.738525 lr 0.00001604 rank 2 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 56.387531 loss_att 51.965221 loss_ctc 61.897652 loss_bias 48.085854 lr 0.00001604 rank 5 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 57.394482 loss_att 53.534023 loss_ctc 62.557728 loss_bias 38.444881 lr 0.00001604 rank 1 2023-10-17 18:48:54,512 DEBUG TRAIN Batch 0/400 loss 61.427593 loss_att 57.190033 loss_ctc 66.434952 loss_bias 48.802876 lr 0.00001604 rank 4 2023-10-17 18:48:54,513 DEBUG TRAIN Batch 0/400 loss 66.382660 loss_att 61.784157 loss_ctc 71.916908 loss_bias 51.955982 lr 0.00001604 rank 3 2023-10-17 18:48:54,517 DEBUG TRAIN Batch 0/400 loss 69.309433 loss_att 61.884018 loss_ctc 81.042137 loss_bias 55.932556 lr 0.00001604 rank 0 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 60.114948 loss_att 58.303940 loss_ctc 60.731007 loss_bias 36.096294 lr 0.00002004 rank 7 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 56.977654 loss_att 53.650196 loss_ctc 61.347378 loss_bias 33.943447 lr 0.00002004 rank 1 2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 56.869381 loss_att 54.544899 loss_ctc 58.243603 loss_bias 40.495705 lr 0.00002004 rank 2 2023-10-17 18:55:16,906 DEBUG TRAIN Batch 0/500 loss 58.940693 loss_att 57.577057 loss_ctc 57.989662 loss_bias 41.328430 lr 0.00002004 rank 4 2023-10-17 18:55:16,907 DEBUG TRAIN Batch 0/500 loss 63.078079 loss_att 60.879333 loss_ctc 64.494652 loss_bias 37.138424 lr 0.00002004 rank 3 2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 62.410076 loss_att 58.739368 loss_ctc 67.138695 loss_bias 38.363663 lr 0.00002004 rank 6 2023-10-17 18:55:16,908 DEBUG TRAIN Batch 0/500 loss 61.162239 loss_att 57.996552 loss_ctc 63.624905 loss_bias 49.239365 lr 0.00002004 rank 5 2023-10-17 18:55:16,909 DEBUG TRAIN Batch 0/500 loss 62.478779 loss_att 60.307823 loss_ctc 63.295692 loss_bias 42.486469 lr 0.00002004 rank 0 2023-10-17 18:55:57,183 DEBUG TRAIN Batch 0/600 loss 62.084000 loss_att 62.485199 loss_ctc 56.836884 loss_bias 43.109840 lr 0.00002404 rank 7 2023-10-17 18:55:57,186 DEBUG TRAIN Batch 0/600 loss 63.226624 loss_att 62.583645 loss_ctc 60.321804 loss_bias 44.050949 lr 0.00002404 rank 3

你是不是直接从头开始训练了,为了减少对原本asr性能的影响,我写的是从一个预训练好的asr模型开始训,除了热词模块之外的参数都给冻结了。从头开始训应该也能够收敛,但是至少得把冻结的参数先解冻。

没有,也是用的之前在librispeech上预训练好的asr模型,做了参数冻结

我试了下,用直接clone下来的代码+github上预训练的librispeech模型+我提供的yaml是可以正常收敛的,大概在1000个batch的时候loss就降到10了。会不会是你用的预训练asr模型和我提供的yaml里面某些参数对不上,导致模型随机初始化了一些参数并且还被冻结了。

我试过重头开始训练是可以收敛的。。。。 对于预训练模型,我对比了和你的训练参数是一模一样的,这就奇怪了。。。。你用的是哪一个预训练的librispeech呢? 我再检查下原因

我用的就是wenet在github上提供下载的这个librispeech模型 https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.en.md

你好,我核对了wenet公布的librispeech的en模型(解压后前缀20210610那个),和你发布的nn_bias模型。发nn_bias/bias_model/units.txt和20210610模型里的uinit.txt无法匹配(不匹配的点在eos/eos)。初始化参数的基础模型是不是搞错了呢?

@kaixunhuang0
Copy link
Collaborator Author

你好,我核对了wenet公布的librispeech的en模型(解压后前缀20210610那个),和你发布的nn_bias模型。发nn_bias/bias_model/units.txt和20210610模型里的uinit.txt无法匹配(不匹配的点在eos/eos)。初始化参数的基础模型是不是搞错了呢?

确实不一样,我看了下貌似是wenet上个月更新了一版预训练模型,把字典给改了。新下下来的那个模型最后编辑时间都是23年12月了,我这边之前下载的librispeech模型最后编辑时间是22年9月,之前初始化模型用的参数是22年9月的那个预训练模型

@SwingSoulF
Copy link

你好,我核对了wenet公布的librispeech的en模型(解压后前缀20210610那个),和你发布的nn_bias模型。发nn_bias/bias_model/units.txt和20210610模型里的uinit.txt无法匹配(不匹配的点在eos/eos)。初始化参数的基础模型是不是搞错了呢?

确实不一样,我看了下貌似是wenet上个月更新了一版预训练模型,把字典给改了。新下下来的那个模型最后编辑时间都是23年12月了,我这边之前下载的librispeech模型最后编辑时间是22年9月,之前初始化模型用的参数是22年9月的那个预训练模型

方便上传一下wenet历史的基础模型嘛,谢谢呢~我在复现librispeech结果时,也同样遇到了loss下不去的问题。排查后才发现是字典不匹配

@kaixunhuang0
Copy link
Collaborator Author

方便上传一下wenet历史的基础模型嘛,谢谢呢~我在复现librispeech结果时,也同样遇到了loss下不去的问题。排查后才发现是字典不匹配

传到上面那个huggingface链接里了

@programYoung
Copy link

请问下有测试中文的U-WER,B-WER的脚本和性能吗

@csf123123
Copy link

您好,我再自己的asr模型上,用aishell训练集进行热词模块的微调训练,现在从训练loss看,train的ctc loss与bias loss接近,但是在cv loss上,bias loss会比ctc loss高很多(ctc loss是1.5,bias loss是8.7). 然后用模型解码的时候,不用热词模块的时候,aishell_hotwords测试集(wer=14.06, U-wer=5.9, B-Wer=43.37),然后我用deep bias,不用context_filtering,热词是187个,测试集wer=13.88 U-wer=8.82, B-WER=32.04; 最后当我用context_filtering ,过滤阈值设为-4的时候,我的模型就存在大量漏字,完全解不出来的情况,请教一下这会是什么问题?我看加了过滤后,每句解码的热词词表里是包含该句的热词的。非常感谢。

我之前在aishell1上测试的时候没有遇到你说的情况,cv loss不正常、解码时U-WER明显增强和过滤丢字都挺奇怪的,感觉训的不太对。你的asr模型有没有啥特殊的地方,训热词是直接用这里的代码吗还是有加什么修改

非常感谢,代码没有改过,只是换了模型与发音词典。

@kaixunhuang0
Copy link
Collaborator Author

请问下有测试中文的U-WER,B-WER的脚本和性能吗

测试脚本传huggingface上了,脚本代码后来改过一点,所以测出来可能和表里会有点区别

@kaixunhuang0
Copy link
Collaborator Author

您好,我换成librispeech也是这么个情况,完全用的您这边的代码程序,模型也是您刚上传的librispeech的pretrain model

至少只用我这边代码的话其他人训过应该是没问题的,你要不再检查下,我确实没见过这种情况,也判断不出来是哪里有问题

@fclearner
Copy link
Contributor

大佬,试了下你的cppn,长尾词召回效果挺好的,赞一个!
然后我看了下runtime的部分,大概要调整的地方有:
1、非runtime部分:支持deep_biasing模块模型导出
2、runtime相关---新建deep_biasing类,沿用context_graph热词词表整理的代码,然后实现二阶段热词筛选(其实这个热词筛选在context_graph里面应该也挺有用的)
3、runtime相关---onnx推理部分:onnx_asr_model里面定义一个deep_biasing_Ort,然后在encoder输出的部分过一下deep_biasing(这里需要支持一个超参配置传入&&是否使用神经网络热词的传入&&以及一些默认值的配置)

大佬有补充不
不过如果把context_module放到encoder里面感觉就不用做模型导出和模块定义的操作了

@dahu1
Copy link

dahu1 commented Jan 16, 2024

大佬,试了下你的cppn,长尾词召回效果挺好的,赞一个! 然后我看了下runtime的部分,大概要调整的地方有: 1、非runtime部分:支持deep_biasing模块模型导出 2、runtime相关---新建deep_biasing类,沿用context_graph热词词表整理的代码,然后实现二阶段热词筛选(其实这个热词筛选在context_graph里面应该也挺有用的) 3、runtime相关---onnx推理部分:onnx_asr_model里面定义一个deep_biasing_Ort,然后在encoder输出的部分过一下deep_biasing(这里需要支持一个超参配置传入&&是否使用神经网络热词的传入&&以及一些默认值的配置)

大佬有补充不 不过如果把context_module放到encoder里面感觉就不用做模型导出和模块定义的操作了

你的思路很对,我基本上是在triton上改的,我遇到一个问题,因为deepbiasing 是对每个词做context_emb ,所以好像没法做batch(一个batch里的音频不一样,两阶段算法后的热词数和热词字数可能不一样,context_list 的size不固定),也就是没法做onnx导出,于是我就在triton里面加了个pytorch推理的代码。。感觉有点丑陋,这周我把code发到仓库里,大佬有空一起看下,看看怎么改进一下。。

@dahu1 dahu1 mentioned this pull request Jan 17, 2024
@fclearner
Copy link
Contributor

fclearner commented Jan 21, 2024

大佬,试了下你的cppn,长尾词召回效果挺好的,赞一个! 然后我看了下runtime的部分,大概要调整的地方有: 1、非runtime部分:支持deep_biasing模块模型导出 2、runtime相关---新建deep_biasing类,沿用context_graph热词词表整理的代码,然后实现二阶段热词筛选(其实这个热词筛选在context_graph里面应该也挺有用的) 3、runtime相关---onnx推理部分:onnx_asr_model里面定义一个deep_biasing_Ort,然后在encoder输出的部分过一下deep_biasing(这里需要支持一个超参配置传入&&是否使用神经网络热词的传入&&以及一些默认值的配置)
大佬有补充不 不过如果把context_module放到encoder里面感觉就不用做模型导出和模块定义的操作了

你的思路很对,我基本上是在triton上改的,我遇到一个问题,因为deepbiasing 是对每个词做context_emb ,所以好像没法做batch(一个batch里的音频不一样,两阶段算法后的热词数和热词字数可能不一样,context_list 的size不固定),也就是没法做onnx导出,于是我就在triton里面加了个pytorch推理的代码。。感觉有点丑陋,这周我把code发到仓库里,大佬有空一起看下,看看怎么改进一下。。

我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:

context_list = torch.randint(low=0, high=args['vocab_size'],
                             size=(200, 10))
context_lengths = torch.randint(low=1, high=10,
                                size=(200,))
context_lengths = torch.tensor([x.size(0) for x in context_list],
                               dtype=torch.int32)
encoder_out = torch.randn((1, 200, args['output_size']))

print("\tStage-4.2: torch.onnx.export")
dynamic_axes = {'context_list': {0: 'context_lengths'},    # 批量动态轴
                'context_lengths': {0: 'context_lengths'},
                'encoder_out': {1: 'T'},
                'encoder_bias_out': {1: 'T'}
                }
inputs = (context_list, context_lengths, encoder_out, 1.0, True)
torch.onnx.export(context_module,
                  inputs,
                  context_module_outpath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names=['context_list', 'context_lengths', 
                               'encoder_out', 'biasing_score', 'recognize'],
                  output_names=['encoder_bias_out', 'bias_out'],
                  dynamic_axes=dynamic_axes,
                  verbose=True)

@dahu1
Copy link

dahu1 commented Jan 22, 2024

我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:

context_list = torch.randint(low=0, high=args['vocab_size'],
                             size=(200, 10))
context_lengths = torch.randint(low=1, high=10,
                                size=(200,))
context_lengths = torch.tensor([x.size(0) for x in context_list],
                               dtype=torch.int32)
encoder_out = torch.randn((1, 200, args['output_size']))

print("\tStage-4.2: torch.onnx.export")
dynamic_axes = {'context_list': {0: 'context_lengths'},    # 批量动态轴
                'context_lengths': {0: 'context_lengths'},
                'encoder_out': {1: 'T'},
                'encoder_bias_out': {1: 'T'}
                }
inputs = (context_list, context_lengths, encoder_out, 1.0, True)
torch.onnx.export(context_module,
                  inputs,
                  context_module_outpath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names=['context_list', 'context_lengths', 
                               'encoder_out', 'biasing_score', 'recognize'],
                  output_names=['encoder_bias_out', 'bias_out'],
                  dynamic_axes=dynamic_axes,
                  verbose=True)

导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。

@fclearner
Copy link
Contributor

我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:

context_list = torch.randint(low=0, high=args['vocab_size'],
                             size=(200, 10))
context_lengths = torch.randint(low=1, high=10,
                                size=(200,))
context_lengths = torch.tensor([x.size(0) for x in context_list],
                               dtype=torch.int32)
encoder_out = torch.randn((1, 200, args['output_size']))

print("\tStage-4.2: torch.onnx.export")
dynamic_axes = {'context_list': {0: 'context_lengths'},    # 批量动态轴
                'context_lengths': {0: 'context_lengths'},
                'encoder_out': {1: 'T'},
                'encoder_bias_out': {1: 'T'}
                }
inputs = (context_list, context_lengths, encoder_out, 1.0, True)
torch.onnx.export(context_module,
                  inputs,
                  context_module_outpath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names=['context_list', 'context_lengths', 
                               'encoder_out', 'biasing_score', 'recognize'],
                  output_names=['encoder_bias_out', 'bias_out'],
                  dynamic_axes=dynamic_axes,
                  verbose=True)

导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。

torch导出不是支持dynamic_axes吗,我这里是给了动态size的,后面的测试目前看是过了的

@fclearner
Copy link
Contributor

我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:

context_list = torch.randint(low=0, high=args['vocab_size'],
                             size=(200, 10))
context_lengths = torch.randint(low=1, high=10,
                                size=(200,))
context_lengths = torch.tensor([x.size(0) for x in context_list],
                               dtype=torch.int32)
encoder_out = torch.randn((1, 200, args['output_size']))

print("\tStage-4.2: torch.onnx.export")
dynamic_axes = {'context_list': {0: 'context_lengths'},    # 批量动态轴
                'context_lengths': {0: 'context_lengths'},
                'encoder_out': {1: 'T'},
                'encoder_bias_out': {1: 'T'}
                }
inputs = (context_list, context_lengths, encoder_out, 1.0, True)
torch.onnx.export(context_module,
                  inputs,
                  context_module_outpath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names=['context_list', 'context_lengths', 
                               'encoder_out', 'biasing_score', 'recognize'],
                  output_names=['encoder_bias_out', 'bias_out'],
                  dynamic_axes=dynamic_axes,
                  verbose=True)

导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。

大佬,你尝试过改runtime代码吗,很难对原始代码不侵入,需要调整ForwardEncoderFunc, 加一个热词输入处理以及热词模块前向,主要ForwardEncoderFunc这个函数把ctc_prob计算也加进去了,其实每个模块抽出来会比较好改

@dahu1
Copy link

dahu1 commented Jan 30, 2024

大佬,你尝试过改runtime代码吗,很难对原始代码不侵入,需要调整ForwardEncoderFunc, 加一个热词输入处理以及热词模块前向,主要ForwardEncoderFunc这个函数把ctc_prob计算也加进去了,其实每个模块抽出来会比较好改

你可以看下我仓库里的nn_bias 分支,我已经把代码合进来了,我做了如下改动:

  1. deepbias 模块用torch直接导出,没有直接转onnx,如果你能转的话,你后面更新一下。链接,然后再进行推理
  2. ctc prefix beam search 这一块,我去掉了原始wenet-triton里的那一套,因为那是来自ctc_decoder,里面的热词实现和训练用的python版本的graph热词不匹配,以及ctc search也不匹配,python train的ctc prefix beam search 和runtime/libtorch 下c++版本的是一致的,我这里还用的是python的,一是学习,而是想快速迁移过来看效果的,接下来就是把c++ 的迁移过来,可以看下面的图,python的太慢了。如果你能集成就更好了,我对c++ 还不是很熟。
  3. wenet-triton 里的score改动,参考配置文件 ,以及score/1/model.py ,我加了每条音频自定义热词以及热词权重。

image

@fclearner
Copy link
Contributor

大佬,你尝试过改runtime代码吗,很难对原始代码不侵入,需要调整ForwardEncoderFunc, 加一个热词输入处理以及热词模块前向,主要ForwardEncoderFunc这个函数把ctc_prob计算也加进去了,其实每个模块抽出来会比较好改

你可以看下我仓库里的nn_bias 分支,我已经把代码合进来了,我做了如下改动:

  1. deepbias 模块用torch直接导出,没有直接转onnx,如果你能转的话,你后面更新一下。链接,然后再进行推理
  2. ctc prefix beam search 这一块,我去掉了原始wenet-triton里的那一套,因为那是来自ctc_decoder,里面的热词实现和训练用的python版本的graph热词不匹配,以及ctc search也不匹配,python train的ctc prefix beam search 和runtime/libtorch 下c++版本的是一致的,我这里还用的是python的,一是学习,而是想快速迁移过来看效果的,接下来就是把c++ 的迁移过来,可以看下面的图,python的太慢了。如果你能集成就更好了,我对c++ 还不是很熟。
  3. wenet-triton 里的score改动,参考配置文件 ,以及score/1/model.py ,我加了每条音频自定义热词以及热词权重。

image

好的,感谢大佬分享哈,我正在尝试改runtime,改完了也分享一下

@fclearner
Copy link
Contributor

我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:

context_list = torch.randint(low=0, high=args['vocab_size'],
                             size=(200, 10))
context_lengths = torch.randint(low=1, high=10,
                                size=(200,))
context_lengths = torch.tensor([x.size(0) for x in context_list],
                               dtype=torch.int32)
encoder_out = torch.randn((1, 200, args['output_size']))

print("\tStage-4.2: torch.onnx.export")
dynamic_axes = {'context_list': {0: 'context_lengths'},    # 批量动态轴
                'context_lengths': {0: 'context_lengths'},
                'encoder_out': {1: 'T'},
                'encoder_bias_out': {1: 'T'}
                }
inputs = (context_list, context_lengths, encoder_out, 1.0, True)
torch.onnx.export(context_module,
                  inputs,
                  context_module_outpath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names=['context_list', 'context_lengths', 
                               'encoder_out', 'biasing_score', 'recognize'],
                  output_names=['encoder_bias_out', 'bias_out'],
                  dynamic_axes=dynamic_axes,
                  verbose=True)

导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。

torch导出不是支持dynamic_axes吗,我这里是给了动态size的,后面的测试目前看是过了的

@dahu1 大佬这个模型导出确实是有维度问题的,不过torch导出是支持dynamic_axes,这个确实是有效的;我目前定位到cppn模型内部使用了torch.nn.utils.rnn.pack_padded_sequence,把数据处理成变长的了,这个代码删了,动态维度导出是正常的,但是会对效果有明显影响,我理解可能是引入了padding的数据做embedding计算了:

pack_seq = torch.nn.utils.rnn.pack_padded_sequence(

@fclearner
Copy link
Contributor

我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:

context_list = torch.randint(low=0, high=args['vocab_size'],
                             size=(200, 10))
context_lengths = torch.randint(low=1, high=10,
                                size=(200,))
context_lengths = torch.tensor([x.size(0) for x in context_list],
                               dtype=torch.int32)
encoder_out = torch.randn((1, 200, args['output_size']))

print("\tStage-4.2: torch.onnx.export")
dynamic_axes = {'context_list': {0: 'context_lengths'},    # 批量动态轴
                'context_lengths': {0: 'context_lengths'},
                'encoder_out': {1: 'T'},
                'encoder_bias_out': {1: 'T'}
                }
inputs = (context_list, context_lengths, encoder_out, 1.0, True)
torch.onnx.export(context_module,
                  inputs,
                  context_module_outpath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names=['context_list', 'context_lengths', 
                               'encoder_out', 'biasing_score', 'recognize'],
                  output_names=['encoder_bias_out', 'bias_out'],
                  dynamic_axes=dynamic_axes,
                  verbose=True)

导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。

torch导出不是支持dynamic_axes吗,我这里是给了动态size的,后面的测试目前看是过了的

@dahu1 大佬这个模型导出确实是有维度问题的,不过torch导出是支持dynamic_axes,这个确实是有效的;我目前定位到cppn模型内部使用了torch.nn.utils.rnn.pack_padded_sequence,把数据处理成变长的了,这个代码删了,动态维度导出是正常的,但是会对效果有明显影响,我理解可能是引入了padding的数据做embedding计算了:

pack_seq = torch.nn.utils.rnn.pack_padded_sequence(

找到一个讨论的帖子,按帖子里这么做确实可以正常导出,onnxruntime可以跑通,就是会报warning;python代码最后导出的检查部分精度差距会比较大:17%,我看了下cer也会损失绝对0.1%左右;我在尝试用mask的形式去处理数据
pytorch/pytorch#62240

@fclearner
Copy link
Contributor

大佬,你尝试过改runtime代码吗,很难对原始代码不侵入,需要调整ForwardEncoderFunc, 加一个热词输入处理以及热词模块前向,主要ForwardEncoderFunc这个函数把ctc_prob计算也加进去了,其实每个模块抽出来会比较好改

你可以看下我仓库里的nn_bias 分支,我已经把代码合进来了,我做了如下改动:

  1. deepbias 模块用torch直接导出,没有直接转onnx,如果你能转的话,你后面更新一下。链接,然后再进行推理
  2. ctc prefix beam search 这一块,我去掉了原始wenet-triton里的那一套,因为那是来自ctc_decoder,里面的热词实现和训练用的python版本的graph热词不匹配,以及ctc search也不匹配,python train的ctc prefix beam search 和runtime/libtorch 下c++版本的是一致的,我这里还用的是python的,一是学习,而是想快速迁移过来看效果的,接下来就是把c++ 的迁移过来,可以看下面的图,python的太慢了。如果你能集成就更好了,我对c++ 还不是很熟。
  3. wenet-triton 里的score改动,参考配置文件 ,以及score/1/model.py ,我加了每条音频自定义热词以及热词权重。

image

@dahu1 大佬,我这尝试支持了下cppn的onnxruntime和模型导出,可以做个参考:https://github.com/fclearner/wenet/tree/nn_bias

@fclearner
Copy link
Contributor

我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:

context_list = torch.randint(low=0, high=args['vocab_size'],
                             size=(200, 10))
context_lengths = torch.randint(low=1, high=10,
                                size=(200,))
context_lengths = torch.tensor([x.size(0) for x in context_list],
                               dtype=torch.int32)
encoder_out = torch.randn((1, 200, args['output_size']))

print("\tStage-4.2: torch.onnx.export")
dynamic_axes = {'context_list': {0: 'context_lengths'},    # 批量动态轴
                'context_lengths': {0: 'context_lengths'},
                'encoder_out': {1: 'T'},
                'encoder_bias_out': {1: 'T'}
                }
inputs = (context_list, context_lengths, encoder_out, 1.0, True)
torch.onnx.export(context_module,
                  inputs,
                  context_module_outpath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names=['context_list', 'context_lengths', 
                               'encoder_out', 'biasing_score', 'recognize'],
                  output_names=['encoder_bias_out', 'bias_out'],
                  dynamic_axes=dynamic_axes,
                  verbose=True)

导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。

torch导出不是支持dynamic_axes吗,我这里是给了动态size的,后面的测试目前看是过了的

@dahu1 大佬这个模型导出确实是有维度问题的,不过torch导出是支持dynamic_axes,这个确实是有效的;我目前定位到cppn模型内部使用了torch.nn.utils.rnn.pack_padded_sequence,把数据处理成变长的了,这个代码删了,动态维度导出是正常的,但是会对效果有明显影响,我理解可能是引入了padding的数据做embedding计算了:

pack_seq = torch.nn.utils.rnn.pack_padded_sequence(

这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态

@XiaGuangmin
Copy link

XiaGuangmin commented Feb 25, 2024 via email

@dahu1
Copy link

dahu1 commented Feb 26, 2024

这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态

好的,感谢大佬分享,我学习一下。确实单个推理会比较慢,如果能做到转onnx进行batch推理,会快很多。另外你说的把lstm改成单向的,那是不是性能会有损啊?

@fclearner
Copy link
Contributor

这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态

好的,感谢大佬分享,我学习一下。确实单个推理会比较慢,如果能做到转onnx进行batch推理,会快很多。另外你说的把lstm改成单向的,那是不是性能会有损啊?

我已经把新改好的代码提交了,

这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态

好的,感谢大佬分享,我学习一下。确实单个推理会比较慢,如果能做到转onnx进行batch推理,会快很多。另外你说的把lstm改成单向的,那是不是性能会有损啊?

我这边刚做完测试,lstm单向相较双向确实热词召回率变差了六七个点,目前来看解决变长的onnx导出问题有两种策略:1、参考funasr,使用单向的lstm,状态索引用热词列表长度;2、使用torch.jit.script先转成静态图,但是运行的时候会报warning,需要调整onnx的日志屏蔽级别;我刚更新了第一种的代码,稍后我把第二种也commit一下;如果有大佬有更好的方法也可以回复下

@dahu1
Copy link

dahu1 commented Feb 27, 2024

我这边刚做完测试,lstm单向相较双向确实热词召回率变差了六七个点,目前来看解决变长的onnx导出问题有两种策略:1、参考funasr,使用单向的lstm,状态索引用热词列表长度;2、使用torch.jit.script先转成静态图,但是运行的时候会报warning,需要调整onnx的日志屏蔽级别;我刚更新了第一种的代码,稍后我把第二种也commit一下;如果有大佬有更好的方法也可以回复下

更改为单向lstm后,是不是context 模型需要重新训练?毕竟训练和推理要保持一致。

@fclearner
Copy link
Contributor

我这边刚做完测试,lstm单向相较双向确实热词召回率变差了六七个点,目前来看解决变长的onnx导出问题有两种策略:1、参考funasr,使用单向的lstm,状态索引用热词列表长度;2、使用torch.jit.script先转成静态图,但是运行的时候会报warning,需要调整onnx的日志屏蔽级别;我刚更新了第一种的代码,稍后我把第二种也commit一下;如果有大佬有更好的方法也可以回复下

更改为单向lstm后,是不是context 模型需要重新训练?毕竟训练和推理要保持一致。

是的,需要重新训练,主要是因为lstm的输出需要接一个context_encoder映射成embedding_size,lstm的单向状态仍然是可以通过热词列表长度索引的,我已经更新了双向的模型导出代码了,你先试试吧,因为我是直接copy到github的,没有做测试,可能会有问题:https://github.com/fclearner/wenet/blob/nn_bias/wenet/bin/export_onnx_cpu.py

@fclearner
Copy link
Contributor

这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态

好的,感谢大佬分享,我学习一下。确实单个推理会比较慢,如果能做到转onnx进行batch推理,会快很多。另外你说的把lstm改成单向的,那是不是性能会有损啊?

我已经把新改好的代码提交了,

这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态

好的,感谢大佬分享,我学习一下。确实单个推理会比较慢,如果能做到转onnx进行batch推理,会快很多。另外你说的把lstm改成单向的,那是不是性能会有损啊?

我这边刚做完测试,lstm单向相较双向确实热词召回率变差了六七个点,目前来看解决变长的onnx导出问题有两种策略:1、参考funasr,使用单向的lstm,状态索引用热词列表长度;2、使用torch.jit.script先转成静态图,但是运行的时候会报warning,需要调整onnx的日志屏蔽级别;我刚更新了第一种的代码,稍后我把第二种也commit一下;如果有大佬有更好的方法也可以回复下

记录下问题,
问题一:
这里记录的方法2双向lstm onnx导出会引入通用效果损失(虽然召回率还可以),暂时找不到问题,onnx导出的校验是可以正常通过的;建议使用单向lstm;

问题二:
lstm导出时batchsize大于1会有问题,具体信息:UerWarning: Exporting a model to ONNX with a batch size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model,这里看到funasr的做法是导出时把batch_first设为false;

问题三:
小数据量微调热词模块会引入通用效果损失,参考论文使用全量数据微调,并且用的static batch收敛效果更好

@Swagger-z
Copy link

看起来有点奇怪,单独使用某个方法时,graph score 从 2 到 3,bias score 从 1 到 1.5 我当时实验出来都是有提升的,所以我才会用这个值,但是你测的结果都没什么变化或者更差了

@kaixunhuang0 你在传bwer工具到huggingface上的时候,可否把你跑的识别结果和 解码参数当成log 文件也传一下?我看看和我的区别,按理说我们用的是同一套模型和参数,不应该有区别的。。

请问您跑这个huggingface上的bwer计算脚本跑通了吗,我跑出来一直不对,wer非常高

@XiaGuangmin
Copy link

XiaGuangmin commented May 16, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet