Skip to content
Snippets Groups Projects
Commit 54c6842d authored by zz-jacob's avatar zz-jacob
Browse files

fix scgpt bugs

parent fa5d36c2
No related branches found
No related tags found
No related merge requests found
CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py \
CUDA_VISIBLE_DEVICES="1" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2051 main.py \
--batch_size 64 \
--base_model_name_path gpt2-medium \
--dataset multiwoz21 \
--scgpt_model_ckpt_path /data/zhangzheng/scgpt \
--model_path /data/zhangzheng/ConvLab-3/convlab/nlg/scgpt/saved_model/epoch_4/epoch_4_step8875.pt
\ No newline at end of file
--exp_name gpt2_mwoz2 \
--model_path saved_models/gpt2_mwoz/epoch_2/epoch_2_step1329.pt \
\ No newline at end of file
......@@ -4,11 +4,13 @@ sys.path.append('../../..')
import argparse
import json
from tqdm import tqdm
import time
import torch
from functools import reduce
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
......@@ -29,10 +31,19 @@ code_test = False
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--lr", default=1e-5, type=float, help="learning rate")
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--train_ratio", default=1.0, type=float)
parser.add_argument("--accumulation_step", default=4, type=int)
parser.add_argument("--epoch_num", default=20, type=int)
parser.add_argument("--val_step", default=100, type=int)
parser.add_argument('--do_train', action="store_true", help="Whether to run training.")
parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.")
parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.")
parser.add_argument('--scgpt_model_ckpt_path', default="", type=str, help="The path of model for testing.")
parser.add_argument('--base_model_name_path', default="gpt2", type=str, help="The path of base model.")
parser.add_argument('--scgpt_model_ckpt_path', default=None, type=str, help="The path of model for testing.")
parser.add_argument('--save_path', default="saved_models", type=str, help="Model save path.")
parser.add_argument('--exp_name', default="default_name", type=str, help="Current experiment name.")
parser.add_argument("--max_seq_len", default=128, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank
......@@ -41,25 +52,20 @@ torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
# TensorBoard
tb_dir = './runs'
tb_dir = 'runs/' + FLAGS.exp_name
if not os.path.exists(tb_dir):
os.mkdir(tb_dir)
tb_writer = SummaryWriter(tb_dir)
tb_writer = SummaryWriter(tb_dir, flush_secs=5)
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
## load model
if FLAGS.scgpt_model_ckpt_path == '':
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2')
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank)
model.resize_token_embeddings(len(tokenizer))
if FLAGS.scgpt_model_ckpt_path is None:
tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path)
model = GPT2LMHeadModel.from_pretrained(FLAGS.base_model_name_path).to(local_rank)
else:
tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.scgpt_model_ckpt_path)
tokenizer.add_special_tokens(
{'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
model = GPT2LMHeadModel.from_pretrained(FLAGS.scgpt_model_ckpt_path).to(local_rank)
tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path)
model = GPT2LMHeadModel(config=GPT2Config.from_pretrained(FLAGS.base_model_name_path)).to(local_rank)
model.load_state_dict(torch.load(FLAGS.scgpt_model_ckpt_path))
print('model load from ' + FLAGS.scgpt_model_ckpt_path)
model.resize_token_embeddings(len(tokenizer))
nll_loss = nn.NLLLoss(reduce=False).to(local_rank)
ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank)
......@@ -80,85 +86,80 @@ def cal_loss(input, target, seq_lens, seq_lens_input):
return mean_loss
def pad_collate(batch):
def pad_collate(ori_batch):
"""
Returns:
batch: batch * max_len
seq_lens: the length of len(da)+1+len(response)
seq_lens_input: the length of len(da)
"""
START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
pad_token_id = tokenizer.pad_token_id
batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch]
START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc('&')
batch = [item[0] + [START_OF_PRED_ID] + item[1] + [tokenizer.eos_token_id] for item in ori_batch]
output_lens = [len(item[1])+1 for item in ori_batch]
batch = [item[-FLAGS.max_seq_len:] for item in batch]
max_len = max([len(item) for item in batch])
# print('max_len', max_len)
seq_lens = [len(item) for item in batch]
split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
def get_x_len(tokens):
"""Get the length of dialogue act tokens"""
split_idx = len(tokens)
try:
split_idx = tokens.index(split_id)+1
except:
pass
return split_idx
seq_lens_input = [get_x_len(item) for item in batch]
batch = [item + [pad_token_id]*(max_len-len(item)) for item in batch]
# print(batch)
# print(seq_lens)
# print(seq_lens_input)
seq_lens_input = []
for idx in range(len(batch)):
curr_ipt_len = seq_lens[idx] - output_lens[idx]
if curr_ipt_len < 0:
curr_ipt_len = 0
seq_lens_input.append(curr_ipt_len)
batch = [item + [0]*(max_len-len(item)) for item in batch]
return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input)
## Training Hyper-params
EPOCH_NUM = 20
BATCH_SIZE = 32 # real_batch_size = BATCH_SIZE * num_gpu
VAL_STEP = 500
WARM_STEPS = 250
if code_test:
EPOCH_NUM = 2
BATCH_SIZE = 4
VAL_STEP = 2
WARM_STEPS = 3
LR = 5e-5
SAVE_PATH = f'./saved_model'
def train(model, nlg_data, global_step=0):
train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
val_dataset = SCGPTDataset(nlg_data['validation'], tokenizer)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
val_dataloader = DataLoader(val_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARM_STEPS,
num_training_steps=len(train_dataloader) * EPOCH_NUM)
optimizer = torch.optim.AdamW(model.parameters(), lr=FLAGS.lr)
t_total = len(train_dataloader) * FLAGS.epoch_num // FLAGS.accumulation_step
warm_steps = int(0.1 * t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_steps,
num_training_steps=t_total)
model.train()
for epoch in range(EPOCH_NUM):
for epoch in range(FLAGS.epoch_num):
train_dataloader.sampler.set_epoch(epoch)
for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{EPOCH_NUM}]')):
for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{FLAGS.epoch_num}]')):
if (batch_id+1) % FLAGS.accumulation_step == 0:
global_step += 1
inputs = inputs.to(local_rank)
seq_lens = seq_lens.to(local_rank)
seq_lens_input = seq_lens_input.to(local_rank)
outputs = model(inputs)
outputs = model(inputs, attention_mask=(inputs!=0).float())
preds = outputs[0]
loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
optimizer.zero_grad()
loss /= FLAGS.accumulation_step
loss.backward()
# update params
if (batch_id+1) % FLAGS.accumulation_step == 0:
optimizer.step()
scheduler.step()
model.zero_grad()
# tensorboard
tb_writer.add_scalar(f'Train/loss', loss.item(), global_step)
tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
if global_step % FLAGS.val_step == 0:
model.eval()
val_loss = eval(model, val_dataloader)
ppl = np.exp(val_loss)
tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
model.train()
global_step += 1
# save the model when each epoch ends
if dist.get_rank() == 0:
# vaidation
model.eval()
val_loss = eval(model, val_dataloader)
......@@ -168,14 +169,13 @@ def train(model, nlg_data, global_step=0):
model.train()
# save model
save_dir = os.path.join(SAVE_PATH, f'epoch_{epoch}')
save_dir = os.path.join(FLAGS.save_path, FLAGS.exp_name, f'epoch_{epoch}')
os.makedirs(save_dir, exist_ok=True)
torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt'))
tokenizer.save_pretrained(save_dir)
torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt'))
print(f'Save model checkpoint to [{save_dir}]')
tb_writer.flush()
......@@ -198,17 +198,19 @@ def eval(model, loader, use_tqdm=False):
def inference_batch(model, sents):
"""Inference model given a batch of sents."""
with torch.no_grad():
sents = [sent + ' ' + START_OF_PRED for sent in sents]
sents = [sent + ' &' for sent in sents]
sent_ids = [tokenizer.encode(sent) for sent in sents]
max_len = max([len(sent) for sent in sent_ids])
sent_ids = [sent + [tokenizer.pad_token_id]*(max_len-len(sent)) for sent in sent_ids]
# ma_len = min(max_len, FLAGS.max_seq_len)
sent_ids = [sent + [0]*(max_len-len(sent)) for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(local_rank)
model_to_run = model.module if type(model) is DDP else model
outputs = model_to_run.generate(inputs, max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.pad_token_id,
pad_token_id=tokenizer.pad_token_id) # greedy
outputs = model_to_run.generate(inputs, attention_mask=(inputs != 0).float(), max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.eos_token_id) # greedy
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search
output_strs = [tokenizer.decode(item) for item in outputs]
# output_strs = [tokenizer.decode(item) for item in outputs]
outputs = outputs[:, len(inputs[0]):]
output_strs = tokenizer.batch_decode(outputs)
return output_strs
......@@ -226,6 +228,22 @@ def inference_sents(model, sents):
return outputs
def inference_sents_by_batch(model, sents):
"""Get the outputs of multiple sentences."""
start_idx = 0
ret = []
start = time.time()
while start_idx < len(sents):
end_idx = start_idx + FLAGS.batch_size
curr_sents = sents[start_idx:end_idx]
outputs = inference_batch(model, curr_sents)
ret += outputs
start_idx += FLAGS.batch_size
time_remain = (time.time()-start) / start_idx * (len(sents) - start_idx)
print('{}/{}, time remaining: {:.2f}'.format(start_idx, len(sents), time_remain))
return ret
def test(model, nlg_data, ontology, model_path):
"""将sheel中的GPU个数设为1运行"""
model.load_state_dict(torch.load(model_path))
......@@ -237,15 +255,15 @@ def test(model, nlg_data, ontology, model_path):
golden_responses = [item['utterance'].strip() for item in test_data]
# dialog_acts = dialog_acts[:10]
# golden_responses = golden_responses[:10]
outputs = inference_sents(model, dialog_acts)
outputs = inference_sents_by_batch(model, dialog_acts)
def get_real_output(ipt):
if '[start_of_pred]' in ipt:
ipt = ipt[ipt.index('[start_of_pred]')+15:].strip()
if '[_pad_token_]' in ipt:
ipt = ipt[:ipt.index('[_pad_token_]')].strip()
if tokenizer.eos_token in ipt:
ipt = ipt[:ipt.index(tokenizer.eos_token)].strip()
return ipt
outputs = [get_real_output(item) for item in outputs]
output_file = './test_output.json'
if not os.path.exists('./test_outputs'):
os.makedirs('./test_outputs', exist_ok=True)
output_file = f'./test_outputs/{FLAGS.exp_name}.json'
if dist.get_rank() == 0:
with open(output_file, 'w+') as f:
result = []
......@@ -253,7 +271,9 @@ def test(model, nlg_data, ontology, model_path):
result.append({
'dialogue_acts': test_data[i]['dialogue_acts'],
'utterance': test_data[i]['utterance'],
'prediction': outputs[i]
'predictions': {
'utterance': outputs[i]
}
})
json.dump(result, f, indent=2, ensure_ascii=False)
evaluator = GentScorer()
......@@ -309,7 +329,26 @@ def test(model, nlg_data, ontology, model_path):
if __name__ == '__main__':
dataset = load_dataset(FLAGS.dataset)
if '_' in FLAGS.dataset:
spans = FLAGS.dataset.split('_')
data_list = spans
datasets = [load_dataset(item) for item in data_list]
nlg_datas = [load_nlg_data(item) for item in datasets]
ret = {}
def aggregrate(nlg_datas, split):
ret = []
for item in nlg_datas:
ret += item[split]
return ret
ret['train'] = aggregrate(nlg_datas, 'train')
ret['validation'] = aggregrate(nlg_datas, 'validation')
ret['test'] = aggregrate(nlg_datas, 'test')
if FLAGS.do_train:
train(model, ret)
else:
print('not supported')
else:
dataset = load_dataset(FLAGS.dataset, dial_ids_order=0, split2ratio={'train': FLAGS.train_ratio})
ontology = load_ontology(FLAGS.dataset)
nlg_data = load_nlg_data(dataset)
if FLAGS.do_train:
......
......@@ -26,3 +26,27 @@ class SCGPTDataset(Dataset):
def __getitem__(self, idx):
return self.data[idx]
class SGD_TMDataset(Dataset):
def __init__(self, data, tokenizer):
"""
Args:
data: [[da_str, response], [da_str, response], ...]
tokenizer: GPT2 Tokenizer
"""
self.data = []
length_list = []
for item in data:
da, response = item['dialogue_acts'], item['utterance']
da_tokens = tokenizer.encode(act2str(da))
response_tokens = tokenizer.encode(response)
length_list.append(len(da_tokens) + len(response_tokens) + 1)
self.data.append([da_tokens, response_tokens])
print(f'max: {np.max(length_list)}, min: {np.min(length_list)}, median: {np.quantile(length_list, 0.5)}, 0.99: {np.quantile(length_list, 0.99)}')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
\ No newline at end of file
......@@ -33,7 +33,7 @@ class SCGPT(NLG):
sents = [sent + ' ' + START_OF_PRED for sent in sents]
sent_ids = [self.tokenizer.encode(sent) for sent in sents]
max_len = max([len(sent) for sent in sent_ids])
sent_ids = [sent + [self.tokenizer.pad_token_id] * (max_len - len(sent)) for sent in sent_ids]
sent_ids = [[self.tokenizer.pad_token_id] * (max_len - len(sent)) + sent for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(self.device)
model_to_run = self.model.module if type(self.model) is DDP else self.model
outputs = model_to_run.generate(inputs, max_length=256,
......@@ -41,5 +41,6 @@ class SCGPT(NLG):
pad_token_id=self.tokenizer.pad_token_id) # greedy
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search
output_strs = [self.tokenizer.decode(item) for item in outputs]
return output_strs
\ No newline at end of file
......@@ -3,7 +3,7 @@ SYS_SPEAK = '[sys_speak]'
USR_SPEAK = '[usr_speak]'
START_OF_PRED = '[start_of_pred]'
END_OF_PRED = '[end_of_pred]'
PAD_TOKEN = '[_pad_token_]'
PAD_TOKEN = '<|pad_token|>'
START_OF_INTENT = '[start_of_intent]'
END_OF_INTENT = '[end_of_intent]'
START_OF_SLOT = ''
......
CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21 --scgpt_model_ckpt_path /data/zhangzheng/scgpt
\ No newline at end of file
CUDA_VISIBLE_DEVICES="3" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2043 main.py \
--batch_size 32 \
--accumulation_step 4 \
--epoch_num 20 \
--lr 5e-5 \
--base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \
--val_step 1000 \
--exp_name scgpt_mwoz \
--do_train \
--dataset sgd \
--train_ratio 1.0 \
# --scgpt_model_ckpt_path saved_models/sgd_tm_1e4/epoch_8/epoch_8_step41094.pt
# --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment