Skip to content
Snippets Groups Projects
Commit 3154fc7d authored by Christian's avatar Christian
Browse files

updated vtrace readme

parent 5ff93d57
No related branches found
No related tags found
No related merge requests found
......@@ -47,11 +47,37 @@ Moreover, you can specify the full dialogue pipeline here, such as the user poli
Parameters that are tied to the RL algorithm and the model architecture can be changed in config.json.
NOTE: you can specify which underlying dataset should be used for creating the action and state space through changing in your **environment-config**
```
environment_config["vectorizer_sys"]["dataset_name"] = dataset_name
```
For instance, dataset_name = "multiwoz21" or dataset_name = "sgd".
## Evaluation
For creating evaluation plots and running evaluation dialogues, please have a look in the README of the policy folder.
## Interface
To use trained models in a dialog system, import them through:
```python
from convlab.policy.vector.vector_nodes import VectorNodes
from convlab.policy.vtrace_DPT import VTRACE
vectorizer = VectorNodes(dataset_name='multiwoz21',
use_masking=False,
manually_add_entity_names=True,
seed=0,
filter_state=True)
ddpt = VTRACE(is_train=True,
seed=0,
vectorizer=vectorizer,
load_path="ddpt")
```
Specify the appropriate load_path in VTRACE.
## References
```
......
......@@ -59,6 +59,7 @@ class VTRACE(nn.Module, Policy):
self.last_action = None
self.vector = vectorizer
self.cfg['dataset_name'] = self.vector.dataset_name
self.policy = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE)
self.value_helper = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE)
......@@ -338,6 +339,7 @@ class VTRACE(nn.Module, Policy):
if os.path.exists(policy_mdl):
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
self.value_helper.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
print(f"Loaded policy checkpoint from file: {policy_mdl}")
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment