diff --git a/convlab2/nlg/scgpt/multiwoz/preprocess.py b/convlab2/nlg/scgpt/multiwoz/preprocess.py index 3dcda2eb246fc919cf4843fe0c8075d0a9071138..27c5e9005f3ee36bf7108d8477050a61fd164d6e 100644 --- a/convlab2/nlg/scgpt/multiwoz/preprocess.py +++ b/convlab2/nlg/scgpt/multiwoz/preprocess.py @@ -8,13 +8,23 @@ Created on Mon Sep 14 11:38:53 2020 import os import json from convlab2.nlg.scgpt.utils import dict2dict, dict2seq +import zipfile + +def read_zipped_json(filepath, filename): + print("zip file path = ", filepath) + archive = zipfile.ZipFile(filepath, 'r') + return json.load(archive.open(filename)) cur_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( cur_dir)))), 'data/multiwoz/') -with open(os.path.join(data_dir, '0807_final.json'),'r', encoding='utf8') as f: - data = json.load(f) +keys = ['train', 'val', 'test'] +data = {} +for key in keys: + data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json') + print('load {}, size {}'.format(key, len(data_key))) + data = dict(data, **data_key) with open(os.path.join(data_dir, 'valListFile'), 'r') as f: val_list = f.read().splitlines()