diff --git a/convlab/policy/vtrace_DPT/README.md b/convlab/policy/vtrace_DPT/README.md index 002a8a050cc8bf573761a1b5ba2276d844a6db7d..3b9e6f4216f165aa99862d4beca554eee5430669 100644 --- a/convlab/policy/vtrace_DPT/README.md +++ b/convlab/policy/vtrace_DPT/README.md @@ -47,11 +47,37 @@ Moreover, you can specify the full dialogue pipeline here, such as the user poli Parameters that are tied to the RL algorithm and the model architecture can be changed in config.json. +NOTE: you can specify which underlying dataset should be used for creating the action and state space through changing in your **environment-config** + +``` +environment_config["vectorizer_sys"]["dataset_name"] = dataset_name +``` +For instance, dataset_name = "multiwoz21" or dataset_name = "sgd". ## Evaluation For creating evaluation plots and running evaluation dialogues, please have a look in the README of the policy folder. +## Interface + +To use trained models in a dialog system, import them through: + +```python +from convlab.policy.vector.vector_nodes import VectorNodes +from convlab.policy.vtrace_DPT import VTRACE + +vectorizer = VectorNodes(dataset_name='multiwoz21', + use_masking=False, + manually_add_entity_names=True, + seed=0, + filter_state=True) +ddpt = VTRACE(is_train=True, + seed=0, + vectorizer=vectorizer, + load_path="ddpt") +``` +Specify the appropriate load_path in VTRACE. + ## References ``` diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index b03662c60539a2e3aa80a5132618a3f7563a0f09..0474a3641694104625244150414793e97afa0268 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -59,6 +59,7 @@ class VTRACE(nn.Module, Policy): self.last_action = None self.vector = vectorizer + self.cfg['dataset_name'] = self.vector.dataset_name self.policy = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE) self.value_helper = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE) @@ -338,6 +339,7 @@ class VTRACE(nn.Module, Policy): if os.path.exists(policy_mdl): self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) self.value_helper.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) + print(f"Loaded policy checkpoint from file: {policy_mdl}") logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) break