From aebfa2746629b706cfc5b6550d6984c87baaf06f Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Fri, 3 Feb 2023 01:41:30 +0100
Subject: [PATCH] analysis

---
 convlab/policy/emoTUS/analysis.py             | 276 ++++++++++++++++++
 convlab/policy/emoTUS/dialogue_collector.py   |  17 +-
 .../emoTUS-BertNLU-RuleDST-RulePolicy.json    |   3 +-
 3 files changed, 289 insertions(+), 7 deletions(-)
 create mode 100644 convlab/policy/emoTUS/analysis.py

diff --git a/convlab/policy/emoTUS/analysis.py b/convlab/policy/emoTUS/analysis.py
new file mode 100644
index 00000000..8ac60c74
--- /dev/null
+++ b/convlab/policy/emoTUS/analysis.py
@@ -0,0 +1,276 @@
+from argparse import ArgumentParser
+import numpy as np
+import json
+import pandas as pd
+import matplotlib.pyplot as plt
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    parser.add_argument("--file", type=str, help="the conversation file")
+    return parser.parse_args()
+
+
+def basic_analysis(conversation):
+    info = {"Complete": [], "Success": [], "Success strict": [], "turns": []}
+    for dialog in conversation:
+        for x in info:
+            info[x].append(dialog[x])
+    for x in info:
+        info[x] = np.mean(info[x])
+    return info
+
+
+def advance(conversation):
+    info = {}
+    for dialog in conversation:
+        temp = turn_level(dialog["log"])
+        for metric, data in temp.items():
+            if metric not in info:
+                info[metric] = {}
+            for emotion, count in data.items():
+                if emotion not in info[metric]:
+                    info[metric][emotion] = 0
+                info[metric][emotion] += count
+
+    return info
+
+
+def get_turn_emotion(conversation):
+    turn_info = {"all": {},
+                 "Complete": {}, "Not Complete": {},
+                 "Success": {}, "Not Success": {},
+                 "Success strict": {}, "Not Success strict": {}}
+    max_turn = 0
+    for dialog in conversation:
+        for i in range(0, len(dialog["log"]), 2):
+            turn = int(i / 2)
+            if turn > max_turn:
+                max_turn = turn
+            emotion = emotion_score(dialog["log"][i]["emotion"])
+            insert_turn(turn_info["all"], turn, emotion)
+            for metric in ["Complete", "Success", "Success strict"]:
+                if dialog[metric]:
+                    insert_turn(turn_info[metric], turn, emotion)
+                else:
+                    insert_turn(turn_info[f"Not {metric}"], turn, emotion)
+    data = {'x': [t for t in range(turn)], 'all_positive': [
+    ], 'all_negative': [], 'all_mean': []}
+    for metric in ["Complete", "Success", "Success strict"]:
+        data[f"{metric}_positive"] = []
+        data[f"{metric}_negative"] = []
+        data[f"{metric}_mean"] = []
+        data[f"Not {metric}_positive"] = []
+        data[f"Not {metric}_negative"] = []
+        data[f"Not {metric}_mean"] = []
+
+    for t in range(turn):
+        pos, neg, mean = turn_score(turn_info["all"][t])
+        data[f"all_positive"].append(pos)
+        data[f"all_negative"].append(neg)
+        data[f"all_mean"].append(mean)
+        for raw_metric in ["Complete", "Success", "Success strict"]:
+            for metric in [raw_metric, f"Not {raw_metric}"]:
+                if t not in turn_info[metric]:
+                    data[f"{metric}_positive"].append(0)
+                    data[f"{metric}_negative"].append(0)
+                    data[f"{metric}_mean"].append(0)
+                else:
+                    pos, neg, mean = turn_score(turn_info[metric][t])
+                    data[f"{metric}_positive"].append(pos)
+                    data[f"{metric}_negative"].append(neg)
+                    data[f"{metric}_mean"].append(mean)
+
+    fig, ax = plt.subplots()
+    ax.plot(data['x'], data["Complete_mean"],
+            'o--', color='C0', label="Complete")
+    ax.fill_between(data['x'], data["Complete_positive"],
+                    data["Complete_negative"], color='C0', alpha=0.2)
+    ax.plot(data['x'], data["Not Complete_mean"],
+            'o--', color='C1', label="Not Complete")
+    ax.fill_between(data['x'], data["Not Complete_positive"],
+                    data["Not Complete_negative"], color='C1', alpha=0.2)
+    ax.plot(data['x'], data["all_mean"], 'o--', color='C2',
+            label="All")
+    ax.fill_between(data['x'], data["all_positive"],
+                    data["all_negative"], color='C2', alpha=0.2)
+    ax.legend()
+    ax.set_xlabel("turn")
+    ax.set_ylabel("Sentiment")
+    # plt.show()
+    plt.savefig("convlab/policy/emoTUS/fig.png")
+
+
+def turn_score(score_list):
+    count = len(score_list)
+    positive = 0
+    negative = 0
+    for s in score_list:
+        if s > 0:
+            positive += 1
+        if s < 0:
+            negative += -1
+    return positive/count, negative/count, np.mean(score_list)
+
+
+def insert_turn(turn_info, turn, emotion):
+    if turn not in turn_info:
+        turn_info[turn] = []
+    turn_info[turn].append(emotion)
+
+
+def emotion_score(emotion):
+    if emotion == "Neutral":
+        return 0
+    if emotion in ["Satisfied", "Excited"]:
+        return 1
+    return -1
+
+
+def plot(conversation):
+    pass
+
+
+def turn_level(dialog):
+    # metric: {emotion: count}
+    dialog_info = {}
+    for index in range(2, len(dialog), 2):
+        pre_usr = dialog[index-2]
+        sys = dialog[index-1]
+        cur_usr = dialog[index]
+        info = neglect_reply(pre_usr, sys, cur_usr)
+        append_info(dialog_info, info)
+        info = confirm(pre_usr, sys, cur_usr)
+        append_info(dialog_info, info)
+        info = miss_info(pre_usr, sys, cur_usr)
+        append_info(dialog_info, info)
+        if index > 2:
+            info = loop(dialog[index-3], sys, cur_usr)
+            append_info(dialog_info, info)
+
+    return dialog_info
+
+# provide wrong info
+# action length
+# incomplete info?
+
+
+def append_info(dialog_info, info):
+    if not info:
+        return
+    for emotion, metric in info.items():
+        if metric not in dialog_info:
+            dialog_info[metric] = {}
+        if emotion not in dialog_info[metric]:
+            dialog_info[metric][emotion] = 0
+        dialog_info[metric][emotion] += 1
+
+
+def get_inform(act):
+    inform = {}
+    for intent, domain, slot, value in act:
+        if intent not in ["inform", "recommend"]:
+            continue
+        if domain not in inform:
+            inform[domain] = []
+        inform[domain].append(slot)
+    return inform
+
+
+def get_request(act):
+    request = {}
+    for intent, domain, slot, _ in act:
+        if intent == "request":
+            if domain not in request:
+                request[domain] = []
+            request[domain].append(slot)
+    return request
+
+
+def neglect_reply(pre_usr, sys, cur_usr):
+    request = get_request(pre_usr["act"])
+    if not request:
+        return {}
+
+    system_inform = get_inform(sys["utt"])
+
+    for domain, slots in request.items():
+        if domain not in system_inform:
+            return {cur_usr["emotion"]: "neglect"}
+        for slot in slots:
+            if slot not in system_inform[domain]:
+                return {cur_usr["emotion"]: "neglect"}
+    return {cur_usr["emotion"]: "reply"}
+
+
+def miss_info(pre_usr, sys, cur_usr):
+    system_request = get_request(sys["utt"])
+    if not system_request:
+        return {}
+    user_inform = get_inform(pre_usr["act"])
+    for domain, slots in system_request.items():
+        if domain not in user_inform:
+            continue
+        for slot in slots:
+            if slot in user_inform[domain]:
+                return {cur_usr["emotion"]: "miss_info"}
+    return {}
+
+
+def confirm(pre_usr, sys, cur_usr):
+    user_inform = get_inform(pre_usr["act"])
+
+    if not user_inform:
+        return {}
+
+    system_inform = get_inform(sys["utt"])
+
+    for domain, slots in user_inform.items():
+        if domain not in system_inform:
+            continue
+        for slot in slots:
+            if slot in system_inform[domain]:
+                return {cur_usr["emotion"]: "confirm"}
+
+    return {cur_usr["emotion"]: "no confirm"}
+
+
+def loop(s0, s1, u1):
+    if s0 == s1:
+        return {u1["emotion"]: "loop"}
+
+
+def dict2csv(data):
+    r = {}
+    emotion = json.load(open("convlab/policy/emoTUS/emotion.json"))
+    for act, value in data.items():
+        temp = [0]*(len(emotion)+1)
+        for emo, count in value.items():
+            temp[emotion[emo]] = count
+        temp[-1] = sum(temp)
+        for i in range(len(emotion)):
+            temp[i] /= temp[-1]
+        r[act] = temp
+    dataframe = pd.DataFrame.from_dict(
+        r, orient='index', columns=[emo for emo in emotion]+["count"])
+    dataframe.to_csv(open("convlab/policy/emoTUS/act2emotion.csv", 'w'))
+
+
+def main():
+    args = arg_parser()
+    result = {}
+    conversation = json.load(open(args.file))
+    basic_info = basic_analysis(conversation)
+    result["basic_info"] = basic_info
+    print(basic_info)
+    advance_info = advance(conversation)
+    print(advance_info)
+    result["advance_info"] = advance_info
+    json.dump(result, open(
+        "convlab/policy/emoTUS/conversation_result.json", 'w'), indent=2)
+    dict2csv(advance_info)
+    get_turn_emotion(conversation)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/convlab/policy/emoTUS/dialogue_collector.py b/convlab/policy/emoTUS/dialogue_collector.py
index 07707377..3422bed1 100644
--- a/convlab/policy/emoTUS/dialogue_collector.py
+++ b/convlab/policy/emoTUS/dialogue_collector.py
@@ -14,7 +14,7 @@ def arg_parser():
     parser.add_argument("-N", "--num", type=int,
                         default=500, help="# of evaluation dialogue")
     parser.add_argument("--model", type=str,
-                        default="rule", help="# of evaluation dialogue")
+                        default="ppo", help="# of evaluation dialogue")
     return parser.parse_args()
 
 
@@ -25,7 +25,7 @@ def interact(model_name, config, seed=0, num_goals=500):
 
     if model_name == "rule":
         policy_sys = RulePolicy()
-    elif model_name == "PPO":
+    elif model_name == "ppo":
         from convlab.policy.ppo import PPO
         policy_sys = PPO(vectorizer=conf['vectorizer_sys_activated'])
 
@@ -91,8 +91,13 @@ if __name__ == "__main__":
     time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
     args = arg_parser()
     conversation = interact(model_name=args.model,
-                            config=args.config, num_goals=args.num)
-    json.dump(conversation,
-              open(os.path.join("convlab/policy/emoTUS",
-                   f"conversation-{time}.json"), 'w'),
+                            config=args.config,
+                            num_goals=args.num)
+    data = {"config": json.load(open(args.config)),
+            "conversation": conversation}
+    folder_name = os.path.join("convlab/policy/emoTUS", "conversation")
+    if not os.path.exists(folder_name):
+        os.makedirs(folder_name)
+    json.dump(data,
+              open(os.path.join(folder_name, f"{time}.json"), 'w'),
               indent=2)
diff --git a/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json b/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json
index 458b0628..5bb9e524 100644
--- a/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json
+++ b/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json
@@ -47,7 +47,8 @@
                 "model_checkpoint": "convlab/policy/emoTUS/unify/experiments/EmoUS_emowoz+dialmage_0_1",
                 "use_sentiment": false,
                 "add_persona": true,
-                "sample": false
+                "sample": false,
+                "weight": 1
             }
         }
     },
-- 
GitLab