Skip to content
Snippets Groups Projects
Commit 4e28a9f0 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

add different version for ablation study

parent 799fffba
No related branches found
No related tags found
No related merge requests found
......@@ -17,22 +17,22 @@ sys.path.append(os.path.dirname(os.path.dirname(
def arg_parser():
parser = ArgumentParser()
parser.add_argument("--dataset", type=str, default="emowoz")
parser.add_argument("--use-sentiment", action="store_true")
parser.add_argument("--dataset", type=str, default="emowoz+dialmage")
parser.add_argument("--dial-ids-order", type=int, default=0)
parser.add_argument("--split2ratio", type=float, default=1)
parser.add_argument("--random-order", action="store_true")
parser.add_argument("--no-status", action="store_true")
parser.add_argument("--add-history", action="store_true")
parser.add_argument("--remove-domain", type=str, default="")
parser.add_argument("--use-sentiment", action="store_true")
parser.add_argument("--add-persona", action="store_true")
parser.add_argument("--emotion-mid", action="store_true")
return parser.parse_args()
class DataBuilder(GenTUSDataBuilder):
def __init__(self, dataset='emowoz', use_sentiment=False):
def __init__(self, dataset='emowoz', **kwargs):
super().__init__(dataset)
self.use_sentiment = use_sentiment
self.use_sentiment = kwargs.get("use_sentiment", False)
self.emotion_mid = kwargs.get("emotion_mid", False)
self.add_persona = kwargs.get("add_persona", False)
self.emotion = {}
for emotion, index in json.load(open("convlab/policy/emoTUS/emotion.json")).items():
......@@ -55,7 +55,7 @@ class DataBuilder(GenTUSDataBuilder):
return example
user_goal = Goal(goal=data_goal)
user_info = None
if self.use_sentiment:
if self.add_persona:
user_info = emotion_info(dialog)
# if user_info["user"] == "Impolite":
# print(user_info)
......@@ -108,7 +108,7 @@ class DataBuilder(GenTUSDataBuilder):
in_str["history"] = h
in_str["turn"] = str(int(turn_id/2))
if self.use_sentiment:
if self.add_persona:
for info in ["event", "user"]:
if info not in user_info:
continue
......@@ -117,15 +117,24 @@ class DataBuilder(GenTUSDataBuilder):
return json.dumps(in_str)
def _dump_out_str(self, usr_act, text, usr_emotion, usr_sentiment=None):
if self.use_sentiment:
if self.use_sentiment and self.emotion_mid:
out_str = {"sentiment": usr_sentiment,
"action": usr_act,
"emotion": usr_emotion,
"text": text}
else:
elif self.use_sentiment and not self.emotion_mid:
out_str = {"sentiment": usr_sentiment,
"emotion": usr_emotion,
"action": usr_act,
"text": text}
elif not self.use_sentiment and not self.emotion_mid:
out_str = {"emotion": usr_emotion,
"action": usr_act,
"text": text}
else:
out_str = {"action": usr_act,
"emotion": usr_emotion,
"text": text}
return json.dumps(out_str)
......@@ -142,8 +151,41 @@ if __name__ == "__main__":
base_name = "convlab/policy/emoTUS/unify/data"
dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}"
use_sentiment = args.use_sentiment
emotion_mid = args.emotion_mid
add_persona = args.add_persona
data_status = [use_sentiment, emotion_mid, add_persona]
if data_status == [True, True, True]:
# current sentUS
dir_name = f"SentUS_{dir_name}"
elif data_status == [True, True, False]:
# current sentUS without persona
dir_name = f"SentUS_noPersona_{dir_name}"
elif data_status == [False, False, True]:
# current emoUS with persona
dir_name = f"EmoUS_{dir_name}"
elif data_status == [False, False, False]:
# current emoUS
dir_name = f"EmoUS_noPersona_{dir_name}"
elif data_status == [False, True, True]:
# mid emotion
dir_name = f"MIDemoUS_{dir_name}"
elif data_status == [False, True, False]:
dir_name = f"MIDemoUS_noPersona_{dir_name}"
elif data_status == [True, False, True]:
# sentiment followed by emotion, not act
dir_name = f"SentEmoUS_{dir_name}"
elif data_status == [True, False, False]:
# sentiment followed by emotion, not act, without perosna
dir_name = f"SentEmoUS_noPersona_{dir_name}"
else:
print("NOT DEFINED", use_sentiment, add_persona, emotion_mid)
print("dir_name", dir_name)
folder_name = os.path.join(base_name, dir_name)
remove_domain = args.remove_domain
if not os.path.exists(folder_name):
os.makedirs(folder_name)
......@@ -153,20 +195,17 @@ if __name__ == "__main__":
split2ratio=args.split2ratio)
data_builder = DataBuilder(
dataset=args.dataset,
use_sentiment=args.use_sentiment)
use_sentiment=use_sentiment,
add_persona=add_persona,
emotion_mid=emotion_mid)
data = data_builder.setup_data(
raw_data=dataset,
random_order=args.random_order,
no_status=args.no_status,
add_history=args.add_history,
remove_domain=remove_domain)
random_order=False,
no_status=False,
add_history=True,
remove_domain=None)
for data_type in data:
if remove_domain:
file_name = os.path.join(
folder_name,
f"no{remove_domain}_{data_type}.json")
else:
file_name = os.path.join(
folder_name,
f"{data_type}.json")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment