diff --git a/convlab2/nlg/scgpt/evaluate.py b/convlab2/nlg/scgpt/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..c54b079f3579c6e29e5dee7f58edbe37f8a2ee27 --- /dev/null +++ b/convlab2/nlg/scgpt/evaluate.py @@ -0,0 +1,260 @@ +# Part of the evaluation script is adopted from https://github.com/pengbaolin/SC-GPT. +import os +import json +import sys +import math +import operator +import nltk +from collections import Counter +from nltk.util import ngrams + +file = open +class ERRScorer(): + + ## Scorer for calculating the slot errors + ## it scores utterances one by one + ## using two levels of matching + ## 1. exact match for categorical values + ## 2. multiple keyword matching for binary values + ## 3. cannot deal with don't care and none values + def __init__(self, detectfile): + + self.detectPairs = [] + fin = file(detectfile) + self.detectPairs = json.load(fin) + fin.close() + + def countSlots(self, dataset, reader): + count = 0 + for t in dataset: + feat = reader.formatter.format(t[0])[0] + c = count + for s, v in feat: + # skip type token + if s == 'type': + continue + if v == '_' or v == 'yes' or v == 'none' or v == 'no': + count += 1 + return count + + def score(self, a, feat, gen): + # import pdb + # pdb.set_trace() + # total slots + slot_count = 0 + # exact match for categorical slots + caty_slot_error = 0 + # fo each slot - token pair in the detect pair dict + for s, tok in self.detectPairs['general'].items(): + # token compare to + comparetos = ['sv.' + s + '._1', 'sv.' + s + '._2', 'sv.' + s + '._3'] + # count feature count in da feature + fcnt = 0 + for f in feat: + for compareto in comparetos: + if compareto == f: fcnt += 1 + # count generated semantic tokens + gcnt = gen.split().count(tok) + # count the slot difference + # if fcnt!=gcnt: + # caty_slot_error += 1.0 + caty_slot_error += abs(fcnt - gcnt) + # accumulate slot count + slot_count += fcnt + + # key word match for binary slots, only an approximation + bnay_slot_error = 0 + # for each binary slot + for s, toks in self.detectPairs['binary'].items(): + # tokens compare to + comparetos = ['sv.' + s + '.yes', 'sv.' + s + '.no', + 'sv.' + s + '.dontcare', 'sv.' + s + '.none'] + # count feature occurrence in da + fcnt = 0 + for f in feat: + for compareto in comparetos: + if compareto == f: fcnt += 1 + # count generated semantic tokens + gcnt = sum([gen.split().count(tok) for tok in toks]) + # count the slot difference + bnay_slot_error += abs(fcnt - gcnt) + # accumulate slot count + slot_count += fcnt + # total slot error + total_slot_error = caty_slot_error + bnay_slot_error + # when ?select/suggest act, only consider categorical errors + if a == [4] or a == [14]: + # return slot_count, caty_slot_error, caty_slot_error + return 0.0, 0.0, 0.0 + else: + return slot_count, total_slot_error, caty_slot_error + + +class BLEUScorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def __init__(self): + pass + + def score(self, parallel_corpus): + # ref_ = [] + # hyp_ = [] + # for hyps,refs in parallel_corpus: + # ref_.append(refs) + # hyp_.append(hyps[0]) + # return nltk.translate.bleu_score.corpus_bleu(ref_, hyp_) + # asdf + # containers and parameters + r, c = 0, 0 + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in parallel_corpus: + # BLEUscore = nltk.translate.bleu_score.sentence_bleu(refs, hyps[0]) + # print(hyps, refs, BLEUscore) + hyps = [hyp.lower().split() for hyp in hyps] + refs = [ref.lower().split() for ref in refs] + # compute ngram counts by matching each hypothesis + for hyp in hyps: + # for each ngram + for i in range(4): + # accumulate hyp ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + # compare to each reference + for ref in refs: + # get reference ngrams + refcnts = Counter(ngrams(ref, i + 1)) + # for each ngram + for ng in hypcnts: + # clipped counts + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + # compute clipped counts by clipping the hyp count if necessary + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c, find best match among all references + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + # length difference + diff = abs(len(ref) - len(hyp)) + # if the current diff less than stored one, change it + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + # extract the best length match in references + r += bestmatch[1] + c += len(hyp) + + # computing bleu score + # for numerical stability + p0 = 1e-7 + # brevity penality + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + # modified prec. + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + # weighted prec. + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + # final bleu score + bleu = bp * math.exp(s) + return bleu + + def sentence_bleu_4(self, parallel_corpus): + # input : single sentence, multiple references + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + weights = [0.25, 0.25, 0.25, 0.25] + r = 0 + c = 0 + + # accumulate ngram statistics + for hyps, refs in parallel_corpus: + hyps = [hyp.split() for hyp in hyps] + refs = [ref.split() for ref in refs] + # compute ngram counts by matching each hypothesis + for hyp in hyps: + # for each ngram + for i in range(4): + # accumulate hyp ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + # compare to each reference + for ref in refs: + # get reference ngrams + refcnts = Counter(ngrams(ref, i + 1)) + # for each ngram + for ng in hypcnts: + # clipped counts + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + # compute clipped counts by clipping the hyp count if necessary + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c, find best match among all references + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + # length difference + diff = abs(len(ref) - len(hyp)) + # if the current diff less than stored one, change it + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + # extract the best length match in references + r += bestmatch[1] + c += len(hyp) + + # for numerical stability + p0 = 1e-7 + # modified brevity penality + bp = math.exp(-abs(1.0 - float(r) / float(c + p0))) + # smoothed version of modified prec. + p_ns = [0, 0, 0, 0] + for i in range(4): + if i < 2: # original version n-gram counts + p_ns[i] = float(clip_count[i]) / float(count[i] + p0) + p0 + else: # smoothed version of ngram counts + smooth_term = 5 * p_ns[i - 1] * p_ns[i - 1] / p_ns[i - 2] + p_ns[i] = float(clip_count[i] + smooth_term) / float(count[i] + 5) + p0 + # weighted prec. + s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) + # final sentence bleu score + bleu_hyp = bp * math.exp(s) + return bleu_hyp + + +class GentScorer(object): + ## main Scorer interfaces for all scorers + ## it can do + ## 1. Compute bleu score + ## 2. Compute slot error rate + ## 3. Detailed illustraction of how differet split + ## of data affect performance + def __init__(self, detectfile): + self.bleuscorer = BLEUScorer() + + def scoreERR(self, parallel_pairs): + """input: [[dialoge_act, utterance], [dialog_act, utterance], ...]""" + + + def scoreBLEU(self, parallel_corpus): + return self.bleuscorer.score(parallel_corpus) + + def scoreSBLEU(self, parallel_corpus): + return self.bleuscorer.sentence_bleu_4(parallel_corpus) \ No newline at end of file diff --git a/convlab2/nlg/scgpt/main.py b/convlab2/nlg/scgpt/main.py index a15f76129438bf2ec280ceadacb7d3e69c092f60..a48bf7b47392ad15c3826d561e88c9242a422f53 100644 --- a/convlab2/nlg/scgpt/main.py +++ b/convlab2/nlg/scgpt/main.py @@ -14,9 +14,10 @@ from torch.utils.tensorboard import SummaryWriter import os from transformers import get_linear_schedule_with_warmup -from convlab2.util.unified_datasets_util import load_dataset, load_nlg_data +from convlab2.util.unified_datasets_util import load_dataset, load_nlg_data, load_ontology from convlab2.nlg.scgpt.util import act2str from convlab2.nlg.scgpt.model import SCGPTDataset +from evaluate import GentScorer # 分部式训练 import torch.distributed as dist @@ -34,6 +35,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument('--do_train', action="store_true", help="Whether to run training.") parser.add_argument('--dataset', default="multiwoz21", type=str, help="Whether to run training.") +parser.add_argument("--max_seq_len", default=256, type=int) FLAGS = parser.parse_args() local_rank = FLAGS.local_rank @@ -80,7 +82,7 @@ def pad_collate(batch): START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) pad_token_id = tokenizer.pad_token_id batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch] - batch = [item[-512:] for item in batch] # TF限制输入长度 + batch = [item[-FLAGS.max_seq_len:] for item in batch] # TF限制输入长度 max_len = max([len(item) for item in batch]) seq_lens = [len(item) for item in batch] split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) @@ -98,8 +100,8 @@ def pad_collate(batch): ## Training Hyper-params EPOCH_NUM = 20 -BATCH_SIZE = 10 # real_batch_size = BATCH_SIZE * num_gpu -VAL_STEP = 30 +BATCH_SIZE = 20 # real_batch_size = BATCH_SIZE * num_gpu +VAL_STEP = 300 WARM_STEPS = 250 if code_test: EPOCH_NUM = 2 @@ -187,7 +189,7 @@ def inference_batch(model, sents): sent_ids = [sent + [tokenizer.pad_token_id]*(max_len-len(sent)) for sent in sent_ids] inputs = torch.LongTensor(sent_ids).to(local_rank) model_to_run = model.module if type(model) is DDP else model - outputs = model_to_run.generate(inputs, max_length=513, eos_token_id=tokenizer.eos_token_id, + outputs = model_to_run.generate(inputs, max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id) # greedy # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id, # pad_token_id=gpt2_tokenizer.pad_token_id) # beam search @@ -209,31 +211,73 @@ def inference_sents(model, sents): return outputs -def test(model, nlg_data, model_path): +def test(model, nlg_data, ontology, model_path): """将sheel中的GPU个数设为1运行""" model.load_state_dict(torch.load(model_path)) + model.eval() print(f'model loaded from [{model_path}]') # sample_file = os.path.join(f'../../../data/dstc2/sample50_{TASK_TYPE}_input_data.txt') # Load test nlg data test_data = nlg_data['test'] dialog_acts = [act2str(item['dialogue_acts']) for item in test_data] golden_responses = [item['utterance'] for item in test_data] - outputs = inference_sents(model, dialog_acts, use_tqdm=True) + outputs = inference_sents(model, dialog_acts) if dist.get_rank() == 0: output_file = './test_output.txt' with open(output_file, 'w+') as f: for i in range(len(dialog_acts)): f.write(f'{dialog_acts[i]}\n{golden_responses[i]}\n{outputs[i]}\n\n') f.close() + evaluator = GentScorer() + parallel_corpus = [] + # BLEU + for i in range(len(dialog_acts)): + parallel_corpus.append([[golden_responses[i]], [outputs[i]]]) + BLEU_Score = evaluator.scoreSBLEU(parallel_corpus) + # ERR + ## all values in ontology + val2ds_dict = {} + for domain_name in ontology['domains']: + domain = ontology['domains'][domain_name] + for slot_name in domain['slots']: + slot = domain['slots'][slot_name] + possible_vals = slot['possible_values'] + if len(possible_vals) > 0: + for val in possible_vals: + val2ds_dict[val] = f'{domain_name}-{slot_name}' + ## missing values + score_list = [] + for item in nlg_data: + da = item['dialogue_acts'] + utterance = item['utterance'] + missing_count = 0 + redundant_count = 0 + all_count = 0 + all_values = set() + for key in da: + slot_value = da[key] + for triple in slot_value: + if 'value' in triple: + value = triple['value'] + all_values.add(value) + if value.strip().lower() not in utterance.lower(): + missing_count += 1 + all_count += 1 + ## redundant values + for val in val2ds_dict: + if f' {val.strip().lower()} ' in f' {utterance.strip().lower()} ' and val.strip().lower() not in all_values: + redundant_count += 1 + item_score = float(redundant_count + all_count) / all_count + score_list.append(item_score) + ERR_Score = np.mean(score_list) + print(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}') if __name__ == '__main__': dataset = load_dataset(FLAGS.dataset) + ontology = load_ontology(FLAGS.dataset) nlg_data = load_nlg_data(dataset) if FLAGS.do_train: train(model, nlg_data) else: - test(model, nlg_data, 'saved_model/{TASK_TYPE}/19_save/19_step5840.pt') - # test_samples(f'saved_model/{TASK_TYPE}/19_save/19_step5840.pt') - # elif FLAGS.show_attn: - # show_attention(f'saved_model/{TASK_TYPE}/19_save/19_step5840.pt') \ No newline at end of file + test(model, nlg_data, ontology, './saved_model/epoch_0/epoch_0_step2839.pt') diff --git a/convlab2/nlg/scgpt/model.py b/convlab2/nlg/scgpt/model.py index 82df1464355327cc6bba69d35066fe99912ce12c..9a41cfa53e92938d781abc27bc7082c3948b7a0d 100644 --- a/convlab2/nlg/scgpt/model.py +++ b/convlab2/nlg/scgpt/model.py @@ -2,6 +2,7 @@ from torch.utils.data import Dataset from util import act2str from scgpt_special_tokens import * import torch +import numpy as np class SCGPTDataset(Dataset): def __init__(self, data, tokenizer): @@ -11,11 +12,14 @@ class SCGPTDataset(Dataset): tokenizer: GPT2 Tokenizer """ self.data = [] + length_list = [] for item in data: da, response = item['dialogue_acts'], item['utterance'] da_tokens = tokenizer.encode(act2str(da)) response_tokens = tokenizer.encode(response) + length_list.append(len(da_tokens) + len(response_tokens) + 1) self.data.append([da_tokens, response_tokens]) + print(f'max: {np.max(length_list)}, min: {np.min(length_list)}, median: {np.quantile(length_list, 0.5)}, 0.99: {np.quantile(length_list, 0.99)}') def __len__(self): return len(self.data) diff --git a/convlab2/nlg/scgpt/scgpt.py b/convlab2/nlg/scgpt/scgpt.py new file mode 100644 index 0000000000000000000000000000000000000000..d184abff9c8d7179a907c3084bb2f22b5760dc1e --- /dev/null +++ b/convlab2/nlg/scgpt/scgpt.py @@ -0,0 +1,45 @@ +import sys +sys.path.append('../../..') + +import torch +from transformers import GPT2Tokenizer, GPT2LMHeadModel +from torch.nn.parallel import DistributedDataParallel as DDP + +from convlab2.nlg.nlg import NLG +from util import act2str +from scgpt_special_tokens import * + +special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK] + +class SCGPT(NLG): + def __init__(self, dataset_name, model_path, device='cpu'): + super(SCGPT, self).__init__() + self.device = device + self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) + self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, + 'additional_special_tokens': special_tokens}) + self.model.resize_token_embeddings(len(self.tokenizer)) + self.model.load_state_dict(torch.load(model_path)) + + + def generate(self, action): + action_str = act2str(action) + output = self._inference_batch([action_str])[0] + return output + + def _inference_batch(self, sents): + with torch.no_grad(): + sents = [sent + ' ' + START_OF_PRED for sent in sents] + sent_ids = [self.tokenizer.encode(sent) for sent in sents] + max_len = max([len(sent) for sent in sent_ids]) + sent_ids = [sent + [self.tokenizer.pad_token_id] * (max_len - len(sent)) for sent in sent_ids] + inputs = torch.LongTensor(sent_ids).to(self.device) + model_to_run = self.model.module if type(self.model) is DDP else self.model + outputs = model_to_run.generate(inputs, max_length=256, + eos_token_id=self.tokenizer.pad_token_id, + pad_token_id=self.tokenizer.pad_token_id) # greedy + # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id, + # pad_token_id=gpt2_tokenizer.pad_token_id) # beam search + output_strs = [self.tokenizer.decode(item) for item in outputs] + return output_strs \ No newline at end of file diff --git a/convlab2/nlg/scgpt/test.sh b/convlab2/nlg/scgpt/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..96d51bc6429e5d5248be24ab1c3331b0f195349c --- /dev/null +++ b/convlab2/nlg/scgpt/test.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES="6" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py --dataset multiwoz21 \ No newline at end of file diff --git a/convlab2/nlg/scgpt/train.sh b/convlab2/nlg/scgpt/train.sh index 5a869e3d63bd14664b7535e1ce906ea52edbfe4d..12f190670e835c0d1f8f5df198dade11ac59ce77 100644 --- a/convlab2/nlg/scgpt/train.sh +++ b/convlab2/nlg/scgpt/train.sh @@ -1 +1 @@ -CUDA_VISIBLE_DEVICES="0,1,2,3" python -m torch.distributed.launch --nproc_per_node 4 main.py --do_train --dataset multiwoz21 \ No newline at end of file +CUDA_VISIBLE_DEVICES="1" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21 \ No newline at end of file