Skip to content
Snippets Groups Projects
Commit 73fad7e8 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 58bd6ce2
No related branches found
No related tags found
No related merge requests found
......@@ -65,7 +65,6 @@ def generate_result(model_checkpoint, data, stop=-1):
for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True):
if stop > 0 and i > stop:
break
if "satisfaction score" in input_text:
i += 1
inputs = tokenizer([input_text], return_tensors="pt", padding=True)
output = model.generate(input_ids=inputs["input_ids"],
......@@ -76,9 +75,10 @@ def generate_result(model_checkpoint, data, stop=-1):
if len(output) > 1:
print(output)
output = "illegal"
if "satisfaction score" in input_text:
output = tri_convert(output)
results.append({"input_text": input_text,
"preds": tri_convert(output),
"preds": output,
"label": target_text})
json.dump(results, open(os.path.join(
model_checkpoint, "emowoz_result.json"), 'w'))
......@@ -89,6 +89,7 @@ def read_result(result):
preds = []
label = []
for r in result:
if "satisfaction score" in r["input_text"]:
preds.append(r["preds"])
label.append(r["label"])
return preds, label
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment