From 73fad7e8827577f53e0911ae0a51425822e7ffbc Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Sun, 5 Feb 2023 23:40:39 +0100 Subject: [PATCH] wip --- convlab/policy/ussT5/emowoz_evaluate.py | 33 +++++++++++++------------ 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py index 30ead5be..01cae7a8 100644 --- a/convlab/policy/ussT5/emowoz_evaluate.py +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -65,21 +65,21 @@ def generate_result(model_checkpoint, data, stop=-1): for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True): if stop > 0 and i > stop: break + i += 1 + inputs = tokenizer([input_text], return_tensors="pt", padding=True) + output = model.generate(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + do_sample=False) + output = tokenizer.batch_decode( + output, skip_special_tokens=True)[0] + if len(output) > 1: + print(output) + output = "illegal" if "satisfaction score" in input_text: - i += 1 - inputs = tokenizer([input_text], return_tensors="pt", padding=True) - output = model.generate(input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - do_sample=False) - output = tokenizer.batch_decode( - output, skip_special_tokens=True)[0] - if len(output) > 1: - print(output) - output = "illegal" - - results.append({"input_text": input_text, - "preds": tri_convert(output), - "label": target_text}) + output = tri_convert(output) + results.append({"input_text": input_text, + "preds": output, + "label": target_text}) json.dump(results, open(os.path.join( model_checkpoint, "emowoz_result.json"), 'w')) return results @@ -89,8 +89,9 @@ def read_result(result): preds = [] label = [] for r in result: - preds.append(r["preds"]) - label.append(r["label"]) + if "satisfaction score" in r["input_text"]: + preds.append(r["preds"]) + label.append(r["label"]) return preds, label -- GitLab