# -*- coding: utf-8 -*-
# Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extracting the Turn Encoder from the model checkpoint"""

import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import json

import torch


def main():
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    parser.add_argument('--model_dir', help='Location of the belief states', required=True)
    parser.add_argument('--output_dir', help='Output image path', default='calibration_plot.png')
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    
    reader = open(os.path.join(args.model_dir, 'config.json'), 'r')
    config = json.load(reader)
    reader.close()
    writer = open(os.path.join(args.output_dir, 'config.json'), 'w')
    json.dump(config, writer)
    writer.close()

    state_dict = torch.load(os.path.join(args.model_dir, 'pytorch_model.bin'), map_location='cpu')
    state_dict = {key: item for key, item in state_dict.items() if 'turn_encoder' in key}
    torch.save(state_dict, os.path.join(args.output_dir, 'pytorch_model.bin'))


if __name__ == '__main__':
    main()