diff --git a/convlab2/policy/genTUS/train_model.py b/convlab2/policy/genTUS/train_model.py index be59fca76a0f0612f5501c9dbe2a91102edf548e..bafe79a7b8b03403ec7e07de91fab150dc074fc6 100644 --- a/convlab2/policy/genTUS/train_model.py +++ b/convlab2/policy/genTUS/train_model.py @@ -201,6 +201,7 @@ def compute_metrics(eval_preds): preds, labels = eval_preds if isinstance(preds, tuple): preds = preds[0] + preds = np.where(preds != -100, preds, TOKENIZER.pad_token_id) decoded_preds = TOKENIZER.batch_decode( preds, skip_special_tokens=True, max_length=400)