From 89b3685f84e3762558d6ec5d71fd37f24aff4f99 Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Wed, 21 Dec 2022 16:01:10 +0100
Subject: [PATCH] added from pretraining loadpath for DDPT to load a mw21 model

---
 convlab/policy/vtrace_DPT/vtrace.py | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py
index 0474a364..37527164 100644
--- a/convlab/policy/vtrace_DPT/vtrace.py
+++ b/convlab/policy/vtrace_DPT/vtrace.py
@@ -5,6 +5,7 @@ import os
 import sys
 import torch
 import torch.nn as nn
+import urllib.request
 
 from torch import optim
 from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder
@@ -64,6 +65,11 @@ class VTRACE(nn.Module, Policy):
         self.value_helper = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE)
 
         try:
+            if load_path == "from_pretrained":
+                urllib.request.urlretrieve(
+                    "https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl",
+                    f"{dir_name}/ddpt.pol.mdl")
+                load_path = f"{dir_name}/ddpt"
             self.load_policy(load_path)
         except Exception as e:
             print(f"Could not load the policy, Exception: {e}")
-- 
GitLab