From 4e28a9f0d92d6935cf977f6f4b2e880ab62d7734 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Mon, 23 Jan 2023 22:22:21 +0100 Subject: [PATCH] add different version for ablation study --- convlab/policy/emoTUS/unify/build_data.py | 91 ++++++++++++++++------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/convlab/policy/emoTUS/unify/build_data.py b/convlab/policy/emoTUS/unify/build_data.py index 82dd20da..86071b80 100644 --- a/convlab/policy/emoTUS/unify/build_data.py +++ b/convlab/policy/emoTUS/unify/build_data.py @@ -17,22 +17,22 @@ sys.path.append(os.path.dirname(os.path.dirname( def arg_parser(): parser = ArgumentParser() - parser.add_argument("--dataset", type=str, default="emowoz") - parser.add_argument("--use-sentiment", action="store_true") + parser.add_argument("--dataset", type=str, default="emowoz+dialmage") parser.add_argument("--dial-ids-order", type=int, default=0) parser.add_argument("--split2ratio", type=float, default=1) - parser.add_argument("--random-order", action="store_true") - parser.add_argument("--no-status", action="store_true") - parser.add_argument("--add-history", action="store_true") - parser.add_argument("--remove-domain", type=str, default="") + parser.add_argument("--use-sentiment", action="store_true") + parser.add_argument("--add-persona", action="store_true") + parser.add_argument("--emotion-mid", action="store_true") return parser.parse_args() class DataBuilder(GenTUSDataBuilder): - def __init__(self, dataset='emowoz', use_sentiment=False): + def __init__(self, dataset='emowoz', **kwargs): super().__init__(dataset) - self.use_sentiment = use_sentiment + self.use_sentiment = kwargs.get("use_sentiment", False) + self.emotion_mid = kwargs.get("emotion_mid", False) + self.add_persona = kwargs.get("add_persona", False) self.emotion = {} for emotion, index in json.load(open("convlab/policy/emoTUS/emotion.json")).items(): @@ -55,7 +55,7 @@ class DataBuilder(GenTUSDataBuilder): return example user_goal = Goal(goal=data_goal) user_info = None - if self.use_sentiment: + if self.add_persona: user_info = emotion_info(dialog) # if user_info["user"] == "Impolite": # print(user_info) @@ -108,7 +108,7 @@ class DataBuilder(GenTUSDataBuilder): in_str["history"] = h in_str["turn"] = str(int(turn_id/2)) - if self.use_sentiment: + if self.add_persona: for info in ["event", "user"]: if info not in user_info: continue @@ -117,15 +117,24 @@ class DataBuilder(GenTUSDataBuilder): return json.dumps(in_str) def _dump_out_str(self, usr_act, text, usr_emotion, usr_sentiment=None): - if self.use_sentiment: + if self.use_sentiment and self.emotion_mid: out_str = {"sentiment": usr_sentiment, "action": usr_act, "emotion": usr_emotion, "text": text} - else: + elif self.use_sentiment and not self.emotion_mid: + out_str = {"sentiment": usr_sentiment, + "emotion": usr_emotion, + "action": usr_act, + "text": text} + elif not self.use_sentiment and not self.emotion_mid: out_str = {"emotion": usr_emotion, "action": usr_act, "text": text} + else: + out_str = {"action": usr_act, + "emotion": usr_emotion, + "text": text} return json.dumps(out_str) @@ -142,8 +151,41 @@ if __name__ == "__main__": base_name = "convlab/policy/emoTUS/unify/data" dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}" + + use_sentiment = args.use_sentiment + emotion_mid = args.emotion_mid + add_persona = args.add_persona + + data_status = [use_sentiment, emotion_mid, add_persona] + + if data_status == [True, True, True]: + # current sentUS + dir_name = f"SentUS_{dir_name}" + elif data_status == [True, True, False]: + # current sentUS without persona + dir_name = f"SentUS_noPersona_{dir_name}" + elif data_status == [False, False, True]: + # current emoUS with persona + dir_name = f"EmoUS_{dir_name}" + elif data_status == [False, False, False]: + # current emoUS + dir_name = f"EmoUS_noPersona_{dir_name}" + elif data_status == [False, True, True]: + # mid emotion + dir_name = f"MIDemoUS_{dir_name}" + elif data_status == [False, True, False]: + dir_name = f"MIDemoUS_noPersona_{dir_name}" + elif data_status == [True, False, True]: + # sentiment followed by emotion, not act + dir_name = f"SentEmoUS_{dir_name}" + elif data_status == [True, False, False]: + # sentiment followed by emotion, not act, without perosna + dir_name = f"SentEmoUS_noPersona_{dir_name}" + else: + print("NOT DEFINED", use_sentiment, add_persona, emotion_mid) + print("dir_name", dir_name) + folder_name = os.path.join(base_name, dir_name) - remove_domain = args.remove_domain if not os.path.exists(folder_name): os.makedirs(folder_name) @@ -153,21 +195,18 @@ if __name__ == "__main__": split2ratio=args.split2ratio) data_builder = DataBuilder( dataset=args.dataset, - use_sentiment=args.use_sentiment) + use_sentiment=use_sentiment, + add_persona=add_persona, + emotion_mid=emotion_mid) data = data_builder.setup_data( raw_data=dataset, - random_order=args.random_order, - no_status=args.no_status, - add_history=args.add_history, - remove_domain=remove_domain) + random_order=False, + no_status=False, + add_history=True, + remove_domain=None) for data_type in data: - if remove_domain: - file_name = os.path.join( - folder_name, - f"no{remove_domain}_{data_type}.json") - else: - file_name = os.path.join( - folder_name, - f"{data_type}.json") + file_name = os.path.join( + folder_name, + f"{data_type}.json") json.dump(data[data_type], open(file_name, 'w'), indent=2) -- GitLab