diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 821f7271c6dadbf9c4f3ef5e766011513ac89be0..5d0a42b3331160c1efe604bbc965eba01e9860ae 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -3,6 +3,7 @@ import os import sys import numpy as np import logging +import json from copy import deepcopy from convlab.policy.vec import Vector @@ -79,8 +80,13 @@ class VectorBase(Vector): print("Load actions from file..") with open(os.path.join(dir_path, "sys_da_voc.txt")) as f: self.da_voc = f.read().splitlines() + if self.da_voc[0][0] != "[": + # if act is not a list, we still have the old action dict + self.load_actions_from_data() + self.da_voc = [tuple(json.loads(act)) for act in self.da_voc] with open(os.path.join(dir_path, "user_da_voc.txt")) as f: self.da_voc_opp = f.read().splitlines() + self.da_voc_opp = [tuple(json.loads(act)) for act in self.da_voc_opp] self.generate_dict() @@ -104,14 +110,14 @@ class VectorBase(Vector): if turn['speaker'] == 'system': for act in delex_acts: - act = "_".join(act) + act = tuple(act) if act not in system_dict: system_dict[act] = 1 else: system_dict[act] += 1 else: for act in delex_acts: - act = "_".join(act) + act = tuple(act) if act not in user_dict: user_dict[act] = 1 else: @@ -138,10 +144,10 @@ class VectorBase(Vector): os.makedirs(dir_path, exist_ok=True) with open(os.path.join(dir_path, "sys_da_voc.txt"), "w") as f: for act in self.da_voc: - f.write(act + "\n") + f.write(json.dumps(act) + "\n") with open(os.path.join(dir_path, "user_da_voc.txt"), "w") as f: for act in self.da_voc_opp: - f.write(act + "\n") + f.write(json.dumps(act) + "\n") def load_actions_from_ontology(self): """ @@ -240,7 +246,7 @@ class VectorBase(Vector): for i in range(self.da_dim): action = self.vec2act[i] - action_domain = action.split('_')[0] + action_domain = action[0] if action_domain in domain_active_dict.keys(): if not domain_active_dict[action_domain]: mask_list[i] = 1.0 @@ -253,7 +259,7 @@ class VectorBase(Vector): for i in range(self.da_dim): action = self.vec2act[i] - domain, intent, slot, value = action.split('_') + domain, intent, slot, value = action # NoBook/NoOffer-SLOT does not make sense because policy can not know which constraint made offer impossible # If one wants to do it, lexicaliser needs to do it @@ -281,7 +287,7 @@ class VectorBase(Vector): return mask_list for i in range(self.da_dim): action = self.vec2act[i] - domain, intent, slot, value = action.split('_') + domain, intent, slot, value = action domain_entities = number_entities_dict.get(domain, 1) if intent in ['inform', 'select', 'recommend'] and value != None and value != 'none': @@ -376,29 +382,29 @@ class VectorBase(Vector): if len(act_array) == 0: if self.reqinfo_filler_action: - act_array.append('general_reqinfo_none_none') + act_array.append(("general", "reqinfo", "none", "none")) else: - act_array.append('general_reqmore_none_none') + act_array.append(("general", "reqmore", "none", "none")) action = deflat_da(act_array) entities = {} for domint in action: - domain, intent = domint.split('_') + domain, intent = domint if domain not in entities and domain not in ['general']: entities[domain] = self.dbquery_domain(domain) # From db query find which slot causes no_offer - nooffer = [domint for domint in action if 'nooffer' in domint] + nooffer = [domint for domint in action if 'nooffer' in domint[1]] for domint in nooffer: - domain, intent = domint.split('_') + domain, intent = domint slot = self.find_nooffer_slot(domain) action[domint] = [[slot, '1'] ] if slot != 'none' else [[slot, 'none']] # Randomly select booking constraint "causing" no_book - nobook = [domint for domint in action if 'nobook' in domint] + nobook = [domint for domint in action if 'nobook' in domint[1]] for domint in nobook: - domain, intent = domint.split('_') + domain, intent = domint if domain in self.state: slots = self.state[domain] slots = [slot for slot, i in slots.items() @@ -430,17 +436,19 @@ class VectorBase(Vector): if not self.use_none: # replace all occurences of "none" with an empty string "" - action = [[a_string.replace('none', '') for a_string in a_list] for a_list in action] + f = lambda x: x if x != "none" else "" + action = [[f(x) for x in a_list] for a_list in action] + #action = [[ for a_tuple in a_list] for a_list in action] return action def add_booking_reference(self, action): new_acts = {} for domint in action: - domain, intent = domint.split('_', 1) + domain, intent = domint if intent == 'book' and action[domint]: - ref_domint = f'{domain}_inform' + ref_domint = (domain, "inform") if ref_domint not in new_acts: new_acts[ref_domint] = [] new_acts[ref_domint].append(['ref', '1']) @@ -458,14 +466,14 @@ class VectorBase(Vector): name_inform = {domain: [] for domain in self.domains} # General Inform Condition for Naming - domains = [domint.split('_', 1)[0] for domint in action] + domains = [domint[0] for domint in action] domains = list(set([d for d in domains if d not in ['general']])) for domain in domains: contains_name = False if domain == 'none': raise NameError('Domain not defined') - cur_inform = domain + '_inform' - cur_request = domain + '_request' + cur_inform = (domain, "inform") + cur_request = (domain, "request") index = -1 if cur_inform in action: # Check if current inform within a domain is accompanied by a name inform diff --git a/convlab/policy/vector/vector_binary.py b/convlab/policy/vector/vector_binary.py index e780dc645043f4775b208479abf022dccce649a5..c6b02a1122adac3002ad1e32dd1f495046d0e6be 100755 --- a/convlab/policy/vector/vector_binary.py +++ b/convlab/policy/vector/vector_binary.py @@ -94,9 +94,10 @@ class VectorBinary(VectorBase): def vectorize_system_act(self, state): action = state['system_action'] if self.character == 'sys' else state['user_action'] action = delexicalize_da(action, self.requestable) - action = flat_da(action) + #action = flat_da(action) last_act_vec = np.zeros(self.da_dim) for da in action: + da = tuple(da) if da in self.act2vec: last_act_vec[self.act2vec[da]] = 1. return last_act_vec @@ -104,9 +105,10 @@ class VectorBinary(VectorBase): def vectorize_user_act(self, state): action = state['user_action'] if self.character == 'sys' else state['system_action'] opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) + #opp_action = flat_da(opp_action) opp_act_vec = np.zeros(self.da_opp_dim) for da in opp_action: + da = tuple(da) if da in self.opp2vec: prob = 1.0 opp_act_vec[self.opp2vec[da]] = prob diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py index 24b1c1045a55960949c4d5747c066fff7c5906e9..2c7712bc9d7df74d9ae35a36bd8bb9edd4886c60 100644 --- a/convlab/policy/vector/vector_nodes.py +++ b/convlab/policy/vector/vector_nodes.py @@ -116,11 +116,12 @@ class VectorNodes(VectorBase): feature_type = 'last system act' action = state['system_action'] if self.character == 'sys' else state['user_action'] action = delexicalize_da(action, self.requestable) - action = flat_da(action) + #action = flat_da(action) for da in action: + da = tuple(da) if da in self.act2vec: - domain = da.split('_')[0] - description = "system-" + da + domain = da[0] + description = "system-" + "_".join(da) value = 1.0 self.add_graph_node(domain, feature_type, description.lower(), value) @@ -129,12 +130,13 @@ class VectorNodes(VectorBase): feature_type = 'user act' action = state['user_action'] if self.character == 'sys' else state['system_action'] opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) + #opp_action = flat_da(opp_action) for da in opp_action: + da = tuple(da) if da in self.opp2vec: - domain = da.split('_')[0] - description = "user-" + da + domain = da[0] + description = "user-" + "_".join(da) value = 1.0 self.add_graph_node(domain, feature_type, description.lower(), value) diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py index afe8a5b89e2caac03217709f9d36632cbe3904c2..20bf9736b78ed75ae23741c141413aad749c8979 100644 --- a/convlab/policy/vector/vector_uncertainty.py +++ b/convlab/policy/vector/vector_uncertainty.py @@ -95,12 +95,13 @@ class VectorUncertainty(VectorBinary): self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None action = state['user_action'] if self.character == 'sys' else state['system_action'] opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) + #opp_action = flat_da(opp_action) opp_act_vec = np.zeros(self.da_opp_dim) for da in opp_action: + da = tuple(da) if da in self.opp2vec: if 'belief_state_probs' in state and self.use_confidence_scores: - domain, intent, slot, value = da.split('_') + domain, intent, slot, value = da if domain in state['belief_state_probs']: slot = slot if slot else 'none' if slot in state['belief_state_probs'][domain]: diff --git a/convlab/policy/vtrace_DPT/create_descriptions.py b/convlab/policy/vtrace_DPT/create_descriptions.py index c6e88daba8132dd30c0aaeeff23e6e2b619e1c92..138861262d20aa43901ef8774571d9babf1a94ed 100644 --- a/convlab/policy/vtrace_DPT/create_descriptions.py +++ b/convlab/policy/vtrace_DPT/create_descriptions.py @@ -20,14 +20,8 @@ def create_description_dicts(name='multiwoz21'): db = None db_domains = [] - root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - voc_file = os.path.join(root_dir, f'vector/action_dicts/{name}_VectorBinary/sys_da_voc.txt') - voc_opp_file = os.path.join(root_dir, f'vector/action_dicts/{name}_VectorBinary/user_da_voc.txt') - - with open(voc_file) as f: - da_voc = f.read().splitlines() - with open(voc_opp_file) as f: - da_voc_opp = f.read().splitlines() + da_voc = vector.da_voc + da_voc_opp = vector.da_voc_opp description_dict_semantic = {} @@ -47,13 +41,15 @@ def create_description_dicts(name='multiwoz21'): description_dict_semantic[f"general-{domain}"] = f"domain {domain}" for act in da_voc: - domain, intent, slot, value = act.split("_") + domain, intent, slot, value = act domain = domain.lower() + act = "_".join(act) description_dict_semantic["system-"+act.lower()] = f"last system act {domain} {intent} {slot} {value}" for act in da_voc_opp: - domain, intent, slot, value = [item.lower() for item in act.split("_")] + domain, intent, slot, value = [item.lower() for item in act] domain = domain.lower() + act = "_".join(act) description_dict_semantic["user-"+act.lower()] = f"user act {domain} {intent} {slot} {value}" root_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py index 8ec1d059388048be63be891806e9f73f724f531a..243ba344cba84b5a6f86288b5db45cf9acf9b695 100644 --- a/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py +++ b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py @@ -24,7 +24,7 @@ class ActionEmbedder(nn.Module): = self.create_dicts(action_dict) #EOS token is considered a "domain" - self.action_dict = dict((key.lower(), value) for key, value in action_dict.items()) + self.action_dict = dict((key, value) for key, value in action_dict.items()) self.action_dict_reversed = dict((value, key) for key, value in self.action_dict.items()) self.embed_domain = torch.randn(len(self.domain_dict), embedding_dim) self.embed_intent = torch.randn(len(self.intent_dict), embedding_dim) @@ -88,19 +88,19 @@ class ActionEmbedder(nn.Module): elif not intent: # Domain was selected, check intents that are allowed for intent in self.intent_dict: - domain_intent = f"{domain}_{intent}" for idx, not_allow in enumerate(legal_mask): semantic_act = self.action_dict_reversed[idx] - if domain_intent in semantic_act and not_allow == 0: + if domain == semantic_act[0] and intent == semantic_act[1] and not_allow == 0: action_mask[self.small_action_dict[intent]] = 0 break else: # Selected domain and intent, need slot-value for slot_value in self.slot_value_dict: - domain_intent_slot = f"{domain}_{intent}_{slot_value}" + slot, value = slot_value for idx, not_allow in enumerate(legal_mask): semantic_act = self.action_dict_reversed[idx] - if domain_intent_slot in semantic_act and not_allow == 0: + if domain == semantic_act[0] and intent == semantic_act[1] \ + and slot == semantic_act[2] and value == semantic_act[3] and not_allow == 0: action_mask[self.small_action_dict[slot_value]] = 0 break @@ -128,14 +128,15 @@ class ActionEmbedder(nn.Module): elif not intent: # Domain was selected, need intent now for intent in self.intent_dict: - domain_intent = f"{domain}_{intent}" - valid = self.is_valid(domain_intent + "_") + domain_intent = (domain, intent) + valid = self.is_valid(domain_intent) if valid: action_mask[self.small_action_dict[intent]] = 0 else: # Selected domain and intent, need slot-value for slot_value in self.slot_value_dict: - domain_intent_slot = f"{domain}_{intent}_{slot_value}" + slot, value = slot_value + domain_intent_slot = (domain, intent, slot, value) valid = self.is_valid(domain_intent_slot) if valid: action_mask[self.small_action_dict[slot_value]] = 0 @@ -160,9 +161,8 @@ class ActionEmbedder(nn.Module): def is_valid(self, part_action): for act in self.action_dict: - if act.startswith(part_action): + if part_action == act[:len(part_action)]: return True - return False def create_action_embeddings(self, embedding_dim): @@ -178,7 +178,7 @@ class ActionEmbedder(nn.Module): action_embeddings[len(small_action_dict)] = self.embed_intent[idx] small_action_dict[intent] = len(small_action_dict) for slot_value in self.slot_value_dict: - slot, value = slot_value.split("_") + slot, value = slot_value slot_idx = self.slot_dict[slot] value_idx = self.value_dict[value] action_embeddings[len(small_action_dict)] = torch.cat( @@ -201,7 +201,7 @@ class ActionEmbedder(nn.Module): action_embeddings.append(intent) small_action_dict[intent] = len(small_action_dict) for slot_value in self.slot_value_dict: - slot, value = slot_value.split("_") + slot, value = slot_value action_embeddings.append(f"{slot} {value}") small_action_dict[slot_value] = len(small_action_dict) @@ -211,7 +211,7 @@ class ActionEmbedder(nn.Module): action_embeddings_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'action_embeddings_{self.dataset_name}.pt') small_action_dict_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), - f'small_action_dict_{self.dataset_name}.json') + f'small_action_dict_{self.dataset_name}.txt') if os.path.exists(action_embeddings_path): self.action_embeddings = torch.load(action_embeddings_path).to(DEVICE) @@ -220,11 +220,15 @@ class ActionEmbedder(nn.Module): torch.save(self.action_embeddings, action_embeddings_path) if os.path.exists(small_action_dict_path): - self.small_action_dict = json.load(open(small_action_dict_path, 'r')) + with open(os.path.join(small_action_dict_path)) as f: + self.small_action_dict = f.read().splitlines() + self.small_action_dict = [tuple(json.loads(act)) for act in self.small_action_dict] + self.small_action_dict = dict((name, idx) for idx, name in enumerate(self.small_action_dict)) else: self.small_action_dict = small_action_dict with open(small_action_dict_path, 'w') as f: - json.dump(self.small_action_dict, f) + for act in self.small_action_dict: + f.write(json.dumps(act) + "\n") self.small_action_dict = small_action_dict @@ -235,7 +239,7 @@ class ActionEmbedder(nn.Module): value_dict = {} slot_value_dict = {} for action in action_dict: - domain, intent, slot, value = [act.lower() for act in action.split('_')] + domain, intent, slot, value = [act.lower() for act in action] if domain not in domain_dict: domain_dict[domain] = len(domain_dict) if intent not in intent_dict: @@ -244,8 +248,8 @@ class ActionEmbedder(nn.Module): slot_dict[slot] = len(slot_dict) if value not in value_dict: value_dict[value] = len(value_dict) - if slot + "_" + value not in slot_value_dict: - slot_value_dict[slot + "_" + value] = len(slot_value_dict) + if (slot, value) not in slot_value_dict: + slot_value_dict[(slot, value)] = len(slot_value_dict) domain_dict['eos'] = len(domain_dict) @@ -255,17 +259,17 @@ class ActionEmbedder(nn.Module): #print("SMALL ACTION LIST:", small_action_list) action_vector = torch.zeros(len(self.action_dict)) - act_string = "" + act_list = [] for idx, act in enumerate(small_action_list): if act == 'eos': break if idx % 3 != 2: - act_string += f"{act}_" + act_list.append(act) else: - act_string += act - action_vector[self.action_dict[act_string]] = 1 - act_string = "" + act_list += list(act) + action_vector[self.action_dict[tuple(act_list)]] = 1 + act_list = [] return action_vector @@ -278,7 +282,8 @@ class ActionEmbedder(nn.Module): action_list = [] for idx, i in enumerate(action): if i == 1: - action_list += self.action_dict_reversed[idx].split("_", 2) + d, i, s, v = self.action_dict_reversed[idx] + action_list += [d, i, (s, v)] if permute and len(action_list) > 3: action_list_new = deepcopy(action_list[-3:]) + deepcopy(action_list[:-3]) diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt deleted file mode 100644 index 2f7e5e9cdf949b082b9ae752c09d2eafeb322bfb..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt deleted file mode 100644 index 67e557ce5ef7a0bfd40c60ae5a03937b47b92de2..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt deleted file mode 100644 index 5c6193163b6a1fc788b959eff2d0153b3a4bf7cb..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt deleted file mode 100644 index 619824588654a36cab3bf795e6fe94527b04ba68..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json deleted file mode 100644 index 0d5bd2002fa7d0082e7589b80ae3664781732ece..0000000000000000000000000000000000000000 --- a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json +++ /dev/null @@ -1 +0,0 @@ -{"attraction": 0, "general": 1, "hospital": 2, "hotel": 3, "police": 4, "restaurant": 5, "taxi": 6, "train": 7, "eos": 8, "inform": 9, "nooffer": 10, "recommend": 11, "request": 12, "select": 13, "bye": 14, "greet": 15, "reqmore": 16, "welcome": 17, "book": 18, "offerbook": 19, "nobook": 20, "address-1": 21, "address-2": 22, "address-3": 23, "area-1": 24, "area-2": 25, "area-3": 26, "choice-1": 27, "choice-2": 28, "choice-3": 29, "entrance fee-1": 30, "entrance fee-2": 31, "name-1": 32, "name-2": 33, "name-3": 34, "name-4": 35, "phone-1": 36, "postcode-1": 37, "type-1": 38, "type-2": 39, "type-3": 40, "type-4": 41, "type-5": 42, "none-none": 43, "area-?": 44, "entrance fee-?": 45, "name-?": 46, "type-?": 47, "department-1": 48, "department-?": 49, "book day-1": 50, "book people-1": 51, "book stay-1": 52, "internet-1": 53, "parking-1": 54, "price range-1": 55, "price range-2": 56, "ref-1": 57, "stars-1": 58, "stars-2": 59, "book day-?": 60, "book people-?": 61, "book stay-?": 62, "internet-?": 63, "parking-?": 64, "price range-?": 65, "stars-?": 66, "book time-1": 67, "food-1": 68, "food-2": 69, "food-3": 70, "food-4": 71, "postcode-2": 72, "book time-?": 73, "food-?": 74, "arrive by-1": 75, "departure-1": 76, "destination-1": 77, "leave at-1": 78, "arrive by-?": 79, "departure-?": 80, "destination-?": 81, "leave at-?": 82, "arrive by-2": 83, "day-1": 84, "duration-1": 85, "leave at-2": 86, "leave at-3": 87, "price-1": 88, "train id-1": 89, "day-?": 90, "pad": 91} \ No newline at end of file diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json deleted file mode 100644 index 0d4ec8d8a388236d6f65cf9b9c2c28560791b818..0000000000000000000000000000000000000000 --- a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json +++ /dev/null @@ -1 +0,0 @@ -{"attraction": 0, "general": 1, "hospital": 2, "hotel": 3, "police": 4, "restaurant": 5, "taxi": 6, "train": 7, "eos": 8, "inform": 9, "nooffer": 10, "recommend": 11, "request": 12, "select": 13, "bye": 14, "greet": 15, "reqmore": 16, "welcome": 17, "book": 18, "offerbook": 19, "nobook": 20, "address_1": 21, "address_2": 22, "address_3": 23, "area_1": 24, "area_2": 25, "area_3": 26, "choice_1": 27, "choice_2": 28, "choice_3": 29, "entrance fee_1": 30, "entrance fee_2": 31, "name_1": 32, "name_2": 33, "name_3": 34, "name_4": 35, "phone_1": 36, "postcode_1": 37, "type_1": 38, "type_2": 39, "type_3": 40, "type_4": 41, "type_5": 42, "none_none": 43, "area_?": 44, "entrance fee_?": 45, "name_?": 46, "type_?": 47, "department_1": 48, "department_?": 49, "book day_1": 50, "book people_1": 51, "book stay_1": 52, "internet_1": 53, "parking_1": 54, "price range_1": 55, "price range_2": 56, "ref_1": 57, "stars_1": 58, "stars_2": 59, "book day_?": 60, "book people_?": 61, "book stay_?": 62, "internet_?": 63, "parking_?": 64, "price range_?": 65, "stars_?": 66, "book time_1": 67, "food_1": 68, "food_2": 69, "food_3": 70, "food_4": 71, "postcode_2": 72, "book time_?": 73, "food_?": 74, "arrive by_1": 75, "departure_1": 76, "destination_1": 77, "leave at_1": 78, "arrive by_?": 79, "departure_?": 80, "destination_?": 81, "leave at_?": 82, "arrive by_2": 83, "day_1": 84, "duration_1": 85, "leave at_2": 86, "leave at_3": 87, "price_1": 88, "train id_1": 89, "day_?": 90, "pad": 91} \ No newline at end of file diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json deleted file mode 100644 index c48573262f50cabe1d4cd3811638ebf4c046a7e0..0000000000000000000000000000000000000000 --- a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json +++ /dev/null @@ -1 +0,0 @@ -{"": 0, "alarm_1": 1, "banks_1": 2, "banks_2": 3, "buses_1": 4, "buses_2": 5, "buses_3": 6, "calendar_1": 7, "events_1": 8, "events_2": 9, "events_3": 10, "flights_1": 11, "flights_2": 12, "flights_3": 13, "flights_4": 14, "homes_1": 15, "homes_2": 16, "hotels_1": 17, "hotels_2": 18, "hotels_3": 19, "hotels_4": 20, "media_1": 21, "media_2": 22, "media_3": 23, "messaging_1": 24, "movies_1": 25, "movies_2": 26, "movies_3": 27, "music_1": 28, "music_2": 29, "music_3": 30, "payment_1": 31, "rentalcars_1": 32, "rentalcars_2": 33, "rentalcars_3": 34, "restaurants_1": 35, "restaurants_2": 36, "ridesharing_1": 37, "ridesharing_2": 38, "services_1": 39, "services_2": 40, "services_3": 41, "services_4": 42, "trains_1": 43, "travel_1": 44, "weather_1": 45, "eos": 46, "goodbye": 47, "req_more": 48, "confirm": 49, "inform_count": 50, "notify_success": 51, "offer": 52, "offer_intent": 53, "request": 54, "inform": 55, "notify_failure": 56, "none-none": 57, "new_alarm_name-1": 58, "new_alarm_time-1": 59, "count-1": 60, "alarm_name-1": 61, "alarm_time-1": 62, "addalarm-1": 63, "new_alarm_time-?": 64, "account_type-1": 65, "amount-1": 66, "recipient_account_name-1": 67, "recipient_account_type-1": 68, "balance-1": 69, "transfermoney-1": 70, "account_type-?": 71, "amount-?": 72, "recipient_account_name-?": 73, "recipient_name-1": 74, "transfer_amount-1": 75, "transfer_time-1": 76, "account_balance-1": 77, "recipient_name-?": 78, "transfer_amount-?": 79, "from_location-1": 80, "leaving_date-1": 81, "leaving_time-1": 82, "to_location-1": 83, "travelers-1": 84, "from_station-1": 85, "to_station-1": 86, "transfers-1": 87, "fare-1": 88, "buybusticket-1": 89, "from_location-?": 90, "leaving_date-?": 91, "leaving_time-?": 92, "to_location-?": 93, "travelers-?": 94, "departure_date-1": 95, "departure_time-1": 96, "destination-1": 97, "fare_type-1": 98, "group_size-1": 99, "origin-1": 100, "destination_station_name-1": 101, "origin_station_name-1": 102, "price-1": 103, "departure_date-?": 104, "departure_time-?": 105, "destination-?": 106, "group_size-?": 107, "origin-?": 108, "additional_luggage-1": 109, "from_city-1": 110, "num_passengers-1": 111, "to_city-1": 112, "category-1": 113, "from_city-?": 114, "num_passengers-?": 115, "to_city-?": 116, "event_date-1": 117, "event_location-1": 118, "event_name-1": 119, "event_time-1": 120, "available_end_time-1": 121, "available_start_time-1": 122, "addevent-1": 123, "event_date-?": 124, "event_location-?": 125, "event_name-?": 126, "event_time-?": 127, "city_of_event-1": 128, "date-1": 129, "number_of_seats-1": 130, "address_of_location-1": 131, "subcategory-1": 132, "time-1": 133, "buyeventtickets-1": 134, "category-?": 135, "city_of_event-?": 136, "date-?": 137, "number_of_seats-?": 138, "city-1": 139, "number_of_tickets-1": 140, "venue-1": 141, "venue_address-1": 142, "city-?": 143, "event_type-?": 144, "number_of_tickets-?": 145, "price_per_ticket-1": 146, "airlines-1": 147, "destination_city-1": 148, "inbound_departure_time-1": 149, "origin_city-1": 150, "outbound_departure_time-1": 151, "passengers-1": 152, "return_date-1": 153, "seating_class-1": 154, "destination_airport-1": 155, "inbound_arrival_time-1": 156, "number_stops-1": 157, "origin_airport-1": 158, "outbound_arrival_time-1": 159, "refundable-1": 160, "reserveonewayflight-1": 161, "reserveroundtripflights-1": 162, "airlines-?": 163, "destination_city-?": 164, "inbound_departure_time-?": 165, "origin_city-?": 166, "outbound_departure_time-?": 167, "return_date-?": 168, "is_redeye-1": 169, "arrives_next_day-1": 170, "destination_airport_name-1": 171, "origin_airport_name-1": 172, "is_nonstop-1": 173, "destination_airport-?": 174, "origin_airport-?": 175, "property_name-1": 176, "visit_date-1": 177, "furnished-1": 178, "pets_allowed-1": 179, "phone_number-1": 180, "address-1": 181, "number_of_baths-1": 182, "number_of_beds-1": 183, "rent-1": 184, "schedulevisit-1": 185, "area-?": 186, "number_of_beds-?": 187, "visit_date-?": 188, "has_garage-1": 189, "in_unit_laundry-1": 190, "intent-?": 191, "number_of_baths-?": 192, "check_in_date-1": 193, "hotel_name-1": 194, "number_of_days-1": 195, "number_of_rooms-1": 196, "has_wifi-1": 197, "price_per_night-1": 198, "street_address-1": 199, "star_rating-1": 200, "reservehotel-1": 201, "check_in_date-?": 202, "hotel_name-?": 203, "number_of_days-?": 204, "check_out_date-1": 205, "number_of_adults-1": 206, "where_to-1": 207, "has_laundry_service-1": 208, "total_price-1": 209, "rating-1": 210, "bookhouse-1": 211, "check_out_date-?": 212, "number_of_adults-?": 213, "where_to-?": 214, "location-1": 215, "pets_welcome-1": 216, "average_rating-1": 217, "location-?": 218, "place_name-1": 219, "stay_length-1": 220, "smoking_allowed-1": 221, "stay_length-?": 222, "subtitles-1": 223, "title-1": 224, "directed_by-1": 225, "genre-1": 226, "title-2": 227, "title-3": 228, "playmovie-1": 229, "genre-?": 230, "title-?": 231, "movie_name-1": 232, "subtitle_language-1": 233, "movie_name-2": 234, "movie_name-3": 235, "rentmovie-1": 236, "starring-1": 237, "contact_name-1": 238, "contact_name-?": 239, "show_date-1": 240, "show_time-1": 241, "show_type-1": 242, "theater_name-1": 243, "buymovietickets-1": 244, "movie_name-?": 245, "show_date-?": 246, "show_time-?": 247, "show_type-?": 248, "aggregate_rating-1": 249, "cast-1": 250, "movie_title-1": 251, "percent_rating-1": 252, "playback_device-1": 253, "song_name-1": 254, "album-1": 255, "year-1": 256, "artist-1": 257, "playsong-1": 258, "song_name-?": 259, "playmedia-1": 260, "device-1": 261, "track-1": 262, "payment_method-1": 263, "private_visibility-1": 264, "receiver-1": 265, "payment_method-?": 266, "receiver-?": 267, "dropoff_date-1": 268, "pickup_date-1": 269, "pickup_location-1": 270, "pickup_time-1": 271, "type-1": 272, "car_name-1": 273, "reservecar-1": 274, "dropoff_date-?": 275, "pickup_city-?": 276, "pickup_date-?": 277, "pickup_location-?": 278, "pickup_time-?": 279, "type-?": 280, "car_type-1": 281, "car_type-?": 282, "add_insurance-1": 283, "end_date-1": 284, "start_date-1": 285, "price_per_day-1": 286, "add_insurance-?": 287, "end_date-?": 288, "start_date-?": 289, "party_size-1": 290, "restaurant_name-1": 291, "cuisine-1": 292, "has_live_music-1": 293, "price_range-1": 294, "serves_alcohol-1": 295, "reserverestaurant-1": 296, "cuisine-?": 297, "restaurant_name-?": 298, "time-?": 299, "has_seating_outdoors-1": 300, "has_vegetarian_options-1": 301, "number_of_riders-1": 302, "shared_ride-1": 303, "approximate_ride_duration-1": 304, "ride_fare-1": 305, "number_of_riders-?": 306, "shared_ride-?": 307, "ride_type-1": 308, "wait_time-1": 309, "ride_type-?": 310, "appointment_date-1": 311, "appointment_time-1": 312, "stylist_name-1": 313, "is_unisex-1": 314, "bookappointment-1": 315, "appointment_date-?": 316, "appointment_time-?": 317, "dentist_name-1": 318, "offers_cosmetic_services-1": 319, "doctor_name-1": 320, "therapist_name-1": 321, "class-1": 322, "date_of_journey-1": 323, "from-1": 324, "journey_start_time-1": 325, "to-1": 326, "trip_protection-1": 327, "total-1": 328, "gettraintickets-1": 329, "date_of_journey-?": 330, "from-?": 331, "to-?": 332, "trip_protection-?": 333, "free_entry-1": 334, "good_for_kids-1": 335, "attraction_name-1": 336, "humidity-1": 337, "wind-1": 338, "precipitation-1": 339, "temperature-1": 340, "pad": 341} \ No newline at end of file diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index a8df10672e5fe6e1ea8631e632d71bd0c2c7ba51..1e5f7ce69eb046e391d752533d024cf9b766ff66 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -34,8 +34,8 @@ def deflat_da(meta): meta = deepcopy(meta) dialog_act = {} for da in meta: - d, i, s, v = da.split('_') - k = '_'.join((d, i)) + d, i, s, v = da + k = (d, i) if k not in dialog_act: dialog_act[k] = [] dialog_act[k].append([s, v]) @@ -45,7 +45,7 @@ def deflat_da(meta): def lexicalize_da(meta, entities, state, requestable): meta = deepcopy(meta) for k, v in meta.items(): - domain, intent = k.split('_') + domain, intent = k if domain in ['general']: continue elif intent in requestable: @@ -99,6 +99,6 @@ def lexicalize_da(meta, entities, state, requestable): tuples = [] for domain_intent, svs in meta.items(): for slot, value in svs: - domain, intent = domain_intent.split('_') + domain, intent = domain_intent tuples.append([intent, domain, slot, value]) return tuples