Skip to content
Snippets Groups Projects
Commit 855e6ab4 authored by zqwerty's avatar zqwerty
Browse files

Merge branch 'master' into replace_azure

parents a471f4ed 7a091a39
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,7 @@ class Logging: ...@@ -24,7 +24,7 @@ class Logging:
f.write('\n') f.write('\n')
f.close() f.close()
def evaluate(predict_result, ontology): def evaluate(predict_result, ontology, filter_empty_acts=True):
predict_result = json.load(open(predict_result)) predict_result = json.load(open(predict_result))
metrics = {} metrics = {}
...@@ -33,7 +33,15 @@ def evaluate(predict_result, ontology): ...@@ -33,7 +33,15 @@ def evaluate(predict_result, ontology):
references = [] references = []
candidates = [] candidates = []
for i in range(len(predict_result)): for i in range(len(predict_result)):
if filter_empty_acts:
acts = predict_result[i]['dialogue_acts']
acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical'])
if acts_size == 0:
continue
references.append(predict_result[i]['utterance']) references.append(predict_result[i]['utterance'])
if 'prediction' in predict_result[i]:
candidates.append(predict_result[i]['prediction'])
else:
candidates.append(predict_result[i]['predictions']['utterance']) candidates.append(predict_result[i]['predictions']['utterance'])
# metrics['bleu'] = corpus_bleu(references, candidates) # metrics['bleu'] = corpus_bleu(references, candidates)
references = [" " if ref=="" else ref for ref in references] references = [" " if ref=="" else ref for ref in references]
...@@ -55,7 +63,7 @@ def evaluate(predict_result, ontology): ...@@ -55,7 +63,7 @@ def evaluate(predict_result, ontology):
score_list = [] score_list = []
for item in predict_result: for item in predict_result:
da = item['dialogue_acts'] da = item['dialogue_acts']
utterance = item['predictions']['utterance'] utterance = item['predictions']['utterance'] if 'predictions' in item else item['prediction']
missing_count = 0 missing_count = 0
redundant_count = 0 redundant_count = 0
all_count = 0 all_count = 0
......
CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py \ CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2052 main.py \
--dataset multiwoz21 \ --batch_size 1 \
--scgpt_model_ckpt_path /data/zhangzheng/scgpt \ --base_model_name_path gpt2-medium \
--model_path /data/zhangzheng/ConvLab-3/convlab/nlg/scgpt/saved_model/epoch_4/epoch_4_step8875.pt --dataset tm3 \
\ No newline at end of file --exp_name tm3_mst_test \
--model_path saved_models/mwoz_sgd_tm_train/epoch_5/epoch_5_step19206.pt \
# --model_path saved_models/gpt2_tm_direct/epoch_19/epoch_19_step65540.pt \
# --model_path saved_models/gpt2_tm_direct/epoch_6/epoch_6_step22939.pt \
\ No newline at end of file
...@@ -4,11 +4,13 @@ sys.path.append('../../..') ...@@ -4,11 +4,13 @@ sys.path.append('../../..')
import argparse import argparse
import json import json
from tqdm import tqdm from tqdm import tqdm
import time
import torch import torch
from functools import reduce
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import os import os
...@@ -29,11 +31,21 @@ code_test = False ...@@ -29,11 +31,21 @@ code_test = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--lr", default=1e-5, type=float, help="learning rate")
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--train_ratio", default=1.0, type=float)
parser.add_argument("--accumulation_step", default=4, type=int)
parser.add_argument("--epoch_num", default=20, type=int)
parser.add_argument("--val_step", default=100, type=int)
parser.add_argument('--do_train', action="store_true", help="Whether to run training.") parser.add_argument('--do_train', action="store_true", help="Whether to run training.")
parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.") parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.")
parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.") parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.")
parser.add_argument('--scgpt_model_ckpt_path', default="", type=str, help="The path of model for testing.") parser.add_argument('--base_model_name_path', default="gpt2", type=str, help="The path of base model.")
parser.add_argument('--scgpt_model_ckpt_path', default=None, type=str, help="The path of model for testing.")
parser.add_argument('--save_path', default="saved_models", type=str, help="Model save path.")
parser.add_argument('--exp_name', default="default_name", type=str, help="Current experiment name.")
parser.add_argument("--max_seq_len", default=128, type=int) parser.add_argument("--max_seq_len", default=128, type=int)
parser.add_argument("--save_epoch_interval", default=1, type=int)
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank local_rank = FLAGS.local_rank
...@@ -41,25 +53,20 @@ torch.cuda.set_device(local_rank) ...@@ -41,25 +53,20 @@ torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl') dist.init_process_group(backend='nccl')
# TensorBoard # TensorBoard
tb_dir = './runs' tb_dir = 'runs/' + FLAGS.exp_name
if not os.path.exists(tb_dir): if not os.path.exists(tb_dir):
os.mkdir(tb_dir) os.mkdir(tb_dir)
tb_writer = SummaryWriter(tb_dir) tb_writer = SummaryWriter(tb_dir, flush_secs=5)
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
## load model ## load model
if FLAGS.scgpt_model_ckpt_path == '': if FLAGS.scgpt_model_ckpt_path is None:
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2') tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path)
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens}) model = GPT2LMHeadModel.from_pretrained(FLAGS.base_model_name_path).to(local_rank)
model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank)
model.resize_token_embeddings(len(tokenizer))
else: else:
tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.scgpt_model_ckpt_path) tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path)
tokenizer.add_special_tokens( model = GPT2LMHeadModel(config=GPT2Config.from_pretrained(FLAGS.base_model_name_path)).to(local_rank)
{'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens}) model.load_state_dict(torch.load(FLAGS.scgpt_model_ckpt_path))
model = GPT2LMHeadModel.from_pretrained(FLAGS.scgpt_model_ckpt_path).to(local_rank)
print('model load from ' + FLAGS.scgpt_model_ckpt_path) print('model load from ' + FLAGS.scgpt_model_ckpt_path)
model.resize_token_embeddings(len(tokenizer))
nll_loss = nn.NLLLoss(reduce=False).to(local_rank) nll_loss = nn.NLLLoss(reduce=False).to(local_rank)
ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank) ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank)
...@@ -74,91 +81,90 @@ def cal_loss(input, target, seq_lens, seq_lens_input): ...@@ -74,91 +81,90 @@ def cal_loss(input, target, seq_lens, seq_lens_input):
input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank) input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank)
output_mask = torch.logical_xor(mask, input_mask) output_mask = torch.logical_xor(mask, input_mask)
pad_mask = torch.logical_not(mask) pad_mask = torch.logical_not(mask)
# masked_loss = loss * output_mask masked_loss = loss * output_mask
masked_loss = loss * (output_mask + pad_mask) # masked_loss = loss * (output_mask + pad_mask)
mean_loss = torch.sum(masked_loss) / torch.sum(output_mask + pad_mask) mean_loss = torch.sum(masked_loss) / torch.sum(output_mask + pad_mask)
return mean_loss return mean_loss
def pad_collate(batch): def pad_collate(ori_batch):
""" """
Returns: Returns:
batch: batch * max_len batch: batch * max_len
seq_lens: the length of len(da)+1+len(response) seq_lens: the length of len(da)+1+len(response)
seq_lens_input: the length of len(da) seq_lens_input: the length of len(da)
""" """
START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc('&')
pad_token_id = tokenizer.pad_token_id batch = [item[0] + [START_OF_PRED_ID] + item[1] + [tokenizer.eos_token_id] for item in ori_batch]
batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch] output_lens = [len(item[1])+1 for item in ori_batch]
batch = [item[-FLAGS.max_seq_len:] for item in batch] batch = [item[-FLAGS.max_seq_len:] for item in batch]
max_len = max([len(item) for item in batch]) max_len = max([len(item) for item in batch])
# print('max_len', max_len) # print('max_len', max_len)
seq_lens = [len(item) for item in batch] seq_lens = [len(item) for item in batch]
split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) seq_lens_input = []
def get_x_len(tokens): for idx in range(len(batch)):
"""Get the length of dialogue act tokens""" curr_ipt_len = seq_lens[idx] - output_lens[idx]
split_idx = len(tokens) if curr_ipt_len < 0:
try: curr_ipt_len = 0
split_idx = tokens.index(split_id)+1 seq_lens_input.append(curr_ipt_len)
except: batch = [item + [0]*(max_len-len(item)) for item in batch]
pass
return split_idx
seq_lens_input = [get_x_len(item) for item in batch]
batch = [item + [pad_token_id]*(max_len-len(item)) for item in batch]
# print(batch)
# print(seq_lens)
# print(seq_lens_input)
return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input) return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input)
## Training Hyper-params ## Training Hyper-params
EPOCH_NUM = 20
BATCH_SIZE = 32 # real_batch_size = BATCH_SIZE * num_gpu
VAL_STEP = 500
WARM_STEPS = 250
if code_test:
EPOCH_NUM = 2
BATCH_SIZE = 4
VAL_STEP = 2
WARM_STEPS = 3
LR = 5e-5
SAVE_PATH = f'./saved_model'
def train(model, nlg_data, global_step=0): def train(model, nlg_data, global_step=0):
train_dataset = SCGPTDataset(nlg_data['train'], tokenizer) train_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['train']), tokenizer)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=train_sampler, collate_fn=pad_collate) train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
val_dataset = SCGPTDataset(nlg_data['validation'], tokenizer) val_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['validation']), tokenizer)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=val_sampler, collate_fn=pad_collate) val_dataloader = DataLoader(val_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
model = DDP(model, device_ids=[local_rank], output_device=local_rank) model = DDP(model, device_ids=[local_rank], output_device=local_rank)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) optimizer = torch.optim.AdamW(model.parameters(), lr=FLAGS.lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARM_STEPS, t_total = len(train_dataloader) * FLAGS.epoch_num // FLAGS.accumulation_step
num_training_steps=len(train_dataloader) * EPOCH_NUM) warm_steps = int(0.1 * t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_steps,
num_training_steps=t_total)
model.train() model.train()
for epoch in range(EPOCH_NUM): for epoch in range(FLAGS.epoch_num):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{EPOCH_NUM}]')): for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{FLAGS.epoch_num}]')):
if (batch_id+1) % FLAGS.accumulation_step == 0:
global_step += 1
inputs = inputs.to(local_rank) inputs = inputs.to(local_rank)
seq_lens = seq_lens.to(local_rank) seq_lens = seq_lens.to(local_rank)
seq_lens_input = seq_lens_input.to(local_rank) seq_lens_input = seq_lens_input.to(local_rank)
outputs = model(inputs, attention_mask=(inputs!=0).float())
outputs = model(inputs)
preds = outputs[0] preds = outputs[0]
loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input) loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
loss /= FLAGS.accumulation_step
optimizer.zero_grad() loss /= dist.get_world_size()
loss.backward() loss.backward()
# update params
if (batch_id+1) % FLAGS.accumulation_step == 0:
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
model.zero_grad()
# tensorboard
if dist.get_rank() == 0:
tb_writer.add_scalar(f'Train/loss', loss.item(), global_step) tb_writer.add_scalar(f'Train/loss', loss.item(), global_step)
tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step) tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step) tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
if global_step % FLAGS.val_step == 0:
model.eval()
val_loss = eval(model, val_dataloader)
ppl = np.exp(val_loss)
if dist.get_rank() == 0:
tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
model.train()
global_step += 1
# save the model when each epoch ends # save the model when each epoch ends
if dist.get_rank() == 0: if dist.get_rank() == 0:
if (epoch+1) % FLAGS.save_epoch_interval == 0:
# vaidation # vaidation
model.eval() model.eval()
val_loss = eval(model, val_dataloader) val_loss = eval(model, val_dataloader)
...@@ -166,16 +172,14 @@ def train(model, nlg_data, global_step=0): ...@@ -166,16 +172,14 @@ def train(model, nlg_data, global_step=0):
tb_writer.add_scalar(f'Val/Loss', val_loss, global_step) tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
tb_writer.add_scalar(f'Val/PPL', ppl, global_step) tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
model.train() model.train()
# save model # save model
save_dir = os.path.join(SAVE_PATH, f'epoch_{epoch}') save_dir = os.path.join(FLAGS.save_path, FLAGS.exp_name, f'epoch_{epoch}')
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt')) torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt'))
tokenizer.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir)
torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt')) torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt')) torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt'))
print(f'Save model checkpoint to [{save_dir}]') print(f'Save model checkpoint to [{save_dir}]')
tb_writer.flush() tb_writer.flush()
...@@ -198,17 +202,19 @@ def eval(model, loader, use_tqdm=False): ...@@ -198,17 +202,19 @@ def eval(model, loader, use_tqdm=False):
def inference_batch(model, sents): def inference_batch(model, sents):
"""Inference model given a batch of sents.""" """Inference model given a batch of sents."""
with torch.no_grad(): with torch.no_grad():
sents = [sent + ' ' + START_OF_PRED for sent in sents] sents = [sent + ' &' for sent in sents]
sent_ids = [tokenizer.encode(sent) for sent in sents] sent_ids = [tokenizer.encode(sent) for sent in sents]
max_len = max([len(sent) for sent in sent_ids]) max_len = max([len(sent) for sent in sent_ids])
sent_ids = [sent + [tokenizer.pad_token_id]*(max_len-len(sent)) for sent in sent_ids] # ma_len = min(max_len, FLAGS.max_seq_len)
sent_ids = [[0]*(max_len-len(sent)) + sent for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(local_rank) inputs = torch.LongTensor(sent_ids).to(local_rank)
model_to_run = model.module if type(model) is DDP else model model_to_run = model.module if type(model) is DDP else model
outputs = model_to_run.generate(inputs, max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.pad_token_id, outputs = model_to_run.generate(inputs, attention_mask=(inputs != 0).float(), max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.eos_token_id) # greedy
pad_token_id=tokenizer.pad_token_id) # greedy
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id, # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search # pad_token_id=gpt2_tokenizer.pad_token_id) # beam search
output_strs = [tokenizer.decode(item) for item in outputs] # output_strs = [tokenizer.decode(item) for item in outputs]
outputs = outputs[:, len(inputs[0]):]
output_strs = tokenizer.batch_decode(outputs)
return output_strs return output_strs
...@@ -226,26 +232,42 @@ def inference_sents(model, sents): ...@@ -226,26 +232,42 @@ def inference_sents(model, sents):
return outputs return outputs
def inference_sents_by_batch(model, sents):
"""Get the outputs of multiple sentences."""
start_idx = 0
ret = []
start = time.time()
while start_idx < len(sents):
end_idx = start_idx + FLAGS.batch_size
curr_sents = sents[start_idx:end_idx]
outputs = inference_batch(model, curr_sents)
ret += outputs
start_idx += FLAGS.batch_size
time_remain = (time.time()-start) / start_idx * (len(sents) - start_idx)
print('{}/{}, time remaining: {:.2f}'.format(start_idx, len(sents), time_remain))
return ret
def test(model, nlg_data, ontology, model_path): def test(model, nlg_data, ontology, model_path):
"""将sheel中的GPU个数设为1运行""" """将sheel中的GPU个数设为1运行"""
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
model.eval() model.eval()
print(f'model loaded from [{model_path}]') print(f'model loaded from [{model_path}]')
# Load test nlg data # Load test nlg data
test_data = nlg_data['test'] test_data = filter_empty_nlg_data(nlg_data['test'])
dialog_acts = [act2str(item['dialogue_acts']).strip() for item in test_data] dialog_acts = [act2str(item['dialogue_acts']).strip() for item in test_data]
golden_responses = [item['utterance'].strip() for item in test_data] golden_responses = [item['utterance'].strip() for item in test_data]
# dialog_acts = dialog_acts[:10] # dialog_acts = dialog_acts[:10]
# golden_responses = golden_responses[:10] # golden_responses = golden_responses[:10]
outputs = inference_sents(model, dialog_acts) outputs = inference_sents_by_batch(model, dialog_acts)
def get_real_output(ipt): def get_real_output(ipt):
if '[start_of_pred]' in ipt: if tokenizer.eos_token in ipt:
ipt = ipt[ipt.index('[start_of_pred]')+15:].strip() ipt = ipt[:ipt.index(tokenizer.eos_token)].strip()
if '[_pad_token_]' in ipt:
ipt = ipt[:ipt.index('[_pad_token_]')].strip()
return ipt return ipt
outputs = [get_real_output(item) for item in outputs] outputs = [get_real_output(item) for item in outputs]
output_file = './test_output.json' if not os.path.exists('./test_outputs'):
os.makedirs('./test_outputs', exist_ok=True)
output_file = f'./test_outputs/{FLAGS.exp_name}.json'
if dist.get_rank() == 0: if dist.get_rank() == 0:
with open(output_file, 'w+') as f: with open(output_file, 'w+') as f:
result = [] result = []
...@@ -253,7 +275,9 @@ def test(model, nlg_data, ontology, model_path): ...@@ -253,7 +275,9 @@ def test(model, nlg_data, ontology, model_path):
result.append({ result.append({
'dialogue_acts': test_data[i]['dialogue_acts'], 'dialogue_acts': test_data[i]['dialogue_acts'],
'utterance': test_data[i]['utterance'], 'utterance': test_data[i]['utterance'],
'prediction': outputs[i] 'predictions': {
'utterance': outputs[i]
}
}) })
json.dump(result, f, indent=2, ensure_ascii=False) json.dump(result, f, indent=2, ensure_ascii=False)
evaluator = GentScorer() evaluator = GentScorer()
...@@ -307,9 +331,42 @@ def test(model, nlg_data, ontology, model_path): ...@@ -307,9 +331,42 @@ def test(model, nlg_data, ontology, model_path):
# f.write(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}') # f.write(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}')
# f.close() # f.close()
def filter_empty_nlg_data(data):
ret = []
empty_number = 0
for item in data:
acts = item['dialogue_acts']
acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical'])
if acts_size == 0:
empty_number += 1
continue
else:
ret.append(item)
print('empty count: ', empty_number)
return ret
if __name__ == '__main__': if __name__ == '__main__':
dataset = load_dataset(FLAGS.dataset) if '_' in FLAGS.dataset:
spans = FLAGS.dataset.split('_')
data_list = spans
datasets = [load_dataset(item) for item in data_list]
nlg_datas = [load_nlg_data(item) for item in datasets]
ret = {}
def aggregrate(nlg_datas, split):
ret = []
for item in nlg_datas:
ret += item[split]
return ret
ret['train'] = aggregrate(nlg_datas, 'train')
ret['validation'] = aggregrate(nlg_datas, 'validation')
ret['test'] = aggregrate(nlg_datas, 'test')
if FLAGS.do_train:
train(model, ret)
else:
print('not supported')
else:
dataset = load_dataset(FLAGS.dataset, dial_ids_order=0, split2ratio={'train': FLAGS.train_ratio})
ontology = load_ontology(FLAGS.dataset) ontology = load_ontology(FLAGS.dataset)
nlg_data = load_nlg_data(dataset) nlg_data = load_nlg_data(dataset)
if FLAGS.do_train: if FLAGS.do_train:
......
...@@ -26,3 +26,27 @@ class SCGPTDataset(Dataset): ...@@ -26,3 +26,27 @@ class SCGPTDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.data[idx] return self.data[idx]
class SGD_TMDataset(Dataset):
def __init__(self, data, tokenizer):
"""
Args:
data: [[da_str, response], [da_str, response], ...]
tokenizer: GPT2 Tokenizer
"""
self.data = []
length_list = []
for item in data:
da, response = item['dialogue_acts'], item['utterance']
da_tokens = tokenizer.encode(act2str(da))
response_tokens = tokenizer.encode(response)
length_list.append(len(da_tokens) + len(response_tokens) + 1)
self.data.append([da_tokens, response_tokens])
print(f'max: {np.max(length_list)}, min: {np.min(length_list)}, median: {np.quantile(length_list, 0.5)}, 0.99: {np.quantile(length_list, 0.99)}')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 4 21:43:42 2020
@author: truthless
"""
from convlab.nlg.scgpt.multiwoz.scgpt import SCGPT
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 14 11:38:53 2020
@author: truthless
"""
import os
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from convlab.nlg.scgpt.utils import dict2dict, dict2seq
import zipfile
def read_zipped_json(filepath, filename):
print("zip file path = ", filepath)
archive = zipfile.ZipFile(filepath, 'r')
return json.load(archive.open(filename))
def init_domain():
return {'Attraction':False,
'Hospital':False,
'Hotel':False,
'Police':False,
'Restaurant':False,
'Taxi':False,
'Train':False}
def write_file(name, data, role='usr'):
with open(f'{name}.txt', 'w', encoding='utf-8') as f:
for ID in data:
sess = data[ID]
sess_domains = init_domain()
for turn in sess:
if role == 'usr':
if not turn['usr_da']:
continue
turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
elif role == 'sys':
if not turn['sys_da']:
continue
turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
else:
raise NameError('Invalid Role: Select usr/sys.')
for domain in domains:
if domain not in ['general', 'Booking'] and not sess_domains[domain]:
da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
sess_domains[domain] = True
if role == 'usr':
da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
elif role == 'sys':
da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
f.write(f'{da_seq} & {da_uttr}\n')
if __name__ == '__main__':
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--role', type=str, default='usr')
args = parser.parse_args()
cur_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
cur_dir)))), 'data/multiwoz/')
keys = ['train', 'val', 'test']
data = {}
for key in keys:
data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json')
print('load {}, size {}'.format(key, len(data_key)))
data = dict(data, **data_key)
with open(os.path.join(data_dir, 'valListFile'), 'r') as f:
val_list = f.read().splitlines()
with open(os.path.join(data_dir, 'testListFile'), 'r') as f:
test_list = f.read().splitlines()
results = {}
results_val = {}
results_test = {}
for title, sess in data.items():
logs = sess['log']
turns = []
turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
current_domain = None
for i, diag in enumerate(logs):
text = diag['text']
da = diag['dialog_act']
span = diag['span_info']
if current_domain:
da = eval(str(da).replace('Booking', current_domain))
span = eval(str(span).replace('Booking', current_domain))
if i % 2 == 0:
turn['usr'] = text
turn['usr_da'] = da
turn['usr_span'] = span
turns.append(turn)
else:
turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
turn['sys'] = text
turn['sys_da'] = da
turn['sys_span'] = span
for key in da:
domain = key.split('-')[0]
if domain not in ['general', 'Booking']:
current_domain = domain
else:
if args.role == 'sys':
turns.append(turn)
title = title
if title in val_list:
current = results_val
elif title in test_list:
current = results_test
else:
current = results
current[title] = turns
results = eval(str(results).replace(" n't", " not"))
results_val = eval(str(results_val).replace(" n't", " not"))
results_test = eval(str(results_test).replace(" n't", " not"))
if not os.path.exists(os.path.join(cur_dir,'data')):
os.makedirs(os.path.join(cur_dir, 'data'))
write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role)
write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role)
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
import sys
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--length", type=int, default=40)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0,
help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument('--input_file', type=str, default=None,
help="file")
parser.add_argument('--output_file', type=str, default=None,
help="file")
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
set_seed(args)
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, pad_token='<PAD>', padding_side='left')
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
model.eval()
if args.length < 0 and model.config.max_position_embeddings > 0:
args.length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < args.length:
args.length = model.config.max_position_embeddings # No generation bigger than model size
elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop
logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7:
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
fin = open(args.input_file)
inputs = [i.strip() for i in fin]
output_tests = []
for idx in range(0, len(inputs), args.batch_size):
logger.info(f"PROGRESS: {int(idx/len(inputs)*100)}%")
# raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
raw_inputs = []
for i in range(idx, min(idx+args.batch_size, len(inputs))):
lines = inputs[i]
raw_text = lines.split(' & ')[0] + ' & '
if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
raw_inputs.append(raw_text)
encoding_inputs = tokenizer.batch_encode_plus(raw_inputs, pad_to_max_length=True, add_special_tokens=False)
context_tokens = torch.LongTensor(encoding_inputs['input_ids']).to(args.device)
max_length = len(context_tokens[0])
attention_mask = torch.LongTensor(encoding_inputs['attention_mask']).to(args.device)
position_ids = (attention_mask.cumsum(-1) - 1)
position_ids.masked_fill_(attention_mask==0, 0)
if args.model_type == "ctrl":
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
out_ids = model.generate(
input_ids=context_tokens,
attention_mask=attention_mask,
position_ids=position_ids,
num_beams=args.num_samples,
num_return_sequences=args.num_samples,
max_length=args.length,
temperature=args.temperature,
do_sample=True,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty
)
out_ids = out_ids.reshape(len(raw_inputs), args.num_samples, -1)[:, :, max_length:].tolist()
for j, out in enumerate(out_ids):
examples = [inputs[j]]
for o in out:
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
examples.append(text)
output_tests.append(examples)
# break
# if args.prompt:
# break
import json
json.dump(output_tests, open(args.output_file,'w'), indent=2)
return text
if __name__ == '__main__':
main()
...@@ -2,27 +2,22 @@ import sys ...@@ -2,27 +2,22 @@ import sys
sys.path.append('../../..') sys.path.append('../../..')
import torch import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from convlab.nlg.nlg import NLG from convlab.nlg.nlg import NLG
from util import act2str from util import act2str
from scgpt_special_tokens import * from scgpt_special_tokens import *
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
class SCGPT(NLG): class SCGPT(NLG):
def __init__(self, dataset_name, model_path, device='cpu'): def __init__(self, dataset_name, model_path, device='cpu'):
super(SCGPT, self).__init__() super(SCGPT, self).__init__()
self.device = device self.device = device
self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED,
'additional_special_tokens': special_tokens})
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.load_state_dict(torch.load(model_path)) self.model.load_state_dict(torch.load(model_path))
def generate(self, action): def generate(self, action):
action_str = act2str(action) action_str = act2str(action)
output = self._inference_batch([action_str])[0] output = self._inference_batch([action_str])[0]
...@@ -30,16 +25,19 @@ class SCGPT(NLG): ...@@ -30,16 +25,19 @@ class SCGPT(NLG):
def _inference_batch(self, sents): def _inference_batch(self, sents):
with torch.no_grad(): with torch.no_grad():
sents = [sent + ' ' + START_OF_PRED for sent in sents] sents = [sent for sent in sents]
sent_ids = [self.tokenizer.encode(sent) for sent in sents] sent_ids = [self.tokenizer.encode(sent) + [self.tokenizer._convert_token_to_id_with_added_voc('&')] for sent in sents]
max_len = max([len(sent) for sent in sent_ids]) max_len = max([len(sent) for sent in sent_ids])
sent_ids = [sent + [self.tokenizer.pad_token_id] * (max_len - len(sent)) for sent in sent_ids] sent_ids = [sent + [0] * (max_len - len(sent)) for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(self.device) inputs = torch.LongTensor(sent_ids).to(self.device)
model_to_run = self.model.module if type(self.model) is DDP else self.model model_to_run = self.model.module if type(self.model) is DDP else self.model
outputs = model_to_run.generate(inputs, max_length=256, outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs!=0).float(),
eos_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.pad_token_id) # greedy
pad_token_id=self.tokenizer.pad_token_id) # greedy outputs = outputs[:, len(inputs[0]):]
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id, def clean_sentence(sent):
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search sent = sent.strip()
output_strs = [self.tokenizer.decode(item) for item in outputs] if self.tokenizer.eos_token in sent:
sent = sent[:sent.index(self.tokenizer.eos_token)]
return sent
output_strs = [clean_sentence(item) for item in outputs]
return output_strs return output_strs
\ No newline at end of file
...@@ -3,7 +3,7 @@ SYS_SPEAK = '[sys_speak]' ...@@ -3,7 +3,7 @@ SYS_SPEAK = '[sys_speak]'
USR_SPEAK = '[usr_speak]' USR_SPEAK = '[usr_speak]'
START_OF_PRED = '[start_of_pred]' START_OF_PRED = '[start_of_pred]'
END_OF_PRED = '[end_of_pred]' END_OF_PRED = '[end_of_pred]'
PAD_TOKEN = '[_pad_token_]' PAD_TOKEN = '<|pad_token|>'
START_OF_INTENT = '[start_of_intent]' START_OF_INTENT = '[start_of_intent]'
END_OF_INTENT = '[end_of_intent]' END_OF_INTENT = '[end_of_intent]'
START_OF_SLOT = '' START_OF_SLOT = ''
......
CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21 --scgpt_model_ckpt_path /data/zhangzheng/scgpt CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2040 main.py \
\ No newline at end of file --batch_size 64 \
--accumulation_step 2 \
--epoch_num 20 \
--lr 5e-5 \
--base_model_name_path gpt2-medium \
--val_step 500 \
--exp_name mwoz_sgd_tm_train \
--do_train \
--dataset multiwoz21_sgd_tm1_tm2_tm3 \
--train_ratio 1.0 \
# --scgpt_model_ckpt_path saved_models/gpt2_sgd_tm/epoch_2/epoch_2_step13698.pt
# --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment