Skip to content
Snippets Groups Projects
Select Git revision
  • 8e8eba91a073d5e969d97787c737e822c62e1bfc
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

calibration_plots.py

Blame
  • user avatar
    zqwerty authored
    8e8eba91
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    calibration_plots.py 3.74 KiB
    # -*- coding: utf-8 -*-
    # Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    """Calibration Plot plotting script"""
    
    import os
    from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
    
    import torch
    from matplotlib import pyplot as plt
    
    
    def main():
        parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
        parser.add_argument('--data_dir', help='Location of the belief states', required=True)
        parser.add_argument('--output', help='Output image path', default='calibration_plot.png')
        parser.add_argument('--n_bins', help='Number of bins', default=10, type=int)
        args = parser.parse_args()
    
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        path = args.data_dir
    
        models = os.listdir(path)
        models = [os.path.join(path, model, 'test.belief') for model in models]
    
        fig = plt.figure(figsize=(14,8))
        font=20
        plt.tick_params(labelsize=font-2)
        linestyle = ['-', ':', (0, (3, 5, 1, 5)), '-.', (0, (5, 10))]
        for i, model in enumerate(models):
            conf, acc = get_calibration(model, device, n_bins=args.n_bins)
            name = model.split('/')[-2].strip()
            print(name, conf, acc)
            plt.plot(conf, acc, label=name, linestyle=linestyle[i], linewidth=3)
    
        plt.plot(torch.tensor([0,1]), torch.tensor([0,1]), linestyle='--', color='black', linewidth=3)
        plt.xlabel('Confidence', fontsize=font)
        plt.ylabel('Joint Goal Accuracy', fontsize=font)
        plt.legend(fontsize=font)
    
        plt.savefig(args.output)
    
    
    def get_calibration(path, device, n_bins=10, temperature=1.00):
        logits = torch.load(path, map_location=device)
        y_true = logits['labels']
        logits = logits['belief_states']
    
        y_pred = {slot: logits[slot].reshape(-1, logits[slot].size(-1)).argmax(-1) for slot in logits}
        goal_acc = {slot: (y_pred[slot] == y_true[slot].reshape(-1)).int() for slot in y_pred}
        goal_acc = sum([goal_acc[slot] for slot in goal_acc])
        goal_acc = (goal_acc == len(y_true)).int()
    
        scores = [logits[slot].reshape(-1, logits[slot].size(-1)).max(-1)[0].unsqueeze(0) for slot in logits]
        scores = torch.cat(scores, 0).min(0)[0]
    
        step = 1.0 / float(n_bins)
        bin_ranges = torch.arange(0.0, 1.0 + 1e-10, step)
        bins = []
        for b in range(n_bins):
            lower, upper = bin_ranges[b], bin_ranges[b + 1]
            if b == 0:
                ids = torch.where((scores >= lower) * (scores <= upper))[0]
            else:
                ids = torch.where((scores > lower) * (scores <= upper))[0]
            bins.append(ids)
    
        conf = [0.0]
        for b in bins:
            if b.size(0) > 0:
                l = scores[b]
                conf.append(l.mean())
            else:
                conf.append(-1)
        conf = torch.tensor(conf)
    
        slot = [s for s in y_true][0]
        acc = [0.0]
        for b in bins:
            if b.size(0) > 0:
                acc_ = goal_acc[b]
                acc_ = acc_[y_true[slot].reshape(-1)[b] >= 0]
                if acc_.size(0) >= 0:
                    acc.append(acc_.float().mean())
                else:
                    acc.append(-1)
            else:
                acc.append(-1)
        acc = torch.tensor(acc)
    
        conf = conf[acc != -1]
        acc = acc[acc != -1]
    
        return conf, acc
    
    
    if __name__ == '__main__':
        main()