Skip to content
Snippets Groups Projects
Commit 0a23ba4f authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Merge branch 'trippy_merge' into 'github_master'

# Conflicts:
#   convlab/dialog_agent/env.py
parents 0d233eee 2084da56
No related branches found
No related tags found
No related merge requests found
Showing
with 1325 additions and 53 deletions
...@@ -49,6 +49,9 @@ class Environment(): ...@@ -49,6 +49,9 @@ class Environment():
dialog_act = self.sys_nlu.predict( dialog_act = self.sys_nlu.predict(
observation) if self.sys_nlu else observation observation) if self.sys_nlu else observation
self.sys_dst.state['user_action'] = dialog_act self.sys_dst.state['user_action'] = dialog_act
self.sys_dst.state['history'].append(["sys", model_response])
self.sys_dst.state['history'].append(["user", observation])
state = self.sys_dst.update(dialog_act) state = self.sys_dst.update(dialog_act)
self.sys_dst.state['history'].append(["sys", model_response]) self.sys_dst.state['history'].append(["sys", model_response])
self.sys_dst.state['history'].append(["usr", observation]) self.sys_dst.state['history'].append(["usr", observation])
......
# Introduction
This is the TripPy DST module for ConvLab-3.
## Supported encoders
* RoBERTa
* BERT (full support w.i.p.)
* ELECTRA (full support w.i.p.)
## Supported datasets
* MultiWOZ 2.X
* Unified Data Format
## Requirements
transformers (tested: 4.18.0)
torch (tested: 1.8.0)
# Parameters
```
model_type # Default: "roberta", Type of the model (Supported: "roberta", "bert", "electra")
model_name # Default: "roberta-base", Name of the model (Use -h to print a list of names)
model_path # Path to a model checkpoint
dataset_name # Default: "multiwoz21", Name of the dataset the model was trained on and/or is being applied to
local_files_only # Default: False, Set to True to load local files only. Useful for offline systems
nlu_usr_config # Path to a NLU config file. Only needed for internal evaluation
nlu_sys_config # Path to a NLU config file. Only needed for internal evaluation
nlu_usr_path # Path to a NLU model file. Only needed for internal evaluation
nlu_sys_path # Path to a NLU model file. Only needed for internal evaluation
no_eval # Default: True, Set to True if internal evaluation should be conducted
no_history # Default: False, Set to True if dialogue history should be omitted during inference
```
# Training
TripPy can easily be trained for the abovementioned supported datasets using the original code in the official [TripPy repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public). Simply clone the code and run the appropriate DO.* script to train a TripPy DST. After training, set model_path to the preferred checkpoint to use TripPy in ConvLab-3.
# Training and evaluation with PPO policy
Switch to the directory:
```
cd ../../policy/ppo
```
Edit trippy_config.json and trippy_config_eval.json accordingly, e.g., edit paths to model checkpoints.
For training, run
```
train.py --path trippy_config.json
```
For evaluation, set training epochs to 0.
# Paper
[TripPy: A Triple Copy Strategy for Value Independent Neural Dialog State Tracking](https://aclanthology.org/2020.sigdial-1.4/)
from convlab.dst.trippy.tracker import TRIPPY
# coding=utf-8
#
# Copyright 2020-2022 Heinrich Heine University Duesseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import logging
class DatasetInterfacer(object):
_domain_map_trippy_to_udf = {}
_slot_map_trippy_to_udf = {}
_generic_referral = {}
def __init__(self):
pass
def map_trippy_to_udf(self, domain, slot):
d = self._domain_map_trippy_to_udf.get(domain, domain)
s = slot
if d in self._slot_map_trippy_to_udf:
s = self._slot_map_trippy_to_udf[d].get(slot, slot)
return d, s
def get_generic_referral(self, domain, slot):
d, s = self.map_trippy_to_udf(domain, slot)
ref = "the %s %s" % (d, s)
if d in self._generic_referral:
ref = self._generic_referral[d].get(s, s)
return ref
def normalize_values(self, text):
return text
def normalize_text(self, text):
return text
def normalize_prediction(self, domain, slot, value, predictions=None, config=None):
return value
class MultiwozInterfacer(DatasetInterfacer):
_slot_map_trippy_to_udf = {
'hotel': {
'pricerange': 'price range',
'book_stay': 'book stay',
'book_day': 'book day',
'book_people': 'book people',
'addr': 'address',
'post': 'postcode',
'price': 'price range',
'people': 'book people'
},
'restaurant': {
'pricerange': 'price range',
'book_time': 'book time',
'book_day': 'book day',
'book_people': 'book people',
'addr': 'address',
'post': 'postcode',
'price': 'price range',
'people': 'book people'
},
'taxi': {
'arriveBy': 'arrive by',
'leaveAt': 'leave at',
'arrive': 'arrive by',
'leave': 'leave at',
'car': 'type',
'car type': 'type',
'depart': 'departure',
'dest': 'destination'
},
'train': {
'arriveBy': 'arrive by',
'leaveAt': 'leave at',
'book_people': 'book people',
'arrive': 'arrive by',
'leave': 'leave at',
'depart': 'departure',
'dest': 'destination',
'id': 'train id',
'people': 'book people',
'time': 'duration',
'ticket': 'price',
'trainid': 'train id'
},
'attraction': {
'post': 'postcode',
'addr': 'address',
'fee': 'entrance fee',
'price': 'entrance fee'
},
'general': {},
'hospital': {
'post': 'postcode',
'addr': 'address'
},
'police': {
'post': 'postcode',
'addr': 'address'
}
}
_generic_referral = {
'hotel': {
'name': 'the hotel',
'area': 'same area as the hotel',
'price range': 'in the same price range as the hotel'
},
'restaurant': {
'name': 'the restaurant',
'area': 'same area as the restaurant',
'price range': 'in the same price range as the restaurant'
},
'attraction': {
'name': 'the attraction',
'area': 'same area as the attraction'
}
}
def normalize_values(self, text):
text = text.lower()
text_to_num = {"zero": "0", "one": "1", "me": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7"}
text = re.sub("\s*(\W)\s*", r"\1" , text) # Re-attach special characters
text = re.sub("s'([^s])", r"s' \1", text) # Add space after plural genitive apostrophe
if text in text_to_num:
text = text_to_num[text]
return text
def normalize_text(self, text):
norm_text = text.lower()
#norm_text = re.sub("n't", " not", norm_text) # Does not make much of a difference
norm_text = ' '.join([tok for tok in map(str.strip, re.split("(\W+)", norm_text)) if len(tok) > 0])
return norm_text
def normalize_prediction(self, domain, slot, value, predictions=None, class_predictions=None, config=None):
v = value
if domain == 'hotel' and slot == 'type':
# Map Boolean predictions to regular predictions.
v = "hotel" if value == "yes" else value
v = "guesthouse" if value == "no" else value
# HOTFIX: Avoid overprediction of hotel type caused by ambiguous rule based user simulator NLG.
if predictions['hotel-name'] != 'none':
v = 'none'
if config.dst_class_types[class_predictions['hotel-none']] == 'request':
v = 'none'
return v
DATASET_INTERFACERS = {
'multiwoz21': MultiwozInterfacer()
}
def create_dataset_interfacer(dataset_name="multiwoz21"):
if dataset_name in DATASET_INTERFACERS:
return DATASET_INTERFACERS[dataset_name]
else:
logging.warn("You attempt to create a dataset interfacer for an unknown dataset '%s'. Creating generic dataset interfacer." % (dataset_name))
return DatasetInterfacer()
# coding=utf-8
#
# Copyright 2020-2022 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (BertModel, BertPreTrainedModel,
RobertaModel, RobertaPreTrainedModel,
ElectraModel, ElectraPreTrainedModel)
PARENT_CLASSES = {
'bert': BertPreTrainedModel,
'roberta': RobertaPreTrainedModel,
'electra': ElectraPreTrainedModel
}
MODEL_CLASSES = {
BertPreTrainedModel: BertModel,
RobertaPreTrainedModel: RobertaModel,
ElectraPreTrainedModel: ElectraModel
}
class ElectraPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
def TransformerForDST(parent_name):
if parent_name not in PARENT_CLASSES:
raise ValueError("Unknown model %s" % (parent_name))
class TransformerForDST(PARENT_CLASSES[parent_name]):
def __init__(self, config):
assert config.model_type in PARENT_CLASSES
assert self.__class__.__bases__[0] in MODEL_CLASSES
super(TransformerForDST, self).__init__(config)
self.model_type = config.model_type
self.slot_list = config.dst_slot_list
self.class_types = config.dst_class_types
self.class_labels = config.dst_class_labels
self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
self.stack_token_logits = config.dst_stack_token_logits
self.class_aux_feats_inform = config.dst_class_aux_feats_inform
self.class_aux_feats_ds = config.dst_class_aux_feats_ds
self.class_loss_ratio = config.dst_class_loss_ratio
# Only use refer loss if refer class is present in dataset.
if 'refer' in self.class_types:
self.refer_index = self.class_types.index('refer')
else:
self.refer_index = -1
# Make sure this module has the same name as in the pretrained checkpoint you want to load!
self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config))
if self.model_type == "electra":
self.pooler = ElectraPooler(config)
self.dropout = nn.Dropout(config.dst_dropout_rate)
self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
if self.class_aux_feats_inform:
self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
if self.class_aux_feats_ds:
self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
for slot in self.slot_list:
self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
self.init_weights()
def forward(self,
input_ids,
input_mask=None,
segment_ids=None,
position_ids=None,
head_mask=None,
start_pos=None,
end_pos=None,
inform_slot_id=None,
refer_id=None,
class_label_id=None,
diag_state=None):
outputs = getattr(self, self.model_type)(
input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
position_ids=position_ids,
head_mask=head_mask
)
sequence_output = outputs[0]
if self.model_type == "electra":
pooled_output = self.pooler(sequence_output)
else:
pooled_output = outputs[1]
sequence_output = self.dropout(sequence_output)
pooled_output = self.dropout(pooled_output)
if inform_slot_id is not None:
inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
if diag_state is not None:
diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
total_loss = 0
per_slot_per_example_loss = {}
per_slot_class_logits = {}
per_slot_start_logits = {}
per_slot_end_logits = {}
per_slot_refer_logits = {}
for slot in self.slot_list:
if self.class_aux_feats_inform and self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
elif self.class_aux_feats_inform:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
elif self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
else:
pooled_output_aux = pooled_output
class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
start_logits, end_logits = token_logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
per_slot_class_logits[slot] = class_logits
per_slot_start_logits[slot] = start_logits
per_slot_end_logits[slot] = end_logits
per_slot_refer_logits[slot] = refer_logits
# If there are no labels, don't compute loss
if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
# If we are on multi-GPU, split add a dimension
if len(start_pos[slot].size()) > 1:
start_pos[slot] = start_pos[slot].squeeze(-1)
if len(end_pos[slot].size()) > 1:
end_pos[slot] = end_pos[slot].squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) # This is a single index
start_pos[slot].clamp_(0, ignored_index)
end_pos[slot].clamp_(0, ignored_index)
class_loss_fct = CrossEntropyLoss(reduction='none')
token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
refer_loss_fct = CrossEntropyLoss(reduction='none')
if not self.stack_token_logits:
start_loss = token_loss_fct(start_logits, start_pos[slot])
end_loss = token_loss_fct(end_logits, end_pos[slot])
else:
start_loss = token_loss_fct(torch.cat((start_logits, end_logits), 1), start_pos[slot])
end_loss = token_loss_fct(torch.cat((end_logits, start_logits), 1), end_pos[slot])
token_loss = (start_loss + end_loss) / 2.0
token_is_pointable = (start_pos[slot] > 0).float()
if not self.token_loss_for_nonpointable:
token_loss *= token_is_pointable
refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
if not self.refer_loss_for_nonpointable:
refer_loss *= token_is_referrable
class_loss = class_loss_fct(class_logits, class_label_id[slot])
if self.refer_index > -1:
per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
else:
per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
total_loss += per_example_loss.sum()
per_slot_per_example_loss[slot] = per_example_loss
# add hidden states and attention if they are here
outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
return outputs
return TransformerForDST
This diff is collapsed.
...@@ -225,7 +225,7 @@ ...@@ -225,7 +225,7 @@
"Do you know the name of it ?", "Do you know the name of it ?",
"can you give me the name of it ?" "can you give me the name of it ?"
], ],
"Price": [ "Fee": [
"any specific price range to help narrow down available options ?", "any specific price range to help narrow down available options ?",
"What price range would you like ?", "What price range would you like ?",
"what is your price range for that ?", "what is your price range for that ?",
...@@ -363,42 +363,6 @@ ...@@ -363,42 +363,6 @@
"I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ." "I ' m sorry but there is no availability for #BOOKING-NOBOOK-PEOPLE# people ."
] ]
}, },
"Booking-Request": {
"Day": [
"What day would you like your booking for ?",
"What day would you like that reservation ?",
"what day would you like the booking to be made for ?",
"What day would you like to book ?",
"Ok , what day would you like to make the reservation on ?"
],
"Stay": [
"How many nights will you be staying ?",
"And how many nights ?",
"for how many days ?",
"And for how many days ?",
"how many days would you like to stay ?",
"How many nights would you like to book it for ?",
"And what nights would you like me to reserve for you ?",
"How many nights are you wanting to stay ?",
"How many days will you be staying ?"
],
"People": [
"For how many people ?",
"How many people will be ?",
"How many people will be with you ?",
"How many people is the reservation for ?"
],
"Time": [
"Do you have a time preference ?",
"what time are you looking for a reservation at ?",
"For what time ?",
"What time would you like me to make your reservation ?",
"What time would you like the reservation for ?",
"what time should I make the reservation for ?",
"What time would you prefer ?",
"What time would you like the reservation for ?"
]
},
"Hotel-Inform": { "Hotel-Inform": {
"Internet": [ "Internet": [
"it has wifi .", "it has wifi .",
...@@ -697,6 +661,30 @@ ...@@ -697,6 +661,30 @@
"Do you need free parking ?", "Do you need free parking ?",
"Will you need parking while you 're there ?", "Will you need parking while you 're there ?",
"Will you be needing free parking ?" "Will you be needing free parking ?"
],
"Day": [
"What day would you like your booking for ?",
"What day would you like that reservation ?",
"what day would you like the booking to be made for ?",
"What day would you like to book ?",
"Ok , what day would you like to make the reservation on ?"
],
"Stay": [
"How many nights will you be staying ?",
"And how many nights ?",
"for how many days ?",
"And for how many days ?",
"how many days would you like to stay ?",
"How many nights would you like to book it for ?",
"And what nights would you like me to reserve for you ?",
"How many nights are you wanting to stay ?",
"How many days will you be staying ?"
],
"People": [
"For how many people ?",
"How many people will be ?",
"How many people will be with you ?",
"How many people is the reservation for ?"
] ]
}, },
"Restaurant-Inform": { "Restaurant-Inform": {
...@@ -918,6 +906,29 @@ ...@@ -918,6 +906,29 @@
"what is the name of the restaurant you are needing information on ?", "what is the name of the restaurant you are needing information on ?",
"Do you know the name of the location ?", "Do you know the name of the location ?",
"Is there a certain restaurant you 're looking for ?" "Is there a certain restaurant you 're looking for ?"
],
"Day": [
"What day would you like your booking for ?",
"What day would you like that reservation ?",
"what day would you like the booking to be made for ?",
"What day would you like to book ?",
"Ok , what day would you like to make the reservation on ?"
],
"People": [
"For how many people ?",
"How many people will be ?",
"How many people will be with you ?",
"How many people is the reservation for ?"
],
"Time": [
"Do you have a time preference ?",
"what time are you looking for a reservation at ?",
"For what time ?",
"What time would you like me to make your reservation ?",
"What time would you like the reservation for ?",
"what time should I make the reservation for ?",
"What time would you prefer ?",
"What time would you like the reservation for ?"
] ]
}, },
"Taxi-Inform": { "Taxi-Inform": {
...@@ -1331,6 +1342,77 @@ ...@@ -1331,6 +1342,77 @@
"Is there a time you need to arrive by ?" "Is there a time you need to arrive by ?"
] ]
}, },
"Police-Inform": {
"Addr": [
"it is located in #POLICE-INFORM-ADDR#",
"adress is #POLICE-INFORM-ADDR#",
"It is on #POLICE-INFORM-ADDR# .",
"their address in our system is listed as #POLICE-INFORM-ADDR# .",
"The address is #POLICE-INFORM-ADDR# .",
"it 's located at #POLICE-INFORM-ADDR# .",
"#POLICE-INFORM-ADDR# is the address",
"They are located at #POLICE-INFORM-ADDR# ."
],
"Post": [
"The postcode of the police is #POLICE-INFORM-POST# .",
"The post code is #POLICE-INFORM-POST# .",
"Its postcode is #POLICE-INFORM-POST# .",
"Their postcode is #POLICE-INFORM-POST# ."
],
"Name": [
"I think a fun place to visit is #POLICE-INFORM-NAME# .",
"#POLICE-INFORM-NAME# looks good .",
"#POLICE-INFORM-NAME# is available , would that work for you ?",
"we have #POLICE-INFORM-NAME# .",
"#POLICE-INFORM-NAME# is popular among visitors .",
"How about #POLICE-INFORM-NAME# ?",
"What about #POLICE-INFORM-NAME# ?",
"you might want to try the #POLICE-INFORM-NAME# ."
],
"Phone": [
"The police phone number is #POLICE-INFORM-PHONE# .",
"Here is the police phone number , #POLICE-INFORM-PHONE# ."
]
},
"Hospital-Inform": {
"Addr": [
"it is located in #HOSPITAL-INFORM-ADDR#",
"adress is #HOSPITAL-INFORM-ADDR#",
"It is on #HOSPITAL-INFORM-ADDR# .",
"their address in our system is listed as #HOSPITAL-INFORM-ADDR# .",
"The address is #HOSPITAL-INFORM-ADDR# .",
"it 's located at #HOSPITAL-INFORM-ADDR# .",
"#HOSPITAL-INFORM-ADDR# is the address",
"They are located at #HOSPITAL-INFORM-ADDR# ."
],
"Post": [
"The postcode of the hospital is #HOSPITAL-INFORM-POST# .",
"The post code is #HOSPITAL-INFORM-POST# .",
"Its postcode is #HOSPITAL-INFORM-POST# .",
"Their postcode is #HOSPITAL-INFORM-POST# ."
],
"Department": [
"The department of the hospital is #HOSPITAL-INFORM-POST# .",
"The department is #HOSPITAL-INFORM-POST# .",
"Its department is #HOSPITAL-INFORM-POST# .",
"Their department is #HOSPITAL-INFORM-POST# ."
],
"Phone": [
"The hospital phone number is #HOSPITAL-INFORM-PHONE# .",
"Here is the hospital phone number , #HOSPITAL-INFORM-PHONE# ."
]
},
"Hospital-Request": {
"Department": [
"What is the name of the hospital department ?",
"What hospital department are you thinking about ?",
"I ' m sorry for the confusion , what hospital department are you interested in ?",
"What hospital department were you thinking of ?",
"Do you know the department of it ?",
"can you give me the department of it ?"
]
},
"general-bye": { "general-bye": {
"none": [ "none": [
"Thank you for using our services .", "Thank you for using our services .",
......
...@@ -31,33 +31,33 @@ def read_json(filename): ...@@ -31,33 +31,33 @@ def read_json(filename):
# supported slot # supported slot
Slot2word = { Slot2word = {
'Fee': 'fee', 'Fee': 'entrance fee',
'Addr': 'address', 'Addr': 'address',
'Area': 'area', 'Area': 'area',
'Stars': 'stars', 'Stars': 'number of stars',
'Internet': 'Internet', 'Internet': 'internet',
'Department': 'department', 'Department': 'department',
'Choice': 'choice', 'Choice': 'choice',
'Ref': 'reference number', 'Ref': 'reference number',
'Food': 'food', 'Food': 'food',
'Type': 'type', 'Type': 'type',
'Price': 'price range', 'Price': 'price range',
'Stay': 'stay', 'Stay': 'length of the stay',
'Phone': 'phone number', 'Phone': 'phone number',
'Post': 'postcode', 'Post': 'postcode',
'Day': 'day', 'Day': 'day',
'Name': 'name', 'Name': 'name',
'Car': 'car type', 'Car': 'car type',
'Leave': 'leave', 'Leave': 'departure time',
'Time': 'time', 'Time': 'time',
'Arrive': 'arrive', 'Arrive': 'arrival time',
'Ticket': 'ticket', 'Ticket': 'ticket price',
'Depart': 'departure', 'Depart': 'departure',
'People': 'people', 'People': 'number of people',
'Dest': 'destination', 'Dest': 'destination',
'Parking': 'parking', 'Parking': 'parking',
'Open': 'open', 'Open': 'opening hours',
'Id': 'Id', 'Id': 'id',
# 'TrainID': 'TrainID' # 'TrainID': 'TrainID'
} }
...@@ -271,6 +271,10 @@ class TemplateNLG(NLG): ...@@ -271,6 +271,10 @@ class TemplateNLG(NLG):
elif 'request' == intent[1]: elif 'request' == intent[1]:
for slot, value in slot_value_pairs: for slot, value in slot_value_pairs:
if dialog_act not in template or slot not in template[dialog_act]: if dialog_act not in template or slot not in template[dialog_act]:
if dialog_act not in template:
print("WARNING (nlg.py): (User?: %s) dialog_act '%s' not in template!" % (self.is_user, dialog_act))
else:
print("WARNING (nlg.py): (User?: %s) slot '%s' of dialog_act '%s' not in template!" % (self.is_user, slot, dialog_act))
sentence = 'What is the {} of {} ? '.format( sentence = 'What is the {} of {} ? '.format(
slot.lower(), dialog_act.split('-')[0].lower()) slot.lower(), dialog_act.split('-')[0].lower())
sentences += self._add_random_noise(sentence) sentences += self._add_random_noise(sentence)
...@@ -288,7 +292,7 @@ class TemplateNLG(NLG): ...@@ -288,7 +292,7 @@ class TemplateNLG(NLG):
value_lower = value.lower() value_lower = value.lower()
if value in ["do nt care", "do n't care", "dontcare"]: if value in ["do nt care", "do n't care", "dontcare"]:
sentence = 'I don\'t care about the {} of the {}'.format( sentence = 'I don\'t care about the {} of the {}'.format(
slot, dialog_act.split('-')[0]) slot2word.get(slot, slot), dialog_act.split('-')[0])
elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any': elif self.is_user and dialog_act.split('-')[1] == 'inform' and slot == 'choice' and value_lower == 'any':
# user have no preference, any choice is ok # user have no preference, any choice is ok
sentence = random.choice([ sentence = random.choice([
......
...@@ -74,7 +74,8 @@ class BERTNLU(NLU): ...@@ -74,7 +74,8 @@ class BERTNLU(NLU):
for token in token_list: for token in token_list:
token = token.strip() token = token.strip()
self.nlp.tokenizer.add_special_case( self.nlp.tokenizer.add_special_case(
token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}])
token, [{ORTH: token}])
logging.info("BERTNLU loaded") logging.info("BERTNLU loaded")
def predict(self, utterance, context=list()): def predict(self, utterance, context=list()):
......
{
"dataset_name": "multiwoz21",
"data_dir": "unified_datasets/data/multiwoz21/system/context_window_size_3",
"output_dir": "unified_datasets/output/multiwoz21/system/context_window_size_3",
"zipped_model_path": "unified_datasets/output/multiwoz21/system/context_window_size_3/bertnlu_unified_multiwoz21_system_context3.zip",
"log_dir": "unified_datasets/output/multiwoz21/system/context_window_size_3/log",
"DEVICE": "cuda:0",
"seed": 2019,
"cut_sen_len": 40,
"use_bert_tokenizer": true,
"context_window_size": 3,
"model": {
"finetune": true,
"context": true,
"context_grad": true,
"pretrained_weights": "bert-base-uncased",
"check_step": 1000,
"max_step": 10000,
"batch_size": 128,
"learning_rate": 1e-4,
"adam_epsilon": 1e-8,
"warmup_steps": 0,
"weight_decay": 0.0,
"dropout": 0.1,
"hidden_units": 1536
}
}
...@@ -56,13 +56,50 @@ class PolicyDataVectorizer: ...@@ -56,13 +56,50 @@ class PolicyDataVectorizer:
state['belief_state'] = data_point['context'][-1]['state'] state['belief_state'] = data_point['context'][-1]['state']
state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts']) state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
else: elif "setsumbt" in str(self.dst):
last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else '' last_system_utt = data_point['context'][-2]['utterance'] if len(data_point['context']) > 1 else ''
self.dst.state['history'].append(['sys', last_system_utt]) self.dst.state['history'].append(['sys', last_system_utt])
usr_utt = data_point['context'][-1]['utterance'] usr_utt = data_point['context'][-1]['utterance']
state = deepcopy(self.dst.update(usr_utt)) state = deepcopy(self.dst.update(usr_utt))
self.dst.state['history'].append(['usr', usr_utt]) self.dst.state['history'].append(['usr', usr_utt])
elif "trippy" in str(self.dst):
# Get last system acts and text.
# System acts are used to fill the inform memory.
last_system_acts = []
last_system_utt = ''
if len(data_point['context']) > 1:
last_system_acts = []
for act_type in data_point['context'][-2]['dialogue_acts']:
for act in data_point['context'][-2]['dialogue_acts'][act_type]:
value = ''
if 'value' not in act:
if act['intent'] == 'request':
value = '?'
elif act['intent'] == 'inform':
value = 'yes'
else:
value = act['value']
last_system_acts.append([act['intent'], act['domain'], act['slot'], value])
last_system_utt = data_point['context'][-2]['utterance']
# Get current user acts and text.
# User acts are used for internal evaluation.
usr_acts = []
for act_type in data_point['context'][-1]['dialogue_acts']:
for act in data_point['context'][-1]['dialogue_acts'][act_type]:
usr_acts.append([act['intent'], act['domain'], act['slot'], act['value'] if 'value' in act else ''])
usr_utt = data_point['context'][-1]['utterance']
# Update the state for DST, then update the state via DST.
self.dst.state['system_action'] = last_system_acts
self.dst.state['user_action'] = usr_acts
self.dst.state['history'].append(['sys', last_system_utt])
self.dst.state['history'].append(['usr', usr_utt])
state = deepcopy(self.dst.update(usr_utt))
else:
raise NameError(f"Tracker: {self.dst} not implemented.")
last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {} last_system_act = data_point['context'][-2]['dialogue_acts'] if len(data_point['context']) > 1 else {}
state['system_action'] = flatten_acts(last_system_act) state['system_action'] = flatten_acts(last_system_act)
state['terminated'] = data_point['terminated'] state['terminated'] = data_point['terminated']
......
...@@ -12,6 +12,7 @@ from convlab.util.custom_util import set_seed, init_logging, save_config ...@@ -12,6 +12,7 @@ from convlab.util.custom_util import set_seed, init_logging, save_config
from convlab.util.train_util import to_device from convlab.util.train_util import to_device
from convlab.policy.rlmodule import MultiDiscretePolicy from convlab.policy.rlmodule import MultiDiscretePolicy
from convlab.policy.vector.vector_binary import VectorBinary from convlab.policy.vector.vector_binary import VectorBinary
from convlab.policy.vector.vector_binary_fuzzy import VectorBinaryFuzzy
root_dir = os.path.dirname( root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
...@@ -195,8 +196,15 @@ if __name__ == '__main__': ...@@ -195,8 +196,15 @@ if __name__ == '__main__':
use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info) use_state_knowledge_uncertainty=dst.return_belief_state_mutual_info)
else: else:
vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking) vector = VectorBinary(dataset_name=args.dataset_name, use_masking=args.use_masking)
elif args.dst == "trippy":
dst_args = [arg.split('=', 1) for arg in args.dst_args.split(', ')
if '=' in arg] if args.dst_args is not None else []
dst_args = {key: eval(value) for key, value in dst_args}
from convlab.dst.trippy import TRIPPY
dst = TRIPPY(**dst_args)
vector = VectorBinaryFuzzy(dataset_name=args.dataset_name, use_masking=args.use_masking)
else: else:
raise NameError(f"Tracker: {args.tracker} not implemented.") raise NameError(f"Tracker: {args.dst} not implemented.")
manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst) manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector, dst=dst)
agent = MLE_Trainer(manager, vector, cfg) agent = MLE_Trainer(manager, vector, cfg)
......
{
"model": {
"load_path": "/path/to/model/checkpoint",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 1000,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
"process_num": 2,
"num_eval_dialogues": 500,
"sys_semantic_to_usr": false
},
"vectorizer_sys": {
"fuzzy_vector_mul": {
"class_path": "convlab.policy.vector.vector_binary_fuzzy.VectorBinaryFuzzy",
"ini_params": {
"use_masking": true,
"manually_add_entity_names": true,
"seed": 0
}
}
},
"nlu_sys": {},
"dst_sys": {
"TripPy": {
"class_path": "convlab.dst.trippy.TRIPPY",
"ini_params": {
"model_type": "roberta",
"model_name": "roberta-base",
"model_path": "/path/to/model/checkpoint",
"dataset_name": "multiwoz21"
}
}
},
"sys_nlg": {
"TemplateNLG": {
"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
"ini_params": {
"is_user": false
}
}
},
"nlu_usr": {
"BERTNLU": {
"class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU",
"ini_params": {
"mode": "sys",
"config_file": "multiwoz21_sys_context3.json",
"model_file": "/path/to/model/checkpoint.zip"
}
}
},
"dst_usr": {},
"policy_usr": {
"RulePolicy": {
"class_path": "convlab.policy.rule.multiwoz.RulePolicy",
"ini_params": {
"character": "usr"
}
}
},
"usr_nlg": {
"TemplateNLG": {
"class_path": "convlab.nlg.template.multiwoz.TemplateNLG",
"ini_params": {
"is_user": true,
"label_noise": 0.0,
"text_noise": 0.0
}
}
}
}
# -*- coding: utf-8 -*-
import sys
import numpy as np
from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
from .vector_binary import VectorBinary
class VectorBinaryFuzzy(VectorBinary):
def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True,
seed=0):
super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
def dbquery_domain(self, domain):
"""
query entities of specified domain
Args:
domain string:
domain to query
Returns:
entities list:
list of entities of the specified domain
"""
# Get all user constraints
constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \
if domain in self.state else []
xx = self.db.query(domain=domain, state=[], soft_contraints=constraints, fuzzy_match_ratio=100, topk=10)
yy = self.db.query(domain=domain, state=constraints, topk=10)
#print("STRICT:", yy)
#print("FUZZY :", xx)
#if len(yy) == 1 and len(xx) > 1:
# import pdb
# pdb.set_trace()
return xx
#return self.db.query(domain=domain, state=[], soft_contraints=constraints, fuzzy_match_ratio=100, topk=10)
...@@ -57,18 +57,23 @@ class Database(BaseDatabase): ...@@ -57,18 +57,23 @@ class Database(BaseDatabase):
for key, val in state: for key, val in state:
if key == 'department': if key == 'department':
department = val department = val
if not department:
for key, val in soft_contraints:
if key == 'department':
department = val
if not department: if not department:
return deepcopy(self.dbs['hospital']) return deepcopy(self.dbs['hospital'])
else: else:
return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()] return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()]
state = list(map(lambda ele: (self.slot2dbattr.get(ele[0], ele[0]), ele[1]) if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), state)) state = list(map(lambda ele: (self.slot2dbattr.get(ele[0], ele[0]), ele[1]) if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), state))
soft_contraints = list(map(lambda ele: (self.slot2dbattr.get(ele[0], ele[0]), ele[1]) if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), soft_contraints))
found = [] found = []
for i, record in enumerate(self.dbs[domain]): for i, record in enumerate(self.dbs[domain]):
constraints_iterator = zip(state, [False] * len(state)) constraints_iterator = zip(state, [False] * len(state))
soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints)) soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints))
for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator): for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator):
if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]: if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care", "do not care"]:
pass pass
else: else:
try: try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment