From 12a1da7a103f1ed238d8f53172b86814b9e91237 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Sat, 19 Mar 2022 14:13:13 +0800 Subject: [PATCH] add interface for t5 nlu --- convlab2/base_models/t5/create_data.py | 51 +----------- .../base_models/t5/nlu/merge_predict_res.py | 34 ++++++++ convlab2/base_models/t5/nlu/nlu.py | 80 +++++++++++++++++++ convlab2/base_models/t5/nlu/nlu_metric.py | 40 ++-------- .../base_models/t5/nlu/run_multiwoz21_user.sh | 2 + .../t5/nlu/run_multiwoz21_user_context3.sh | 2 + convlab2/base_models/t5/nlu/run_tm1_user.sh | 2 + .../t5/nlu/run_tm1_user_context3.sh | 2 + convlab2/base_models/t5/nlu/run_tm2_user.sh | 2 + .../t5/nlu/run_tm2_user_context3.sh | 2 + convlab2/base_models/t5/nlu/run_tm3_user.sh | 2 + .../t5/nlu/run_tm3_user_context3.sh | 2 + convlab2/base_models/t5/nlu/serialization.py | 51 ++++++++++++ convlab2/nlu/evaluate_unified_datasets.py | 4 +- 14 files changed, 190 insertions(+), 86 deletions(-) create mode 100755 convlab2/base_models/t5/nlu/merge_predict_res.py create mode 100755 convlab2/base_models/t5/nlu/nlu.py create mode 100644 convlab2/base_models/t5/nlu/serialization.py diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py index 33b87018..19be0b81 100644 --- a/convlab2/base_models/t5/create_data.py +++ b/convlab2/base_models/t5/create_data.py @@ -3,6 +3,7 @@ import json from tqdm import tqdm import re from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data +from convlab2.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq def create_rg_data(dataset, data_dir, args): data_by_split = load_rg_data(dataset, speaker=args.speaker) @@ -29,56 +30,6 @@ def create_nlu_data(dataset, data_dir, args): data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}') os.makedirs(data_dir, exist_ok=True) - def serialize_dialogue_acts(dialogue_acts): - da_seqs = [] - for da_type in dialogue_acts: - for da in dialogue_acts[da_type]: - intent, domain, slot = da['intent'], da['domain'], da['slot'] - if da_type == 'binary': - da_seq = f'[{da_type}][{intent}][{domain}][{slot}]' - else: - value = da['value'] - da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]' - da_seqs.append(da_seq) - return ';'.join(da_seqs) - - def deserialize_dialogue_acts(das_seq): - dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} - if len(das_seq) == 0: - return dialogue_acts - da_seqs = das_seq.split('];[') - for i, da_seq in enumerate(da_seqs): - if i == 0: - assert da_seq[0] == '[' - da_seq = da_seq[1:] - if i == len(da_seqs) - 1: - assert da_seq[-1] == ']' - da_seq = da_seq[:-1] - da = da_seq.split('][') - if len(da) == 0: - continue - da_type = da[0] - if len(da) == 5 and da_type in ['categorical', 'non-categorical']: - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) - elif len(da) == 4 and da_type == 'binary': - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]}) - else: - # invalid da format, skip - # print(das_seq) - # print(da_seq) - # print() - pass - return dialogue_acts - - def equal_da_seq(dialogue_acts, das_seq): - predict_dialogue_acts = deserialize_dialogue_acts(das_seq) - for da_type in ['binary', 'categorical', 'non-categorical']: - das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]]) - predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]]) - if das != predict_das: - return False - return True - data_splits = data_by_split.keys() file_name = os.path.join(data_dir, f"source_prefix.txt") with open(file_name, "w") as f: diff --git a/convlab2/base_models/t5/nlu/merge_predict_res.py b/convlab2/base_models/t5/nlu/merge_predict_res.py new file mode 100755 index 00000000..f3386b21 --- /dev/null +++ b/convlab2/base_models/t5/nlu/merge_predict_res.py @@ -0,0 +1,34 @@ +import json +import os +from convlab2.util import load_dataset, load_nlu_data +from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts + + +def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): + assert os.path.exists(predict_result) + dataset = load_dataset(dataset_name) + data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] + + if save_dir is None: + save_dir = os.path.dirname(predict_result) + else: + os.makedirs(save_dir, exist_ok=True) + predict_result = [deserialize_dialogue_acts(json.loads(x)['predictions'].strip()) for x in open(predict_result)] + + for sample, prediction in zip(data, predict_result): + sample['predictions'] = {'dialogue_acts': prediction} + + json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation") + parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset') + parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') + parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') + parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + args = parser.parse_args() + print(args) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) diff --git a/convlab2/base_models/t5/nlu/nlu.py b/convlab2/base_models/t5/nlu/nlu.py new file mode 100755 index 00000000..2be43330 --- /dev/null +++ b/convlab2/base_models/t5/nlu/nlu.py @@ -0,0 +1,80 @@ +import logging +import os +import json +import torch +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig +from convlab2.nlu.nlu import NLU +from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts +from convlab2.util.custom_util import model_downloader + + +class T5NLU(NLU): + def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): + assert speaker in ['user', 'system'] + self.speaker = speaker + self.opponent = 'system' if speaker == 'user' else 'user' + self.context_window_size = context_window_size + self.use_context = context_window_size > 0 + self.prefix = "parse the dialogue action of the last utterance: " + + model_dir = os.path.dirname(os.path.abspath(__file__)) + if not os.path.exists(model_name_or_path): + model_downloader(model_dir, model_file) + + self.config = AutoConfig.from_pretrained(model_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config) + self.model.eval() + self.device = device if torch.cuda.is_available() else "cpu" + self.model.to(self.device) + + logging.info("T5NLU loaded") + + def predict(self, utterance, context=list()): + if self.use_context: + if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: + context = [item[1] for item in context] + utts = context + [utterance] + else: + utts = [utterance] + input_seq = ' '.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)]) + # print(input_seq) + input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device) + # print(input_seq) + output_seq = self.model.generate(**input_seq, max_length=256) + # print(output_seq) + output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True) + # print(output_seq) + dialogue_acts = deserialize_dialogue_acts(output_seq.strip()) + return dialogue_acts + + +if __name__ == '__main__': + texts = [ + "I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", + "I want to leave after 17:15.", + "Thank you for all the help! I appreciate it.", + "Please find a restaurant called Nusha.", + "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", + "It's not a restaurant, it's an attraction. Nusha." + ] + contexts = [ + [], + ["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", + "What time do you want to leave and what time do you want to arrive by?"], + ["What time do you want to leave and what time do you want to arrive by?", + "I want to leave after 17:15.", + "Booking completed! your taxi will be blue honda Contact number is 07218068540"], + [], + ["Please find a restaurant called Nusha.", + "I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?"], + ["I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?", + "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", + "Could you double check that you've spelled the name correctly? The closest I can find is Nandos."] + ] + nlu = T5NLU(speaker='user', context_window_size=3, model_name_or_path='output/nlu/multiwoz21/user/context_3') + for text, context in zip(texts, contexts): + print(text) + print(nlu.predict(text, context)) + print() diff --git a/convlab2/base_models/t5/nlu/nlu_metric.py b/convlab2/base_models/t5/nlu/nlu_metric.py index feff4edd..1eb57c84 100644 --- a/convlab2/base_models/t5/nlu/nlu_metric.py +++ b/convlab2/base_models/t5/nlu/nlu_metric.py @@ -14,7 +14,7 @@ """NLU Metric""" import datasets -import re +from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts # TODO: Add BibTeX citation @@ -42,8 +42,8 @@ Returns: Examples: >>> nlu_metric = datasets.load_metric("nlu_metric.py") - >>> predictions = ["[binary]-[thank]-[general]-[]", "[non-categorical]-[inform]-[taxi]-[leave at]-[17:15]"] - >>> references = ["[binary]-[thank]-[general]-[]", "[non-categorical]-[inform]-[train]-[leave at]-[17:15]"] + >>> predictions = ["[binary][thank][general][]", "[non-categorical][inform][taxi][leave at][17:15]"] + >>> references = ["[binary][thank][general][]", "[non-categorical][inform][train][leave at][17:15]"] >>> results = nlu_metric.compute(predictions=predictions, references=references) >>> print(results) {'seq_em': 0.5, 'accuracy': 0.5, @@ -70,36 +70,6 @@ class NLUMetrics(datasets.Metric): }) ) - def deserialize_dialogue_acts(self, das_seq): - dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} - if len(das_seq) == 0: - return dialogue_acts - da_seqs = das_seq.split('];[') - for i, da_seq in enumerate(da_seqs): - if len(da_seq) == 0: - continue - if i == 0: - if da_seq[0] == '[': - da_seq = da_seq[1:] - if i == len(da_seqs) - 1: - if da_seq[-1] == ']': - da_seq = da_seq[:-1] - da = da_seq.split('][') - if len(da) == 0: - continue - da_type = da[0] - if len(da) == 5 and da_type in ['categorical', 'non-categorical']: - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) - elif len(da) == 4 and da_type == 'binary': - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]}) - else: - # invalid da format, skip - # print(das_seq) - # print(da_seq) - # print() - pass - return dialogue_acts - def _compute(self, predictions, references): """Returns the scores: sequence exact match, dialog acts accuracy and f1""" seq_em = [] @@ -108,8 +78,8 @@ class NLUMetrics(datasets.Metric): for prediction, reference in zip(predictions, references): seq_em.append(prediction.strip()==reference.strip()) - pred_da = self.deserialize_dialogue_acts(prediction) - gold_da = self.deserialize_dialogue_acts(reference) + pred_da = deserialize_dialogue_acts(prediction) + gold_da = deserialize_dialogue_acts(reference) flag = True for da_type in ['binary', 'categorical', 'non-categorical']: if da_type == 'binary': diff --git a/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh b/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh index a5bd0527..85f3ec83 100644 --- a/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh +++ b/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh b/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh index d9f59dbf..8d7b5c93 100644 --- a/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm1_user.sh b/convlab2/base_models/t5/nlu/run_tm1_user.sh index ffd8aa41..16a16fdb 100644 --- a/convlab2/base_models/t5/nlu/run_tm1_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm1_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh index b99d7fad..ccb67609 100644 --- a/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm2_user.sh b/convlab2/base_models/t5/nlu/run_tm2_user.sh index e41d2f39..8686822f 100644 --- a/convlab2/base_models/t5/nlu/run_tm2_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm2_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh index 832c20b4..03c24899 100644 --- a/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm3_user.sh b/convlab2/base_models/t5/nlu/run_tm3_user.sh index 39399a76..470cb7d7 100644 --- a/convlab2/base_models/t5/nlu/run_tm3_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm3_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh index fce613ad..5e325d1f 100644 --- a/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/serialization.py b/convlab2/base_models/t5/nlu/serialization.py new file mode 100644 index 00000000..5a620f46 --- /dev/null +++ b/convlab2/base_models/t5/nlu/serialization.py @@ -0,0 +1,51 @@ +def serialize_dialogue_acts(dialogue_acts): + da_seqs = [] + for da_type in dialogue_acts: + for da in dialogue_acts[da_type]: + intent, domain, slot = da['intent'], da['domain'], da['slot'] + if da_type == 'binary': + da_seq = f'[{da_type}][{intent}][{domain}][{slot}]' + else: + value = da['value'] + da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]' + da_seqs.append(da_seq) + return ';'.join(da_seqs) + +def deserialize_dialogue_acts(das_seq): + dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} + if len(das_seq) == 0: + return dialogue_acts + da_seqs = das_seq.split('];[') + for i, da_seq in enumerate(da_seqs): + if len(da_seq) == 0: + continue + if i == 0: + if da_seq[0] == '[': + da_seq = da_seq[1:] + if i == len(da_seqs) - 1: + if da_seq[-1] == ']': + da_seq = da_seq[:-1] + da = da_seq.split('][') + if len(da) == 0: + continue + da_type = da[0] + if len(da) == 5 and da_type in ['categorical', 'non-categorical']: + dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) + elif len(da) == 4 and da_type == 'binary': + dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]}) + else: + # invalid da format, skip + # print(das_seq) + # print(da_seq) + # print() + pass + return dialogue_acts + +def equal_da_seq(dialogue_acts, das_seq): + predict_dialogue_acts = deserialize_dialogue_acts(das_seq) + for da_type in ['binary', 'categorical', 'non-categorical']: + das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]]) + predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]]) + if das != predict_das: + return False + return True diff --git a/convlab2/nlu/evaluate_unified_datasets.py b/convlab2/nlu/evaluate_unified_datasets.py index 86e91747..907b1afa 100644 --- a/convlab2/nlu/evaluate_unified_datasets.py +++ b/convlab2/nlu/evaluate_unified_datasets.py @@ -17,6 +17,8 @@ def evaluate(predict_result): else: predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['predictions']['dialogue_acts'][da_type]] labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['dialogue_acts'][da_type]] + predicts = sorted(list(set(predicts))) + labels = sorted(list(set(labels))) for ele in predicts: if ele in labels: metrics['overall']['TP'] += 1 @@ -28,7 +30,7 @@ def evaluate(predict_result): if ele not in predicts: metrics['overall']['FN'] += 1 metrics[da_type]['FN'] += 1 - flag &= (sorted(predicts)==sorted(labels)) + flag &= (predicts==labels) acc.append(flag) for metric in metrics: -- GitLab