Skip to content
Snippets Groups Projects
Commit 89b3685f authored by Christian's avatar Christian
Browse files

added from pretraining loadpath for DDPT to load a mw21 model

parent dc4f268d
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ import os
import sys
import torch
import torch.nn as nn
import urllib.request
from torch import optim
from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder
......@@ -64,6 +65,11 @@ class VTRACE(nn.Module, Policy):
self.value_helper = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE)
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
"https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl",
f"{dir_name}/ddpt.pol.mdl")
load_path = f"{dir_name}/ddpt"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment