From 4426a3d405c5d92cbd87d797cdbde7dd1851fee6 Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Thu, 8 Jun 2023 15:49:48 +0200
Subject: [PATCH] fix decode bug because new huggingface version

---
 convlab2/policy/genTUS/train_model.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/convlab2/policy/genTUS/train_model.py b/convlab2/policy/genTUS/train_model.py
index be59fca..bafe79a 100644
--- a/convlab2/policy/genTUS/train_model.py
+++ b/convlab2/policy/genTUS/train_model.py
@@ -201,6 +201,7 @@ def compute_metrics(eval_preds):
     preds, labels = eval_preds
     if isinstance(preds, tuple):
         preds = preds[0]
+    preds = np.where(preds != -100, preds, TOKENIZER.pad_token_id)
     decoded_preds = TOKENIZER.batch_decode(
         preds, skip_special_tokens=True, max_length=400)
 
-- 
GitLab