# coding=utf-8 # # Copyright 2020-2022 Heinrich Heine University Duesseldorf # # Part of this code is based on the source code of BERT-DST # (arXiv:1907.03040) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import json import sys import numpy as np import re import math import argparse def load_dataset_config(dataset_config): with open(dataset_config, "r", encoding='utf-8') as f: raw_config = json.load(f) return raw_config['class_types'], raw_config['slots'], raw_config['label_maps'], raw_config['noncategorical'], raw_config['boolean'] def tokenize(text): if "\u0120" in text: text = re.sub(" ", "", text) text = re.sub("\u0120", " ", text) else: text = re.sub(" ##", "", text) text = text.strip() return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0]) def filter_sequences(seqs, mode="first"): if mode == "first": return tokenize(seqs[0][0][0]) elif mode == "max_first": max_conf = 0 max_idx = 0 for e_itr, e in enumerate(seqs[0]): if e[1] > max_conf: max_conf = e[1] max_idx = e_itr return tokenize(seqs[0][max_idx][0]) elif mode == "max": max_conf = 0 max_t_idx = 0 for t_itr, t in enumerate(seqs): for e_itr, e in enumerate(t): if e[1] > max_conf: max_conf = e[1] max_t_idx = t_itr max_idx = e_itr return tokenize(seqs[max_t_idx][max_idx][0]) else: print("WARN: mode %s unknown. Aborting." % mode) exit() def is_in_list(tok, value): found = False tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0] value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0] tok_len = len(tok_list) value_len = len(value_list) for i in range(tok_len + 1 - value_len): if tok_list[i:i + value_len] == value_list: found = True break return found def check_slot_inform(value_label, inform_label, label_maps): value = inform_label if value_label == inform_label: value = value_label elif is_in_list(inform_label, value_label): value = value_label elif is_in_list(value_label, inform_label): value = value_label elif inform_label in label_maps: for inform_label_variant in label_maps[inform_label]: if value_label == inform_label_variant: value = value_label break elif is_in_list(inform_label_variant, value_label): value = value_label break elif is_in_list(value_label, inform_label_variant): value = value_label break elif value_label in label_maps: for value_label_variant in label_maps[value_label]: if value_label_variant == inform_label: value = value_label break elif is_in_list(inform_label, value_label_variant): value = value_label break elif is_in_list(value_label_variant, inform_label): value = value_label break return value def match(gt, pd, label_maps): # We want to be as conservative as possible here. # We only allow maps according to label_maps and # tolerate the absence/presence of the definite article. if pd[:4] == "the " and gt == pd[4:]: return True if gt[:4] == "the " and gt[4:] == pd: return True if gt in label_maps: for variant in label_maps[gt]: if variant == pd: return True return False def get_joint_slot_correctness(fp, args, class_types, label_maps, key_class_label_id='class_label_id', key_class_prediction='class_prediction', key_start_pos='start_pos', key_start_prediction='start_prediction', key_start_confidence='start_confidence', key_refer_id='refer_id', key_refer_prediction='refer_prediction', key_slot_groundtruth='slot_groundtruth', key_slot_prediction='slot_prediction', key_slot_dist_prediction='slot_dist_prediction', key_slot_dist_confidence='slot_dist_confidence', key_value_prediction='value_prediction', key_value_groundtruth='value_groundtruth', key_value_confidence='value_confidence', key_slot_value_prediction='slot_value_prediction', key_slot_value_confidence='slot_value_confidence', noncategorical=False, boolean=False): with open(fp) as f: preds = json.load(f) class_correctness = [[] for cl in range(len(class_types) + 1)] confusion_matrix = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))] pos_correctness = [] refer_correctness = [] val_correctness = [] total_correctness = [] c_tp = {ct: 0 for ct in range(len(class_types))} c_tn = {ct: 0 for ct in range(len(class_types))} c_fp = {ct: 0 for ct in range(len(class_types))} c_fn = {ct: 0 for ct in range(len(class_types))} s_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} s_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} a_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} a_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} value_match_cnt = 0 for pred in preds: guid = pred['guid'] # List: set_type, dialogue_idx, turn_idx turn_gt_class = pred[key_class_label_id] turn_pd_class = pred[key_class_prediction] gt_start_pos = pred[key_start_pos] pd_start_pos = pred[key_start_prediction] pd_start_conf = pred[key_start_confidence] gt_refer = pred[key_refer_id] pd_refer = pred[key_refer_prediction] gt_slot = tokenize(pred[key_slot_groundtruth]) pd_slot = pred[key_slot_prediction] pd_slot_dist_pred = tokenize(pred[key_slot_dist_prediction]) pd_slot_dist_conf = float(pred[key_slot_dist_confidence]) pd_slot_value_pred = tokenize(pred[key_slot_value_prediction]) pd_slot_value_conf = pred[key_slot_value_confidence] pd_slot_raw = pd_slot if isinstance(pd_slot, list): pd_slot = filter_sequences(pd_slot, mode="max") else: pd_slot = tokenize(pd_slot) # Make sure the true turn labels are contained in the prediction json file! joint_gt_slot = gt_slot # Sequence tagging confidence if len(pd_start_pos) > 0: avg_s_conf = np.mean(pd_start_conf) if avg_s_conf == 0.0: avg_s_conf += 1e-8 s_c_bin = "%.1f" % (math.ceil(avg_s_conf * 10) / 10) if gt_start_pos == pd_start_pos: s_confidence_bins[s_c_bin] += 1 s_confidence_cnts[s_c_bin] += 1 # Distance based value matching confidence if pd_slot_dist_conf == 0.0: pd_slot_dist_conf += 1e-8 c_bin = "%.1f" % (math.ceil(pd_slot_dist_conf * 10) / 10) if joint_gt_slot == pd_slot_dist_pred: confidence_bins[c_bin] += 1 confidence_cnts[c_bin] += 1 # Attention based value matching confidence if pd_slot_value_conf == 0.0: pd_slot_value_conf += 1e-8 c_bin = "%.1f" % (math.ceil(pd_slot_value_conf * 10) / 10) if joint_gt_slot == pd_slot_value_pred: a_confidence_bins[c_bin] += 1 a_confidence_cnts[c_bin] += 1 if guid[-1] == '0': # First turn, reset the slots joint_pd_slot = 'none' # If turn_pd_class or a value to be copied is "none", do not update the dialog state. if turn_pd_class == class_types.index('none'): pass elif turn_pd_class == class_types.index('dontcare'): if not boolean: joint_pd_slot = 'dontcare' elif turn_pd_class == class_types.index('copy_value'): if not boolean: if pd_slot not in ["< none >", "[ NONE ]"]: joint_pd_slot = pd_slot elif 'true' in class_types and turn_pd_class == class_types.index('true'): if boolean: joint_pd_slot = 'true' elif 'false' in class_types and turn_pd_class == class_types.index('false'): if boolean: joint_pd_slot = 'false' elif 'refer' in class_types and turn_pd_class == class_types.index('refer'): if not boolean: if pd_slot[0:2] == "§§": if pd_slot[2:].strip() != 'none': joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:].strip(), label_maps) elif pd_slot != 'none': joint_pd_slot = pd_slot elif 'inform' in class_types and turn_pd_class == class_types.index('inform'): if not boolean: if pd_slot[0:2] == "§§": if pd_slot[2:].strip() != 'none': joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:].strip(), label_maps) elif 'request' in class_types and turn_pd_class == class_types.index('request'): pass else: print("ERROR: Unexpected class_type. Aborting.") exit() # Value matching if args.confidence_threshold < 1.0 and turn_pd_class == class_types.index('copy_value') and not boolean: # Treating categorical slots if not noncategorical: max_conf = max(np.mean(pd_start_conf), pd_slot_dist_conf, pd_slot_value_conf) if max_conf == pd_slot_dist_conf and max_conf > args.confidence_threshold: joint_pd_slot = tokenize(pd_slot_dist_pred) value_match_cnt += 1 elif max_conf == pd_slot_value_conf and max_conf > args.confidence_threshold: joint_pd_slot = tokenize(pd_slot_value_pred) value_match_cnt += 1 # Treating all slots (including categorical slots) if pd_slot_dist_conf > args.confidence_threshold: joint_pd_slot = tokenize(pd_slot_dist_pred) value_match_cnt += 1 total_correct = True # Check the per turn correctness of the class_type prediction if turn_gt_class == turn_pd_class: class_correctness[turn_gt_class].append(1.0) class_correctness[-1].append(1.0) c_tp[turn_gt_class] += 1 # Only where there is a span, we check its per turn correctness if turn_gt_class == class_types.index('copy_value'): if gt_start_pos == pd_start_pos: pos_correctness.append(1.0) else: pos_correctness.append(0.0) # Only where there is a referral, we check its per turn correctness if 'refer' in class_types and turn_gt_class == class_types.index('refer'): if gt_refer == pd_refer: refer_correctness.append(1.0) print(" [%s] Correct referral: %s | %s" % (guid, gt_refer, pd_refer)) else: refer_correctness.append(0.0) print(" [%s] Incorrect referral: %s | %s" % (guid, gt_refer, pd_refer)) else: if turn_gt_class == class_types.index('copy_value'): pos_correctness.append(0.0) if 'refer' in class_types and turn_gt_class == class_types.index('refer'): refer_correctness.append(0.0) class_correctness[turn_gt_class].append(0.0) class_correctness[-1].append(0.0) confusion_matrix[turn_gt_class][turn_pd_class].append(1.0) c_fn[turn_gt_class] += 1 c_fp[turn_pd_class] += 1 for cc in range(len(class_types)): if cc != turn_gt_class and cc != turn_pd_class: c_tn[cc] += 1 # Check the joint slot correctness. # If the value label is not none, then we need to have a value prediction. # Even if the class_type is 'none', there can still be a value label, # it might just not be pointable in the current turn. It might however # be referrable and thus predicted correctly. if joint_gt_slot == joint_pd_slot: val_correctness.append(1.0) elif joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false': is_match = match(joint_gt_slot, joint_pd_slot, label_maps) if not is_match: val_correctness.append(0.0) total_correct = False print(" [%s] Incorrect value (variant): %s (turn class: %s) | %s (turn class: %s) | %.2f %s %.2f %s %s %s" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class, np.mean(pd_start_conf), pd_slot_raw, pd_slot_dist_conf, pd_slot_dist_pred, "%.2f" % pd_slot_value_conf if pd_slot_value_pred != "" else "", pd_slot_value_pred)) else: val_correctness.append(1.0) else: val_correctness.append(0.0) total_correct = False print(" [%s] Incorrect value: %s (turn class: %s) | %s (turn class: %s) | %.2f %s %.2f %s %s %s" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class, np.mean(pd_start_conf), pd_slot_raw, pd_slot_dist_conf, pd_slot_dist_pred, "%.2f" % pd_slot_value_conf if pd_slot_value_pred != "" else "", pd_slot_value_pred)) total_correctness.append(1.0 if total_correct else 0.0) # Account for empty lists (due to no instances of spans or referrals being seen) if pos_correctness == []: pos_correctness.append(1.0) if refer_correctness == []: refer_correctness.append(1.0) for ct in range(len(class_types)): if c_tp[ct] + c_fp[ct] > 0: precision = c_tp[ct] / (c_tp[ct] + c_fp[ct]) else: precision = 1.0 if c_tp[ct] + c_fn[ct] > 0: recall = c_tp[ct] / (c_tp[ct] + c_fn[ct]) else: recall = 1.0 if precision + recall > 0: f1 = 2 * ((precision * recall) / (precision + recall)) else: f1 = 1.0 if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0: acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct]) else: acc = 1.0 print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" % (class_types[ct], ct, recall, np.sum(class_correctness[ct]), len(class_correctness[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct])) print("Confusion matrix:") for cl in range(len(class_types)): print(" %s" % (cl), end="") print("") for cl_a in range(len(class_types)): print("%s " % (cl_a), end="") for cl_b in range(len(class_types)): if len(class_correctness[cl_a]) > 0: print("%.2f " % (np.sum(confusion_matrix[cl_a][cl_b]) / len(class_correctness[cl_a])), end="") else: print("---- ", end="") print("") print("Confidence bins for sequence tagging:") print(" bin cor") for c in s_confidence_bins: print(" %s %.2f (%d of %d)" % (c, s_confidence_bins[c] / (s_confidence_cnts[c] + 1e-8), s_confidence_bins[c], s_confidence_cnts[c])) print("Confidence bins for distance based value matching:") print(" bin cor") for c in confidence_bins: print(" %s %.2f (%d of %d)" % (c, confidence_bins[c] / (confidence_cnts[c] + 1e-8), confidence_bins[c], confidence_cnts[c])) print("Confidence bins for attention based value matching:") print(" bin cor") for c in a_confidence_bins: print(" %s %.2f (%d of %d)" % (c, a_confidence_bins[c] / (a_confidence_cnts[c] + 1e-8), a_confidence_bins[c], a_confidence_cnts[c])) print("Values replaced by value matching:", value_match_cnt) return np.asarray(total_correctness), np.asarray(val_correctness), np.asarray(class_correctness), np.asarray(pos_correctness), np.asarray(refer_correctness), np.asarray(confusion_matrix), c_tp, c_tn, c_fp, c_fn, s_confidence_bins, s_confidence_cnts, confidence_bins, confidence_cnts, a_confidence_bins, a_confidence_cnts if __name__ == "__main__": acc_list = [] s_acc_list = [] key_class_label_id = 'class_label_id_%s' key_class_prediction = 'class_prediction_%s' key_start_pos = 'start_pos_%s' key_start_prediction = 'start_prediction_%s' key_start_confidence = 'start_confidence_%s' key_refer_id = 'refer_id_%s' key_refer_prediction = 'refer_prediction_%s' key_slot_groundtruth = 'slot_groundtruth_%s' key_slot_prediction = 'slot_prediction_%s' key_slot_dist_prediction = 'slot_dist_prediction_%s' key_slot_dist_confidence = 'slot_dist_confidence_%s' key_value_prediction = 'value_prediction_%s' key_value_groundtruth = 'value_label_id_%s' key_value_confidence = 'value_confidence_%s' key_slot_value_prediction = 'slot_value_prediction_%s' key_slot_value_confidence = 'slot_value_confidence_%s' parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--dataset_config", default=None, type=str, required=True, help="Dataset configuration file.") parser.add_argument("--file_list", default=None, type=str, required=True, help="List of input files.") # Other parameters parser.add_argument("--confidence_threshold", default=1.0, type=float, help="Threshold for value matching confidence. 1.0 means no value matching is used.") args = parser.parse_args() assert args.confidence_threshold >= 0.0 and args.confidence_threshold <= 1.0 class_types, slots, label_maps, noncategorical, boolean = load_dataset_config(args.dataset_config) # Prepare label_maps label_maps_tmp = {} for v in label_maps: label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]] label_maps = label_maps_tmp for fp in sorted(glob.glob(args.file_list)): # Infer slot list from data if not provided. if len(slots) == 0: with open(fp) as f: preds = json.load(f) for e in preds[0]: slot = re.match("^slot_groundtruth_(.*)$", e) slot = slot[1] if slot else None if slot and slot not in slots: slots.append(slot) print(fp) goal_correctness = 1.0 cls_acc = [[] for cl in range(len(class_types))] cls_conf = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))] c_tp = {ct: 0 for ct in range(len(class_types))} c_tn = {ct: 0 for ct in range(len(class_types))} c_fp = {ct: 0 for ct in range(len(class_types))} c_fn = {ct: 0 for ct in range(len(class_types))} s_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} s_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} a_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} a_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} for slot in slots: tot_cor, joint_val_cor, cls_cor, pos_cor, ref_cor, \ conf_mat, ctp, ctn, cfp, cfn, \ scbins, sccnts, cbins, ccnts, acbins, accnts = get_joint_slot_correctness(fp, args, class_types, label_maps, key_class_label_id=(key_class_label_id % slot), key_class_prediction=(key_class_prediction % slot), key_start_pos=(key_start_pos % slot), key_start_prediction=(key_start_prediction % slot), key_start_confidence=(key_start_confidence % slot), key_refer_id=(key_refer_id % slot), key_refer_prediction=(key_refer_prediction % slot), key_slot_groundtruth=(key_slot_groundtruth % slot), key_slot_prediction=(key_slot_prediction % slot), key_slot_dist_prediction=(key_slot_dist_prediction % slot), key_slot_dist_confidence=(key_slot_dist_confidence % slot), key_value_prediction=(key_value_prediction % slot), key_value_groundtruth=(key_value_groundtruth % slot), key_value_confidence=(key_value_confidence % slot), key_slot_value_prediction=(key_slot_value_prediction % slot), key_slot_value_confidence=(key_slot_value_confidence % slot), noncategorical=slot in noncategorical, boolean=slot in boolean) print('%s: joint slot acc: %g, joint value acc: %g, turn class acc: %g, turn position acc: %g, turn referral acc: %g' % (slot, np.mean(tot_cor), np.mean(joint_val_cor), np.mean(cls_cor[-1]), np.mean(pos_cor), np.mean(ref_cor))) goal_correctness *= tot_cor for cl_a in range(len(class_types)): cls_acc[cl_a] += cls_cor[cl_a] for cl_b in range(len(class_types)): cls_conf[cl_a][cl_b] += list(conf_mat[cl_a][cl_b]) c_tp[cl_a] += ctp[cl_a] c_tn[cl_a] += ctn[cl_a] c_fp[cl_a] += cfp[cl_a] c_fn[cl_a] += cfn[cl_a] for c in scbins: s_confidence_bins[c] += scbins[c] s_confidence_cnts[c] += sccnts[c] for c in cbins: confidence_bins[c] += cbins[c] confidence_cnts[c] += ccnts[c] for c in cbins: a_confidence_bins[c] += acbins[c] a_confidence_cnts[c] += accnts[c] for ct in range(len(class_types)): if c_tp[ct] + c_fp[ct] > 0: precision = c_tp[ct] / (c_tp[ct] + c_fp[ct]) else: precision = 1.0 if c_tp[ct] + c_fn[ct] > 0: recall = c_tp[ct] / (c_tp[ct] + c_fn[ct]) else: recall = 1.0 if precision + recall > 0: f1 = 2 * ((precision * recall) / (precision + recall)) else: f1 = 1.0 if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0: acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct]) else: acc = 1.0 print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" % (class_types[ct], ct, recall, np.sum(cls_acc[ct]), len(cls_acc[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct])) print("Confusion matrix:") for cl in range(len(class_types)): print(" %s" % (cl), end="") print("") for cl_a in range(len(class_types)): print("%s " % (cl_a), end="") for cl_b in range(len(class_types)): if len(cls_acc[cl_a]) > 0: print("%.2f " % (np.sum(cls_conf[cl_a][cl_b]) / len(cls_acc[cl_a])), end="") else: print("---- ", end="") print("") print("Confidence bins for sequence tagging:") print(" bin cor") for c in s_confidence_bins: print(" %s %.2f (%d of %d)" % (c, s_confidence_bins[c] / (s_confidence_cnts[c] + 1e-8), s_confidence_bins[c], s_confidence_cnts[c])) print("Confidence bins for distance based value matching:") print(" bin cor") for c in confidence_bins: print(" %s %.2f (%d of %d)" % (c, confidence_bins[c] / (confidence_cnts[c] + 1e-8), confidence_bins[c], confidence_cnts[c])) print("Confidence bins for attention based value matching:") print(" bin cor") for c in a_confidence_bins: print(" %s %.2f (%d of %d)" % (c, a_confidence_bins[c] / (a_confidence_cnts[c] + 1e-8), a_confidence_bins[c], a_confidence_cnts[c])) acc = np.mean(goal_correctness) acc_list.append((fp, acc)) acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True) for (fp, acc) in acc_list_s: print('Joint goal acc: %g, %s' % (acc, fp))