Skip to content
Snippets Groups Projects
Commit 39084550 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

New plotting options

parent c5c1ea22
Branches
No related tags found
No related merge requests found
......@@ -13,8 +13,6 @@ from tqdm import tqdm
from convlab.policy.plot_results.plot_action_distributions import plot_distributions
plt.rcParams["font.family"] = "Times New Roman"
def get_args():
parser = argparse.ArgumentParser(description='Export tensorboard data')
......@@ -27,8 +25,11 @@ def get_args():
parser.add_argument("--max-dialogues", type=int, default=0)
parser.add_argument("--fill-between", type=float, default=0.3,
help="the transparency of the std err area")
parser.add_argument("--fontsize", type=int, default=16)
parser.add_argument("--font", type=str, default="Times New Roman")
args = parser.parse_args()
plt.rcParams["font.family"] = args.font
return args
......@@ -61,15 +62,17 @@ def read_tb_data(in_path):
return df
def plot(data, out_file, plot_type="complete_rate", show_image=False, fill_between=0.3, max_dialogues=0, y_label=''):
def plot(data, out_file, plot_type="complete_rate", show_image=False, fill_between=0.3, max_dialogues=0, y_label='',
fontsize=16):
legends = [alg for alg in data]
clrs = sns.color_palette("husl", len(legends))
plt.figure(plot_type)
plt.gca().patch.set_facecolor('#E6E6E6')
plt.grid(color='w', linestyle='solid')
largest_max = -sys.maxsize
smallest_min = sys.maxsize
with sns.axes_style("darkgrid"):
for i, alg in enumerate(legends):
max_step = min([len(d[plot_type]) for d in data[alg]])
......@@ -90,13 +93,15 @@ def plot(data, out_file, plot_type="complete_rate", show_image=False, fill_betwe
largest_max = mean.max() if mean.max() > largest_max else largest_max
smallest_min = mean.min() if mean.min() < smallest_min else smallest_min
plt.xlabel('Training dialogues')
plt.xlabel('Training dialogues', fontsize=fontsize)
#plt.gca().yaxis.set_major_locator(plt.MultipleLocator(round((largest_max - smallest_min) / 10.0, 2)))
if len(y_label) > 0:
plt.ylabel(y_label)
plt.ylabel(y_label, fontsize=fontsize)
else:
plt.ylabel(plot_type)
plt.legend(fancybox=True, shadow=False, ncol=1, loc='lower right')
plt.ylabel(plot_type, fontsize=fontsize)
plt.xticks(fontsize=fontsize-4)
plt.yticks(fontsize=fontsize-4)
plt.legend(fancybox=True, shadow=False, ncol=1, loc='lower right', fontsize=fontsize)
plt.savefig(out_file + ".pdf", bbox_inches='tight')
if show_image:
......@@ -122,7 +127,8 @@ if __name__ == "__main__":
plot_type=plot_type,
fill_between=args.fill_between,
max_dialogues=args.max_dialogues,
y_label=y_label_dict[plot_type])
y_label=y_label_dict[plot_type],
fontsize=args.fontsize)
plot_distributions(args.dir, json.load(open(args.map_file)), args.out_file)
plot_distributions(args.dir, json.load(open(args.map_file)), args.out_file, fontsize=args.fontsize, font=args.font)
......@@ -4,8 +4,6 @@ import os
import seaborn as sns
import pandas as pd
plt.rcParams["font.family"] = "Times New Roman"
def extract_action_distributions_across_seeds(algorithm_dir_path):
'''
......@@ -69,7 +67,8 @@ def extract_action_distributions_across_seeds(algorithm_dir_path):
return distribution_per_step_dict
def plot_distributions(dir_path, alg_maps, output_dir, fill_between=0.3):
def plot_distributions(dir_path, alg_maps, output_dir, fill_between=0.3, fontsize=16, font="Times New Roman"):
plt.rcParams["font.family"] = font
clrs = sns.color_palette("husl", len(alg_maps))
alg_paths = [os.path.join(dir_path, alg_map['dir'])
......@@ -79,10 +78,13 @@ def plot_distributions(dir_path, alg_maps, output_dir, fill_between=0.3):
possible_actions = action_distributions[0][0].keys()
create_bar_plots(action_distributions, alg_maps,
possible_actions, output_dir)
possible_actions, output_dir,
fontsize)
for action in possible_actions:
plt.clf()
plt.gca().patch.set_facecolor('#E6E6E6')
plt.grid(color='w', linestyle='solid')
largest_max = 0
smallest_min = 1
......@@ -97,7 +99,7 @@ def plot_distributions(dir_path, alg_maps, output_dir, fill_between=0.3):
seeds_used = distributions.shape[1]
std_error = std_dev / np.sqrt(seeds_used)
with sns.axes_style("darkgrid"):
# with sns.axes_style("darkgrid"):
plt.plot(steps, mean, c=clrs[i],
label=f"{alg_maps[i]['legend']}")
plt.fill_between(
......@@ -114,16 +116,17 @@ def plot_distributions(dir_path, alg_maps, output_dir, fill_between=0.3):
if round((largest_max - smallest_min) / 10.0, 2) > 0:
plt.gca().yaxis.set_major_locator(plt.MultipleLocator(
round((largest_max - smallest_min) / 10.0, 2)))
plt.xticks(fontsize=7, rotation=0)
plt.xlabel('Training dialogues')
plt.ylabel(f"{action} action probability")
plt.title(f"{action.upper()} action probability")
plt.xticks(fontsize=fontsize-4, rotation=0)
plt.yticks(fontsize=fontsize-4)
plt.xlabel('Training dialogues', fontsize=fontsize)
plt.ylabel(f"{action.title()} action probability", fontsize=fontsize)
plt.title(f"{action.title()} action probability", fontsize=fontsize)
plt.legend(fancybox=True, shadow=False, ncol=1, loc='upper left')
plt.savefig(
output_dir + f'/{action}_probability.pdf', bbox_inches='tight')
def create_bar_plots(action_distributions, alg_maps, possible_actions, output_dir):
def create_bar_plots(action_distributions, alg_maps, possible_actions, output_dir, fontsize):
max_step = max(action_distributions[0].keys())
final_distributions = [distribution[max_step]
......@@ -131,15 +134,20 @@ def create_bar_plots(action_distributions, alg_maps, possible_actions, output_di
df_list = []
for action in possible_actions:
action_list = [action]
action_list = [action.title()]
for distribution in final_distributions:
action_list.append(np.mean(distribution[action]))
df_list.append(action_list)
df = pd.DataFrame(df_list, columns=[
'Probabilities'] + [alg_map["legend"] for alg_map in alg_maps])
plt.figure()
plt.rcParams.update({'font.size': fontsize})
fig = df.plot(x='Probabilities', kind='bar', stacked=False, title='Final Action Distributions',
rot=0, grid=True, color=sns.color_palette("husl", len(alg_maps))).get_figure()
plt.yticks(np.arange(0, 1, 0.1))
rot=0, grid=True, color=sns.color_palette("husl", len(alg_maps)),
fontsize=fontsize).get_figure()
plt.gca().patch.set_facecolor('#E6E6E6')
plt.grid(color='w', linestyle='solid')
plt.yticks(np.arange(0, 1, 0.1), fontsize=fontsize-4)
plt.xticks(fontsize=fontsize-4)
fig.savefig(os.path.join(output_dir, "final_action_probabilities.pdf"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment