Skip to content
Snippets Groups Projects
Unverified Commit dd3ae1e2 authored by zz-jacob's avatar zz-jacob Committed by GitHub
Browse files

Merge pull request #45 from ConvLab/nlg-scgpt

Nlg scgpt
parents 9f70248a b2cc2e40
No related branches found
No related tags found
No related merge requests found
# GPT # SC-GPT
The code derives from [HuggingFace/Transformers](https://github.com/huggingface/transformers). This is the implemention of [SC-GPT](https://aclanthology.org/2020.findings-emnlp.17) which is proposed by
Peng et al., 2020.
## Preprocess
```python
cd $dataset$
python preprocess.py
```
## Train ## Train
Fetch and unzip the checkpoint ```python
./train.sh
```
wget https://bapengstorage.blob.core.windows.net/fileshare/scgpt.tar.gz
tar -xvf scgpt.tar.gz
``` ```
When using the training code, you may have to adjust the parameters
according to your machine configuration. Note that the training code
only supports GPU training.
Then ## Test
```python ```python
python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir ./test.sh
``` ```
The test code also only supports GPU mode. We will report the BLEU score
and ERR score according to the original SC-GPT paper(Peng et al., 2020).
some tricks (optional training argument): ## NLG Interface
* `--gradient_accumulation_steps xxx` The NLG interface of SC-GPT is implemented in ./scgpt.py.
* `--fp16`, if it's set, you'd better set `--per_gpu_train_batch_size` to be multiple of 8
* `--max_seq xxx`, it should be larger than the length of the longest sequence. You can set `--max_seq 1024`. The script uses a dynamic sequence length at each training step.
* `--gradient_checkpointing`, it allows larger `per_gpu_train_batch_size`
* `--use_multi_tensor_adamw`, someone says it's a faster optimizer
distributed data parallel:
If multiple GPUs are available, you can run `python -m torch.distributed.launch --nproc_per_node CUDA_COUNT train.py ......`
`CUDA_COUNT` is the number of GPUs. `.....` are arguments of `train.py`.
## Use
```python ```python
python run.py --model_type=gpt2 --model_name_or_path=$save_dir$ --num_samples 5 --input_file=$test_file$ --output_file=$output_file$ --length 100 --stop_token '<|endoftext|>' --batch_size 16 def generate(self, action)
``` ```
This class supports both CPU and GPU mode by providing the
'device' parameter in constructor function.
## Data Format
## Reference
``` ```
dialog act seq & user utterance @inproceedings{peng-etal-2020-shot,
title = "Few-shot Natural Language Generation for Task-Oriented Dialog",
author = "Peng, Baolin and Zhu, Chenguang and Li, Chunyuan and Li, Xiujun and Li, Jinchao and Zeng, Michael and Gao, Jianfeng",
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
month = nov,
year = "2020",
publisher = "Association for Computational Linguistics",
pages = "172--182",
}
``` ```
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 4 21:34:38 2020
@author: truthless
"""
import numpy as np
import torch
def set_seed(seed, n_gpu):
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
with torch.no_grad():
for _ in range(length):
inputs = {'input_ids': generated}
if is_xlnet:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
if is_xlm_mlm and xlm_mask_token:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
inputs = {'input_ids': input_ids}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for i in range(num_samples):
for _ in set(generated[i].tolist()):
next_token_logits[i, _] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0: # greedy sampling:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
else:
next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
return generated
# Part of the evaluation script is adopted from https://github.com/pengbaolin/SC-GPT.
import os
import json
import sys
import math
import operator
import nltk
from collections import Counter
from nltk.util import ngrams
file = open
class ERRScorer():
## Scorer for calculating the slot errors
## it scores utterances one by one
## using two levels of matching
## 1. exact match for categorical values
## 2. multiple keyword matching for binary values
## 3. cannot deal with don't care and none values
def __init__(self, detectfile):
self.detectPairs = []
fin = file(detectfile)
self.detectPairs = json.load(fin)
fin.close()
def countSlots(self, dataset, reader):
count = 0
for t in dataset:
feat = reader.formatter.format(t[0])[0]
c = count
for s, v in feat:
# skip type token
if s == 'type':
continue
if v == '_' or v == 'yes' or v == 'none' or v == 'no':
count += 1
return count
def score(self, a, feat, gen):
# import pdb
# pdb.set_trace()
# total slots
slot_count = 0
# exact match for categorical slots
caty_slot_error = 0
# fo each slot - token pair in the detect pair dict
for s, tok in self.detectPairs['general'].items():
# token compare to
comparetos = ['sv.' + s + '._1', 'sv.' + s + '._2', 'sv.' + s + '._3']
# count feature count in da feature
fcnt = 0
for f in feat:
for compareto in comparetos:
if compareto == f: fcnt += 1
# count generated semantic tokens
gcnt = gen.split().count(tok)
# count the slot difference
# if fcnt!=gcnt:
# caty_slot_error += 1.0
caty_slot_error += abs(fcnt - gcnt)
# accumulate slot count
slot_count += fcnt
# key word match for binary slots, only an approximation
bnay_slot_error = 0
# for each binary slot
for s, toks in self.detectPairs['binary'].items():
# tokens compare to
comparetos = ['sv.' + s + '.yes', 'sv.' + s + '.no',
'sv.' + s + '.dontcare', 'sv.' + s + '.none']
# count feature occurrence in da
fcnt = 0
for f in feat:
for compareto in comparetos:
if compareto == f: fcnt += 1
# count generated semantic tokens
gcnt = sum([gen.split().count(tok) for tok in toks])
# count the slot difference
bnay_slot_error += abs(fcnt - gcnt)
# accumulate slot count
slot_count += fcnt
# total slot error
total_slot_error = caty_slot_error + bnay_slot_error
# when ?select/suggest act, only consider categorical errors
if a == [4] or a == [14]:
# return slot_count, caty_slot_error, caty_slot_error
return 0.0, 0.0, 0.0
else:
return slot_count, total_slot_error, caty_slot_error
class BLEUScorer(object):
## BLEU score calculator via GentScorer interface
## it calculates the BLEU-4 by taking the entire corpus in
## Calulate based multiple candidates against multiple references
def __init__(self):
pass
def score(self, parallel_corpus):
# ref_ = []
# hyp_ = []
# for hyps,refs in parallel_corpus:
# ref_.append(refs)
# hyp_.append(hyps[0])
# return nltk.translate.bleu_score.corpus_bleu(ref_, hyp_)
# asdf
# containers and parameters
r, c = 0, 0
count = [0, 0, 0, 0]
clip_count = [0, 0, 0, 0]
weights = [0.25, 0.25, 0.25, 0.25]
# accumulate ngram statistics
for hyps, refs in parallel_corpus:
# BLEUscore = nltk.translate.bleu_score.sentence_bleu(refs, hyps[0])
# print(hyps, refs, BLEUscore)
hyps = [hyp.lower().split() for hyp in hyps]
refs = [ref.lower().split() for ref in refs]
# compute ngram counts by matching each hypothesis
for hyp in hyps:
# for each ngram
for i in range(4):
# accumulate hyp ngram counts
hypcnts = Counter(ngrams(hyp, i + 1))
cnt = sum(hypcnts.values())
count[i] += cnt
# compute clipped counts
max_counts = {}
# compare to each reference
for ref in refs:
# get reference ngrams
refcnts = Counter(ngrams(ref, i + 1))
# for each ngram
for ng in hypcnts:
# clipped counts
max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng])
# compute clipped counts by clipping the hyp count if necessary
clipcnt = dict((ng, min(count, max_counts[ng])) \
for ng, count in hypcnts.items())
clip_count[i] += sum(clipcnt.values())
# accumulate r & c, find best match among all references
bestmatch = [1000, 1000]
for ref in refs:
if bestmatch[0] == 0: break
# length difference
diff = abs(len(ref) - len(hyp))
# if the current diff less than stored one, change it
if diff < bestmatch[0]:
bestmatch[0] = diff
bestmatch[1] = len(ref)
# extract the best length match in references
r += bestmatch[1]
c += len(hyp)
# computing bleu score
# for numerical stability
p0 = 1e-7
# brevity penality
bp = 1 if c > r else math.exp(1 - float(r) / float(c))
# modified prec.
p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \
for i in range(4)]
# weighted prec.
s = math.fsum(w * math.log(p_n) \
for w, p_n in zip(weights, p_ns) if p_n)
# final bleu score
bleu = bp * math.exp(s)
return bleu
def sentence_bleu_4(self, parallel_corpus):
# input : single sentence, multiple references
count = [0, 0, 0, 0]
clip_count = [0, 0, 0, 0]
weights = [0.25, 0.25, 0.25, 0.25]
r = 0
c = 0
# accumulate ngram statistics
for hyps, refs in parallel_corpus:
hyps = [hyp.split() for hyp in hyps]
refs = [ref.split() for ref in refs]
# compute ngram counts by matching each hypothesis
for hyp in hyps:
# for each ngram
for i in range(4):
# accumulate hyp ngram counts
hypcnts = Counter(ngrams(hyp, i + 1))
cnt = sum(hypcnts.values())
count[i] += cnt
# compute clipped counts
max_counts = {}
# compare to each reference
for ref in refs:
# get reference ngrams
refcnts = Counter(ngrams(ref, i + 1))
# for each ngram
for ng in hypcnts:
# clipped counts
max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng])
# compute clipped counts by clipping the hyp count if necessary
clipcnt = dict((ng, min(count, max_counts[ng])) \
for ng, count in hypcnts.items())
clip_count[i] += sum(clipcnt.values())
# accumulate r & c, find best match among all references
bestmatch = [1000, 1000]
for ref in refs:
if bestmatch[0] == 0: break
# length difference
diff = abs(len(ref) - len(hyp))
# if the current diff less than stored one, change it
if diff < bestmatch[0]:
bestmatch[0] = diff
bestmatch[1] = len(ref)
# extract the best length match in references
r += bestmatch[1]
c += len(hyp)
# for numerical stability
p0 = 1e-7
# modified brevity penality
bp = math.exp(-abs(1.0 - float(r) / float(c + p0)))
# smoothed version of modified prec.
p_ns = [0, 0, 0, 0]
for i in range(4):
if i < 2: # original version n-gram counts
p_ns[i] = float(clip_count[i]) / float(count[i] + p0) + p0
else: # smoothed version of ngram counts
smooth_term = 5 * p_ns[i - 1] * p_ns[i - 1] / p_ns[i - 2]
p_ns[i] = float(clip_count[i] + smooth_term) / float(count[i] + 5) + p0
# weighted prec.
s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)
# final sentence bleu score
bleu_hyp = bp * math.exp(s)
return bleu_hyp
class GentScorer(object):
## main Scorer interfaces for all scorers
## it can do
## 1. Compute bleu score
## 2. Compute slot error rate
## 3. Detailed illustraction of how differet split
## of data affect performance
def __init__(self):
self.bleuscorer = BLEUScorer()
def scoreERR(self, parallel_pairs):
"""input: [[dialoge_act, utterance], [dialog_act, utterance], ...]"""
def scoreBLEU(self, parallel_corpus):
return self.bleuscorer.score(parallel_corpus)
def scoreSBLEU(self, parallel_corpus):
return self.bleuscorer.sentence_bleu_4(parallel_corpus)
\ No newline at end of file
import sys
sys.path.append('../../..')
import argparse
from tqdm import tqdm
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import os
from transformers import get_linear_schedule_with_warmup
from convlab2.util.unified_datasets_util import load_dataset, load_nlg_data, load_ontology
from convlab2.nlg.scgpt.util import act2str
from convlab2.nlg.scgpt.model import SCGPTDataset
from evaluate import GentScorer
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from util import build_mask
from scgpt_special_tokens import *
code_test = False
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument('--do_train', action="store_true", help="Whether to run training.")
parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.")
parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.")
parser.add_argument("--max_seq_len", default=256, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
# TensorBoard
tb_writer = SummaryWriter()
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
## load model
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2')
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank)
model.resize_token_embeddings(len(tokenizer))
nll_loss = nn.NLLLoss(reduce=False).to(local_rank)
ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank)
def cal_loss(input, target, seq_lens, seq_lens_input):
"""Only calculate loss on responses, not on dialog act"""
global nll_loss
"""Input: [batch, length, vocab]; target: [batch, length]; seq_lens: [batch]"""
log_probs = F.log_softmax(input, dim=-1).transpose(1, 2)
loss = nll_loss(log_probs, target)
# loss = ce_loss(input, target)
mask = build_mask(torch.max(seq_lens).item()-1, seq_lens-1).to(local_rank)
input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank)
output_mask = torch.logical_xor(mask, input_mask)
pad_mask = torch.logical_not(mask)
# masked_loss = loss * output_mask
masked_loss = loss * (output_mask + pad_mask)
mean_loss = torch.sum(masked_loss) / torch.sum(output_mask + pad_mask)
return mean_loss
def pad_collate(batch):
"""
Returns:
batch: batch * max_len
seq_lens: the length of len(da)+1+len(response)
seq_lens_input: the length of len(da)
"""
START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
pad_token_id = tokenizer.pad_token_id
batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch]
batch = [item[-FLAGS.max_seq_len:] for item in batch]
max_len = max([len(item) for item in batch])
seq_lens = [len(item) for item in batch]
split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
def get_x_len(tokens):
"""Get the length of dialogue act tokens"""
split_idx = len(tokens)
try:
split_idx = tokens.index(split_id)+1
except:
pass
return split_idx
seq_lens_input = [get_x_len(item) for item in batch]
batch = [item + [pad_token_id]*(max_len-len(item)) for item in batch]
return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input)
## Training Hyper-params
EPOCH_NUM = 20
BATCH_SIZE = 20 # real_batch_size = BATCH_SIZE * num_gpu
VAL_STEP = 300
WARM_STEPS = 250
if code_test:
EPOCH_NUM = 2
BATCH_SIZE = 4
VAL_STEP = 2
WARM_STEPS = 3
LR = 5e-5
TASK_TYPE = 'nlu' # nlu or dst
SAVE_PATH = f'./saved_model'
def train(model, nlg_data, global_step=0):
train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
val_dataset = SCGPTDataset(nlg_data['validation'], tokenizer)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARM_STEPS,
num_training_steps=len(train_dataloader) * EPOCH_NUM)
model.train()
for epoch in range(EPOCH_NUM):
train_dataloader.sampler.set_epoch(epoch)
for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{EPOCH_NUM}]')):
inputs = inputs.to(local_rank)
seq_lens = seq_lens.to(local_rank)
seq_lens_input = seq_lens_input.to(local_rank)
outputs = model(inputs)
preds = outputs[0]
loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
tb_writer.add_scalar(f'Train/loss', loss.item(), global_step)
tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
if batch_id % VAL_STEP == 0:
model.eval()
val_loss = eval(model, val_dataloader)
ppl = np.exp(val_loss)
tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
model.train()
global_step += 1
# save the model when each epoch ends
if dist.get_rank() == 0:
save_dir = os.path.join(SAVE_PATH, f'epoch_{epoch}')
os.makedirs(save_dir, exist_ok=True)
torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt'))
tokenizer.save_pretrained(save_dir)
torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt'))
print(f'Save model checkpoint to [{save_dir}]')
tb_writer.flush()
def eval(model, loader, use_tqdm=False):
with torch.no_grad():
loss_list = []
iter = tqdm(loader, desc='Val') if use_tqdm else loader
for inputs, seq_lens, seq_lens_input in iter:
inputs = inputs.to(local_rank)
seq_lens = seq_lens.to(local_rank)
seq_lens_input = seq_lens_input.to(local_rank)
outputs = model(inputs)
preds = outputs[0]
loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
loss_list.append(loss.item())
mean_loss = np.mean(loss_list)
return mean_loss
def inference_batch(model, sents):
"""Inference model given a batch of sents."""
with torch.no_grad():
sents = [sent + ' ' + START_OF_PRED for sent in sents]
sent_ids = [tokenizer.encode(sent) for sent in sents]
max_len = max([len(sent) for sent in sent_ids])
sent_ids = [sent + [tokenizer.pad_token_id]*(max_len-len(sent)) for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(local_rank)
model_to_run = model.module if type(model) is DDP else model
outputs = model_to_run.generate(inputs, max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.pad_token_id,
pad_token_id=tokenizer.pad_token_id) # greedy
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search
output_strs = [tokenizer.decode(item) for item in outputs]
return output_strs
def inference_sent(model, sent):
"""Inference model given one single sentence."""
return inference_batch(model, [sent])[0]
def inference_sents(model, sents):
"""Get the outputs of multiple sentences."""
outputs = []
for sent in tqdm(sents, desc='Inference Sentences'):
output = inference_sent(model, sent)
outputs.append(output)
return outputs
def test(model, nlg_data, ontology, model_path):
"""将sheel中的GPU个数设为1运行"""
model.load_state_dict(torch.load(model_path))
model.eval()
print(f'model loaded from [{model_path}]')
# sample_file = os.path.join(f'../../../data/dstc2/sample50_{TASK_TYPE}_input_data.txt')
# Load test nlg data
test_data = nlg_data['test']
dialog_acts = [act2str(item['dialogue_acts']) for item in test_data]
golden_responses = [item['utterance'] for item in test_data]
# dialog_acts = dialog_acts[:10]
# golden_responses = golden_responses[:10]
outputs = inference_sents(model, dialog_acts)
if dist.get_rank() == 0:
output_file = './test_output.txt'
with open(output_file, 'w+') as f:
for i in range(len(dialog_acts)):
f.write(f'{dialog_acts[i]}\n{golden_responses[i]}\n{outputs[i]}\n\n')
f.close()
evaluator = GentScorer()
parallel_corpus = []
# BLEU
for i in range(len(dialog_acts)):
parallel_corpus.append([[golden_responses[i]], [outputs[i]]])
BLEU_Score = evaluator.scoreSBLEU(parallel_corpus)
# ERR
## all values in ontology
val2ds_dict = {}
for domain_name in ontology['domains']:
domain = ontology['domains'][domain_name]
for slot_name in domain['slots']:
slot = domain['slots'][slot_name]
if 'possible_values' not in slot:
continue
possible_vals = slot['possible_values']
if len(possible_vals) > 0:
for val in possible_vals:
val2ds_dict[val] = f'{domain_name}-{slot_name}'
## missing values
score_list = []
for item in test_data:
da = item['dialogue_acts']
utterance = item['utterance']
missing_count = 0
redundant_count = 0
all_count = 0
all_values = set()
for key in da:
slot_value = da[key]
for triple in slot_value:
if 'value' in triple:
value = triple['value']
all_values.add(value)
if value.strip().lower() not in utterance.lower():
missing_count += 1
all_count += 1
if all_count == 0:
continue
## redundant values
for val in val2ds_dict:
if f' {val.strip().lower()} ' in f' {utterance.strip().lower()} ' and val.strip().lower() not in all_values:
redundant_count += 1
item_score = float(redundant_count + redundant_count) / all_count
score_list.append(item_score)
ERR_Score = np.mean(score_list)
print(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}')
if __name__ == '__main__':
dataset = load_dataset(FLAGS.dataset)
ontology = load_ontology(FLAGS.dataset)
nlg_data = load_nlg_data(dataset)
if FLAGS.do_train:
train(model, nlg_data)
else:
test(model, nlg_data, ontology, FLAGS.model_path)
from torch.utils.data import Dataset
from util import act2str
from scgpt_special_tokens import *
import torch
import numpy as np
class SCGPTDataset(Dataset):
def __init__(self, data, tokenizer):
"""
Args:
data: [[da_str, response], [da_str, response], ...]
tokenizer: GPT2 Tokenizer
"""
self.data = []
length_list = []
for item in data:
da, response = item['dialogue_acts'], item['utterance']
da_tokens = tokenizer.encode(act2str(da))
response_tokens = tokenizer.encode(response)
length_list.append(len(da_tokens) + len(response_tokens) + 1)
self.data.append([da_tokens, response_tokens])
print(f'max: {np.max(length_list)}, min: {np.min(length_list)}, median: {np.quantile(length_list, 0.5)}, 0.99: {np.quantile(length_list, 0.99)}')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
\ No newline at end of file
import warnings
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch.cuda.amp as amp
import transformers
from transformers import GPT2LMHeadModel
# reference: https://pytorch.org/docs/master/notes/amp_examples.html
class AmpGPT2LMHeadModel(GPT2LMHeadModel):
if TYPE_CHECKING:
# For IDE's code hinting
forward = GPT2LMHeadModel.forward
else:
def forward(self, *args, **kwargs):
with amp.autocast():
return super().forward(*args, **kwargs)
def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"):
if model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable()
else:
warnings.warn(f"{type(model)} doesn't support gradient_checkpointing")
class AmpHelper:
"""
References:
https://pytorch.org/docs/master/notes/amp_examples.html
"""
def __init__(self, use_amp=True):
self.use_amp = use_amp
self.might_enable_autocast = amp.autocast() if use_amp else nullcontext()
self.scaler = amp.GradScaler()
def backward(self, loss):
if self.use_amp:
return self.scaler.scale(loss).backward()
else:
return loss.backward()
def step(self, optimizer):
if self.use_amp:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()
def might_unscale_(self, optimizer):
if self.use_amp:
# Unscales the gradients of optimizer's assigned params in-place
self.scaler.unscale_(optimizer)
\ No newline at end of file
import sys
sys.path.append('../../..')
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.nn.parallel import DistributedDataParallel as DDP
from convlab2.nlg.nlg import NLG
from util import act2str
from scgpt_special_tokens import *
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
class SCGPT(NLG):
def __init__(self, dataset_name, model_path, device='cpu'):
super(SCGPT, self).__init__()
self.device = device
self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED,
'additional_special_tokens': special_tokens})
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.load_state_dict(torch.load(model_path))
def generate(self, action):
action_str = act2str(action)
output = self._inference_batch([action_str])[0]
return output
def _inference_batch(self, sents):
with torch.no_grad():
sents = [sent + ' ' + START_OF_PRED for sent in sents]
sent_ids = [self.tokenizer.encode(sent) for sent in sents]
max_len = max([len(sent) for sent in sent_ids])
sent_ids = [sent + [self.tokenizer.pad_token_id] * (max_len - len(sent)) for sent in sent_ids]
inputs = torch.LongTensor(sent_ids).to(self.device)
model_to_run = self.model.module if type(self.model) is DDP else self.model
outputs = model_to_run.generate(inputs, max_length=256,
eos_token_id=self.tokenizer.pad_token_id,
pad_token_id=self.tokenizer.pad_token_id) # greedy
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search
output_strs = [self.tokenizer.decode(item) for item in outputs]
return output_strs
\ No newline at end of file
# separator
SYS_SPEAK = '[sys_speak]'
USR_SPEAK = '[usr_speak]'
START_OF_PRED = '[start_of_pred]'
END_OF_PRED = '[end_of_pred]'
PAD_TOKEN = '[_pad_token_]'
START_OF_INTENT = '[start_of_intent]'
END_OF_INTENT = '[end_of_intent]'
START_OF_SLOT = ''
SPECIAL_TOKENS = [val for name, val in globals().items() if
str.isupper(name) and isinstance(val, str) and val and val[0] == '[' and val[-1] == ']']
assert all(token.islower() for token in SPECIAL_TOKENS)
\ No newline at end of file
CUDA_VISIBLE_DEVICES="6" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py --dataset multiwoz21
\ No newline at end of file
This diff is collapsed.
CUDA_VISIBLE_DEVICES="1" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21
\ No newline at end of file
# -*- coding: utf-8 -*- import torch
def act2str(act):
"""Convert unified dataset dialog act dict to string.
act:
{'categorical': [{'intent': 'inform', 'domain': 'restaurant', 'slot': 'area', 'value': 'north'}],
'non-categorical': [{'intent': 'inform', 'domain': 'hotel', 'slot': 'area', 'value': 'north'}],
'binary': [{'intent': 'request', 'domain': 'hotel', 'slot': 'area'}]}
return:
restaurant { inform ( area = north ) } | hotel { inform ( area = north ) @ request ( area ) }
""" """
Created on Tue Mar 24 18:34:55 2020 old_format_dict = convert2old_format(act)
return dict2seq(old_format_dict)
@author: truthless def build_mask(max_len, seq_lens, use_float=False):
""" """
make one-hot masks given seq_lens list.
e.g., input: max_len=4, seq_lens=[2,3], return: [[1,1,0,0], [1,1,1,0]]
Args:
max_len (int): maximum sequence length
seq_lens (torch.Tensor): (batch)
Returns:
mask (torch.Tensor): (batch, max_len)
"""
a = torch.arange(max_len)[None, :]
b = seq_lens[:, None].cpu()
mask = a < b
if use_float:
mask = mask.float()
return mask
def tuple2dict(t):
'''
tuple: [(intent, domain, slot, value)]
dict: [domain: { intent: [slot, value] }]
'''
d = {}
for intent, domain, slot, value in t:
if domain not in d:
d[domain] = {}
if intent not in d[domain]:
d[domain][intent] = []
if slot == 'none' or slot is None:
continue
d[domain][intent].append([slot, value])
return d
def dict2dict(D): def convert2old_format(act):
''' """
dict: [domain-intent: [slot, value]] dict: {'categorical': [{'intent': 'request', 'domain': 'hotel', 'slot': 'area'}], 'non-categorical': [...], 'binary': [,,,]}
dict: [domain: { intent: [slot, value] }] """
''' new_act = {}
d = {} for key in act:
for domint in D: for item_dic in act[key]:
domain, intent = domint.split('-') domain = item_dic['domain']
if domain not in d: if domain not in new_act:
d[domain] = {} new_act[domain] = {}
if intent not in d[domain]: intent = item_dic['intent']
d[domain][intent] = [] if intent not in new_act[domain]:
for slot, value in D[domint]: new_act[domain][intent] = []
if slot == 'none' or slot is None: slot = item_dic['slot']
continue if 'value' in item_dic:
d[domain][intent].append([slot, value]) value = item_dic['value']
return d else:
value = None
new_act[domain][intent].append([slot, value])
return new_act
def dict2seq(d): def dict2seq(d):
''' '''
...@@ -74,25 +88,7 @@ def dict2seq(d): ...@@ -74,25 +88,7 @@ def dict2seq(d):
s += ' }' s += ' }'
return s.lower() return s.lower()
def tuple2seq(t):
d = tuple2dict(t)
s = dict2seq(d)
return s
if __name__ == '__main__': if __name__ == '__main__':
da_tuple = [('Inform', 'Booking', 'none', 'none'), ('Inform', 'Hotel', 'Price', 'cheap'), ('Inform', 'Hotel', 'Choice', '1'), ('Inform', 'Hotel', 'Parking', 'none')] ipt = {'categorical': [{'intent': 'inform', 'domain': 'restaurant', 'slot': 'area', 'value': 'north'}], 'non-categorical': [{'intent': 'inform', 'domain': 'hotel', 'slot': 'area', 'value': 'north'}], 'binary': [{'intent': 'request', 'domain': 'hotel', 'slot': 'area'}]}
da_dict = tuple2dict(da_tuple) print(act2str(ipt))
print(da_dict) \ No newline at end of file
da_seq = dict2seq(da_dict)
print(da_seq)
da_tuple = [('Request', 'Hotel', 'Address', '?'), ('Request', 'Hotel', 'Area', '?'), ('Inform', 'Attraction', 'Area', 'center'), ('Inform', 'Hotel', 'Price', 'cheap')]
da_dict = tuple2dict(da_tuple)
print(da_dict)
da_seq = dict2seq(da_dict)
print(da_seq)
D = {'Hotel-Inform': [['Price', 'cheap'], ['Type', 'hotel']]}
da_dict = dict2dict(D)
print(da_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment