diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index 508d06b56080f72146d65a98286a0fd099f9b4c2..a8915301153f8e824e0a2ed91cd6eb9cd34e2605 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -49,6 +49,9 @@ class Environment(): dialog_act = self.sys_nlu.predict( observation) if self.sys_nlu else observation self.sys_dst.state['user_action'] = dialog_act + self.sys_dst.state['history'].append(["sys", model_response]) + self.sys_dst.state['history'].append(["user", observation]) + state = self.sys_dst.update(dialog_act) self.sys_dst.state['history'].append(["sys", model_response]) self.sys_dst.state['history'].append(["usr", observation]) diff --git a/convlab/dst/trippy/README.md b/convlab/dst/trippy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..178c578287152c894c8727156a8db6e3c255b549 --- /dev/null +++ b/convlab/dst/trippy/README.md @@ -0,0 +1,59 @@ +# Introduction + +This is the TripPy DST module for ConvLab-3. + +## Supported encoders + +* RoBERTa +* BERT (full support w.i.p.) +* ELECTRA (full support w.i.p.) + +## Supported datasets + +* MultiWOZ 2.X +* Unified Data Format + +## Requirements + +transformers (tested: 4.18.0) +torch (tested: 1.8.0) + +# Parameters + +``` +model_type # Default: "roberta", Type of the model (Supported: "roberta", "bert", "electra") +model_name # Default: "roberta-base", Name of the model (Use -h to print a list of names) +model_path # Path to a model checkpoint +dataset_name # Default: "multiwoz21", Name of the dataset the model was trained on and/or is being applied to +local_files_only # Default: False, Set to True to load local files only. Useful for offline systems +nlu_usr_config # Path to a NLU config file. Only needed for internal evaluation +nlu_sys_config # Path to a NLU config file. Only needed for internal evaluation +nlu_usr_path # Path to a NLU model file. Only needed for internal evaluation +nlu_sys_path # Path to a NLU model file. Only needed for internal evaluation +no_eval # Default: True, Set to True if internal evaluation should be conducted +no_history # Default: False, Set to True if dialogue history should be omitted during inference +``` + +# Training + +TripPy can easily be trained for the abovementioned supported datasets using the original code in the official [TripPy repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public). Simply clone the code and run the appropriate DO.* script to train a TripPy DST. After training, set model_path to the preferred checkpoint to use TripPy in ConvLab-3. + +# Training and evaluation with PPO policy + +Switch to the directory: +``` +cd ../../policy/ppo +``` + +Edit trippy_config.json and trippy_config_eval.json accordingly, e.g., edit paths to model checkpoints. + +For training, run +``` +train.py --path trippy_config.json +``` + +For evaluation, set training epochs to 0. + +# Paper + +[TripPy: A Triple Copy Strategy for Value Independent Neural Dialog State Tracking](https://aclanthology.org/2020.sigdial-1.4/) diff --git a/convlab/dst/trippy/__init__.py b/convlab/dst/trippy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb36114aeabe0859fdae1c4230c147a7a40305ff --- /dev/null +++ b/convlab/dst/trippy/__init__.py @@ -0,0 +1 @@ +from convlab.dst.trippy.tracker import TRIPPY diff --git a/convlab/dst/trippy/dataset_interfacer.py b/convlab/dst/trippy/dataset_interfacer.py new file mode 100644 index 0000000000000000000000000000000000000000..51a692bbaa50349a70053f05933a0cea1f8246c3 --- /dev/null +++ b/convlab/dst/trippy/dataset_interfacer.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import logging + + +class DatasetInterfacer(object): + _domain_map_trippy_to_udf = {} + _slot_map_trippy_to_udf = {} + _generic_referral = {} + + def __init__(self): + pass + + def map_trippy_to_udf(self, domain, slot): + d = self._domain_map_trippy_to_udf.get(domain, domain) + s = slot + if d in self._slot_map_trippy_to_udf: + s = self._slot_map_trippy_to_udf[d].get(slot, slot) + return d, s + + def get_generic_referral(self, domain, slot): + d, s = self.map_trippy_to_udf(domain, slot) + ref = "the %s %s" % (d, s) + if d in self._generic_referral: + ref = self._generic_referral[d].get(s, s) + return ref + + def normalize_values(self, text): + return text + + def normalize_text(self, text): + return text + + def normalize_prediction(self, domain, slot, value, predictions=None, config=None): + return value + + +class MultiwozInterfacer(DatasetInterfacer): + _slot_map_trippy_to_udf = { + 'hotel': { + 'pricerange': 'price range', + 'book_stay': 'book stay', + 'book_day': 'book day', + 'book_people': 'book people', + 'addr': 'address', + 'post': 'postcode', + 'price': 'price range', + 'people': 'book people' + }, + 'restaurant': { + 'pricerange': 'price range', + 'book_time': 'book time', + 'book_day': 'book day', + 'book_people': 'book people', + 'addr': 'address', + 'post': 'postcode', + 'price': 'price range', + 'people': 'book people' + }, + 'taxi': { + 'arriveBy': 'arrive by', + 'leaveAt': 'leave at', + 'arrive': 'arrive by', + 'leave': 'leave at', + 'car': 'type', + 'car type': 'type', + 'depart': 'departure', + 'dest': 'destination' + }, + 'train': { + 'arriveBy': 'arrive by', + 'leaveAt': 'leave at', + 'book_people': 'book people', + 'arrive': 'arrive by', + 'leave': 'leave at', + 'depart': 'departure', + 'dest': 'destination', + 'id': 'train id', + 'people': 'book people', + 'time': 'duration', + 'ticket': 'price', + 'trainid': 'train id' + }, + 'attraction': { + 'post': 'postcode', + 'addr': 'address', + 'fee': 'entrance fee', + 'price': 'entrance fee' + }, + 'general': {}, + 'hospital': { + 'post': 'postcode', + 'addr': 'address' + }, + 'police': { + 'post': 'postcode', + 'addr': 'address' + } + } + + _generic_referral = { + 'hotel': { + 'name': 'the hotel', + 'area': 'same area as the hotel', + 'price range': 'in the same price range as the hotel' + }, + 'restaurant': { + 'name': 'the restaurant', + 'area': 'same area as the restaurant', + 'price range': 'in the same price range as the restaurant' + }, + 'attraction': { + 'name': 'the attraction', + 'area': 'same area as the attraction' + } + } + + def normalize_values(self, text): + text = text.lower() + text_to_num = {"zero": "0", "one": "1", "me": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7"} + text = re.sub("\s*(\W)\s*", r"\1" , text) # Re-attach special characters + text = re.sub("s'([^s])", r"s' \1", text) # Add space after plural genitive apostrophe + if text in text_to_num: + text = text_to_num[text] + return text + + def normalize_text(self, text): + norm_text = text.lower() + #norm_text = re.sub("n't", " not", norm_text) # Does not make much of a difference + norm_text = ' '.join([tok for tok in map(str.strip, re.split("(\W+)", norm_text)) if len(tok) > 0]) + return norm_text + + def normalize_prediction(self, domain, slot, value, predictions=None, class_predictions=None, config=None): + v = value + if domain == 'hotel' and slot == 'type': + # Map Boolean predictions to regular predictions. + v = "hotel" if value == "yes" else value + v = "guesthouse" if value == "no" else value + # HOTFIX: Avoid overprediction of hotel type caused by ambiguous rule based user simulator NLG. + if predictions['hotel-name'] != 'none': + v = 'none' + if config.dst_class_types[class_predictions['hotel-none']] == 'request': + v = 'none' + return v + + +DATASET_INTERFACERS = { + 'multiwoz21': MultiwozInterfacer() +} + + +def create_dataset_interfacer(dataset_name="multiwoz21"): + if dataset_name in DATASET_INTERFACERS: + return DATASET_INTERFACERS[dataset_name] + else: + logging.warn("You attempt to create a dataset interfacer for an unknown dataset '%s'. Creating generic dataset interfacer." % (dataset_name)) + return DatasetInterfacer() + + diff --git a/convlab/dst/trippy/modeling_dst.py b/convlab/dst/trippy/modeling_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..2828d17ed1e97ebb60b74ab999c28efb1e7bfa88 --- /dev/null +++ b/convlab/dst/trippy/modeling_dst.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers import (BertModel, BertPreTrainedModel, + RobertaModel, RobertaPreTrainedModel, + ElectraModel, ElectraPreTrainedModel) + +PARENT_CLASSES = { + 'bert': BertPreTrainedModel, + 'roberta': RobertaPreTrainedModel, + 'electra': ElectraPreTrainedModel +} + +MODEL_CLASSES = { + BertPreTrainedModel: BertModel, + RobertaPreTrainedModel: RobertaModel, + ElectraPreTrainedModel: ElectraModel +} + + +class ElectraPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +def TransformerForDST(parent_name): + if parent_name not in PARENT_CLASSES: + raise ValueError("Unknown model %s" % (parent_name)) + + class TransformerForDST(PARENT_CLASSES[parent_name]): + def __init__(self, config): + assert config.model_type in PARENT_CLASSES + assert self.__class__.__bases__[0] in MODEL_CLASSES + super(TransformerForDST, self).__init__(config) + self.model_type = config.model_type + self.slot_list = config.dst_slot_list + self.class_types = config.dst_class_types + self.class_labels = config.dst_class_labels + self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable + self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable + self.stack_token_logits = config.dst_stack_token_logits + self.class_aux_feats_inform = config.dst_class_aux_feats_inform + self.class_aux_feats_ds = config.dst_class_aux_feats_ds + self.class_loss_ratio = config.dst_class_loss_ratio + + # Only use refer loss if refer class is present in dataset. + if 'refer' in self.class_types: + self.refer_index = self.class_types.index('refer') + else: + self.refer_index = -1 + + # Make sure this module has the same name as in the pretrained checkpoint you want to load! + self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config)) + if self.model_type == "electra": + self.pooler = ElectraPooler(config) + + self.dropout = nn.Dropout(config.dst_dropout_rate) + self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) + + if self.class_aux_feats_inform: + self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + if self.class_aux_feats_ds: + self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + + aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 + + for slot in self.slot_list: + self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) + self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) + self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) + + self.init_weights() + + def forward(self, + input_ids, + input_mask=None, + segment_ids=None, + position_ids=None, + head_mask=None, + start_pos=None, + end_pos=None, + inform_slot_id=None, + refer_id=None, + class_label_id=None, + diag_state=None): + outputs = getattr(self, self.model_type)( + input_ids, + attention_mask=input_mask, + token_type_ids=segment_ids, + position_ids=position_ids, + head_mask=head_mask + ) + + sequence_output = outputs[0] + if self.model_type == "electra": + pooled_output = self.pooler(sequence_output) + else: + pooled_output = outputs[1] + + sequence_output = self.dropout(sequence_output) + pooled_output = self.dropout(pooled_output) + + if inform_slot_id is not None: + inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() + if diag_state is not None: + diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) + + total_loss = 0 + per_slot_per_example_loss = {} + per_slot_class_logits = {} + per_slot_start_logits = {} + per_slot_end_logits = {} + per_slot_refer_logits = {} + for slot in self.slot_list: + if self.class_aux_feats_inform and self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) + elif self.class_aux_feats_inform: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) + elif self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) + else: + pooled_output_aux = pooled_output + class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) + + token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) + start_logits, end_logits = token_logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) + + per_slot_class_logits[slot] = class_logits + per_slot_start_logits[slot] = start_logits + per_slot_end_logits[slot] = end_logits + per_slot_refer_logits[slot] = refer_logits + + # If there are no labels, don't compute loss + if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: + # If we are on multi-GPU, split add a dimension + if len(start_pos[slot].size()) > 1: + start_pos[slot] = start_pos[slot].squeeze(-1) + if len(end_pos[slot].size()) > 1: + end_pos[slot] = end_pos[slot].squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) # This is a single index + start_pos[slot].clamp_(0, ignored_index) + end_pos[slot].clamp_(0, ignored_index) + + class_loss_fct = CrossEntropyLoss(reduction='none') + token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) + refer_loss_fct = CrossEntropyLoss(reduction='none') + + if not self.stack_token_logits: + start_loss = token_loss_fct(start_logits, start_pos[slot]) + end_loss = token_loss_fct(end_logits, end_pos[slot]) + else: + start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot]) + end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot]) + + token_loss = (start_loss + end_loss) / 2.0 + + token_is_pointable = (start_pos[slot] > 0).float() + if not self.token_loss_for_nonpointable: + token_loss *= token_is_pointable + + refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) + token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() + if not self.refer_loss_for_nonpointable: + refer_loss *= token_is_referrable + + class_loss = class_loss_fct(class_logits, class_label_id[slot]) + + if self.refer_index > -1: + per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + else: + per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + + total_loss += per_example_loss.sum() + per_slot_per_example_loss[slot] = per_example_loss + + # add hidden states and attention if they are here + outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:] + + return outputs + + return TransformerForDST diff --git a/convlab/dst/trippy/tracker.py b/convlab/dst/trippy/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..b0470266b2ce2c7d5cfba29d0270972b0c7cfa78 --- /dev/null +++ b/convlab/dst/trippy/tracker.py @@ -0,0 +1,544 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import json +import copy +import logging + +import torch +from transformers import (BertConfig, BertTokenizer, + RobertaConfig, RobertaTokenizer, + ElectraConfig, ElectraTokenizer) + +from convlab.dst.dst import DST +from convlab.dst.trippy.modeling_dst import (TransformerForDST) +from convlab.dst.trippy.dataset_interfacer import (create_dataset_interfacer) +from convlab.util import relative_import_module_from_unified_datasets + +MODEL_CLASSES = { + 'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer), + 'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer), + 'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer), +} + + +class TRIPPY(DST): + def print_header(self): + logging.info(" _________ ________ ___ ________ ________ ___ ___ ") + logging.info("|\___ ___\\\ __ \|\ \|\ __ \|\ __ \|\ \ / /|") + logging.info("\|___ \ \_\ \ \|\ \ \ \ \ \|\ \ \ \|\ \ \ \/ / /") + logging.info(" \ \ \ \ \ _ _\ \ \ \ ____\ \ ____\ \ / / ") + logging.info(" \ \ \ \ \ \\\ \\\ \ \ \ \___|\ \ \___|\/ / / ") + logging.info(" \ \__\ \ \__\\\ _\\\ \__\ \__\ \ \__\ __/ / / ") + logging.info(" \|__| \|__|\|__|\|__|\|__| \|__||\___/ / ") + logging.info(" (c) 2022 Heinrich Heine University \|___|/ ") + logging.info("") + + def print_dialog(self, hst): + logging.info("Dialogue %s, turn %s:" % (self.global_diag_cnt, self.global_turn_cnt)) + for utt in hst[:-2]: + logging.info(" \033[92m%s\033[0m" % (utt)) + if len(hst) > 1: + logging.info(" %s" % (hst[-2])) + logging.info(" %s" % (hst[-1])) + + def print_inform_memory(self, inform_mem): + logging.info("Inform memory:") + is_all_none = True + for s in inform_mem: + if inform_mem[s] != 'none': + logging.info(" %s = %s" % (s, inform_mem[s])) + is_all_none = False + if is_all_none: + logging.info(" -") + + def eval_user_acts(self, user_act, user_acts): + logging.info("User acts:") + for ua in user_acts: + if ua not in user_act: + logging.info(" \033[33m%s\033[0m" % (ua)) + else: + logging.info(" \033[92m%s\033[0m" % (ua)) + for ua in user_act: + if ua not in user_acts: + logging.info(" \033[91m%s\033[0m" % (ua)) + + def eval_dialog_state(self, state_updates, new_belief_state): + logging.info("Dialogue state:") + for d in self.gt_belief_state: + logging.info(" %s:" % (d)) + for s in new_belief_state[d]: + is_printed = False + is_updated = False + if state_updates[d][s] > 0: + is_updated = True + log_str = "" + if is_updated: + log_str += "\033[3m" + if new_belief_state[d][s] != self.gt_belief_state[d][s]: + self.global_eval_stats[d][s]['FP'] += 1 + if self.gt_belief_state[d][s] == '': + log_str += " \033[33m%s: %s\033[0m" % (s, new_belief_state[d][s]) + else: + log_str += " \033[91m%s: %s\033[0m (label: %s)" % (s, new_belief_state[d][s] if new_belief_state[d][s] != '' else 'none', self.gt_belief_state[d][s]) + self.global_eval_stats[d][s]['FN'] += 1 + is_printed = True + elif new_belief_state[d][s] != '': + log_str += " \033[92m%s: %s\033[0m" % (s, new_belief_state[d][s]) + self.global_eval_stats[d][s]['TP'] += 1 + is_printed = True + if is_updated: + log_str += " (%s)" % (self.config.dst_class_types[state_updates[d][s]]) + logging.info(log_str) + elif is_printed: + logging.info(log_str) + + def eval_print_stats(self): + logging.info("Statistics:") + for d in self.global_eval_stats: + for s in self.global_eval_stats[d]: + TP = self.global_eval_stats[d][s]['TP'] + FP = self.global_eval_stats[d][s]['FP'] + FN = self.global_eval_stats[d][s]['FN'] + prec = TP / ( TP + FP + 1e-8) + rec = TP / ( TP + FN + 1e-8) + f1 = 2 * ((prec * rec) / (prec + rec + 1e-8)) + logging.info(" %s %s Recall: %.2f, Precision: %.2f, F1: %.2f" % (d, s, rec, prec, f1)) + + def __init__(self, model_type="roberta", + model_name="roberta-base", + model_path="", + dataset_name="multiwoz21", + local_files_only=False, + nlu_usr_config="", + nlu_sys_config="", + nlu_usr_path="", + nlu_sys_path="", + no_eval=True, + no_history=False): + super(TRIPPY, self).__init__() + + self.print_header() + + self.model_type = model_type.lower() + self.model_name = model_name.lower() + self.model_path = model_path + self.local_files_only = local_files_only + self.nlu_usr_config = nlu_usr_config + self.nlu_sys_config = nlu_sys_config + self.nlu_usr_path = nlu_usr_path + self.nlu_sys_path = nlu_sys_path + self.dataset_name = dataset_name + self.no_eval = no_eval + self.no_history = no_history + + assert self.model_type in ['roberta'] # TODO: ensure proper behavior for 'bert', 'electra' + assert self.dataset_name in ['multiwoz21', 'multiwoz22', 'multiwoz23'] # TODO: ensure proper behavior for other datasets + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + _ontology = relative_import_module_from_unified_datasets(self.dataset_name, 'preprocess.py', 'ontology') + self.template_state = _ontology['state'] + + self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[self.model_type] + self.config = self.config_class.from_pretrained(self.model_path, local_files_only=self.local_files_only) + + self.dataset_interfacer = create_dataset_interfacer(dataset_name) + + # For internal evaluation only + self.nlu_usr = None + self.nlu_sys = None + self.global_eval_stats = copy.deepcopy(self.template_state) + for d in self.global_eval_stats: + for s in self.global_eval_stats[d]: + self.global_eval_stats[d][s] = {'TP': 0, 'FP': 0, 'FN': 0} + self.global_diag_cnt = -3 + self.global_turn_cnt = -1 + if not self.no_eval: + global BERTNLU + from convlab.nlu.jointBERT.unified_datasets import BERTNLU + self.load_nlu() + + # For semantic action pipelines only + self.nlg_usr = None + self.nlg_sys = None + + logging.info("DST INIT VARS: %s" % (vars(self))) + + self.load_weights() + + def load_weights(self): + self.tokenizer = self.tokenizer_class.from_pretrained(self.model_name, local_files_only=self.local_files_only) # TODO: do_lower_case? + self.model = self.model_class.from_pretrained(self.model_path, config=self.config, local_files_only=self.local_files_only) + self.model.to(self.device) + self.model.eval() + logging.info("DST model weights loaded from %s" % (self.model_path)) + + def load_nlu(self): + """ Loads NLUs for internal evaluation """ + # NLU for system utterances is used in case the policy does or can not provide semantic actions. + # The sole purpose of this is to fill the inform memory. + # NLU for user utterances is used in case the user simulator does or can not provide semantic actions. + # The sole purpose of this is to enable internal DST evaluation. + if self.nlu_usr_config == self.nlu_sys_config and \ + self.nlu_usr_path == self.nlu_sys_path: + self.nlu_usr = BERTNLU(mode="all", config_file=self.nlu_usr_config, model_file=self.nlu_usr_path) + self.nlu_sys = self.nlu_usr + else: + self.nlu_usr = BERTNLU(mode="user", config_file=self.nlu_usr_config, model_file=self.nlu_usr_path) + self.nlu_sys = BERTNLU(mode="sys", config_file=self.nlu_sys_config, model_file=self.nlu_sys_path) + logging.info("DST user NLU model weights loaded from %s" % (self.nlu_usr_path)) + logging.info("DST sys NLU model weights loaded from %s" % (self.nlu_sys_path)) + + def load_nlg(self): + if self.dataset_name in ['multiwoz21', 'multiwoz22', 'multiwoz23']: + from convlab.nlg.template.multiwoz import TemplateNLG + self.nlg_usr = TemplateNLG(is_user=True) + self.nlg_sys = TemplateNLG(is_user=False) + logging.info("DST template NLG loaded for dataset %s" % (self.dataset_name)) + else: + raise Exception("DST no NLG for dataset %s available." % (self.dataset_name)) + + def init_session(self): + # Initialise empty state + self.state = {'user_action': [], + 'system_action': [], + 'belief_state': {}, + 'booked': {}, + 'request_state': {}, + 'terminated': False, + 'history': []} + self.state['belief_state'] = copy.deepcopy(self.template_state) + self.history = [] + self.ds_aux = {slot: torch.tensor([0]).to(self.device) for slot in self.config.dst_slot_list} + self.gt_belief_state = copy.deepcopy(self.template_state) + self.global_diag_cnt += 1 + self.global_turn_cnt = -1 + + def update_gt_belief_state(self, user_act): + for intent, domain, slot, value in user_act: + if domain == 'police': + continue + if intent == 'inform': + if slot == 'none' or slot == '': + continue + if slot in self.gt_belief_state[domain]: + self.gt_belief_state[domain][slot] = value + + def update(self, user_act=''): + prev_state = self.state + + if not self.no_eval: + logging.info("-" * 40) + + if self.no_history: + self.history = [] + self.history.append(['sys', self.get_text(prev_state['history'][-2][1], is_user=False, normalize=True)]) + self.history.append(['user', self.get_text(prev_state['history'][-1][1], is_user=True, normalize=True)]) + + self.global_turn_cnt += 1 + if not self.no_eval: + self.print_dialog(self.history) + + # --- Get inform memory and auxiliary features --- + + # system_action is a list of semantic system actions. + # TripPy uses system actions to fill the inform memory. + # End-to-end policies like Lava produce plain text instead. + # If system_action is plain text, get acts using NLU. + if isinstance(prev_state['system_action'], str): + s_acts = self.get_acts(prev_state['system_action']) + elif isinstance(prev_state['system_action'], list): + s_acts = prev_state['system_action'] + else: + raise Exception('Unknown format for system action:', prev_state['system_action']) + + if not self.no_eval: + # user_action is a list of semantic user actions if no NLG is used + # in the pipeline, otherwise user_action is plain text. + # TripPy uses user actions to perform internal DST evaluation. + # If user_action is plain text, get acts using NLU. + if isinstance(prev_state['user_action'], str): + u_acts = self.get_acts(prev_state['user_action'], is_user=True) + elif isinstance(prev_state['user_action'], list): + u_acts = prev_state['user_action'] # This is the same as user_act + else: + raise Exception('Unknown format for user action:', prev_state['user_action']) + + # Fill the inform memory. + inform_aux, inform_mem = self.get_inform_aux(s_acts) + if not self.no_eval: + self.print_inform_memory(inform_mem) + + # --- Tokenize dialogue context and feed DST model --- + + used_ds_aux = None if not self.config.dst_class_aux_feats_ds else self.ds_aux + used_inform_aux = None if not self.config.dst_class_aux_feats_inform else inform_aux + features = self.get_features(self.history, ds_aux=used_ds_aux, inform_aux=used_inform_aux) + pred_states, pred_classes = self.predict(features, inform_mem) + + # --- Update ConvLab-style dialogue state --- + + new_belief_state = copy.deepcopy(prev_state['belief_state']) + user_acts = [] + for state, value in pred_states.items(): + value = self.dataset_interfacer.normalize_values(value) + domain, slot = state.split('-', 1) + value = self.dataset_interfacer.normalize_prediction(domain, slot, value, + predictions=pred_states, + class_predictions=pred_classes, + config=self.config) + if value == 'none': + continue + if slot in new_belief_state[domain]: + new_belief_state[domain][slot] = value + user_acts.append(['inform', domain, slot, value]) + else: + raise Exception('Unknown slot name <{}> with value <{}> of domain <{}>'.format(slot, value, domain)) + + if not self.no_eval: + self.update_gt_belief_state(u_acts) # For evaluation + + # BELIEF STATE UPDATE + new_state = copy.deepcopy(dict(prev_state)) + new_state['belief_state'] = new_belief_state # TripPy + + state_updates = {} + for cl in pred_classes: + cl_d, cl_s = cl.split('-') + # Some reformatting for the evaluation further down + if cl_d not in state_updates: + state_updates[cl_d] = {} + state_updates[cl_d][cl_s] = pred_classes[cl] + # We care only about the requestable slots here + if self.config.dst_class_types[pred_classes[cl]] != 'request': + continue + if cl_d != 'general' and cl_s == 'none': + user_acts.append(['inform', cl_d, '', '']) + elif cl_d == 'general': + user_acts.append([cl_s, 'general', '', '']) + else: + user_acts.append(['request', cl_d, cl_s, '']) + + # USER ACTS UPDATE + new_state['user_action'] = user_acts # TripPy + + if not self.no_eval: + self.eval_user_acts(u_acts, user_acts) + self.eval_dialog_state(state_updates, new_belief_state) + + self.state = new_state + + # Print eval statistics + if self.state['terminated'] and not self.no_eval: + logging.info("Booked: %s" % self.state['booked']) + self.eval_print_stats() + logging.info("=" * 10 + "End of the dialogue" + "=" * 10) + self.ds_aux = self.update_ds_aux(self.state['belief_state'], pred_states) + + return self.state + + def predict(self, features, inform_mem): + with torch.no_grad(): + outputs = self.model(input_ids=features['input_ids'], + input_mask=features['attention_mask'], + inform_slot_id=features['inform_slot_id'], + diag_state=features['diag_state']) + + input_tokens = self.tokenizer.convert_ids_to_tokens(features['input_ids'][0]) # unmasked! + + per_slot_class_logits = outputs[2] + per_slot_start_logits = outputs[3] + per_slot_end_logits = outputs[4] + per_slot_refer_logits = outputs[5] + + # TODO: maybe add assert to check that batch=1 + + predictions = {} + class_predictions = {} + + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + slot_udf = "%s-%s" % (self.dataset_interfacer.map_trippy_to_udf(d, s)) + + predictions[slot_udf] = 'none' + class_predictions[slot_udf] = 0 + + class_logits = per_slot_class_logits[slot][0] + start_logits = per_slot_start_logits[slot][0] + end_logits = per_slot_end_logits[slot][0] + refer_logits = per_slot_refer_logits[slot][0] + + class_prediction = int(class_logits.argmax()) + start_prediction = int(start_logits.argmax()) + end_prediction = int(end_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if class_prediction == self.config.dst_class_types.index('dontcare'): + predictions[slot_udf] = 'dontcare' + elif class_prediction == self.config.dst_class_types.index('copy_value'): + predictions[slot_udf] = ' '.join(input_tokens[start_prediction:end_prediction + 1]) + predictions[slot_udf] = re.sub("(^| )##", "", predictions[slot_udf]) + if "\u0120" in predictions[slot_udf]: + predictions[slot_udf] = re.sub(" ", "", predictions[slot_udf]) + predictions[slot_udf] = re.sub("\u0120", " ", predictions[slot_udf]) + predictions[slot_udf] = predictions[slot_udf].strip() + elif 'true' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('true'): + predictions[slot_udf] = "yes" # 'true' + elif 'false' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('false'): + predictions[slot_udf] = "no" # 'false' + elif class_prediction == self.config.dst_class_types.index('inform'): + predictions[slot_udf] = inform_mem[slot_udf] + # Referral case is handled below + + # Referral case. All other slot values need to be seen first in order + # to be able to do this correctly. + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + slot_udf = "%s-%s" % (self.dataset_interfacer.map_trippy_to_udf(d, s)) + + class_logits = per_slot_class_logits[slot][0] + refer_logits = per_slot_refer_logits[slot][0] + + class_prediction = int(class_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if 'refer' in self.config.dst_class_types and class_prediction == self.config.dst_class_types.index('refer'): + # Only slots that have been mentioned before can be referred to. + # First try to resolve a reference within the same turn. (One can think of a situation + # where one slot is referred to in the same utterance. This phenomenon is however + # currently not properly covered in the training data label generation process) + # Then try to resolve a reference given the current dialogue state. + referred_slot = self.config.dst_slot_list[refer_prediction - 1] + referred_slot_d, referred_slot_s = referred_slot.split('-') + referred_slot_d, referred_slot_s = self.dataset_interfacer.map_trippy_to_udf(referred_slot_d, referred_slot_s) + referred_slot_udf = "%s-%s" % (referred_slot_d, referred_slot_s) + predictions[slot_udf] = predictions[referred_slot_udf] + if predictions[slot_udf] == 'none': + if self.state['belief_state'][referred_slot_d][referred_slot_s] != '': + predictions[slot_udf] = self.state['belief_state'][referred_slot_d][referred_slot_s] + if predictions[slot_udf] == 'none': + ref_slot = self.config.dst_slot_list[refer_prediction - 1] + ref_slot_d, ref_slot_s = ref_slot.split('-') + generic_ref = self.dataset_interfacer.get_generic_referral(ref_slot_d, ref_slot_s) + predictions[slot_udf] = generic_ref + + class_predictions[slot_udf] = class_prediction + + return predictions, class_predictions + + def get_features(self, context, ds_aux=None, inform_aux=None): + assert(self.model_type == "roberta") # TODO: generalize to other BERT-like models + input_tokens = ['<s>'] # TODO: use tokenizer token names rather than strings + e_itr = 0 + for e_itr, e in enumerate(reversed(context)): + if e[1] not in ['null', '']: + input_tokens.append(e[1]) + if e_itr < 2: + input_tokens.append('</s> </s>') + if e_itr == 0: + input_tokens.append('</s> </s>') + input_tokens.append('</s>') + input_tokens = ' '.join(input_tokens) + + # TODO: delex sys utt currently not supported + features = self.tokenizer.encode_plus(input_tokens, add_special_tokens=False, max_length=self.config.dst_max_seq_length) + + input_ids = torch.tensor(features['input_ids']).reshape(1,-1).to(self.device) + attention_mask = torch.tensor(features['attention_mask']).reshape(1,-1).to(self.device) + features = {'input_ids': input_ids, + 'attention_mask': attention_mask, + 'inform_slot_id': inform_aux, + 'diag_state': ds_aux} + + return features + + def update_ds_aux(self, state, pred_states, terminated=False): + ds_aux = copy.deepcopy(self.ds_aux) + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + d_udf, s_udf = self.dataset_interfacer.map_trippy_to_udf(d, s) + slot_udf = "%s-%s" % (d_udf, s_udf) + if d_udf in state and s_udf in state[d_udf]: + ds_aux[slot][0] = int(state[d_udf][s_udf] != '') + else: + # Requestable slots are not found in the DS + ds_aux[slot][0] = int(pred_states[slot_udf] != 'none') + return ds_aux + + def get_inform_aux(self, state): + # Initialise auxiliary variables. + # For inform_aux, only the proper order of slots + # as defined in dst_slot_list is relevant, but not + # the actual slot names (as inform_aux will be + # converted into a simple binary list in the model) + inform_aux = {} + inform_mem = {} + for slot in self.config.dst_slot_list: + d, s = slot.split('-') + d_udf, s_udf = self.dataset_interfacer.map_trippy_to_udf(d, s) + inform_aux["%s-%s" % (d_udf, s_udf)] = torch.tensor([0]).to(self.device) + inform_mem["%s-%s" % (d_udf, s_udf)] = 'none' + for e in state: + a, d, s, v = e + # TODO: offerbook needed? booked needed? + if a in ['inform', 'recommend', 'select', 'book', 'offerbook']: + slot = "%s-%s" % (d, s) + if slot in inform_aux: + inform_aux[slot][0] = 1 + inform_mem[slot] = self.dataset_interfacer.normalize_values(v) + return inform_aux, inform_mem + + def get_acts(self, act, is_user=False): + if isinstance(act, list): + return act + context = self.state['history'] + if context[-1][0] not in ['user', 'usr']: + raise Exception("Wrong order of utterances, check your input.") + system_context = [self.get_text(t) for s,t in context[:-2]] + user_context = [self.get_text(t, is_user=True) for s,t in context[:-1]] + if is_user: + if self.nlu_usr is None: + raise Exception("You attempt to convert semantic user actions into text, but no NLU module is loaded.") + acts = self.nlu_usr.predict(act, context=user_context) + else: + if self.nlu_sys is None: + raise Exception("You attempt to convert semantic system actions into text, but no NLU module is loaded.") + acts = self.nlu_sys.predict(act, context=system_context) + for act_itr in range(len(acts)): + acts[act_itr][-1] = self.dataset_interfacer.normalize_values(acts[act_itr][-1]) + return acts + + def get_text(self, act, is_user=False, normalize=False): + if act == 'null': + return 'null' + if not isinstance(act, list): + result = act + else: + if self.nlg_usr is None or self.nlg_sys is None: + logging.warn("You attempt to input semantic actions into TripPy, which expects text.") + logging.warn("Attempting to load NLG modules in order to convert actions into text.") + self.load_nlg() + if is_user: + result = self.nlg_usr.generate(act) + else: + result = self.nlg_sys.generate(act) + if normalize: + return self.dataset_interfacer.normalize_text(result) + else: + return result diff --git a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json index 414f0b80b740d0c33568ba21bab5e13e822d95df..300369b525d3d5c4781ea00587f198e56657a37b 100755 --- a/convlab/nlg/template/multiwoz/manual_system_template_nlg.json +++ b/convlab/nlg/template/multiwoz/manual_system_template_nlg.json @@ -225,7 +225,7 @@ "Do you know the name of it ?", "can you give me the name of it ?" ], - "Price": [ + "Fee": [ "any specific price range to help narrow down available options ?", "What price range would you like ?", "what is your price range for that ?", @@ -363,42 +363,6 @@ "I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ." ] }, - "Booking-Request": { - "Day": [ - "What day would you like your booking for ?", - "What day would you like that reservation ?", - "what day would you like the booking to be made for ?", - "What day would you like to book ?", - "Ok , what day would you like to make the reservation on ?" - ], - "Stay": [ - "How many nights will you be staying ?", - "And how many nights ?", - "for how many days ?", - "And for how many days ?", - "how many days would you like to stay ?", - "How many nights would you like to book it for ?", - "And what nights would you like me to reserve for you ?", - "How many nights are you wanting to stay ?", - "How many days will you be staying ?" - ], - "People": [ - "For how many people ?", - "How many people will be ?", - "How many people will be with you ?", - "How many people is the reservation for ?" - ], - "Time": [ - "Do you have a time preference ?", - "what time are you looking for a reservation at ?", - "For what time ?", - "What time would you like me to make your reservation ?", - "What time would you like the reservation for ?", - "what time should I make the reservation for ?", - "What time would you prefer ?", - "What time would you like the reservation for ?" - ] - }, "Hotel-Inform": { "Internet": [ "it has wifi .", @@ -697,6 +661,30 @@ "Do you need free parking ?", "Will you need parking while you 're there ?", "Will you be needing free parking ?" + ], + "Day": [ + "What day would you like your booking for ?", + "What day would you like that reservation ?", + "what day would you like the booking to be made for ?", + "What day would you like to book ?", + "Ok , what day would you like to make the reservation on ?" + ], + "Stay": [ + "How many nights will you be staying ?", + "And how many nights ?", + "for how many days ?", + "And for how many days ?", + "how many days would you like to stay ?", + "How many nights would you like to book it for ?", + "And what nights would you like me to reserve for you ?", + "How many nights are you wanting to stay ?", + "How many days will you be staying ?" + ], + "People": [ + "For how many people ?", + "How many people will be ?", + "How many people will be with you ?", + "How many people is the reservation for ?" ] }, "Restaurant-Inform": { @@ -918,6 +906,29 @@ "what is the name of the restaurant you are needing information on ?", "Do you know the name of the location ?", "Is there a certain restaurant you 're looking for ?" + ], + "Day": [ + "What day would you like your booking for ?", + "What day would you like that reservation ?", + "what day would you like the booking to be made for ?", + "What day would you like to book ?", + "Ok , what day would you like to make the reservation on ?" + ], + "People": [ + "For how many people ?", + "How many people will be ?", + "How many people will be with you ?", + "How many people is the reservation for ?" + ], + "Time": [ + "Do you have a time preference ?", + "what time are you looking for a reservation at ?", + "For what time ?", + "What time would you like me to make your reservation ?", + "What time would you like the reservation for ?", + "what time should I make the reservation for ?", + "What time would you prefer ?", + "What time would you like the reservation for ?" ] }, "Taxi-Inform": { @@ -1331,6 +1342,77 @@ "Is there a time you need to arrive by ?" ] }, + "Police-Inform": { + "Addr": [ + "it is located in #POLICE-INFORM-ADDR#", + "adress is #POLICE-INFORM-ADDR#", + "It is on #POLICE-INFORM-ADDR# .", + "their address in our system is listed as #POLICE-INFORM-ADDR# .", + "The address is #POLICE-INFORM-ADDR# .", + "it 's located at #POLICE-INFORM-ADDR# .", + "#POLICE-INFORM-ADDR# is the address", + "They are located at #POLICE-INFORM-ADDR# ." + ], + "Post": [ + "The postcode of the police is #POLICE-INFORM-POST# .", + "The post code is #POLICE-INFORM-POST# .", + "Its postcode is #POLICE-INFORM-POST# .", + "Their postcode is #POLICE-INFORM-POST# ." + ], + "Name": [ + "I think a fun place to visit is #POLICE-INFORM-NAME# .", + "#POLICE-INFORM-NAME# looks good .", + "#POLICE-INFORM-NAME# is available , would that work for you ?", + "we have #POLICE-INFORM-NAME# .", + "#POLICE-INFORM-NAME# is popular among visitors .", + "How about #POLICE-INFORM-NAME# ?", + "What about #POLICE-INFORM-NAME# ?", + "you might want to try the #POLICE-INFORM-NAME# ." + ], + "Phone": [ + "The police phone number is #POLICE-INFORM-PHONE# .", + "Here is the police phone number , #POLICE-INFORM-PHONE# ." + ] + }, + "Hospital-Inform": { + "Addr": [ + "it is located in #HOSPITAL-INFORM-ADDR#", + "adress is #HOSPITAL-INFORM-ADDR#", + "It is on #HOSPITAL-INFORM-ADDR# .", + "their address in our system is listed as #HOSPITAL-INFORM-ADDR# .", + "The address is #HOSPITAL-INFORM-ADDR# .", + "it 's located at #HOSPITAL-INFORM-ADDR# .", + "#HOSPITAL-INFORM-ADDR# is the address", + "They are located at #HOSPITAL-INFORM-ADDR# ." + ], + "Post": [ + "The postcode of the hospital is #HOSPITAL-INFORM-POST# .", + "The post code is #HOSPITAL-INFORM-POST# .", + "Its postcode is #HOSPITAL-INFORM-POST# .", + "Their postcode is #HOSPITAL-INFORM-POST# ." + ], + "Department": [ + "The department of the hospital is #HOSPITAL-INFORM-POST# .", + "The department is #HOSPITAL-INFORM-POST# .", + "Its department is #HOSPITAL-INFORM-POST# .", + "Their department is #HOSPITAL-INFORM-POST# ." + + ], + "Phone": [ + "The hospital phone number is #HOSPITAL-INFORM-PHONE# .", + "Here is the hospital phone number , #HOSPITAL-INFORM-PHONE# ." + ] + }, + "Hospital-Request": { + "Department": [ + "What is the name of the hospital department ?", + "What hospital department are you thinking about ?", + "I ' m sorry for the confusion , what hospital department are you interested in ?", + "What hospital department were you thinking of ?", + "Do you know the department of it ?", + "can you give me the department of it ?" + ] + }, "general-bye": { "none": [ "Thank you for using our services .", @@ -1378,4 +1460,4 @@ "You 're welcome . Have a good day !" ] } -} \ No newline at end of file +} diff --git a/convlab/nlg/template/multiwoz/nlg.py b/convlab/nlg/template/multiwoz/nlg.py index f83a6db4e2ad6f77f9bdc154cfda7bf2db0ff2c5..5f362ebbabd68ee0b6fab62c692caf0e8da436e9 100755 --- a/convlab/nlg/template/multiwoz/nlg.py +++ b/convlab/nlg/template/multiwoz/nlg.py @@ -31,33 +31,33 @@ def read_json(filename): # supported slot Slot2word = { - 'Fee': 'fee', + 'Fee': 'entrance fee', 'Addr': 'address', 'Area': 'area', - 'Stars': 'stars', - 'Internet': 'Internet', + 'Stars': 'number of stars', + 'Internet': 'internet', 'Department': 'department', 'Choice': 'choice', 'Ref': 'reference number', 'Food': 'food', 'Type': 'type', 'Price': 'price range', - 'Stay': 'stay', + 'Stay': 'length of the stay', 'Phone': 'phone number', 'Post': 'postcode', 'Day': 'day', 'Name': 'name', 'Car': 'car type', - 'Leave': 'leave', + 'Leave': 'departure time', 'Time': 'time', - 'Arrive': 'arrive', - 'Ticket': 'ticket', + 'Arrive': 'arrival time', + 'Ticket': 'ticket price', 'Depart': 'departure', - 'People': 'people', + 'People': 'number of people', 'Dest': 'destination', 'Parking': 'parking', - 'Open': 'open', - 'Id': 'Id', + 'Open': 'opening hours', + 'Id': 'id', # 'TrainID': 'TrainID' } @@ -271,6 +271,10 @@ class TemplateNLG(NLG): elif 'request' == intent[1]: for slot, value in slot_value_pairs: if dialog_act not in template or slot not in template[dialog_act]: + if dialog_act not in template: + print("WARNING (nlg.py): (User?: %s) dialog_act '%s' not in template!" % (self.is_user, dialog_act)) + else: + print("WARNING (nlg.py): (User?: %s) slot '%s' of dialog_act '%s' not in template!" % (self.is_user, slot, dialog_act)) sentence = 'What is the {} of {} ? '.format( slot.lower(), dialog_act.split('-')[0].lower()) sentences += self._add_random_noise(sentence) @@ -288,7 +292,7 @@ class TemplateNLG(NLG): value_lower = value.lower() if value in ["do nt care", "do n't care", "dontcare"]: sentence = 'I don\'t care about the {} of the {}'.format( - slot, dialog_act.split('-')[0]) + slot2word.get(slot, slot), dialog_act.split('-')[0]) elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any': # user have no preference, any choice is ok sentence = random.choice([ diff --git a/convlab/nlu/jointBERT/multiwoz/nlu.py b/convlab/nlu/jointBERT/multiwoz/nlu.py index e25fbad1227c4b1f85ae2ae42a8ac899fa61d7b8..1373919e5861156c87a2ba14d6506d13e0842204 100755 --- a/convlab/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab/nlu/jointBERT/multiwoz/nlu.py @@ -74,7 +74,8 @@ class BERTNLU(NLU): for token in token_list: token = token.strip() self.nlp.tokenizer.add_special_case( - token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) + #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) + token, [{ORTH: token}]) logging.info("BERTNLU loaded") def predict(self, utterance, context=list()): diff --git a/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_sys_context3.json b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_sys_context3.json new file mode 100755 index 0000000000000000000000000000000000000000..dfbef5a39963f030ccf989276fcaf7efb8141cc1 --- /dev/null +++ b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_sys_context3.json @@ -0,0 +1,27 @@ +{ + "dataset_name": "multiwoz21", + "data_dir": "unified_datasets/data/multiwoz21/system/context_window_size_3", + "output_dir": "unified_datasets/output/multiwoz21/system/context_window_size_3", + "zipped_model_path": "unified_datasets/output/multiwoz21/system/context_window_size_3/bertnlu_unified_multiwoz21_system_context3.zip", + "log_dir": "unified_datasets/output/multiwoz21/system/context_window_size_3/log", + "DEVICE": "cuda:0", + "seed": 2019, + "cut_sen_len": 40, + "use_bert_tokenizer": true, + "context_window_size": 3, + "model": { + "finetune": true, + "context": true, + "context_grad": true, + "pretrained_weights": "bert-base-uncased", + "check_step": 1000, + "max_step": 10000, + "batch_size": 128, + "learning_rate": 1e-4, + "adam_epsilon": 1e-8, + "warmup_steps": 0, + "weight_decay": 0.0, + "dropout": 0.1, + "hidden_units": 1536 + } +} diff --git a/convlab/policy/mle/loader.py b/convlab/policy/mle/loader.py index bb898ab4a30ad957db632b4eb74fca581f05f05b..9c10d2a7a3841efb55ba418e5cf56708e4e94b7d 100755 --- a/convlab/policy/mle/loader.py +++ b/convlab/policy/mle/loader.py @@ -56,13 +56,50 @@ class PolicyDataVectorizer: state['belief_state'] = data_point['context'][-1]['state'] state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts']) - else: + elif "setsumbt" in str(self.dst): last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else '' self.dst.state['history'].append(['sys', last_system_utt]) usr_utt = data_point['context'][-1]['utterance'] state = deepcopy(self.dst.update(usr_utt)) self.dst.state['history'].append(['usr', usr_utt]) + elif "trippy" in str(self.dst): + # Get last system acts and text. + # System acts are used to fill the inform memory. + last_system_acts = [] + last_system_utt = '' + if len(data_point['context']) > 1: + last_system_acts = [] + for act_type in data_point['context'][-2]['dialogue_acts']: + for act in data_point['context'][-2]['dialogue_acts'][act_type]: + value = '' + if 'value' not in act: + if act['intent'] == 'request': + value = '?' + elif act['intent'] == 'inform': + value = 'yes' + else: + value = act['value'] + last_system_acts.append([act['intent'], act['domain'], act['slot'], value]) + last_system_utt = data_point['context'][-2]['utterance'] + + # Get current user acts and text. + # User acts are used for internal evaluation. + usr_acts = [] + for act_type in data_point['context'][-1]['dialogue_acts']: + for act in data_point['context'][-1]['dialogue_acts'][act_type]: + usr_acts.append([act['intent'], act['domain'], act['slot'], act['value'] if 'value' in act else '']) + usr_utt = data_point['context'][-1]['utterance'] + + # Update the state for DST, then update the state via DST. + self.dst.state['system_action'] = last_system_acts + self.dst.state['user_action'] = usr_acts + self.dst.state['history'].append(['sys', last_system_utt]) + self.dst.state['history'].append(['usr', usr_utt]) + state = deepcopy(self.dst.update(usr_utt)) + else: + raise NameError(f"Tracker: {self.dst} not implemented.") + last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {} state['system_action'] = flatten_acts(last_system_act) state['terminated'] = data_point['terminated'] diff --git a/convlab/policy/mle/train.py b/convlab/policy/mle/train.py index c2477760c8a189cbacd978151755fa790996a421..5253f95d9709fe19071ab7e36ed5a7339597bbe9 100755 --- a/convlab/policy/mle/train.py +++ b/convlab/policy/mle/train.py @@ -12,6 +12,7 @@ from convlab.util.custom_util import set_seed, init_logging, save_config from convlab.util.train_util import to_device from convlab.policy.rlmodule import MultiDiscretePolicy from convlab.policy.vector.vector_binary import VectorBinary +from convlab.policy.vector.vector_binary_fuzzy import VectorBinaryFuzzy root_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) @@ -195,8 +196,15 @@ if __name__ == '__main__': use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info) else: vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking) + elif args.dst == "trippy": + dst_args = [arg.split('=', 1) for arg in args.dst_args.split(', ') + if '=' in arg] if args.dst_args is not None else [] + dst_args = {key: eval(value) for key, value in dst_args} + from convlab.dst.trippy import TRIPPY + dst = TRIPPY(**dst_args) + vector = VectorBinaryFuzzy(dataset_name=args.dataset_name, use_masking=args.use_masking) else: - raise NameError(f"Tracker: {args.tracker} not implemented.") + raise NameError(f"Tracker: {args.dst} not implemented.") manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst) agent = MLE_Trainer(manager, vector, cfg) diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/trippy_config.json new file mode 100644 index 0000000000000000000000000000000000000000..41b1c3623aca944312c6389e55e34d72422fb6e0 --- /dev/null +++ b/convlab/policy/ppo/trippy_config.json @@ -0,0 +1,73 @@ +{ + "model": { + "load_path": "/path/to/model/checkpoint", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 1000, + "seed": 0, + "epoch": 50, + "eval_frequency": 5, + "process_num": 2, + "num_eval_dialogues": 500, + "sys_semantic_to_usr": false + }, + "vectorizer_sys": { + "fuzzy_vector_mul": { + "class_path": "convlab.policy.vector.vector_binary_fuzzy.VectorBinaryFuzzy", + "ini_params": { + "use_masking": true, + "manually_add_entity_names": true, + "seed": 0 + } + } + }, + "nlu_sys": {}, + "dst_sys": { + "TripPy": { + "class_path": "convlab.dst.trippy.TRIPPY", + "ini_params": { + "model_type": "roberta", + "model_name": "roberta-base", + "model_path": "/path/to/model/checkpoint", + "dataset_name": "multiwoz21" + } + } + }, + "sys_nlg": { + "TemplateNLG": { + "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", + "ini_params": { + "is_user": false + } + } + }, + "nlu_usr": { + "BERTNLU": { + "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU", + "ini_params": { + "mode": "sys", + "config_file": "multiwoz21_sys_context3.json", + "model_file": "/path/to/model/checkpoint.zip" + } + } + }, + "dst_usr": {}, + "policy_usr": { + "RulePolicy": { + "class_path": "convlab.policy.rule.multiwoz.RulePolicy", + "ini_params": { + "character": "usr" + } + } + }, + "usr_nlg": { + "TemplateNLG": { + "class_path": "convlab.nlg.template.multiwoz.TemplateNLG", + "ini_params": { + "is_user": true, + "label_noise": 0.0, + "text_noise": 0.0 + } + } + } +} diff --git a/convlab/policy/vector/vector_binary_fuzzy.py b/convlab/policy/vector/vector_binary_fuzzy.py new file mode 100755 index 0000000000000000000000000000000000000000..5314cf1303f99f0ac2d0fbc5a2bca34c4e3d74c4 --- /dev/null +++ b/convlab/policy/vector/vector_binary_fuzzy.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +import sys +import numpy as np +from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da +from .vector_binary import VectorBinary + + +class VectorBinaryFuzzy(VectorBinary): + + def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True, + seed=0): + + super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed) + + def dbquery_domain(self, domain): + """ + query entities of specified domain + Args: + domain string: + domain to query + Returns: + entities list: + list of entities of the specified domain + """ + # Get all user constraints + constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \ + if domain in self.state else [] + xx = self.db.query(domain=domain, state=[], soft_contraints=constraints, fuzzy_match_ratio=100, topk=10) + yy = self.db.query(domain=domain, state=constraints, topk=10) + #print("STRICT:", yy) + #print("FUZZY :", xx) + #if len(yy) == 1 and len(xx) > 1: + # import pdb + # pdb.set_trace() + return xx + #return self.db.query(domain=domain, state=[], soft_contraints=constraints, fuzzy_match_ratio=100, topk=10) diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py index a23908ee13b460614213ecaad3f747fc791deff3..f599f95acfb9e89efc11572f3fb779f25900604e 100644 --- a/data/unified_datasets/multiwoz21/database.py +++ b/data/unified_datasets/multiwoz21/database.py @@ -57,18 +57,23 @@ class Database(BaseDatabase): for key, val in state: if key == 'department': department = val + if not department: + for key, val in soft_contraints: + if key == 'department': + department = val if not department: return deepcopy(self.dbs['hospital']) else: return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()] state = list(map(lambda ele: (self.slot2dbattr.get(ele[0], ele[0]), ele[1]) if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), state)) + soft_contraints = list(map(lambda ele: (self.slot2dbattr.get(ele[0], ele[0]), ele[1]) if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), soft_contraints)) found = [] for i, record in enumerate(self.dbs[domain]): constraints_iterator = zip(state, [False] * len(state)) soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints)) for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator): - if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]: + if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care", "do not care"]: pass else: try: