From 0d233eee81eb20fba0d5d3615739e417491ceefc Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <niekerk@hhu.de>
Date: Tue, 29 Nov 2022 18:03:52 +0100
Subject: [PATCH] Bug fix and update downloader for setsumbt

---
 convlab/dst/setsumbt/do/nbt.py  | 19 +++++++++++++++++++
 convlab/dst/setsumbt/tracker.py |  4 ++--
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py
index ea099442..d250f29e 100644
--- a/convlab/dst/setsumbt/do/nbt.py
+++ b/convlab/dst/setsumbt/do/nbt.py
@@ -20,6 +20,7 @@ import os
 from shutil import copy2 as copy
 import json
 from copy import deepcopy
+import pdb
 
 import torch
 import transformers
@@ -34,6 +35,7 @@ from convlab.dst.setsumbt.modeling import training
 from convlab.dst.setsumbt.dataset import ontology as embeddings
 from convlab.dst.setsumbt.utils import get_args, update_args
 from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble
+from convlab.util.custom_util import model_downloader
 
 
 # Available model
@@ -55,6 +57,23 @@ def main(args=None, config=None):
 
     # Set up output directory
     OUTPUT_DIR = args.output_dir
+
+    # Download model if needed
+    if not os.path.exists(OUTPUT_DIR):
+        # Get path /.../convlab/dst/setsumbt/multiwoz/models
+        download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        download_path = os.path.join(download_path, 'models')
+        if not os.path.exists(download_path):
+            os.mkdir(download_path)
+        model_downloader(download_path, OUTPUT_DIR)
+        # Downloadable model path format http://.../model_name.zip
+        OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '')
+        OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR)
+
+        args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1])
+        args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1])
+        os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders'))
+
     if not os.path.exists(OUTPUT_DIR):
         os.makedirs(OUTPUT_DIR)
         os.mkdir(os.path.join(OUTPUT_DIR, 'database'))
diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py
index 6b620247..eca7f174 100644
--- a/convlab/dst/setsumbt/tracker.py
+++ b/convlab/dst/setsumbt/tracker.py
@@ -61,8 +61,8 @@ class SetSUMBTTracker(DST):
             if not os.path.exists(download_path):
                 os.mkdir(download_path)
             model_downloader(download_path, self.model_path)
-            # Downloadable model path format http://.../setsumbt_model_name.zip
-            self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '')
+            # Downloadable model path format http://.../model_name.zip
+            self.model_path = self.model_path.split('/')[-1].replace('.zip', '')
             self.model_path = os.path.join(download_path, self.model_path)
 
         # Select model type based on the encoder
-- 
GitLab