Skip to content
Snippets Groups Projects
Commit aebfa274 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

analysis

parent 14f05d86
Branches
No related tags found
No related merge requests found
from argparse import ArgumentParser
import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt
def arg_parser():
parser = ArgumentParser()
parser.add_argument("--file", type=str, help="the conversation file")
return parser.parse_args()
def basic_analysis(conversation):
info = {"Complete": [], "Success": [], "Success strict": [], "turns": []}
for dialog in conversation:
for x in info:
info[x].append(dialog[x])
for x in info:
info[x] = np.mean(info[x])
return info
def advance(conversation):
info = {}
for dialog in conversation:
temp = turn_level(dialog["log"])
for metric, data in temp.items():
if metric not in info:
info[metric] = {}
for emotion, count in data.items():
if emotion not in info[metric]:
info[metric][emotion] = 0
info[metric][emotion] += count
return info
def get_turn_emotion(conversation):
turn_info = {"all": {},
"Complete": {}, "Not Complete": {},
"Success": {}, "Not Success": {},
"Success strict": {}, "Not Success strict": {}}
max_turn = 0
for dialog in conversation:
for i in range(0, len(dialog["log"]), 2):
turn = int(i / 2)
if turn > max_turn:
max_turn = turn
emotion = emotion_score(dialog["log"][i]["emotion"])
insert_turn(turn_info["all"], turn, emotion)
for metric in ["Complete", "Success", "Success strict"]:
if dialog[metric]:
insert_turn(turn_info[metric], turn, emotion)
else:
insert_turn(turn_info[f"Not {metric}"], turn, emotion)
data = {'x': [t for t in range(turn)], 'all_positive': [
], 'all_negative': [], 'all_mean': []}
for metric in ["Complete", "Success", "Success strict"]:
data[f"{metric}_positive"] = []
data[f"{metric}_negative"] = []
data[f"{metric}_mean"] = []
data[f"Not {metric}_positive"] = []
data[f"Not {metric}_negative"] = []
data[f"Not {metric}_mean"] = []
for t in range(turn):
pos, neg, mean = turn_score(turn_info["all"][t])
data[f"all_positive"].append(pos)
data[f"all_negative"].append(neg)
data[f"all_mean"].append(mean)
for raw_metric in ["Complete", "Success", "Success strict"]:
for metric in [raw_metric, f"Not {raw_metric}"]:
if t not in turn_info[metric]:
data[f"{metric}_positive"].append(0)
data[f"{metric}_negative"].append(0)
data[f"{metric}_mean"].append(0)
else:
pos, neg, mean = turn_score(turn_info[metric][t])
data[f"{metric}_positive"].append(pos)
data[f"{metric}_negative"].append(neg)
data[f"{metric}_mean"].append(mean)
fig, ax = plt.subplots()
ax.plot(data['x'], data["Complete_mean"],
'o--', color='C0', label="Complete")
ax.fill_between(data['x'], data["Complete_positive"],
data["Complete_negative"], color='C0', alpha=0.2)
ax.plot(data['x'], data["Not Complete_mean"],
'o--', color='C1', label="Not Complete")
ax.fill_between(data['x'], data["Not Complete_positive"],
data["Not Complete_negative"], color='C1', alpha=0.2)
ax.plot(data['x'], data["all_mean"], 'o--', color='C2',
label="All")
ax.fill_between(data['x'], data["all_positive"],
data["all_negative"], color='C2', alpha=0.2)
ax.legend()
ax.set_xlabel("turn")
ax.set_ylabel("Sentiment")
# plt.show()
plt.savefig("convlab/policy/emoTUS/fig.png")
def turn_score(score_list):
count = len(score_list)
positive = 0
negative = 0
for s in score_list:
if s > 0:
positive += 1
if s < 0:
negative += -1
return positive/count, negative/count, np.mean(score_list)
def insert_turn(turn_info, turn, emotion):
if turn not in turn_info:
turn_info[turn] = []
turn_info[turn].append(emotion)
def emotion_score(emotion):
if emotion == "Neutral":
return 0
if emotion in ["Satisfied", "Excited"]:
return 1
return -1
def plot(conversation):
pass
def turn_level(dialog):
# metric: {emotion: count}
dialog_info = {}
for index in range(2, len(dialog), 2):
pre_usr = dialog[index-2]
sys = dialog[index-1]
cur_usr = dialog[index]
info = neglect_reply(pre_usr, sys, cur_usr)
append_info(dialog_info, info)
info = confirm(pre_usr, sys, cur_usr)
append_info(dialog_info, info)
info = miss_info(pre_usr, sys, cur_usr)
append_info(dialog_info, info)
if index > 2:
info = loop(dialog[index-3], sys, cur_usr)
append_info(dialog_info, info)
return dialog_info
# provide wrong info
# action length
# incomplete info?
def append_info(dialog_info, info):
if not info:
return
for emotion, metric in info.items():
if metric not in dialog_info:
dialog_info[metric] = {}
if emotion not in dialog_info[metric]:
dialog_info[metric][emotion] = 0
dialog_info[metric][emotion] += 1
def get_inform(act):
inform = {}
for intent, domain, slot, value in act:
if intent not in ["inform", "recommend"]:
continue
if domain not in inform:
inform[domain] = []
inform[domain].append(slot)
return inform
def get_request(act):
request = {}
for intent, domain, slot, _ in act:
if intent == "request":
if domain not in request:
request[domain] = []
request[domain].append(slot)
return request
def neglect_reply(pre_usr, sys, cur_usr):
request = get_request(pre_usr["act"])
if not request:
return {}
system_inform = get_inform(sys["utt"])
for domain, slots in request.items():
if domain not in system_inform:
return {cur_usr["emotion"]: "neglect"}
for slot in slots:
if slot not in system_inform[domain]:
return {cur_usr["emotion"]: "neglect"}
return {cur_usr["emotion"]: "reply"}
def miss_info(pre_usr, sys, cur_usr):
system_request = get_request(sys["utt"])
if not system_request:
return {}
user_inform = get_inform(pre_usr["act"])
for domain, slots in system_request.items():
if domain not in user_inform:
continue
for slot in slots:
if slot in user_inform[domain]:
return {cur_usr["emotion"]: "miss_info"}
return {}
def confirm(pre_usr, sys, cur_usr):
user_inform = get_inform(pre_usr["act"])
if not user_inform:
return {}
system_inform = get_inform(sys["utt"])
for domain, slots in user_inform.items():
if domain not in system_inform:
continue
for slot in slots:
if slot in system_inform[domain]:
return {cur_usr["emotion"]: "confirm"}
return {cur_usr["emotion"]: "no confirm"}
def loop(s0, s1, u1):
if s0 == s1:
return {u1["emotion"]: "loop"}
def dict2csv(data):
r = {}
emotion = json.load(open("convlab/policy/emoTUS/emotion.json"))
for act, value in data.items():
temp = [0]*(len(emotion)+1)
for emo, count in value.items():
temp[emotion[emo]] = count
temp[-1] = sum(temp)
for i in range(len(emotion)):
temp[i] /= temp[-1]
r[act] = temp
dataframe = pd.DataFrame.from_dict(
r, orient='index', columns=[emo for emo in emotion]+["count"])
dataframe.to_csv(open("convlab/policy/emoTUS/act2emotion.csv", 'w'))
def main():
args = arg_parser()
result = {}
conversation = json.load(open(args.file))
basic_info = basic_analysis(conversation)
result["basic_info"] = basic_info
print(basic_info)
advance_info = advance(conversation)
print(advance_info)
result["advance_info"] = advance_info
json.dump(result, open(
"convlab/policy/emoTUS/conversation_result.json", 'w'), indent=2)
dict2csv(advance_info)
get_turn_emotion(conversation)
if __name__ == "__main__":
main()
......@@ -14,7 +14,7 @@ def arg_parser():
parser.add_argument("-N", "--num", type=int,
default=500, help="# of evaluation dialogue")
parser.add_argument("--model", type=str,
default="rule", help="# of evaluation dialogue")
default="ppo", help="# of evaluation dialogue")
return parser.parse_args()
......@@ -25,7 +25,7 @@ def interact(model_name, config, seed=0, num_goals=500):
if model_name == "rule":
policy_sys = RulePolicy()
elif model_name == "PPO":
elif model_name == "ppo":
from convlab.policy.ppo import PPO
policy_sys = PPO(vectorizer=conf['vectorizer_sys_activated'])
......@@ -91,8 +91,13 @@ if __name__ == "__main__":
time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
args = arg_parser()
conversation = interact(model_name=args.model,
config=args.config, num_goals=args.num)
json.dump(conversation,
open(os.path.join("convlab/policy/emoTUS",
f"conversation-{time}.json"), 'w'),
config=args.config,
num_goals=args.num)
data = {"config": json.load(open(args.config)),
"conversation": conversation}
folder_name = os.path.join("convlab/policy/emoTUS", "conversation")
if not os.path.exists(folder_name):
os.makedirs(folder_name)
json.dump(data,
open(os.path.join(folder_name, f"{time}.json"), 'w'),
indent=2)
......@@ -47,7 +47,8 @@
"model_checkpoint": "convlab/policy/emoTUS/unify/experiments/EmoUS_emowoz+dialmage_0_1",
"use_sentiment": false,
"add_persona": true,
"sample": false
"sample": false,
"weight": 1
}
}
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment