diff --git a/convlab/policy/pg/pg.py b/convlab/policy/pg/pg.py index 060be694d06748bae2e737ba81853b3b0de2c29d..2230ac8d51886ac3630a89f58bc493b896d99b4a 100755 --- a/convlab/policy/pg/pg.py +++ b/convlab/policy/pg/pg.py @@ -13,6 +13,7 @@ from convlab.policy.vector.vector_binary import VectorBinary from convlab.util.file_util import cached_path import zipfile import sys +import urllib.request root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(root_dir) @@ -22,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PG(Policy): - def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): + def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs): with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: cfg = json.load(f) self.cfg = cfg @@ -32,19 +33,31 @@ class PG(Policy): self.optim_batchsz = cfg['batchsz'] self.gamma = cfg['gamma'] self.is_train = is_train - self.vector = vectorizer self.info_dict = {} set_seed(seed) + self.vector = vectorizer + dir_name = os.path.dirname(os.path.abspath(__file__)) + if self.vector is None: logging.info("No vectorizer was set, using default..") - from convlab.policy.vector.vector_binary import VectorBinary - self.vector = VectorBinary() - - if dataset == 'Multiwoz': - self.vector = vectorizer - self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) + self.vector = VectorBinary(dataset_name=kwargs['dataset_name'], + use_masking=kwargs.get('use_masking', True), + manually_add_entity_names=kwargs.get('manually_add_entity_names', True), + seed=seed) + + self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) + + try: + if load_path == "from_pretrained": + urllib.request.urlretrieve( + f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl", + f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl") + load_path = f"{dir_name}/{self.vector.dataset_name}_mle" + self.load_policy(load_path) + except Exception as e: + print(f"Could not load the policy, Exception: {e}") # self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) if is_train: @@ -208,4 +221,18 @@ class PG(Policy): cfg = json.load(f) model = cls() model.load_from_pretrained(archive_file, model_file, cfg['load']) - return model \ No newline at end of file + return model + + def load_policy(self, filename=""): + policy_mdl_candidates = [ + filename + '.pol.mdl', + filename + '_ppo.pol.mdl', + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'), + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl') + ] + for policy_mdl in policy_mdl_candidates: + if os.path.exists(policy_mdl): + print(f"Loaded policy checkpoint from file: {policy_mdl}") + self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) + logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) + break \ No newline at end of file diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py index 28fee71c70c640319b29becc77bddfe8311f2767..ca70f2b6f81e716aee544c5d2404b296369964bf 100755 --- a/convlab/policy/ppo/ppo.py +++ b/convlab/policy/ppo/ppo.py @@ -6,10 +6,13 @@ import numpy as np import logging import os import json +from convlab.policy.vector.vector_binary import VectorBinary from convlab.policy.policy import Policy from convlab.policy.rlmodule import MultiDiscretePolicy, Value from convlab.util.custom_util import model_downloader, set_seed import sys +import urllib.request + root_dir = os.path.dirname(os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) @@ -20,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PPO(Policy): - def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): + def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs): with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f: cfg = json.load(f) @@ -39,18 +42,29 @@ class PPO(Policy): logging.info('PPO seed ' + str(seed)) set_seed(seed) + dir_name = os.path.dirname(os.path.abspath(__file__)) if self.vector is None: logging.info("No vectorizer was set, using default..") - from convlab.policy.vector.vector_binary import VectorBinary - self.vector = VectorBinary() - - # construct policy and value network - if dataset == 'Multiwoz': - self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], - self.vector.da_dim, seed).to(device=DEVICE) - logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}") - logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}") + self.vector = VectorBinary(dataset_name=kwargs['dataset_name'], + use_masking=kwargs.get('use_masking', True), + manually_add_entity_names=kwargs.get('manually_add_entity_names', True), + seed=seed) + + self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], + self.vector.da_dim, seed).to(device=DEVICE) + logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}") + logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}") + + try: + if load_path == "from_pretrained": + urllib.request.urlretrieve( + f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl", + f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl") + load_path = f"{dir_name}/{self.vector.dataset_name}_mle" + self.load_policy(load_path) + except Exception as e: + print(f"Could not load the policy, Exception: {e}") self.value = Value(self.vector.state_dim, cfg['hv_dim']).to(device=DEVICE) @@ -263,6 +277,20 @@ class PPO(Policy): logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) break + def load_policy(self, filename=""): + policy_mdl_candidates = [ + filename + '.pol.mdl', + filename + '_ppo.pol.mdl', + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'), + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl') + ] + for policy_mdl in policy_mdl_candidates: + if os.path.exists(policy_mdl): + print(f"Loaded policy checkpoint from file: {policy_mdl}") + self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) + logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) + break + # Load model from model_path(URL) def load_from_pretrained(self, model_path=""): diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index 5b031c4c297d075d3de66b1f776101a3d8a2e614..85d239e3299c73999042b9b58b34dc039fe1057f 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -75,9 +75,9 @@ class VTRACE(nn.Module, Policy): try: if load_path == "from_pretrained": urllib.request.urlretrieve( - "https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl", - f"{dir_name}/ddpt.pol.mdl") - load_path = f"{dir_name}/ddpt" + f"https://huggingface.co/ConvLab/ddpt-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl", + f"{dir_name}/{self.vector.dataset_name}_ddpt.pol.mdl") + load_path = f"{dir_name}/{self.vector.dataset_name}_ddpt" self.load_policy(load_path) except Exception as e: print(f"Could not load the policy, Exception: {e}")