Skip to content
Snippets Groups Projects
Commit e0c4d01b authored by zz-jacob's avatar zz-jacob
Browse files

fix scgpt.py

parent 54c6842d
No related branches found
No related tags found
No related merge requests found
CUDA_VISIBLE_DEVICES="1" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2051 main.py \
--batch_size 64 \
CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2050 main.py \
--batch_size 128 \
--base_model_name_path gpt2-medium \
--dataset multiwoz21 \
--exp_name gpt2_mwoz2 \
--model_path saved_models/gpt2_mwoz/epoch_2/epoch_2_step1329.pt \
\ No newline at end of file
--dataset sgd \
--exp_name gpt2_sgd_test \
--model_path saved_models/exp_name/epoch_x/epoch_7_step10312.pt \
\ No newline at end of file
......@@ -2,27 +2,22 @@ import sys
sys.path.append('../../..')
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.nn.parallel import DistributedDataParallel as DDP
from convlab.nlg.nlg import NLG
from util import act2str
from scgpt_special_tokens import *
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
class SCGPT(NLG):
def __init__(self, dataset_name, model_path, device='cpu'):
super(SCGPT, self).__init__()
self.device = device
self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2')).to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED,
'additional_special_tokens': special_tokens})
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.load_state_dict(torch.load(model_path))
def generate(self, action):
action_str = act2str(action)
output = self._inference_batch([action_str])[0]
......@@ -30,17 +25,14 @@ class SCGPT(NLG):
def _inference_batch(self, sents):
with torch.no_grad():
sents = [sent + ' ' + START_OF_PRED for sent in sents]
sent_ids = [self.tokenizer.encode(sent) for sent in sents]
sents = [sent for sent in sents]
sent_ids = [self.tokenizer.encode(sent) + [self.tokenizer._convert_token_to_id_with_added_voc('&')] for sent in sents]
max_len = max([len(sent) for sent in sent_ids])
sent_ids = [[self.tokenizer.pad_token_id] * (max_len - len(sent)) + sent for sent in sent_ids]
sent_ids = [sent + [0] * (max_len - len(sent)) for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(self.device)
model_to_run = self.model.module if type(self.model) is DDP else self.model
outputs = model_to_run.generate(inputs, max_length=256,
eos_token_id=self.tokenizer.pad_token_id,
pad_token_id=self.tokenizer.pad_token_id) # greedy
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search
output_strs = [self.tokenizer.decode(item) for item in outputs]
outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs==0).float(),
eos_token_id=self.tokenizer.pad_token_id) # greedy
outputs = outputs[:, len(inputs[0]):]
output_strs = [self.tokenizer.decode(item).strip() for item in outputs]
return output_strs
\ No newline at end of file
CUDA_VISIBLE_DEVICES="3" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2043 main.py \
CUDA_VISIBLE_DEVICES="2" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2042 main.py \
--batch_size 32 \
--accumulation_step 4 \
--epoch_num 20 \
--epoch_num 100 \
--lr 5e-5 \
--base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \
--val_step 1000 \
--exp_name scgpt_mwoz \
--base_model_name_path gpt2-medium \
--val_step 100 \
--exp_name gpt2_mwoz001_direct \
--do_train \
--dataset sgd \
--train_ratio 1.0 \
# --scgpt_model_ckpt_path saved_models/sgd_tm_1e4/epoch_8/epoch_8_step41094.pt
--dataset multiwoz21 \
--train_ratio 0.01 \
# --scgpt_model_ckpt_path saved_models/gpt2_sgd_tm/epoch_2/epoch_2_step13698.pt
# --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment