diff --git a/.gitignore b/.gitignore index ea07c374a8d5c858256dcd34100b5eee1b52e8ab..59ff27e7413065504f08eb935f90ad5c9afddadd 100644 --- a/.gitignore +++ b/.gitignore @@ -104,5 +104,9 @@ convlab/deploy/templates/dialog_eg.html *convlab/policy/vector/action_dicts *.egg-info +.eggs/* pre-trained-models/ venv +*.zip +*/dummy_data.json +*.csv \ No newline at end of file diff --git a/convlab/dialog_agent/agent.py b/convlab/dialog_agent/agent.py index 025b41e88cef8eec0e0ca556fbb807dee84ef72d..640f8c33ca41663a09903ce24fb6291b375e149b 100755 --- a/convlab/dialog_agent/agent.py +++ b/convlab/dialog_agent/agent.py @@ -102,7 +102,7 @@ class PipelineAgent(Agent): self.history = [] self.turn = 0 - #logging.info("Pipeline Agent info_dict check") + # logging.info("Pipeline Agent info_dict check") if hasattr(self.nlu, 'info_dict') == False: logging.warning('nlu info_dict is not initialized') if hasattr(self.dst, 'info_dict') == False: @@ -111,7 +111,7 @@ class PipelineAgent(Agent): logging.warning('policy info_dict is not initialized') if hasattr(self.nlg, 'info_dict') == False: logging.warning('nlg info_dict is not initialized') - #logging.info("Done") + # logging.info("Done") def state_replace(self, agent_state): """ @@ -259,6 +259,8 @@ class PipelineAgent(Agent): return self.input_action def get_out_da(self): + if self.name == "user" and hasattr(self.policy, "semantic_action"): + return self.policy.semantic_action return self.output_action diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index a8915301153f8e824e0a2ed91cd6eb9cd34e2605..c1d729c51be7172abb157fc3adad44f123f7604c 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -27,7 +27,7 @@ class Environment(): s, r, t = self.step([]) return self.sys_dst.state - def step(self, action): + def step(self, action, user_reward=False): # save last system action self.sys_dst.state['system_action'] = action if not self.use_semantic_acts: @@ -41,9 +41,9 @@ class Environment(): if intent == "book": self.sys_dst.state['booked'][domain] = [{slot: value}] observation = self.usr.response(model_response) - if self.evaluator: - self.evaluator.add_sys_da(self.usr.get_in_da(), self.sys_dst.state['belief_state']) + self.evaluator.add_sys_da( + self.usr.get_in_da(), self.sys_dst.state['belief_state']) self.evaluator.add_usr_da(self.usr.get_out_da()) dialog_act = self.sys_nlu.predict( @@ -59,9 +59,11 @@ class Environment(): state = deepcopy(state) terminated = self.usr.is_terminated() - - if self.evaluator: - reward = self.evaluator.get_reward(terminated) + if not user_reward: + if self.evaluator: + reward = self.evaluator.get_reward(terminated) + else: + reward = self.usr.get_reward() else: reward = self.usr.get_reward() diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index 75c4f2195f0541aa378404d34b09b3050b5a60b0..3ca0905c20361fcdc72146192bcd0a2b7d693775 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -27,8 +27,10 @@ for dom, ref_slots in REF_SYS_DA.items(): REF_SYS_DA_M['taxi']['phone'] = 'phone' REF_SYS_DA_M['taxi']['car'] = 'car type' -reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') -reverse_da_slot_name_map = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da_slot_name_map') +reverse_da = relative_import_module_from_unified_datasets( + 'multiwoz21', 'preprocess.py', 'reverse_da') +reverse_da_slot_name_map = relative_import_module_from_unified_datasets( + 'multiwoz21', 'preprocess.py', 'reverse_da_slot_name_map') requestable = \ diff --git a/convlab/nlu/jointBERT/multiwoz/nlu.py b/convlab/nlu/jointBERT/multiwoz/nlu.py index eb776fcaeeb846abb63c2a8865c69b1c0257c614..10b79e15f753e498476b25f0855103873f352d97 100755 --- a/convlab/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab/nlu/jointBERT/multiwoz/nlu.py @@ -41,12 +41,13 @@ class BERTNLU(NLU): dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab, pretrained_weights=config['model']['pretrained_weights']) - logging.info('intent num:' + str(len(intent_vocab))) + logging.info('intent num:' + str(len(intent_vocab))) logging.info('tag num:' + str(len(tag_vocab))) if not os.path.exists(output_dir): model_downloader(root_dir, model_file) - model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim) + model = JointBERT(config['model'], DEVICE, + dataloader.tag_dim, dataloader.intent_dim) state_dict = torch.load(os.path.join( output_dir, 'pytorch_model.bin'), DEVICE) @@ -74,7 +75,7 @@ class BERTNLU(NLU): for token in token_list: token = token.strip() self.nlp.tokenizer.add_special_case( - #token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) + # token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) token, [{ORTH: token}]) logging.info("BERTNLU loaded") @@ -97,7 +98,8 @@ class BERTNLU(NLU): intents = [] da = {} - word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq) + word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize( + ori_word_seq, ori_tag_seq) word_seq = word_seq[:510] tag_seq = tag_seq[:510] batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq, diff --git a/convlab/policy/USMDA/evaluate.py b/convlab/policy/USMDA/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..567150b39b0981e1d65c7881a993dd122ab6d4b4 --- /dev/null +++ b/convlab/policy/USMDA/evaluate.py @@ -0,0 +1,79 @@ +import json +import os +from argparse import ArgumentParser + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn import metrics +from tqdm import tqdm +from transformers import AutoModelForSequenceClassification, AutoTokenizer + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model", type=str, default="", + help="model name") + parser.add_argument("--data", type=str) + parser.add_argument("--gen-file", type=str) + return parser.parse_args() + + +def generate_result(model_checkpoint, data): + result = [] + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + model = AutoModelForSequenceClassification.from_pretrained( + model_checkpoint) + data = pd.read_csv(data, index_col=False).astype(str) + # Neutral: 0, Negative: 1, Positive: 2 + t2i = {'3': 0, '1': 1, '2': 1, '4': 2, '5': 2} + prefix = "satisfaction score: " + for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True): + if prefix in input_text: + text = input_text.replace(prefix, '') + target = t2i[target_text] + model_input = tokenizer( + [text], return_tensors="pt", padding=True) + output = model(input_ids=model_input["input_ids"], + attention_mask=model_input["attention_mask"]) + output = int(np.argmax(output, axis=-1)) + result.append({"input_text": text, + "preds": output, + "label": target}) + json.dump(result, open(os.path.join( + model_checkpoint, "uss_result.json"), 'w')) + return result + + +def read_result(result): + preds = [] + label = [] + for r in result: + preds.append(r["preds"]) + label.append(r["label"]) + return preds, label + + +def main(): + args = arg_parser() + if args.gen_file: + preds, label = read_result(json.load(open(args.gen_file))) + else: + results = generate_result(args.model, args.data) + preds, label = read_result(results) + + macro_f1 = metrics.f1_score(label, preds, average="macro") + sep_f1 = metrics.f1_score( + label, preds, average=None, + labels=[0, 1, 2]) + cm = metrics.confusion_matrix( + label, preds, normalize="true", + labels=[0, 1, 2]) + print("Neutral: 0, Negative: 1, Positive: 2") + print("cm", cm) + print("f1", sep_f1) + print("macro", macro_f1) + + +if __name__ == "__main__": + main() diff --git a/convlab/policy/USMDA/example.py b/convlab/policy/USMDA/example.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ddb28c99dd0dff45f0591f502e7132abaf63a7 --- /dev/null +++ b/convlab/policy/USMDA/example.py @@ -0,0 +1,24 @@ +from datasets import Dataset +from transformers import AutoTokenizer + +model_checkpoint = "bert-base-cased" +tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) + +raw_data = { + "train": [{"label": 0, "text": "hi how are you"}, + {"label": 1, "text": "i'm fine thank you"}, ], + "test": [{"label": 0, "text": "hi how are you"}, + {"label": 1, "text": "i'm fine thank you"}, ]} +data = {} +for x in raw_data: + data[x] = Dataset.from_list(raw_data[x]) + + +def tokenize_function(examples): + print(examples) + return tokenizer(examples["text"], padding="max_length", truncation=True) + + +t = data["train"].map(tokenize_function, batched=True) + +print(t) diff --git a/convlab/policy/USMDA/predict.py b/convlab/policy/USMDA/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..63ffb91b7cfc11c5b5d3923ec59d9a0759069c2a --- /dev/null +++ b/convlab/policy/USMDA/predict.py @@ -0,0 +1,30 @@ +from argparse import ArgumentParser + +import numpy as np +from transformers import AutoModelForSequenceClassification, AutoTokenizer + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model", type=str, default="", + help="model name") + parser.add_argument("--data", type=str) + parser.add_argument("--gen-file", type=str) + return parser.parse_args() + + +def main(): + args = arg_parser() + model_checkpoint = args.model + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + model = AutoModelForSequenceClassification.from_pretrained( + model_checkpoint) + input_text = "Yeah, I think we are. This isn't even my dress." + inputs = tokenizer([input_text], return_tensors="pt", padding=True) + output = model(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"]) + print(np.argmax(output, axis=-1)) + + +if __name__ == "__main__": + main() diff --git a/convlab/policy/USMDA/train.py b/convlab/policy/USMDA/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8410a3f7bc4c067ae822a06801373f9bf5f69cd8 --- /dev/null +++ b/convlab/policy/USMDA/train.py @@ -0,0 +1,122 @@ +import os +import random +from argparse import ArgumentParser +import json + +import numpy as np +import torch +from datasets import load_metric, Dataset +from sklearn.model_selection import train_test_split +from transformers import (AutoModelForSequenceClassification, AutoTokenizer, + Trainer, TrainingArguments) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--data", type=str, default="", + help="input data") + parser.add_argument("--batch", type=int, default=2, + help="batch size") + + return parser.parse_args() + + +def set_seed(r_seed): + random.seed(r_seed) + np.random.seed(r_seed) + torch.manual_seed(r_seed) + + +def read_data(data_dir): + print("data_dir", data_dir) + subfix = {"train": "trn", "validation": "dev", "test": "tst"} + files = {} + data = {} + for data_split, sub in subfix.items(): + data[data_split] = parse_data(json.load( + open(os.path.join(data_dir, f"emotion-detection-{sub}.json")))) + + return data + + +def parse_data(data): + emo2label = { + "Neutral": 0, + "Scared": 1, + "Mad": 1, + "Sad": 1, + "Joyful": 2, + "Peaceful": 2, + "Powerful": 2 + } + d = [] + for episode in data["episodes"]: + for scene in episode["scenes"]: + for r in range(len(scene["utterances"])-1): + text = ' '.join([scene["utterances"][r]["transcript"], + scene["utterances"][r+1]["transcript"]]) + label = emo2label.get( + scene["utterances"][r+1]["emotion"], "Neutral") + d.append({"label": label, "text": text}) + + return d + + +def main(): + args = arg_parser() + base_name = "convlab/policy/USMDA" + model_checkpoint = "bert-base-cased" + tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) + model = AutoModelForSequenceClassification.from_pretrained( + model_checkpoint, num_labels=3) + metric = load_metric("accuracy") + + fp16 = False + if torch.cuda.is_available(): + print("use cuda") + fp16 = True + model.to("cuda") + + def tokenize_function(examples): + return tokenizer(examples["text"], padding="max_length", truncation=True) + + emory_data = read_data(args.data) + folder_name = os.path.join(base_name, "data") + if not os.path.exists(folder_name): + os.makedirs(folder_name) + json.dump(emory_data, open(os.path.join(folder_name, "data.json"), 'w')) + + data = {} + for data_split, d in emory_data.items(): + d = Dataset.from_list(d) + data[data_split] = d.map(tokenize_function, batched=True) + + model_dir = os.path.join(base_name, "model") + + def compute_metrics(eval_pred): + logits, labels = eval_pred + predictions = np.argmax(logits, axis=-1) + return metric.compute(predictions=predictions, references=labels) + + training_args = TrainingArguments( + output_dir=model_dir, + learning_rate=2e-5, + per_device_train_batch_size=args.batch, + per_device_eval_batch_size=args.batch, + evaluation_strategy="epoch", + num_train_epochs=2, + fp16=fp16) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=data["train"], + eval_dataset=data["test"], + compute_metrics=compute_metrics,) + + trainer.train() + trainer.save_model() + + +if __name__ == "__main__": + main() diff --git a/convlab/policy/emoUS/analysis.py b/convlab/policy/emoUS/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..37d0c7648c49257944506654d2d8aa7843bada21 --- /dev/null +++ b/convlab/policy/emoUS/analysis.py @@ -0,0 +1,307 @@ +import json +import os +from argparse import ArgumentParser + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +result_dir = "convlab/policy/emoUS/result" + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--file", type=str, help="the conversation file") + return parser.parse_args() + + +def basic_analysis(conversation): + info = {"Complete": [], "Success": [], "Success strict": [], "turns": []} + for dialog in conversation: + for x in info: + info[x].append(dialog[x]) + for x in info: + info[x] = np.mean(info[x]) + return info + + +def advance(conversation): + info = {} + for dialog in conversation: + temp = turn_level(dialog["log"]) + for metric, data in temp.items(): + if metric not in info: + info[metric] = {} + for emotion, count in data.items(): + if emotion not in info[metric]: + info[metric][emotion] = 0 + info[metric][emotion] += count + + return info + + +def get_turn_emotion(conversation): + """ Get the emotion of each turn in the conversation + Args: + conversation (list): a list of dialog + Returns: + turn_emotion (list): a list of emotion of each turn + """ + turn_info = {"all": {}, + "Complete": {}, "Not Complete": {}, + "Success": {}, "Not Success": {}, + "Success strict": {}, "Not Success strict": {}} + max_turn = 0 + for dialog in conversation: + for i in range(0, len(dialog["log"]), 2): + turn = int(i / 2) + if turn > max_turn: + max_turn = turn + emotion = emotion_score(dialog["log"][i]["emotion"]) + insert_turn(turn_info["all"], turn, emotion) + for metric in ["Complete", "Success", "Success strict"]: + if dialog[metric]: + insert_turn(turn_info[metric], turn, emotion) + else: + insert_turn(turn_info[f"Not {metric}"], turn, emotion) + print("MAX_TURN", max_turn) + data = {'x': [t for t in range(max_turn)], + 'all_positive': [], + 'all_negative': [], + 'all_mean': [], + 'all_std': []} + + for metric in ["Complete", "Success", "Success strict"]: + data[f"{metric}_positive"] = [] + data[f"{metric}_negative"] = [] + data[f"{metric}_mean"] = [] + data[f"{metric}_std"] = [] + data[f"Not {metric}_positive"] = [] + data[f"Not {metric}_negative"] = [] + data[f"Not {metric}_mean"] = [] + data[f"Not {metric}_std"] = [] + + for t in range(turn): + pos, neg, mean, std = turn_score(turn_info["all"][t]) + data[f"all_positive"].append(pos) + data[f"all_negative"].append(neg) + data[f"all_mean"].append(mean) + data[f"all_std"].append(std) + for raw_metric in ["Complete", "Success", "Success strict"]: + for metric in [raw_metric, f"Not {raw_metric}"]: + if t not in turn_info[metric]: + data[f"{metric}_positive"].append(0) + data[f"{metric}_negative"].append(0) + data[f"{metric}_mean"].append(0) + data[f"{metric}_std"].append(0) + else: + pos, neg, mean, std = turn_score(turn_info[metric][t]) + data[f"{metric}_positive"].append(pos) + data[f"{metric}_negative"].append(neg) + data[f"{metric}_mean"].append(mean) + data[f"{metric}_std"].append(std) + for x in data: + data[x] = np.array(data[x]) + + fig, ax = plt.subplots(figsize=(6.0, 2.5)) + p = {"Complete": {"color": "C0", "label": "Success"}, + "Not Complete": {"color": "C1", "label": "Fail"}, + "all": {"color": "C2", "label": "all"}} + for name, para in p.items(): + + ax.plot(data['x'], + data[f"{name}_mean"], + 'o--', + color=para["color"], + label=para["label"]) + ax.fill_between(data['x'], + data[f"{name}_mean"]+data[f"{name}_std"], + data[f"{name}_mean"]-data[f"{name}_std"], + color=para["color"], alpha=0.2) + + ax.legend() + ax.set_xlabel("turn") + ax.set_ylabel("Sentiment") + ax.set_xticks([t for t in range(0, max_turn, 2)]) + plt.grid(axis='x', color='0.95') + plt.grid(axis='y', color='0.95') + # plt.show() + plt.tight_layout() + plt.savefig(os.path.join(result_dir, "turn2emotion.png")) + + +def turn_score(score_list): + count = len(score_list) + positive = 0 + negative = 0 + for s in score_list: + if s > 0: + positive += 1 + if s < 0: + negative += -1 + return positive/count, negative/count, np.mean(score_list), np.std(score_list, ddof=1)/np.sqrt(len(score_list)) + + +def insert_turn(turn_info, turn, emotion): + if turn not in turn_info: + turn_info[turn] = [] + turn_info[turn].append(emotion) + + +def emotion_score(emotion): + if emotion == "Neutral": + return 0 + if emotion in ["Satisfied", "Excited"]: + return 1 + return -1 + + +def plot(conversation): + pass + + +def turn_level(dialog): + # metric: {emotion: count} + dialog_info = {} + for index in range(2, len(dialog), 2): + pre_usr = dialog[index-2] + sys = dialog[index-1] + cur_usr = dialog[index] + info = neglect_reply(pre_usr, sys, cur_usr) + append_info(dialog_info, info) + info = confirm(pre_usr, sys, cur_usr) + append_info(dialog_info, info) + info = miss_info(pre_usr, sys, cur_usr) + append_info(dialog_info, info) + if index > 2: + info = loop(dialog[index-3], sys, cur_usr) + append_info(dialog_info, info) + + return dialog_info + +# provide wrong info +# action length +# incomplete info? + + +def append_info(dialog_info, info): + if not info: + return + for emotion, metric in info.items(): + if metric not in dialog_info: + dialog_info[metric] = {} + if emotion not in dialog_info[metric]: + dialog_info[metric][emotion] = 0 + dialog_info[metric][emotion] += 1 + + +def get_inform(act): + inform = {} + for intent, domain, slot, value in act: + if intent not in ["inform", "recommend"]: + continue + if domain not in inform: + inform[domain] = [] + inform[domain].append(slot) + return inform + + +def get_request(act): + request = {} + for intent, domain, slot, _ in act: + if intent == "request": + if domain not in request: + request[domain] = [] + request[domain].append(slot) + return request + + +def neglect_reply(pre_usr, sys, cur_usr): + request = get_request(pre_usr["act"]) + if not request: + return {} + + system_inform = get_inform(sys["utt"]) + + for domain, slots in request.items(): + if domain not in system_inform: + return {cur_usr["emotion"]: "neglect"} + for slot in slots: + if slot not in system_inform[domain]: + return {cur_usr["emotion"]: "neglect"} + return {cur_usr["emotion"]: "reply"} + + +def miss_info(pre_usr, sys, cur_usr): + system_request = get_request(sys["utt"]) + if not system_request: + return {} + user_inform = get_inform(pre_usr["act"]) + for domain, slots in system_request.items(): + if domain not in user_inform: + continue + for slot in slots: + if slot in user_inform[domain]: + return {cur_usr["emotion"]: "miss_info"} + return {} + + +def confirm(pre_usr, sys, cur_usr): + user_inform = get_inform(pre_usr["act"]) + + if not user_inform: + return {} + + system_inform = get_inform(sys["utt"]) + + for domain, slots in user_inform.items(): + if domain not in system_inform: + continue + for slot in slots: + if slot in system_inform[domain]: + return {cur_usr["emotion"]: "confirm"} + + return {cur_usr["emotion"]: "no confirm"} + + +def loop(s0, s1, u1): + if s0 == s1: + return {u1["emotion"]: "loop"} + + +def dict2csv(data): + r = {} + emotion = json.load(open("convlab/policy/emoUS/emotion.json")) + for act, value in data.items(): + temp = [0]*(len(emotion)+1) + for emo, count in value.items(): + temp[emotion[emo]] = count + temp[-1] = sum(temp) + for i in range(len(emotion)): + temp[i] /= temp[-1] + r[act] = temp + dataframe = pd.DataFrame.from_dict( + r, orient='index', columns=[emo for emo in emotion]+["count"]) + dataframe.to_csv(open(os.path.join(result_dir, "act2emotion.csv"), 'w')) + + +def main(): + args = arg_parser() + result = {} + if not os.path.exists(result_dir): + os.makedirs(result_dir) + conversation = json.load(open(args.file))["conversation"] + # basic_info = basic_analysis(conversation) + # result["basic_info"] = basic_info + # print(basic_info) + # advance_info = advance(conversation) + # print(advance_info) + # result["advance_info"] = advance_info + # json.dump(result, open( + # os.path.join("conversation_result.json"), 'w'), indent=2) + # dict2csv(advance_info) + get_turn_emotion(conversation) + + +if __name__ == "__main__": + main() diff --git a/convlab/policy/emoUS/dialogue_collector.py b/convlab/policy/emoUS/dialogue_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..1976c0cd54bf5d3438622bb28f6d6a627ca03b7b --- /dev/null +++ b/convlab/policy/emoUS/dialogue_collector.py @@ -0,0 +1,103 @@ +from argparse import ArgumentParser + +from tqdm import tqdm + +from convlab.policy.rule.multiwoz import RulePolicy +from convlab.task.multiwoz.goal_generator import GoalGenerator +from convlab.util.custom_util import (create_goals, data_goals, env_config, + get_config, set_seed) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--config", type=str, help="the model path") + parser.add_argument("-N", "--num", type=int, + default=500, help="# of evaluation dialogue") + parser.add_argument("--model", type=str, + default="ppo", help="# of evaluation dialogue") + return parser.parse_args() + + +def interact(model_name, config, seed=0, num_goals=500): + conversation = [] + set_seed(seed) + conf = get_config(config, []) + + if model_name == "rule": + policy_sys = RulePolicy() + elif model_name == "ppo": + from convlab.policy.ppo import PPO + policy_sys = PPO(vectorizer=conf['vectorizer_sys_activated']) + + model_path = conf['model']['load_path'] + if model_path: + policy_sys.load(model_path) + + env, sess = env_config(conf, policy_sys) + goal_generator = GoalGenerator() + + goals = create_goals(goal_generator, num_goals=num_goals, + single_domains=False, allowed_domains=None) + + for seed in tqdm(range(1000, 1000 + num_goals)): + dialogue = {"seed": seed, "log": []} + set_seed(seed) + sess.init_session(goal=goals[seed-1000]) + sys_response = [] + actions = 0.0 + total_return = 0.0 + turns = 0 + task_succ = 0 + task_succ_strict = 0 + complete = 0 + dialogue["goal"] = env.usr.policy.policy.goal.domain_goals + dialogue["user info"] = env.usr.policy.policy.user_info + + for i in range(40): + sys_response, user_response, session_over, reward = sess.next_turn( + sys_response) + dialogue["log"].append( + {"role": "usr", + "utt": user_response, + "emotion": env.usr.policy.policy.emotion, + "act": env.usr.policy.policy.semantic_action}) + dialogue["log"].append({"role": "sys", "utt": sys_response}) + + # logging.info(f"Actions in turn: {len(sys_response)}") + turns += 1 + total_return += sess.evaluator.get_reward(session_over) + + if session_over: + task_succ = sess.evaluator.task_success() + task_succ = sess.evaluator.success + task_succ_strict = sess.evaluator.success_strict + complete = sess.evaluator.complete + break + + dialogue['Complete'] = complete + dialogue['Success'] = task_succ + dialogue['Success strict'] = task_succ_strict + dialogue['total_return'] = total_return + dialogue['turns'] = turns + + conversation.append(dialogue) + return conversation + + +if __name__ == "__main__": + import json + from datetime import datetime + import os + time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" + args = arg_parser() + conversation = interact(model_name=args.model, + config=args.config, + num_goals=args.num) + data = {"config": json.load(open(args.config)), + "conversation": conversation} + folder_name = os.path.join("convlab/policy/emoUS", "conversation") + if not os.path.exists(folder_name): + os.makedirs(folder_name) + json.dump(data, + open(os.path.join(folder_name, f"{time}.json"), 'w'), + indent=2) diff --git a/convlab/policy/emoUS/emoUS-BertNLU-RuleDST-RulePolicy.json b/convlab/policy/emoUS/emoUS-BertNLU-RuleDST-RulePolicy.json new file mode 100644 index 0000000000000000000000000000000000000000..84d4dddb8e3b318040d2e2b29f47459dd46cddf5 --- /dev/null +++ b/convlab/policy/emoUS/emoUS-BertNLU-RuleDST-RulePolicy.json @@ -0,0 +1,56 @@ +{ + "model": { + "load_path": "convlab/policy/ppo/finished_experiments/history/NLGEmoUS/experiment_2023-01-19-17-56-38/save/best_ppo", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 200, + "seed": 0, + "epoch": 100, + "eval_frequency": 5, + "process_num": 1, + "num_eval_dialogues": 20, + "sys_semantic_to_usr": false + }, + "vectorizer_sys": { + "uncertainty_vector_mul": { + "class_path": "convlab.policy.vector.vector_binary.VectorBinary", + "ini_params": { + "use_masking": true, + "manually_add_entity_names": true, + "seed": 0 + } + } + }, + "nlu_sys": { + "BertNLU": { + "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU", + "ini_params": { + "mode": "all", + "config_file": "multiwoz21_all.json", + "model_file": "https://huggingface.co/ConvLab/bert-base-nlu/resolve/main/bertnlu_unified_multiwoz21_all_context0.zip" + } + } + }, + "dst_sys": { + "RuleDST": { + "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", + "ini_params": {} + } + }, + "sys_nlg": {}, + "nlu_usr": {}, + "dst_usr": {}, + "policy_usr": { + "emoUS": { + "class_path": "convlab.policy.emoUS.emoUS.UserPolicy", + "ini_params": { + "model_checkpoint": "convlab/policy/emoUS/unify/experiments/EmoUS_emowoz+dialmage_0_1/23-01-23-15-03/", + "use_sentiment": false, + "add_persona": true, + "sample": false, + "weight": 1 + } + } + }, + "usr_nlg": {} +} \ No newline at end of file diff --git a/convlab/policy/emoUS/emoUS.py b/convlab/policy/emoUS/emoUS.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0c24b78f95d273b0d0855134a429a736fdd0c0 --- /dev/null +++ b/convlab/policy/emoUS/emoUS.py @@ -0,0 +1,492 @@ +import os +import json + +import torch + +from convlab.policy.emoUS.token_map import tokenMap +from convlab.policy.emoUS.unify.knowledge_graph import KnowledgeGraph +from convlab.policy.genTUS.stepGenTUS import \ + UserActionPolicy as GenTUSUserActionPolicy +from convlab.policy.policy import Policy +from convlab.util.custom_util import model_downloader +from convlab.policy.emoUS.unify.Goal import Goal + +DEBUG = False + + +class UserActionPolicy(GenTUSUserActionPolicy): + def __init__(self, model_checkpoint, mode="language", max_turn=40, **kwargs): + self.use_sentiment = kwargs.get("use_sentiment", False) + self.add_persona = kwargs.get("add_persona", True) + self.emotion_mid = kwargs.get("emotion_mid", False) + + if not os.path.exists(os.path.dirname(model_checkpoint)): + os.makedirs(os.path.dirname(model_checkpoint)) + model_downloader(os.path.dirname(model_checkpoint), + "https://zenodo.org/record/7801525/files/EmoUS_default.zip") + + if mode == "language": + only_action = False + elif mode == "semantic": + only_action = True + else: + raise ValueError("mode should be language or semantic") + + super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs) + weight = kwargs.get("weight", None) + self.kg = KnowledgeGraph( + tokenizer=self.tokenizer, + dataset="emowoz", + use_sentiment=self.use_sentiment, + weight=weight) + data_emotion = json.load(open("convlab/policy/emoUS/emotion.json")) + self.emotion_list = [""]*len(data_emotion) + for emotion, index in data_emotion.items(): + self.emotion_list[index] = emotion + + self.init_session() + + def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None): + allow_general_intent = False + self.model.eval() + + if not self.add_sys_from_reward: + self.goal.update_user_goal(action=sys_act, char="sys") + self.sys_acts.append(sys_act) # for terminate conversation + + # update constraint + self.time_step += 2 + + history = [] + if self.usr_acts: + if self.max_history == 1: + history = self.usr_acts[-1] + else: + history = self.usr_acts[-1*self.max_history:] + + input_dict = {"system": sys_act, + "goal": self.goal.get_goal_list(), + "history": history, + "turn": str(int(self.time_step/2))} + + if self.add_persona: + for user, info in self.user_info.items(): + input_dict[user] = info + + inputs = json.dumps(input_dict) + + with torch.no_grad(): + if emotion == "all": + raw_output = self.generate_from_emotion( + raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) + for emo in raw_output: + output = self._parse_output(raw_output[emo]) + print("emo:", emo) + print("act:", output["action"]) + print("utt:", output["text"]) + raw_output = raw_output["Neutral"] + elif emotion is not None: + raw_output = self.generate_from_emotion( + raw_inputs=inputs, emotion=emotion, mode=mode, allow_general_intent=allow_general_intent) + for emo in raw_output: + output = self._parse_output(raw_output[emo]) + print("emo:", emo) + print("act:", output["action"]) + print("utt:", output["text"]) + raw_output = raw_output[emotion] + else: + raw_output = self._generate_action( + raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) + output = self._parse_output(raw_output) + self.semantic_action = self._remove_illegal_action(output["action"]) + + if not self.only_action: + self.utterance = output["text"] + + self.emotion = output["emotion"] + if self.use_sentiment: + self.sentiment = output["sentiment"] + + if self.is_finish(): + self.emotion, self.semantic_action, self.utterance = self._good_bye() + if self.use_sentiment: + self.sentiment = "Neutral" + + self.goal.update_user_goal(action=self.semantic_action, char="usr") + self.vector.update_mentioned_domain(self.semantic_action) + self.usr_acts.append(self.semantic_action) + + del inputs + + if self.only_action: + return self.semantic_action + + return self.utterance + + def _parse_output(self, in_str): + in_str = str(in_str) + in_str = in_str.replace('<s>', '').replace( + '<\\s>', '').replace('o"clock', "o'clock") + action = {"emotion": "Neutral", "action": [], "text": ""} + if self.use_sentiment: + action["sentiment"] = "Neutral" + + try: + action = json.loads(in_str) + except: + print("invalid action:", in_str) + print("-"*20) + return action + + def _update_sentiment(self, pos, model_input, mode): + pos = self._update_seq( + self.token_map.get_id('start_sentiment'), pos) + sentiment = self._get_sentiment( + model_input, self.seq[:1, :pos], mode) + pos = self._update_seq(sentiment["token_id"], pos) + return sentiment, pos + + def _update_emotion(self, pos, model_input, mode, emotion_mode, sentiment=None): + pos = self._update_seq( + self.token_map.get_id('start_emotion'), pos) + emotion = self._get_emotion( + model_input, self.seq[:1, :pos], mode, emotion_mode, sentiment) + pos = self._update_seq(emotion["token_id"], pos) + return pos + + def _update_semantic_act(self, pos, model_input, mode, allow_general_intent): + mode = "max" + for act_len in range(self.max_action_len): + pos = self._get_semantic_action( + model_input, pos, mode, allow_general_intent) + + terminate, token_name = self._stop_semantic( + model_input, pos, act_len) + pos = self._update_seq(self.token_map.get_id(token_name), pos) + + if terminate: + break + return pos + + def _sent_act_emo(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # sent + sentiment, pos = self._update_sentiment(pos, model_input, mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode, sentiment["token_name"]) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + + return pos + + def _sent_emo_act(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # sent + sentiment, pos = self._update_sentiment(pos, model_input, mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode, sentiment["token_name"]) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + + return pos + + def _emo_act(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + + return pos + + def _act_emo(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + + return pos + + def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="normal"): + self.kg.parse_input(raw_inputs) + model_input = self.vector.encode(raw_inputs, self.max_in_len) + # start token + self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() + pos = self._update_seq([0], 0) + pos = self._update_seq(self.token_map.get_id('start_json'), pos) + + if self.use_sentiment and self.emotion_mid: + pos = self._sent_act_emo( + pos, model_input, mode, emotion_mode, allow_general_intent) + elif self.use_sentiment and not self.emotion_mid: + pos = self._sent_emo_act( + pos, model_input, mode, emotion_mode, allow_general_intent) + elif not self.use_sentiment and self.emotion_mid: + pos = self._act_emo( + pos, model_input, mode, emotion_mode, allow_general_intent) + else: # defalut method + pos = self._emo_act( + pos, model_input, mode, emotion_mode, allow_general_intent) + + if self.only_action: + # return semantic action. Don't need to generate text + return self.vector.decode(self.seq[0, :pos]) + + pos = self._update_seq(self.token_map.get_id("start_text"), pos) + text = self._get_text(model_input, pos) + + return text + + def generate_from_emotion(self, raw_inputs, emotion=None, mode="max", allow_general_intent=True): + self.kg.parse_input(raw_inputs) + model_input = self.vector.encode(raw_inputs, self.max_in_len) + responses = {} + if emotion: + emotion_list = [emotion] + else: + emotion_list = self.emotion_list + + for emotion in emotion_list: + # start token + self.seq = torch.zeros(1, self.max_out_len, + device=self.device).long() + pos = self._update_seq([0], 0) + pos = self._update_seq(self.token_map.get_id('start_json'), pos) + pos = self._update_seq( + self.token_map.get_id('start_emotion'), pos) + + pos = self._update_seq(self.kg._get_token_id(emotion), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + + # get semantic actions + for act_len in range(self.max_action_len): + pos = self._get_semantic_action( + model_input, pos, mode, allow_general_intent) + + terminate, token_name = self._stop_semantic( + model_input, pos, act_len) + pos = self._update_seq(self.token_map.get_id(token_name), pos) + + if terminate: + break + + if self.only_action: + return self.vector.decode(self.seq[0, :pos]) + + pos = self._update_seq(self.token_map.get_id("start_text"), pos) + text = self._get_text(model_input, pos) + responses[emotion] = text + + return responses + + def generate_text_from_give_semantic(self, raw_inputs, semantic_action, emotion="Neutral"): + self.kg.parse_input(raw_inputs) + model_input = self.vector.encode(raw_inputs, self.max_in_len) + self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() + pos = self._update_seq([0], 0) + pos = self._update_seq(self.token_map.get_id('start_json'), pos) + pos = self._update_seq( + self.token_map.get_id('start_emotion'), pos) + pos = self._update_seq(self.kg._get_token_id(emotion), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + + if len(semantic_action) == 0: + pos = self._update_seq(self.token_map.get_id("end_act"), pos) + + for act_id, (intent, domain, slot, value) in enumerate(semantic_action): + pos = self._update_seq(self.kg._get_token_id(intent), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.kg._get_token_id(domain), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.kg._get_token_id(slot), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.kg._get_token_id(value), pos) + + if act_id == len(semantic_action) - 1: + token_name = "end_act" + else: + token_name = "sep_act" + pos = self._update_seq(self.token_map.get_id(token_name), pos) + pos = self._update_seq(self.token_map.get_id("start_text"), pos) + + raw_output = self._get_text(model_input, pos) + return self._parse_output(raw_output)["text"] + + def _get_sentiment(self, model_input, generated_so_far, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + return self.kg.get_sentiment(next_token_logits, mode) + + def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal", sentiment=None): + mode = "max" # emotion is always max + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + return self.kg.get_emotion(next_token_logits, mode, emotion_mode, sentiment) + + def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + + return self.kg.get_intent(next_token_logits, mode, allow_general_intent) + + def init_session(self, goal=None): + self.token_map = tokenMap( + tokenizer=self.tokenizer, use_sentiment=self.use_sentiment) + self.token_map.default(only_action=self.only_action) + self.time_step = 0 + remove_domain = "police" # remove police domain in inference + + if not goal: + self._new_goal(remove_domain=remove_domain) + else: + self._read_goal(goal) + + self.vector.init_session(goal=self.goal) + + self.terminated = False + self.add_sys_from_reward = False + self.sys_acts = [] + self.usr_acts = [] + self.semantic_action = [] + self.utterance = "" + self.emotion = "Neutral" + # TODO sentiment? event? user? + self.user_info = self.goal.emotion_info() + + def _read_goal(self, data_goal): + self.goal = Goal(goal=data_goal) + + def _new_goal(self, remove_domain="police", domain_len=None): + self.goal = Goal(goal_generator=self.goal_gen) + + def _good_bye(self): + # add emotion + if self.is_success(): + return "Satisfied", [['thank', 'general', 'none', 'none']], "thank you. bye" + else: + return "Dissatisfied", [["bye", "general", "None", "None"]], "bye" + + def get_reward(self): + if self.is_finish(): + if self.is_success(): + reward = self.reward["success"] + self.success = True + else: + reward = self.reward["fail"] + self.success = False + + else: + reward = -1 + if self.use_sentiment: + if self.sentiment == "Positive": + reward += 1 + elif self.sentiment == "Negative": + reward -= 1 + + self.success = None + return reward + + +class UserPolicy(Policy): + def __init__(self, + model_checkpoint="convlab/policy/emoUS/unify/default/EmoUS_default", + mode="language", + sample=False, + action_penalty=False, + **kwargs): + # self.config = config + print("emoUS model checkpoint: ", model_checkpoint) + if sample: + print("EmoUS will sample action, but emotion is always max") + if not os.path.exists(os.path.dirname(model_checkpoint)): + os.makedirs(os.path.dirname(model_checkpoint)) + model_downloader(os.path.dirname(model_checkpoint), + "https://zenodo.org/record/7801525/files/EmoUS_default.zip") + + self.policy = UserActionPolicy( + model_checkpoint, + mode=mode, + action_penalty=action_penalty, + **kwargs) + self.policy.load(os.path.join( + model_checkpoint, "pytorch_model.bin")) + self.sample = sample + + def predict(self, sys_act, mode="max"): + if self.sample: + mode = "sample" + else: + mode = "max" + response = self.policy.predict(sys_act, mode) + self.semantic_action = self.policy.semantic_action + return response + + def init_session(self, goal=None): + self.policy.init_session(goal) + self.semantic_action = [] + + def is_terminated(self): + return self.policy.is_terminated() + + def get_reward(self): + return self.policy.get_reward() + + def get_goal(self): + if hasattr(self.policy, 'get_goal'): + return self.policy.get_goal() + return None + + def get_emotion(self): + return self.policy.emotion + + +if __name__ == "__main__": + import os + from convlab.dialog_agent import PipelineAgent + from convlab.util.custom_util import set_seed + import time + + use_sentiment, emotion_mid = False, False + set_seed(100) + # Test semantic level behaviour + usr_policy = UserPolicy( + # model_checkpoint, # default location = convlab/policy/emoUS/unify/default/EmoUS_default + mode="semantic", + sample=True, + use_sentiment=use_sentiment, + emotion_mid=emotion_mid) + # usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin")) + usr_nlu = None # BERTNLU() + usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') + usr.init_session() + usr.init_session() + print(usr.policy.get_goal()) + start = time.time() + + # print(usr.policy.policy.goal.status) + print(usr.response([['inform', 'train', 'day', 'saturday']]), + usr.policy.get_emotion()) + # print(usr.policy.policy.goal.status) + print(usr.response([]), + usr.policy.get_emotion()) + end = time.time() + print("-"*50) + print("time: ", end - start) + # print(usr.policy.policy.goal.status) diff --git a/convlab/policy/emoUS/emotion.json b/convlab/policy/emoUS/emotion.json new file mode 100644 index 0000000000000000000000000000000000000000..464a953e3e74df5ddb06e98bf2b653bae7e76e59 --- /dev/null +++ b/convlab/policy/emoUS/emotion.json @@ -0,0 +1,9 @@ +{ + "Neutral": 0, + "Fearful": 1, + "Dissatisfied": 2, + "Apologetic": 3, + "Abusive": 4, + "Excited": 5, + "Satisfied": 6 +} \ No newline at end of file diff --git a/convlab/policy/emoUS/emotion_eval.py b/convlab/policy/emoUS/emotion_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..511db5d9a1a1f43f2a2a8fc758f82712ca5dd219 --- /dev/null +++ b/convlab/policy/emoUS/emotion_eval.py @@ -0,0 +1,346 @@ +import json +import os +import sys +from argparse import ArgumentParser +from datetime import datetime + +import matplotlib.pyplot as plt +import torch +from datasets import load_metric +from sklearn import metrics +from tqdm import tqdm + +from convlab.nlg.evaluate import fine_SER +from convlab.policy.emoUS.emoUS import UserActionPolicy + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model-checkpoint", type=str, help="the model path") + parser.add_argument("--input-file", type=str, help="the testing input file", + default="") + parser.add_argument("--generated-file", type=str, help="the generated results", + default="") + parser.add_argument("--dataset", default="multiwoz") + + # model parameter + parser.add_argument("--use-sentiment", action="store_true") + parser.add_argument("--emotion-mid", action="store_true") + parser.add_argument("--weight", type=float, default=None) + parser.add_argument("--sample", action="store_true") + return parser.parse_args() + + +class Evaluator: + def __init__(self, model_checkpoint, dataset, **kwargs): + self.dataset = dataset + self.model_checkpoint = model_checkpoint + + self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" + self.use_sentiment = kwargs.get("use_sentiment", False) + self.add_persona = kwargs.get("add_persona", True) + self.emotion_mid = kwargs.get("emotion_mid", False) + weight = kwargs.get("weight", None) + self.sample = kwargs.get("sample", False) + + self.usr = UserActionPolicy( + model_checkpoint, + dataset=self.dataset, + use_sentiment=self.use_sentiment, + add_persona=self.add_persona, + emotion_mid=self.emotion_mid, + weight=weight) + + self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) + + """ + self.r = {"input", "golden_acts", "golden_utts", "golden_emotions", + emotion_acts, emotion_utts} + """ + + self.r = {"input": [], + "golden_acts": [], + "golden_utts": [], + "golden_emotion": []} + + if self.use_sentiment: + self.r["golden_sentiment"] = [] + self.r["gen_sentiment"] = [] + + self.emotion_list = [] + + for emotion in json.load(open("convlab/policy/emoUS/emotion.json")): + self.emotion_list.append(emotion) + self.r[f"{emotion}_acts"] = [] + self.r[f"{emotion}_utts"] = [] + + sent2emo = json.load( + open("convlab/policy/emoUS/sent2emo.json")) + self.emo2sent = {} + for sent, emotions in sent2emo.items(): + for emo in emotions: + self.emo2sent[emo] = sent + + def _append_result(self, temp): + for x in self.r: + self.r[x].append(temp[x]) + + def generate_results(self, f_eval, golden=False): + emotion_mode = "normal" + in_file = json.load(open(f_eval)) + + for dialog in tqdm(in_file['dialog']): + temp = {} + inputs = dialog["in"] + labels = self.usr._parse_output(dialog["out"]) + + response = self.usr.generate_from_emotion( + raw_inputs=inputs) + + temp["input"] = inputs + temp["golden_acts"] = labels["action"] + temp["golden_utts"] = labels["text"] + temp["golden_emotion"] = labels["emotion"] + + for emotion, resp in response.items(): + output = self.usr._parse_output(resp) + temp[f"{emotion}_acts"] = output["action"] + temp[f"{emotion}_utts"] = output["text"] + + if self.use_sentiment: + temp["golden_sentiment"] = labels["sentiment"] + temp["gen_sentiment"] = output["sentiment"] + + self._append_result(temp) + + def read_generated_result(self, f_eval): + in_file = json.load(open(f_eval)) + + for dialog in tqdm(in_file['dialog']): + for x in dialog: + self.r[x].append(dialog[x]) + + def _transform_result(self): + index = [x for x in self.r] + result = [] + for i in range(len(self.r[index[0]])): + temp = {} + for x in index: + temp[x] = self.r[x][i] + result.append(temp) + return result + + def nlg_evaluation(self, input_file=None, generated_file=None, golden=False): + if input_file: + print("Force generation") + self.generate_results(input_file, golden) + + elif generated_file: + self.read_generated_result(generated_file) + else: + print("You must specify the input_file or the generated_file") + mode = "max" + if self.sample: + mode = "sample" + + nlg_eval = { + "golden": golden, + "mode": mode, + "metrics": {}, + "dialog": self._transform_result() + } + + # TODO emotion metric + + dir_name = self.model_checkpoint + json.dump(nlg_eval, + open(os.path.join( + dir_name, f"{self.time}-nlg_eval.json"), 'w'), + indent=2) + return os.path.join(dir_name, f"{self.time}-nlg_eval.json") + + def evaluation(self, input_file=None, generated_file=None): + # TODO add emotion + gen_file = json.load(open(generated_file)) + self.read_generated_result(generated_file) + + r = {"golden_acts": [], "golden_emotions": [], "golden_utts": []} + for emotion in self.emotion_list: + r[f"{emotion}_acts"] = [] + r[f"{emotion}_utts"] = [] + + for dialog in gen_file['dialog']: + r["golden_acts"].append(dialog["golden_acts"]) + r["golden_emotions"].append(dialog["golden_emotion"]) + r["golden_utts"].append(dialog["golden_utts"]) + for emotion in self.emotion_list: + r[f"{emotion}_acts"].append(dialog[f"{emotion}_acts"]) + r[f"{emotion}_utts"].append(dialog[f"{emotion}_utts"]) + + dialog_result = gen_file['dialog'] + + scores = {} + for emotion in self.emotion_list: + # if emotion == "Neutral": + # continue + scores[emotion] = {"precision": [], + "recall": [], "f1": [], "turn_acc": []} + for gen_act, golden_act in zip(r[f"{emotion}_acts"], r["Neutral_acts"]): + s = f1_measure(preds=gen_act, labels=golden_act) + for metric in scores[emotion]: + scores[emotion][metric].append(s[metric]) + + result = {} + for emotion in self.emotion_list: + # if emotion == "Neutral": + # continue + result[emotion] = {} + for metric in scores[emotion]: + result[emotion][metric] = sum( + scores[emotion][metric])/len(scores[emotion][metric]) + result[emotion]["bleu"] = bleu(golden_utts=r["Neutral_utts"], + gen_utts=r[f"{emotion}_utts"]) + result[emotion]["SER"] = SER(gen_utts=r[f"{emotion}_utts"], + gen_acts=r[f"{emotion}_acts"]) + + result[emotion]["len"] = avg_len(gen_utts=r[f"{emotion}_utts"]) + + rouge_score = rouge(golden_utts=r["Neutral_utts"], + gen_utts=r[f"{emotion}_utts"]) + for metric, score in rouge_score.items(): + result[emotion][metric] = score.mid.fmeasure + + print("emotion:", emotion) + for metric in result[emotion]: + print(f"{metric}: {result[emotion][metric]}") + + # for metric in emo_score: + # result[metric] = emo_score[metric] + # print(f"{metric}: {result[metric]}") + + result["dialog"] = dialog_result + + basename = "semantic_evaluation_result" + json.dump(result, open(os.path.join( + self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'), indent=2) + + +def avg_len(gen_utts): + n = [len(s.split()) for s in gen_utts] + return sum(n)/len(n) + + +def bleu(golden_utts, gen_utts): + bleu_metric = load_metric("sacrebleu") + labels = [[utt] for utt in golden_utts] + + bleu_score = bleu_metric.compute(predictions=gen_utts, + references=labels, + force=True) + return bleu_score["score"] + + +def rouge(golden_utts, gen_utts): + rouge_metric = load_metric("rouge") + rouge_score = rouge_metric.compute(predictions=gen_utts, + references=golden_utts) + return rouge_score + + +def SER(gen_utts, gen_acts): + missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( + gen_acts, gen_utts) + if total <= 0: + print("ERROR, total = 0") + return 1 + return missing/total + + +def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutral=False): + labels = ["Neutral", "Fearful", "Dissatisfied", + "Apologetic", "Abusive", "Excited", "Satisfied"] + if no_neutral: + labels = labels[1:] + print(labels) + macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro") + sep_f1 = metrics.f1_score( + golden_emotions, gen_emotions, average=None, labels=labels) + cm = metrics.confusion_matrix( + golden_emotions, gen_emotions, normalize="true", labels=labels) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, display_labels=labels) + disp.plot() + plt.savefig(os.path.join(dirname, f"{time}-emotion.png")) + r = {"macro_f1": float(macro_f1), "sep_f1": list( + sep_f1), "cm": [list(c) for c in list(cm)]} + print(r) + return r + + +def sentiment_score(golden_sentiment, gen_sentiment, dirname=".", time=""): + labels = ["Neutral", "Negative", "Positive"] + + print(labels) + macro_f1 = metrics.f1_score( + golden_sentiment, gen_sentiment, average="macro") + sep_f1 = metrics.f1_score( + golden_sentiment, gen_sentiment, average=None, labels=labels) + cm = metrics.confusion_matrix( + golden_sentiment, gen_sentiment, normalize="true", labels=labels) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, display_labels=labels) + disp.plot() + plt.savefig(os.path.join(dirname, f"{time}-sentiment.png")) + r = {"macro_f1": float(macro_f1), "sep_f1": list( + sep_f1), "cm": [list(c) for c in list(cm)]} + print(r) + return r + + +def f1_measure(preds, labels): + tp = 0 + score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0} + for p in preds: + if p in labels: + tp += 1.0 + if preds: + score["precision"] = tp/len(preds) + if labels: + score["recall"] = tp/len(labels) + if (score["precision"] + score["recall"]) > 0: + score["f1"] = 2*(score["precision"]*score["recall"]) / \ + (score["precision"]+score["recall"]) + if tp == len(preds) and tp == len(labels): + score["turn_acc"] = 1 + return score + + +def main(): + args = arg_parser() + eval = Evaluator(args.model_checkpoint, + args.dataset, + use_sentiment=args.use_sentiment, + emotion_mid=args.emotion_mid, + weight=args.weight, + sample=args.sample) + print("=== evaluation ===") + print("model checkpoint", args.model_checkpoint) + print("generated_file", args.generated_file) + print("input_file", args.input_file) + with torch.no_grad(): + if args.generated_file: + generated_file = args.generated_file + else: + nlg_result = eval.nlg_evaluation(input_file=args.input_file, + generated_file=args.generated_file) + + generated_file = nlg_result + eval.evaluation(args.input_file, + generated_file) + + +if __name__ == '__main__': + main() diff --git a/convlab/policy/emoUS/evaluate.py b/convlab/policy/emoUS/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ac3968902371d99d511735b282fb76f2311436 --- /dev/null +++ b/convlab/policy/emoUS/evaluate.py @@ -0,0 +1,356 @@ +import json +import os +import sys +from argparse import ArgumentParser +from datetime import datetime + +import matplotlib.pyplot as plt +import torch +from datasets import load_metric +from sklearn import metrics +from tqdm import tqdm +from pprint import pprint + +from convlab.nlg.evaluate import fine_SER +from convlab.policy.emoUS.emoUS import UserActionPolicy + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model-checkpoint", type=str, help="the model path") + parser.add_argument("--model-weight", type=str, + help="the model weight", default="") + parser.add_argument("--input-file", type=str, help="the testing input file", + default="") + parser.add_argument("--generated-file", type=str, help="the generated results", + default="") + parser.add_argument("--dataset", default="multiwoz") + parser.add_argument("--golden-emotion", action="store_true", + help="golden emotion -> action + utt") + parser.add_argument("--golden-action", action="store_true", + help="golden emotion + action -> utt") + parser.add_argument("--use-sentiment", action="store_true") + parser.add_argument("--emotion-mid", action="store_true") + parser.add_argument("--weight", type=float, default=None) + parser.add_argument("--sample", action="store_true") + return parser.parse_args() + + +class Evaluator: + def __init__(self, model_checkpoint, dataset, model_weight=None, **kwargs): + self.dataset = dataset + self.model_checkpoint = model_checkpoint + self.result_dir = os.path.join(model_checkpoint, "results") + os.makedirs(self.result_dir, exist_ok=True) + self.model_weight = model_weight + self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M-%S')}" + self.use_sentiment = kwargs.get("use_sentiment", False) + self.add_persona = kwargs.get("add_persona", False) + self.emotion_mid = kwargs.get("emotion_mid", False) + self.emotion_weight = kwargs.get("weight", None) + self.sample = kwargs.get("sample", False) + print("self.emotion_weight", self.emotion_weight) + self.evaluation_result = { + "emotion prediction": {}, + "semantic action prediction": {}, + "natural language generation": {}} + + self.usr = UserActionPolicy( + model_checkpoint, + dataset=self.dataset, + use_sentiment=self.use_sentiment, + add_persona=self.add_persona, + emotion_mid=self.emotion_mid, + weight=self.emotion_weight) + + self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) + + self.r = {"input": [], + "golden_acts": [], + "golden_utts": [], + "golden_emotion": [], + "gen_acts": [], + "gen_utts": [], + "gen_emotion": []} + + if self.use_sentiment: + self.r["golden_sentiment"] = [] + self.r["gen_sentiment"] = [] + + sent2emo = json.load( + open("convlab/policy/emoUS/sent2emo.json")) + self.emo2sent = {} + for sent, emotions in sent2emo.items(): + for emo in emotions: + self.emo2sent[emo] = sent + + def _append_result(self, temp): + for x in self.r: + self.r[x].append(temp[x]) + + def generate_results(self, f_eval, golden_emotion=False, golden_action=False): + emotion_mode = "normal" + in_file = json.load(open(f_eval)) + mode = "max" + if self.sample: + mode = "sample" + + for dialog in tqdm(in_file['dialog']): + inputs = dialog["in"] + labels = self.usr._parse_output(dialog["out"]) + + if golden_action: + usr_act = labels["action"] + usr_emo = labels["emotion"] + usr_utt = self.usr.generate_text_from_give_semantic( + inputs, labels["action"], labels["emotion"]) + elif golden_emotion: + usr_emo = labels["emotion"] + output = self.usr.generate_from_emotion( + inputs, emotion=usr_emo, mode=mode) + output = self.usr._parse_output(output[usr_emo]) + usr_act = self.usr._remove_illegal_action(output["action"]) + usr_utt = output["text"] + else: + output = self.usr._parse_output( + self.usr._generate_action(inputs, mode=mode, emotion_mode=emotion_mode)) + usr_emo = output["emotion"] + usr_act = self.usr._remove_illegal_action(output["action"]) + usr_utt = output["text"] + + temp = {} + temp["input"] = inputs + temp["golden_acts"] = labels["action"] + temp["golden_utts"] = labels["text"] + temp["golden_emotion"] = labels["emotion"] + + temp["gen_acts"] = usr_act + temp["gen_utts"] = usr_utt + temp["gen_emotion"] = usr_emo + + if self.use_sentiment: + temp["golden_sentiment"] = labels["sentiment"] + temp["gen_sentiment"] = output["sentiment"] + + self._append_result(temp) + + # save generations + generations = {} + generations["time"] = self.time + generations["golden"] = False + if golden_action: + # basically, golden_action includes golden_emotion + generations["golden"] = "golden_action" + elif golden_emotion: + generations["golden"] = "golden_emotion" + generations["mode"] = mode + generations["dialog"] = self._transform_result() + + file_name = "generations.json" + if generations["golden"]: + file_name = generations['golden'] + "_" + file_name + + with open(os.path.join(self.result_dir, file_name), "w") as f: + json.dump(generations, f, indent=2) + + def read_generated_result(self, f_eval): + in_file = json.load(open(f_eval)) + + for dialog in tqdm(in_file['dialog']): + for x in dialog: + if x not in self.r: + self.r[x] = [] + self.r[x].append(dialog[x]) + + def _transform_result(self): + index = [x for x in self.r] + result = [] + for i in range(len(self.r[index[0]])): + temp = {} + for x in index: + temp[x] = self.r[x][i] + result.append(temp) + return result + + @staticmethod + def nlg_evaluation(golden_utts, gen_utts, gen_acts): + bleu_metric = load_metric("sacrebleu") + labels = [[utt] for utt in golden_utts] + bleu_score = bleu_metric.compute(predictions=gen_utts, + references=labels, + force=True) + missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( + gen_acts, gen_utts) + + return {"bleu": bleu_score["score"], "SER": missing/total} + + @staticmethod + def _intent_domain(action): + acts = [] + for intent, domain, slot, value in action: + if [intent, domain] not in acts: + acts.append([intent, domain]) + return acts + + def semantic_evaluation(self, gen_acts, golden_acts): + scores = {"full action": {"precision": [], "recall": [], "f1": [], "turn_acc": []}, + "intent-domain": {"precision": [], "recall": [], "f1": [], "turn_acc": []}} + for gen_act, golden_act in zip(gen_acts, golden_acts): + s = f1_measure(preds=gen_act, labels=golden_act) + for metric in scores["full action"]: + scores["full action"][metric].append(s[metric]) + s = f1_measure(preds=self._intent_domain(gen_act), + labels=self._intent_domain(golden_act)) + for metric in scores["intent-domain"]: + scores["intent-domain"][metric].append(s[metric]) + + result = {} + for metric_type, score in scores.items(): + result[metric_type] = {} + for m, s in score.items(): + result[metric_type][m] = sum(s)/len(s) + return result + + def evaluation(self, input_file="", generated_file="", golden_emotion=False, golden_action=False): + if input_file: + print("Force generation") + self.generate_results(input_file, golden_emotion, golden_action) + elif generated_file: + self.read_generated_result(generated_file) + else: + print("You must specify the input_file or the generated_file") + + r = self.nlg_evaluation( + self.r["golden_utts"], self.r["gen_utts"], self.r["gen_acts"]) + for metric, score in r.items(): + self.evaluation_result["natural language generation"][metric] = score + + if not golden_action: + r = self.semantic_evaluation( + self.r["gen_acts"], self.r["golden_acts"]) + for metric, score in r.items(): + self.evaluation_result["semantic action prediction"][metric] = score + + if not golden_emotion and not golden_action: + r = emotion_score(self.r["golden_emotion"], + self.r["gen_emotion"], + self.result_dir) + self.evaluation_result["emotion prediction"]["emotion"] = {} + self.evaluation_result["emotion prediction"]["emotion"]["macro_f1"] = r["macro_f1"] + self.evaluation_result["emotion prediction"]["emotion"]["sep_f1"] = { + emo: f1 for emo, f1 in zip(r["label"], r["sep_f1"])} + + if self.use_sentiment: + golden_sentiment = self.r["golden_sentiment"] + gen_sentiment = self.r["gen_sentiment"] + else: + # transfer emotions to sentiment if the model do not generate sentiment + golden_sentiment = [self.emo2sent[emo] + for emo in self.r["golden_emotion"]] + gen_sentiment = [self.emo2sent[emo] + for emo in self.r["gen_emotion"]] + r = sentiment_score(golden_sentiment, + gen_sentiment, + self.result_dir) + + self.evaluation_result["emotion prediction"]["sentiment"] = {} + self.evaluation_result["emotion prediction"]["sentiment"]["macro_f1"] = r["macro_f1"] + self.evaluation_result["emotion prediction"]["sentiment"]["sep_f1"] = { + emo: f1 for emo, f1 in zip(r["label"], r["sep_f1"])} + + pprint(self.evaluation_result) + + # def save_results(self): + + # def print_result(self): + # print("=== Natural language generation ===") + # print("Sacre-BLEU", nlg_eval["metrics"]["bleu"]["score"]) + # print("SER", nlg_eval["metrics"]["SER"]) + # self.r[""] + + +def emotion_score(golden_emotions, gen_emotions, dirname=".", no_neutral=False): + labels = ["Neutral", "Fearful", "Dissatisfied", + "Apologetic", "Abusive", "Excited", "Satisfied"] + if no_neutral: + labels = labels[1:] + + macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro") + sep_f1 = metrics.f1_score( + golden_emotions, gen_emotions, average=None, labels=labels) + cm = metrics.confusion_matrix( + golden_emotions, gen_emotions, normalize="true", labels=labels) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, display_labels=labels) + disp.plot() + plt.savefig(os.path.join(dirname, f"emotion.png")) + r = {"label": labels, + "macro_f1": float(macro_f1), + "sep_f1": list(sep_f1), + "cm": [list(c) for c in list(cm)]} + return r + + +def sentiment_score(golden_sentiment, gen_sentiment, dirname="."): + labels = ["Neutral", "Negative", "Positive"] + + macro_f1 = metrics.f1_score( + golden_sentiment, gen_sentiment, average="macro") + sep_f1 = metrics.f1_score( + golden_sentiment, gen_sentiment, average=None, labels=labels) + cm = metrics.confusion_matrix( + golden_sentiment, gen_sentiment, normalize="true", labels=labels) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, display_labels=labels) + disp.plot() + plt.savefig(os.path.join(dirname, f"sentiment.png")) + r = {"label": labels, + "macro_f1": float(macro_f1), + "sep_f1": list(sep_f1), + "cm": [list(c) for c in list(cm)]} + return r + + +def f1_measure(preds, labels): + tp = 0 + score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0} + for p in preds: + if p in labels: + tp += 1.0 + if preds: + score["precision"] = tp/len(preds) + if labels: + score["recall"] = tp/len(labels) + if (score["precision"] + score["recall"]) > 0: + score["f1"] = 2*(score["precision"]*score["recall"]) / \ + (score["precision"]+score["recall"]) + if tp == len(preds) and tp == len(labels): + score["turn_acc"] = 1 + return score + + +def main(): + args = arg_parser() + eval = Evaluator(args.model_checkpoint, + args.dataset, + args.model_weight, + use_sentiment=args.use_sentiment, + emotion_mid=args.emotion_mid, + weight=args.weight, + sample=args.sample) + print("=== evaluation ===") + print("model checkpoint", args.model_checkpoint) + print("generated_file", args.generated_file) + print("input_file", args.input_file) + with torch.no_grad(): + eval.evaluation(input_file=args.input_file, + generated_file=args.generated_file, + golden_emotion=args.golden_emotion, + golden_action=args.golden_action) + + +if __name__ == '__main__': + main() diff --git a/convlab/policy/emoUS/self_bleu.py b/convlab/policy/emoUS/self_bleu.py new file mode 100644 index 0000000000000000000000000000000000000000..b5ba1cfad2f149bfc48c4333d0f9a83c95c3c30d --- /dev/null +++ b/convlab/policy/emoUS/self_bleu.py @@ -0,0 +1,67 @@ +# from fast_bleu import SelfBLEU +import argparse +import json +from datasets import Dataset, load_metric +from tqdm import tqdm + + +def arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--file", type=str) + parser.add_argument("--fast-bleu", action="store_true") + parser.add_argument("--uss", action="store_true") + return parser.parse_args() + + +def read_file(file_name): + nlg_candidates = json.load(open(file_name)) + return nlg_candidates + + +def get_sent(candidates, bleu_mode="torch", uss=False): + if bleu_mode == "torch": + if uss: + return [x["preds"] for x in candidates] + if "log" in candidates: + return [x["gen_utts"] for x in candidates["log"]] + else: + return [x["gen_utts"] for x in candidates["dialog"]] + else: + if uss: + return [x["preds"].split() for x in candidates] + if "log" in candidates: + return [x["gen_utts"].split() for x in candidates["log"]] + else: + return [x["gen_utts"].split() for x in candidates["dialog"]] + + +def SelfBLEU(sentences): + metric = load_metric("sacrebleu") + result = [] + for i, sent in tqdm(enumerate(sentences), ascii=True): + r = metric.compute(predictions=[sent], references=[ + sentences[i:]+sentences[i+1:]]) + result.append(r["score"]) + + return sum(result)/len(result) + + +def calculate(candidates, bleu_mode="torch", uss=False): + sentences = get_sent(candidates, bleu_mode, uss) + if bleu_mode == "torch": + x = SelfBLEU(sentences) + else: + bleu = fast_bleu.SelfBLEU(sentences) + x = bleu.get_score() + # x = bleu.get_score() + # print(x) + print(sum(x[4])/len(x[4])) + + +if __name__ == "__main__": + args = arg_parser() + if args.fast_bleu: + import fast_bleu + calculate(read_file(args.file), "fast-bleu", uss=args.uss) + else: + calculate(read_file(args.file), uss=args.uss) diff --git a/convlab/policy/emoUS/sent2emo.json b/convlab/policy/emoUS/sent2emo.json new file mode 100644 index 0000000000000000000000000000000000000000..aca824ff7d93e3aadaa5cd8513afa7c8c8cbef4c --- /dev/null +++ b/convlab/policy/emoUS/sent2emo.json @@ -0,0 +1,15 @@ +{ + "Neutral": [ + "Neutral" + ], + "Negative": [ + "Fearful", + "Dissatisfied", + "Apologetic", + "Abusive" + ], + "Positive": [ + "Excited", + "Satisfied" + ] +} \ No newline at end of file diff --git a/convlab/policy/emoUS/sentiment.json b/convlab/policy/emoUS/sentiment.json new file mode 100644 index 0000000000000000000000000000000000000000..7d39df53b29133f0ef817326c37df481f3c936ed --- /dev/null +++ b/convlab/policy/emoUS/sentiment.json @@ -0,0 +1,5 @@ +{ + "Neutral": 0, + "Negative": 1, + "Positive": 2 +} \ No newline at end of file diff --git a/convlab/policy/emoUS/token_map.py b/convlab/policy/emoUS/token_map.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8eef2fca42dbc36fc80e9e4c6355ccd94a0091 --- /dev/null +++ b/convlab/policy/emoUS/token_map.py @@ -0,0 +1,68 @@ +import json + + +class tokenMap: + def __init__(self, tokenizer, **kwargs): + self.tokenizer = tokenizer + self.token_name = {} + self.hash_map = {} + self.debug = False + self.default() + + def default(self, only_action=False): + self.format_tokens = { + 'start_json': '{"', + 'start_sentiment': 'sentiment": "', + 'start_emotion': 'emotion": "', + 'start_act': 'action": [["', + 'sep_token': '", "', + 'sep_act': '"], ["', + 'end_act': '"]], "', + 'start_text': 'text": "', + 'end_json': '}', + 'end_json_2': '"}', + 'book': 'book' + } + + if only_action: + self.format_tokens['end_act'] = '"]]}' + for token_name in self.format_tokens: + self.add_token( + token_name, self.format_tokens[token_name]) + + def add_token(self, token_name, value): + if token_name in self.token_name and self.debug: + print(f"---> duplicate token: {token_name}({value})!!!!!!!") + + token_id = self.tokenizer(str(value), add_special_tokens=False)[ + "input_ids"] + self.token_name[token_name] = {"value": value, "token_id": token_id} + # print(token_id) + hash_id = token_id[0] + if hash_id in self.hash_map and self.debug: + print( + f"---> conflict hash number {hash_id}: {self.hash_map[hash_id]['name']} and {token_name}") + self.hash_map[hash_id] = { + "name": token_name, "value": value, "token_id": token_id} + + def get_info(self, hash_id): + return self.hash_map[hash_id] + + def get_id(self, token_name): + # workaround + # if token_name not in self.token_name[token_name]: + # self.add_token(token_name, token_name) + return self.token_name[token_name]["token_id"] + + def get_token_value(self, token_name): + return self.token_name[token_name]["value"] + + def token_name_is_in(self, token_name): + if token_name in self.token_name: + return True + return False + + def hash_id_is_in(self, hash_id): + if hash_id in self.hash_map: + return True + return False diff --git a/convlab/policy/emoUS/train_model.py b/convlab/policy/emoUS/train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9883ee4503fe2ca043ddf8519259e53b02c8a864 --- /dev/null +++ b/convlab/policy/emoUS/train_model.py @@ -0,0 +1,399 @@ +import json +import os +import sys +from argparse import ArgumentParser +from datetime import datetime + +import numpy as np +import torch +import transformers +from datasets import Dataset, load_metric +from tqdm import tqdm +from transformers import (BartForConditionalGeneration, BartTokenizer, + DataCollatorForSeq2Seq, Seq2SeqTrainer, + Seq2SeqTrainingArguments) + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + +# os.environ["WANDB_DISABLED"] = "true" + +METRIC = load_metric("sacrebleu") +TOKENIZER = BartTokenizer.from_pretrained("facebook/bart-base") +TOKENIZER.add_tokens(["<?>"]) +MAX_IN_LEN = 500 +MAX_OUT_LEN = 500 + + +def arg_parser(): + parser = ArgumentParser() + # data_name, dial_ids_order, split2ratio + parser.add_argument("--model-type", type=str, default="unify", + help="unify or multiwoz") + parser.add_argument("--data-name", type=str, default="emowoz", + help="emowoz or dialmage") + parser.add_argument("--dial-ids-order", type=int, default=0) + parser.add_argument("--split2ratio", type=float, default=1) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--model-checkpoint", type=str, + default="facebook/bart-base") + parser.add_argument("--fine-tune", action="store_true") + return parser.parse_args() + + +def gentus_compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = TOKENIZER.batch_decode( + preds, skip_special_tokens=True, max_length=MAX_OUT_LEN) + + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id) + decoded_labels = TOKENIZER.batch_decode( + labels, skip_special_tokens=True, max_length=MAX_OUT_LEN) + + act, text = postprocess_text(decoded_preds, decoded_labels) + + result = METRIC.compute( + # predictions=decoded_preds, references=decoded_labels) + predictions=text["preds"], references=text["labels"]) + result = {"bleu": result["score"]} + f1_scores = f1_measure(pred_acts=act["preds"], label_acts=act["labels"]) + for s in f1_scores: + result[s] = f1_scores[s] + + result = {k: round(v, 4) for k, v in result.items()} + return result + + +def basic_metric(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = TOKENIZER.batch_decode( + preds, skip_special_tokens=True, max_length=MAX_OUT_LEN) + + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id) + decoded_labels = TOKENIZER.batch_decode( + labels, skip_special_tokens=True, max_length=MAX_OUT_LEN) + labels = [[x] for x in decoded_labels] + + result = METRIC.compute( + predictions=decoded_preds, references=labels) + result = {"bleu": result["score"]} + return result + + +def postprocess_text(preds, labels): + act = {"preds": [], "labels": []} + text = {"preds": [], "labels": []} + + for pred, label in zip(preds, labels): + model_output = parse_output(pred.strip()) + label_output = parse_output(label.strip()) + if len(label_output["text"]) < 1: + continue + act["preds"].append(model_output.get("action", [])) + text["preds"].append(model_output.get("text", pred.strip())) + act["labels"].append(label_output["action"]) + text["labels"].append([label_output["text"]]) + + return act, text + + +def parse_output(in_str): + in_str = in_str.replace('<s>', '').replace('<\\s>', '') + try: + output = json.loads(in_str) + except: + # print(f"invalid action {in_str}") + output = {"action": [], "text": ""} + return output + + +def f1_measure(pred_acts, label_acts): + result = {"precision": [], "recall": [], "f1": []} + for pred, label in zip(pred_acts, label_acts): + r = tp_fn_fp(pred, label) + for m in result: + result[m].append(r[m]) + for m in result: + result[m] = sum(result[m])/len(result[m]) + + return result + + +def tp_fn_fp(pred, label): + tp, fn, fp = 0.0, 0.0, 0.0 + precision, recall, f1 = 0, 0, 0 + for p in pred: + if p in label: + tp += 1 + else: + fp += 1 + for l in label: + if l not in pred: + fn += 1 + if (tp+fp) > 0: + precision = tp / (tp+fp) + if (tp+fn) > 0: + recall = tp/(tp+fn) + if (precision + recall) > 0: + f1 = (2*precision*recall)/(precision+recall) + + return {"precision": precision, "recall": recall, "f1": f1} + + +class TrainerHelper: + def __init__(self, tokenizer, max_input_length=500, max_target_length=500): + print("transformers version is: ", transformers.__version__) + self.tokenizer = tokenizer + self.max_input_length = max_input_length + self.max_target_length = max_target_length + self.base_name = "convlab/policy/emoUS" + self.dir_name = "" + + def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1): + if model_type not in ["unify", "multiwoz"]: + print("Unknown model type. Currently only support unify and multiwoz") + self.dir_name = f"{data_name}_{dial_ids_order}_{split2ratio}" + return os.path.join(self.base_name, model_type, 'data', self.dir_name) + + def get_model_folder(self, model_type): + folder_name = os.path.join( + self.base_name, model_type, "experiments", self.dir_name) + if not os.path.exists(folder_name): + os.makedirs(folder_name) + return folder_name + + def parse_data(self, model_type, data_name, dial_ids_order=0, split2ratio=1): + data_folder = self._get_data_folder( + model_type, data_name, dial_ids_order, split2ratio) + + raw_data = {} + for d_type in ["train", "validation", "test"]: + f_name = os.path.join(data_folder, f"{d_type}.json") + raw_data[d_type] = json.load(open(f_name)) + + tokenized_datasets = {} + for data_type, data in raw_data.items(): + tokenized_datasets[data_type] = Dataset.from_dict( + self._preprocess(data["dialog"])) + + return tokenized_datasets + + def remove_dialmage_action(self): + self.dir_name = "fine_tune" + folder = "convlab/policy/emoUS/unify/data" + data_name = { + "emowoz": "EmoUS_emowoz_0_1", + "dialmage": "EmoUS_dialmage_0_1_emotion_only"} + data = {} + for d, d_n in data_name.items(): + data[d] = {} + for d_type in ["train", "validation", "test"]: + f_name = os.path.join(folder, d_n, f"{d_type}.json") + data[d][d_type] = json.load(open(f_name)) + + tokenized_datasets = {} + for d_n, d in data.items(): + tokenized_datasets[d_n] = {} + for s_d_n, s_d in d.items(): + tokenized_datasets[d_n][s_d_n] = Dataset.from_dict( + self._preprocess(s_d["dialog"])) + return tokenized_datasets + + def _preprocess(self, examples): + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if isinstance(examples, dict): + examples = [examples] + for example in tqdm(examples): + inputs = self.tokenizer(example["in"], + max_length=self.max_input_length, + truncation=True) + + # Setup the tokenizer for targets + with self.tokenizer.as_target_tokenizer(): + labels = self.tokenizer(example["out"], + max_length=self.max_target_length, + truncation=True) + for key in ["input_ids", "attention_mask"]: + model_inputs[key].append(inputs[key]) + model_inputs["labels"].append(labels["input_ids"]) + + return model_inputs + + +def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"): + tokenizer = TOKENIZER + + train_helper = TrainerHelper( + tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length) + data = train_helper.parse_data(model_type=model_type, + data_name=data_name, + dial_ids_order=dial_ids_order, + split2ratio=split2ratio) + + model = BartForConditionalGeneration.from_pretrained(model_checkpoint) + model.resize_token_embeddings(len(tokenizer)) + fp16 = False + if torch.cuda.is_available(): + print("use cuda") + fp16 = True + model.to("cuda") + + model_dir = os.path.join( + train_helper.get_model_folder(model_type), + f"{datetime.now().strftime('%y-%m-%d-%H-%M')}") + + args = Seq2SeqTrainingArguments( + model_dir, + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + weight_decay=0.01, + save_total_limit=2, + num_train_epochs=5, + predict_with_generate=True, + fp16=fp16, + push_to_hub=False, + generation_max_length=max_target_length, + logging_dir=os.path.join(model_dir, 'log') + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, model=model, padding=True) + + # customize this trainer + trainer = Seq2SeqTrainer( + model=model, + args=args, + train_dataset=data["train"], + eval_dataset=data["test"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=gentus_compute_metrics) + print("start training...") + trainer.train() + print("saving model...") + trainer.save_model() + + +def fine_tune_on_dialmage(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"): + tokenizer = TOKENIZER + + train_helper = TrainerHelper( + tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length) + data = train_helper.remove_dialmage_action() + + model = BartForConditionalGeneration.from_pretrained(model_checkpoint) + model.resize_token_embeddings(len(tokenizer)) + fp16 = False + if torch.cuda.is_available(): + print("use cuda") + fp16 = True + model.to("cuda") + + model_dir = os.path.join( + train_helper.get_model_folder(model_type), + f"{datetime.now().strftime('%y-%m-%d-%H-%M')}") + + # Emowoz + + args = Seq2SeqTrainingArguments( + model_dir, + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + weight_decay=0.01, + save_total_limit=2, + num_train_epochs=4, + predict_with_generate=True, + fp16=fp16, + push_to_hub=False, + generation_max_length=max_target_length, + logging_dir=os.path.join(model_dir, 'log') + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, model=model, padding=True) + + trainer = Seq2SeqTrainer( + model=model, + args=args, + train_dataset=data["emowoz"]["train"], + eval_dataset=data["emowoz"]["test"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=gentus_compute_metrics) + print("start training...") + trainer.train() + print("saving model...") + trainer.save_model() + + # dialmage + args = Seq2SeqTrainingArguments( + model_dir+"_dialmage_fine_tune", + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + weight_decay=0.01, + save_total_limit=2, + num_train_epochs=1, + predict_with_generate=True, + fp16=fp16, + push_to_hub=False, + generation_max_length=max_target_length, + logging_dir=os.path.join(model_dir, 'log') + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, model=model, padding=True) + + trainer = Seq2SeqTrainer( + model=model, + args=args, + train_dataset=data["dialmage"]["train"], + eval_dataset=data["dialmage"]["test"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=basic_metric) + print("start training...") + trainer.train() + print("saving model...") + trainer.save_model() + + +def main(): + args = arg_parser() + print("---> data_name", args.data_name) + if args.fine_tune: + fine_tune_on_dialmage( + model_type=args.model_type, + data_name=args.data_name, + dial_ids_order=args.dial_ids_order, + split2ratio=args.split2ratio, + batch_size=args.batch_size, + max_input_length=MAX_IN_LEN, + max_target_length=MAX_OUT_LEN, + model_checkpoint=args.model_checkpoint + ) + else: + + train( + model_type=args.model_type, + data_name=args.data_name, + dial_ids_order=args.dial_ids_order, + split2ratio=args.split2ratio, + batch_size=args.batch_size, + max_input_length=MAX_IN_LEN, + max_target_length=MAX_OUT_LEN, + model_checkpoint=args.model_checkpoint + ) + + +if __name__ == "__main__": + main() + # sgd+tm: 46000 diff --git a/convlab/policy/emoUS/unify/Goal.py b/convlab/policy/emoUS/unify/Goal.py new file mode 100644 index 0000000000000000000000000000000000000000..c5fa14c7ac3d32ef58a8cfaa869fec3c939e9433 --- /dev/null +++ b/convlab/policy/emoUS/unify/Goal.py @@ -0,0 +1,118 @@ +""" +The user goal for unify data format +""" +from convlab.policy.genTUS.unify.Goal import Goal as GenTUSGoal +from convlab.policy.genTUS.unify.Goal import DEF_VAL_UNK +from random import random +from convlab.util.custom_util import set_seed + + +class Goal(GenTUSGoal): + """ User Goal Model Class. """ + + def __init__(self, goal=None, goal_generator=None, use_sentiment=False): + """ + create new Goal from a dialog or from goal_generator + Args: + goal: can be a list (create from a dialog), an abus goal, or none + """ + super().__init__(goal, goal_generator) + self.use_sentiment = use_sentiment + # TODO sample Exciting? User politeness? + + def _init_goal_from_data(self, goal=None, goal_generator=None): + goal = self._old_goal(goal, goal_generator) + # be careful of this order + for domain, intent, slot, value in goal: + if domain == "none": + continue + if domain not in self.domains: + self.domains.append(domain) + self.domain_goals[domain] = {} + if intent not in self.domain_goals[domain]: + self.domain_goals[domain][intent] = {} + + if not value: + if intent == "request": + self.domain_goals[domain][intent][slot] = DEF_VAL_UNK + else: + print( + f"unknown no value intent {domain}, {intent}, {slot}") + else: + self.domain_goals[domain][intent][slot] = value + + def emotion_info(self): + self.user_persona = {"user": "Polite"} + event = {} + z = random() + if z > 0.95: + self.user_persona["user"] = "Impolite" + # TODO: should check domains only in the user goal + + for d in self.domains: + # Excited + z = random() + if z > 0.8 and d in ["restaurant", "attraction", "train"]: + event[d] = "Excited" + z = random() + if z > 0.95 and d in ["restaurant", "police", "hospital"] and d not in event: + event[d] = "Fearful" + + if event: + self.user_persona["event"] = event + + return self.user_persona + + +def emotion_info(dialog=None, goal=None): + user_persona = {"user": "Polite"} + event_emotion = {1: "Fearful", 5: "Excited"} + event = {} + if dialog is None: + # politeness + z = random() + if z > 0.95: + user_persona = "Impolite" + # TODO: should check domains only in the user goal + + for d in ["restaurant", "attraction", "train"]: + z = random() + if z > 0.8: + event[d] = "Excited" + for d in ["restaurant", "police", "hospital"]: + if d in event: + continue + z = random() + if z > 0.95: + event[d] = "Fearful" + if event: + user_persona["event"] = event + + else: + for turn in dialog['turns']: + if turn['speaker'] == 'user': + emotion = turn["emotion"][-1]["emotion"] + # Fearful and Excited + if int(emotion) in event_emotion: + domain = check_domain(turn["dialogue_acts"]) + for d in domain: + if d not in event: + event[d] = event_emotion[emotion] + # Abusive + if int(emotion) == 4: + user_persona["user"] = "Impolite" + if event: + user_persona["event"] = event + + return user_persona + + +def check_domain(dialog_act): + domain = [] + for _, acts in dialog_act.items(): + for act in acts: + if act["domain"] == "general": + continue + if act["domain"] not in domain: + domain.append(act["domain"]) + return domain diff --git a/convlab/policy/emoUS/unify/build_data.py b/convlab/policy/emoUS/unify/build_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e2af8f83c78cbbd84445eb5cb2f585f06c1469 --- /dev/null +++ b/convlab/policy/emoUS/unify/build_data.py @@ -0,0 +1,213 @@ +import json +import os +import sys +from argparse import ArgumentParser + +from tqdm import tqdm + +from convlab.policy.emoUS.unify.Goal import Goal, emotion_info +from convlab.policy.genTUS.unify.build_data import \ + DataBuilder as GenTUSDataBuilder +from convlab.policy.genTUS.unify.Goal import transform_data_act +from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--dataset", type=str, default="emowoz+dialmage") + parser.add_argument("--dial-ids-order", type=int, default=0) + parser.add_argument("--split2ratio", type=float, default=1) + parser.add_argument("--use-sentiment", action="store_true") + parser.add_argument("--add-persona", action="store_true") + parser.add_argument("--emotion-mid", action="store_true") + parser.add_argument("--emotion-only", action="store_true") + + return parser.parse_args() + + +class DataBuilder(GenTUSDataBuilder): + def __init__(self, dataset='emowoz', **kwargs): + super().__init__(dataset) + self.use_sentiment = kwargs.get("use_sentiment", False) + self.emotion_mid = kwargs.get("emotion_mid", False) + self.add_persona = kwargs.get("add_persona", False) + self.emotion_only = kwargs.get("emotion_only", False) + + self.emotion = {} + for emotion, index in json.load(open("convlab/policy/emoUS/emotion.json")).items(): + self.emotion[int(index)] = emotion + + if use_sentiment: + self.sentiment = {} + for sentiment, index in json.load(open("convlab/policy/emoUS/sentiment.json")).items(): + self.sentiment[int(index)] = sentiment + self.sent2emo = json.load( + open("convlab/policy/emoUS/sent2emo.json")) + # TODO check excited distribution + + def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False): + example = [] + history = [] + + data_goal = self.norm_domain_goal(create_goal(dialog)) + if not data_goal: + return example + user_goal = Goal(goal=data_goal) + user_info = None + if self.add_persona: + user_info = emotion_info(dialog) + # if user_info["user"] == "Impolite": + # print(user_info) + # if "event" in user_info: + # print(user_info) + + for turn_id in range(0, len(dialog["turns"]), 2): + sys_act = self._get_sys_act(dialog, turn_id) + + user_goal.update_user_goal(action=sys_act, char="sys") + usr_goal_str = self._user_goal_str( + user_goal, data_goal, random_order, no_status) + + usr_act = self.norm_domain(transform_data_act( + dialog["turns"][turn_id]["dialogue_acts"])) + user_goal.update_user_goal(action=usr_act, char="usr") + + # change value "?" to "<?>" + usr_act = self._modify_act(usr_act) + usr_emotion = self.emotion[ + dialog["turns"][turn_id]["emotion"][-1]["emotion"]] + + in_str = self._dump_in_str( + sys_act, usr_goal_str, history, turn_id, add_history, user_info) + + if self.use_sentiment: + usr_sentiment = self.sentiment[ + dialog["turns"][turn_id]["emotion"][-1]["sentiment"]] + out_str = self._dump_out_str( + usr_act, dialog["turns"][turn_id]["utterance"], usr_emotion, usr_sentiment) + + else: + out_str = self._dump_out_str( + usr_act, dialog["turns"][turn_id]["utterance"], usr_emotion) + + history.append(usr_act) + if usr_act: + example.append({"in": in_str, "out": out_str}) + + return example + + def _dump_in_str(self, sys_act, usr_goal_str, history, turn_id, add_history, user_info=None): + in_str = {} + in_str["system"] = self._modify_act(sys_act) + in_str["goal"] = usr_goal_str + if add_history: + h = [] + if history: + h = history[-3:] + in_str["history"] = h + in_str["turn"] = str(int(turn_id/2)) + + if self.add_persona: + for info in ["event", "user"]: + if info not in user_info: + continue + in_str[info] = user_info[info] + + return json.dumps(in_str) + + def _dump_out_str(self, usr_act, text, usr_emotion, usr_sentiment=None): + if self.use_sentiment and self.emotion_mid: + out_str = {"sentiment": usr_sentiment, + "action": usr_act, + "emotion": usr_emotion, + "text": text} + elif self.use_sentiment and not self.emotion_mid: + out_str = {"sentiment": usr_sentiment, + "emotion": usr_emotion, + "action": usr_act, + "text": text} + elif not self.use_sentiment and not self.emotion_mid: + if self.emotion_only: + out_str = {"emotion": usr_emotion} + else: + out_str = {"emotion": usr_emotion, + "action": usr_act, + "text": text} + else: + out_str = {"action": usr_act, + "emotion": usr_emotion, + "text": text} + return json.dumps(out_str) + + +if __name__ == "__main__": + args = arg_parser() + + base_name = "convlab/policy/emoUS/unify/data" + dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}" + + use_sentiment = args.use_sentiment + emotion_mid = args.emotion_mid + add_persona = args.add_persona + + data_status = [use_sentiment, emotion_mid, add_persona] + + if data_status == [True, True, True]: + # current sentUS + dir_name = f"SentUS_{dir_name}" + elif data_status == [True, True, False]: + # current sentUS without persona + dir_name = f"SentUS_noPersona_{dir_name}" + elif data_status == [False, False, True]: + # current emoUS with persona + dir_name = f"EmoUS_{dir_name}" + elif data_status == [False, False, False]: + # current emoUS + dir_name = f"EmoUS_noPersona_{dir_name}" + elif data_status == [False, True, True]: + # mid emotion + dir_name = f"MIDemoUS_{dir_name}" + elif data_status == [False, True, False]: + dir_name = f"MIDemoUS_noPersona_{dir_name}" + elif data_status == [True, False, True]: + # sentiment followed by emotion, not act + dir_name = f"SentEmoUS_{dir_name}" + elif data_status == [True, False, False]: + # sentiment followed by emotion, not act, without perosna + dir_name = f"SentEmoUS_noPersona_{dir_name}" + else: + print("NOT DEFINED", use_sentiment, add_persona, emotion_mid) + + if args.emotion_only: + dir_name = dir_name + '_emotion_only' + print("dir_name", dir_name) + + folder_name = os.path.join(base_name, dir_name) + + if not os.path.exists(folder_name): + os.makedirs(folder_name) + dataset = load_experiment_dataset( + data_name=args.dataset, + dial_ids_order=args.dial_ids_order, + split2ratio=args.split2ratio) + data_builder = DataBuilder( + dataset=args.dataset, + use_sentiment=use_sentiment, + add_persona=add_persona, + emotion_mid=emotion_mid, + emotion_only=args.emotion_only) + data = data_builder.setup_data( + raw_data=dataset, + random_order=False, + no_status=False, + add_history=True, + remove_domain=None) + + for data_type in data: + file_name = os.path.join( + folder_name, + f"{data_type}.json") + json.dump(data[data_type], open(file_name, 'w'), indent=2) diff --git a/convlab/policy/emoUS/unify/knowledge_graph.py b/convlab/policy/emoUS/unify/knowledge_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..7b68c2fe55dc3a0161dd2b1a6136acf0f20cbe41 --- /dev/null +++ b/convlab/policy/emoUS/unify/knowledge_graph.py @@ -0,0 +1,84 @@ +import json +from random import choices + +from convlab.policy.genTUS.token_map import tokenMap +from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph as GenTUSKnowledgeGraph + +from transformers import BartTokenizer + +DEBUG = False +DATASET = "unify" + +# TODO add emotion + + +class KnowledgeGraph(GenTUSKnowledgeGraph): + def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="emowoz", use_sentiment=False, weight=None): + super().__init__(tokenizer, ontology_file, dataset="multiwoz") + self.use_sentiment = use_sentiment + + if use_sentiment: + data_sentiment = json.load( + open("convlab/policy/emoUS/sentiment.json")) + self.kg_map = {"sentiment": tokenMap(tokenizer=self.tokenizer)} + self.sentiment = [""]*len(data_sentiment) + for sentiment, index in data_sentiment.items(): + self.sentiment[index] = sentiment + for sentiment in self.sentiment: + self.kg_map["sentiment"].add_token(sentiment, sentiment) + self.kg_map[sentiment] = tokenMap(tokenizer=self.tokenizer) + self.sent2emo = json.load( + open("convlab/policy/emoUS/sent2emo.json")) + for sent in self.sent2emo: + for emo in self.sent2emo[sent]: + self.kg_map[sent].add_token(emo, emo) + + else: + data_emotion = json.load( + open("convlab/policy/emoUS/emotion.json")) + self.emotion = [""]*len(data_emotion) + for emotion, index in data_emotion.items(): + self.emotion[index] = emotion + self.kg_map = {"emotion": tokenMap(tokenizer=self.tokenizer)} + for emotion in self.emotion: + self.kg_map["emotion"].add_token(emotion, emotion) + + self.emotion_weight = {"Neutral": 1, + "Fearful": 1, + "Dissatisfied": 1, + "Apologetic": 1, + "Abusive": 1, + "Excited": 1, + "Satisfied": 1} + self.sentiment_weight = {"Neutral": 1, "Positive": 1, "Negative": 1} + + if weight: + if use_sentiment: + self.sentiment_weight["Neutral"] = weight + else: + self.emotion_weight["Neutral"] = weight + + def get_sentiment(self, outputs, mode="max"): + score = self._get_max_score( + outputs, self.sentiment, "sentiment", weight=self.sentiment_weight) + s = self._select(score, mode) + return score[s] + + def get_emotion(self, outputs, mode="max", emotion_mode="normal", sentiment=None): + if self.use_sentiment: + if not sentiment: + print("You are in 'use_sentiment' mode. Please provide sentiment") + score = self._get_max_score( + outputs, self.sent2emo[sentiment], "sentiment") + else: + if emotion_mode == "normal": + score = self._get_max_score( + outputs, self.emotion, "emotion", weight=self.emotion_weight) + elif emotion_mode == "no_neutral": + score = self._get_max_score( + outputs, self.emotion[1:], "emotion", weight=self.emotion_weight) + else: + print(f"unknown emotion mode: {emotion_mode}") + s = self._select(score, mode) + + return score[s] diff --git a/convlab/policy/genTUS/evaluate.py b/convlab/policy/genTUS/evaluate.py index 87de854970d2701900ba180d2bf15736071e0c1a..23306764207f3c11ed6493efacacdda0fc530a57 100644 --- a/convlab/policy/genTUS/evaluate.py +++ b/convlab/policy/genTUS/evaluate.py @@ -147,6 +147,14 @@ class Evaluator: indent=2) return os.path.join(dir_name, "nlg_eval.json") + @staticmethod + def _intent_domain(action): + acts = [] + for intent, domain, slot, value in action: + if [intent, domain] not in acts: + acts.append([intent, domain]) + return acts + def evaluation(self, input_file=None, generated_file=None): force_prediction = True if generated_file: @@ -187,17 +195,28 @@ class Evaluator: golden_acts.append(dialog["golden_acts"]) dialog_result = gen_file['dialog'] - scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []} + scores = {"complete": {"precision": [], "recall": [], "f1": [], "turn_acc": []}, + "intent_domain": {"precision": [], "recall": [], "f1": [], "turn_acc": []}} for gen_act, golden_act in zip(gen_acts, golden_acts): s = f1_measure(preds=gen_act, labels=golden_act) - for metric in scores: - scores[metric].append(s[metric]) + for metric in scores["complete"]: + scores["complete"][metric].append(s[metric]) + s = f1_measure(preds=self._intent_domain(gen_act), + labels=self._intent_domain(golden_act)) + for metric in scores["intent_domain"]: + scores["intent_domain"][metric].append(s[metric]) result = {} - for metric in scores: - result[metric] = sum(scores[metric])/len(scores[metric]) - print(f"{metric}: {result[metric]}") + # for metric in scores: + # result[metric] = sum(scores[metric])/len(scores[metric]) + # print(f"{metric}: {result[metric]}") + + for metric_type, score in scores.items(): + result[metric_type] = {} + for m, s in score.items(): + result[metric_type][m] = sum(s)/len(s) + print(f"{metric_type}-{m}: {result[metric_type][m]}") result["dialog"] = dialog_result basename = "semantic_evaluation_result" diff --git a/convlab/policy/genTUS/ppo/vector.py b/convlab/policy/genTUS/ppo/vector.py index 4c502a46f87582008ff49219f8a14844378b9ed2..ca1a74152a6a25eae127d13dfb0804271275ef07 100644 --- a/convlab/policy/genTUS/ppo/vector.py +++ b/convlab/policy/genTUS/ppo/vector.py @@ -19,8 +19,8 @@ class stepGenTUSVector: self.mentioned_domain = [] self.allow_general_intent = allow_general_intent self.candidate_num = 5 - if self.allow_general_intent: - print("---> allow_general_intent") + # if self.allow_general_intent: + # print("---> allow_general_intent") def init_session(self, goal: Goal): self.goal = goal diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py index 0b5af9f3315e9db1cc98e4618e1800155d97e670..aa18e46abfa2aeb3cd184b519dfa6ee4a1a603f9 100644 --- a/convlab/policy/genTUS/stepGenTUS.py +++ b/convlab/policy/genTUS/stepGenTUS.py @@ -3,6 +3,8 @@ import os import torch from transformers import BartTokenizer +from random import choices + from convlab.policy.genTUS.ppo.vector import stepGenTUSVector from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel @@ -30,7 +32,7 @@ class UserActionPolicy(Policy): self.max_in_len = 500 self.max_out_len = 100 if only_action else 200 max_act_len = kwargs.get("max_act_len", 2) - print("max_act_len", max_act_len) + # print("max_act_len", max_act_len) self.max_action_len = max_act_len if "max_act_len" in kwargs: self.max_out_len = 30 * self.max_action_len @@ -109,7 +111,7 @@ class UserActionPolicy(Policy): # get text output pos = self._update_seq(self.token_map.get_id("start_text"), pos) - text = self._get_text(model_input, pos) + text = self._get_text(model_input, pos, mode) return text @@ -143,13 +145,18 @@ class UserActionPolicy(Policy): raw_output = self._get_text(model_input, pos) return self._parse_output(raw_output)["text"] - def _get_text(self, model_input, pos): + def _get_text(self, model_input, pos, mode="max"): s_pos = pos + mode = "sample" for i in range(s_pos, self.max_out_len): next_token_logits = self.model.get_next_token_logits( model_input, self.seq[:1, :pos]) - next_token = torch.argmax(next_token_logits, dim=-1) - + if mode == "sample": + s = torch.multinomial(torch.softmax( + next_token_logits, dim=-1), 1) + next_token = s + else: + next_token = torch.argmax(next_token_logits, dim=-1) if self._stop_text(next_token): # text = self.vector.decode(self.seq[0, s_pos:pos]) # text = self._norm_str(text) @@ -216,6 +223,11 @@ class UserActionPolicy(Policy): # get slot slot = self._get_slot( model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) + if "book" in slot["token_name"]: + pos = self._update_seq(self.token_map.get_id('book'), pos) + slot = self._get_book_slot( + model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) + slot["token_name"] = "book" + slot["token_name"] pos = self._update_seq(slot["token_id"], pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos) @@ -245,6 +257,12 @@ class UserActionPolicy(Policy): is_mentioned = self.vector.is_mentioned(domain) return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned) + def _get_book_slot(self, model_input, generated_so_far, intent, domain, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + is_mentioned = self.vector.is_mentioned(domain) + return self.kg.get_book_slot(next_token_logits, intent, domain, mode, is_mentioned) + def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"): next_token_logits = self.model.get_next_token_logits( model_input, generated_so_far) @@ -277,6 +295,7 @@ class UserActionPolicy(Policy): return action def predict(self, sys_act, mode="max", allow_general_intent=True): + # TODO emotion # raw_sys_act = sys_act # sys_act = sys_act[:5] # update goal @@ -592,7 +611,7 @@ class UserPolicy(Policy): **kwargs): # self.config = config if not os.path.exists(os.path.dirname(model_checkpoint)): - os.mkdir(os.path.dirname(model_checkpoint)) + os.makedirs(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip") @@ -612,10 +631,12 @@ class UserPolicy(Policy): else: mode = "max" response = self.policy.predict(sys_act, mode) + self.semantic_action = self.policy.semantic_action return response def init_session(self, goal=None): self.policy.init_session(goal) + self.semantic_action = [] def is_terminated(self): return self.policy.is_terminated() diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py index 7825c2880928c40f68284b0c3199932cd1cfc477..a6187318dc0ba37eea0318c5deebe4874f691fbf 100644 --- a/convlab/policy/genTUS/token_map.py +++ b/convlab/policy/genTUS/token_map.py @@ -14,11 +14,12 @@ class tokenMap: 'start_json': '{"action": [', # 49643, 10845, 7862, 646 'start_act': '["', # 49329 'sep_token': '", "', # 1297('",'), 22 - 'sep_act': '"], ["', # 49177 + 'sep_act': '"], ["', # 49177 'end_act': '"]], "', # 42248, 7479, 22 'start_text': 'text": "', # 29015, 7862, 22 - 'end_json': '}', # 24303 - 'end_json_2': '"}' # 48805 + 'end_json': '}', # 24303 + 'end_json_2': '"}', # 48805 + 'book': 'book' # 6298 } if only_action: self.format_tokens['end_act'] = '"]]}' diff --git a/convlab/policy/genTUS/unify/Goal.py b/convlab/policy/genTUS/unify/Goal.py index 6a77b090a7266d07f653e707bd4b749b6a6114bb..e3049e5af0da750c75136434a4170e628cad909e 100644 --- a/convlab/policy/genTUS/unify/Goal.py +++ b/convlab/policy/genTUS/unify/Goal.py @@ -40,7 +40,7 @@ class Goal: json.dumps(self.domain_goals, indent=4) + \ '\n-----Goal-----' - def _init_goal_from_data(self, goal=None, goal_generator=None): + def _old_goal(self, goal=None, goal_generator=None): if not goal and goal_generator: goal = ABUS_Goal(goal_generator) self.raw_goal = goal.domain_goals @@ -56,6 +56,10 @@ class Goal: # else: # print("unknow goal") + return goal + + def _init_goal_from_data(self, goal=None, goal_generator=None): + goal = self._old_goal(goal, goal_generator) # be careful of this order for domain, intent, slot, value in goal: diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py index 68af13e481fe4799dfc2a6f3763b526611eabd9c..12dd72ae10433d9eee82a6900a6702842ea9124c 100644 --- a/convlab/policy/genTUS/unify/knowledge_graph.py +++ b/convlab/policy/genTUS/unify/knowledge_graph.py @@ -11,7 +11,7 @@ DATASET = "unify" class KnowledgeGraph: def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="multiwoz21"): - print("dataset", dataset) + # print("dataset", dataset) self.debug = DEBUG self.tokenizer = tokenizer @@ -83,7 +83,7 @@ class KnowledgeGraph: if slot not in self.user_goal[domain]: self.user_goal[domain][slot] = [] - self.add_token(domain, "slot") + self.add_token(slot, "slot") if value not in self.user_goal[domain][slot]: value = json.dumps(str(value))[1:-1] @@ -97,7 +97,7 @@ class KnowledgeGraph: if not self.kg_map[map_type].token_name_is_in(token_name): self.kg_map[map_type].add_token(token_name, token_name) - def _get_max_score(self, outputs, candidate_list, map_type): + def _get_max_score(self, outputs, candidate_list, map_type, weight=None): score = {} if not candidate_list: print(f"ERROR: empty candidate list for {map_type}") @@ -107,6 +107,8 @@ class KnowledgeGraph: for x in candidate_list: hash_id = self._get_token_id(x)[0] s = outputs[:, hash_id].item() + if weight: + s *= weight[x] score[s] = {"token_id": self._get_token_id(x), "token_name": x} return score @@ -202,6 +204,17 @@ class KnowledgeGraph: return token_map + def get_book_slot(self, outputs, intent, domain, mode="max", is_mentioned=False): + slot_list = self.candidate( + candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned) + book_slot_list = [s.replace("book", "") + for s in slot_list if 'book' in s] + + token_map = self._get_max_domain_token( + outputs=outputs, candidates=book_slot_list, map_type="slot", mode=mode) + + return token_map + def get_value(self, outputs, intent, domain, slot, mode="max"): if intent in self.general_intent or slot.lower() == "none": token_name = "none" diff --git a/convlab/policy/ppo/GenTUS-BertNLU-RuleDST-PPOPolicy.json b/convlab/policy/ppo/GenTUS-BertNLU-RuleDST-PPOPolicy.json new file mode 100644 index 0000000000000000000000000000000000000000..eda6d88003cef390692d25c73e5b4892522278e2 --- /dev/null +++ b/convlab/policy/ppo/GenTUS-BertNLU-RuleDST-PPOPolicy.json @@ -0,0 +1,54 @@ +{ + "model": { + "load_path": "convlab/policy/ppo/pretrained_models/mle", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 200, + "seed": 0, + "epoch": 100, + "eval_frequency": 5, + "process_num": 1, + "num_eval_dialogues": 20, + "sys_semantic_to_usr": false + }, + "vectorizer_sys": { + "uncertainty_vector_mul": { + "class_path": "convlab.policy.vector.vector_binary.VectorBinary", + "ini_params": { + "use_masking": true, + "manually_add_entity_names": true, + "seed": 0 + } + } + }, + "nlu_sys": { + "BertNLU": { + "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU", + "ini_params": { + "mode": "all", + "config_file": "multiwoz21_all.json", + "model_file": "https://huggingface.co/ConvLab/bert-base-nlu/resolve/main/bertnlu_unified_multiwoz21_all_context0.zip" + } + } + }, + "dst_sys": { + "RuleDST": { + "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", + "ini_params": {} + } + }, + "sys_nlg": {}, + "nlu_usr": {}, + "dst_usr": {}, + "policy_usr": { + "GenTUS": { + "class_path": "convlab.policy.genTUS.stepGenTUS.UserPolicy", + "ini_params": { + "model_checkpoint": "convlab/policy/genTUS/unify/experiments/multiwoz21-exp", + "mode": "language", + "only_action": false + } + } + }, + "usr_nlg": {} +} diff --git a/convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json b/convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json index 7e170f6ddb65798771bb5e497b6a9dbf7e6013f0..75af9b1327532fa53876b40f59c9a361487c6abd 100644 --- a/convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json +++ b/convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json @@ -32,7 +32,7 @@ "nlu_usr": {}, "dst_usr": {}, "policy_usr": { - "RulePolicy": { + "GenTUS": { "class_path": "convlab.policy.genTUS.stepGenTUS.UserPolicy", "ini_params": { "model_checkpoint": "convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0", @@ -41,4 +41,4 @@ } }, "usr_nlg": {} -} +} \ No newline at end of file diff --git a/convlab/policy/ppo/configs/ppo_config.json b/convlab/policy/ppo/configs/ppo_config.json index 19cf280748ca6cb554eaf947594ccea9732a393c..35d9af0244356f54e559a3fa84ca747499ca910f 100755 --- a/convlab/policy/ppo/configs/ppo_config.json +++ b/convlab/policy/ppo/configs/ppo_config.json @@ -3,7 +3,7 @@ "gamma": 0.99, "epsilon": 0.2, "tau": 0.95, - "policy_lr": 0.00001, + "policy_lr": 0.0001, "value_lr": 0.00005, "save_dir": "save", "log_dir": "log", @@ -17,4 +17,4 @@ "lr_supervised": 0.001, "weight_decay": 0.00001, "epoch": 100 -} +} \ No newline at end of file diff --git a/convlab/policy/ppo/emoUS-BertNLU-RuleDST-PPOPolicy.json b/convlab/policy/ppo/emoUS-BertNLU-RuleDST-PPOPolicy.json new file mode 100644 index 0000000000000000000000000000000000000000..eba5bde86b5387c96dbf2b6be2d89c3e0bb71cdb --- /dev/null +++ b/convlab/policy/ppo/emoUS-BertNLU-RuleDST-PPOPolicy.json @@ -0,0 +1,57 @@ +{ + "model": { + "load_path": "convlab/policy/ppo/pretrained_models/mle", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 200, + "seed": 0, + "epoch": 100, + "eval_frequency": 5, + "process_num": 1, + "num_eval_dialogues": 20, + "sys_semantic_to_usr": false + }, + "vectorizer_sys": { + "uncertainty_vector_mul": { + "class_path": "convlab.policy.vector.vector_binary.VectorBinary", + "ini_params": { + "use_masking": true, + "manually_add_entity_names": true, + "seed": 0 + } + } + }, + "nlu_sys": { + "BertNLU": { + "class_path": "convlab.nlu.jointBERT.unified_datasets.BERTNLU", + "ini_params": { + "mode": "all", + "config_file": "multiwoz21_all.json", + "model_file": "https://huggingface.co/ConvLab/bert-base-nlu/resolve/main/bertnlu_unified_multiwoz21_all_context0.zip" + } + } + }, + "dst_sys": { + "RuleDST": { + "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", + "ini_params": {} + } + }, + "sys_nlg": {}, + "nlu_usr": {}, + "dst_usr": {}, + "policy_usr": { + "emoUS": { + "class_path": "convlab.policy.emoUS.emoUS.UserPolicy", + "ini_params": { + "model_checkpoint": "convlab/policy/emoUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17", + "character": "usr", + "mode": "language", + "only_action": false, + "use_sentiment": true, + "sample": false + } + } + }, + "usr_nlg": {} +} \ No newline at end of file diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py index a2814df36a24642d053897f2e42e16cf56a3adea..d116e7cffa22c6d50ec3e1e9575d6347de5bfd0f 100755 --- a/convlab/policy/ppo/train.py +++ b/convlab/policy/ppo/train.py @@ -4,21 +4,25 @@ Created on Sun Jul 14 16:14:07 2019 @author: truthless """ -import sys -import os import logging +import os +import random +import sys import time +from argparse import ArgumentParser +from datetime import datetime + import numpy as np import torch -import random +from torch import multiprocessing as mp from convlab.policy.ppo import PPO from convlab.policy.rlmodule import Memory -from torch import multiprocessing as mp -from argparse import ArgumentParser -from convlab.util.custom_util import set_seed, init_logging, save_config, move_finished_training, env_config, \ - eval_policy, log_start_args, save_best, load_config_file, get_config -from datetime import datetime +from convlab.util.custom_util import (env_config, eval_policy, get_config, + init_logging, load_config_file, + log_start_args, move_finished_training, + save_best, save_config, set_seed) +from convlab.dialog_agent.env import Environment sys.path.append(os.path.dirname(os.path.dirname( os.path.dirname(os.path.abspath(__file__))))) @@ -33,8 +37,7 @@ except RuntimeError: pass -def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0): - +def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0, user_reward=False): """ This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple processes. @@ -46,7 +49,6 @@ def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0): :param batchsz: total sampled items :return: """ - buff = Memory() # we need to sample batchsz of (state, action, next_state, reward, mask) # each trajectory contains `trajectory_len` num of items, so we only need to sample @@ -75,7 +77,7 @@ def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0): # print(f"s : {s['system_action']}") # print(f"a : {a}") # interact with env - next_s, r, done = env.step(a) + next_s, r, done = env.step(a, user_reward=user_reward) # print(f"next_s: {next_s['system_action']}") # a flag indicates ending or not @@ -108,8 +110,7 @@ def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0): evt.wait() -def sample(env, policy, num_train_dialogues, process_num, seed): - +def sample(env, policy, num_train_dialogues, process_num, seed, user_reward=False): """ Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return @@ -137,7 +138,7 @@ def sample(env, policy, num_train_dialogues, process_num, seed): evt = mp.Event() processes = [] for i in range(process_num): - process_args = (i, queue, evt, env, policy, process_num_dialogues, train_seeds[i]) + process_args = (i, queue, evt, env, policy, process_num_dialogues, train_seeds[i], user_reward) processes.append(mp.Process(target=sampler, args=process_args)) for p in processes: # set the process as daemon, and it will be killed once the main process is stoped. @@ -157,11 +158,10 @@ def sample(env, policy, num_train_dialogues, process_num, seed): return buff.get_batch() -def update(env, policy, num_dialogues, epoch, process_num, seed=0): +def update(env, policy, num_dialogues, epoch, process_num, seed=0, user_reward=False): # sample data asynchronously - batch = sample(env, policy, num_dialogues, process_num, seed) - + batch = sample(env, policy, num_dialogues, process_num, seed, user_reward) # print(batch) # data in batch is : batch.state: ([1, s_dim], [1, s_dim]...) # batch.action: ([1, a_dim], [1, a_dim]...) @@ -190,12 +190,14 @@ if __name__ == '__main__': help="Set level for logger") parser.add_argument("--save_eval_dials", type=bool, default=False, help="Flag for saving dialogue_info during evaluation") + parser.add_argument("--user-reward", action="store_true") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs', f'{parser.parse_args().config_name}.json') seed = parser.parse_args().seed mode = parser.parse_args().mode save_eval = parser.parse_args().save_eval_dials + use_user_reward = parser.parse_args().user_reward logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ init_logging(os.path.dirname(os.path.abspath(__file__)), mode) @@ -203,19 +205,22 @@ if __name__ == '__main__': args = [('model', 'seed', seed)] if seed is not None else list() environment_config = load_config_file(path) - save_config(vars(parser.parse_args()), environment_config, config_save_path) + save_config(vars(parser.parse_args()), + environment_config, config_save_path) conf = get_config(path, args) seed = conf['model']['seed'] logging.info('Train seed is ' + str(seed)) set_seed(seed) - policy_sys = PPO(True, seed=conf['model']['seed'], vectorizer=conf['vectorizer_sys_activated']) + policy_sys = PPO(True, seed=conf['model']['seed'], + vectorizer=conf['vectorizer_sys_activated']) # Load model if conf['model']['use_pretrained_initialisation']: logging.info("Loading supervised model checkpoint.") - policy_sys.load_from_pretrained(conf['model'].get('pretrained_load_path', "")) + policy_sys.load_from_pretrained( + conf['model'].get('pretrained_load_path', "")) elif conf['model']['load_path']: try: policy_sys.load(conf['model']['load_path']) @@ -229,14 +234,14 @@ if __name__ == '__main__': env, sess = env_config(conf, policy_sys) - policy_sys.current_time = current_time policy_sys.log_dir = config_save_path.replace('configs', 'logs') policy_sys.save_dir = save_path logging.info(f"Evaluating at start - {time_now}" + '-'*60) time_now = time.time() - eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path) + eval_dict = eval_policy(conf, policy_sys, env, sess, + save_eval, log_save_path) logging.info(f"Finished evaluating, time spent: {time.time() - time_now}") for key in eval_dict: @@ -251,13 +256,14 @@ if __name__ == '__main__': for i in range(conf['model']['epoch']): idx = i + 1 # print("Epoch :{}".format(str(idx))) - update(env, policy_sys, conf['model']['num_train_dialogues'], idx, conf['model']['process_num'], seed=seed) + update(env, policy_sys, conf['model']['num_train_dialogues'], idx, conf['model']['process_num'], seed=seed, user_reward=use_user_reward) if idx % conf['model']['eval_frequency'] == 0 and idx != 0: time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) logging.info(f"Evaluating after Dialogues: {idx * conf['model']['num_train_dialogues']} - {time_now}" + '-' * 60) - eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path) + eval_dict = eval_policy( + conf, policy_sys, env, sess, save_eval, log_save_path) best_complete_rate, best_success_rate, best_return = \ save_best(policy_sys, best_complete_rate, best_success_rate, best_return, @@ -266,7 +272,6 @@ if __name__ == '__main__': policy_sys.save(save_path, "last") for key in eval_dict: tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['num_train_dialogues']) - logging.info("End of Training: " + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py index 625fbb9f2ded3ba44ea4fcb47b06ccb380e61c31..709145e801cbdae762277a11efdb9ad5e406295c 100755 --- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -18,7 +18,9 @@ from convlab.task.multiwoz.goal_generator import GoalGenerator from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA, REF_SYS_DA from convlab.util import relative_import_module_from_unified_datasets -reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) +reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets( + 'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) + def unified_format(acts): new_acts = {'categorical': []} @@ -136,8 +138,9 @@ class UserPolicyAgendaMultiWoz(Policy): action = {} while len(action) == 0: # A -> A' + user_action - action = self.agenda.get_action(random.randint(1, self.max_initiative)) - #action = self.agenda.get_action(self.max_initiative) + action = self.agenda.get_action( + random.randint(1, self.max_initiative)) + # action = self.agenda.get_action(self.max_initiative) # transform to DA action = self._transform_usract_out(action) diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py index d65f72a06e181e66bfe0d7ac0c60f0c03a56ad43..b3e24a57028933ee15776192c83d650a45cbbf53 100644 --- a/convlab/policy/tus/unify/util.py +++ b/convlab/policy/tus/unify/util.py @@ -8,6 +8,7 @@ NOT_MENTIONED = "not mentioned" def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1): ratio = {'train': split2ratio, 'validation': split2ratio} + print("data_name", data_name) if data_name == "all" or data_name == "sgd+tm" or data_name == "tm": print("merge all datasets...") if data_name == "all": @@ -24,6 +25,15 @@ def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2rati dial_ids_order=dial_ids_order, split2ratio=ratio) raw_data = merge_dataset(datasets, all_dataset[0]) + elif data_name == "emowoz+dialmage": + all_dataset = ["emowoz", "dialmage"] + datasets = {} + for name in all_dataset: + datasets[name] = load_dataset( + name, dial_ids_order=None) + raw_data = merge_dataset(datasets, all_dataset[0]) + elif data_name in ["dialmage", "emowoz"]: + raw_data = load_dataset(data_name, dial_ids_order=None) else: print(f"load single dataset {data_name}/{split2ratio}") diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9cd8a329976f8afb3be3c695f5a8bc26d995fd --- /dev/null +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -0,0 +1,206 @@ +import json +import os +from argparse import ArgumentParser +from datetime import datetime +import numpy as np + +import matplotlib.pyplot as plt +import pandas as pd +from sklearn import metrics +from tqdm import tqdm +from transformers import T5ForConditionalGeneration, T5Tokenizer + +from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset +from convlab.policy.ussT5.evaluate import tri_convert + +from datasets import load_metric + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model", type=str, default="", + help="model name") + parser.add_argument("--data", default="emowoz+dialmage", type=str) + parser.add_argument("--gen-file", type=str) + parser.add_argument("--stop", default=-1, type=int) + return parser.parse_args() + + +def build_data(raw_data): + sentiments = {} + for sentiment, index in json.load(open("convlab/policy/emoUS/sentiment.json")).items(): + sentiments[int(index)] = sentiment + data = {"input_text": [], "target_text": []} + for prefix in ["satisfaction score: ", "action prediction: ", "utterance generation: "]: + for d in raw_data: + utt = "" + turn_len = len(d["turns"]) + for index, turn in enumerate(d["turns"]): + if turn["speaker"] == "user": + if index == turn_len - 2: + break + if index == 0: + utt = prefix + turn["utterance"] + else: + utt += ' ' + turn["utterance"] + else: + if index == 0: + print("this should no happen (index == 0)") + utt = prefix + turn["utterance"] + if index == turn_len - 1: + print("this should no happen (index == turn_len - 1)") + continue + + utt += ' ' + turn["utterance"] + + data["input_text"].append(utt) + if prefix == "satisfaction score: ": + data["target_text"].append( + sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]]) + elif prefix == "action prediction: ": + data["target_text"].append( + get_action(d["turns"][index+1]["dialogue_acts"])) + else: + data["target_text"].append( + d["turns"][index+1]["utterance"]) + + json.dump(data, open("convlab/policy/ussT5/emowoz-test.json", 'w'), indent=2) + return data + + +def get_action(dialogue_acts): + acts = [] + for _, act in dialogue_acts.items(): + for a in act: + acts.append( + f"{a['domain'].capitalize()}-{a['intent'].capitalize()}") + if not acts: + return "None" + return ','.join(acts) + + +def generate_result(model_checkpoint, data, stop=-1): + tokenizer = T5Tokenizer.from_pretrained(model_checkpoint) + model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) + results = [] + i = 0 + print("stop", stop) + for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True): + if stop > 0 and i > stop: + break + i += 1 + inputs = tokenizer([input_text], return_tensors="pt", padding=True) + output = model.generate(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + do_sample=False) + output = tokenizer.batch_decode( + output, skip_special_tokens=True)[0] + + if "satisfaction score" in input_text: + output = tri_convert(output) + results.append({"input_text": input_text, + "preds": output, + "label": target_text}) + json.dump(results, open(os.path.join( + model_checkpoint, "emowoz_result.json"), 'w'), indent=2) + return results + + +def read_result(result): + d = {} + for d_name in ["satisfaction score", "utterance generation", "action prediction"]: + d[d_name] = {"preds": [], "label": []} + for r in result: + for d_name in ["satisfaction score", "utterance generation", "action prediction"]: + if d_name in r["input_text"]: + d[d_name]["preds"].append(r["preds"]) + d[d_name]["label"].append(r["label"]) + return d + + +def satisfaction(model, d): + # satisfaction + all_sentiment = ["Neutral", "Negative", "Positive"] + print(all_sentiment) + tri_f1 = metrics.f1_score( + d["satisfaction score"]["label"], + d["satisfaction score"]["preds"], average="macro") + sep_f1 = metrics.f1_score( + d["satisfaction score"]["label"], + d["satisfaction score"]["preds"], average=None, labels=all_sentiment) + cm = metrics.confusion_matrix( + d["satisfaction score"]["label"], + d["satisfaction score"]["preds"], normalize="true", labels=all_sentiment) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, + display_labels=all_sentiment) + disp.plot() + r = {"tri_f1": float(tri_f1), + "sep_f1": list(sep_f1), + "cm": [list(c) for c in list(cm)]} + print(r) + time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" + plt.savefig(os.path.join(model, f"{time}-emowoz.png")) + + +def utterance(model, d): + bleu_metric = load_metric("sacrebleu") + labels = [[utt] for utt in d["utterance generation"]["label"]] + + bleu_score = bleu_metric.compute( + predictions=d["utterance generation"]["preds"], + references=labels, + force=True) + print(f"{model} bleu_score", bleu_score) + + +def action(model, d): + score = {} + for preds, label in zip(d["action prediction"]["preds"], d["action prediction"]["label"]): + s = f1_score(preds, label) + for n, v in s.items(): + if n not in score: + score[n] = [] + score[n].append(v) + print(f"{model} action") + for n, v in score.items(): + print(n, np.mean(v)) + + +def f1_score(prediction, label): + score = {} + tp = 0 + pre = prediction.split(',') + lab = label.split(',') + for p in pre: + if p in lab: + tp += 1 + score["precision"] = tp/len(pre) + score["recall"] = tp/len(lab) + score["F1"] = 0 + if score["precision"]+score["recall"] > 0: + score["F1"] = 2*score["precision"]*score["recall"] / \ + (score["precision"]+score["recall"]) + if pre == lab: + score["acc"] = 1 + else: + score["acc"] = 0 + return score + + +def main(): + args = arg_parser() + if args.gen_file: + d = read_result(json.load(open(args.gen_file))) + else: + data = build_data(load_experiment_dataset(args.data)["test"]) + results = generate_result(args.model, data, args.stop) + d = read_result(results) + model = args.model + satisfaction(model, d) + utterance(model, d) + action(model, d) + + +if __name__ == "__main__": + main() diff --git a/convlab/policy/ussT5/evaluate.py b/convlab/policy/ussT5/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..501fc9aa12a96713f6ae9548e0bafd5116ac6d4c --- /dev/null +++ b/convlab/policy/ussT5/evaluate.py @@ -0,0 +1,125 @@ +import os +from argparse import ArgumentParser +from datetime import datetime + +import matplotlib.pyplot as plt +import pandas as pd +from sklearn import metrics +from tqdm import tqdm +from transformers import T5ForConditionalGeneration, T5Tokenizer +import json + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model", type=str, default="", + help="model name") + parser.add_argument("--data", type=str) + parser.add_argument("--gen-file", type=str) + return parser.parse_args() + + +def bi_f1(x): + if x in ['1', '2']: + return 0 + elif x in ['3', '4', '5']: + return 1 + else: + return 0 + + +def tri_convert(x): + if x == '3': + return "Neutral" + if x in ['1', '2']: + return "Negative" + if x in ['4', '5']: + return "Positive" + return "Neutral" + + +def bi_check(p, l): + negative = ['1', '2'] + positive = ['3', '4', '5'] + if p in negative and l in negative: + return 1 + if p in positive and l in positive: + return 1 + + return 0 + + +def read_result(result): + preds = {'bi': [], "five": [], 'tri': []} + label = {'bi': [], "five": [], 'tri': []} + for r in result: + p = r["preds"] + l = r["label"] + preds["five"].append(p) + preds["bi"].append(bi_f1(p)) + preds["tri"].append(tri_convert(p)) + + label["five"].append(l) + label["bi"].append(bi_f1(l)) + label["tri"].append(tri_convert(l)) + return preds, label + + +def generate_result(model_checkpoint, data): + tokenizer = T5Tokenizer.from_pretrained(model_checkpoint) + model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) + data = pd.read_csv(data, index_col=False).astype(str) + results = [] + for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True): + if "satisfaction score" in input_text: + inputs = tokenizer([input_text], return_tensors="pt", padding=True) + output = model.generate(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + do_sample=False) + output = tokenizer.batch_decode(output, skip_special_tokens=True)[0] + if len(output) > 1: + print(output) + output = "illegal" + + results.append({"input_text": input_text, + "preds": output, + "label": target_text}) + json.dump(results, open(os.path.join( + model_checkpoint, "uss_result.json"), 'w')) + return results + + +def main(): + args = arg_parser() + if args.gen_file: + preds, label = read_result(json.load(open(args.gen_file))) + else: + results = generate_result(args.model, args.data) + preds, label = read_result(results) + + macro_f1 = metrics.f1_score(label["five"], preds["five"], average="macro") + tri_f1 = metrics.f1_score(label["tri"], preds["tri"], average="macro") + f1 = metrics.f1_score(label["bi"], preds["bi"]) + sep_f1 = metrics.f1_score( + label["five"], preds["five"], average=None, + labels=['1', '2', '3', '4', '5']) + cm = metrics.confusion_matrix( + label["five"], preds["five"], normalize="true", + labels=['1', '2', '3', '4', '5']) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, + display_labels=['1', '2', '3', '4', '5']) + disp.plot() + r = {"macro_f1": float(macro_f1), + "tri_f1": float(tri_f1), + "bi_f1": float(f1), + "sep_f1": list(sep_f1), + "cm": [list(c) for c in list(cm)]} + print(r) + dirname = "convlab/policy/uss-t5/" + time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" + plt.savefig(os.path.join(args.model, f"{time}-satisfied.png")) + + +if __name__ == "__main__": + main() diff --git a/convlab/policy/ussT5/predict.py b/convlab/policy/ussT5/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..2485d5feb18ec34514f4ff73ff65e041c38e58a4 --- /dev/null +++ b/convlab/policy/ussT5/predict.py @@ -0,0 +1,28 @@ +from argparse import ArgumentParser + +from transformers import T5ForConditionalGeneration, T5Tokenizer + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model", type=str, default="", + help="model name") + + return parser.parse_args() + + +def main(): + args = arg_parser() + model_checkpoint = args.model + tokenizer = T5Tokenizer.from_pretrained(model_checkpoint) + model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) + prefix = 'satisfaction score: ' + text = "hi, i'm looking for an attraction in the center of town to visit. we have quite a few interesting attractions in the center of town. is there anything in particular you would like to see?" + inputs = tokenizer([prefix+text], return_tensors="pt", padding=True) + output = model.generate(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + do_sample=False) + print(tokenizer.batch_decode(output, skip_special_tokens=True)) + +if __name__ == "__main__": + main() diff --git a/convlab/policy/ussT5/train.py b/convlab/policy/ussT5/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f68f408a7882014dd86ddf4743c214797966a014 --- /dev/null +++ b/convlab/policy/ussT5/train.py @@ -0,0 +1,152 @@ +import os +import random +import sys +from argparse import ArgumentParser +from datetime import datetime + +import numpy as np +import pandas as pd +import torch +from datasets import load_metric +from sklearn.model_selection import train_test_split +from transformers import (DataCollatorForSeq2Seq, Seq2SeqTrainer, + Seq2SeqTrainingArguments, T5ForConditionalGeneration, + T5Tokenizer) + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + + +def set_seed(r_seed): + random.seed(r_seed) + np.random.seed(r_seed) + torch.manual_seed(r_seed) + + +class ForT5Dataset(torch.utils.data.Dataset): + def __init__(self, inputs, targets): + self.inputs = inputs + self.targets = targets + + def __len__(self): + return len(self.targets) + + def __getitem__(self, index): + input_ids = torch.tensor(self.inputs[index]).squeeze() + target_ids = torch.tensor(self.targets[index]).squeeze() + + return {"input_ids": input_ids, "labels": target_ids} + + +def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [[label.strip()] for label in labels] + + return preds, labels + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--task", type=str, default="act-sat-utt", + help="act-sat, act-sat-utt, act-sat_no-alt, or act-sat-utt_no-alt") + parser.add_argument("--data", type=str, default="", + help="input data") + parser.add_argument("--batch", type=int, default=8, + help="batch size") + + return parser.parse_args() + + +def main(): + set_seed(0) + + def compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode( + labels, skip_special_tokens=True) + + decoded_preds, decoded_labels = postprocess_text( + decoded_preds, decoded_labels) + + result = metric.compute(predictions=decoded_preds, + references=decoded_labels) + result = {"bleu": result["score"]} + + prediction_lens = [np.count_nonzero( + pred != tokenizer.pad_token_id) for pred in preds] + result["gen_len"] = np.mean(prediction_lens) + result = {k: round(v, 4) for k, v in result.items()} + return result + + def preprocess_function(examples): + inputs = examples["input_text"].to_list() + targets = examples["target_text"].to_list() + model_inputs = tokenizer(inputs, text_target=targets, + max_length=512, truncation=True) + return model_inputs + + args = arg_parser() + base_name = "convlab/policy/ussT5" + tokenizer = T5Tokenizer.from_pretrained("t5-base") + model = T5ForConditionalGeneration.from_pretrained("t5-base") + data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) + metric = load_metric("sacrebleu") + + output_dir = os.path.join(base_name, "experiments", args.task) + # f"{datetime.now().strftime('%y-%m-%d-%H-%M')}") + + raw_data = pd.read_csv(args.data, index_col=False).astype(str) + data = {"train": None, "validation": None, "test": None} + train_set, data["test"] = train_test_split(raw_data, test_size=0.1) + data["train"], data["validation"] = train_test_split( + train_set, test_size=0.1) + folder_name = os.path.join(base_name, "data", args.task) + if not os.path.exists(folder_name): + os.makedirs(folder_name) + print("Building data...") + for data_type in data: + data[data_type].to_csv(os.path.join(folder_name, f"{data_type}.csv")) + data[data_type] = preprocess_function(data[data_type]) + data[data_type] = ForT5Dataset(inputs=data[data_type]["input_ids"], + targets=data[data_type]["labels"]) + + fp16 = False + if torch.cuda.is_available(): + print("use cuda") + fp16 = True + model.to("cuda") + + training_args = Seq2SeqTrainingArguments( + output_dir=output_dir, + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=args.batch, + per_device_eval_batch_size=args.batch, + weight_decay=0.01, + save_total_limit=3, + num_train_epochs=2, + predict_with_generate=True, + fp16=fp16 + ) + + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=data["train"], + eval_dataset=data["validation"], + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, + ) + + trainer.train() + trainer.save_model() + + +if __name__ == "__main__": + main() diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py index 3bc6550fe903b86f38240b0061f41f91d6326e3e..efd2e137149eb07314fddb1492781cbb007392e8 100644 --- a/convlab/util/custom_util.py +++ b/convlab/util/custom_util.py @@ -253,13 +253,14 @@ def env_config(conf, policy_sys, check_book_constraints=True): if dst_sys: try: if dst_sys.return_confidence_scores: - policy_sys.vector.setup_uncertain_query(dst_sys.confidence_thresholds) + policy_sys.vector.setup_uncertain_query( + dst_sys.confidence_thresholds) except: logging.info('Uncertainty threshold not set.') simulator = PipelineAgent(nlu_usr, dst_usr, policy_usr, usr_nlg, 'user') system_pipeline = PipelineAgent(nlu_sys, dst_sys, policy_sys, sys_nlg, - 'sys', return_semantic_acts=conf['model']['sys_semantic_to_usr']) + 'sys') # , return_semantic_acts=conf['model']['sys_semantic_to_usr']) # assemble evaluator = MultiWozEvaluator( diff --git a/data/unified_datasets/dialmage/README.md b/data/unified_datasets/dialmage/README.md new file mode 100644 index 0000000000000000000000000000000000000000..696fbdd0637bae48ee83b64016004eeb3388aebe --- /dev/null +++ b/data/unified_datasets/dialmage/README.md @@ -0,0 +1,67 @@ +## EmoWOZ + +This is the codebase for [EmoWOZ: A Large-Scale Corpus and Labelling Scheme for Emotion Recognition in Task-Oriented Dialogue Systems](https://arxiv.org/abs/2109.04919). + + +### Data + +The dataset can be found in `data/`. EmoWOZ adopts the same format as MultiWOZ logs. We add an additional `emotion` field in each log item. The emotion contains annotations by three annotators, each identified by an anonymous 8-character global annotator id. The `final` field contains the final label obtained either from majority voting or manual resolution. + +All DialMAGE dialogues have a dialogue id in the form of ''DMAGExxx.json'' where xxx is a number. We provide dialog_act and span_info used to generate system responses in DialMAGE. + +The definition for each label is defined as below: +| Label | Emotion Tokens | Valence | Elicitor | Conduct | +|-------|------------------------------|----------|------------|----------| +| 0 | Neutral | Neutral | Any | Polite | +| 1 | Fearful, sad, disappointed | Negative | Event/fact | Polite | +| 2 | Dissatisfied, disliking | Negative | Operator | Polite | +| 3 | Apologetic | Negative | User | Polite | +| 4 | Abusive | Negative | Operator | Impolite | +| 5 | Excited, happy, anticipating | Positive | Event/fact | Polite | +| 6 | Satisfied, liking | Positive | Operator | Polite | + +EmoWOZ dataset is licensed under Creative Commons Attribution-NonCommercial 4.0 International Public License and later. + + +### Baseline Models + +To test the dataset with baseline models used in the paper, please follow instructions in each model folder of `baselines/`. +The implementation of two models, `baselines/COSMIC/` and `baselines/DialogueRNN/`, are taken and modified from https://github.com/declare-lab/conv-emotion. + +### Requirements + +See `requirements.txt`. These are packages required for running all baseline models. Tested versions are listed below: +- Python (tested: 3.7.8) +- transformers (tested: 4.12.5) +- torch (tested: 1.8.1) +- pandas (tested: 1.3.4) +- sklearn (tested: 1.0.1) +- tqdm (tested: 4.62.3) +- nltk (tested: 3.6.5) +- ftfy (tested: 6.0.3) +- spacy (tested: 3.2.0) +- ipython (tested: 7.30.1) +- keras (tested: 2.7.0) +- tensorflow (2.7.0) + + +### Citation + +If you use EmoWOZ in your own work, please cite our work as follows: + +``` +@misc{feng2021emowoz, + title={EmoWOZ: A Large-Scale Corpus and Labelling Scheme for Emotion in Task-Oriented Dialogue Systems}, + author={Shutong Feng and Nurul Lubis and Christian Geishauser and Hsien-chin Lin and Michael Heck and Carel van Niekerk and Milica Gašić}, + year={2021}, + eprint={2109.04919}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` +Please note that this dataset should only be used for research purpose. + + +### Contact + +Any questions or bug reports can be sent to shutong.feng@hhu.de \ No newline at end of file diff --git a/data/unified_datasets/dialmage/emowoz/data_split.json b/data/unified_datasets/dialmage/emowoz/data_split.json new file mode 100644 index 0000000000000000000000000000000000000000..8021f3c01754384ca7e927109d726d745ab785ae --- /dev/null +++ b/data/unified_datasets/dialmage/emowoz/data_split.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32bb14bab3a5c2a6ed21c399e0825666b04dc340264686193cf396424eb0971b +size 328466 diff --git a/data/unified_datasets/dialmage/emowoz/emowoz-dialmage.json b/data/unified_datasets/dialmage/emowoz/emowoz-dialmage.json new file mode 100644 index 0000000000000000000000000000000000000000..17addd91c651387336bbe17534a46898c1e076d1 --- /dev/null +++ b/data/unified_datasets/dialmage/emowoz/emowoz-dialmage.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfc5d48fe8df1949b5eafd5a016393e57dbed03df1acec64569e3210bec9dbd9 +size 18098247 diff --git a/data/unified_datasets/dialmage/preprocess.py b/data/unified_datasets/dialmage/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e53a6319819146559959040deb557f1d3b8cb282 --- /dev/null +++ b/data/unified_datasets/dialmage/preprocess.py @@ -0,0 +1,1217 @@ +import copy +import re +from zipfile import ZipFile, ZIP_DEFLATED +from shutil import copy2, rmtree +import json +import os +from tqdm import tqdm +from collections import Counter +from pprint import pprint +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer +import pickle + +ontology = { + "domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different + "attraction": { + "description": "find an attraction", + "slots": { + "area": { + "description": "area to search for attractions", + "is_categorical": True, + "possible_values": [ + "centre", + "east", + "north", + "south", + "west" + ] + }, + "name": { + "description": "name of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "type": { + "description": "type of the attraction", + "is_categorical": True, + "possible_values": [ + "architecture", + "boat", + "cinema", + "college", + "concerthall", + "entertainment", + "museum", + "multiple sports", + "nightclub", + "park", + "swimmingpool", + "theatre" + ] + }, + "entrance fee": { + "description": "how much is the entrance fee", + "is_categorical": False, + "possible_values": [] + }, + "open hours": { + "description": "open hours of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "choice": { + "description": "number of attractions that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "hotel": { + "description": "find and book a hotel", + "slots": { + "internet": { + "description": "whether the hotel has internet", + "is_categorical": True, + "possible_values": [ + "free", + "no", + "yes" + ] + }, + "parking": { + "description": "whether the hotel has parking", + "is_categorical": True, + "possible_values": [ + "free", + "no", + "yes" + ] + }, + "area": { + "description": "area or place of the hotel", + "is_categorical": True, + "possible_values": [ + "centre", + "east", + "north", + "south", + "west" + ] + }, + "stars": { + "description": "star rating of the hotel", + "is_categorical": True, + "possible_values": [ + "0", + "1", + "2", + "3", + "4", + "5" + ] + }, + "price range": { + "description": "price budget of the hotel", + "is_categorical": True, + "possible_values": [ + "expensive", + "cheap", + "moderate" + ] + }, + "type": { + "description": "what is the type of the hotel", + "is_categorical": False, + "possible_values": [ + "guesthouse", + "hotel" + ] + }, + "name": { + "description": "name of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "book people": { + "description": "number of people for the hotel booking", + "is_categorical": False, + "possible_values": [] + }, + "book stay": { + "description": "length of stay at the hotel", + "is_categorical": False, + "possible_values": [] + }, + "book day": { + "description": "day of the hotel booking", + "is_categorical": True, + "possible_values": [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday" + ] + }, + "phone": { + "description": "phone number of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "ref": { + "description": "reference number of the hotel booking", + "is_categorical": False, + "possible_values": [] + }, + "choice": { + "description": "number of hotels that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "taxi": { + "description": "rent taxi to travel", + "slots": { + "destination": { + "description": "destination of taxi", + "is_categorical": False, + "possible_values": [] + }, + "departure": { + "description": "departure location of taxi", + "is_categorical": False, + "possible_values": [] + }, + "leave at": { + "description": "leaving time of taxi", + "is_categorical": False, + "possible_values": [] + }, + "arrive by": { + "description": "arrival time of taxi", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the taxi", + "is_categorical": False, + "possible_values": [] + }, + "type": { + "description": "car type of the taxi", + "is_categorical": False, + "possible_values": [] + } + } + }, + "restaurant": { + "description": "find and book a restaurant", + "slots": { + "price range": { + "description": "price budget for the restaurant", + "is_categorical": True, + "possible_values": [ + "cheap", + "expensive", + "moderate" + ] + }, + "area": { + "description": "area or place of the restaurant", + "is_categorical": True, + "possible_values": [ + "centre", + "east", + "north", + "south", + "west" + ] + }, + "food": { + "description": "the cuisine of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "name": { + "description": "name of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "book people": { + "description": "number of people for the restaurant booking", + "is_categorical": False, + "possible_values": [] + }, + "book time": { + "description": "time of the restaurant booking", + "is_categorical": False, + "possible_values": [] + }, + "book day": { + "description": "day of the restaurant booking", + "is_categorical": True, + "possible_values": [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday" + ] + }, + "ref": { + "description": "reference number of the restaurant booking", + "is_categorical": False, + "possible_values": [] + }, + "choice": { + "description": "number of restaurants that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "train": { + "description": "find a train to travel", + "slots": { + "destination": { + "description": "destination of the train", + "is_categorical": False, + "possible_values": [] + }, + "arrive by": { + "description": "arrival time of the train", + "is_categorical": False, + "possible_values": [] + }, + "departure": { + "description": "departure location of the train", + "is_categorical": False, + "possible_values": [] + }, + "leave at": { + "description": "leaving time for the train", + "is_categorical": False, + "possible_values": [] + }, + "duration": { + "description": "duration of the travel", + "is_categorical": False, + "possible_values": [] + }, + "book people": { + "description": "number of people booking for train", + "is_categorical": False, + "possible_values": [] + }, + "day": { + "description": "day of the train", + "is_categorical": True, + "possible_values": [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday" + ] + }, + "ref": { + "description": "reference number of the train booking", + "is_categorical": False, + "possible_values": [] + }, + "price": { + "description": "price of the train ticket", + "is_categorical": False, + "possible_values": [] + }, + "train id": { + "description": "id of the train", + "is_categorical": False + }, + "choice": { + "description": "number of trains that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "police": { + "description": "find a police station for help", + "slots": { + "name": { + "description": "name of the police station", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the police station", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the police station", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the police station", + "is_categorical": False, + "possible_values": [] + } + } + }, + "hospital": { + "description": "find a hospital for help", + "slots": { + "department": { + "description": "specific department of the hospital", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the hospital", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the hospital", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the hospital", + "is_categorical": False, + "possible_values": [] + } + } + }, + "general": { + "description": "general domain without slots", + "slots": {} + } + }, + "intents": { + "inform": { + "description": "inform the value of a slot" + }, + "request": { + "description": "ask for the value of a slot" + }, + "nobook": { + "description": "inform the user that the booking is failed" + }, + "reqmore": { + "description": "ask the user for more instructions" + }, + "book": { + "description": "book something for the user" + }, + "bye": { + "description": "say goodbye to the user and end the conversation" + }, + "thank": { + "description": "thanks for the help" + }, + "welcome": { + "description": "you're welcome" + }, + "greet": { + "description": "express greeting" + }, + "recommend": { + "description": "recommend a choice to the user" + }, + "select": { + "description": "provide several choices for the user" + }, + "offerbook": { + "description": "ask the user if he or she needs booking" + }, + "offerbooked": { + "description": "provide information about the booking" + }, + "nooffer": { + "description": "inform the user that there is no result satisfies user requirements" + } + }, + "state": { + "attraction": { + "type": "", + "name": "", + "area": "" + }, + "hotel": { + "name": "", + "area": "", + "parking": "", + "price range": "", + "stars": "", + "internet": "", + "type": "", + "book stay": "", + "book day": "", + "book people": "" + }, + "restaurant": { + "food": "", + "price range": "", + "name": "", + "area": "", + "book time": "", + "book day": "", + "book people": "" + }, + "taxi": { + "leave at": "", + "destination": "", + "departure": "", + "arrive by": "" + }, + "train": { + "leave at": "", + "destination": "", + "day": "", + "arrive by": "", + "departure": "", + "book people": "" + }, + "hospital": { + "department": "" + } + }, + "dialogue_acts": { + "categorical": {}, + "non-categorical": {}, + "binary": {} + } +} + +slot_name_map = { + 'addr': "address", + 'post': "postcode", + 'pricerange': "price range", + 'arrive': "arrive by", + 'arriveby': "arrive by", + 'leave': "leave at", + 'leaveat': "leave at", + 'depart': "departure", + 'dest': "destination", + 'fee': "entrance fee", + 'open': 'open hours', + 'car': "type", + 'car type': "type", + 'ticket': 'price', + 'trainid': 'train id', + 'id': 'train id', + 'people': 'book people', + 'stay': 'book stay', + 'none': '', + 'attraction': { + 'price': 'entrance fee' + }, + 'hospital': {}, + 'hotel': { + 'day': 'book day', 'price': "price range" + }, + 'restaurant': { + 'day': 'book day', 'time': 'book time', 'price': "price range" + }, + 'taxi': {}, + 'train': { + 'day': 'day', 'time': "duration" + }, + 'police': {}, + 'booking': {} +} + +reverse_da_slot_name_map = { + 'address': 'Addr', + 'postcode': 'Post', + 'price range': 'Price', + 'arrive by': 'Arrive', + 'leave at': 'Leave', + 'departure': 'Depart', + 'destination': 'Dest', + 'entrance fee': 'Fee', + 'open hours': 'Open', + 'price': 'Ticket', + 'train id': 'Id', + 'book people': 'People', + 'book stay': 'Stay', + 'book day': 'Day', + 'book time': 'Time', + 'duration': 'Time', + 'taxi': { + 'type': 'Car', + 'phone': 'Phone' + } +} + +digit2word = { + '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', '5': 'five', + '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', '10': 'ten' +} + +cnt_domain_slot = Counter() + + +class BookingActRemapper: + + def __init__(self, ontology): + self.ontology = ontology + self.reset() + + def reset(self): + self.current_domains_user = [] + self.current_domains_system = [] + self.booked_domains = [] + + def retrieve_current_domain_from_user(self, turn_id, ori_dialog): + prev_user_turn = ori_dialog[turn_id - 1] + + dialog_acts = prev_user_turn.get('dialog_act', []) + keyword_domains_user = get_keyword_domains(prev_user_turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + self.current_domains_user = current_domains_temp if current_domains_temp else self.current_domains_user + next_user_domains = get_next_user_act_domains(ori_dialog, turn_id) + + return keyword_domains_user, next_user_domains + + def retrieve_current_domain_from_system(self, turn_id, ori_dialog): + + system_turn = ori_dialog[turn_id] + dialog_acts = system_turn.get('dialog_act', []) + keyword_domains_system = get_keyword_domains(system_turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + self.current_domains_system = current_domains_temp if current_domains_temp else self.current_domains_system + booked_domain_current = self.check_domain_booked(system_turn) + + return keyword_domains_system, booked_domain_current + + def remap(self, turn_id, ori_dialog): + + keyword_domains_user, next_user_domains = self.retrieve_current_domain_from_user( + turn_id, ori_dialog) + keyword_domains_system, booked_domain_current = self.retrieve_current_domain_from_system( + turn_id, ori_dialog) + + # only need to remap if there is a dialog action labelled + dialog_acts = ori_dialog[turn_id].get('dialog_act', []) + spans = ori_dialog[turn_id].get('span_info', []) + if dialog_acts: + + flattened_acts = flatten_acts(dialog_acts) + # flattened_spans = flatten_span_acts(spans) + remapped_acts, error_local = remap_acts(flattened_acts, self.current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, self.current_domains_system, + next_user_domains, self.ontology) + + # remapped_spans, _ = remap_acts(flattened_spans, self.current_domains_user, + # booked_domain_current, keyword_domains_user, + # keyword_domains_system, self.current_domains_system, + # next_user_domains, self.ontology) + + deflattened_remapped_acts = deflat_acts(remapped_acts) + # deflattened_remapped_spans = deflat_span_acts(remapped_spans) + + return deflattened_remapped_acts, spans # deflattened_remapped_spans + else: + return dialog_acts, spans + + def check_domain_booked(self, turn): + + booked_domain_current = None + return booked_domain_current + + # workaround + for domain in turn['metadata']: + if turn['metadata'][domain]["book"]["booked"] and domain not in self.booked_domains: + booked_domain_current = domain.capitalize() + self.booked_domains.append(domain) + return booked_domain_current + + +def get_keyword_domains(turn): + keyword_domains = [] + text = turn['text'] + for d in ["Hotel", "Restaurant", "Train"]: + if d.lower() in text.lower(): + keyword_domains.append(d) + return keyword_domains + + +def get_current_domains_from_act(dialog_acts): + + current_domains_temp = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + if domain in ["general", "Booking"]: + continue + if domain not in current_domains_temp: + current_domains_temp.append(domain) + + return current_domains_temp + + +def get_next_user_act_domains(ori_dialog, turn_id): + domains = [] + try: + next_user_act = ori_dialog[turn_id + 1]['dialog_act'] + domains = get_current_domains_from_act(next_user_act) + except: + # will fail if system act is the last act of the dialogue + pass + return domains + + +def flatten_acts(dialog_acts): + flattened_acts = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + for slot_value in dialog_acts[dom_int]: + slot = slot_value[0] + value = slot_value[1] + flattened_acts.append((domain, intent, slot, value)) + + return flattened_acts + + +def flatten_span_acts(span_acts): + + flattened_acts = [] + for span_act in span_acts: + domain, intent = span_act[0].split("-") + flattened_acts.append((domain, intent, span_act[1], span_act[2:])) + return flattened_acts + + +def deflat_acts(flattened_acts): + + dialog_acts = dict() + + for act in flattened_acts: + domain, intent, slot, value = act + if f"{domain}-{intent}" not in dialog_acts.keys(): + dialog_acts[f"{domain}-{intent}"] = [[slot, value]] + else: + dialog_acts[f"{domain}-{intent}"].append([slot, value]) + + return dialog_acts + + +def deflat_span_acts(flattened_acts): + + dialog_span_acts = [] + for act in flattened_acts: + domain, intent, slot, value = act + if value == 'none': + continue + new_act = [f"{domain}-{intent}", slot] + new_act.extend(value) + dialog_span_acts.append(new_act) + + return dialog_span_acts + + +def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_domains_user=None, + keyword_domains_system=None, current_domain_system=None, next_user_domain=None, ontology=None): + + # We now look for all cases that can happen: Booking domain, Booking within a domain or taxi-inform-car for booking + error = 0 + remapped_acts = [] + + # if there is more than one current domain or none at all, we try to get booked domain differently + if len(current_domains) != 1 and booked_domain: + current_domains = [booked_domain] + elif len(current_domains) != 1 and len(keyword_domains_user) == 1: + current_domains = keyword_domains_user + elif len(current_domains) != 1 and len(keyword_domains_system) == 1: + current_domains = keyword_domains_system + elif len(current_domains) != 1 and len(current_domain_system) == 1: + current_domains = current_domain_system + elif len(current_domains) != 1 and len(next_user_domain) == 1: + current_domains = next_user_domain + + for act in flattened_acts: + try: + domain, intent, slot, value = act + if f"{domain}-{intent}-{slot}" == "Booking-Book-Ref": + # We need to remap that booking act now + potential_domain = current_domains[0] + remapped_acts.append( + (potential_domain, "Book", "none", "none")) + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, "Inform", "Ref", value)) + elif domain == "Booking" and intent == "Book" and slot != "Ref": + # the book intent is here actually an inform intent according to the data + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, "Inform", slot, value)) + elif domain == "Booking" and intent == "Inform": + # the inform intent is here actually a request intent according to the data + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, "OfferBook", slot, value)) + elif domain == "Booking" and intent in ["NoBook", "Request"]: + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, intent, slot, value)) + elif f"{domain}-{intent}-{slot}" == "Taxi-Inform-Car": + # taxi-inform-car actually triggers the booking and informs on a car + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, intent, slot, value)) + elif f"{domain}-{intent}-{slot}" in ["Train-Inform-Ref", "Train-OfferBooked-Ref"]: + # train-inform/offerbooked-ref actually triggers the booking and informs on the reference number + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, "Inform", slot, value)) + elif domain == "Train" and intent == "OfferBooked" and slot != "Ref": + # this is actually an inform act + remapped_acts.append((domain, "Inform", slot, value)) + else: + remapped_acts.append(act) + except Exception as e: + print("Error detected:", e) + error += 1 + + return remapped_acts, error + + +def ontology_check(domain_, slot_, init_ontology): + + domain = domain_.lower() + slot = slot_.lower() + if slot not in init_ontology['domains'][domain]['slots']: + if slot in slot_name_map: + slot = slot_name_map[slot] + elif slot in slot_name_map[domain]: + slot = slot_name_map[domain][slot] + return slot in init_ontology['domains'][domain]['slots'] + + +def reverse_da(dialogue_acts): + global reverse_da_slot_name_map + das = {} + for da_type in dialogue_acts: + for da in dialogue_acts[da_type]: + intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da.get( + 'value', '') + if domain == 'general': + Domain_Intent = '-'.join([domain, intent]) + elif intent == 'nooffer': + Domain_Intent = '-'.join([domain.capitalize(), 'NoOffer']) + elif intent == 'nobook': + Domain_Intent = '-'.join([domain.capitalize(), 'NoBook']) + elif intent == 'offerbook': + Domain_Intent = '-'.join([domain.capitalize(), 'OfferBook']) + else: + Domain_Intent = '-'.join([domain.capitalize(), + intent.capitalize()]) + das.setdefault(Domain_Intent, []) + if slot in reverse_da_slot_name_map: + Slot = reverse_da_slot_name_map[slot] + elif domain in reverse_da_slot_name_map and slot in reverse_da_slot_name_map[domain]: + Slot = reverse_da_slot_name_map[domain][slot] + else: + Slot = slot.capitalize() + if value == '': + if intent == 'request': + value = '?' + else: + value = 'none' + if Slot == '': + Slot = 'none' + das[Domain_Intent].append([Slot, value]) + return das + + +def normalize_domain_slot_value(domain, slot, value): + global ontology, slot_name_map + domain = domain.lower() + slot = slot.lower() + value = value.strip() + if value in ['do nt care', "do n't care"]: + value = 'dontcare' + if value in ['?', 'none', 'not mentioned']: + value = "" + if domain not in ontology['domains']: + raise Exception(f'{domain} not in ontology') + if slot not in ontology['domains'][domain]['slots']: + if slot in slot_name_map: + slot = slot_name_map[slot] + elif slot in slot_name_map[domain]: + slot = slot_name_map[domain][slot] + else: + raise Exception(f'{domain}-{slot} not in ontology') + assert slot == '' or slot in ontology['domains'][domain][ + 'slots'], f'{(domain, slot, value)} not in ontology' + return domain, slot, value + + +def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer): + ''' + convert multiwoz dialogue acts to required format + :param da_dict: dict[(intent, domain, slot, value)] = [word_start, word_end] + :param utt: user or system utt + ''' + global ontology, digit2word, cnt_domain_slot + + converted_da = { + 'categorical': [], + 'non-categorical': [], + 'binary': [] + } + sentences = sent_tokenizer.tokenize(utt) + sent_spans = sent_tokenizer.span_tokenize(utt) + tokens = [ + token for sent in sentences for token in word_tokenizer.tokenize(sent)] + token_spans = [(sent_span[0] + token_span[0], sent_span[0] + token_span[1]) for sent, sent_span in + zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)] + # assert len(tokens) == len(token_spans) + # for token, span in zip(tokens, token_spans): + # if utt[span[0]:span[1]] != '"': + # assert utt[span[0]:span[1]] == token + + for (intent, domain, slot, value), span in da_dict.items(): + if intent == 'request' or slot == '' or value == '': + # binary dialog acts + assert value == '' + converted_da['binary'].append({ + 'intent': intent, + 'domain': domain, + 'slot': slot + }) + elif ontology['domains'][domain]['slots'][slot]['is_categorical']: + # categorical dialog acts + converted_da['categorical'].append({ + 'intent': intent, + 'domain': domain, + 'slot': slot, + 'value': value + }) + else: + # non-categorical dialog acts + converted_da['non-categorical'].append({ + 'intent': intent, + 'domain': domain, + 'slot': slot, + 'value': value + }) + # correct some value and try to give char level span + match = False + value = value.lower() + if span and span[0] <= span[1]: + # use original span annotation, but tokenizations are different + start_word, end_word = span + if end_word >= len(tokens): + # due to different tokenization, sometimes will out of index + delta = end_word - len(tokens) + 1 + start_word -= delta + end_word -= delta + start_char, end_char = token_spans[start_word][0], token_spans[end_word][1] + value_span = utt[start_char:end_char].lower() + match = True + if value_span == value: + cnt_domain_slot['span match'] += 1 + elif value.isdigit() and value in digit2word and digit2word[value] == value_span: + # !!!CHANGE VALUE: value is digit but value span is word + cnt_domain_slot['digit value match'] += 1 + elif ''.join(value.split()) == ''.join(value_span.split()): + # !!!CHANGE VALUE: equal when remove blank + cnt_domain_slot['remove blank'] += 1 + elif value in value_span: + # value in value_span + start_char += value_span.index(value) + end_char = start_char + len(value) + assert utt[start_char:end_char].lower( + ) == value, f'{[value, utt[start_char:end_char], utt]}' + cnt_domain_slot['value in span'] += 1 + elif ':' in value and value == '0' + value_span: + # !!!CHANGE VALUE: time x:xx == 0x:xx + cnt_domain_slot['x:xx == 0x:xx'] += 1 + else: + # span mismatch, search near 1-2 words + for window in range(1, 3): + start = max(0, start_word - window) + end = min(len(token_spans) - 1, end_word + window) + large_span = utt[token_spans[start] + [0]:token_spans[end][1]].lower() + if value in large_span: + start_char = token_spans[start][0] + \ + large_span.index(value) + end_char = start_char + len(value) + assert utt[ + start_char:end_char].lower() == value, f'{[value, utt[start_char:end_char], utt]}' + cnt_domain_slot[f'window={window}'] += 1 + break + else: + # still not found + match = False + + if match: + converted_da['non-categorical'][-1]['value'] = utt[start_char:end_char] + converted_da['non-categorical'][-1]['start'] = start_char + converted_da['non-categorical'][-1]['end'] = end_char + cnt_domain_slot['have span'] += 1 + else: + cnt_domain_slot['no span'] += 1 + return converted_da + + +def act_list2dict(act_list): + act_dict = {} + for intent, domain, slot, value in act_list: + key = f"{domain}-{intent}" + if key not in act_dict: + act_dict[key] = [] + act_dict[key].append([slot, value]) + return act_dict + + +def preprocess(): + original_data_dir = 'emowoz' + new_data_dir = 'data' + + if not os.path.exists(original_data_dir): + original_data_zip = 'MultiWOZ_2.1.zip' + if not os.path.exists(original_data_zip): + raise FileNotFoundError( + f'cannot find original data {original_data_zip} in multiwoz21/, should manually download MultiWOZ_2.1.zip from https://github.com/budzianowski/multiwoz/blob/master/data/MultiWOZ_2.1.zip') + else: + archive = ZipFile(original_data_zip) + archive.extractall() + + os.makedirs(new_data_dir, exist_ok=True) + for filename in os.listdir(original_data_dir): + if 'db' in filename: + copy2(f'{original_data_dir}/{filename}', new_data_dir) + + # how about emowoz-dialmage + original_data = json.load( + open(f'{original_data_dir}/emowoz-dialmage.json')) + global ontology, cnt_domain_slot + + raw_data = pickle.load(open('dialog_state.pkl', 'rb')) + actions = raw_data[0] + + data_split = json.load(open(f'{original_data_dir}/data_split.json')) + val_list = data_split["dev"]["dialmage"] + test_list = data_split["test"]["dialmage"] + dataset = 'emowoz-dialmage' + splits = ['train', 'validation', 'test'] + dialogues_by_split = {split: [] for split in splits} + sent_tokenizer = PunktSentenceTokenizer() + word_tokenizer = TreebankWordTokenizer() + booking_remapper = BookingActRemapper(ontology) + for ori_dialog_id, ori_dialog in tqdm(original_data.items()): + act = actions[ori_dialog_id] + for turn_id in range(len(ori_dialog["log"])): + ori_dialog["log"][turn_id]["dialog_act"] = act_list2dict( + act[turn_id]) + + if ori_dialog_id in val_list: + split = 'validation' + elif ori_dialog_id in test_list: + split = 'test' + else: + split = 'train' + dialogue_id = f'{dataset}-{split}-{len(dialogues_by_split[split])}' + + # get user goal and involved domains + cur_domains = [] + + dialogue = { + 'dataset': dataset, + 'data_split': split, + 'dialogue_id': dialogue_id, + 'original_id': ori_dialog_id, + 'domains': cur_domains, # will be updated by dialog_acts and state + 'goal': "", + 'turns': [] + } + + booking_remapper.reset() + belief_domains = ['attraction', 'restaurant', + 'train', 'hotel', 'taxi', 'hospital'] + entity_booked_dict = dict((domain, False) for domain in belief_domains) + + for turn_id, turn in enumerate(ori_dialog['log']): + # correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1 + text = turn['text'] + text = re.sub(" Im ", " I'm ", text) + text = re.sub(" im ", " i'm ", text) + text = re.sub(r"^Im ", "I'm ", text) + text = re.sub(r"^im ", "i'm ", text) + text = re.sub("theres", "there's", text) + text = re.sub("dont", "don't", text) + text = re.sub("whats", "what's", text) + text = re.sub('thats', "that's", text) + utt = text + speaker = 'user' if turn_id % 2 == 0 else 'system' + + das = turn.get('dialog_act', []) + spans = turn.get('span_info', []) + + # if speaker == 'system': + das, spans = booking_remapper.remap(turn_id, ori_dialog['log']) + + da_dict = {} + # transform DA + for Domain_Intent in das: + domain, intent = Domain_Intent.lower().split('-') + assert intent in ontology['intents'], f'{ori_dialog_id}:{turn_id}:da\t{intent} not in ontology' + for Slot, value in das[Domain_Intent]: + domain, slot, value = normalize_domain_slot_value( + domain, Slot, value) + if domain not in cur_domains: + # update original cur_domains + cur_domains.append(domain) + da_dict[(intent, domain, slot, value,)] = [] + + # for span in spans: + # Domain_Intent, Slot, value, start_word, end_word = span + # domain, intent = Domain_Intent.lower().split('-') + # domain, slot, value = normalize_domain_slot_value( + # domain, Slot, value) + # print(da_dict) + # assert (intent, domain, slot, value,) in da_dict + # da_dict[(intent, domain, slot, value,)] = [ + # start_word, end_word] + + dialogue_acts = convert_da( + da_dict, utt, sent_tokenizer, word_tokenizer) + + # reverse_das = reverse_da(dialogue_acts) + # das_list = sorted([(Domain_Intent, Slot, ''.join(value.split()).lower()) for Domain_Intent in das for Slot, value in das[Domain_Intent]]) + # reverse_das_list = sorted([(Domain_Intent, Slot, ''.join(value.split()).lower()) for Domain_Intent in reverse_das for Slot, value in reverse_das[Domain_Intent]]) + # if das_list != reverse_das_list: + # print(das_list) + # print(reverse_das_list) + # print() + # print() + + dialogue['turns'].append({ + 'speaker': speaker, + 'utterance': utt, + 'utt_idx': turn_id, + 'dialogue_acts': dialogue_acts, + 'emotion': turn['emotion'] + }) + + # add to dialogue_acts dictionary in the ontology + for da_type in dialogue_acts: + das = dialogue_acts[da_type] + for da in das: + ontology["dialogue_acts"][da_type].setdefault( + (da['intent'], da['domain'], da['slot']), {}) + ontology["dialogue_acts"][da_type][( + da['intent'], da['domain'], da['slot'])][speaker] = True + + if speaker == 'system': + # add state to last user turn + # add empty db_results + # turn_state = turn['metadata'] + cur_state = copy.deepcopy(ontology['state']) + booked = {} + # for domain in turn_state: + # if domain not in cur_state: + # continue + # for subdomain in ['semi', 'book']: + # for slot, value in turn_state[domain][subdomain].items(): + # if slot == 'ticket': + # continue + # elif slot == 'booked': + # assert domain in ontology['domains'] + # booked[domain] = value + # continue + # _, slot, value = normalize_domain_slot_value( + # domain, slot, value) + # cur_state[domain][slot] = value + dialogue['turns'][-2]['state'] = cur_state + # entity_booked_dict, booked = fix_entity_booked_info( + # entity_booked_dict, booked) + dialogue['turns'][-1]['booked'] = booked + dialogues_by_split[split].append(dialogue) + # pprint(cnt_domain_slot.most_common()) + dialogues = [] + for split in splits: + dialogues += dialogues_by_split[split] + for da_type in ontology['dialogue_acts']: + ontology["dialogue_acts"][da_type] = sorted([str( + {'user': speakers.get('user', False), 'system': speakers.get('system', False), 'intent': da[0], + 'domain': da[1], 'slot': da[2]}) for da, speakers in ontology["dialogue_acts"][da_type].items()]) + json.dump(dialogues[:10], open(f'dummy_data.json', 'w', + encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(ontology, open(f'{new_data_dir}/ontology.json', + 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', + 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: + for filename in os.listdir(new_data_dir): + zf.write(f'{new_data_dir}/{filename}') + # rmtree(original_data_dir) + # rmtree(new_data_dir) + return dialogues, ontology + + +def fix_entity_booked_info(entity_booked_dict, booked): + for domain in entity_booked_dict: + if not entity_booked_dict[domain] and booked[domain]: + entity_booked_dict[domain] = True + booked[domain] = [] + return entity_booked_dict, booked + + +if __name__ == '__main__': + preprocess() diff --git a/data/unified_datasets/dialmage/shuffled_dial_ids.json b/data/unified_datasets/dialmage/shuffled_dial_ids.json new file mode 100644 index 0000000000000000000000000000000000000000..1eaa37c9a3a371632126b96fc719c9915be0500b --- /dev/null +++ b/data/unified_datasets/dialmage/shuffled_dial_ids.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81c036e40d0d7a149163363b11c853c635c8157d2d59be7de686132e81a0b34c +size 4684 diff --git a/data/unified_datasets/emowoz/README.md b/data/unified_datasets/emowoz/README.md new file mode 100644 index 0000000000000000000000000000000000000000..696fbdd0637bae48ee83b64016004eeb3388aebe --- /dev/null +++ b/data/unified_datasets/emowoz/README.md @@ -0,0 +1,67 @@ +## EmoWOZ + +This is the codebase for [EmoWOZ: A Large-Scale Corpus and Labelling Scheme for Emotion Recognition in Task-Oriented Dialogue Systems](https://arxiv.org/abs/2109.04919). + + +### Data + +The dataset can be found in `data/`. EmoWOZ adopts the same format as MultiWOZ logs. We add an additional `emotion` field in each log item. The emotion contains annotations by three annotators, each identified by an anonymous 8-character global annotator id. The `final` field contains the final label obtained either from majority voting or manual resolution. + +All DialMAGE dialogues have a dialogue id in the form of ''DMAGExxx.json'' where xxx is a number. We provide dialog_act and span_info used to generate system responses in DialMAGE. + +The definition for each label is defined as below: +| Label | Emotion Tokens | Valence | Elicitor | Conduct | +|-------|------------------------------|----------|------------|----------| +| 0 | Neutral | Neutral | Any | Polite | +| 1 | Fearful, sad, disappointed | Negative | Event/fact | Polite | +| 2 | Dissatisfied, disliking | Negative | Operator | Polite | +| 3 | Apologetic | Negative | User | Polite | +| 4 | Abusive | Negative | Operator | Impolite | +| 5 | Excited, happy, anticipating | Positive | Event/fact | Polite | +| 6 | Satisfied, liking | Positive | Operator | Polite | + +EmoWOZ dataset is licensed under Creative Commons Attribution-NonCommercial 4.0 International Public License and later. + + +### Baseline Models + +To test the dataset with baseline models used in the paper, please follow instructions in each model folder of `baselines/`. +The implementation of two models, `baselines/COSMIC/` and `baselines/DialogueRNN/`, are taken and modified from https://github.com/declare-lab/conv-emotion. + +### Requirements + +See `requirements.txt`. These are packages required for running all baseline models. Tested versions are listed below: +- Python (tested: 3.7.8) +- transformers (tested: 4.12.5) +- torch (tested: 1.8.1) +- pandas (tested: 1.3.4) +- sklearn (tested: 1.0.1) +- tqdm (tested: 4.62.3) +- nltk (tested: 3.6.5) +- ftfy (tested: 6.0.3) +- spacy (tested: 3.2.0) +- ipython (tested: 7.30.1) +- keras (tested: 2.7.0) +- tensorflow (2.7.0) + + +### Citation + +If you use EmoWOZ in your own work, please cite our work as follows: + +``` +@misc{feng2021emowoz, + title={EmoWOZ: A Large-Scale Corpus and Labelling Scheme for Emotion in Task-Oriented Dialogue Systems}, + author={Shutong Feng and Nurul Lubis and Christian Geishauser and Hsien-chin Lin and Michael Heck and Carel van Niekerk and Milica Gašić}, + year={2021}, + eprint={2109.04919}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` +Please note that this dataset should only be used for research purpose. + + +### Contact + +Any questions or bug reports can be sent to shutong.feng@hhu.de \ No newline at end of file diff --git a/data/unified_datasets/emowoz/emowoz/data_split.json b/data/unified_datasets/emowoz/emowoz/data_split.json new file mode 100644 index 0000000000000000000000000000000000000000..8021f3c01754384ca7e927109d726d745ab785ae --- /dev/null +++ b/data/unified_datasets/emowoz/emowoz/data_split.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32bb14bab3a5c2a6ed21c399e0825666b04dc340264686193cf396424eb0971b +size 328466 diff --git a/data/unified_datasets/emowoz/emowoz/emowoz-dialmage.json b/data/unified_datasets/emowoz/emowoz/emowoz-dialmage.json new file mode 100644 index 0000000000000000000000000000000000000000..17addd91c651387336bbe17534a46898c1e076d1 --- /dev/null +++ b/data/unified_datasets/emowoz/emowoz/emowoz-dialmage.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfc5d48fe8df1949b5eafd5a016393e57dbed03df1acec64569e3210bec9dbd9 +size 18098247 diff --git a/data/unified_datasets/emowoz/emowoz/emowoz-multiwoz.json b/data/unified_datasets/emowoz/emowoz/emowoz-multiwoz.json new file mode 100644 index 0000000000000000000000000000000000000000..2e5feede56c724bf2ecf5f01c448394ab68724f5 --- /dev/null +++ b/data/unified_datasets/emowoz/emowoz/emowoz-multiwoz.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:588c22e2d0d2f66399fcc545bb9d2596b816f4357bca7450d889d51c73e29318 +size 159559868 diff --git a/data/unified_datasets/emowoz/preprocess.py b/data/unified_datasets/emowoz/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..f89deae111271e914e067da5059aea310f108849 --- /dev/null +++ b/data/unified_datasets/emowoz/preprocess.py @@ -0,0 +1,1198 @@ +import copy +import re +from zipfile import ZipFile, ZIP_DEFLATED +from shutil import copy2, rmtree +import json +import os +from tqdm import tqdm +from collections import Counter +from pprint import pprint +from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer + +ontology = { + "domains": { # descriptions are adapted from multiwoz22, but is_categorical may be different + "attraction": { + "description": "find an attraction", + "slots": { + "area": { + "description": "area to search for attractions", + "is_categorical": True, + "possible_values": [ + "centre", + "east", + "north", + "south", + "west" + ] + }, + "name": { + "description": "name of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "type": { + "description": "type of the attraction", + "is_categorical": True, + "possible_values": [ + "architecture", + "boat", + "cinema", + "college", + "concerthall", + "entertainment", + "museum", + "multiple sports", + "nightclub", + "park", + "swimmingpool", + "theatre" + ] + }, + "entrance fee": { + "description": "how much is the entrance fee", + "is_categorical": False, + "possible_values": [] + }, + "open hours": { + "description": "open hours of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the attraction", + "is_categorical": False, + "possible_values": [] + }, + "choice": { + "description": "number of attractions that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "hotel": { + "description": "find and book a hotel", + "slots": { + "internet": { + "description": "whether the hotel has internet", + "is_categorical": True, + "possible_values": [ + "free", + "no", + "yes" + ] + }, + "parking": { + "description": "whether the hotel has parking", + "is_categorical": True, + "possible_values": [ + "free", + "no", + "yes" + ] + }, + "area": { + "description": "area or place of the hotel", + "is_categorical": True, + "possible_values": [ + "centre", + "east", + "north", + "south", + "west" + ] + }, + "stars": { + "description": "star rating of the hotel", + "is_categorical": True, + "possible_values": [ + "0", + "1", + "2", + "3", + "4", + "5" + ] + }, + "price range": { + "description": "price budget of the hotel", + "is_categorical": True, + "possible_values": [ + "expensive", + "cheap", + "moderate" + ] + }, + "type": { + "description": "what is the type of the hotel", + "is_categorical": False, + "possible_values": [ + "guesthouse", + "hotel" + ] + }, + "name": { + "description": "name of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "book people": { + "description": "number of people for the hotel booking", + "is_categorical": False, + "possible_values": [] + }, + "book stay": { + "description": "length of stay at the hotel", + "is_categorical": False, + "possible_values": [] + }, + "book day": { + "description": "day of the hotel booking", + "is_categorical": True, + "possible_values": [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday" + ] + }, + "phone": { + "description": "phone number of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the hotel", + "is_categorical": False, + "possible_values": [] + }, + "ref": { + "description": "reference number of the hotel booking", + "is_categorical": False, + "possible_values": [] + }, + "choice": { + "description": "number of hotels that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "taxi": { + "description": "rent taxi to travel", + "slots": { + "destination": { + "description": "destination of taxi", + "is_categorical": False, + "possible_values": [] + }, + "departure": { + "description": "departure location of taxi", + "is_categorical": False, + "possible_values": [] + }, + "leave at": { + "description": "leaving time of taxi", + "is_categorical": False, + "possible_values": [] + }, + "arrive by": { + "description": "arrival time of taxi", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the taxi", + "is_categorical": False, + "possible_values": [] + }, + "type": { + "description": "car type of the taxi", + "is_categorical": False, + "possible_values": [] + } + } + }, + "restaurant": { + "description": "find and book a restaurant", + "slots": { + "price range": { + "description": "price budget for the restaurant", + "is_categorical": True, + "possible_values": [ + "cheap", + "expensive", + "moderate" + ] + }, + "area": { + "description": "area or place of the restaurant", + "is_categorical": True, + "possible_values": [ + "centre", + "east", + "north", + "south", + "west" + ] + }, + "food": { + "description": "the cuisine of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "name": { + "description": "name of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the restaurant", + "is_categorical": False, + "possible_values": [] + }, + "book people": { + "description": "number of people for the restaurant booking", + "is_categorical": False, + "possible_values": [] + }, + "book time": { + "description": "time of the restaurant booking", + "is_categorical": False, + "possible_values": [] + }, + "book day": { + "description": "day of the restaurant booking", + "is_categorical": True, + "possible_values": [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday" + ] + }, + "ref": { + "description": "reference number of the restaurant booking", + "is_categorical": False, + "possible_values": [] + }, + "choice": { + "description": "number of restaurants that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "train": { + "description": "find a train to travel", + "slots": { + "destination": { + "description": "destination of the train", + "is_categorical": False, + "possible_values": [] + }, + "arrive by": { + "description": "arrival time of the train", + "is_categorical": False, + "possible_values": [] + }, + "departure": { + "description": "departure location of the train", + "is_categorical": False, + "possible_values": [] + }, + "leave at": { + "description": "leaving time for the train", + "is_categorical": False, + "possible_values": [] + }, + "duration": { + "description": "duration of the travel", + "is_categorical": False, + "possible_values": [] + }, + "book people": { + "description": "number of people booking for train", + "is_categorical": False, + "possible_values": [] + }, + "day": { + "description": "day of the train", + "is_categorical": True, + "possible_values": [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday" + ] + }, + "ref": { + "description": "reference number of the train booking", + "is_categorical": False, + "possible_values": [] + }, + "price": { + "description": "price of the train ticket", + "is_categorical": False, + "possible_values": [] + }, + "train id": { + "description": "id of the train", + "is_categorical": False + }, + "choice": { + "description": "number of trains that meet the requirement", + "is_categorical": False, + "possible_values": [] + } + } + }, + "police": { + "description": "find a police station for help", + "slots": { + "name": { + "description": "name of the police station", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the police station", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the police station", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the police station", + "is_categorical": False, + "possible_values": [] + } + } + }, + "hospital": { + "description": "find a hospital for help", + "slots": { + "department": { + "description": "specific department of the hospital", + "is_categorical": False, + "possible_values": [] + }, + "address": { + "description": "address of the hospital", + "is_categorical": False, + "possible_values": [] + }, + "phone": { + "description": "phone number of the hospital", + "is_categorical": False, + "possible_values": [] + }, + "postcode": { + "description": "postcode of the hospital", + "is_categorical": False, + "possible_values": [] + } + } + }, + "general": { + "description": "general domain without slots", + "slots": {} + } + }, + "intents": { + "inform": { + "description": "inform the value of a slot" + }, + "request": { + "description": "ask for the value of a slot" + }, + "nobook": { + "description": "inform the user that the booking is failed" + }, + "reqmore": { + "description": "ask the user for more instructions" + }, + "book": { + "description": "book something for the user" + }, + "bye": { + "description": "say goodbye to the user and end the conversation" + }, + "thank": { + "description": "thanks for the help" + }, + "welcome": { + "description": "you're welcome" + }, + "greet": { + "description": "express greeting" + }, + "recommend": { + "description": "recommend a choice to the user" + }, + "select": { + "description": "provide several choices for the user" + }, + "offerbook": { + "description": "ask the user if he or she needs booking" + }, + "offerbooked": { + "description": "provide information about the booking" + }, + "nooffer": { + "description": "inform the user that there is no result satisfies user requirements" + } + }, + "state": { + "attraction": { + "type": "", + "name": "", + "area": "" + }, + "hotel": { + "name": "", + "area": "", + "parking": "", + "price range": "", + "stars": "", + "internet": "", + "type": "", + "book stay": "", + "book day": "", + "book people": "" + }, + "restaurant": { + "food": "", + "price range": "", + "name": "", + "area": "", + "book time": "", + "book day": "", + "book people": "" + }, + "taxi": { + "leave at": "", + "destination": "", + "departure": "", + "arrive by": "" + }, + "train": { + "leave at": "", + "destination": "", + "day": "", + "arrive by": "", + "departure": "", + "book people": "" + }, + "hospital": { + "department": "" + } + }, + "dialogue_acts": { + "categorical": {}, + "non-categorical": {}, + "binary": {} + } +} + +slot_name_map = { + 'addr': "address", + 'post': "postcode", + 'pricerange': "price range", + 'arrive': "arrive by", + 'arriveby': "arrive by", + 'leave': "leave at", + 'leaveat': "leave at", + 'depart': "departure", + 'dest': "destination", + 'fee': "entrance fee", + 'open': 'open hours', + 'car': "type", + 'car type': "type", + 'ticket': 'price', + 'trainid': 'train id', + 'id': 'train id', + 'people': 'book people', + 'stay': 'book stay', + 'none': '', + 'attraction': { + 'price': 'entrance fee' + }, + 'hospital': {}, + 'hotel': { + 'day': 'book day', 'price': "price range" + }, + 'restaurant': { + 'day': 'book day', 'time': 'book time', 'price': "price range" + }, + 'taxi': {}, + 'train': { + 'day': 'day', 'time': "duration" + }, + 'police': {}, + 'booking': {} +} + +reverse_da_slot_name_map = { + 'address': 'Addr', + 'postcode': 'Post', + 'price range': 'Price', + 'arrive by': 'Arrive', + 'leave at': 'Leave', + 'departure': 'Depart', + 'destination': 'Dest', + 'entrance fee': 'Fee', + 'open hours': 'Open', + 'price': 'Ticket', + 'train id': 'Id', + 'book people': 'People', + 'book stay': 'Stay', + 'book day': 'Day', + 'book time': 'Time', + 'duration': 'Time', + 'taxi': { + 'type': 'Car', + 'phone': 'Phone' + } +} + +digit2word = { + '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', '5': 'five', + '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', '10': 'ten' +} + +cnt_domain_slot = Counter() + + +class BookingActRemapper: + + def __init__(self, ontology): + self.ontology = ontology + self.reset() + + def reset(self): + self.current_domains_user = [] + self.current_domains_system = [] + self.booked_domains = [] + + def retrieve_current_domain_from_user(self, turn_id, ori_dialog): + prev_user_turn = ori_dialog[turn_id - 1] + + dialog_acts = prev_user_turn.get('dialog_act', []) + keyword_domains_user = get_keyword_domains(prev_user_turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + self.current_domains_user = current_domains_temp if current_domains_temp else self.current_domains_user + next_user_domains = get_next_user_act_domains(ori_dialog, turn_id) + + return keyword_domains_user, next_user_domains + + def retrieve_current_domain_from_system(self, turn_id, ori_dialog): + + system_turn = ori_dialog[turn_id] + dialog_acts = system_turn.get('dialog_act', []) + keyword_domains_system = get_keyword_domains(system_turn) + current_domains_temp = get_current_domains_from_act(dialog_acts) + self.current_domains_system = current_domains_temp if current_domains_temp else self.current_domains_system + booked_domain_current = self.check_domain_booked(system_turn) + + return keyword_domains_system, booked_domain_current + + def remap(self, turn_id, ori_dialog): + + keyword_domains_user, next_user_domains = self.retrieve_current_domain_from_user( + turn_id, ori_dialog) + keyword_domains_system, booked_domain_current = self.retrieve_current_domain_from_system( + turn_id, ori_dialog) + + # only need to remap if there is a dialog action labelled + dialog_acts = ori_dialog[turn_id].get('dialog_act', []) + spans = ori_dialog[turn_id].get('span_info', []) + if dialog_acts: + + flattened_acts = flatten_acts(dialog_acts) + flattened_spans = flatten_span_acts(spans) + remapped_acts, error_local = remap_acts(flattened_acts, self.current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, self.current_domains_system, + next_user_domains, self.ontology) + + remapped_spans, _ = remap_acts(flattened_spans, self.current_domains_user, + booked_domain_current, keyword_domains_user, + keyword_domains_system, self.current_domains_system, + next_user_domains, self.ontology) + + deflattened_remapped_acts = deflat_acts(remapped_acts) + deflattened_remapped_spans = deflat_span_acts(remapped_spans) + + return deflattened_remapped_acts, deflattened_remapped_spans + else: + return dialog_acts, spans + + def check_domain_booked(self, turn): + + booked_domain_current = None + return booked_domain_current + + # workaround + for domain in turn['metadata']: + if turn['metadata'][domain]["book"]["booked"] and domain not in self.booked_domains: + booked_domain_current = domain.capitalize() + self.booked_domains.append(domain) + return booked_domain_current + + +def get_keyword_domains(turn): + keyword_domains = [] + text = turn['text'] + for d in ["Hotel", "Restaurant", "Train"]: + if d.lower() in text.lower(): + keyword_domains.append(d) + return keyword_domains + + +def get_current_domains_from_act(dialog_acts): + + current_domains_temp = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + if domain in ["general", "Booking"]: + continue + if domain not in current_domains_temp: + current_domains_temp.append(domain) + + return current_domains_temp + + +def get_next_user_act_domains(ori_dialog, turn_id): + domains = [] + try: + next_user_act = ori_dialog[turn_id + 1]['dialog_act'] + domains = get_current_domains_from_act(next_user_act) + except: + # will fail if system act is the last act of the dialogue + pass + return domains + + +def flatten_acts(dialog_acts): + flattened_acts = [] + for dom_int in dialog_acts: + domain, intent = dom_int.split('-') + for slot_value in dialog_acts[dom_int]: + slot = slot_value[0] + value = slot_value[1] + flattened_acts.append((domain, intent, slot, value)) + + return flattened_acts + + +def flatten_span_acts(span_acts): + + flattened_acts = [] + for span_act in span_acts: + domain, intent = span_act[0].split("-") + flattened_acts.append((domain, intent, span_act[1], span_act[2:])) + return flattened_acts + + +def deflat_acts(flattened_acts): + + dialog_acts = dict() + + for act in flattened_acts: + domain, intent, slot, value = act + if f"{domain}-{intent}" not in dialog_acts.keys(): + dialog_acts[f"{domain}-{intent}"] = [[slot, value]] + else: + dialog_acts[f"{domain}-{intent}"].append([slot, value]) + + return dialog_acts + + +def deflat_span_acts(flattened_acts): + + dialog_span_acts = [] + for act in flattened_acts: + domain, intent, slot, value = act + if value == 'none': + continue + new_act = [f"{domain}-{intent}", slot] + new_act.extend(value) + dialog_span_acts.append(new_act) + + return dialog_span_acts + + +def remap_acts(flattened_acts, current_domains, booked_domain=None, keyword_domains_user=None, + keyword_domains_system=None, current_domain_system=None, next_user_domain=None, ontology=None): + + # We now look for all cases that can happen: Booking domain, Booking within a domain or taxi-inform-car for booking + error = 0 + remapped_acts = [] + + # if there is more than one current domain or none at all, we try to get booked domain differently + if len(current_domains) != 1 and booked_domain: + current_domains = [booked_domain] + elif len(current_domains) != 1 and len(keyword_domains_user) == 1: + current_domains = keyword_domains_user + elif len(current_domains) != 1 and len(keyword_domains_system) == 1: + current_domains = keyword_domains_system + elif len(current_domains) != 1 and len(current_domain_system) == 1: + current_domains = current_domain_system + elif len(current_domains) != 1 and len(next_user_domain) == 1: + current_domains = next_user_domain + + for act in flattened_acts: + try: + domain, intent, slot, value = act + if f"{domain}-{intent}-{slot}" == "Booking-Book-Ref": + # We need to remap that booking act now + potential_domain = current_domains[0] + remapped_acts.append( + (potential_domain, "Book", "none", "none")) + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, "Inform", "Ref", value)) + elif domain == "Booking" and intent == "Book" and slot != "Ref": + # the book intent is here actually an inform intent according to the data + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, "Inform", slot, value)) + elif domain == "Booking" and intent == "Inform": + # the inform intent is here actually a request intent according to the data + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, "OfferBook", slot, value)) + elif domain == "Booking" and intent in ["NoBook", "Request"]: + potential_domain = current_domains[0] + if ontology_check(potential_domain, slot, ontology): + remapped_acts.append( + (potential_domain, intent, slot, value)) + elif f"{domain}-{intent}-{slot}" == "Taxi-Inform-Car": + # taxi-inform-car actually triggers the booking and informs on a car + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, intent, slot, value)) + elif f"{domain}-{intent}-{slot}" in ["Train-Inform-Ref", "Train-OfferBooked-Ref"]: + # train-inform/offerbooked-ref actually triggers the booking and informs on the reference number + remapped_acts.append((domain, "Book", "none", "none")) + remapped_acts.append((domain, "Inform", slot, value)) + elif domain == "Train" and intent == "OfferBooked" and slot != "Ref": + # this is actually an inform act + remapped_acts.append((domain, "Inform", slot, value)) + else: + remapped_acts.append(act) + except Exception as e: + print("Error detected:", e) + error += 1 + + return remapped_acts, error + + +def ontology_check(domain_, slot_, init_ontology): + + domain = domain_.lower() + slot = slot_.lower() + if slot not in init_ontology['domains'][domain]['slots']: + if slot in slot_name_map: + slot = slot_name_map[slot] + elif slot in slot_name_map[domain]: + slot = slot_name_map[domain][slot] + return slot in init_ontology['domains'][domain]['slots'] + + +def reverse_da(dialogue_acts): + global reverse_da_slot_name_map + das = {} + for da_type in dialogue_acts: + for da in dialogue_acts[da_type]: + intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da.get( + 'value', '') + if domain == 'general': + Domain_Intent = '-'.join([domain, intent]) + elif intent == 'nooffer': + Domain_Intent = '-'.join([domain.capitalize(), 'NoOffer']) + elif intent == 'nobook': + Domain_Intent = '-'.join([domain.capitalize(), 'NoBook']) + elif intent == 'offerbook': + Domain_Intent = '-'.join([domain.capitalize(), 'OfferBook']) + else: + Domain_Intent = '-'.join([domain.capitalize(), + intent.capitalize()]) + das.setdefault(Domain_Intent, []) + if slot in reverse_da_slot_name_map: + Slot = reverse_da_slot_name_map[slot] + elif domain in reverse_da_slot_name_map and slot in reverse_da_slot_name_map[domain]: + Slot = reverse_da_slot_name_map[domain][slot] + else: + Slot = slot.capitalize() + if value == '': + if intent == 'request': + value = '?' + else: + value = 'none' + if Slot == '': + Slot = 'none' + das[Domain_Intent].append([Slot, value]) + return das + + +def normalize_domain_slot_value(domain, slot, value): + global ontology, slot_name_map + domain = domain.lower() + slot = slot.lower() + value = value.strip() + if value in ['do nt care', "do n't care"]: + value = 'dontcare' + if value in ['?', 'none', 'not mentioned']: + value = "" + if domain not in ontology['domains']: + raise Exception(f'{domain} not in ontology') + if slot not in ontology['domains'][domain]['slots']: + if slot in slot_name_map: + slot = slot_name_map[slot] + elif slot in slot_name_map[domain]: + slot = slot_name_map[domain][slot] + else: + raise Exception(f'{domain}-{slot} not in ontology') + assert slot == '' or slot in ontology['domains'][domain][ + 'slots'], f'{(domain, slot, value)} not in ontology' + return domain, slot, value + + +def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer): + ''' + convert multiwoz dialogue acts to required format + :param da_dict: dict[(intent, domain, slot, value)] = [word_start, word_end] + :param utt: user or system utt + ''' + global ontology, digit2word, cnt_domain_slot + + converted_da = { + 'categorical': [], + 'non-categorical': [], + 'binary': [] + } + sentences = sent_tokenizer.tokenize(utt) + sent_spans = sent_tokenizer.span_tokenize(utt) + tokens = [ + token for sent in sentences for token in word_tokenizer.tokenize(sent)] + token_spans = [(sent_span[0] + token_span[0], sent_span[0] + token_span[1]) for sent, sent_span in + zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)] + # assert len(tokens) == len(token_spans) + # for token, span in zip(tokens, token_spans): + # if utt[span[0]:span[1]] != '"': + # assert utt[span[0]:span[1]] == token + + for (intent, domain, slot, value), span in da_dict.items(): + if intent == 'request' or slot == '' or value == '': + # binary dialog acts + assert value == '' + converted_da['binary'].append({ + 'intent': intent, + 'domain': domain, + 'slot': slot + }) + elif ontology['domains'][domain]['slots'][slot]['is_categorical']: + # categorical dialog acts + converted_da['categorical'].append({ + 'intent': intent, + 'domain': domain, + 'slot': slot, + 'value': value + }) + else: + # non-categorical dialog acts + converted_da['non-categorical'].append({ + 'intent': intent, + 'domain': domain, + 'slot': slot, + 'value': value + }) + # correct some value and try to give char level span + match = False + value = value.lower() + if span and span[0] <= span[1]: + # use original span annotation, but tokenizations are different + start_word, end_word = span + if end_word >= len(tokens): + # due to different tokenization, sometimes will out of index + delta = end_word - len(tokens) + 1 + start_word -= delta + end_word -= delta + start_char, end_char = token_spans[start_word][0], token_spans[end_word][1] + value_span = utt[start_char:end_char].lower() + match = True + if value_span == value: + cnt_domain_slot['span match'] += 1 + elif value.isdigit() and value in digit2word and digit2word[value] == value_span: + # !!!CHANGE VALUE: value is digit but value span is word + cnt_domain_slot['digit value match'] += 1 + elif ''.join(value.split()) == ''.join(value_span.split()): + # !!!CHANGE VALUE: equal when remove blank + cnt_domain_slot['remove blank'] += 1 + elif value in value_span: + # value in value_span + start_char += value_span.index(value) + end_char = start_char + len(value) + assert utt[start_char:end_char].lower( + ) == value, f'{[value, utt[start_char:end_char], utt]}' + cnt_domain_slot['value in span'] += 1 + elif ':' in value and value == '0' + value_span: + # !!!CHANGE VALUE: time x:xx == 0x:xx + cnt_domain_slot['x:xx == 0x:xx'] += 1 + else: + # span mismatch, search near 1-2 words + for window in range(1, 3): + start = max(0, start_word - window) + end = min(len(token_spans) - 1, end_word + window) + large_span = utt[token_spans[start] + [0]:token_spans[end][1]].lower() + if value in large_span: + start_char = token_spans[start][0] + \ + large_span.index(value) + end_char = start_char + len(value) + assert utt[ + start_char:end_char].lower() == value, f'{[value, utt[start_char:end_char], utt]}' + cnt_domain_slot[f'window={window}'] += 1 + break + else: + # still not found + match = False + + if match: + converted_da['non-categorical'][-1]['value'] = utt[start_char:end_char] + converted_da['non-categorical'][-1]['start'] = start_char + converted_da['non-categorical'][-1]['end'] = end_char + cnt_domain_slot['have span'] += 1 + else: + cnt_domain_slot['no span'] += 1 + return converted_da + + +def preprocess(): + original_data_dir = 'emowoz' + new_data_dir = 'data' + + if not os.path.exists(original_data_dir): + original_data_zip = 'MultiWOZ_2.1.zip' + if not os.path.exists(original_data_zip): + raise FileNotFoundError( + f'cannot find original data {original_data_zip} in multiwoz21/, should manually download MultiWOZ_2.1.zip from https://github.com/budzianowski/multiwoz/blob/master/data/MultiWOZ_2.1.zip') + else: + archive = ZipFile(original_data_zip) + archive.extractall() + + os.makedirs(new_data_dir, exist_ok=True) + for filename in os.listdir(original_data_dir): + if 'db' in filename: + copy2(f'{original_data_dir}/{filename}', new_data_dir) + + # how about emowoz-dialmage + original_data = json.load( + open(f'{original_data_dir}/emowoz-multiwoz.json')) + global ontology, cnt_domain_slot + + data_split = json.load(open(f'{original_data_dir}/data_split.json')) + val_list = data_split["dev"]["multiwoz"] + test_list = data_split["test"]["multiwoz"] + # val_list = set(open(f'{original_data_dir}/valListFile.txt').read().split()) + # test_list = set(open(f'{original_data_dir}/testListFile.txt').read().split()) + dataset = 'multiwoz21' + splits = ['train', 'validation', 'test'] + dialogues_by_split = {split: [] for split in splits} + sent_tokenizer = PunktSentenceTokenizer() + word_tokenizer = TreebankWordTokenizer() + booking_remapper = BookingActRemapper(ontology) + for ori_dialog_id, ori_dialog in tqdm(original_data.items()): + if ori_dialog_id in val_list: + split = 'validation' + elif ori_dialog_id in test_list: + split = 'test' + else: + split = 'train' + dialogue_id = f'{dataset}-{split}-{len(dialogues_by_split[split])}' + + # get user goal and involved domains + cur_domains = [] + + dialogue = { + 'dataset': dataset, + 'data_split': split, + 'dialogue_id': dialogue_id, + 'original_id': ori_dialog_id, + 'domains': cur_domains, # will be updated by dialog_acts and state + 'goal': "", + 'turns': [] + } + + booking_remapper.reset() + belief_domains = ['attraction', 'restaurant', + 'train', 'hotel', 'taxi', 'hospital'] + entity_booked_dict = dict((domain, False) for domain in belief_domains) + for turn_id, turn in enumerate(ori_dialog['log']): + # correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1 + text = turn['text'] + text = re.sub(" Im ", " I'm ", text) + text = re.sub(" im ", " i'm ", text) + text = re.sub(r"^Im ", "I'm ", text) + text = re.sub(r"^im ", "i'm ", text) + text = re.sub("theres", "there's", text) + text = re.sub("dont", "don't", text) + text = re.sub("whats", "what's", text) + text = re.sub('thats', "that's", text) + utt = text + speaker = 'user' if turn_id % 2 == 0 else 'system' + + das = turn.get('dialog_act', []) + spans = turn.get('span_info', []) + + if speaker == 'system': + das, spans = booking_remapper.remap(turn_id, ori_dialog['log']) + + da_dict = {} + # transform DA + for Domain_Intent in das: + domain, intent = Domain_Intent.lower().split('-') + assert intent in ontology['intents'], f'{ori_dialog_id}:{turn_id}:da\t{intent} not in ontology' + for Slot, value in das[Domain_Intent]: + domain, slot, value = normalize_domain_slot_value( + domain, Slot, value) + if domain not in cur_domains: + # update original cur_domains + cur_domains.append(domain) + da_dict[(intent, domain, slot, value,)] = [] + + for span in spans: + Domain_Intent, Slot, value, start_word, end_word = span + domain, intent = Domain_Intent.lower().split('-') + domain, slot, value = normalize_domain_slot_value( + domain, Slot, value) + assert (intent, domain, slot, value,) in da_dict + da_dict[(intent, domain, slot, value,)] = [ + start_word, end_word] + + dialogue_acts = convert_da( + da_dict, utt, sent_tokenizer, word_tokenizer) + + # reverse_das = reverse_da(dialogue_acts) + # das_list = sorted([(Domain_Intent, Slot, ''.join(value.split()).lower()) for Domain_Intent in das for Slot, value in das[Domain_Intent]]) + # reverse_das_list = sorted([(Domain_Intent, Slot, ''.join(value.split()).lower()) for Domain_Intent in reverse_das for Slot, value in reverse_das[Domain_Intent]]) + # if das_list != reverse_das_list: + # print(das_list) + # print(reverse_das_list) + # print() + # print() + + dialogue['turns'].append({ + 'speaker': speaker, + 'utterance': utt, + 'utt_idx': turn_id, + 'dialogue_acts': dialogue_acts, + 'emotion': turn['emotion'] + }) + + # add to dialogue_acts dictionary in the ontology + for da_type in dialogue_acts: + das = dialogue_acts[da_type] + for da in das: + ontology["dialogue_acts"][da_type].setdefault( + (da['intent'], da['domain'], da['slot']), {}) + ontology["dialogue_acts"][da_type][( + da['intent'], da['domain'], da['slot'])][speaker] = True + + if speaker == 'system': + # add state to last user turn + # add empty db_results + # turn_state = turn['metadata'] + cur_state = copy.deepcopy(ontology['state']) + booked = {} + # for domain in turn_state: + # if domain not in cur_state: + # continue + # for subdomain in ['semi', 'book']: + # for slot, value in turn_state[domain][subdomain].items(): + # if slot == 'ticket': + # continue + # elif slot == 'booked': + # assert domain in ontology['domains'] + # booked[domain] = value + # continue + # _, slot, value = normalize_domain_slot_value( + # domain, slot, value) + # cur_state[domain][slot] = value + dialogue['turns'][-2]['state'] = cur_state + # entity_booked_dict, booked = fix_entity_booked_info( + # entity_booked_dict, booked) + dialogue['turns'][-1]['booked'] = booked + dialogues_by_split[split].append(dialogue) + # pprint(cnt_domain_slot.most_common()) + dialogues = [] + for split in splits: + dialogues += dialogues_by_split[split] + for da_type in ontology['dialogue_acts']: + ontology["dialogue_acts"][da_type] = sorted([str( + {'user': speakers.get('user', False), 'system': speakers.get('system', False), 'intent': da[0], + 'domain': da[1], 'slot': da[2]}) for da, speakers in ontology["dialogue_acts"][da_type].items()]) + json.dump(dialogues[:10], open(f'dummy_data.json', 'w', + encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(ontology, open(f'{new_data_dir}/ontology.json', + 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', + 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: + for filename in os.listdir(new_data_dir): + zf.write(f'{new_data_dir}/{filename}') + # rmtree(original_data_dir) + # rmtree(new_data_dir) + return dialogues, ontology + + +def fix_entity_booked_info(entity_booked_dict, booked): + for domain in entity_booked_dict: + if not entity_booked_dict[domain] and booked[domain]: + entity_booked_dict[domain] = True + booked[domain] = [] + return entity_booked_dict, booked + + +if __name__ == '__main__': + preprocess() diff --git a/data/unified_datasets/emowoz/shuffled_dial_ids.json b/data/unified_datasets/emowoz/shuffled_dial_ids.json new file mode 100644 index 0000000000000000000000000000000000000000..bc2752e1c759d25a64a826fcbd96c0196a106912 --- /dev/null +++ b/data/unified_datasets/emowoz/shuffled_dial_ids.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb4fa37179f665437e5add87a780e777089a43dee35f36bc4ef316b3ad714042 +size 619950 diff --git a/setup.py b/setup.py index 0a94fb0e50cbc36f96b4b3102690656bfa10f2e1..b88454ccc61aa941f888e69494e030848599f558 100755 --- a/setup.py +++ b/setup.py @@ -77,6 +77,6 @@ setup( url='https://github.com/ConvLab/ConvLab-3', author='convlab', author_email='convlab@googlegroups.com', - python_requires='>=3.8', + python_requires='>=3.7', zip_safe=False )