diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py index aede679019252b4703493155dbe475dea54a33c0..4977b4f3778229293bb60c2e98b7e715ebe227dd 100644 --- a/convlab/nlg/scgpt/scgpt.py +++ b/convlab/nlg/scgpt/scgpt.py @@ -31,7 +31,7 @@ class SCGPT(NLG): sent_ids = [sent + [0] * (max_len - len(sent)) for sent in sent_ids] inputs = torch.LongTensor(sent_ids).to(self.device) model_to_run = self.model.module if type(self.model) is DDP else self.model - outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs==0).float(), + outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs!=0).float(), eos_token_id=self.tokenizer.pad_token_id) # greedy outputs = outputs[:, len(inputs[0]):] def clean_sentence(sent):