From 7a091a39089ab1731173338f34f411d4bf9555bc Mon Sep 17 00:00:00 2001
From: zz-jacob <zhangz.goal@gmail.com>
Date: Wed, 30 Nov 2022 10:46:56 +0800
Subject: [PATCH] fix scgpt mask bug

---
 convlab/nlg/scgpt/scgpt.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py
index aede6790..4977b4f3 100644
--- a/convlab/nlg/scgpt/scgpt.py
+++ b/convlab/nlg/scgpt/scgpt.py
@@ -31,7 +31,7 @@ class SCGPT(NLG):
             sent_ids = [sent + [0] * (max_len - len(sent))  for sent in sent_ids]
             inputs = torch.LongTensor(sent_ids).to(self.device)
             model_to_run = self.model.module if type(self.model) is DDP else self.model
-            outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs==0).float(),
+            outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs!=0).float(),
                                             eos_token_id=self.tokenizer.pad_token_id)  # greedy
             outputs = outputs[:, len(inputs[0]):]
             def clean_sentence(sent):
-- 
GitLab