Skip to content
Snippets Groups Projects
Unverified Commit 2ee60a42 authored by Christian Geishauser's avatar Christian Geishauser Committed by GitHub
Browse files

Merge pull request #149 from ConvLab/add_default_vectorizer_and_pretrained_loading

load default vectorizer if none is given and load a huggingface hub model in case from_pretrained is used as load_path
parents 87768ad6 56303a8c
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ from convlab.policy.vector.vector_binary import VectorBinary
from convlab.util.file_util import cached_path
import zipfile
import sys
import urllib.request
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(root_dir)
......@@ -22,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PG(Policy):
def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None):
def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
self.cfg = cfg
......@@ -32,20 +33,32 @@ class PG(Policy):
self.optim_batchsz = cfg['batchsz']
self.gamma = cfg['gamma']
self.is_train = is_train
self.vector = vectorizer
self.info_dict = {}
set_seed(seed)
self.vector = vectorizer
dir_name = os.path.dirname(os.path.abspath(__file__))
if self.vector is None:
logging.info("No vectorizer was set, using default..")
from convlab.policy.vector.vector_binary import VectorBinary
self.vector = VectorBinary()
self.vector = VectorBinary(dataset_name=kwargs['dataset_name'],
use_masking=kwargs.get('use_masking', True),
manually_add_entity_names=kwargs.get('manually_add_entity_names', True),
seed=seed)
if dataset == 'Multiwoz':
self.vector = vectorizer
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl")
load_path = f"{dir_name}/{self.vector.dataset_name}_mle"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
# self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
if is_train:
self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr'])
......@@ -209,3 +222,17 @@ class PG(Policy):
model = cls()
model.load_from_pretrained(archive_file, model_file, cfg['load'])
return model
def load_policy(self, filename=""):
policy_mdl_candidates = [
filename + '.pol.mdl',
filename + '_ppo.pol.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl')
]
for policy_mdl in policy_mdl_candidates:
if os.path.exists(policy_mdl):
print(f"Loaded policy checkpoint from file: {policy_mdl}")
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
\ No newline at end of file
......@@ -6,10 +6,13 @@ import numpy as np
import logging
import os
import json
from convlab.policy.vector.vector_binary import VectorBinary
from convlab.policy.policy import Policy
from convlab.policy.rlmodule import MultiDiscretePolicy, Value
from convlab.util.custom_util import model_downloader, set_seed
import sys
import urllib.request
root_dir = os.path.dirname(os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
......@@ -20,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PPO(Policy):
def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None):
def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f:
cfg = json.load(f)
......@@ -39,19 +42,30 @@ class PPO(Policy):
logging.info('PPO seed ' + str(seed))
set_seed(seed)
dir_name = os.path.dirname(os.path.abspath(__file__))
if self.vector is None:
logging.info("No vectorizer was set, using default..")
from convlab.policy.vector.vector_binary import VectorBinary
self.vector = VectorBinary()
self.vector = VectorBinary(dataset_name=kwargs['dataset_name'],
use_masking=kwargs.get('use_masking', True),
manually_add_entity_names=kwargs.get('manually_add_entity_names', True),
seed=seed)
# construct policy and value network
if dataset == 'Multiwoz':
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
self.vector.da_dim, seed).to(device=DEVICE)
logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}")
logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}")
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl")
load_path = f"{dir_name}/{self.vector.dataset_name}_mle"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
self.value = Value(self.vector.state_dim,
cfg['hv_dim']).to(device=DEVICE)
if is_train:
......@@ -263,6 +277,20 @@ class PPO(Policy):
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
def load_policy(self, filename=""):
policy_mdl_candidates = [
filename + '.pol.mdl',
filename + '_ppo.pol.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl')
]
for policy_mdl in policy_mdl_candidates:
if os.path.exists(policy_mdl):
print(f"Loaded policy checkpoint from file: {policy_mdl}")
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
# Load model from model_path(URL)
def load_from_pretrained(self, model_path=""):
......
......@@ -75,9 +75,9 @@ class VTRACE(nn.Module, Policy):
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
"https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl",
f"{dir_name}/ddpt.pol.mdl")
load_path = f"{dir_name}/ddpt"
f"https://huggingface.co/ConvLab/ddpt-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/{self.vector.dataset_name}_ddpt.pol.mdl")
load_path = f"{dir_name}/{self.vector.dataset_name}_ddpt"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment