diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py index 33b87018271e8c34be85f7efcdb7422dc4341dad..19be0b81520cf4077ac34166e7b2e7a0d12f80a3 100644 --- a/convlab2/base_models/t5/create_data.py +++ b/convlab2/base_models/t5/create_data.py @@ -3,6 +3,7 @@ import json from tqdm import tqdm import re from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data +from convlab2.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq def create_rg_data(dataset, data_dir, args): data_by_split = load_rg_data(dataset, speaker=args.speaker) @@ -29,56 +30,6 @@ def create_nlu_data(dataset, data_dir, args): data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}') os.makedirs(data_dir, exist_ok=True) - def serialize_dialogue_acts(dialogue_acts): - da_seqs = [] - for da_type in dialogue_acts: - for da in dialogue_acts[da_type]: - intent, domain, slot = da['intent'], da['domain'], da['slot'] - if da_type == 'binary': - da_seq = f'[{da_type}][{intent}][{domain}][{slot}]' - else: - value = da['value'] - da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]' - da_seqs.append(da_seq) - return ';'.join(da_seqs) - - def deserialize_dialogue_acts(das_seq): - dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} - if len(das_seq) == 0: - return dialogue_acts - da_seqs = das_seq.split('];[') - for i, da_seq in enumerate(da_seqs): - if i == 0: - assert da_seq[0] == '[' - da_seq = da_seq[1:] - if i == len(da_seqs) - 1: - assert da_seq[-1] == ']' - da_seq = da_seq[:-1] - da = da_seq.split('][') - if len(da) == 0: - continue - da_type = da[0] - if len(da) == 5 and da_type in ['categorical', 'non-categorical']: - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) - elif len(da) == 4 and da_type == 'binary': - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]}) - else: - # invalid da format, skip - # print(das_seq) - # print(da_seq) - # print() - pass - return dialogue_acts - - def equal_da_seq(dialogue_acts, das_seq): - predict_dialogue_acts = deserialize_dialogue_acts(das_seq) - for da_type in ['binary', 'categorical', 'non-categorical']: - das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]]) - predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]]) - if das != predict_das: - return False - return True - data_splits = data_by_split.keys() file_name = os.path.join(data_dir, f"source_prefix.txt") with open(file_name, "w") as f: diff --git a/convlab2/base_models/t5/nlu/merge_predict_res.py b/convlab2/base_models/t5/nlu/merge_predict_res.py new file mode 100755 index 0000000000000000000000000000000000000000..f3386b210817a6ae26c153776e47324793c70546 --- /dev/null +++ b/convlab2/base_models/t5/nlu/merge_predict_res.py @@ -0,0 +1,34 @@ +import json +import os +from convlab2.util import load_dataset, load_nlu_data +from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts + + +def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): + assert os.path.exists(predict_result) + dataset = load_dataset(dataset_name) + data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] + + if save_dir is None: + save_dir = os.path.dirname(predict_result) + else: + os.makedirs(save_dir, exist_ok=True) + predict_result = [deserialize_dialogue_acts(json.loads(x)['predictions'].strip()) for x in open(predict_result)] + + for sample, prediction in zip(data, predict_result): + sample['predictions'] = {'dialogue_acts': prediction} + + json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation") + parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset') + parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') + parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') + parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + args = parser.parse_args() + print(args) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) diff --git a/convlab2/base_models/t5/nlu/nlu.py b/convlab2/base_models/t5/nlu/nlu.py new file mode 100755 index 0000000000000000000000000000000000000000..2be433301049505d115abda0e4e1fe9899b479a1 --- /dev/null +++ b/convlab2/base_models/t5/nlu/nlu.py @@ -0,0 +1,80 @@ +import logging +import os +import json +import torch +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig +from convlab2.nlu.nlu import NLU +from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts +from convlab2.util.custom_util import model_downloader + + +class T5NLU(NLU): + def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): + assert speaker in ['user', 'system'] + self.speaker = speaker + self.opponent = 'system' if speaker == 'user' else 'user' + self.context_window_size = context_window_size + self.use_context = context_window_size > 0 + self.prefix = "parse the dialogue action of the last utterance: " + + model_dir = os.path.dirname(os.path.abspath(__file__)) + if not os.path.exists(model_name_or_path): + model_downloader(model_dir, model_file) + + self.config = AutoConfig.from_pretrained(model_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config) + self.model.eval() + self.device = device if torch.cuda.is_available() else "cpu" + self.model.to(self.device) + + logging.info("T5NLU loaded") + + def predict(self, utterance, context=list()): + if self.use_context: + if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: + context = [item[1] for item in context] + utts = context + [utterance] + else: + utts = [utterance] + input_seq = ' '.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)]) + # print(input_seq) + input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device) + # print(input_seq) + output_seq = self.model.generate(**input_seq, max_length=256) + # print(output_seq) + output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True) + # print(output_seq) + dialogue_acts = deserialize_dialogue_acts(output_seq.strip()) + return dialogue_acts + + +if __name__ == '__main__': + texts = [ + "I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", + "I want to leave after 17:15.", + "Thank you for all the help! I appreciate it.", + "Please find a restaurant called Nusha.", + "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", + "It's not a restaurant, it's an attraction. Nusha." + ] + contexts = [ + [], + ["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", + "What time do you want to leave and what time do you want to arrive by?"], + ["What time do you want to leave and what time do you want to arrive by?", + "I want to leave after 17:15.", + "Booking completed! your taxi will be blue honda Contact number is 07218068540"], + [], + ["Please find a restaurant called Nusha.", + "I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?"], + ["I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?", + "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.", + "Could you double check that you've spelled the name correctly? The closest I can find is Nandos."] + ] + nlu = T5NLU(speaker='user', context_window_size=3, model_name_or_path='output/nlu/multiwoz21/user/context_3') + for text, context in zip(texts, contexts): + print(text) + print(nlu.predict(text, context)) + print() diff --git a/convlab2/base_models/t5/nlu/nlu_metric.py b/convlab2/base_models/t5/nlu/nlu_metric.py index feff4edd4bb28cbca763c03e88d0cad542ab7179..1eb57c84a02bd1f019eb1978271cb53c3b3a1916 100644 --- a/convlab2/base_models/t5/nlu/nlu_metric.py +++ b/convlab2/base_models/t5/nlu/nlu_metric.py @@ -14,7 +14,7 @@ """NLU Metric""" import datasets -import re +from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts # TODO: Add BibTeX citation @@ -42,8 +42,8 @@ Returns: Examples: >>> nlu_metric = datasets.load_metric("nlu_metric.py") - >>> predictions = ["[binary]-[thank]-[general]-[]", "[non-categorical]-[inform]-[taxi]-[leave at]-[17:15]"] - >>> references = ["[binary]-[thank]-[general]-[]", "[non-categorical]-[inform]-[train]-[leave at]-[17:15]"] + >>> predictions = ["[binary][thank][general][]", "[non-categorical][inform][taxi][leave at][17:15]"] + >>> references = ["[binary][thank][general][]", "[non-categorical][inform][train][leave at][17:15]"] >>> results = nlu_metric.compute(predictions=predictions, references=references) >>> print(results) {'seq_em': 0.5, 'accuracy': 0.5, @@ -70,36 +70,6 @@ class NLUMetrics(datasets.Metric): }) ) - def deserialize_dialogue_acts(self, das_seq): - dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} - if len(das_seq) == 0: - return dialogue_acts - da_seqs = das_seq.split('];[') - for i, da_seq in enumerate(da_seqs): - if len(da_seq) == 0: - continue - if i == 0: - if da_seq[0] == '[': - da_seq = da_seq[1:] - if i == len(da_seqs) - 1: - if da_seq[-1] == ']': - da_seq = da_seq[:-1] - da = da_seq.split('][') - if len(da) == 0: - continue - da_type = da[0] - if len(da) == 5 and da_type in ['categorical', 'non-categorical']: - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) - elif len(da) == 4 and da_type == 'binary': - dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]}) - else: - # invalid da format, skip - # print(das_seq) - # print(da_seq) - # print() - pass - return dialogue_acts - def _compute(self, predictions, references): """Returns the scores: sequence exact match, dialog acts accuracy and f1""" seq_em = [] @@ -108,8 +78,8 @@ class NLUMetrics(datasets.Metric): for prediction, reference in zip(predictions, references): seq_em.append(prediction.strip()==reference.strip()) - pred_da = self.deserialize_dialogue_acts(prediction) - gold_da = self.deserialize_dialogue_acts(reference) + pred_da = deserialize_dialogue_acts(prediction) + gold_da = deserialize_dialogue_acts(reference) flag = True for da_type in ['binary', 'categorical', 'non-categorical']: if da_type == 'binary': diff --git a/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh b/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh index a5bd0527b9961f20d3f179d802c97018bfb1744e..85f3ec8302d161b29ba71b760a56d0f64a6b4dfc 100644 --- a/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh +++ b/convlab2/base_models/t5/nlu/run_multiwoz21_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh b/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh index d9f59dbfeedbd3dd062a76483e9a8fcb5ab92d73..8d7b5c93e8deb9c8c5da9ecd03e42bbc53341442 100644 --- a/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_multiwoz21_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm1_user.sh b/convlab2/base_models/t5/nlu/run_tm1_user.sh index ffd8aa4148c7ac05eacbb1b4dc9967fe26f6ba5e..16a16fdb106f09a7001190477de8b0878d2e20f3 100644 --- a/convlab2/base_models/t5/nlu/run_tm1_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm1_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh index b99d7fade71c24701a62dd53f86c1c5eb1be70fa..ccb67609279be5c4b044a9baadc19672d69c1532 100644 --- a/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm1_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm2_user.sh b/convlab2/base_models/t5/nlu/run_tm2_user.sh index e41d2f393703ec906ba36739e0dc87c70ebece26..8686822fea882cb75776bee89dbd4344b71ea64b 100644 --- a/convlab2/base_models/t5/nlu/run_tm2_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm2_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh index 832c20b4da1b70b2f406cb03cc8a2d14966863f3..03c2489940e38dd16256f6b4f2683a413f514235 100644 --- a/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm2_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm3_user.sh b/convlab2/base_models/t5/nlu/run_tm3_user.sh index 39399a76e8326c4051f96ff5e698def61ad8ca0d..470cb7d71c2b7a630e6917912e21d2c61ca1c075 100644 --- a/convlab2/base_models/t5/nlu/run_tm3_user.sh +++ b/convlab2/base_models/t5/nlu/run_tm3_user.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh b/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh index fce613ad425169246157db71bce1bcf34150efe2..5e325d1fe2b127ef1af0b0733dd5db03bb1cbe3c 100644 --- a/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh +++ b/convlab2/base_models/t5/nlu/run_tm3_user_context3.sh @@ -66,3 +66,5 @@ python -m torch.distributed.launch \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json diff --git a/convlab2/base_models/t5/nlu/serialization.py b/convlab2/base_models/t5/nlu/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..5a620f4689519accaccdc1149a54ed6c8efb52d8 --- /dev/null +++ b/convlab2/base_models/t5/nlu/serialization.py @@ -0,0 +1,51 @@ +def serialize_dialogue_acts(dialogue_acts): + da_seqs = [] + for da_type in dialogue_acts: + for da in dialogue_acts[da_type]: + intent, domain, slot = da['intent'], da['domain'], da['slot'] + if da_type == 'binary': + da_seq = f'[{da_type}][{intent}][{domain}][{slot}]' + else: + value = da['value'] + da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]' + da_seqs.append(da_seq) + return ';'.join(da_seqs) + +def deserialize_dialogue_acts(das_seq): + dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []} + if len(das_seq) == 0: + return dialogue_acts + da_seqs = das_seq.split('];[') + for i, da_seq in enumerate(da_seqs): + if len(da_seq) == 0: + continue + if i == 0: + if da_seq[0] == '[': + da_seq = da_seq[1:] + if i == len(da_seqs) - 1: + if da_seq[-1] == ']': + da_seq = da_seq[:-1] + da = da_seq.split('][') + if len(da) == 0: + continue + da_type = da[0] + if len(da) == 5 and da_type in ['categorical', 'non-categorical']: + dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]}) + elif len(da) == 4 and da_type == 'binary': + dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]}) + else: + # invalid da format, skip + # print(das_seq) + # print(da_seq) + # print() + pass + return dialogue_acts + +def equal_da_seq(dialogue_acts, das_seq): + predict_dialogue_acts = deserialize_dialogue_acts(das_seq) + for da_type in ['binary', 'categorical', 'non-categorical']: + das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]]) + predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]]) + if das != predict_das: + return False + return True diff --git a/convlab2/nlu/evaluate_unified_datasets.py b/convlab2/nlu/evaluate_unified_datasets.py index 86e91747702a1dedd727b5945bf6cbcaa08540d9..907b1afaaee6788c1e90e3bd85b67b3360c9c2da 100644 --- a/convlab2/nlu/evaluate_unified_datasets.py +++ b/convlab2/nlu/evaluate_unified_datasets.py @@ -17,6 +17,8 @@ def evaluate(predict_result): else: predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['predictions']['dialogue_acts'][da_type]] labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['dialogue_acts'][da_type]] + predicts = sorted(list(set(predicts))) + labels = sorted(list(set(labels))) for ele in predicts: if ele in labels: metrics['overall']['TP'] += 1 @@ -28,7 +30,7 @@ def evaluate(predict_result): if ele not in predicts: metrics['overall']['FN'] += 1 metrics[da_type]['FN'] += 1 - flag &= (sorted(predicts)==sorted(labels)) + flag &= (predicts==labels) acc.append(flag) for metric in metrics: