Skip to content
Snippets Groups Projects
Commit 7a091a39 authored by zz-jacob's avatar zz-jacob
Browse files

fix scgpt mask bug

parent 3ac4a347
No related branches found
No related tags found
No related merge requests found
...@@ -31,7 +31,7 @@ class SCGPT(NLG): ...@@ -31,7 +31,7 @@ class SCGPT(NLG):
sent_ids = [sent + [0] * (max_len - len(sent)) for sent in sent_ids] sent_ids = [sent + [0] * (max_len - len(sent)) for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(self.device) inputs = torch.LongTensor(sent_ids).to(self.device)
model_to_run = self.model.module if type(self.model) is DDP else self.model model_to_run = self.model.module if type(self.model) is DDP else self.model
outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs==0).float(), outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs!=0).float(),
eos_token_id=self.tokenizer.pad_token_id) # greedy eos_token_id=self.tokenizer.pad_token_id) # greedy
outputs = outputs[:, len(inputs[0]):] outputs = outputs[:, len(inputs[0]):]
def clean_sentence(sent): def clean_sentence(sent):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment