diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py index 01cae7a82fa1c6c43c25b9ce01d1f2f5f8ae4029..5c2586f83f3d4c7fd29a8f55a90ee5790cfb6edd 100644 --- a/convlab/policy/ussT5/emowoz_evaluate.py +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -28,34 +28,54 @@ def build_data(raw_data): for sentiment, index in json.load(open("convlab/policy/emoTUS/sentiment.json")).items(): sentiments[int(index)] = sentiment data = {"input_text": [], "target_text": []} - prefix = "satisfaction score: " - for d in raw_data: - utt = "" - turn_len = len(d["turns"]) - for index, turn in enumerate(d["turns"]): - if turn["speaker"] == "user": - if index == turn_len - 2: - break - if index == 0: - utt = prefix + turn["utterance"] + for prefix in ["satisfaction score: ", "action prediction: ", "utterance generation: "]: + for d in raw_data: + utt = "" + turn_len = len(d["turns"]) + for index, turn in enumerate(d["turns"]): + if turn["speaker"] == "user": + if index == turn_len - 2: + break + if index == 0: + utt = prefix + turn["utterance"] + else: + utt += ' ' + turn["utterance"] else: + if index == 0: + print("this should no happen (index == 0)") + utt = prefix + turn["utterance"] + if index == turn_len - 1: + print("this should no happen (index == turn_len - 1)") + continue + utt += ' ' + turn["utterance"] - else: - if index == 0: - print("this should no happen (index == 0)") - utt = prefix + turn["utterance"] - if index == turn_len - 1: - print("this should no happen (index == turn_len - 1)") - continue - - utt += ' ' + turn["utterance"] - - data["input_text"].append(utt) - data["target_text"].append( - sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]]) + + data["input_text"].append(utt) + if prefix == "satisfaction score: ": + data["target_text"].append( + sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]]) + elif prefix == "action prediction: ": + data["target_text"].append( + get_action(d["turns"][index+1]["dialogue_acts"])) + else: + data["target_text"].append( + d["turns"][index+1]["utterance"]) + + json.dump(data, open("convlab/policy/ussT5/emowoz-test.json", 'w'), indent=2) return data +def get_action(dialogue_acts): + acts = [] + for _, act in dialogue_acts.items(): + for a in act: + acts.append( + f"{a['domain'].capitalize()}-{a['intent'].capitalize()}") + if not acts: + return "None" + return ','.join(acts) + + def generate_result(model_checkpoint, data, stop=-1): tokenizer = T5Tokenizer.from_pretrained(model_checkpoint) model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) @@ -81,7 +101,7 @@ def generate_result(model_checkpoint, data, stop=-1): "preds": output, "label": target_text}) json.dump(results, open(os.path.join( - model_checkpoint, "emowoz_result.json"), 'w')) + model_checkpoint, "emowoz_result.json"), 'w'), indent=2) return results