diff --git a/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py index 6de31fbea9825b54b2e29bcf51e035a571de1c6b..a6be2242357d9eb9b0f5a317c69df63a5013efd3 100755 --- a/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py +++ b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py @@ -15,7 +15,7 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): predict_result = json.load(open(predict_result)) for sample, prediction in zip(data, predict_result): - sample['dialogue_acts_prediction'] = prediction['predict'] + sample['predictions'] = {'dialogue_acts': prediction['predict']} json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)