Skip to content
Snippets Groups Projects
Select Git revision
  • 0b0c0b84a58ebcf1fc8c7d8dda7fec34d15c47c4
  • master default protected
  • release/1.1.4
  • release/1.1.3
  • release/1.1.1
  • 1.4.2
  • 1.4.1
  • 1.4.0
  • 1.3.0
  • 1.2.1
  • 1.2.0
  • 1.1.5
  • 1.1.4
  • 1.1.3
  • 1.1.1
  • 1.1.0
  • 1.0.9
  • 1.0.8
  • 1.0.7
  • v1.0.5
  • 1.0.5
21 results

ComplexExpressionTest.java

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    merge_predict_res.py 2.19 KiB
    import json
    import os
    from convlab.util import load_dataset, load_nlg_data
    
    
    def merge(dataset_names, speaker, save_dir, context_window_size, predict_result):
        assert os.path.exists(predict_result)
        
        if save_dir is None:
            save_dir = os.path.dirname(predict_result)
        else:
            os.makedirs(save_dir, exist_ok=True)
        predict_result = [json.loads(x)['predictions'].strip() for x in open(predict_result)]
    
        merged = []
        i = 0
        for dataset_name in dataset_names.split('+'):
            print(dataset_name)
            dataset = load_dataset(dataset_name, args.dial_ids_order)
            data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
        
            for sample in data:
                if all([len(sample['dialogue_acts'][da_type])==0 for da_type in sample['dialogue_acts']]):
                    continue
                sample['predictions'] = {'utterance': predict_result[i]}
                i += 1
    
        json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
    
    
    if __name__ == '__main__':
        from argparse import ArgumentParser
        parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation")
        parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
        parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
        parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
        parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
        parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
        parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
        args = parser.parse_args()
        print(args)
        merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)