diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py index 0d9f0f4ebe1954e7a75b7f4c9526d43833b8d2da..e79d5535a98bb3a6cdad933154988a04dbfb2bf9 100644 --- a/data/unified_datasets/multiwoz21/database.py +++ b/data/unified_datasets/multiwoz21/database.py @@ -5,13 +5,14 @@ from fuzzywuzzy import fuzz from itertools import chain from zipfile import ZipFile from copy import deepcopy -from convlab.util.unified_datasets_util import BaseDatabase +from convlab.util.unified_datasets_util import BaseDatabase, download_unified_datasets class Database(BaseDatabase): def __init__(self): """extract data.zip and load the database.""" - archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip')) + data_path = download_unified_datasets('multiwoz21', 'data.zip', os.path.dirname(os.path.abspath(__file__))) + archive = ZipFile(data_path) self.domains = ['restaurant', 'hotel', 'attraction', 'train', 'hospital', 'police'] self.dbs = {} for domain in self.domains: