import json
import os
import argparse
from tqdm import tqdm

print(f"Preprocessing emowoz for glove feature extraction.")

data_dir = '../../data'
with open(f'{data_dir}/data-split.json') as json_file:
    id_dict = json.load(json_file)

train_ids = id_dict['train']['multiwoz'] + id_dict['train']['dialmage']
dev_ids = id_dict['dev']['multiwoz'] + id_dict['dev']['dialmage']
test_ids = id_dict['test']['multiwoz'] + id_dict['test']['dialmage']
ids = [train_ids, dev_ids, test_ids]

with open(f'{data_dir}/emowoz-multiwoz.json') as json_file:
    multiwoz = json.load(json_file)
with open(f'{data_dir}/emowoz-dialmage.json') as json_file:
    dialmage = json.load(json_file)
dialogues = {**multiwoz, **dialmage}

emo_dict = {-1:'dontcare', 0:'neutral', 1:'neg-event', 2:'neg-dial', 3:'apologetic', 4:'abusive', 5:'pos-event', 6:'pos-dial'}
senti_dict = {-1:'dontcare', 0:'neutral', 1:'negative', 2:'positive'}

jsonfiles = ['emowoz/emowoz-train.json', 'emowoz/emowoz-dev.json', 'emowoz/emowoz-test.json']

# remove existing since we are using appending mode later
for f in jsonfiles:
    if os.path.exists(f):
        os.remove(f)
print("Existing files removed, if any")

for i in range(len(jsonfiles)):
    output_file = jsonfiles[i]
    output_ids = ids[i]
    print(f"Generating {output_file}...")
    with open(output_file, 'a') as f:
        for k in tqdm(output_ids):
            jsonl = {}
            jsonl['fold'] = 'train'
            jsonl['topic'] = 'dontcare'     # topic is not used in dialoguernn.
            jsonl['dialid'] = k
            dials = []
            for log in dialogues[k]['log']:
                sent_dict = {}
                emo = -1    
                senti = -1 # emotion/sentiment = dontcare for system turns
                if len(log['emotion']) > 0:
                    emo = log['emotion'][3]['emotion']
                    senti = log['emotion'][3]['sentiment']
                sent_dict['emotion'] = emo_dict[emo]
                sent_dict['sentiment'] = senti_dict[senti]
                dact = ''
                if len(log['dialog_act']) == 0:
                    dact = 'unknown'
                else:
                    dact = list(log['dialog_act'].keys())[0]
                # fill this value with the first domain-act pair if any. Act is not used by dialoguernn anyway.
                sent_dict['act'] = dact
                sent_dict['text'] = log['text']
                dials.append(sent_dict)
            jsonl['dialogue'] = dials
            f.write(json.dumps(jsonl))
            f.write('\n')

print("Done!")