Skip to content
Snippets Groups Projects
Commit 8183031a authored by zqwerty's avatar zqwerty
Browse files

add preprocess.py for opendialkg and 10% wikidialog

parent 02bd56be
Branches
No related tags found
No related merge requests found
from turtle import st
from zipfile import ZipFile, ZIP_DEFLATED
from shutil import rmtree
import json
import os
from tqdm import tqdm
from collections import Counter
from pprint import pprint
import re
import requests
from dateutil import parser as date_parser
from string import punctuation
from copy import deepcopy
import csv
import random
def value_in_utt(value, utt):
"""return character level (start, end) if value in utt"""
value = value.strip(punctuation).lower()
utt = utt
p = '(^|[\s,\.:\?!-])(?P<v>{})([\s,\.:\?!-\']|$)'.format(re.escape(value))
p = re.compile(p, re.I)
m = re.search(p, utt)
if m:
# very few value appears more than once, take the first span
return True, m.span('v')
else:
try:
# solve date representation, e.g. '3 pm' vs '3pm'
date_parser.parse(value)
if (value.endswith('pm') or value.endswith('am')) and ''.join(value.split(' ')) in ''.join(utt.split(' ')):
return True, None
except:
if value in utt:
# value appears, but may be in the plural, -ing, -ly, etc.
return True, None
return False, None
def preprocess():
random.seed(42)
data_file = "opendialkg.csv"
if not os.path.exists(data_file):
response = requests.get("https://github.com/facebookresearch/opendialkg/raw/main/data/opendialkg.csv")
open(data_file, "wb").write(response.content)
new_data_dir = 'data'
os.makedirs(new_data_dir, exist_ok=True)
dataset = 'opendialkg'
splits = ['train', 'validation', 'test']
dialogues_by_split = {split:[] for split in splits}
ontology = {'domains': {},
'intents': {},
'state': {},
'dialogue_acts': {
"categorical": {},
"non-categorical": {},
"binary": {}
}}
data = []
with open(data_file) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
header = next(csv_reader)
for row in csv_reader:
sample = {}
for i, col in enumerate(row):
sample[header[i]] = col
data.append(sample)
# shuffle for random split to train:validation:test = 70:15:15
random.shuffle(data)
split2range = {
'train': [0, round(len(data)*0.7)],
'validation': [round(len(data)*0.7), round(len(data)*0.85)],
'test': [round(len(data)*0.85), len(data)],
}
cnt = 0
for data_split in splits:
for i in tqdm(range(*split2range[data_split])):
item = data[i]
dialogue_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
dialogue = {
'dataset': dataset,
'data_split': data_split,
'dialogue_id': dialogue_id,
'original_id': f'{data_split}-{len(dialogues_by_split[data_split])}',
'user_rating': eval(item['User Rating']),
'system_rating': eval(item['Assistant Rating']),
'turns': [],
}
for turn in eval(item['Messages']):
speaker = 'user' if turn['sender'] == 'user' else 'system'
turn_type = turn['type']
if turn_type == 'chat':
assert len(turn) == 3
if len(dialogue['turns'])>0 and speaker == dialogue['turns'][-1]['speaker']:
dialogue['turns'][-1]['utterance'] += turn['message']
else:
dialogue['turns'].append({
'speaker': speaker,
'utterance': turn['message'],
'utt_idx': len(dialogue['turns']),
})
elif turn['action_id'] == "meta_thread/send_meta_message":
# skip annotator communication
pass
else:
assert turn_type == 'action' and turn['action_id'] == "kgwalk/choose_path"
assert len(dialogue['turns'])==0 or (speaker != dialogue['turns'][-1]['speaker']), print(turn)
dialogue['turns'].append({
'speaker': speaker,
'utterance': '',
'kg_path': {k: v for k, v in zip(['score', 'triples', 'rendering'], turn['metadata']['path'])},
'utt_idx': len(dialogue['turns']),
})
if len(dialogue['turns']) != 0:
dialogues_by_split[data_split].append(dialogue)
if any(['kg_path' in turn for turn in dialogue['turns']]):
cnt+=1
dialogues = dialogues_by_split['train']+dialogues_by_split['validation']+dialogues_by_split['test']
print(cnt, len(dialogues), cnt/len(dialogues))
json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
for filename in os.listdir(new_data_dir):
zf.write(f'{new_data_dir}/{filename}')
rmtree(new_data_dir)
return dialogues, ontology
if __name__ == '__main__':
preprocess()
import gzip
import json
from zipfile import ZipFile, ZIP_DEFLATED
import os
from shutil import rmtree
from tqdm import tqdm
def preprocess():
original_data_dir = 'WikiDialog-OQ'
new_data_dir = 'data'
os.makedirs(new_data_dir, exist_ok=True)
dataset = 'wikidialog'
splits = ['train', 'validation']
dialogues_by_split = {split:[] for split in splits}
ontology = {
'domains': {},
'intents': {},
'state': {},
"dialogue_acts": {
"categorical": {},
"non-categorical": {},
"binary": {}
}
}
def process_dial(line, dial_id, data_split):
item = json.loads(line)
dialogue = {
'dataset': dataset,
'data_split': data_split,
'dialogue_id': dial_id,
'original_id': item['pid'],
'topic': item['title'],
'turns': []
}
for speaker, utterance in zip(item['author_num'], item['utterances']):
speaker = 'system' if speaker == 0 else 'user'
turn = {
'speaker': speaker,
'utterance': utterance.strip(),
'utt_idx': len(dialogue['turns']),
}
dialogue['turns'].append(turn)
return dialogue
data_split = 'train'
for shard in tqdm(range(1)):
with gzip.open(f'{original_data_dir}/data_train.jsonl-000{shard:02}-of-00099.gz','r') as fin:
for line in fin:
dial_id = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
dialogue = process_dial(line, dial_id, data_split)
dialogues_by_split[data_split].append(dialogue)
data_split = 'validation'
with gzip.open(f'{original_data_dir}/data_validation.jsonl.gz','r') as fin:
for line in fin:
dialogue = process_dial(line, dial_id, data_split)
dialogue['dialogue_id'] = f'{dataset}-{data_split}-{len(dialogues_by_split[data_split])}'
dialogues_by_split[data_split].append(dialogue)
if len(dialogues_by_split[data_split]) >= len(dialogues_by_split['train']) // 10:
break
dialogues = dialogues_by_split['train']+dialogues_by_split['validation']
json.dump(dialogues[:10], open(f'dummy_data.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
json.dump(ontology, open(f'{new_data_dir}/ontology.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
for filename in os.listdir(new_data_dir):
zf.write(f'{new_data_dir}/{filename}')
# rmtree(original_data_dir)
rmtree(new_data_dir)
return dialogues, ontology
if __name__ == '__main__':
preprocess()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment