# -*- coding: utf-8 -*- # Copyright 2020 DSML Group, Heinrich Heine University, Düsseldorf # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Extracting the Turn Encoder from the model checkpoint""" import os from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser import json import torch def main(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument('--model_dir', help='Location of the belief states', required=True) parser.add_argument('--output_dir', help='Output image path', default='calibration_plot.png') args = parser.parse_args() if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) reader = open(os.path.join(args.model_dir, 'config.json'), 'r') config = json.load(reader) reader.close() writer = open(os.path.join(args.output_dir, 'config.json'), 'w') json.dump(config, writer) writer.close() state_dict = torch.load(os.path.join(args.model_dir, 'pytorch_model.bin'), map_location='cpu') state_dict = {key: item for key, item in state_dict.items() if 'turn_encoder' in key} torch.save(state_dict, os.path.join(args.output_dir, 'pytorch_model.bin')) if __name__ == '__main__': main()