Skip to content
Snippets Groups Projects
Commit af67d123 authored by Christian's avatar Christian
Browse files

Merge branch 'master' of https://github.com/ConvLab/ConvLab-3 into github_master

parents 8f4e7632 9577daba
Branches
No related tags found
No related merge requests found
include LICENSE.txt
include README.md
prune convlab2/*/__pycache__
n_gpus=1
task_name="dst"
dataset_name=$1
speaker="user"
context_window_size=100
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="dst_metric.py"
metric_for_best_model="accuracy"
source_column="context"
target_column="state_seq"
truncation_side="left"
max_source_length=1024
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=10
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -l t5-small
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab2.nlg.nlg import NLG
from convlab2.base_models.t5.nlu.serialization import serialize_dialogue_acts
from convlab2.util.custom_util import model_downloader
class T5NLG(NLG):
def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'):
assert speaker in ['user', 'system']
self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size
self.use_context = context_window_size > 0
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
model_downloader(model_dir, model_file)
self.config = AutoConfig.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config)
self.model.eval()
self.device = device if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
logging.info("T5NLG loaded")
def generate(self, dialogue_acts, context=list()):
if self.use_context:
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context]
utts = context + ['']
else:
utts = ['']
input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
dialogue_acts_seq = serialize_dialogue_acts(dialogue_acts)
input_seq = dialogue_acts_seq + '\n' + input_seq
print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
# print(input_seq)
output_seq = self.model.generate(**input_seq, max_length=256)
# print(output_seq)
output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True)
# print(output_seq)
return output_seq
if __name__ == '__main__':
das = [
{
"categorical": [],
"non-categorical": [],
"binary": [
{
"intent": "request",
"domain": "taxi",
"slot": "leave at"
},
{
"intent": "request",
"domain": "taxi",
"slot": "arrive by"
}
]
},
{
"categorical": [],
"non-categorical": [
{
"intent": "inform",
"domain": "taxi",
"slot": "type",
"value": "blue honda",
"start": 38,
"end": 48
},
{
"intent": "inform",
"domain": "taxi",
"slot": "phone",
"value": "07218068540",
"start": 67,
"end": 78
}
],
"binary": [
{
"intent": "book",
"domain": "taxi",
"slot": ""
}
]
},
{
"categorical": [],
"non-categorical": [],
"binary": [
{
"intent": "reqmore",
"domain": "general",
"slot": ""
}
]
},
{
"categorical": [],
"non-categorical": [],
"binary": [
{
"intent": "bye",
"domain": "general",
"slot": ""
}
]
}
]
contexts = [
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton."],
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15."],
["I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540",
"Thank you for all the help! I appreciate it."],
["Thank you for all the help! I appreciate it.",
"You are welcome. Is there anything else I can help you with today?"
"No, I am all set. Have a nice day. Bye."],
]
nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='output/nlg/multiwoz21/system/context_3')
for da, context in zip(das, contexts):
print(da)
print(nlg.generate(da, context))
print()
import logging
import os
import json
import torch
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab2.nlu.nlu import NLU
from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts
......@@ -16,7 +14,6 @@ class T5NLU(NLU):
self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size
self.use_context = context_window_size > 0
self.prefix = "parse the dialogue action of the last utterance: "
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
......@@ -38,7 +35,7 @@ class T5NLU(NLU):
utts = context + [utterance]
else:
utts = [utterance]
input_seq = ' '.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
# print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
# print(input_seq)
......
import sys
from nltk.translate.bleu_score import corpus_bleu
from nltk.tokenize import word_tokenize
sys.path.append('../..')
import json
from pprint import pprint
import sacrebleu
from evaluate_util import GentScorer
from convlab2.util.unified_datasets_util import load_ontology
import numpy as np
logger = None
def evaluate(predict_result):
predict_result = json.load(open(predict_result))
class Logging:
def __init__(self, path):
file = open(path, 'w+')
file.write('')
file.close()
self.path = path
def log(self, sent):
with open(self.path, 'a') as f:
f.write(sent)
f.write('\n')
f.close()
def evaluate(predict_result, ontology):
predict_result = json.load(open(predict_result))
metrics = {}
predictions, references = [], []
for sample in predict_result:
references.append(sample['utterance'])
predictions.append(sample['predictions']['utterance'])
metrics['bleu'] = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
# BLEU Score
evaluator = GentScorer()
references = []
candidates = []
for i in range(len(predict_result)):
references.append([word_tokenize(predict_result[i]['utterance'])])
candidates.append(word_tokenize(predict_result[i]['prediction']))
metrics['bleu'] = corpus_bleu(references, candidates)
# ERROR Rate
## get all values in ontology
val2ds_dict = {}
for domain_name in ontology['domains']:
domain = ontology['domains'][domain_name]
for slot_name in domain['slots']:
slot = domain['slots'][slot_name]
if 'possible_values' not in slot:
continue
possible_vals = slot['possible_values']
if len(possible_vals) > 0:
for val in possible_vals:
val2ds_dict[val] = f'{domain_name}-{slot_name}'
score_list = []
for item in predict_result:
da = item['dialogue_acts']
utterance = item['prediction']
missing_count = 0
redundant_count = 0
all_count = 0
all_values = set()
## missing values
for key in da:
slot_value = da[key]
for triple in slot_value:
if 'value' in triple:
value = triple['value']
all_values.add(value)
if value.strip().lower() not in utterance.lower():
missing_count += 1
# logger.log(f"missing: {triple['slot']}-{triple['value']} | {item['prediction']} | {item['utterance']}")
all_count += 1
if all_count == 0:
continue
## redundant values
for val in val2ds_dict:
if f' {val.strip().lower()} ' in f' {utterance.strip().lower()} ' and val.strip().lower() not in all_values:
wlist = val2ds_dict[val].split('-')
domain, slot = wlist[0], wlist[1]
if f' {slot.strip().lower()}' in f' {utterance.strip().lower()} ':
redundant_count += 1
# logger.log(f"redundant: {val}/{val2ds_dict[val]} | {item['prediction']} | {item['utterance']}")
item_score = float(missing_count + redundant_count) / all_count
# logger.log(f"redundant: {redundant_count} | missing_count: {missing_count} |all_count: {all_count}")
score_list.append(item_score)
metrics['err'] = np.mean(score_list)
return metrics
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the prediction file that in the unified data format')
parser = ArgumentParser(description="calculate NLG metrics for unified datasets")
parser.add_argument('--predict_result', '-p', type=str, required=True,
help='path to the prediction file that in the unified data format')
parser.add_argument('--dataset_name', type=str, required=True,
help='the name of the dataset to be evaluated')
args = parser.parse_args()
print(args)
metrics = evaluate(args.predict_result)
ontology = load_ontology(args.dataset_name)
# logger = Logging('./evaluate_unified_datasets.log')
metrics = evaluate(args.predict_result, ontology)
pprint(metrics)
# Part of the evaluation script is adopted from https://github.com/pengbaolin/SC-GPT.
import os
import json
import sys
import math
import operator
import nltk
from collections import Counter
from nltk.util import ngrams
file = open
class ERRScorer():
## Scorer for calculating the slot errors
## it scores utterances one by one
## using two levels of matching
## 1. exact match for categorical values
## 2. multiple keyword matching for binary values
## 3. cannot deal with don't care and none values
def __init__(self, detectfile):
self.detectPairs = []
fin = file(detectfile)
self.detectPairs = json.load(fin)
fin.close()
def countSlots(self, dataset, reader):
count = 0
for t in dataset:
feat = reader.formatter.format(t[0])[0]
c = count
for s, v in feat:
# skip type token
if s == 'type':
continue
if v == '_' or v == 'yes' or v == 'none' or v == 'no':
count += 1
return count
def score(self, a, feat, gen):
# import pdb
# pdb.set_trace()
# total slots
slot_count = 0
# exact match for categorical slots
caty_slot_error = 0
# fo each slot - token pair in the detect pair dict
for s, tok in self.detectPairs['general'].items():
# token compare to
comparetos = ['sv.' + s + '._1', 'sv.' + s + '._2', 'sv.' + s + '._3']
# count feature count in da feature
fcnt = 0
for f in feat:
for compareto in comparetos:
if compareto == f: fcnt += 1
# count generated semantic tokens
gcnt = gen.split().count(tok)
# count the slot difference
# if fcnt!=gcnt:
# caty_slot_error += 1.0
caty_slot_error += abs(fcnt - gcnt)
# accumulate slot count
slot_count += fcnt
# key word match for binary slots, only an approximation
bnay_slot_error = 0
# for each binary slot
for s, toks in self.detectPairs['binary'].items():
# tokens compare to
comparetos = ['sv.' + s + '.yes', 'sv.' + s + '.no',
'sv.' + s + '.dontcare', 'sv.' + s + '.none']
# count feature occurrence in da
fcnt = 0
for f in feat:
for compareto in comparetos:
if compareto == f: fcnt += 1
# count generated semantic tokens
gcnt = sum([gen.split().count(tok) for tok in toks])
# count the slot difference
bnay_slot_error += abs(fcnt - gcnt)
# accumulate slot count
slot_count += fcnt
# total slot error
total_slot_error = caty_slot_error + bnay_slot_error
# when ?select/suggest act, only consider categorical errors
if a == [4] or a == [14]:
# return slot_count, caty_slot_error, caty_slot_error
return 0.0, 0.0, 0.0
else:
return slot_count, total_slot_error, caty_slot_error
class BLEUScorer(object):
## BLEU score calculator via GentScorer interface
## it calculates the BLEU-4 by taking the entire corpus in
## Calulate based multiple candidates against multiple references
def __init__(self):
pass
def score(self, parallel_corpus):
# ref_ = []
# hyp_ = []
# for hyps,refs in parallel_corpus:
# ref_.append(refs)
# hyp_.append(hyps[0])
# return nltk.translate.bleu_score.corpus_bleu(ref_, hyp_)
# asdf
# containers and parameters
r, c = 0, 0
count = [0, 0, 0, 0]
clip_count = [0, 0, 0, 0]
weights = [0.25, 0.25, 0.25, 0.25]
# accumulate ngram statistics
for hyps, refs in parallel_corpus:
# BLEUscore = nltk.translate.bleu_score.sentence_bleu(refs, hyps[0])
# print(hyps, refs, BLEUscore)
hyps = [hyp.lower().split() for hyp in hyps]
refs = [ref.lower().split() for ref in refs]
# compute ngram counts by matching each hypothesis
for hyp in hyps:
# for each ngram
for i in range(4):
# accumulate hyp ngram counts
hypcnts = Counter(ngrams(hyp, i + 1))
cnt = sum(hypcnts.values())
count[i] += cnt
# compute clipped counts
max_counts = {}
# compare to each reference
for ref in refs:
# get reference ngrams
refcnts = Counter(ngrams(ref, i + 1))
# for each ngram
for ng in hypcnts:
# clipped counts
max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng])
# compute clipped counts by clipping the hyp count if necessary
clipcnt = dict((ng, min(count, max_counts[ng])) \
for ng, count in hypcnts.items())
clip_count[i] += sum(clipcnt.values())
# accumulate r & c, find best match among all references
bestmatch = [1000, 1000]
for ref in refs:
if bestmatch[0] == 0: break
# length difference
diff = abs(len(ref) - len(hyp))
# if the current diff less than stored one, change it
if diff < bestmatch[0]:
bestmatch[0] = diff
bestmatch[1] = len(ref)
# extract the best length match in references
r += bestmatch[1]
c += len(hyp)
# computing bleu score
# for numerical stability
p0 = 1e-7
# brevity penality
bp = 1 if c > r else math.exp(1 - float(r) / float(c))
# modified prec.
p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \
for i in range(4)]
# weighted prec.
s = math.fsum(w * math.log(p_n) \
for w, p_n in zip(weights, p_ns) if p_n)
# final bleu score
bleu = bp * math.exp(s)
return bleu
def sentence_bleu_4(self, parallel_corpus):
# input : single sentence, multiple references
count = [0, 0, 0, 0]
clip_count = [0, 0, 0, 0]
weights = [0.25, 0.25, 0.25, 0.25]
r = 0
c = 0
# accumulate ngram statistics
for hyps, refs in parallel_corpus:
hyps = [hyp.split() for hyp in hyps]
refs = [ref.split() for ref in refs]
# compute ngram counts by matching each hypothesis
for hyp in hyps:
# for each ngram
for i in range(4):
# accumulate hyp ngram counts
hypcnts = Counter(ngrams(hyp, i + 1))
cnt = sum(hypcnts.values())
count[i] += cnt
# compute clipped counts
max_counts = {}
# compare to each reference
for ref in refs:
# get reference ngrams
refcnts = Counter(ngrams(ref, i + 1))
# for each ngram
for ng in hypcnts:
# clipped counts
max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng])
# compute clipped counts by clipping the hyp count if necessary
clipcnt = dict((ng, min(count, max_counts[ng])) \
for ng, count in hypcnts.items())
clip_count[i] += sum(clipcnt.values())
# accumulate r & c, find best match among all references
bestmatch = [1000, 1000]
for ref in refs:
if bestmatch[0] == 0: break
# length difference
diff = abs(len(ref) - len(hyp))
# if the current diff less than stored one, change it
if diff < bestmatch[0]:
bestmatch[0] = diff
bestmatch[1] = len(ref)
# extract the best length match in references
r += bestmatch[1]
c += len(hyp)
# for numerical stability
p0 = 1e-7
# modified brevity penality
bp = math.exp(-abs(1.0 - float(r) / float(c + p0)))
# smoothed version of modified prec.
p_ns = [0, 0, 0, 0]
for i in range(4):
if i < 2: # original version n-gram counts
p_ns[i] = float(clip_count[i]) / float(count[i] + p0) + p0
else: # smoothed version of ngram counts
smooth_term = 5 * p_ns[i - 1] * p_ns[i - 1] / p_ns[i - 2]
p_ns[i] = float(clip_count[i] + smooth_term) / float(count[i] + 5) + p0
# weighted prec.
s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)
# final sentence bleu score
bleu_hyp = bp * math.exp(s)
return bleu_hyp
class GentScorer(object):
## main Scorer interfaces for all scorers
## it can do
## 1. Compute bleu score
## 2. Compute slot error rate
## 3. Detailed illustraction of how differet split
## of data affect performance
def __init__(self):
self.bleuscorer = BLEUScorer()
def scoreERR(self, parallel_pairs):
"""input: [[dialoge_act, utterance], [dialog_act, utterance], ...]"""
def scoreBLEU(self, parallel_corpus):
return self.bleuscorer.score(parallel_corpus)
def scoreSBLEU(self, parallel_corpus):
return self.bleuscorer.sentence_bleu_4(parallel_corpus)
\ No newline at end of file
......@@ -2,6 +2,12 @@
This is the implemention of [SC-GPT](https://aclanthology.org/2020.findings-emnlp.17) which is proposed by
Peng et al., 2020.
You should first download and unzip the SG-GPT checkpoint
```bash
wget https://bapengstorage.blob.core.windows.net/fileshare/scgpt.tar.gz
tar -xvf scgpt.tar.gz
```
and if you want to use this checkpoint, you have to specifiy its path through ``--scgpt_model_ckpt_path`` parameter in ``train.sh`` and ``test.sh``.
## Train
......@@ -12,9 +18,9 @@ When using the training code, you may have to adjust the parameters
according to your machine configuration. Note that the training code
only supports GPU training.
## Test
## Evaluation
```python
./test.sh
./evaluate.sh
```
The test code also only supports GPU mode. We will report the BLEU score
and ERR score according to the original SC-GPT paper(Peng et al., 2020).
......
CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py \
--dataset multiwoz21 \
--scgpt_model_ckpt_path /data/zhangzheng/scgpt \
--model_path /data/zhangzheng/ConvLab-3/convlab2/nlg/scgpt/saved_model/epoch_4/epoch_4_step8875.pt
\ No newline at end of file
......@@ -2,6 +2,7 @@ import sys
sys.path.append('../../..')
import argparse
import json
from tqdm import tqdm
import torch
import numpy as np
......@@ -9,7 +10,6 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import os
from transformers import get_linear_schedule_with_warmup
......@@ -32,7 +32,8 @@ parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument('--do_train', action="store_true", help="Whether to run training.")
parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.")
parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.")
parser.add_argument("--max_seq_len", default=256, type=int)
parser.add_argument('--scgpt_model_ckpt_path', default="", type=str, help="The path of model for testing.")
parser.add_argument("--max_seq_len", default=128, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank
......@@ -40,14 +41,25 @@ torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
# TensorBoard
tb_writer = SummaryWriter()
tb_dir = './runs'
if not os.path.exists(tb_dir):
os.mkdir(tb_dir)
tb_writer = SummaryWriter(tb_dir)
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
## load model
if FLAGS.scgpt_model_ckpt_path == '':
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2')
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank)
model.resize_token_embeddings(len(tokenizer))
else:
tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.scgpt_model_ckpt_path)
tokenizer.add_special_tokens(
{'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
model = GPT2LMHeadModel.from_pretrained(FLAGS.scgpt_model_ckpt_path).to(local_rank)
print('model load from ' + FLAGS.scgpt_model_ckpt_path)
model.resize_token_embeddings(len(tokenizer))
nll_loss = nn.NLLLoss(reduce=False).to(local_rank)
ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank)
......@@ -80,6 +92,7 @@ def pad_collate(batch):
batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch]
batch = [item[-FLAGS.max_seq_len:] for item in batch]
max_len = max([len(item) for item in batch])
# print('max_len', max_len)
seq_lens = [len(item) for item in batch]
split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
def get_x_len(tokens):
......@@ -92,12 +105,15 @@ def pad_collate(batch):
return split_idx
seq_lens_input = [get_x_len(item) for item in batch]
batch = [item + [pad_token_id]*(max_len-len(item)) for item in batch]
# print(batch)
# print(seq_lens)
# print(seq_lens_input)
return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input)
## Training Hyper-params
EPOCH_NUM = 20
BATCH_SIZE = 20 # real_batch_size = BATCH_SIZE * num_gpu
VAL_STEP = 300
BATCH_SIZE = 32 # real_batch_size = BATCH_SIZE * num_gpu
VAL_STEP = 500
WARM_STEPS = 250
if code_test:
EPOCH_NUM = 2
......@@ -105,7 +121,6 @@ if code_test:
VAL_STEP = 2
WARM_STEPS = 3
LR = 5e-5
TASK_TYPE = 'nlu' # nlu or dst
SAVE_PATH = f'./saved_model'
def train(model, nlg_data, global_step=0):
train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
......@@ -140,16 +155,19 @@ def train(model, nlg_data, global_step=0):
tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
if batch_id % VAL_STEP == 0:
global_step += 1
# save the model when each epoch ends
if dist.get_rank() == 0:
# vaidation
model.eval()
val_loss = eval(model, val_dataloader)
ppl = np.exp(val_loss)
tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
model.train()
global_step += 1
# save the model when each epoch ends
if dist.get_rank() == 0:
# save model
save_dir = os.path.join(SAVE_PATH, f'epoch_{epoch}')
os.makedirs(save_dir, exist_ok=True)
torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt'))
......@@ -157,6 +175,7 @@ def train(model, nlg_data, global_step=0):
torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt'))
print(f'Save model checkpoint to [{save_dir}]')
tb_writer.flush()
......@@ -212,20 +231,31 @@ def test(model, nlg_data, ontology, model_path):
model.load_state_dict(torch.load(model_path))
model.eval()
print(f'model loaded from [{model_path}]')
# sample_file = os.path.join(f'../../../data/dstc2/sample50_{TASK_TYPE}_input_data.txt')
# Load test nlg data
test_data = nlg_data['test']
dialog_acts = [act2str(item['dialogue_acts']) for item in test_data]
golden_responses = [item['utterance'] for item in test_data]
dialog_acts = [act2str(item['dialogue_acts']).strip() for item in test_data]
golden_responses = [item['utterance'].strip() for item in test_data]
# dialog_acts = dialog_acts[:10]
# golden_responses = golden_responses[:10]
outputs = inference_sents(model, dialog_acts)
def get_real_output(ipt):
if '[start_of_pred]' in ipt:
ipt = ipt[ipt.index('[start_of_pred]')+15:].strip()
if '[_pad_token_]' in ipt:
ipt = ipt[:ipt.index('[_pad_token_]')].strip()
return ipt
outputs = [get_real_output(item) for item in outputs]
output_file = './test_output.json'
if dist.get_rank() == 0:
output_file = './test_output.txt'
with open(output_file, 'w+') as f:
result = []
for i in range(len(dialog_acts)):
f.write(f'{dialog_acts[i]}\n{golden_responses[i]}\n{outputs[i]}\n\n')
f.close()
result.append({
'dialogue_acts': test_data[i]['dialogue_acts'],
'utterance': test_data[i]['utterance'],
'prediction': outputs[i]
})
json.dump(result, f, indent=2, ensure_ascii=False)
evaluator = GentScorer()
parallel_corpus = []
# BLEU
......@@ -273,6 +303,9 @@ def test(model, nlg_data, ontology, model_path):
score_list.append(item_score)
ERR_Score = np.mean(score_list)
print(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}')
# with open(output_file, 'a') as f:
# f.write(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}')
# f.close()
if __name__ == '__main__':
......
CUDA_VISIBLE_DEVICES="6" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py --dataset multiwoz21
\ No newline at end of file
CUDA_VISIBLE_DEVICES="1" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21
\ No newline at end of file
CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21 --scgpt_model_ckpt_path /data/zhangzheng/scgpt
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment