diff --git a/convlab/nlg/evaluate_unified_datasets.py b/convlab/nlg/evaluate_unified_datasets.py index 7a19a49267f1c526bbca29201893a41f948afc67..23c937ed732ae547b0ca48dd81e6cce6f71c3c62 100644 --- a/convlab/nlg/evaluate_unified_datasets.py +++ b/convlab/nlg/evaluate_unified_datasets.py @@ -24,7 +24,7 @@ class Logging: f.write('\n') f.close() -def evaluate(predict_result, ontology): +def evaluate(predict_result, ontology, filter_empty_acts=True): predict_result = json.load(open(predict_result)) metrics = {} @@ -33,8 +33,16 @@ def evaluate(predict_result, ontology): references = [] candidates = [] for i in range(len(predict_result)): + if filter_empty_acts: + acts = predict_result[i]['dialogue_acts'] + acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical']) + if acts_size == 0: + continue references.append(predict_result[i]['utterance']) - candidates.append(predict_result[i]['predictions']['utterance']) + if 'prediction' in predict_result[i]: + candidates.append(predict_result[i]['prediction']) + else: + candidates.append(predict_result[i]['predictions']['utterance']) # metrics['bleu'] = corpus_bleu(references, candidates) references = [" " if ref=="" else ref for ref in references] metrics['bleu'] = sacrebleu.corpus_bleu(candidates, [references], lowercase=True).score @@ -55,7 +63,7 @@ def evaluate(predict_result, ontology): score_list = [] for item in predict_result: da = item['dialogue_acts'] - utterance = item['predictions']['utterance'] + utterance = item['predictions']['utterance'] if 'predictions' in item else item['prediction'] missing_count = 0 redundant_count = 0 all_count = 0 diff --git a/convlab/nlg/scgpt/evaluate.sh b/convlab/nlg/scgpt/evaluate.sh old mode 100644 new mode 100755 index 9bf8f52ac6c351d71b76dd442ce0257001bf07b3..c2ef94d35c0fdc48ef5de405bfce0d5c8a3828a2 --- a/convlab/nlg/scgpt/evaluate.sh +++ b/convlab/nlg/scgpt/evaluate.sh @@ -1,4 +1,8 @@ -CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py \ ---dataset multiwoz21 \ ---scgpt_model_ckpt_path /data/zhangzheng/scgpt \ ---model_path /data/zhangzheng/ConvLab-3/convlab/nlg/scgpt/saved_model/epoch_4/epoch_4_step8875.pt \ No newline at end of file +CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2052 main.py \ +--batch_size 1 \ +--base_model_name_path gpt2-medium \ +--dataset tm3 \ +--exp_name tm3_mst_test \ +--model_path saved_models/mwoz_sgd_tm_train/epoch_5/epoch_5_step19206.pt \ +# --model_path saved_models/gpt2_tm_direct/epoch_19/epoch_19_step65540.pt \ +# --model_path saved_models/gpt2_tm_direct/epoch_6/epoch_6_step22939.pt \ \ No newline at end of file diff --git a/convlab/nlg/scgpt/main.py b/convlab/nlg/scgpt/main.py index f3d48fbe7876c0a3bc31351dcb316327f3700da5..6c2fd505ed7da4ddac9971267bd6a6ba50bc402f 100644 --- a/convlab/nlg/scgpt/main.py +++ b/convlab/nlg/scgpt/main.py @@ -4,11 +4,13 @@ sys.path.append('../../..') import argparse import json from tqdm import tqdm +import time import torch +from functools import reduce import numpy as np import torch.nn as nn import torch.nn.functional as F -from transformers import GPT2Tokenizer, GPT2LMHeadModel +from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import os @@ -29,11 +31,21 @@ code_test = False parser = argparse.ArgumentParser() parser.add_argument("--local_rank", default=-1, type=int) +parser.add_argument("--lr", default=1e-5, type=float, help="learning rate") +parser.add_argument("--batch_size", default=32, type=int) +parser.add_argument("--train_ratio", default=1.0, type=float) +parser.add_argument("--accumulation_step", default=4, type=int) +parser.add_argument("--epoch_num", default=20, type=int) +parser.add_argument("--val_step", default=100, type=int) parser.add_argument('--do_train', action="store_true", help="Whether to run training.") parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.") parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.") -parser.add_argument('--scgpt_model_ckpt_path', default="", type=str, help="The path of model for testing.") +parser.add_argument('--base_model_name_path', default="gpt2", type=str, help="The path of base model.") +parser.add_argument('--scgpt_model_ckpt_path', default=None, type=str, help="The path of model for testing.") +parser.add_argument('--save_path', default="saved_models", type=str, help="Model save path.") +parser.add_argument('--exp_name', default="default_name", type=str, help="Current experiment name.") parser.add_argument("--max_seq_len", default=128, type=int) +parser.add_argument("--save_epoch_interval", default=1, type=int) FLAGS = parser.parse_args() local_rank = FLAGS.local_rank @@ -41,25 +53,20 @@ torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl') # TensorBoard -tb_dir = './runs' +tb_dir = 'runs/' + FLAGS.exp_name if not os.path.exists(tb_dir): os.mkdir(tb_dir) -tb_writer = SummaryWriter(tb_dir) +tb_writer = SummaryWriter(tb_dir, flush_secs=5) -special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK] ## load model -if FLAGS.scgpt_model_ckpt_path == '': - tokenizer = GPT2Tokenizer.from_pretrained('./gpt2') - tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens}) - model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank) - model.resize_token_embeddings(len(tokenizer)) +if FLAGS.scgpt_model_ckpt_path is None: + tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path) + model = GPT2LMHeadModel.from_pretrained(FLAGS.base_model_name_path).to(local_rank) else: - tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.scgpt_model_ckpt_path) - tokenizer.add_special_tokens( - {'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens}) - model = GPT2LMHeadModel.from_pretrained(FLAGS.scgpt_model_ckpt_path).to(local_rank) + tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path) + model = GPT2LMHeadModel(config=GPT2Config.from_pretrained(FLAGS.base_model_name_path)).to(local_rank) + model.load_state_dict(torch.load(FLAGS.scgpt_model_ckpt_path)) print('model load from ' + FLAGS.scgpt_model_ckpt_path) - model.resize_token_embeddings(len(tokenizer)) nll_loss = nn.NLLLoss(reduce=False).to(local_rank) ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank) @@ -74,108 +81,105 @@ def cal_loss(input, target, seq_lens, seq_lens_input): input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank) output_mask = torch.logical_xor(mask, input_mask) pad_mask = torch.logical_not(mask) - # masked_loss = loss * output_mask - masked_loss = loss * (output_mask + pad_mask) + masked_loss = loss * output_mask + # masked_loss = loss * (output_mask + pad_mask) mean_loss = torch.sum(masked_loss) / torch.sum(output_mask + pad_mask) return mean_loss -def pad_collate(batch): +def pad_collate(ori_batch): """ Returns: batch: batch * max_len seq_lens: the length of len(da)+1+len(response) seq_lens_input: the length of len(da) """ - START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) - pad_token_id = tokenizer.pad_token_id - batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch] + START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc('&') + batch = [item[0] + [START_OF_PRED_ID] + item[1] + [tokenizer.eos_token_id] for item in ori_batch] + output_lens = [len(item[1])+1 for item in ori_batch] batch = [item[-FLAGS.max_seq_len:] for item in batch] max_len = max([len(item) for item in batch]) # print('max_len', max_len) seq_lens = [len(item) for item in batch] - split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) - def get_x_len(tokens): - """Get the length of dialogue act tokens""" - split_idx = len(tokens) - try: - split_idx = tokens.index(split_id)+1 - except: - pass - return split_idx - seq_lens_input = [get_x_len(item) for item in batch] - batch = [item + [pad_token_id]*(max_len-len(item)) for item in batch] - # print(batch) - # print(seq_lens) - # print(seq_lens_input) + seq_lens_input = [] + for idx in range(len(batch)): + curr_ipt_len = seq_lens[idx] - output_lens[idx] + if curr_ipt_len < 0: + curr_ipt_len = 0 + seq_lens_input.append(curr_ipt_len) + batch = [item + [0]*(max_len-len(item)) for item in batch] return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input) ## Training Hyper-params -EPOCH_NUM = 20 -BATCH_SIZE = 32 # real_batch_size = BATCH_SIZE * num_gpu -VAL_STEP = 500 -WARM_STEPS = 250 -if code_test: - EPOCH_NUM = 2 - BATCH_SIZE = 4 - VAL_STEP = 2 - WARM_STEPS = 3 -LR = 5e-5 -SAVE_PATH = f'./saved_model' def train(model, nlg_data, global_step=0): - train_dataset = SCGPTDataset(nlg_data['train'], tokenizer) + train_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['train']), tokenizer) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=train_sampler, collate_fn=pad_collate) + train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=train_sampler, collate_fn=pad_collate) - val_dataset = SCGPTDataset(nlg_data['validation'], tokenizer) + val_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['validation']), tokenizer) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) - val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=val_sampler, collate_fn=pad_collate) + val_dataloader = DataLoader(val_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=val_sampler, collate_fn=pad_collate) model = DDP(model, device_ids=[local_rank], output_device=local_rank) - optimizer = torch.optim.AdamW(model.parameters(), lr=LR) - scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARM_STEPS, - num_training_steps=len(train_dataloader) * EPOCH_NUM) + optimizer = torch.optim.AdamW(model.parameters(), lr=FLAGS.lr) + t_total = len(train_dataloader) * FLAGS.epoch_num // FLAGS.accumulation_step + warm_steps = int(0.1 * t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_steps, + num_training_steps=t_total) model.train() - for epoch in range(EPOCH_NUM): + for epoch in range(FLAGS.epoch_num): train_dataloader.sampler.set_epoch(epoch) - for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{EPOCH_NUM}]')): + for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{FLAGS.epoch_num}]')): + if (batch_id+1) % FLAGS.accumulation_step == 0: + global_step += 1 inputs = inputs.to(local_rank) seq_lens = seq_lens.to(local_rank) seq_lens_input = seq_lens_input.to(local_rank) - - outputs = model(inputs) + outputs = model(inputs, attention_mask=(inputs!=0).float()) preds = outputs[0] loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input) - - optimizer.zero_grad() + loss /= FLAGS.accumulation_step + loss /= dist.get_world_size() loss.backward() - optimizer.step() - scheduler.step() - tb_writer.add_scalar(f'Train/loss', loss.item(), global_step) - tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step) - tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step) - - global_step += 1 + # update params + + + if (batch_id+1) % FLAGS.accumulation_step == 0: + optimizer.step() + scheduler.step() + model.zero_grad() + # tensorboard + if dist.get_rank() == 0: + tb_writer.add_scalar(f'Train/loss', loss.item(), global_step) + tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step) + tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step) + if global_step % FLAGS.val_step == 0: + model.eval() + val_loss = eval(model, val_dataloader) + ppl = np.exp(val_loss) + if dist.get_rank() == 0: + tb_writer.add_scalar(f'Val/Loss', val_loss, global_step) + tb_writer.add_scalar(f'Val/PPL', ppl, global_step) + model.train() + # save the model when each epoch ends if dist.get_rank() == 0: - - # vaidation - model.eval() - val_loss = eval(model, val_dataloader) - ppl = np.exp(val_loss) - tb_writer.add_scalar(f'Val/Loss', val_loss, global_step) - tb_writer.add_scalar(f'Val/PPL', ppl, global_step) - model.train() - - # save model - save_dir = os.path.join(SAVE_PATH, f'epoch_{epoch}') - os.makedirs(save_dir, exist_ok=True) - torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt')) - tokenizer.save_pretrained(save_dir) - torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt')) - torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt')) - print(f'Save model checkpoint to [{save_dir}]') - + if (epoch+1) % FLAGS.save_epoch_interval == 0: + # vaidation + model.eval() + val_loss = eval(model, val_dataloader) + ppl = np.exp(val_loss) + tb_writer.add_scalar(f'Val/Loss', val_loss, global_step) + tb_writer.add_scalar(f'Val/PPL', ppl, global_step) + model.train() + # save model + save_dir = os.path.join(FLAGS.save_path, FLAGS.exp_name, f'epoch_{epoch}') + os.makedirs(save_dir, exist_ok=True) + torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt')) + tokenizer.save_pretrained(save_dir) + torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt')) + torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt')) + print(f'Save model checkpoint to [{save_dir}]') tb_writer.flush() @@ -198,17 +202,19 @@ def eval(model, loader, use_tqdm=False): def inference_batch(model, sents): """Inference model given a batch of sents.""" with torch.no_grad(): - sents = [sent + ' ' + START_OF_PRED for sent in sents] + sents = [sent + ' &' for sent in sents] sent_ids = [tokenizer.encode(sent) for sent in sents] max_len = max([len(sent) for sent in sent_ids]) - sent_ids = [sent + [tokenizer.pad_token_id]*(max_len-len(sent)) for sent in sent_ids] + # ma_len = min(max_len, FLAGS.max_seq_len) + sent_ids = [[0]*(max_len-len(sent)) + sent for sent in sent_ids] inputs = torch.LongTensor(sent_ids).to(local_rank) model_to_run = model.module if type(model) is DDP else model - outputs = model_to_run.generate(inputs, max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.pad_token_id, - pad_token_id=tokenizer.pad_token_id) # greedy + outputs = model_to_run.generate(inputs, attention_mask=(inputs != 0).float(), max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.eos_token_id) # greedy # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id, # pad_token_id=gpt2_tokenizer.pad_token_id) # beam search - output_strs = [tokenizer.decode(item) for item in outputs] + # output_strs = [tokenizer.decode(item) for item in outputs] + outputs = outputs[:, len(inputs[0]):] + output_strs = tokenizer.batch_decode(outputs) return output_strs @@ -226,26 +232,42 @@ def inference_sents(model, sents): return outputs +def inference_sents_by_batch(model, sents): + """Get the outputs of multiple sentences.""" + start_idx = 0 + ret = [] + start = time.time() + while start_idx < len(sents): + end_idx = start_idx + FLAGS.batch_size + curr_sents = sents[start_idx:end_idx] + outputs = inference_batch(model, curr_sents) + ret += outputs + start_idx += FLAGS.batch_size + time_remain = (time.time()-start) / start_idx * (len(sents) - start_idx) + print('{}/{}, time remaining: {:.2f}'.format(start_idx, len(sents), time_remain)) + return ret + + def test(model, nlg_data, ontology, model_path): """将sheel中的GPU个数设为1运行""" model.load_state_dict(torch.load(model_path)) model.eval() print(f'model loaded from [{model_path}]') # Load test nlg data - test_data = nlg_data['test'] + test_data = filter_empty_nlg_data(nlg_data['test']) dialog_acts = [act2str(item['dialogue_acts']).strip() for item in test_data] golden_responses = [item['utterance'].strip() for item in test_data] # dialog_acts = dialog_acts[:10] # golden_responses = golden_responses[:10] - outputs = inference_sents(model, dialog_acts) + outputs = inference_sents_by_batch(model, dialog_acts) def get_real_output(ipt): - if '[start_of_pred]' in ipt: - ipt = ipt[ipt.index('[start_of_pred]')+15:].strip() - if '[_pad_token_]' in ipt: - ipt = ipt[:ipt.index('[_pad_token_]')].strip() + if tokenizer.eos_token in ipt: + ipt = ipt[:ipt.index(tokenizer.eos_token)].strip() return ipt outputs = [get_real_output(item) for item in outputs] - output_file = './test_output.json' + if not os.path.exists('./test_outputs'): + os.makedirs('./test_outputs', exist_ok=True) + output_file = f'./test_outputs/{FLAGS.exp_name}.json' if dist.get_rank() == 0: with open(output_file, 'w+') as f: result = [] @@ -253,7 +275,9 @@ def test(model, nlg_data, ontology, model_path): result.append({ 'dialogue_acts': test_data[i]['dialogue_acts'], 'utterance': test_data[i]['utterance'], - 'prediction': outputs[i] + 'predictions': { + 'utterance': outputs[i] + } }) json.dump(result, f, indent=2, ensure_ascii=False) evaluator = GentScorer() @@ -307,12 +331,45 @@ def test(model, nlg_data, ontology, model_path): # f.write(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}') # f.close() +def filter_empty_nlg_data(data): + ret = [] + empty_number = 0 + for item in data: + acts = item['dialogue_acts'] + acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical']) + if acts_size == 0: + empty_number += 1 + continue + else: + ret.append(item) + print('empty count: ', empty_number) + return ret + if __name__ == '__main__': - dataset = load_dataset(FLAGS.dataset) - ontology = load_ontology(FLAGS.dataset) - nlg_data = load_nlg_data(dataset) - if FLAGS.do_train: - train(model, nlg_data) + if '_' in FLAGS.dataset: + spans = FLAGS.dataset.split('_') + data_list = spans + datasets = [load_dataset(item) for item in data_list] + nlg_datas = [load_nlg_data(item) for item in datasets] + ret = {} + def aggregrate(nlg_datas, split): + ret = [] + for item in nlg_datas: + ret += item[split] + return ret + ret['train'] = aggregrate(nlg_datas, 'train') + ret['validation'] = aggregrate(nlg_datas, 'validation') + ret['test'] = aggregrate(nlg_datas, 'test') + if FLAGS.do_train: + train(model, ret) + else: + print('not supported') else: - test(model, nlg_data, ontology, FLAGS.model_path) + dataset = load_dataset(FLAGS.dataset, dial_ids_order=0, split2ratio={'train': FLAGS.train_ratio}) + ontology = load_ontology(FLAGS.dataset) + nlg_data = load_nlg_data(dataset) + if FLAGS.do_train: + train(model, nlg_data) + else: + test(model, nlg_data, ontology, FLAGS.model_path) diff --git a/convlab/nlg/scgpt/model.py b/convlab/nlg/scgpt/model.py index 9a41cfa53e92938d781abc27bc7082c3948b7a0d..6d41319ec4d7254a27620dfe6d20fcdb684e8773 100644 --- a/convlab/nlg/scgpt/model.py +++ b/convlab/nlg/scgpt/model.py @@ -24,5 +24,29 @@ class SCGPTDataset(Dataset): def __len__(self): return len(self.data) + def __getitem__(self, idx): + return self.data[idx] + + +class SGD_TMDataset(Dataset): + def __init__(self, data, tokenizer): + """ + Args: + data: [[da_str, response], [da_str, response], ...] + tokenizer: GPT2 Tokenizer + """ + self.data = [] + length_list = [] + for item in data: + da, response = item['dialogue_acts'], item['utterance'] + da_tokens = tokenizer.encode(act2str(da)) + response_tokens = tokenizer.encode(response) + length_list.append(len(da_tokens) + len(response_tokens) + 1) + self.data.append([da_tokens, response_tokens]) + print(f'max: {np.max(length_list)}, min: {np.min(length_list)}, median: {np.quantile(length_list, 0.5)}, 0.99: {np.quantile(length_list, 0.99)}') + + def __len__(self): + return len(self.data) + def __getitem__(self, idx): return self.data[idx] \ No newline at end of file diff --git a/convlab/nlg/scgpt/multiwoz/__init__.py b/convlab/nlg/scgpt/multiwoz/__init__.py deleted file mode 100644 index 88c7ca2e9735ded913e007644fc8b46fd78535f6..0000000000000000000000000000000000000000 --- a/convlab/nlg/scgpt/multiwoz/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Sat Apr 4 21:43:42 2020 - -@author: truthless -""" - -from convlab.nlg.scgpt.multiwoz.scgpt import SCGPT \ No newline at end of file diff --git a/convlab/nlg/scgpt/multiwoz/preprocess.py b/convlab/nlg/scgpt/multiwoz/preprocess.py deleted file mode 100644 index 3f5cf70f664895cd7f1c7d67f6c31903d157809a..0000000000000000000000000000000000000000 --- a/convlab/nlg/scgpt/multiwoz/preprocess.py +++ /dev/null @@ -1,129 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Mon Sep 14 11:38:53 2020 -@author: truthless -""" - -import os -import json -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -from convlab.nlg.scgpt.utils import dict2dict, dict2seq -import zipfile - -def read_zipped_json(filepath, filename): - print("zip file path = ", filepath) - archive = zipfile.ZipFile(filepath, 'r') - return json.load(archive.open(filename)) - -def init_domain(): - return {'Attraction':False, - 'Hospital':False, - 'Hotel':False, - 'Police':False, - 'Restaurant':False, - 'Taxi':False, - 'Train':False} - -def write_file(name, data, role='usr'): - with open(f'{name}.txt', 'w', encoding='utf-8') as f: - for ID in data: - sess = data[ID] - sess_domains = init_domain() - for turn in sess: - if role == 'usr': - if not turn['usr_da']: - continue - turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train')) - da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and') - domains = set([key.split('-')[0] for key in turn['usr_da'].keys()]) - elif role == 'sys': - if not turn['sys_da']: - continue - turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train')) - da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and') - domains = set([key.split('-')[0] for key in turn['sys_da'].keys()]) - else: - raise NameError('Invalid Role: Select usr/sys.') - for domain in domains: - if domain not in ['general', 'Booking'] and not sess_domains[domain]: - da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1) - sess_domains[domain] = True - - if role == 'usr': - da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and') - elif role == 'sys': - da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and') - f.write(f'{da_seq} & {da_uttr}\n') - - -if __name__ == '__main__': - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--role', type=str, default='usr') - args = parser.parse_args() - - cur_dir = os.path.dirname(os.path.abspath(__file__)) - data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( - cur_dir)))), 'data/multiwoz/') - - keys = ['train', 'val', 'test'] - data = {} - for key in keys: - data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json') - print('load {}, size {}'.format(key, len(data_key))) - data = dict(data, **data_key) - - with open(os.path.join(data_dir, 'valListFile'), 'r') as f: - val_list = f.read().splitlines() - with open(os.path.join(data_dir, 'testListFile'), 'r') as f: - test_list = f.read().splitlines() - - results = {} - results_val = {} - results_test = {} - - for title, sess in data.items(): - logs = sess['log'] - turns = [] - turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''} - current_domain = None - for i, diag in enumerate(logs): - text = diag['text'] - da = diag['dialog_act'] - span = diag['span_info'] - if current_domain: - da = eval(str(da).replace('Booking', current_domain)) - span = eval(str(span).replace('Booking', current_domain)) - if i % 2 == 0: - turn['usr'] = text - turn['usr_da'] = da - turn['usr_span'] = span - turns.append(turn) - else: - turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''} - turn['sys'] = text - turn['sys_da'] = da - turn['sys_span'] = span - for key in da: - domain = key.split('-')[0] - if domain not in ['general', 'Booking']: - current_domain = domain - else: - if args.role == 'sys': - turns.append(turn) - title = title - if title in val_list: - current = results_val - elif title in test_list: - current = results_test - else: - current = results - current[title] = turns - - results = eval(str(results).replace(" n't", " not")) - results_val = eval(str(results_val).replace(" n't", " not")) - results_test = eval(str(results_test).replace(" n't", " not")) - - if not os.path.exists(os.path.join(cur_dir,'data')): - os.makedirs(os.path.join(cur_dir, 'data')) - write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role) - write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role) diff --git a/convlab/nlg/scgpt/multiwoz/run.py b/convlab/nlg/scgpt/multiwoz/run.py deleted file mode 100644 index e583fe72fb26cd4262a6c4aae7776aabee49293b..0000000000000000000000000000000000000000 --- a/convlab/nlg/scgpt/multiwoz/run.py +++ /dev/null @@ -1,171 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -import argparse -import logging -from tqdm import trange - -import torch -import torch.nn.functional as F -import numpy as np - -import sys - -from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig - -from transformers import GPT2LMHeadModel, GPT2Tokenizer -from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer -from transformers import XLNetLMHeadModel, XLNetTokenizer -from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer -from transformers import CTRLLMHeadModel, CTRLTokenizer -from transformers import XLMWithLMHeadModel, XLMTokenizer - - -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) -logger = logging.getLogger(__name__) - -MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop - -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ()) - -MODEL_CLASSES = { - 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), - 'ctrl': (CTRLLMHeadModel, CTRLTokenizer), - 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - 'xlnet': (XLNetLMHeadModel, XLNetTokenizer), - 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), - 'xlm': (XLMWithLMHeadModel, XLMTokenizer), -} - -# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia -# in https://github.com/rusiaaman/XLNet-gen#methodology -# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e -PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family -(except for Alexei and Maria) are discovered. -The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the -remainder of the story. 1883 Western Siberia, -a young Grigori Rasputin is asked by his father and a group of men to perform magic. -Rasputin has a vision and denounces one of the men as a horse thief. Although his -father initially slaps him for making such an accusation, Rasputin watches as the -man is chased outside and beaten. Twenty years later, Rasputin sees a vision of -the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, -with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" - - -def set_seed(args): - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_type", default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) - parser.add_argument("--model_name_or_path", default=None, type=str, required=True, - help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) - parser.add_argument("--prompt", type=str, default="") - parser.add_argument("--padding_text", type=str, default="") - parser.add_argument("--length", type=int, default=40) - parser.add_argument("--num_samples", type=int, default=1) - parser.add_argument("--temperature", type=float, default=1.0, - help="temperature of 0 implies greedy sampling") - parser.add_argument("--repetition_penalty", type=float, default=1.0, - help="primarily useful for CTRL model; in that case, use 1.2") - parser.add_argument("--top_k", type=int, default=50) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--no_cuda", action='store_true', - help="Avoid using CUDA when available") - parser.add_argument('--seed', type=int, default=42, - help="random seed for initialization") - parser.add_argument('--stop_token', type=str, default=None, - help="Token at which text generation is stopped") - parser.add_argument("--batch_size", default=1, type=int) - parser.add_argument('--input_file', type=str, default=None, - help="file") - parser.add_argument('--output_file', type=str, default=None, - help="file") - - args = parser.parse_args() - - args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") - args.n_gpu = torch.cuda.device_count() - - set_seed(args) - - args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, pad_token='<PAD>', padding_side='left') - model = model_class.from_pretrained(args.model_name_or_path) - model.to(args.device) - model.eval() - - if args.length < 0 and model.config.max_position_embeddings > 0: - args.length = model.config.max_position_embeddings - elif 0 < model.config.max_position_embeddings < args.length: - args.length = model.config.max_position_embeddings # No generation bigger than model size - elif args.length < 0: - args.length = MAX_LENGTH # avoid infinite loop - - logger.info(args) - if args.model_type in ["ctrl"]: - if args.temperature > 0.7: - logger.info('CTRL typically works better with lower temperatures (and lower top_k).') - - fin = open(args.input_file) - inputs = [i.strip() for i in fin] - output_tests = [] - for idx in range(0, len(inputs), args.batch_size): - logger.info(f"PROGRESS: {int(idx/len(inputs)*100)}%") - - # raw_text = args.prompt if args.prompt else input("Model prompt >>> ") - raw_inputs = [] - for i in range(idx, min(idx+args.batch_size, len(inputs))): - lines = inputs[i] - raw_text = lines.split(' & ')[0] + ' & ' - if args.model_type in ["transfo-xl", "xlnet"]: - # Models with memory likes to have a long prompt for short inputs. - raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text - raw_inputs.append(raw_text) - - encoding_inputs = tokenizer.batch_encode_plus(raw_inputs, pad_to_max_length=True, add_special_tokens=False) - context_tokens = torch.LongTensor(encoding_inputs['input_ids']).to(args.device) - max_length = len(context_tokens[0]) - attention_mask = torch.LongTensor(encoding_inputs['attention_mask']).to(args.device) - position_ids = (attention_mask.cumsum(-1) - 1) - position_ids.masked_fill_(attention_mask==0, 0) - - if args.model_type == "ctrl": - if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()): - logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") - out_ids = model.generate( - input_ids=context_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - num_beams=args.num_samples, - num_return_sequences=args.num_samples, - max_length=args.length, - temperature=args.temperature, - do_sample=True, - top_k=args.top_k, - top_p=args.top_p, - repetition_penalty=args.repetition_penalty - ) - out_ids = out_ids.reshape(len(raw_inputs), args.num_samples, -1)[:, :, max_length:].tolist() - for j, out in enumerate(out_ids): - examples = [inputs[j]] - for o in out: - text = tokenizer.decode(o, clean_up_tokenization_spaces=True) - text = text[: text.find(args.stop_token) if args.stop_token else None] - examples.append(text) - output_tests.append(examples) - # break - # if args.prompt: - # break - import json - json.dump(output_tests, open(args.output_file,'w'), indent=2) - return text - -if __name__ == '__main__': - main() diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py index 22763df86987755fbbdd7c5edc618e4673a90004..4977b4f3778229293bb60c2e98b7e715ebe227dd 100644 --- a/convlab/nlg/scgpt/scgpt.py +++ b/convlab/nlg/scgpt/scgpt.py @@ -2,27 +2,22 @@ import sys sys.path.append('../../..') import torch -from transformers import GPT2Tokenizer, GPT2LMHeadModel +from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config from torch.nn.parallel import DistributedDataParallel as DDP from convlab.nlg.nlg import NLG from util import act2str from scgpt_special_tokens import * -special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK] class SCGPT(NLG): def __init__(self, dataset_name, model_path, device='cpu'): super(SCGPT, self).__init__() self.device = device - self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) - self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, - 'additional_special_tokens': special_tokens}) - self.model.resize_token_embeddings(len(self.tokenizer)) + self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device) + self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') self.model.load_state_dict(torch.load(model_path)) - def generate(self, action): action_str = act2str(action) output = self._inference_batch([action_str])[0] @@ -30,16 +25,19 @@ class SCGPT(NLG): def _inference_batch(self, sents): with torch.no_grad(): - sents = [sent + ' ' + START_OF_PRED for sent in sents] - sent_ids = [self.tokenizer.encode(sent) for sent in sents] + sents = [sent for sent in sents] + sent_ids = [self.tokenizer.encode(sent) + [self.tokenizer._convert_token_to_id_with_added_voc('&')] for sent in sents] max_len = max([len(sent) for sent in sent_ids]) - sent_ids = [sent + [self.tokenizer.pad_token_id] * (max_len - len(sent)) for sent in sent_ids] + sent_ids = [sent + [0] * (max_len - len(sent)) for sent in sent_ids] inputs = torch.LongTensor(sent_ids).to(self.device) model_to_run = self.model.module if type(self.model) is DDP else self.model - outputs = model_to_run.generate(inputs, max_length=256, - eos_token_id=self.tokenizer.pad_token_id, - pad_token_id=self.tokenizer.pad_token_id) # greedy - # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id, - # pad_token_id=gpt2_tokenizer.pad_token_id) # beam search - output_strs = [self.tokenizer.decode(item) for item in outputs] + outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs!=0).float(), + eos_token_id=self.tokenizer.pad_token_id) # greedy + outputs = outputs[:, len(inputs[0]):] + def clean_sentence(sent): + sent = sent.strip() + if self.tokenizer.eos_token in sent: + sent = sent[:sent.index(self.tokenizer.eos_token)] + return sent + output_strs = [clean_sentence(item) for item in outputs] return output_strs \ No newline at end of file diff --git a/convlab/nlg/scgpt/scgpt_special_tokens.py b/convlab/nlg/scgpt/scgpt_special_tokens.py index 643820dd04e26bde83edddcd4581784577ad3853..4610be5ff74e6322d58a37a6eaf04b0dce7c7216 100644 --- a/convlab/nlg/scgpt/scgpt_special_tokens.py +++ b/convlab/nlg/scgpt/scgpt_special_tokens.py @@ -3,7 +3,7 @@ SYS_SPEAK = '[sys_speak]' USR_SPEAK = '[usr_speak]' START_OF_PRED = '[start_of_pred]' END_OF_PRED = '[end_of_pred]' -PAD_TOKEN = '[_pad_token_]' +PAD_TOKEN = '<|pad_token|>' START_OF_INTENT = '[start_of_intent]' END_OF_INTENT = '[end_of_intent]' START_OF_SLOT = '' diff --git a/convlab/nlg/scgpt/train.sh b/convlab/nlg/scgpt/train.sh old mode 100644 new mode 100755 index d36d1066abec984ca89c203435a5cf7111209c98..fbfed6a496c387cb8f5c090ac20aee33f79d0325 --- a/convlab/nlg/scgpt/train.sh +++ b/convlab/nlg/scgpt/train.sh @@ -1 +1,13 @@ -CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21 --scgpt_model_ckpt_path /data/zhangzheng/scgpt \ No newline at end of file +CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2040 main.py \ +--batch_size 64 \ +--accumulation_step 2 \ +--epoch_num 20 \ +--lr 5e-5 \ +--base_model_name_path gpt2-medium \ +--val_step 500 \ +--exp_name mwoz_sgd_tm_train \ +--do_train \ +--dataset multiwoz21_sgd_tm1_tm2_tm3 \ +--train_ratio 1.0 \ +# --scgpt_model_ckpt_path saved_models/gpt2_sgd_tm/epoch_2/epoch_2_step13698.pt +# --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \