Skip to content
Snippets Groups Projects
Select Git revision
  • ec48f3d916f4a138cd928f42ef30e4b540abbfa4
  • master default protected
2 results

Jars.prob2project

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    vector_base.py 20.99 KiB
    # -*- coding: utf-8 -*-
    import os
    import sys
    import numpy as np
    import logging
    import json
    
    from copy import deepcopy
    from convlab.policy.vec import Vector
    from convlab.util.custom_util import flatten_acts, timeout
    from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da
    from convlab.util import load_ontology, load_database, load_dataset
    
    
    root_dir = os.path.dirname(os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
    sys.path.append(root_dir)
    
    
    class VectorBase(Vector):
    
        def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=False,
                     always_inform_booking_reference=True, seed=0, use_none=True):
    
            super().__init__()
    
            logging.info(f"Vectorizer: Data set used is {dataset_name}")
            self.set_seed(seed)
            self.ontology = load_ontology(dataset_name)
            try:
                # execute to make sure that the database exists or is downloaded otherwise
                if dataset_name == "multiwoz21" or dataset_name == "crosswoz":
                    load_database(dataset_name)
                # the following two lines are needed for pickling correctly during multi-processing
                exec(f'from data.unified_datasets.{dataset_name}.database import Database')
                self.db = eval('Database()')
                self.db_domains = self.db.domains
            except Exception as e:
                self.db = None
                self.db_domains = []
                print(f"VectorBase: {e}")
    
            self.dataset_name = dataset_name
            self.max_actionval = {}
            self.use_mask = use_masking
            self.use_add_name = manually_add_entity_names
            self.always_inform_booking_reference = always_inform_booking_reference
            self.reqinfo_filler_action = None
            self.character = character
            self.use_none = use_none
            self.requestable = ['request']
            self.informable = ['inform', 'recommend']
    
            self.load_attributes()
            self.get_state_dim()
            print(f"State dimension: {self.state_dim}")
    
        def load_attributes(self):
    
            self.domains = list(self.ontology['domains'].keys())
            self.domains.sort()
    
            self.previous_name_actions = {domain: [] for domain in self.domains}
    
            self.state = self.ontology['state']
            self.belief_domains = list(self.state.keys())
            self.belief_domains.sort()
    
            self.load_action_dicts()
    
        def load_action_dicts(self):
    
            dir_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                    f'action_dicts/{self.dataset_name}_{type(self).__name__}')
            if not (os.path.exists(os.path.join(dir_path, "sys_da_voc.txt"))
                    and os.path.exists(os.path.join(dir_path, "user_da_voc.txt"))):
                print("Load actions from data..")
                self.load_actions_from_data()
            else:
                print("Load actions from file..")
                with open(os.path.join(dir_path, "sys_da_voc.txt")) as f:
                    self.da_voc = f.read().splitlines()
                    if self.da_voc[0][0] != "[":
                        # if act is not a list, we still have the old action dict
                        self.load_actions_from_data()
                    else:
                        self.da_voc = [tuple(json.loads(act)) for act in self.da_voc]
                with open(os.path.join(dir_path, "user_da_voc.txt")) as f:
                    self.da_voc_opp = f.read().splitlines()
                    self.da_voc_opp = [tuple(json.loads(act)) for act in self.da_voc_opp]
    
            self.generate_dict()
    
        def load_actions_from_data(self, frequency_threshold=50):
            """
            Loads the action sets for user and system using a data set.
            The frequency_threshold prohibits adding actions that occur fewer times than this threshold in the data
            (for instance there might be incorrectly labelled actions)
            """
    
            data_split = load_dataset(self.dataset_name)
            system_dict = {}
            user_dict = {}
            for key in data_split:
                data = data_split[key]
                for dialogue in data:
                    for turn in dialogue['turns']:
                        dialogue_acts = turn['dialogue_acts']
                        act_list = flatten_acts(dialogue_acts)
                        delex_acts = delexicalize_da(act_list, self.requestable)
    
                        if turn['speaker'] == 'system':
                            for act in delex_acts:
                                act = tuple([a.lower() for a in act])
                                if act not in system_dict:
                                    system_dict[act] = 1
                                else:
                                    system_dict[act] += 1
                        else:
                            for act in delex_acts:
                                act = tuple([a.lower() for a in act])
                                if act not in user_dict:
                                    user_dict[act] = 1
                                else:
                                    user_dict[act] += 1
    
            for key in deepcopy(system_dict):
                if system_dict[key] < frequency_threshold:
                    del system_dict[key]
    
            for key in deepcopy(user_dict):
                if user_dict[key] < frequency_threshold:
                    del user_dict[key]
    
            self.da_voc = list(system_dict.keys())
            self.da_voc.sort()
            self.da_voc_opp = list(user_dict.keys())
            self.da_voc_opp.sort()
    
            self.save_acts_to_txt()
    
        def save_acts_to_txt(self):
            dir_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                    f'action_dicts/{self.dataset_name}_{type(self).__name__}')
            os.makedirs(dir_path, exist_ok=True)
            with open(os.path.join(dir_path, "sys_da_voc.txt"), "w") as f:
                for act in self.da_voc:
                    f.write(json.dumps(act) + "\n")
            with open(os.path.join(dir_path, "user_da_voc.txt"), "w") as f:
                for act in self.da_voc_opp:
                    f.write(json.dumps(act) + "\n")
    
        def load_actions_from_ontology(self):
            """
            Loads the action sets for user and system if an ontology is provided.
            It is recommended to use load_actions_from_data to guarantee consistency with previous results
            """
    
            self.da_voc = []
            self.da_voc_opp = []
            for act_type in self.ontology['dialogue_acts']:
                for act in self.ontology['dialogue_acts'][act_type]:
                    act = eval(act)
                    system = act['system']
                    user = act['user']
                    if system:
                        system_acts_with_value = self.add_values_to_act(
                            act['domain'], act['intent'], act['slot'], True)
                        self.da_voc.extend(system_acts_with_value)
    
                    if user:
                        user_acts_with_value = self.add_values_to_act(
                            act['domain'], act['intent'], act['slot'], False)
                        self.da_voc_opp.extend(user_acts_with_value)
    
            self.da_voc.sort()
            self.da_voc_opp.sort()
    
        def generate_dict(self):
            """
            init the dict for mapping state/action into vector
            """
            self.act2vec = dict((a, i) for i, a in enumerate(self.da_voc))
            self.vec2act = dict((v, k) for k, v in self.act2vec.items())
            self.da_dim = len(self.da_voc)
            self.opp2vec = dict((a, i) for i, a in enumerate(self.da_voc_opp))
            self.da_opp_dim = len(self.da_voc_opp)
    
            print(f"Dimension of system actions: {self.da_dim}")
            print(f"Dimension of user actions: {self.da_opp_dim}")
    
        def get_state_dim(self):
            '''
            Compute the state dimension for the policy input
            '''
            self.state_dim = 0
            raise NotImplementedError
    
        def state_vectorize(self, state):
            """vectorize a state
            Args:
                state (tuple):
                    Dialog state
            Returns:
                state_vec (np.array):
                    Dialog state vector
            """
            raise NotImplementedError
    
        def add_values_to_act(self, domain, intent, slot, system):
            '''
            The ontology does not contain information about the value of an act. This method will add the value and
            is based on how it is created in MultiWOZ. This might need to be changed for other datasets such as SGD.
            '''
    
            if intent == 'request':
                return [f"{domain}_{intent}_{slot}_?"]
    
            if slot == '':
                return [f"{domain}_{intent}_none_none"]
    
            if system:
                if intent in ['recommend', 'select', 'inform']:
                    return [f"{domain}_{intent}_{slot}_{i}" for i in range(1, 4)]
                else:
                    return [f"{domain}_{intent}_{slot}_1"]
            else:
                return [f"{domain}_{intent}_{slot}_1"]
    
        def init_domain_active_dict(self):
            domain_active_dict = {}
            for domain in self.domains:
                if domain == 'general':
                    continue
                domain_active_dict[domain] = False
            return domain_active_dict
    
        def set_seed(self, seed):
            np.random.seed(seed)
    
        def compute_domain_mask(self, domain_active_dict):
            '''
            Can not speak about a domain if that domain is not active.
            A domain is active if the user mentioned it in the current turn or if a slot is filled with a value
            '''
            mask_list = np.zeros(self.da_dim)
    
            for i in range(self.da_dim):
                action = self.vec2act[i]
                action_domain = action[0]
                if action_domain in domain_active_dict.keys():
                    if not domain_active_dict[action_domain]:
                        mask_list[i] = 1.0
    
            return mask_list
    
        def compute_general_mask(self):
    
            mask_list = np.zeros(self.da_dim)
    
            for i in range(self.da_dim):
                action = self.vec2act[i]
                domain, intent, slot, value = action
    
                # NoBook/NoOffer-SLOT does not make sense because policy can not know which constraint made offer impossible
                # If one wants to do it, lexicaliser needs to do it
                if intent in ['nobook', 'nooffer'] and slot != 'none':
                    mask_list[i] = 1.0
    
                if "book" in slot and intent == 'inform':
                    if not self.state.get(domain, {}).get(slot, {}):
                        mask_list[i] = 1.0
    
                if domain == 'taxi':
                    if slot in self.state.get('taxi', {}):
                        if not self.state['taxi'][slot] and intent == 'inform':
                            mask_list[i] = 1.0
    
            return mask_list
    
        def compute_entity_mask(self, number_entities_dict):
            '''
            1. If there is no i-th entity in the data base, can not inform/recommend/select on that entity
            2. If there is an entity available, can not say NoOffer or NoBook
            '''
            mask_list = np.zeros(self.da_dim)
            if number_entities_dict is None:
                return mask_list
            for i in range(self.da_dim):
                action = self.vec2act[i]
                domain, intent, slot, value = action
                domain_entities = number_entities_dict.get(domain, 1)
    
                if intent in ['inform', 'select', 'recommend'] and value != None and value != 'none':
                    if int(value) > domain_entities:
                        mask_list[i] = 1.0
                if intent in ['nooffer', 'nobook'] and number_entities_dict.get(domain, 0) > 0:
                    mask_list[i] = 1.0
    
            return mask_list
    
        def dbquery_domain(self, domain):
            """
            query entities of specified domain
            Args:
                domain string:
                    domain to query
            Returns:
                entities list:
                    list of entities of the specified domain
            """
            #constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \
            #    if domain in self.state else []
            state = self.state if domain in self.state else {domain: {}}
            if domain.lower() == "general":
                return []
            return self.db.query(domain, state, topk=10)
    
        def find_nooffer_slot(self, domain):
            """
            Function used to find which user constraint results in no entities being found
            query entities of specified domain
            Args:
                domain string:
                    domain to query
            Returns:
                entities list:
                    list of entities of the specified domain
            """
            constraints = self.state.get(domain, {})
    
            # Leave slots out of constraints to find which slot constraint results in no entities being found
            for constraint_slot in constraints:
                state = [[slot, value] for slot,
                         value in constraints.items() if slot != constraint_slot]
                entities = self.db.query(domain, state, topk=1)
                if entities:
                    return constraint_slot
    
            # If no single slot results in no entities being found try the above with pairs of slots
            slots = [slot for slot in constraints]
            pairs = []
            for i, slot in enumerate(slots):
                for j, slot1 in enumerate(slots):
                    if j > i:
                        pairs.append((slot, slot1))
    
            for constraint_slots in pairs:
                state = [[slot, value] for slot, value in constraints.items() if slot not in constraint_slots]
                entities = self.db.query(domain, state, topk=1)
                if entities:
                    return np.random.choice(constraint_slots)
    
            # If no single slots or pairs removed results in success then set slot 'none'
            return 'none'
    
        def action_vectorize(self, action):
            action = delexicalize_da(action, self.requestable)
            #action = flat_da(action)
            act_vec = np.zeros(self.da_dim)
    
            for da in action:
                da = tuple(da)
                if da in self.act2vec:
                    act_vec[self.act2vec[da]] = 1.
            return act_vec
    
        def action_devectorize(self, action_vec):
            """
            recover an action
            Args:
                action_vec (np.array):
                    Dialog act vector
            Returns:
                action (tuple):
                    Dialog act
            """
    
            act_array = []
    
            for i, idx in enumerate(action_vec):
                if idx == 1:
                    act_array.append(self.vec2act[i])
    
            if len(act_array) == 0:
                if self.reqinfo_filler_action:
                    act_array.append(("general", "reqinfo", "none", "none"))
                else:
                    act_array.append(("general", "reqmore", "none", "none"))
    
            action = deflat_da(act_array)
            entities = {}
            for domint in action:
                domain, intent = domint
                if domain not in entities and domain not in ['general']:
                    entities[domain] = self.dbquery_domain(domain)
    
            # From db query find which slot causes no_offer
            nooffer = [domint for domint in action if 'nooffer' in domint[1]]
            for domint in nooffer:
                domain, intent = domint
                slot = self.find_nooffer_slot(domain)
                action[domint] = [[slot, '1']
                                  ] if slot != 'none' else [[slot, 'none']]
    
            # Randomly select booking constraint "causing" no_book
            nobook = [domint for domint in action if 'nobook' in domint[1]]
            for domint in nobook:
                domain, intent = domint
                if domain in self.state:
                    slots = self.state[domain]
                    slots = [slot for slot, i in slots.items()
                             if i and 'book' in slot]
                    slots.append('none')
                    slot = np.random.choice(slots)
                else:
                    slot = 'none'
                action[domint] = [[slot, '1']
                                  ] if slot != 'none' else [[slot, 'none']]
    
            if self.always_inform_booking_reference:
                action = self.add_booking_reference(action)
    
            # When there is a INFORM(1 name) or OFFER(multiple) action then inform the name
            if self.use_add_name:
                action = self.add_name(action)
    
            for key in action.keys():
                index = -1
                for [item, idx] in action[key]:
                    if index != -1 and index != idx and idx != '?':
                        pass
                        # logging.debug(
                        #    "System is likely refering multiple entities within this turn")
                        # logging.debug(action[key])
                    index = idx
            action = lexicalize_da(action, entities, self.state, self.requestable)
    
            if not self.use_none:
                # replace all occurences of "none" with an empty string ""
                f = lambda x: x if x != "none" else ""
                action = [[f(x) for x in a_list] for a_list in action]
                #action = [[ for a_tuple in a_list] for a_list in action]
    
            return action
    
        def add_booking_reference(self, action):
            new_acts = {}
            for domint in action:
                domain, intent = domint
    
                if intent == 'book' and action[domint]:
                    ref_domint = (domain, "inform")
                    if ref_domint not in new_acts:
                        new_acts[ref_domint] = []
                    new_acts[ref_domint].append(['ref', '1'])
                    if domint not in new_acts:
                        new_acts[domint] = []
                    new_acts[domint].append(['none', '1'])
                elif domint in new_acts:
                    new_acts[domint] += action[domint]
                else:
                    new_acts[domint] = action[domint]
    
            return new_acts
    
        def add_name(self, action):
    
            name_inform = {domain: [] for domain in self.domains}
            # General Inform Condition for Naming
            domains = [domint[0] for domint in action]
            domains = list(set([d for d in domains if d not in ['general']]))
            for domain in domains:
                contains_name = False
                if domain == 'none':
                    raise NameError('Domain not defined')
                cur_inform = (domain, "inform")
                cur_request = (domain, "request")
                index = -1
                if cur_inform in action:
                    # Check if current inform within a domain is accompanied by a name inform
                    for [slot, value_id] in action[cur_inform]:
                        if slot == 'name':
                            contains_name = True
                        elif domain == 'train' and slot == 'id':
                            contains_name = True
                        elif domain == 'hospital':
                            contains_name = True
                        elif slot == 'choice' and cur_request in action:
                            contains_name = True
    
                    if not contains_name:
                        # Construct name inform act if name is not contained in acts
                        if domain == 'train':
                            name_inform[domain] = ['id', value_id]
                        else:
                            name_inform[domain] = ['name', value_id]
    
                        # If name inform act has not been taken before then add to action set
                        if name_inform[domain] != self.previous_name_actions[domain]:
                            action[cur_inform] += [name_inform[domain]]
                            self.previous_name_actions[domain] = name_inform[domain]
    
            return action
    
        def pointer(self):
            pointer_vector = np.zeros(6 * len(self.db_domains))
            number_entities_dict = {}
            for domain in self.db_domains:
                entities = self.dbquery_domain(domain)
                number_entities_dict[domain] = len(entities)
                pointer_vector = self.one_hot_vector(
                    len(entities), domain, pointer_vector)
    
            return pointer_vector, number_entities_dict
    
        def one_hot_vector(self, num, domain, vector):
            """Return number of available entities for particular domain."""
            if domain != 'train':
                idx = self.db_domains.index(domain)
                if num == 0:
                    vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0])
                elif num == 1:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0])
                elif num == 2:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0])
                elif num == 3:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0])
                elif num == 4:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0])
                elif num >= 5:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1])
            else:
                idx = self.db_domains.index(domain)
                if num == 0:
                    vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0])
                elif num <= 2:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0])
                elif num <= 5:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0])
                elif num <= 10:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0])
                elif num <= 40:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0])
                elif num > 40:
                    vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1])
    
            return vector
    
    
    if __name__ == '__main__':
        vector = VectorBase()