Skip to content
Snippets Groups Projects
Commit d6295047 authored by zz-jacob's avatar zz-jacob
Browse files

fix scgpt training

parent e0c4d01b
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ class Logging:
f.write('\n')
f.close()
def evaluate(predict_result, ontology):
def evaluate(predict_result, ontology, filter_empty_acts=True):
predict_result = json.load(open(predict_result))
metrics = {}
......@@ -33,7 +33,15 @@ def evaluate(predict_result, ontology):
references = []
candidates = []
for i in range(len(predict_result)):
if filter_empty_acts:
acts = predict_result[i]['dialogue_acts']
acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical'])
if acts_size == 0:
continue
references.append(predict_result[i]['utterance'])
if 'prediction' in predict_result[i]:
candidates.append(predict_result[i]['prediction'])
else:
candidates.append(predict_result[i]['predictions']['utterance'])
# metrics['bleu'] = corpus_bleu(references, candidates)
references = [" " if ref=="" else ref for ref in references]
......@@ -55,7 +63,7 @@ def evaluate(predict_result, ontology):
score_list = []
for item in predict_result:
da = item['dialogue_acts']
utterance = item['predictions']['utterance']
utterance = item['predictions']['utterance'] if 'predictions' in item else item['prediction']
missing_count = 0
redundant_count = 0
all_count = 0
......
CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2050 main.py \
--batch_size 128 \
CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2052 main.py \
--batch_size 1 \
--base_model_name_path gpt2-medium \
--dataset sgd \
--exp_name gpt2_sgd_test \
--model_path saved_models/exp_name/epoch_x/epoch_7_step10312.pt \
\ No newline at end of file
--dataset tm3 \
--exp_name tm3_mst_test \
--model_path saved_models/mwoz_sgd_tm_train/epoch_5/epoch_5_step19206.pt \
# --model_path saved_models/gpt2_tm_direct/epoch_19/epoch_19_step65540.pt \
# --model_path saved_models/gpt2_tm_direct/epoch_6/epoch_6_step22939.pt \
\ No newline at end of file
......@@ -45,6 +45,7 @@ parser.add_argument('--scgpt_model_ckpt_path', default=None, type=str, help="The
parser.add_argument('--save_path', default="saved_models", type=str, help="Model save path.")
parser.add_argument('--exp_name', default="default_name", type=str, help="Current experiment name.")
parser.add_argument("--max_seq_len", default=128, type=int)
parser.add_argument("--save_epoch_interval", default=1, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank
......@@ -80,8 +81,8 @@ def cal_loss(input, target, seq_lens, seq_lens_input):
input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank)
output_mask = torch.logical_xor(mask, input_mask)
pad_mask = torch.logical_not(mask)
# masked_loss = loss * output_mask
masked_loss = loss * (output_mask + pad_mask)
masked_loss = loss * output_mask
# masked_loss = loss * (output_mask + pad_mask)
mean_loss = torch.sum(masked_loss) / torch.sum(output_mask + pad_mask)
return mean_loss
......@@ -111,11 +112,11 @@ def pad_collate(ori_batch):
## Training Hyper-params
def train(model, nlg_data, global_step=0):
train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
train_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['train']), tokenizer)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
val_dataset = SCGPTDataset(nlg_data['validation'], tokenizer)
val_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['validation']), tokenizer)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
......@@ -138,6 +139,7 @@ def train(model, nlg_data, global_step=0):
preds = outputs[0]
loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
loss /= FLAGS.accumulation_step
loss /= dist.get_world_size()
loss.backward()
# update params
......@@ -147,6 +149,7 @@ def train(model, nlg_data, global_step=0):
scheduler.step()
model.zero_grad()
# tensorboard
if dist.get_rank() == 0:
tb_writer.add_scalar(f'Train/loss', loss.item(), global_step)
tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
......@@ -154,12 +157,14 @@ def train(model, nlg_data, global_step=0):
model.eval()
val_loss = eval(model, val_dataloader)
ppl = np.exp(val_loss)
if dist.get_rank() == 0:
tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
model.train()
# save the model when each epoch ends
if dist.get_rank() == 0:
if (epoch+1) % FLAGS.save_epoch_interval == 0:
# vaidation
model.eval()
val_loss = eval(model, val_dataloader)
......@@ -167,7 +172,6 @@ def train(model, nlg_data, global_step=0):
tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
model.train()
# save model
save_dir = os.path.join(FLAGS.save_path, FLAGS.exp_name, f'epoch_{epoch}')
os.makedirs(save_dir, exist_ok=True)
......@@ -202,7 +206,7 @@ def inference_batch(model, sents):
sent_ids = [tokenizer.encode(sent) for sent in sents]
max_len = max([len(sent) for sent in sent_ids])
# ma_len = min(max_len, FLAGS.max_seq_len)
sent_ids = [sent + [0]*(max_len-len(sent)) for sent in sent_ids]
sent_ids = [[0]*(max_len-len(sent)) + sent for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(local_rank)
model_to_run = model.module if type(model) is DDP else model
outputs = model_to_run.generate(inputs, attention_mask=(inputs != 0).float(), max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.eos_token_id) # greedy
......@@ -250,7 +254,7 @@ def test(model, nlg_data, ontology, model_path):
model.eval()
print(f'model loaded from [{model_path}]')
# Load test nlg data
test_data = nlg_data['test']
test_data = filter_empty_nlg_data(nlg_data['test'])
dialog_acts = [act2str(item['dialogue_acts']).strip() for item in test_data]
golden_responses = [item['utterance'].strip() for item in test_data]
# dialog_acts = dialog_acts[:10]
......@@ -327,6 +331,20 @@ def test(model, nlg_data, ontology, model_path):
# f.write(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}')
# f.close()
def filter_empty_nlg_data(data):
ret = []
empty_number = 0
for item in data:
acts = item['dialogue_acts']
acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical'])
if acts_size == 0:
empty_number += 1
continue
else:
ret.append(item)
print('empty count: ', empty_number)
return ret
if __name__ == '__main__':
if '_' in FLAGS.dataset:
......
......@@ -14,8 +14,8 @@ class SCGPT(NLG):
def __init__(self, dataset_name, model_path, device='cpu'):
super(SCGPT, self).__init__()
self.device = device
self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2')).to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
self.model.load_state_dict(torch.load(model_path))
def generate(self, action):
......@@ -34,5 +34,10 @@ class SCGPT(NLG):
outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs==0).float(),
eos_token_id=self.tokenizer.pad_token_id) # greedy
outputs = outputs[:, len(inputs[0]):]
output_strs = [self.tokenizer.decode(item).strip() for item in outputs]
def clean_sentence(sent):
sent = sent.strip()
if self.tokenizer.eos_token in sent:
sent = sent[:sent.index(self.tokenizer.eos_token)]
return sent
output_strs = [clean_sentence(item) for item in outputs]
return output_strs
\ No newline at end of file
CUDA_VISIBLE_DEVICES="2" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2042 main.py \
--batch_size 32 \
--accumulation_step 4 \
--epoch_num 100 \
CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2040 main.py \
--batch_size 64 \
--accumulation_step 2 \
--epoch_num 20 \
--lr 5e-5 \
--base_model_name_path gpt2-medium \
--val_step 100 \
--exp_name gpt2_mwoz001_direct \
--val_step 500 \
--exp_name mwoz_sgd_tm_train \
--do_train \
--dataset multiwoz21 \
--train_ratio 0.01 \
--dataset multiwoz21_sgd_tm1_tm2_tm3 \
--train_ratio 1.0 \
# --scgpt_model_ckpt_path saved_models/gpt2_sgd_tm/epoch_2/epoch_2_step13698.pt
# --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment