diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py index 30ead5be5ee174b4659452ee65ff7af88065dd64..01cae7a82fa1c6c43c25b9ce01d1f2f5f8ae4029 100644 --- a/convlab/policy/ussT5/emowoz_evaluate.py +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -65,21 +65,21 @@ def generate_result(model_checkpoint, data, stop=-1): for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True): if stop > 0 and i > stop: break + i += 1 + inputs = tokenizer([input_text], return_tensors="pt", padding=True) + output = model.generate(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + do_sample=False) + output = tokenizer.batch_decode( + output, skip_special_tokens=True)[0] + if len(output) > 1: + print(output) + output = "illegal" if "satisfaction score" in input_text: - i += 1 - inputs = tokenizer([input_text], return_tensors="pt", padding=True) - output = model.generate(input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - do_sample=False) - output = tokenizer.batch_decode( - output, skip_special_tokens=True)[0] - if len(output) > 1: - print(output) - output = "illegal" - - results.append({"input_text": input_text, - "preds": tri_convert(output), - "label": target_text}) + output = tri_convert(output) + results.append({"input_text": input_text, + "preds": output, + "label": target_text}) json.dump(results, open(os.path.join( model_checkpoint, "emowoz_result.json"), 'w')) return results @@ -89,8 +89,9 @@ def read_result(result): preds = [] label = [] for r in result: - preds.append(r["preds"]) - label.append(r["label"]) + if "satisfaction score" in r["input_text"]: + preds.append(r["preds"]) + label.append(r["label"]) return preds, label