Skip to content
Snippets Groups Projects
Commit 56303a8c authored by Christian's avatar Christian
Browse files

the policy module is loaded with a default vectorizer in case no vectorizer is...

the policy module is loaded with a default vectorizer in case no vectorizer is passed. Also, if the load_path argument is set to from_pretrained, it will automatically download a model path from the hugging face hub and load the model
parent 87768ad6
Branches add_default_vectorizer_and_pretrained_loading
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ from convlab.policy.vector.vector_binary import VectorBinary ...@@ -13,6 +13,7 @@ from convlab.policy.vector.vector_binary import VectorBinary
from convlab.util.file_util import cached_path from convlab.util.file_util import cached_path
import zipfile import zipfile
import sys import sys
import urllib.request
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(root_dir) sys.path.append(root_dir)
...@@ -22,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ...@@ -22,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PG(Policy): class PG(Policy):
def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f) cfg = json.load(f)
self.cfg = cfg self.cfg = cfg
...@@ -32,20 +33,32 @@ class PG(Policy): ...@@ -32,20 +33,32 @@ class PG(Policy):
self.optim_batchsz = cfg['batchsz'] self.optim_batchsz = cfg['batchsz']
self.gamma = cfg['gamma'] self.gamma = cfg['gamma']
self.is_train = is_train self.is_train = is_train
self.vector = vectorizer
self.info_dict = {} self.info_dict = {}
set_seed(seed) set_seed(seed)
self.vector = vectorizer
dir_name = os.path.dirname(os.path.abspath(__file__))
if self.vector is None: if self.vector is None:
logging.info("No vectorizer was set, using default..") logging.info("No vectorizer was set, using default..")
from convlab.policy.vector.vector_binary import VectorBinary self.vector = VectorBinary(dataset_name=kwargs['dataset_name'],
self.vector = VectorBinary() use_masking=kwargs.get('use_masking', True),
manually_add_entity_names=kwargs.get('manually_add_entity_names', True),
seed=seed)
if dataset == 'Multiwoz':
self.vector = vectorizer
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl")
load_path = f"{dir_name}/{self.vector.dataset_name}_mle"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
# self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) # self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
if is_train: if is_train:
self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr']) self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr'])
...@@ -209,3 +222,17 @@ class PG(Policy): ...@@ -209,3 +222,17 @@ class PG(Policy):
model = cls() model = cls()
model.load_from_pretrained(archive_file, model_file, cfg['load']) model.load_from_pretrained(archive_file, model_file, cfg['load'])
return model return model
def load_policy(self, filename=""):
policy_mdl_candidates = [
filename + '.pol.mdl',
filename + '_ppo.pol.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl')
]
for policy_mdl in policy_mdl_candidates:
if os.path.exists(policy_mdl):
print(f"Loaded policy checkpoint from file: {policy_mdl}")
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
\ No newline at end of file
...@@ -6,10 +6,13 @@ import numpy as np ...@@ -6,10 +6,13 @@ import numpy as np
import logging import logging
import os import os
import json import json
from convlab.policy.vector.vector_binary import VectorBinary
from convlab.policy.policy import Policy from convlab.policy.policy import Policy
from convlab.policy.rlmodule import MultiDiscretePolicy, Value from convlab.policy.rlmodule import MultiDiscretePolicy, Value
from convlab.util.custom_util import model_downloader, set_seed from convlab.util.custom_util import model_downloader, set_seed
import sys import sys
import urllib.request
root_dir = os.path.dirname(os.path.dirname( root_dir = os.path.dirname(os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
...@@ -20,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ...@@ -20,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PPO(Policy): class PPO(Policy):
def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f: with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f:
cfg = json.load(f) cfg = json.load(f)
...@@ -39,19 +42,30 @@ class PPO(Policy): ...@@ -39,19 +42,30 @@ class PPO(Policy):
logging.info('PPO seed ' + str(seed)) logging.info('PPO seed ' + str(seed))
set_seed(seed) set_seed(seed)
dir_name = os.path.dirname(os.path.abspath(__file__))
if self.vector is None: if self.vector is None:
logging.info("No vectorizer was set, using default..") logging.info("No vectorizer was set, using default..")
from convlab.policy.vector.vector_binary import VectorBinary self.vector = VectorBinary(dataset_name=kwargs['dataset_name'],
self.vector = VectorBinary() use_masking=kwargs.get('use_masking', True),
manually_add_entity_names=kwargs.get('manually_add_entity_names', True),
seed=seed)
# construct policy and value network
if dataset == 'Multiwoz':
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
self.vector.da_dim, seed).to(device=DEVICE) self.vector.da_dim, seed).to(device=DEVICE)
logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}") logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}")
logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}") logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}")
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl")
load_path = f"{dir_name}/{self.vector.dataset_name}_mle"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
self.value = Value(self.vector.state_dim, self.value = Value(self.vector.state_dim,
cfg['hv_dim']).to(device=DEVICE) cfg['hv_dim']).to(device=DEVICE)
if is_train: if is_train:
...@@ -263,6 +277,20 @@ class PPO(Policy): ...@@ -263,6 +277,20 @@ class PPO(Policy):
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break break
def load_policy(self, filename=""):
policy_mdl_candidates = [
filename + '.pol.mdl',
filename + '_ppo.pol.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl')
]
for policy_mdl in policy_mdl_candidates:
if os.path.exists(policy_mdl):
print(f"Loaded policy checkpoint from file: {policy_mdl}")
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
# Load model from model_path(URL) # Load model from model_path(URL)
def load_from_pretrained(self, model_path=""): def load_from_pretrained(self, model_path=""):
......
...@@ -75,9 +75,9 @@ class VTRACE(nn.Module, Policy): ...@@ -75,9 +75,9 @@ class VTRACE(nn.Module, Policy):
try: try:
if load_path == "from_pretrained": if load_path == "from_pretrained":
urllib.request.urlretrieve( urllib.request.urlretrieve(
"https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl", f"https://huggingface.co/ConvLab/ddpt-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/ddpt.pol.mdl") f"{dir_name}/{self.vector.dataset_name}_ddpt.pol.mdl")
load_path = f"{dir_name}/ddpt" load_path = f"{dir_name}/{self.vector.dataset_name}_ddpt"
self.load_policy(load_path) self.load_policy(load_path)
except Exception as e: except Exception as e:
print(f"Could not load the policy, Exception: {e}") print(f"Could not load the policy, Exception: {e}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment