Skip to content
Snippets Groups Projects
Commit 446a9675 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent d878c8fa
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ def arg_parser(): ...@@ -15,7 +15,7 @@ def arg_parser():
parser.add_argument("--model", type=str, default="", parser.add_argument("--model", type=str, default="",
help="model name") help="model name")
parser.add_argument("--data", type=str) parser.add_argument("--data", type=str)
parser.add_argument("--gen-file", type=str)
return parser.parse_args() return parser.parse_args()
...@@ -49,18 +49,27 @@ def bi_check(p, l): ...@@ -49,18 +49,27 @@ def bi_check(p, l):
return 0 return 0
def main(): def read_result(result):
args = arg_parser() preds = {'bi': [], "five": [], 'tri': []}
model_checkpoint = args.model label = {'bi': [], "five": [], 'tri': []}
for r in result:
p = r["preds"]
l = r["label"]
preds["five"].append(p)
preds["bi"].append(bi_f1(p))
preds["tri"].append(tri_convert(p))
label["five"].append(l)
label["bi"].append(bi_f1(l))
label["tri"].append(tri_convert(l))
return preds, label
def generate_result(model_checkpoint, data):
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint) tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
data = pd.read_csv(args.data, index_col=False).astype(str) data = pd.read_csv(data, index_col=False).astype(str)
preds = {'bi': [], "five": []}
label = {'bi': [], "five": []}
bi_f1_score = []
results = [] results = []
for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True): for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True):
if "satisfaction score" in input_text: if "satisfaction score" in input_text:
inputs = tokenizer([input_text], return_tensors="pt", padding=True) inputs = tokenizer([input_text], return_tensors="pt", padding=True)
...@@ -71,17 +80,23 @@ def main(): ...@@ -71,17 +80,23 @@ def main():
if len(output) > 1: if len(output) > 1:
print(output) print(output)
output = "illegal" output = "illegal"
label["five"].append(target_text)
preds["five"].append(output)
label["bi"].append(bi_f1(target_text))
preds["bi"].append(bi_f1(output))
bi_f1_score.append(bi_check(output, target_text))
results.append({"input_text": input_text, results.append({"input_text": input_text,
"preds": output, "preds": output,
"label": target_text}) "label": target_text})
json.dump(results, open(os.path.join(model_checkpoint, "result.json"))) return results
def main():
args = arg_parser()
if args.gen_file:
preds, label = read_result(json.load(open(args.gen_file)))
else:
results = generate_result(args.model, args.data)
preds, label = read_result(results)
macro_f1 = metrics.f1_score(label["five"], preds["five"], average="macro") macro_f1 = metrics.f1_score(label["five"], preds["five"], average="macro")
tri_f1 = metrics.f1_score(label["tri"], preds["tri"], average="macro")
f1 = metrics.f1_score(label["bi"], preds["bi"]) f1 = metrics.f1_score(label["bi"], preds["bi"])
sep_f1 = metrics.f1_score( sep_f1 = metrics.f1_score(
label["five"], preds["five"], average=None, label["five"], preds["five"], average=None,
...@@ -94,13 +109,14 @@ def main(): ...@@ -94,13 +109,14 @@ def main():
display_labels=['1', '2', '3', '4', '5']) display_labels=['1', '2', '3', '4', '5'])
disp.plot() disp.plot()
r = {"macro_f1": float(macro_f1), r = {"macro_f1": float(macro_f1),
"tri_f1": float(tri_f1),
"bi_f1": float(f1), "bi_f1": float(f1),
"sep_f1": list(sep_f1), "sep_f1": list(sep_f1),
"cm": [list(c) for c in list(cm)]} "cm": [list(c) for c in list(cm)]}
print(r) print(r)
dirname = "convlab/policy/uss-t5/" dirname = "convlab/policy/uss-t5/"
time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
plt.savefig(os.path.join(model_checkpoint, f"{time}-satisfied.png")) plt.savefig(os.path.join(args.model, f"{time}-satisfied.png"))
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment