diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index d85cca1086596f2e3a253a0057320b35fc1fcb03..566fd718c9726065680fb28ef81443a0af10e7cf 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -19,7 +19,7 @@ sys.path.append(root_dir) class VectorBase(Vector): def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=False, - always_inform_booking_reference=True, seed=0): + always_inform_booking_reference=True, seed=0, use_none=True): super().__init__() @@ -46,6 +46,7 @@ class VectorBase(Vector): self.always_inform_booking_reference = always_inform_booking_reference self.reqinfo_filler_action = None self.character = character + self.use_none = use_none self.requestable = ['request'] self.informable = ['inform', 'recommend'] @@ -103,14 +104,14 @@ class VectorBase(Vector): if turn['speaker'] == 'system': for act in delex_acts: - act = "-".join(act) + act = "_".join(act) if act not in system_dict: system_dict[act] = 1 else: system_dict[act] += 1 else: for act in delex_acts: - act = "-".join(act) + act = "_".join(act) if act not in user_dict: user_dict[act] = 1 else: @@ -206,18 +207,18 @@ class VectorBase(Vector): ''' if intent == 'request': - return [f"{domain}-{intent}-{slot}-?"] + return [f"{domain}_{intent}_{slot}_?"] if slot == '': - return [f"{domain}-{intent}-none-none"] + return [f"{domain}_{intent}_none_none"] if system: if intent in ['recommend', 'select', 'inform']: - return [f"{domain}-{intent}-{slot}-{i}" for i in range(1, 4)] + return [f"{domain}_{intent}_{slot}_{i}" for i in range(1, 4)] else: - return [f"{domain}-{intent}-{slot}-1"] + return [f"{domain}_{intent}_{slot}_1"] else: - return [f"{domain}-{intent}-{slot}-1"] + return [f"{domain}_{intent}_{slot}_1"] def init_domain_active_dict(self): domain_active_dict = {} @@ -239,7 +240,7 @@ class VectorBase(Vector): for i in range(self.da_dim): action = self.vec2act[i] - action_domain = action.split('-')[0] + action_domain = action.split('_')[0] if action_domain in domain_active_dict.keys(): if not domain_active_dict[action_domain]: mask_list[i] = 1.0 @@ -252,7 +253,7 @@ class VectorBase(Vector): for i in range(self.da_dim): action = self.vec2act[i] - domain, intent, slot, value = action.split('-') + domain, intent, slot, value = action.split('_') # NoBook/NoOffer-SLOT does not make sense because policy can not know which constraint made offer impossible # If one wants to do it, lexicaliser needs to do it @@ -280,7 +281,7 @@ class VectorBase(Vector): return mask_list for i in range(self.da_dim): action = self.vec2act[i] - domain, intent, slot, value = action.split('-') + domain, intent, slot, value = action.split('_') domain_entities = number_entities_dict.get(domain, 1) if intent in ['inform', 'select', 'recommend'] and value != None and value != 'none': @@ -372,21 +373,21 @@ class VectorBase(Vector): if len(act_array) == 0: if self.reqinfo_filler_action: - act_array.append('general-reqinfo-none-none') + act_array.append('general_reqinfo_none_none') else: - act_array.append('general-reqmore-none-none') + act_array.append('general_reqmore_none_none') action = deflat_da(act_array) entities = {} for domint in action: - domain, intent = domint.split('-') + domain, intent = domint.split('_') if domain not in entities and domain not in ['general']: entities[domain] = self.dbquery_domain(domain) # From db query find which slot causes no_offer nooffer = [domint for domint in action if 'nooffer' in domint] for domint in nooffer: - domain, intent = domint.split('-') + domain, intent = domint.split('_') slot = self.find_nooffer_slot(domain) action[domint] = [[slot, '1'] ] if slot != 'none' else [[slot, 'none']] @@ -394,7 +395,7 @@ class VectorBase(Vector): # Randomly select booking constraint "causing" no_book nobook = [domint for domint in action if 'nobook' in domint] for domint in nobook: - domain, intent = domint.split('-') + domain, intent = domint.split('_') if domain in self.state: slots = self.state[domain] slots = [slot for slot, i in slots.items() @@ -424,15 +425,19 @@ class VectorBase(Vector): index = idx action = lexicalize_da(action, entities, self.state, self.requestable) + if not self.use_none: + # replace all occurences of "none" with an empty string "" + action = [[a_string.replace('none', '') for a_string in a_list] for a_list in action] + return action def add_booking_reference(self, action): new_acts = {} for domint in action: - domain, intent = domint.split('-', 1) + domain, intent = domint.split('_', 1) if intent == 'book' and action[domint]: - ref_domint = f'{domain}-inform' + ref_domint = f'{domain}_inform' if ref_domint not in new_acts: new_acts[ref_domint] = [] new_acts[ref_domint].append(['ref', '1']) @@ -450,14 +455,14 @@ class VectorBase(Vector): name_inform = {domain: [] for domain in self.domains} # General Inform Condition for Naming - domains = [domint.split('-', 1)[0] for domint in action] + domains = [domint.split('_', 1)[0] for domint in action] domains = list(set([d for d in domains if d not in ['general']])) for domain in domains: contains_name = False if domain == 'none': raise NameError('Domain not defined') - cur_inform = domain + '-inform' - cur_request = domain + '-request' + cur_inform = domain + '_inform' + cur_request = domain + '_request' index = -1 if cur_inform in action: # Check if current inform within a domain is accompanied by a name inform diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py index 2e073669effc518cee4efd1f03d25bbd501b65af..24b1c1045a55960949c4d5747c066fff7c5906e9 100644 --- a/convlab/policy/vector/vector_nodes.py +++ b/convlab/policy/vector/vector_nodes.py @@ -119,7 +119,7 @@ class VectorNodes(VectorBase): action = flat_da(action) for da in action: if da in self.act2vec: - domain = da.split('-')[0] + domain = da.split('_')[0] description = "system-" + da value = 1.0 self.add_graph_node(domain, feature_type, description.lower(), value) @@ -133,7 +133,7 @@ class VectorNodes(VectorBase): for da in opp_action: if da in self.opp2vec: - domain = da.split('-')[0] + domain = da.split('_')[0] description = "user-" + da value = 1.0 self.add_graph_node(domain, feature_type, description.lower(), value) diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py index 7da05449dbe247cacf01bac4970ee669cd670c44..afe8a5b89e2caac03217709f9d36632cbe3904c2 100644 --- a/convlab/policy/vector/vector_uncertainty.py +++ b/convlab/policy/vector/vector_uncertainty.py @@ -100,7 +100,7 @@ class VectorUncertainty(VectorBinary): for da in opp_action: if da in self.opp2vec: if 'belief_state_probs' in state and self.use_confidence_scores: - domain, intent, slot, value = da.split('-') + domain, intent, slot, value = da.split('_') if domain in state['belief_state_probs']: slot = slot if slot else 'none' if slot in state['belief_state_probs'][domain]: diff --git a/convlab/policy/vtrace_DPT/create_descriptions.py b/convlab/policy/vtrace_DPT/create_descriptions.py index 5357fafafcf4972aa8eebaecf8c56e9fba79afe7..fa5f97313bc44e8208d5857f35a71b20e9ad90bd 100644 --- a/convlab/policy/vtrace_DPT/create_descriptions.py +++ b/convlab/policy/vtrace_DPT/create_descriptions.py @@ -47,12 +47,12 @@ def create_description_dicts(name='multiwoz21'): description_dict_semantic[f"general-{domain}"] = f"domain {domain}" for act in da_voc: - domain, intent, slot, value = act.split("-") + domain, intent, slot, value = act.split("_") domain = domain.lower() description_dict_semantic["system-"+act.lower()] = f"last system act {domain} {intent} {slot} {value}" for act in da_voc_opp: - domain, intent, slot, value = [item.lower() for item in act.split("-")] + domain, intent, slot, value = [item.lower() for item in act.split("_")] domain = domain.lower() description_dict_semantic["user-"+act.lower()] = f"user act {domain} {intent} {slot} {value}" diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py index 283443eb2daf55c56427e6de01f9cf6a2b4049cf..8ec1d059388048be63be891806e9f73f724f531a 100644 --- a/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py +++ b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py @@ -88,7 +88,7 @@ class ActionEmbedder(nn.Module): elif not intent: # Domain was selected, check intents that are allowed for intent in self.intent_dict: - domain_intent = f"{domain}-{intent}" + domain_intent = f"{domain}_{intent}" for idx, not_allow in enumerate(legal_mask): semantic_act = self.action_dict_reversed[idx] if domain_intent in semantic_act and not_allow == 0: @@ -97,7 +97,7 @@ class ActionEmbedder(nn.Module): else: # Selected domain and intent, need slot-value for slot_value in self.slot_value_dict: - domain_intent_slot = f"{domain}-{intent}-{slot_value}" + domain_intent_slot = f"{domain}_{intent}_{slot_value}" for idx, not_allow in enumerate(legal_mask): semantic_act = self.action_dict_reversed[idx] if domain_intent_slot in semantic_act and not_allow == 0: @@ -128,14 +128,14 @@ class ActionEmbedder(nn.Module): elif not intent: # Domain was selected, need intent now for intent in self.intent_dict: - domain_intent = f"{domain}-{intent}" - valid = self.is_valid(domain_intent + "-") + domain_intent = f"{domain}_{intent}" + valid = self.is_valid(domain_intent + "_") if valid: action_mask[self.small_action_dict[intent]] = 0 else: # Selected domain and intent, need slot-value for slot_value in self.slot_value_dict: - domain_intent_slot = f"{domain}-{intent}-{slot_value}" + domain_intent_slot = f"{domain}_{intent}_{slot_value}" valid = self.is_valid(domain_intent_slot) if valid: action_mask[self.small_action_dict[slot_value]] = 0 @@ -178,7 +178,7 @@ class ActionEmbedder(nn.Module): action_embeddings[len(small_action_dict)] = self.embed_intent[idx] small_action_dict[intent] = len(small_action_dict) for slot_value in self.slot_value_dict: - slot, value = slot_value.split("-") + slot, value = slot_value.split("_") slot_idx = self.slot_dict[slot] value_idx = self.value_dict[value] action_embeddings[len(small_action_dict)] = torch.cat( @@ -201,7 +201,7 @@ class ActionEmbedder(nn.Module): action_embeddings.append(intent) small_action_dict[intent] = len(small_action_dict) for slot_value in self.slot_value_dict: - slot, value = slot_value.split("-") + slot, value = slot_value.split("_") action_embeddings.append(f"{slot} {value}") small_action_dict[slot_value] = len(small_action_dict) @@ -235,7 +235,7 @@ class ActionEmbedder(nn.Module): value_dict = {} slot_value_dict = {} for action in action_dict: - domain, intent, slot, value = [act.lower() for act in action.split('-')] + domain, intent, slot, value = [act.lower() for act in action.split('_')] if domain not in domain_dict: domain_dict[domain] = len(domain_dict) if intent not in intent_dict: @@ -244,8 +244,8 @@ class ActionEmbedder(nn.Module): slot_dict[slot] = len(slot_dict) if value not in value_dict: value_dict[value] = len(value_dict) - if slot + "-" + value not in slot_value_dict: - slot_value_dict[slot + "-" + value] = len(slot_value_dict) + if slot + "_" + value not in slot_value_dict: + slot_value_dict[slot + "_" + value] = len(slot_value_dict) domain_dict['eos'] = len(domain_dict) @@ -261,7 +261,7 @@ class ActionEmbedder(nn.Module): break if idx % 3 != 2: - act_string += f"{act}-" + act_string += f"{act}_" else: act_string += act action_vector[self.action_dict[act_string]] = 1 @@ -278,7 +278,7 @@ class ActionEmbedder(nn.Module): action_list = [] for idx, i in enumerate(action): if i == 1: - action_list += self.action_dict_reversed[idx].split("-", 2) + action_list += self.action_dict_reversed[idx].split("_", 2) if permute and len(action_list) > 3: action_list_new = deepcopy(action_list[-3:]) + deepcopy(action_list[:-3]) diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt deleted file mode 100644 index 2f7e5e9cdf949b082b9ae752c09d2eafeb322bfb..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt deleted file mode 100644 index 5c6193163b6a1fc788b959eff2d0153b3a4bf7cb..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json deleted file mode 100644 index 0d5bd2002fa7d0082e7589b80ae3664781732ece..0000000000000000000000000000000000000000 --- a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json +++ /dev/null @@ -1 +0,0 @@ -{"attraction": 0, "general": 1, "hospital": 2, "hotel": 3, "police": 4, "restaurant": 5, "taxi": 6, "train": 7, "eos": 8, "inform": 9, "nooffer": 10, "recommend": 11, "request": 12, "select": 13, "bye": 14, "greet": 15, "reqmore": 16, "welcome": 17, "book": 18, "offerbook": 19, "nobook": 20, "address-1": 21, "address-2": 22, "address-3": 23, "area-1": 24, "area-2": 25, "area-3": 26, "choice-1": 27, "choice-2": 28, "choice-3": 29, "entrance fee-1": 30, "entrance fee-2": 31, "name-1": 32, "name-2": 33, "name-3": 34, "name-4": 35, "phone-1": 36, "postcode-1": 37, "type-1": 38, "type-2": 39, "type-3": 40, "type-4": 41, "type-5": 42, "none-none": 43, "area-?": 44, "entrance fee-?": 45, "name-?": 46, "type-?": 47, "department-1": 48, "department-?": 49, "book day-1": 50, "book people-1": 51, "book stay-1": 52, "internet-1": 53, "parking-1": 54, "price range-1": 55, "price range-2": 56, "ref-1": 57, "stars-1": 58, "stars-2": 59, "book day-?": 60, "book people-?": 61, "book stay-?": 62, "internet-?": 63, "parking-?": 64, "price range-?": 65, "stars-?": 66, "book time-1": 67, "food-1": 68, "food-2": 69, "food-3": 70, "food-4": 71, "postcode-2": 72, "book time-?": 73, "food-?": 74, "arrive by-1": 75, "departure-1": 76, "destination-1": 77, "leave at-1": 78, "arrive by-?": 79, "departure-?": 80, "destination-?": 81, "leave at-?": 82, "arrive by-2": 83, "day-1": 84, "duration-1": 85, "leave at-2": 86, "leave at-3": 87, "price-1": 88, "train id-1": 89, "day-?": 90, "pad": 91} \ No newline at end of file diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index 4fd9f262ece3fe38935c26bc449b1e93121db498..2fe3f299ca5a4fa8e194c29193f5caa926691c37 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -17,7 +17,7 @@ def delexicalize_da(da, requestable): if slot == 'none': v = 'none' else: - k = '-'.join([intent, domain, slot]) + k = '_'.join([intent, domain, slot]) counter.setdefault(k, 0) counter[k] += 1 v = str(counter[k]) @@ -26,7 +26,7 @@ def delexicalize_da(da, requestable): def flat_da(delexicalized_da): - flaten = ['-'.join(x) for x in delexicalized_da] + flaten = ['_'.join(x) for x in delexicalized_da] return flaten @@ -34,8 +34,8 @@ def deflat_da(meta): meta = deepcopy(meta) dialog_act = {} for da in meta: - d, i, s, v = da.split('-') - k = '-'.join((d, i)) + d, i, s, v = da.split('_') + k = '_'.join((d, i)) if k not in dialog_act: dialog_act[k] = [] dialog_act[k].append([s, v]) @@ -45,7 +45,7 @@ def deflat_da(meta): def lexicalize_da(meta, entities, state, requestable): meta = deepcopy(meta) for k, v in meta.items(): - domain, intent = k.split('-') + domain, intent = k.split('_') if domain in ['general']: continue elif intent in requestable: @@ -96,6 +96,6 @@ def lexicalize_da(meta, entities, state, requestable): tuples = [] for domain_intent, svs in meta.items(): for slot, value in svs: - domain, intent = domain_intent.split('-') + domain, intent = domain_intent.split('_') tuples.append([intent, domain, slot, value]) return tuples