diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index 0474a3641694104625244150414793e97afa0268..37527164c3941172fa0374cc52c36fd34cd524c4 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -5,6 +5,7 @@ import os import sys import torch import torch.nn as nn +import urllib.request from torch import optim from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder @@ -64,6 +65,11 @@ class VTRACE(nn.Module, Policy): self.value_helper = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE) try: + if load_path == "from_pretrained": + urllib.request.urlretrieve( + "https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl", + f"{dir_name}/ddpt.pol.mdl") + load_path = f"{dir_name}/ddpt" self.load_policy(load_path) except Exception as e: print(f"Could not load the policy, Exception: {e}")