Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward() got an unexpected keyword argument 'log_probs' #220

Open
ChaofanTao opened this issue Feb 4, 2024 · 1 comment
Open

forward() got an unexpected keyword argument 'log_probs' #220

ChaofanTao opened this issue Feb 4, 2024 · 1 comment

Comments

@ChaofanTao
Copy link

Environment info

  • Platform: Linux
  • Python version: 3.9.18
  • PyTorch version (GPU?): 2.0.0+cu118
  • Using GPU in script?: yes

Information

I want to train context-net on the librispeech dataset. Here is my training script located in openspeech/scripts: (First time I set dataset.dataset_download=True to download the dataset).

# sh scripts/train.sh 
python3 ./openspeech_cli/hydra_train.py \
    dataset=librispeech \
    dataset.dataset_download=False \
    dataset.dataset_path=$DATASET_PATH \
    dataset.manifest_file_path=$MANIFEST_FILE_PATH \
    tokenizer=libri_subword \
    model=contextnet \
    audio=fbank \
    lr_scheduler=warmup_reduce_lr_on_plateau \
    trainer=gpu \
   criterion=cross_entropy

It returns

-- Process 0 terminated with the following error:                                                               
Traceback (most recent call last):                                                                              
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 
69, in _wrap                                                                                                    
    fn(i, *args)                                                                                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers
/multiprocessing.py", line 139, in _wrapping_function                                                           
    results = function(*args, **kwargs)                                                                         
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 645, in _fit_impl                                                                                         
    self._run(model, ckpt_path=self.ckpt_path)                                                                  
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1098, in _run      
     results = self._run_stage()                                                                                   File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1177, in _run_stage                                                                                       
    self._run_train()                                                                                           
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1190, in _run_train                                                                                           self._run_sanity_check()                                                                                    
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1262, in _run_sanity_check                                                                                
    val_loop.run()                                                                                              
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line
 199, in run                                                                                                    
    self.advance(*args, **kwargs)                                                                               
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/eva
luation_loop.py", line 152, in advance                                                                          
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line
 199, in run                                                                                                    
    self.advance(*args, **kwargs)                                                                               
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluati
on_epoch_loop.py", line 137, in advance                                                                         
    output = self._evaluation_step(**kwargs)                                                                    
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluati
on_epoch_loop.py", line 234, in _evaluation_step                                                                
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())                                      
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1480, in _call_strategy_hook                                                                              
    output = fn(*args, **kwargs)                                                                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp_spawn
.py", line 288, in validation_step                                                                              
    return self.model(*args, **kwargs)         
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", lin
e 1156, in forward                                                                                              
    output = self._run_ddp_forward(*inputs, **kwargs)                                                             File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", lin
e 1110, in _run_ddp_forward                                                                                     
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501
, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", 
line 110, in forward                                                                                            
    return self._forward_module.validation_step(*inputs, **kwargs)                                              
  File "/home/mnt/cftao/openspeech/openspeech/models/contextnet/model.py", line 133, in validation_step         
    return self.collect_outputs(                                                                                
  File "/home/mnt/cftao/openspeech/openspeech/models/openspeech_ctc_model.py", line 73, in collect_outputs      
    loss = self.criterion(                                                                                      
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501
, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
TypeError: forward() got an unexpected keyword argument 'log_probs'    

How to solve this problem? Thanks.

@upskyy
Copy link
Member

upskyy commented Feb 5, 2024

@ChaofanTao
Thank you for reporting the issue. I will check and leave a comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants