diff --git a/convlab/policy/plot_results/example_map.json b/convlab/policy/plot_results/example_map.json index 6148c50a5c7f260f1c2a52b2abd13d881d2ab9e7..72b5c6cc721b487cd313d2af7731b39607230a1a 100644 --- a/convlab/policy/plot_results/example_map.json +++ b/convlab/policy/plot_results/example_map.json @@ -4,7 +4,7 @@ "legend": "PPO" }, { - "dir": "vtrace", - "legend": "Vtrace" + "dir": "pg", + "legend": "PG" } ] \ No newline at end of file diff --git a/convlab/policy/plot_results/plot.py b/convlab/policy/plot_results/plot.py index 7a1a10db22d209a913dddf284ca584078978361c..8dca2b84b79ff2c150ec4cb7389bc6e6ae43d5f4 100644 --- a/convlab/policy/plot_results/plot.py +++ b/convlab/policy/plot_results/plot.py @@ -18,7 +18,7 @@ def get_args(): parser.add_argument('--tb-dir', type=str, default="TB_summary", help='The last dir for tensorboard files') parser.add_argument("--map-file", type=str, default="results/map.json") - parser.add_argument("--out-file", type=str, default="results/fig") + parser.add_argument("--out-file", type=str, default="results/") parser.add_argument("--max-dialogues", type=int, default=0) parser.add_argument("--fill-between", type=float, default=0.3, help="the transparency of the std err area") @@ -104,6 +104,7 @@ if __name__ == "__main__": for plot_type in ["complete_rate", "success_rate", 'turns', 'avg_return']: file_name, file_extension = os.path.splitext(args.out_file) + os.makedirs(file_name, exist_ok=True) fig_name = f"{file_name}_{plot_type}{file_extension}" data = read_data(exp_dir=args.dir, tb_dir=args.tb_dir,