From 622284b8da9ee97ecfb3c69199b8585283a692ec Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Mon, 6 Feb 2023 00:01:03 +0100 Subject: [PATCH] wip --- convlab/policy/ussT5/emowoz_evaluate.py | 68 ++++++++++++++++--------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py index 01cae7a8..5c2586f8 100644 --- a/convlab/policy/ussT5/emowoz_evaluate.py +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -28,34 +28,54 @@ def build_data(raw_data): for sentiment, index in json.load(open("convlab/policy/emoTUS/sentiment.json")).items(): sentiments[int(index)] = sentiment data = {"input_text": [], "target_text": []} - prefix = "satisfaction score: " - for d in raw_data: - utt = "" - turn_len = len(d["turns"]) - for index, turn in enumerate(d["turns"]): - if turn["speaker"] == "user": - if index == turn_len - 2: - break - if index == 0: - utt = prefix + turn["utterance"] + for prefix in ["satisfaction score: ", "action prediction: ", "utterance generation: "]: + for d in raw_data: + utt = "" + turn_len = len(d["turns"]) + for index, turn in enumerate(d["turns"]): + if turn["speaker"] == "user": + if index == turn_len - 2: + break + if index == 0: + utt = prefix + turn["utterance"] + else: + utt += ' ' + turn["utterance"] else: + if index == 0: + print("this should no happen (index == 0)") + utt = prefix + turn["utterance"] + if index == turn_len - 1: + print("this should no happen (index == turn_len - 1)") + continue + utt += ' ' + turn["utterance"] - else: - if index == 0: - print("this should no happen (index == 0)") - utt = prefix + turn["utterance"] - if index == turn_len - 1: - print("this should no happen (index == turn_len - 1)") - continue - - utt += ' ' + turn["utterance"] - - data["input_text"].append(utt) - data["target_text"].append( - sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]]) + + data["input_text"].append(utt) + if prefix == "satisfaction score: ": + data["target_text"].append( + sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]]) + elif prefix == "action prediction: ": + data["target_text"].append( + get_action(d["turns"][index+1]["dialogue_acts"])) + else: + data["target_text"].append( + d["turns"][index+1]["utterance"]) + + json.dump(data, open("convlab/policy/ussT5/emowoz-test.json", 'w'), indent=2) return data +def get_action(dialogue_acts): + acts = [] + for _, act in dialogue_acts.items(): + for a in act: + acts.append( + f"{a['domain'].capitalize()}-{a['intent'].capitalize()}") + if not acts: + return "None" + return ','.join(acts) + + def generate_result(model_checkpoint, data, stop=-1): tokenizer = T5Tokenizer.from_pretrained(model_checkpoint) model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) @@ -81,7 +101,7 @@ def generate_result(model_checkpoint, data, stop=-1): "preds": output, "label": target_text}) json.dump(results, open(os.path.join( - model_checkpoint, "emowoz_result.json"), 'w')) + model_checkpoint, "emowoz_result.json"), 'w'), indent=2) return results -- GitLab