diff --git a/convlab2/dialog_agent/agent.py b/convlab2/dialog_agent/agent.py index 2feed5ad6e0aebe0fea218d3801c9b1a150aea6f..a5e437c277547f09522fd9bffa35019c1afc552c 100755 --- a/convlab2/dialog_agent/agent.py +++ b/convlab2/dialog_agent/agent.py @@ -197,7 +197,7 @@ class PipelineAgent(Agent): if domain.lower() not in ['general', 'booking']: self.cur_domain = domain if intent == "book": - self.dst.state['belief_state'][domain.lower()]['book']['booked'] = [{slot.lower(): value}] + self.dst.state['booked'][domain] = [{slot.lower(): value}] else: self.dst.state['user_action'] = self.output_action # user dst is also updated by itself @@ -403,14 +403,8 @@ class DialogueAgent(Agent): for intent, domain, slot, value in self.output_action: if domain.lower() not in ['general', 'booking']: self.cur_domain = domain - dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}' - if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']: - if self.cur_domain: - self.dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref': - self.dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'taxi-inform-car': - self.dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}] + if intent == "book": + self.dst.state['booked'][domain] = [{slot.lower(): value}] self.history.append([self.name, model_response]) self.turn += 1 diff --git a/convlab2/dialog_agent/env.py b/convlab2/dialog_agent/env.py index 331737b075bf121b71f9134c4f9c8595efb057ac..1fd190d26c0a7baac3f84499e7daa3b41f584bdd 100755 --- a/convlab2/dialog_agent/env.py +++ b/convlab2/dialog_agent/env.py @@ -37,16 +37,8 @@ class Environment(): # If system takes booking action add booking info to the 'book-booked' section of the belief state if type(action) == list: for intent, domain, slot, value in action: - if domain.lower() not in ['general', 'booking']: - self.cur_domain = domain - dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}' - if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']: - if self.cur_domain: - self.sys_dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref': - self.sys_dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'taxi-inform-car': - self.sys_dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}] + if intent == "book": + self.sys_dst.state['booked'][domain] = [{slot.lower(): value}] observation = self.usr.response(model_response) if self.evaluator: @@ -59,11 +51,6 @@ class Environment(): state = self.sys_dst.update(dialog_act) dialog_act = self.sys_dst.state['user_action'] - if type(dialog_act) == list: - for intent, domain, slot, value in dialog_act: - if domain.lower() not in ['booking', 'general']: - self.cur_domain = domain - state['history'].append(["sys", model_response]) state['history'].append(["usr", observation]) diff --git a/convlab2/dst/rule/multiwoz/dst.py b/convlab2/dst/rule/multiwoz/dst.py index 3974ab848f47a70455c9d1a80f3f14baafcea17e..e4ac0f432396b6f5a21acec2fd9c4e31e7f6afc6 100755 --- a/convlab2/dst/rule/multiwoz/dst.py +++ b/convlab2/dst/rule/multiwoz/dst.py @@ -1,6 +1,7 @@ import json import os +from convlab2.util import load_ontology from convlab2.util.multiwoz.state import default_state from convlab2.dst.rule.multiwoz.dst_util import normalize_value from convlab2.dst.dst import DST @@ -17,9 +18,11 @@ class RuleDST(DST): It helps check whether ``user_act`` has correct content. """ - def __init__(self): + def __init__(self, dataset_name='multiwoz21'): DST.__init__(self) + self.ontology = load_ontology(dataset_name) self.state = default_state() + self.state['belief_state'] = self.ontology['state'] path = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) path = os.path.join(path, 'data/multiwoz/value_dict.json') @@ -31,34 +34,17 @@ class RuleDST(DST): :param user_act: :return: """ - #print("dst", user_act) for intent, domain, slot, value in user_act: - domain = domain.lower() - intent = intent.lower() - if domain in ['unk', 'general', 'booking']: + if domain not in self.state['belief_state']: continue if intent == 'inform': k = REF_SYS_DA[domain.capitalize()].get(slot, slot) if k is None: continue - try: - assert domain in self.state['belief_state'] - except: - raise Exception( - 'Error: domain <{}> not in new belief state'.format(domain)) domain_dic = self.state['belief_state'][domain] - assert 'semi' in domain_dic - assert 'book' in domain_dic - if k in domain_dic['semi']: + if k in domain_dic: nvalue = normalize_value(self.value_dict, domain, k, value) - self.state['belief_state'][domain]['semi'][k] = nvalue - elif k in domain_dic['book']: - self.state['belief_state'][domain]['book'][k] = value - elif k.lower() in domain_dic['book']: - self.state['belief_state'][domain]['book'][k.lower() - ] = value - elif k == 'trainID' and domain == 'train': - self.state['belief_state'][domain]['book'][k] = normalize_value(self.value_dict, domain, k, value) + self.state['belief_state'][domain][k] = nvalue elif k != 'none': # raise Exception('unknown slot name <{}> of domain <{}>'.format(k, domain)) with open('unknown_slot.log', 'a+') as f: @@ -76,6 +62,7 @@ class RuleDST(DST): def init_session(self): """Initialize ``self.state`` with a default state, which ``convlab2.util.multiwoz.state.default_state`` returns.""" self.state = default_state() + self.state['belief_state'] = self.ontology['state'] if __name__ == '__main__': diff --git a/convlab2/policy/mle/__init__.py b/convlab2/policy/mle/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..e96e3c71a9984c5c0a97bab1edbaf811006b3f5d 100755 --- a/convlab2/policy/mle/__init__.py +++ b/convlab2/policy/mle/__init__.py @@ -0,0 +1 @@ +from .mle import MLE \ No newline at end of file diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index b2778a31925aaee64374b17975899d54a1490a62..9fbad07a0fb2ef98f183176d5941868fe0b91646 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -57,10 +57,6 @@ class UserPolicyAgendaMultiWoz(Policy): """ self.max_turn = 40 self.max_initiative = 4 - self.forbidden_domains = {} - self.mandatory_domains = [] - self.only_single_domains = False - self.goal_generator = GoalGenerator() self.__turn = 0 @@ -75,25 +71,12 @@ class UserPolicyAgendaMultiWoz(Policy): """ Build new Goal and Agenda for next session """ self.reset_turn() if not ini_goal: - self.find_valid_goal() + self.goal = Goal(self.goal_generator) else: self.goal = ini_goal self.domain_goals = self.goal.domain_goals self.agenda = Agenda(self.goal) - def find_valid_goal(self): - - valid = False - while not valid: - self.goal = Goal(self.goal_generator) - if self.only_single_domains and len(self.goal.domain_goals) > 1: - continue - if set(self.forbidden_domains) & set(self.goal.domain_goals): - continue - if not set(self.mandatory_domains).issubset(set(self.goal.domain_goals)): - continue - valid = True - def predict(self, sys_dialog_act): """ Predict an user act based on state and preorder system action. diff --git a/convlab2/policy/vector/vector_base.py b/convlab2/policy/vector/vector_base.py index 5880da6fd94f8cc39c57d169549406e0596d86c5..108ac83d71fef8f7b547a9f94ec195ff832f131d 100644 --- a/convlab2/policy/vector/vector_base.py +++ b/convlab2/policy/vector/vector_base.py @@ -4,37 +4,23 @@ import sys import numpy as np import copy import logging + from copy import deepcopy from convlab2.policy.vec import Vector from convlab2.util.custom_util import flatten_acts from convlab2.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA - from convlab2.util import load_ontology, load_database, load_dataset -DEFAULT_INTENT_FILEPATH = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))), - 'data/multiwoz/trackable_intent.json' -) root_dir = os.path.dirname(os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(root_dir) -SLOT_MAP = {'taxi_types': 'car type'} - -#TODO: The masks depend on multiwoz, deal with that somehow, shall we build a Mask class? -#TODO: Check the masks with new action strings -#TODO: Where should i save the action dicts? -#TODO: Load actions from ontology properly -#TODO: method AddName is properly not working right anymore - - class VectorBase(Vector): - def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True, + def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=False, seed=0): super().__init__() @@ -78,10 +64,27 @@ class VectorBase(Vector): def load_action_dicts(self): - self.load_actions_from_data() + dir_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + f'action_dicts/{self.dataset_name}_{type(self).__name__}') + if not (os.path.exists(os.path.join(dir_path, "sys_da_voc.txt")) + and os.path.exists(os.path.join(dir_path, "user_da_voc.txt"))): + print("Load actions from data..") + self.load_actions_from_data() + else: + print("Load actions from file..") + with open(os.path.join(dir_path, "sys_da_voc.txt")) as f: + self.da_voc = f.read().splitlines() + with open(os.path.join(dir_path, "user_da_voc.txt")) as f: + self.da_voc_opp = f.read().splitlines() + self.generate_dict() def load_actions_from_data(self, frequency_threshold=50): + """ + Loads the action sets for user and system using a data set. + The frequency_threshold prohibits adding actions that occur fewer times than this threshold in the data + (for instance there might be incorrectly labelled actions) + """ data_split = load_dataset(self.dataset_name) system_dict = {} @@ -117,27 +120,35 @@ class VectorBase(Vector): if user_dict[key] < frequency_threshold: del user_dict[key] - with open("sys_da_voc.txt", "w") as f: - system_acts = list(system_dict.keys()) - system_acts.sort() - for act in system_acts: + self.da_voc = list(system_dict.keys()) + self.da_voc.sort() + self.da_voc_opp = list(user_dict.keys()) + self.da_voc_opp.sort() + + self.save_acts_to_txt() + + def save_acts_to_txt(self): + dir_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + f'action_dicts/{self.dataset_name}_{type(self).__name__}') + os.makedirs(dir_path, exist_ok=True) + with open(os.path.join(dir_path, "sys_da_voc.txt"), "w") as f: + for act in self.da_voc: f.write(act + "\n") - with open("user_da_voc.txt", "w") as f: - user_acts = list(user_dict.keys()) - user_acts.sort() - for act in user_acts: + with open(os.path.join(dir_path, "user_da_voc.txt"), "w") as f: + for act in self.da_voc_opp: f.write(act + "\n") - print("Saved new action dict.") - - self.da_voc = system_acts - self.da_voc_opp = user_acts def load_actions_from_ontology(self): + """ + Loads the action sets for user and system if an ontology is provided. + It is recommended to use load_actions_from_data to guarantee consistency with previous results + """ self.da_voc = [] self.da_voc_opp = [] for act_type in self.ontology['dialogue_acts']: for act in self.ontology['dialogue_acts'][act_type]: + act = eval(act) system = act['system'] user = act['user'] if system: @@ -148,6 +159,9 @@ class VectorBase(Vector): user_acts_with_value = self.add_values_to_act(act['domain'], act['intent'], act['slot'], False) self.da_voc_opp.extend(user_acts_with_value) + self.da_voc.sort() + self.da_voc_opp.sort() + def generate_dict(self): """ init the dict for mapping state/action into vector @@ -212,7 +226,10 @@ class VectorBase(Vector): np.random.seed(seed) def compute_domain_mask(self, domain_active_dict): - + ''' + Can not speak about a domain if that domain is not active. + A domain is active if the user mentioned it in the current turn or if a slot is filled with a value + ''' mask_list = np.zeros(self.da_dim) for i in range(self.da_dim): @@ -232,49 +249,36 @@ class VectorBase(Vector): action = self.vec2act[i] domain, intent, slot, value = action.split('-') - # NoBook-SLOT does not make sense because policy can not know which constraint made booking impossible + # NoBook/NoOffer-SLOT does not make sense because policy can not know which constraint made offer impossible # If one wants to do it, lexicaliser needs to do it - if intent.lower() in ['nobook', 'nooffer'] and slot.lower() != 'none': + if intent in ['nobook', 'nooffer'] and slot != 'none': mask_list[i] = 1.0 - # see policy/rule/multiwoz/policy_agenda_multiwoz.py: illegal booking slot. Is self.cur_domain correct? - if self.cur_domain is not None: - if slot.lower() == 'time' and self.cur_domain.lower() not in ['train', 'restaurant']: - if domain.lower() == 'booking': - mask_list[i] = 1.0 - - if slot.lower() in self.state[self.cur_domain.lower()]['book']: - if not self.state[self.cur_domain.lower()]['book'][slot.lower()] and intent.lower() == 'inform': - mask_list[i] = 1.0 + if "book" in slot and intent.lower() == 'inform' and not self.state[domain][slot]: + mask_list[i] = 1.0 - if domain.lower() == 'taxi': - slot = REF_SYS_DA.get(domain, {}).get(slot, slot.lower()) - if slot in self.state['taxi']['semi']: - if not self.state['taxi']['semi'][slot] and intent.lower() == 'inform': + if domain == 'taxi': + if slot in self.state['taxi']: + if not self.state['taxi'][slot] and intent.lower() == 'inform': mask_list[i] = 1.0 return mask_list def compute_entity_mask(self, number_entities_dict): + ''' + 1. If there is no i-th entity in the data base, can not inform/recommend/select on that entity + 2. If there is an entity available, can not say NoOffer or NoBook + ''' mask_list = np.zeros(self.da_dim) for i in range(self.da_dim): action = self.vec2act[i] domain, intent, slot, value = action.split('-') domain_entities = number_entities_dict.get(domain, 1) - if intent.lower() in ['inform', 'select', 'recommend'] and value != None and value != 'none': - if(int(value) > domain_entities): - mask_list[i] = 1.0 - - if intent.lower() in ['inform', 'select', 'recommend'] and domain.lower() in ['booking']: - if number_entities_dict.get(self.cur_domain, 0) == 0: + if intent in ['inform', 'select', 'recommend'] and value != None and value != 'none': + if int(value) > domain_entities: mask_list[i] = 1.0 - - # mask Booking-NoBook if an entity is available in the current domain - if intent.lower() in ['nobook'] and number_entities_dict.get(self.cur_domain, 0) > 0: - mask_list[i] = 1.0 - - if intent.lower() in ['nooffer'] and number_entities_dict.get(domain, 0) > 0: + if intent in ['nooffer', 'nobook'] and number_entities_dict.get(domain, 0) > 0: mask_list[i] = 1.0 return mask_list @@ -372,12 +376,13 @@ class VectorBase(Vector): entities = {} for domint in action: domain, intent = domint.split('-') - if domain not in entities and domain.lower() not in ['general', 'booking']: + if domain not in entities and domain.lower() not in ['general']: entities[domain] = self.dbquery_domain(domain) if self.cur_domain and self.cur_domain not in entities: entities[self.cur_domain] = self.dbquery_domain(self.cur_domain) - nooffer = [domint for domint in action if 'NoOffer' in domint] + #TODO: Rewrite find_noffer_slot + nooffer = [domint for domint in action if 'nooffer' in domint] for domint in nooffer: domain, intent = domint.split('-') slot = self.find_nooffer_slot(domain) @@ -385,7 +390,7 @@ class VectorBase(Vector): action[domint] = [[slot, '1'] ] if slot != 'none' else [[slot, 'none']] - nobook = [domint for domint in action if 'NoBook' in domint] + nobook = [domint for domint in action if 'nobook' in domint] for domint in nobook: domain = self.cur_domain if self.cur_domain else 'none' if domain.lower() in self.state: @@ -401,7 +406,6 @@ class VectorBase(Vector): ] if slot != 'none' else [[slot, 'none']] # When there is a INFORM(1 name) or OFFER(multiple) action then inform the name - if self.use_add_name: action = self.add_name(action) @@ -424,18 +428,18 @@ class VectorBase(Vector): name_inform = [] contains_name = False # General Inform Condition for Naming - cur_inform = str(self.cur_domain) + '-Inform' - cur_request = str(self.cur_domain) + '-Request' + cur_inform = str(self.cur_domain) + '-inform' + cur_request = str(self.cur_domain) + '-request' index = -1 if cur_inform in action: for [item, idx] in action[cur_inform]: - if item == 'Name': + if item == 'name': contains_name = True - elif self.cur_domain == 'Train' and item == 'Id': + elif self.cur_domain == 'train' and item == 'id': contains_name = True - elif self.cur_domain == 'Hospital': + elif self.cur_domain == 'hospital': contains_name = True - elif item == 'Choice' and cur_request in action: + elif item == 'choice' and cur_request in action: contains_name = True if index != -1 and index != idx and idx is not None: @@ -445,10 +449,10 @@ class VectorBase(Vector): index = idx if contains_name == False: - if self.cur_domain == 'Train': - name_act = ['Id', index] + if self.cur_domain == 'train': + name_act = ['id', index] else: - name_act = ['Name', index] + name_act = ['name', index] tmp = [name_act] + action[cur_inform] name_inform = name_act diff --git a/convlab2/policy/vector/vector_binary.py b/convlab2/policy/vector/vector_binary.py index 8fde29144b98d1fc1b6b21002c1cc646e5047bcf..996a72ff9e49475d6be4afb17bd9a34b7c2e12a6 100755 --- a/convlab2/policy/vector/vector_binary.py +++ b/convlab2/policy/vector/vector_binary.py @@ -1,19 +1,9 @@ # -*- coding: utf-8 -*- import sys -import os import numpy as np from convlab2.util.multiwoz.lexicalize import delexicalize_da, flat_da from .vector_base import VectorBase -DEFAULT_INTENT_FILEPATH = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))), - 'data/multiwoz/trackable_intent.json' -) - - -SLOT_MAP = {'taxi_types': 'car type'} - class VectorBinary(VectorBase):