diff --git a/convlab2/nlu/jointBERT/unified_datasets/preprocess.py b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py index 055e0724ab4c8d5d8f467a6d11354426974da3d8..ca942b38f039abc449dcc9c80ba1ab352aac2483 100755 --- a/convlab2/nlu/jointBERT/unified_datasets/preprocess.py +++ b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py @@ -1,7 +1,7 @@ import json import os from collections import Counter -from convlab2.util import load_dataset, load_ontology, load_nlu_data +from convlab2.util import load_dataset, load_nlu_data from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer from tqdm import tqdm diff --git a/convlab2/nlu/milu/README.md b/convlab2/nlu/milu/README.md index b1c6bf5a130215e89c4b6145a73c720fca86be18..4cf508e1b0b45f69672ba5e0802a50ff1a2f5904 100755 --- a/convlab2/nlu/milu/README.md +++ b/convlab2/nlu/milu/README.md @@ -5,16 +5,41 @@ MILU is a joint neural model that allows you to simultaneously predict multiple ## Example usage We based our implementation on the [AllenNLP library](https://github.com/allenai/allennlp). For an introduction to this library, you should check [these tutorials](https://allennlp.org/tutorials). +To use this model, you need to additionally install `overrides==4.1.2, allennlp==0.9.0` and use `python>=3.6,<=3.8`. + +### On MultiWOZ dataset + ```bash -$ PYTHONPATH=../../.. python train.py multiwoz/configs/[base|context3].jsonnet -s serialization_dir -$ PYTHONPATH=../../.. python evaluate.py serialization_dir/model.tar.gz {test_file} --cuda-device {CUDA_DEVICE} +$ python train.py multiwoz/configs/[base|context3].jsonnet -s serialization_dir +$ python evaluate.py serialization_dir/model.tar.gz {test_file} --cuda-device {CUDA_DEVICE} ``` If you want to perform end-to-end evaluation, you can include the trained model by adding the model path (serialization_dir/model.tar.gz) to your ConvLab spec file. -## Data +#### Data We use the multiwoz data (data/multiwoz/[train|val|test].json.zip). +### MILU on datasets in unified format +We support training MILU on datasets that are in our unified format. + +- For **non-categorical** dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values. +- For **categorical** and **binary** dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances. + +Takes MultiWOZ 2.1 (unified format) as an example, +```bash +$ python train.py unified_datasets/configs/multiwoz21_user_context3.jsonnet -s serialization_dir +$ python evaluate.py serialization_dir/model.tar.gz test --cuda-device {CUDA_DEVICE} +``` +Note that the config file is different from the above. You should set: +- `"use_unified_datasets": true` in `dataset_reader` and `model` +- `"dataset_name": "multiwoz21"` in `dataset_reader` +- `"train_data_path": "train"` +- `"validation_data_path": "validation"` +- `"test_data_path": "test"` + +## Predict +See `nlu.py` under `multiwoz` and `unified_datasets` directories. + ## References ``` @inproceedings{lee2019convlab, diff --git a/convlab2/nlu/milu/dai_f1_measure.py b/convlab2/nlu/milu/dai_f1_measure.py index f82d9acd627192d3239b24bc3456614aac85a5e7..4bb7ec868259dd4c09351d1edc348bb50caa8542 100755 --- a/convlab2/nlu/milu/dai_f1_measure.py +++ b/convlab2/nlu/milu/dai_f1_measure.py @@ -9,7 +9,7 @@ from allennlp.training.metrics.metric import Metric class DialogActItemF1Measure(Metric): """ """ - def __init__(self) -> None: + def __init__(self, use_unified_datasets) -> None: """ Parameters ---------- @@ -18,6 +18,7 @@ class DialogActItemF1Measure(Metric): self._true_positives = 0 self._false_positives = 0 self._false_negatives = 0 + self.use_unified_datasets = use_unified_datasets def __call__(self, @@ -32,17 +33,36 @@ class DialogActItemF1Measure(Metric): A tensor of integer class label of shape (batch_size, sequence_length). It must be the same shape as the ``predictions`` tensor without the ``num_classes`` dimension. """ - for prediction, gold_label in zip(predictions, gold_labels): - for dat in prediction: - for sv in prediction[dat]: - if dat not in gold_label or sv not in gold_label[dat]: - self._false_positives += 1 + if self.use_unified_datasets: + for prediction, gold_label in zip(predictions, gold_labels): + for da_type in ['non-categorical', 'categorical', 'binary']: + if da_type == 'binary': + predicts = [(x['intent'], x['domain'], x['slot']) for x in prediction[da_type]] + labels = [(x['intent'], x['domain'], x['slot']) for x in gold_label[da_type]] else: - self._true_positives += 1 - for dat in gold_label: - for sv in gold_label[dat]: - if dat not in prediction or sv not in prediction[dat]: - self._false_negatives += 1 + predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in prediction[da_type]] + labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in gold_label[da_type]] + + for ele in predicts: + if ele in labels: + self._true_positives += 1 + else: + self._false_positives += 1 + for ele in labels: + if ele not in predicts: + self._false_negatives += 1 + else: + for prediction, gold_label in zip(predictions, gold_labels): + for dat in prediction: + for sv in prediction[dat]: + if dat not in gold_label or sv not in gold_label[dat]: + self._false_positives += 1 + else: + self._true_positives += 1 + for dat in gold_label: + for sv in gold_label[dat]: + if dat not in prediction or sv not in prediction[dat]: + self._false_negatives += 1 def get_metric(self, reset: bool = False): diff --git a/convlab2/nlu/milu/dataset_reader.py b/convlab2/nlu/milu/dataset_reader.py index 5e00af04e7fe6c13ddbb21d60f22c51d5cbbc106..35f71903ab4a269b0f9e5d3cd208d78e48278349 100755 --- a/convlab2/nlu/milu/dataset_reader.py +++ b/convlab2/nlu/milu/dataset_reader.py @@ -13,6 +13,8 @@ from allennlp.data.fields import TextField, SequenceLabelField, MultiLabelField, from allennlp.data.instance import Instance from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer from allennlp.data.tokenizers import Token +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer +from convlab2.util import load_dataset, load_nlu_data from overrides import overrides from convlab2.util.file_util import cached_path @@ -45,6 +47,8 @@ class MILUDatasetReader(DatasetReader): def __init__(self, context_size: int = 0, agent: str = None, + use_unified_datasets: bool = False, + dataset_name: str = None, random_context_size: bool = True, token_delimiter: str = None, token_indexers: Dict[str, TokenIndexer] = None, @@ -52,81 +56,150 @@ class MILUDatasetReader(DatasetReader): super().__init__(lazy) self._context_size = context_size self._agent = agent + self.use_unified_datasets = use_unified_datasets + if self.use_unified_datasets: + self._dataset_name = dataset_name + self._dataset = load_dataset(self._dataset_name) self._random_context_size = random_context_size self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} self._token_delimiter = token_delimiter + self._sent_tokenizer = PunktSentenceTokenizer() + self._word_tokenizer = TreebankWordTokenizer() @overrides def _read(self, file_path): - # if `file_path` is a URL, redirect to the cache - file_path = cached_path(file_path) - - if file_path.endswith("zip"): - archive = zipfile.ZipFile(file_path, "r") - data_file = archive.open(os.path.basename(file_path)[:-4]) - else: - data_file = open(file_path, "r") - - logger.info("Reading instances from lines in file at: %s", file_path) - - dialogs = json.load(data_file) - - for dial_name in dialogs: - dialog = dialogs[dial_name]["log"] - context_tokens_list = [] - for i, turn in enumerate(dialog): - if self._agent and self._agent == "user" and i % 2 == 1: - context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"]) - continue - if self._agent and self._agent == "system" and i % 2 == 0: - context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"]) - continue - - tokens = turn["text"].split() - - dialog_act = {} - for dacts in turn["span_info"]: - if dacts[0] not in dialog_act: - dialog_act[dacts[0]] = [] - dialog_act[dacts[0]].append([dacts[1], " ".join(tokens[dacts[3]: dacts[4]+1])]) - - spans = turn["span_info"] - tags = [] - for j in range(len(tokens)): - for span in spans: - if j == span[3]: - tags.append("B-"+span[0]+"+"+span[1]) - break - if j > span[3] and j <= span[4]: - tags.append("I-"+span[0]+"+"+span[1]) - break - else: - tags.append("O") + if self.use_unified_datasets: + data_split = file_path + logger.info("Reading instances from unified dataset %s[%s]", self._dataset_name, data_split) + + data = load_nlu_data(self._dataset, data_split=data_split, speaker=self._agent, use_context=self._context_size>0, context_window_size=self._context_size)[data_split] + + for sample in data: + utterance = sample['utterance'] + sentences = self._sent_tokenizer.tokenize(utterance) + sent_spans = self._sent_tokenizer.span_tokenize(utterance) + tokens = [token for sent in sentences for token in self._word_tokenizer.tokenize(sent)] + token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in self._word_tokenizer.span_tokenize(sent)] + tags = ['O'] * len(tokens) + + for da in sample['dialogue_acts']['non-categorical']: + if 'start' not in da: + # skip da that doesn't have span annotation + continue + char_start = da['start'] + char_end = da['end'] + word_start, word_end = -1, -1 + for i, token_span in enumerate(token_spans): + if char_start == token_span[0]: + word_start = i + if char_end == token_span[1]: + word_end = i + 1 + if word_start == -1 and word_end == -1: + # char span does not match word, maybe there is an error in the annotation, skip + print('char span does not match word, skipping') + print('\t', 'utteance:', utterance) + print('\t', 'value:', utterance[char_start: char_end]) + print('\t', 'da:', da, '\n') + continue + intent, domain, slot = da['intent'], da['domain'], da['slot'] + tags[word_start] = f"B-{intent}+{domain}+{slot}" + for i in range(word_start+1, word_end): + tags[i] = f"I-{intent}+{domain}+{slot}" intents = [] - for dacts in turn["dialog_act"]: - for dact in turn["dialog_act"][dacts]: - if dacts not in dialog_act or dact[0] not in [sv[0] for sv in dialog_act[dacts]]: - if dact[1] in ["none", "?", "yes", "no", "dontcare", "do nt care", "do n't care"]: - intents.append(dacts+"+"+dact[0]+"*"+dact[1]) - - for dacts in turn["dialog_act"]: - for dact in turn["dialog_act"][dacts]: - if dacts not in dialog_act: - dialog_act[dacts] = turn["dialog_act"][dacts] - break - elif dact[0] not in [sv[0] for sv in dialog_act[dacts]]: - dialog_act[dacts].append(dact) + for da in sample['dialogue_acts']['categorical']: + intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'].strip().lower() + intent = str((intent, domain, slot, value)) + intents.append(intent) + for da in sample['dialogue_acts']['binary']: + intent, domain, slot = da['intent'], da['domain'], da['slot'] + intent = str((intent, domain, slot)) + intents.append(intent) + + wrapped_tokens = [Token(token) for token in tokens] + wrapped_context_tokens = [] num_context = random.randint(0, self._context_size) if self._random_context_size else self._context_size - if len(context_tokens_list) > 0 and num_context > 0: - wrapped_context_tokens = [Token(token) for context_tokens in context_tokens_list[-num_context:] for token in context_tokens] + if num_context > 0 and len(sample['context']) > 0: + for utt in sample['context']: + for sent in self._sent_tokenizer.tokenize(utt['utterance']): + for token in self._word_tokenizer.tokenize(sent): + wrapped_context_tokens.append(Token(token)) + wrapped_context_tokens.append(Token("SENT_END")) else: wrapped_context_tokens = [Token("SENT_END")] - wrapped_tokens = [Token(token) for token in tokens] - context_tokens_list.append(tokens + ["SENT_END"]) - yield self.text_to_instance(wrapped_context_tokens, wrapped_tokens, tags, intents, dialog_act) + yield self.text_to_instance(wrapped_context_tokens, wrapped_tokens, tags, intents, sample['dialogue_acts']) + else: + # if `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) + + if file_path.endswith("zip"): + archive = zipfile.ZipFile(file_path, "r") + data_file = archive.open(os.path.basename(file_path)[:-4]) + else: + data_file = open(file_path, "r") + + logger.info("Reading instances from lines in file at: %s", file_path) + + dialogs = json.load(data_file) + + for dial_name in dialogs: + dialog = dialogs[dial_name]["log"] + context_tokens_list = [] + for i, turn in enumerate(dialog): + if self._agent and self._agent == "user" and i % 2 == 1: + context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"]) + continue + if self._agent and self._agent == "system" and i % 2 == 0: + context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"]) + continue + + tokens = turn["text"].split() + + dialog_act = {} + for dacts in turn["span_info"]: + if dacts[0] not in dialog_act: + dialog_act[dacts[0]] = [] + dialog_act[dacts[0]].append([dacts[1], " ".join(tokens[dacts[3]: dacts[4]+1])]) + + spans = turn["span_info"] + tags = [] + for j in range(len(tokens)): + for span in spans: + if j == span[3]: + tags.append("B-"+span[0]+"+"+span[1]) + break + if j > span[3] and j <= span[4]: + tags.append("I-"+span[0]+"+"+span[1]) + break + else: + tags.append("O") + + intents = [] + for dacts in turn["dialog_act"]: + for dact in turn["dialog_act"][dacts]: + if dacts not in dialog_act or dact[0] not in [sv[0] for sv in dialog_act[dacts]]: + if dact[1] in ["none", "?", "yes", "no", "dontcare", "do nt care", "do n't care"]: + intents.append(dacts+"+"+dact[0]+"*"+dact[1]) + + for dacts in turn["dialog_act"]: + for dact in turn["dialog_act"][dacts]: + if dacts not in dialog_act: + dialog_act[dacts] = turn["dialog_act"][dacts] + break + elif dact[0] not in [sv[0] for sv in dialog_act[dacts]]: + dialog_act[dacts].append(dact) + + num_context = random.randint(0, self._context_size) if self._random_context_size else self._context_size + if len(context_tokens_list) > 0 and num_context > 0: + wrapped_context_tokens = [Token(token) for context_tokens in context_tokens_list[-num_context:] for token in context_tokens] + else: + wrapped_context_tokens = [Token("SENT_END")] + wrapped_tokens = [Token(token) for token in tokens] + context_tokens_list.append(tokens + ["SENT_END"]) + + yield self.text_to_instance(wrapped_context_tokens, wrapped_tokens, tags, intents, dialog_act) def text_to_instance(self, context_tokens: List[Token], tokens: List[Token], tags: List[str] = None, diff --git a/convlab2/nlu/milu/model.py b/convlab2/nlu/milu/model.py index e4831926bd08c1b2df6b924f3f65f65d097e8b87..8d5031c7891e3abb010ff0b72ff4a42489dc39a6 100755 --- a/convlab2/nlu/milu/model.py +++ b/convlab2/nlu/milu/model.py @@ -57,6 +57,7 @@ class MILU(Model): feedforward: Optional[FeedForward] = None, label_encoding: Optional[str] = None, include_start_end_transitions: bool = True, + use_unified_datasets: bool = False, crf_decoding: bool = False, constrain_crf_decoding: bool = None, focal_loss_gamma: float = None, @@ -83,6 +84,7 @@ class MILU(Model): self.tag_encoder = intent_encoder self._feedforward = feedforward self._verbose_metrics = verbose_metrics + self.use_unified_datasets = use_unified_datasets self.rl = False if attention: @@ -164,7 +166,7 @@ class MILU(Model): self._f1_metric = SpanBasedF1Measure(vocab, tag_namespace=sequence_label_namespace, label_encoding=label_encoding) - self._dai_f1_metric = DialogActItemF1Measure() + self._dai_f1_metric = DialogActItemF1Measure(self.use_unified_datasets) check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), "text field embedding dim", "encoder input dim") @@ -355,29 +357,64 @@ class MILU(Model): for i, tags in enumerate(output_dict["tags"]): seq_len = len(output_dict["words"][i]) spans = bio_tags_to_spans(tags[:seq_len]) - dialog_act = {} - for span in spans: - domain_act = span[0].split("+")[0] - slot = span[0].split("+")[1] - value = " ".join(output_dict["words"][i][span[1][0]:span[1][1]+1]) - if domain_act not in dialog_act: - dialog_act[domain_act] = [[slot, value]] - else: - dialog_act[domain_act].append([slot, value]) - for intent in output_dict["intents"][i]: - if "+" in intent: - if "*" in intent: - intent, value = intent.split("*", 1) + if self.use_unified_datasets: + dialog_act = { + 'categorical': [], + 'non-categorical': [], + 'binary': [] + } + for span in spans: + intent, domain, slot = span[0].split("+") + value = " ".join(output_dict["words"][i][span[1][0]:span[1][1]+1]) + dialog_act['non-categorical'].append({ + 'intent': intent, + 'domain': domain, + 'slot': slot, + 'value': value + }) + + for intent in output_dict["intents"][i]: + intent = eval(intent) + if len(intent) == 3: + dialog_act['binary'].append({ + 'intent': intent[0], + 'domain': intent[1], + 'slot': intent[2] + }) else: - value = "?" - domain_act = intent.split("+")[0] + assert len(intent) == 4 + dialog_act['categorical'].append({ + 'intent': intent[0], + 'domain': intent[1], + 'slot': intent[2], + 'value': intent[3] + }) + output_dict["dialog_act"].append(dialog_act) + + else: + dialog_act = {} + for span in spans: + domain_act = span[0].split("+")[0] + slot = span[0].split("+")[1] + value = " ".join(output_dict["words"][i][span[1][0]:span[1][1]+1]) if domain_act not in dialog_act: - dialog_act[domain_act] = [[intent.split("+")[1], value]] + dialog_act[domain_act] = [[slot, value]] + else: + dialog_act[domain_act].append([slot, value]) + for intent in output_dict["intents"][i]: + if "+" in intent: + if "*" in intent: + intent, value = intent.split("*", 1) + else: + value = "?" + domain_act = intent.split("+")[0] + if domain_act not in dialog_act: + dialog_act[domain_act] = [[intent.split("+")[1], value]] + else: + dialog_act[domain_act].append([intent.split("+")[1], value]) else: - dialog_act[domain_act].append([intent.split("+")[1], value]) - else: - dialog_act[intent] = [["none", "none"]] - output_dict["dialog_act"].append(dialog_act) + dialog_act[intent] = [["none", "none"]] + output_dict["dialog_act"].append(dialog_act) return output_dict diff --git a/convlab2/nlu/milu/unified_datasets/__init__.py b/convlab2/nlu/milu/unified_datasets/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..902c9a23af7fe1f0c8423394085a0f4ea0bb55f9 --- /dev/null +++ b/convlab2/nlu/milu/unified_datasets/__init__.py @@ -0,0 +1 @@ +from convlab2.nlu.milu.unified_datasets.nlu import MILU diff --git a/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user.jsonnet b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user.jsonnet new file mode 100755 index 0000000000000000000000000000000000000000..858a57a19ecbe2acd59f02465d79c1d852341f60 --- /dev/null +++ b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user.jsonnet @@ -0,0 +1,104 @@ +{ + "dataset_reader": { + "type": "milu", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + }, + "token_characters": { + "type": "characters", + "min_padding_length": 3 + }, + }, + "context_size": 0, + "agent": "user", + "use_unified_datasets": true, + "dataset_name": "multiwoz21", + "random_context_size": false + }, + "train_data_path": "train", + "validation_data_path": "validation", + "test_data_path": "test", + "model": { + "type": "milu", + "label_encoding": "BIO", + "use_unified_datasets": true, + "dropout": 0.3, + "include_start_end_transitions": false, + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 50, + "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.50d.txt.gz", + "trainable": true + }, + "token_characters": { + "type": "character_encoding", + "embedding": { + "embedding_dim": 16 + }, + "encoder": { + "type": "cnn", + "embedding_dim": 16, + "num_filters": 128, + "ngram_filter_sizes": [3], + "conv_layer_activation": "relu" + } + } + } + }, + "encoder": { + "type": "lstm", + "input_size": 178, + "hidden_size": 200, + "num_layers": 1, + "dropout": 0.5, + "bidirectional": true + }, + "intent_encoder": { + "type": "lstm", + "input_size": 400, + "hidden_size": 200, + "num_layers": 1, + "dropout": 0.5, + "bidirectional": true + }, + "attention": { + "type": "bilinear", + "vector_dim": 400, + "matrix_dim": 400 + }, + "context_for_intent": true, + "context_for_tag": false, + "attention_for_intent": false, + "attention_for_tag": false, + "regularizer": [ + [ + "scalar_parameters", + { + "type": "l2", + "alpha": 0.1 + } + ] + ] + }, + "iterator": { + "type": "basic", + "batch_size": 64 + }, + "trainer": { + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "validation_metric": "+f1-measure", + "num_serialized_models_to_keep": 3, + "num_epochs": 40, + "grad_norm": 5.0, + "patience": 75, + "cuda_device": 4 + }, + "evaluate_on_test": true +} diff --git a/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user_context3.jsonnet b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user_context3.jsonnet new file mode 100755 index 0000000000000000000000000000000000000000..1ce1d0d0b4135a2ee03d1ba7dda084f5c22da0c0 --- /dev/null +++ b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user_context3.jsonnet @@ -0,0 +1,104 @@ +{ + "dataset_reader": { + "type": "milu", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + }, + "token_characters": { + "type": "characters", + "min_padding_length": 3 + }, + }, + "context_size": 3, + "agent": "user", + "use_unified_datasets": true, + "dataset_name": "multiwoz21", + "random_context_size": false + }, + "train_data_path": "train", + "validation_data_path": "validation", + "test_data_path": "test", + "model": { + "type": "milu", + "label_encoding": "BIO", + "use_unified_datasets": true, + "dropout": 0.3, + "include_start_end_transitions": false, + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 50, + "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.50d.txt.gz", + "trainable": true + }, + "token_characters": { + "type": "character_encoding", + "embedding": { + "embedding_dim": 16 + }, + "encoder": { + "type": "cnn", + "embedding_dim": 16, + "num_filters": 128, + "ngram_filter_sizes": [3], + "conv_layer_activation": "relu" + } + } + } + }, + "encoder": { + "type": "lstm", + "input_size": 178, + "hidden_size": 200, + "num_layers": 1, + "dropout": 0.5, + "bidirectional": true + }, + "intent_encoder": { + "type": "lstm", + "input_size": 400, + "hidden_size": 200, + "num_layers": 1, + "dropout": 0.5, + "bidirectional": true + }, + "attention": { + "type": "bilinear", + "vector_dim": 400, + "matrix_dim": 400 + }, + "context_for_intent": true, + "context_for_tag": false, + "attention_for_intent": false, + "attention_for_tag": false, + "regularizer": [ + [ + "scalar_parameters", + { + "type": "l2", + "alpha": 0.1 + } + ] + ] + }, + "iterator": { + "type": "basic", + "batch_size": 64 + }, + "trainer": { + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "validation_metric": "+f1-measure", + "num_serialized_models_to_keep": 3, + "num_epochs": 40, + "grad_norm": 5.0, + "patience": 75, + "cuda_device": 0 + }, + "evaluate_on_test": true +} diff --git a/convlab2/nlu/milu/unified_datasets/nlu.py b/convlab2/nlu/milu/unified_datasets/nlu.py new file mode 100755 index 0000000000000000000000000000000000000000..8d5e96852c6369a4db61702c241e2eb4e4cee1a5 --- /dev/null +++ b/convlab2/nlu/milu/unified_datasets/nlu.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +""" + +import os +from pprint import pprint +import torch +from allennlp.common.checks import check_for_gpu +from allennlp.data import DatasetReader +from allennlp.models.archival import load_archive +from allennlp.data.tokenizers import Token + +from convlab2.util.file_util import cached_path +from convlab2.nlu.milu import dataset_reader, model +from convlab2.nlu.nlu import NLU +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer + +DEFAULT_CUDA_DEVICE = -1 +DEFAULT_DIRECTORY = "models" +DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "milu_multiwoz_all_context.tar.gz") + +class MILU(NLU): + """Multi-intent language understanding model.""" + + def __init__(self, + archive_file, + cuda_device, + model_file, + context_size): + """ Constructor for NLU class. """ + + self.context_size = context_size + cuda_device = 0 if torch.cuda.is_available() else DEFAULT_CUDA_DEVICE + check_for_gpu(cuda_device) + + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for MILU is specified!") + + archive_file = cached_path(model_file) + + archive = load_archive(archive_file, + cuda_device=cuda_device) + self.sent_tokenizer = PunktSentenceTokenizer() + self.word_tokenizer = TreebankWordTokenizer() + + dataset_reader_params = archive.config["dataset_reader"] + self.dataset_reader = DatasetReader.from_params(dataset_reader_params) + self.model = archive.model + self.model.eval() + + + def predict(self, utterance, context=list()): + """ + Predict the dialog act of a natural language utterance and apply error model. + Args: + utterance (str): A natural language utterance. + Returns: + output (dict): The dialog act of utterance. + """ + if len(utterance) == 0: + return [] + + if self.context_size > 0 and len(context) > 0: + context_tokens = [] + for utt in context[-self.context_size:]: + for sent in self.sent_tokenizer.tokenize(utt): + for token in self.word_tokenizer.tokenize(sent): + context_tokens.append(Token(token)) + context_tokens.append(Token("SENT_END")) + else: + context_tokens = [Token("SENT_END")] + sentences = self.sent_tokenizer.tokenize(utterance) + tokens = [Token(token) for sent in sentences for token in self.word_tokenizer.tokenize(sent)] + instance = self.dataset_reader.text_to_instance(context_tokens, tokens) + outputs = self.model.forward_on_instance(instance) + + tuples = [] + for da_type in outputs['dialog_act']: + for da in outputs['dialog_act'][da_type]: + tuples.append([da['intent'], da['domain'], da['slot'], da.get('value','')]) + return tuples + + +if __name__ == "__main__": + nlu = MILU(archive_file='../output/multiwoz21_user/model.tar.gz', cuda_device=3, model_file=None, context_size=3) + test_utterances = [ + "What type of accommodations are they. No , i just need their address . Can you tell me if the hotel has internet available ?", + "What type of accommodations are they.", + "No , i just need their address .", + "Can you tell me if the hotel has internet available ?", + "yes. it should be moderately priced.", + "i want to book a table for 6 at 18:45 on thursday", + "i will be departing out of stevenage.", + "What is the name of attraction ?", + "Can I get the name of restaurant?", + "Can I get the address and phone number of the restaurant?", + "do you have a specific area you want to stay in?" + ] + for utt in test_utterances: + print(utt) + pprint(nlu.predict(utt)) diff --git a/setup.py b/setup.py index 62cbf1bd880e55e5dc13b26322dcf67ce6b1b73b..900b92f3912b72831bb027da851da556a9d8ad3d 100755 --- a/setup.py +++ b/setup.py @@ -46,7 +46,6 @@ setup( 'datasets>=1.8', 'seqeval', 'spacy', - 'allennlp', 'simplejson', 'unidecode', 'jieba',