diff --git a/convlab/policy/README.md b/convlab/policy/README.md index 6cdac456c14b50618ba8a8edc5ee947a6627d43a..233ea0eece10d0789791e3a97c5d1aa8311f232a 100755 --- a/convlab/policy/README.md +++ b/convlab/policy/README.md @@ -43,7 +43,7 @@ The necessary step before starting a training is to set up the environment and p ``` { "model": { - "load_path": "", # specify a loading path to load a pre-trained model + "load_path": "", # specify a loading path to load a pre-trained model, omit the ending .pol.mdl "use_pretrained_initialisation": false, # will download a provided ConvLab-3 model "pretrained_load_path": "", "seed": 0, # the seed for the experiment diff --git a/convlab/policy/mle/README.md b/convlab/policy/mle/README.md index db62faaef6855ba92fc449e6d19556269d9aa971..c13140497508d79ec6faedfc588fa4f6af7043f7 100644 --- a/convlab/policy/mle/README.md +++ b/convlab/policy/mle/README.md @@ -14,6 +14,9 @@ The dataset name can be "multiwoz21" or "sgd" for instance. The first time you r Other hyperparameters such as learning rate or number of epochs can be set in the config.json file. +We provide a model trained on multiwoz21 on hugging-face: https://huggingface.co/ConvLab/mle-policy-multiwoz21 + + ## Evaluation Evaluation on the validation data set takes place during training. \ No newline at end of file diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py index c2f6258f48dfc7b27f2bce6c17b7c6e1f92e7705..2e073669effc518cee4efd1f03d25bbd501b65af 100644 --- a/convlab/policy/vector/vector_nodes.py +++ b/convlab/policy/vector/vector_nodes.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- import sys import numpy as np +import logging + from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da from .vector_base import VectorBase @@ -8,9 +10,11 @@ from .vector_base import VectorBase class VectorNodes(VectorBase): def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True, - seed=0): + seed=0, filter_state=True): super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed) + self.filter_state = filter_state + logging.info(f"We filter state by active domains: {self.filter_state}") def get_state_dim(self): self.belief_state_dim = 0 @@ -56,9 +60,16 @@ class VectorNodes(VectorBase): self.get_user_act_feature(state) self.get_sys_act_feature(state) domain_active_dict = self.get_user_goal_feature(state, domain_active_dict) - number_entities_dict = self.get_db_features() self.get_general_features(state, domain_active_dict) + if self.db is not None: + number_entities_dict = self.get_db_features() + else: + number_entities_dict = None + + if self.filter_state: + self.kg_info = self.filter_inactive_domains(domain_active_dict) + if self.use_mask: mask = self.get_mask(domain_active_dict, number_entities_dict) for i in range(self.da_dim): @@ -89,13 +100,15 @@ class VectorNodes(VectorBase): feature_type = 'user goal' for domain in self.belief_domains: - for slot, value in state['belief_state'][domain].items(): - description = f"user goal-{domain}-{slot}".lower() - value = 1.0 if (value and value != "not mentioned") else 0.0 - self.add_graph_node(domain, feature_type, description, value) - - if [slot for slot, value in state['belief_state'][domain].items() if value]: - domain_active_dict[domain] = True + # the if case is needed because SGD only saves the dialogue state info for active domains + if domain in state['belief_state']: + for slot, value in state['belief_state'][domain].items(): + description = f"user goal-{domain}-{slot}".lower() + value = 1.0 if (value and value != "not mentioned") else 0.0 + self.add_graph_node(domain, feature_type, description, value) + + if [slot for slot, value in state['belief_state'][domain].items() if value]: + domain_active_dict[domain] = True return domain_active_dict def get_sys_act_feature(self, state): @@ -128,11 +141,12 @@ class VectorNodes(VectorBase): def get_general_features(self, state, domain_active_dict): feature_type = 'general' - for i, domain in enumerate(self.db_domains): - if domain in state['booked']: - description = f"general-{domain}-booked".lower() - value = 1.0 if state['booked'][domain] else 0.0 - self.add_graph_node(domain, feature_type, description, value) + if 'booked' in state: + for i, domain in enumerate(self.db_domains): + if domain in state['booked']: + description = f"general-{domain}-booked".lower() + value = 1.0 if state['booked'][domain] else 0.0 + self.add_graph_node(domain, feature_type, description, value) for domain in self.domains: if domain == 'general': @@ -140,3 +154,17 @@ class VectorNodes(VectorBase): value = 1.0 if domain_active_dict[domain] else 0 description = f"general-{domain}".lower() self.add_graph_node(domain, feature_type, description, value) + + def filter_inactive_domains(self, domain_active_dict): + + kg_filtered = [] + for node in self.kg_info: + domain = node['domain'] + if domain in domain_active_dict: + if domain_active_dict[domain]: + kg_filtered.append(node) + else: + kg_filtered.append(node) + + return kg_filtered + diff --git a/convlab/policy/vtrace_DPT/README.md b/convlab/policy/vtrace_DPT/README.md index 6dcee257bd655eb1e873fb5945de0f226d469210..002a8a050cc8bf573761a1b5ba2276d844a6db7d 100644 --- a/convlab/policy/vtrace_DPT/README.md +++ b/convlab/policy/vtrace_DPT/README.md @@ -20,7 +20,11 @@ You can specify the dataset that you would like to use, e.g. "multiwoz21" or "sg You can specify hyperparamters such as epoch, supervised_lr and data_percentage (how much of the data you want to use) in the config.json file. +We provide several supervised trained models on hugging-face to reproduce the results: +- pre-trained on SGD: https://huggingface.co/ConvLab/ddpt-policy-sgd +- pre-trained on 1% multiwoz21: https://huggingface.co/ConvLab/ddpt-policy-0.01multiwoz21 +- pre-trained on SGD and afterwards on 1% multiwoz21: https://huggingface.co/ConvLab/ddpt-policy-sgd_0.01multiwoz21 ## RL training