Skip to content
Snippets Groups Projects
Commit 12a1da7a authored by zqwerty's avatar zqwerty
Browse files

add interface for t5 nlu

parent a82fcf16
No related branches found
No related tags found
No related merge requests found
Showing
with 190 additions and 86 deletions
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
from tqdm import tqdm from tqdm import tqdm
import re import re
from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
from convlab2.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq
def create_rg_data(dataset, data_dir, args): def create_rg_data(dataset, data_dir, args):
data_by_split = load_rg_data(dataset, speaker=args.speaker) data_by_split = load_rg_data(dataset, speaker=args.speaker)
...@@ -29,56 +30,6 @@ def create_nlu_data(dataset, data_dir, args): ...@@ -29,56 +30,6 @@ def create_nlu_data(dataset, data_dir, args):
data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}') data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}')
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
def serialize_dialogue_acts(dialogue_acts):
da_seqs = []
for da_type in dialogue_acts:
for da in dialogue_acts[da_type]:
intent, domain, slot = da['intent'], da['domain'], da['slot']
if da_type == 'binary':
da_seq = f'[{da_type}][{intent}][{domain}][{slot}]'
else:
value = da['value']
da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]'
da_seqs.append(da_seq)
return ';'.join(da_seqs)
def deserialize_dialogue_acts(das_seq):
dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []}
if len(das_seq) == 0:
return dialogue_acts
da_seqs = das_seq.split('];[')
for i, da_seq in enumerate(da_seqs):
if i == 0:
assert da_seq[0] == '['
da_seq = da_seq[1:]
if i == len(da_seqs) - 1:
assert da_seq[-1] == ']'
da_seq = da_seq[:-1]
da = da_seq.split('][')
if len(da) == 0:
continue
da_type = da[0]
if len(da) == 5 and da_type in ['categorical', 'non-categorical']:
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]})
elif len(da) == 4 and da_type == 'binary':
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]})
else:
# invalid da format, skip
# print(das_seq)
# print(da_seq)
# print()
pass
return dialogue_acts
def equal_da_seq(dialogue_acts, das_seq):
predict_dialogue_acts = deserialize_dialogue_acts(das_seq)
for da_type in ['binary', 'categorical', 'non-categorical']:
das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]])
predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]])
if das != predict_das:
return False
return True
data_splits = data_by_split.keys() data_splits = data_by_split.keys()
file_name = os.path.join(data_dir, f"source_prefix.txt") file_name = os.path.join(data_dir, f"source_prefix.txt")
with open(file_name, "w") as f: with open(file_name, "w") as f:
......
import json
import os
from convlab2.util import load_dataset, load_nlu_data
from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
assert os.path.exists(predict_result)
dataset = load_dataset(dataset_name)
data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
if save_dir is None:
save_dir = os.path.dirname(predict_result)
else:
os.makedirs(save_dir, exist_ok=True)
predict_result = [deserialize_dialogue_acts(json.loads(x)['predictions'].strip()) for x in open(predict_result)]
for sample, prediction in zip(data, predict_result):
sample['predictions'] = {'dialogue_acts': prediction}
json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation")
parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
args = parser.parse_args()
print(args)
merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
import logging
import os
import json
import torch
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab2.nlu.nlu import NLU
from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts
from convlab2.util.custom_util import model_downloader
class T5NLU(NLU):
def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'):
assert speaker in ['user', 'system']
self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size
self.use_context = context_window_size > 0
self.prefix = "parse the dialogue action of the last utterance: "
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
model_downloader(model_dir, model_file)
self.config = AutoConfig.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config)
self.model.eval()
self.device = device if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
logging.info("T5NLU loaded")
def predict(self, utterance, context=list()):
if self.use_context:
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context]
utts = context + [utterance]
else:
utts = [utterance]
input_seq = ' '.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
# print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
# print(input_seq)
output_seq = self.model.generate(**input_seq, max_length=256)
# print(output_seq)
output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True)
# print(output_seq)
dialogue_acts = deserialize_dialogue_acts(output_seq.strip())
return dialogue_acts
if __name__ == '__main__':
texts = [
"I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"I want to leave after 17:15.",
"Thank you for all the help! I appreciate it.",
"Please find a restaurant called Nusha.",
"I am not sure of the type of food but could you please check again and see if you can find it? Thank you.",
"It's not a restaurant, it's an attraction. Nusha."
]
contexts = [
[],
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?"],
["What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540"],
[],
["Please find a restaurant called Nusha.",
"I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?"],
["I don't seem to be finding anything called Nusha. What type of food does the restaurant serve?",
"I am not sure of the type of food but could you please check again and see if you can find it? Thank you.",
"Could you double check that you've spelled the name correctly? The closest I can find is Nandos."]
]
nlu = T5NLU(speaker='user', context_window_size=3, model_name_or_path='output/nlu/multiwoz21/user/context_3')
for text, context in zip(texts, contexts):
print(text)
print(nlu.predict(text, context))
print()
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""NLU Metric""" """NLU Metric"""
import datasets import datasets
import re from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts
# TODO: Add BibTeX citation # TODO: Add BibTeX citation
...@@ -42,8 +42,8 @@ Returns: ...@@ -42,8 +42,8 @@ Returns:
Examples: Examples:
>>> nlu_metric = datasets.load_metric("nlu_metric.py") >>> nlu_metric = datasets.load_metric("nlu_metric.py")
>>> predictions = ["[binary]-[thank]-[general]-[]", "[non-categorical]-[inform]-[taxi]-[leave at]-[17:15]"] >>> predictions = ["[binary][thank][general][]", "[non-categorical][inform][taxi][leave at][17:15]"]
>>> references = ["[binary]-[thank]-[general]-[]", "[non-categorical]-[inform]-[train]-[leave at]-[17:15]"] >>> references = ["[binary][thank][general][]", "[non-categorical][inform][train][leave at][17:15]"]
>>> results = nlu_metric.compute(predictions=predictions, references=references) >>> results = nlu_metric.compute(predictions=predictions, references=references)
>>> print(results) >>> print(results)
{'seq_em': 0.5, 'accuracy': 0.5, {'seq_em': 0.5, 'accuracy': 0.5,
...@@ -70,36 +70,6 @@ class NLUMetrics(datasets.Metric): ...@@ -70,36 +70,6 @@ class NLUMetrics(datasets.Metric):
}) })
) )
def deserialize_dialogue_acts(self, das_seq):
dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []}
if len(das_seq) == 0:
return dialogue_acts
da_seqs = das_seq.split('];[')
for i, da_seq in enumerate(da_seqs):
if len(da_seq) == 0:
continue
if i == 0:
if da_seq[0] == '[':
da_seq = da_seq[1:]
if i == len(da_seqs) - 1:
if da_seq[-1] == ']':
da_seq = da_seq[:-1]
da = da_seq.split('][')
if len(da) == 0:
continue
da_type = da[0]
if len(da) == 5 and da_type in ['categorical', 'non-categorical']:
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]})
elif len(da) == 4 and da_type == 'binary':
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]})
else:
# invalid da format, skip
# print(das_seq)
# print(da_seq)
# print()
pass
return dialogue_acts
def _compute(self, predictions, references): def _compute(self, predictions, references):
"""Returns the scores: sequence exact match, dialog acts accuracy and f1""" """Returns the scores: sequence exact match, dialog acts accuracy and f1"""
seq_em = [] seq_em = []
...@@ -108,8 +78,8 @@ class NLUMetrics(datasets.Metric): ...@@ -108,8 +78,8 @@ class NLUMetrics(datasets.Metric):
for prediction, reference in zip(predictions, references): for prediction, reference in zip(predictions, references):
seq_em.append(prediction.strip()==reference.strip()) seq_em.append(prediction.strip()==reference.strip())
pred_da = self.deserialize_dialogue_acts(prediction) pred_da = deserialize_dialogue_acts(prediction)
gold_da = self.deserialize_dialogue_acts(reference) gold_da = deserialize_dialogue_acts(reference)
flag = True flag = True
for da_type in ['binary', 'categorical', 'non-categorical']: for da_type in ['binary', 'categorical', 'non-categorical']:
if da_type == 'binary': if da_type == 'binary':
......
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \ ...@@ -66,3 +66,5 @@ python -m torch.distributed.launch \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
def serialize_dialogue_acts(dialogue_acts):
da_seqs = []
for da_type in dialogue_acts:
for da in dialogue_acts[da_type]:
intent, domain, slot = da['intent'], da['domain'], da['slot']
if da_type == 'binary':
da_seq = f'[{da_type}][{intent}][{domain}][{slot}]'
else:
value = da['value']
da_seq = f'[{da_type}][{intent}][{domain}][{slot}][{value}]'
da_seqs.append(da_seq)
return ';'.join(da_seqs)
def deserialize_dialogue_acts(das_seq):
dialogue_acts = {'binary': [], 'categorical': [], 'non-categorical': []}
if len(das_seq) == 0:
return dialogue_acts
da_seqs = das_seq.split('];[')
for i, da_seq in enumerate(da_seqs):
if len(da_seq) == 0:
continue
if i == 0:
if da_seq[0] == '[':
da_seq = da_seq[1:]
if i == len(da_seqs) - 1:
if da_seq[-1] == ']':
da_seq = da_seq[:-1]
da = da_seq.split('][')
if len(da) == 0:
continue
da_type = da[0]
if len(da) == 5 and da_type in ['categorical', 'non-categorical']:
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3], 'value': da[4]})
elif len(da) == 4 and da_type == 'binary':
dialogue_acts[da_type].append({'intent': da[1], 'domain': da[2], 'slot': da[3]})
else:
# invalid da format, skip
# print(das_seq)
# print(da_seq)
# print()
pass
return dialogue_acts
def equal_da_seq(dialogue_acts, das_seq):
predict_dialogue_acts = deserialize_dialogue_acts(das_seq)
for da_type in ['binary', 'categorical', 'non-categorical']:
das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in dialogue_acts[da_type]])
predict_das = sorted([(da['intent'], da['domain'], da['slot'], da.get('value', '')) for da in predict_dialogue_acts[da_type]])
if das != predict_das:
return False
return True
...@@ -17,6 +17,8 @@ def evaluate(predict_result): ...@@ -17,6 +17,8 @@ def evaluate(predict_result):
else: else:
predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['predictions']['dialogue_acts'][da_type]] predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['predictions']['dialogue_acts'][da_type]]
labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['dialogue_acts'][da_type]] labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['dialogue_acts'][da_type]]
predicts = sorted(list(set(predicts)))
labels = sorted(list(set(labels)))
for ele in predicts: for ele in predicts:
if ele in labels: if ele in labels:
metrics['overall']['TP'] += 1 metrics['overall']['TP'] += 1
...@@ -28,7 +30,7 @@ def evaluate(predict_result): ...@@ -28,7 +30,7 @@ def evaluate(predict_result):
if ele not in predicts: if ele not in predicts:
metrics['overall']['FN'] += 1 metrics['overall']['FN'] += 1
metrics[da_type]['FN'] += 1 metrics[da_type]['FN'] += 1
flag &= (sorted(predicts)==sorted(labels)) flag &= (predicts==labels)
acc.append(flag) acc.append(flag)
for metric in metrics: for metric in metrics:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment