From 89b3685f84e3762558d6ec5d71fd37f24aff4f99 Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Wed, 21 Dec 2022 16:01:10 +0100 Subject: [PATCH] added from pretraining loadpath for DDPT to load a mw21 model --- convlab/policy/vtrace_DPT/vtrace.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index 0474a364..37527164 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -5,6 +5,7 @@ import os import sys import torch import torch.nn as nn +import urllib.request from torch import optim from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder @@ -64,6 +65,11 @@ class VTRACE(nn.Module, Policy): self.value_helper = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE) try: + if load_path == "from_pretrained": + urllib.request.urlretrieve( + "https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl", + f"{dir_name}/ddpt.pol.mdl") + load_path = f"{dir_name}/ddpt" self.load_policy(load_path) except Exception as e: print(f"Could not load the policy, Exception: {e}") -- GitLab