From 0d233eee81eb20fba0d5d3615739e417491ceefc Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <niekerk@hhu.de> Date: Tue, 29 Nov 2022 18:03:52 +0100 Subject: [PATCH] Bug fix and update downloader for setsumbt --- convlab/dst/setsumbt/do/nbt.py | 19 +++++++++++++++++++ convlab/dst/setsumbt/tracker.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/convlab/dst/setsumbt/do/nbt.py b/convlab/dst/setsumbt/do/nbt.py index ea099442..d250f29e 100644 --- a/convlab/dst/setsumbt/do/nbt.py +++ b/convlab/dst/setsumbt/do/nbt.py @@ -20,6 +20,7 @@ import os from shutil import copy2 as copy import json from copy import deepcopy +import pdb import torch import transformers @@ -34,6 +35,7 @@ from convlab.dst.setsumbt.modeling import training from convlab.dst.setsumbt.dataset import ontology as embeddings from convlab.dst.setsumbt.utils import get_args, update_args from convlab.dst.setsumbt.modeling.ensemble_nbt import setup_ensemble +from convlab.util.custom_util import model_downloader # Available model @@ -55,6 +57,23 @@ def main(args=None, config=None): # Set up output directory OUTPUT_DIR = args.output_dir + + # Download model if needed + if not os.path.exists(OUTPUT_DIR): + # Get path /.../convlab/dst/setsumbt/multiwoz/models + download_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + download_path = os.path.join(download_path, 'models') + if not os.path.exists(download_path): + os.mkdir(download_path) + model_downloader(download_path, OUTPUT_DIR) + # Downloadable model path format http://.../model_name.zip + OUTPUT_DIR = OUTPUT_DIR.split('/')[-1].replace('.zip', '') + OUTPUT_DIR = os.path.join(download_path, OUTPUT_DIR) + + args.tensorboard_path = os.path.join(OUTPUT_DIR, args.tensorboard_path.split('/')[-1]) + args.logging_path = os.path.join(OUTPUT_DIR, args.logging_path.split('/')[-1]) + os.mkdir(os.path.join(OUTPUT_DIR, 'dataloaders')) + if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) os.mkdir(os.path.join(OUTPUT_DIR, 'database')) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index 6b620247..eca7f174 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -61,8 +61,8 @@ class SetSUMBTTracker(DST): if not os.path.exists(download_path): os.mkdir(download_path) model_downloader(download_path, self.model_path) - # Downloadable model path format http://.../setsumbt_model_name.zip - self.model_path = self.model_path.split('/')[-1].split('_', 1)[-1].replace('.zip', '') + # Downloadable model path format http://.../model_name.zip + self.model_path = self.model_path.split('/')[-1].replace('.zip', '') self.model_path = os.path.join(download_path, self.model_path) # Select model type based on the encoder -- GitLab