Skip to content
Snippets Groups Projects
Commit 622284b8 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 73fad7e8
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ def build_data(raw_data):
for sentiment, index in json.load(open("convlab/policy/emoTUS/sentiment.json")).items():
sentiments[int(index)] = sentiment
data = {"input_text": [], "target_text": []}
prefix = "satisfaction score: "
for prefix in ["satisfaction score: ", "action prediction: ", "utterance generation: "]:
for d in raw_data:
utt = ""
turn_len = len(d["turns"])
......@@ -51,11 +51,31 @@ def build_data(raw_data):
utt += ' ' + turn["utterance"]
data["input_text"].append(utt)
if prefix == "satisfaction score: ":
data["target_text"].append(
sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]])
elif prefix == "action prediction: ":
data["target_text"].append(
get_action(d["turns"][index+1]["dialogue_acts"]))
else:
data["target_text"].append(
d["turns"][index+1]["utterance"])
json.dump(data, open("convlab/policy/ussT5/emowoz-test.json", 'w'), indent=2)
return data
def get_action(dialogue_acts):
acts = []
for _, act in dialogue_acts.items():
for a in act:
acts.append(
f"{a['domain'].capitalize()}-{a['intent'].capitalize()}")
if not acts:
return "None"
return ','.join(acts)
def generate_result(model_checkpoint, data, stop=-1):
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
......@@ -81,7 +101,7 @@ def generate_result(model_checkpoint, data, stop=-1):
"preds": output,
"label": target_text})
json.dump(results, open(os.path.join(
model_checkpoint, "emowoz_result.json"), 'w'))
model_checkpoint, "emowoz_result.json"), 'w'), indent=2)
return results
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment