diff --git a/convlab/nlg/scgpt/evaluate.sh b/convlab/nlg/scgpt/evaluate.sh index ca76577b1da53a685b64a32bf503b511d537dd97..ef357a1da67d48c07959f740356335af274d88c1 100755 --- a/convlab/nlg/scgpt/evaluate.sh +++ b/convlab/nlg/scgpt/evaluate.sh @@ -1,6 +1,6 @@ -CUDA_VISIBLE_DEVICES="1" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2051 main.py \ ---batch_size 64 \ +CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2050 main.py \ +--batch_size 128 \ --base_model_name_path gpt2-medium \ ---dataset multiwoz21 \ ---exp_name gpt2_mwoz2 \ ---model_path saved_models/gpt2_mwoz/epoch_2/epoch_2_step1329.pt \ \ No newline at end of file +--dataset sgd \ +--exp_name gpt2_sgd_test \ +--model_path saved_models/exp_name/epoch_x/epoch_7_step10312.pt \ \ No newline at end of file diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py index 801cc74fa199805cfbf8e9ddba92ec3f049e1ffa..ee5d4c9d2646b9b10133fb8c0b4a4d1ce44b795b 100644 --- a/convlab/nlg/scgpt/scgpt.py +++ b/convlab/nlg/scgpt/scgpt.py @@ -2,27 +2,22 @@ import sys sys.path.append('../../..') import torch -from transformers import GPT2Tokenizer, GPT2LMHeadModel +from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config from torch.nn.parallel import DistributedDataParallel as DDP from convlab.nlg.nlg import NLG from util import act2str from scgpt_special_tokens import * -special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK] class SCGPT(NLG): def __init__(self, dataset_name, model_path, device='cpu'): super(SCGPT, self).__init__() self.device = device - self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) + self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2')).to(self.device) self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, - 'additional_special_tokens': special_tokens}) - self.model.resize_token_embeddings(len(self.tokenizer)) self.model.load_state_dict(torch.load(model_path)) - def generate(self, action): action_str = act2str(action) output = self._inference_batch([action_str])[0] @@ -30,17 +25,14 @@ class SCGPT(NLG): def _inference_batch(self, sents): with torch.no_grad(): - sents = [sent + ' ' + START_OF_PRED for sent in sents] - sent_ids = [self.tokenizer.encode(sent) for sent in sents] + sents = [sent for sent in sents] + sent_ids = [self.tokenizer.encode(sent) + [self.tokenizer._convert_token_to_id_with_added_voc('&')] for sent in sents] max_len = max([len(sent) for sent in sent_ids]) - sent_ids = [[self.tokenizer.pad_token_id] * (max_len - len(sent)) + sent for sent in sent_ids] + sent_ids = [sent + [0] * (max_len - len(sent)) for sent in sent_ids] inputs = torch.LongTensor(sent_ids).to(self.device) model_to_run = self.model.module if type(self.model) is DDP else self.model - outputs = model_to_run.generate(inputs, max_length=256, - eos_token_id=self.tokenizer.pad_token_id, - pad_token_id=self.tokenizer.pad_token_id) # greedy - # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id, - # pad_token_id=gpt2_tokenizer.pad_token_id) # beam search - - output_strs = [self.tokenizer.decode(item) for item in outputs] + outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs==0).float(), + eos_token_id=self.tokenizer.pad_token_id) # greedy + outputs = outputs[:, len(inputs[0]):] + output_strs = [self.tokenizer.decode(item).strip() for item in outputs] return output_strs \ No newline at end of file diff --git a/convlab/nlg/scgpt/train.sh b/convlab/nlg/scgpt/train.sh index 1bf17226712a7d7d13eba4735219b279fcd61c2b..469216db11e00a353397aaf72869922929f596eb 100755 --- a/convlab/nlg/scgpt/train.sh +++ b/convlab/nlg/scgpt/train.sh @@ -1,13 +1,13 @@ -CUDA_VISIBLE_DEVICES="3" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2043 main.py \ +CUDA_VISIBLE_DEVICES="2" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2042 main.py \ --batch_size 32 \ --accumulation_step 4 \ ---epoch_num 20 \ +--epoch_num 100 \ --lr 5e-5 \ ---base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \ ---val_step 1000 \ ---exp_name scgpt_mwoz \ +--base_model_name_path gpt2-medium \ +--val_step 100 \ +--exp_name gpt2_mwoz001_direct \ --do_train \ ---dataset sgd \ ---train_ratio 1.0 \ -# --scgpt_model_ckpt_path saved_models/sgd_tm_1e4/epoch_8/epoch_8_step41094.pt +--dataset multiwoz21 \ +--train_ratio 0.01 \ +# --scgpt_model_ckpt_path saved_models/gpt2_sgd_tm/epoch_2/epoch_2_step13698.pt # --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \