Skip to content
Snippets Groups Projects
Commit e2addbbd authored by zqwerty's avatar zqwerty Committed by zhuqi
Browse files

fix sclstm crosswoz import issues

parent fb643d2b
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,7 @@ def read_json(filename): ...@@ -20,6 +20,7 @@ def read_json(filename):
return json.load(f) return json.load(f)
def main():
def cmp_intent(intent1: str, intent2: str): def cmp_intent(intent1: str, intent2: str):
assert role in ['sys', 'usr'] assert role in ['sys', 'usr']
intent_order = { intent_order = {
...@@ -185,8 +186,8 @@ def cmp_intent(intent1: str, intent2: str): ...@@ -185,8 +186,8 @@ def cmp_intent(intent1: str, intent2: str):
print(role, intent1, intent2) print(role, intent1, intent2)
return intent_order[role].index(intent1) - intent_order[role].index(intent2) return intent_order[role].index(intent1) - intent_order[role].index(intent2)
data_dir = '../../../../../data/crosswoz'
data_dir = '../../../../data/crosswoz' data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), data_dir))
train_archive = zipfile.ZipFile(os.path.join(data_dir, 'train.json.zip'), 'r') train_archive = zipfile.ZipFile(os.path.join(data_dir, 'train.json.zip'), 'r')
train_data = json.load(train_archive.open('train.json')) train_data = json.load(train_archive.open('train.json'))
valid_archive = zipfile.ZipFile(os.path.join(data_dir, 'val.json.zip'), 'r') valid_archive = zipfile.ZipFile(os.path.join(data_dir, 'val.json.zip'), 'r')
...@@ -473,3 +474,6 @@ for require_role in ['sys', 'usr']: ...@@ -473,3 +474,6 @@ for require_role in ['sys', 'usr']:
with open(os.path.join(output_data_dir, 'split.json'), 'w', encoding='utf-8') as f: with open(os.path.join(output_data_dir, 'split.json'), 'w', encoding='utf-8') as f:
json.dump(split_dict, f, indent=4, sort_keys=True, ensure_ascii=False) json.dump(split_dict, f, indent=4, sort_keys=True, ensure_ascii=False)
if __name__ == '__main__':
main()
...@@ -3,6 +3,7 @@ from collections import defaultdict ...@@ -3,6 +3,7 @@ from collections import defaultdict
from copy import copy from copy import copy
import functools import functools
import zipfile import zipfile
import os
def read_zipped_json(filepath, filename): def read_zipped_json(filepath, filename):
...@@ -10,7 +11,10 @@ def read_zipped_json(filepath, filename): ...@@ -10,7 +11,10 @@ def read_zipped_json(filepath, filename):
return json.load(archive.open(filename)) return json.load(archive.open(filename))
file_path = '../../../../data/crosswoz/train.json.zip' def main():
file_path = '../../../../../data/crosswoz/train.json.zip'
file_path = os.path.abspath(os.path.join(os.path.abspath(__file__), file_path))
print(os.path.abspath(file_path))
data = read_zipped_json(file_path, 'train.json') data = read_zipped_json(file_path, 'train.json')
print('\n\nLength of data: ', len(data)) print('\n\nLength of data: ', len(data))
...@@ -260,4 +264,5 @@ with open('auto_user_template_nlg.json', 'w', encoding='utf-8') as f: ...@@ -260,4 +264,5 @@ with open('auto_user_template_nlg.json', 'w', encoding='utf-8') as f:
with open('auto_system_template_nlg.json', 'w', encoding='utf-8') as f: with open('auto_system_template_nlg.json', 'w', encoding='utf-8') as f:
json.dump(sys_multi_intent_dict, f, indent=4, sort_keys=True, ensure_ascii=False) json.dump(sys_multi_intent_dict, f, indent=4, sort_keys=True, ensure_ascii=False)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment