diff --git a/.gitignore b/.gitignore index 387c65fcf809dab91a7faaf6fa55d58277d6b498..895f82fa24c81da7155bf01e32976ab005b0641c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,142 @@ public # Byte-compiled / optimized / DLL files __pycache__/ -*.py[cod] \ No newline at end of file +*.py[cod] + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Pydial logs and models +_*/ +*.log +*.json +*.dct +*.prm +*.pyc +*.model + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# PyCharm stuff +.idea/ +.xml + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/config/pydial_benchmarks/env1-dqn-CR.cfg b/config/pydial_benchmarks/env1-dqn-CR.cfg new file mode 100644 index 0000000000000000000000000000000000000000..df989ee54d7396b02608970e2183773725736265 --- /dev/null +++ b/config/pydial_benchmarks/env1-dqn-CR.cfg @@ -0,0 +1,228 @@ +# Error model: 0% error rate, addditive confscorer, uniform nbestgenerator +# User model: standard sampled params, sampled patience +# Masks: on + +###### General parameters ###### +[GENERAL] +domains = CamRestaurants +singledomain = True +tracedialog = 0 +seed = 07051991 + +[exec_config] +configdir = _benchmarkpolicies +logfiledir = _benchmarklogs +numtrainbatches = 4 +traindialogsperbatch = 1000 +numbatchtestdialogs = 500 +trainsourceiteration = 0 +numtestdialogs = 500 +trainerrorrate = 0 +testerrorrate = 0 +testeverybatch = True +#deleteprevpolicy = True + +[logging] +usecolor = False +screen_level = results +file_level = results +file = auto + +###### Environment parameters ###### + +[agent] +maxturns = 25 + +[usermodel] +usenewgoalscenarios = True +oldstylepatience = False +patience = 4,6 +configfile = config/sampledUM.cfg + +[errormodel] +nbestsize = 1 +confusionmodel = RandomConfusions +nbestgeneratormodel = SampledNBestGenerator +confscorer = additive + + +[summaryacts] +maxinformslots = 5 +informmask = True +requestmask = True +informcountaccepted = 4 +byemask = True + +###### Dialogue Manager parameters ###### + +## Comment the following lines if using any other policy (this uses handcrafted policy)## +# [policy] +# policydir = _benchmarkpolicies +# belieftype = focus +# useconfreq = False +# learning = True +# policytype = hdc +# startwithhello = False +# inpolicyfile = auto +# outpolicyfile = auto + +## Uncomment for GP policy ## +#[policy] +#policydir = _benchmarkpolicies +#belieftype = focus +#useconfreq = False +#learning = True +#policytype = gp +#startwithhello = False +#inpolicyfile = auto +#outpolicyfile = auto +# +#[gppolicy] +#kernel = polysort +# +#[gpsarsa] +#random = False +#scale = 3 + +## Uncomment for DQN policy ## +[policy] +policydir = _benchmarkpolicies +belieftype = focus +useconfreq = False +learning = True +policytype = dqn +startwithhello = False +inpolicyfile = auto +outpolicyfile = auto + +[dqnpolicy] +maxiter = 4000 +gamma = 0.99 +learning_rate = 0.001 +tau = 0.02 +replay_type = vanilla +minibatch_size = 64 +capacity = 6000 +exploration_type = e-greedy +episodeNum= 0.0 +epsilon_start = 0.3 +epsilon_end = 0.0 +n_in = 268 +features = ["discourseAct", "method", "requested", "full", "lastActionInformNone", "offerHappened", "inform_info"] +max_k = 5 +learning_algorithm = dqn +architecture = vanilla +h1_size = 300 +h2_size = 100 +training_frequency = 2 +n_samples = 1 +stddev_var_mu = 0.01 +stddev_var_logsigma = 0.01 +mean_log_sigma = 0.000001 +sigma_prior = 1.5 +alpha =0.85 +alpha_divergence =False +sigma_eps = 0.01 +delta = 1.0 +beta = 0.95 +is_threshold = 5.0 +train_iters_per_episode = 1 + +## Uncomment for A2C policy ## +#[policy] +#policydir = _benchmarkpolicies +#belieftype = focus +#useconfreq = False +#learning = True +#policytype = a2c +#startwithhello = False +#inpolicyfile = auto +#outpolicyfile = auto + +#[dqnpolicy] +#maxiter = 4000 +#gamma = 0.99 +#learning_rate = 0.001 +#tau = 0.02 +#replay_type = vanilla +#minibatch_size = 64 +#capacity = 1000 +#exploration_type = e-greedy +#episodeNum= 0.0 +#epsilon_start = 0.5 +#epsilon_end = 0.0 +#n_in = 268 +#features = ["discourseAct", "method", "requested", "full", "lastActionInformNone", "offerHappened", "inform_info"] +#max_k = 5 +#learning_algorithm = dqn +#architecture = vanilla +#h1_size = 200 +#h2_size = 75 +#training_frequency = 2 +#n_samples = 1 +#stddev_var_mu = 0.01 +#stddev_var_logsigma = 0.01 +#mean_log_sigma = 0.000001 +#sigma_prior = 1.5 +#alpha =0.85 +#alpha_divergence =False +#sigma_eps = 0.01 +#delta = 1.0 +#beta = 0.95 +#is_threshold = 5.0 +#train_iters_per_episode = 1 + +## Uncomment for eNAC policy ## +#[policy] +#policydir = _benchmarkpolicies +#belieftype = focus +#useconfreq = False +#learning = True +#policytype = enac +#startwithhello = False +#inpolicyfile = auto +#outpolicyfile = auto + +#[dqnpolicy] +#maxiter = 4000 +#gamma = 0.99 +#learning_rate = 0.001 +#tau = 0.02 +#replay_type = vanilla +#minibatch_size = 64 +#capacity = 1000 +#exploration_type = e-greedy +#episodeNum= 0.0 +#epsilon_start = 0.3 +#epsilon_end = 0.0 +#n_in = 268 +#features = ["discourseAct", "method", "requested", "full", "lastActionInformNone", "offerHappened", "inform_info"] +#max_k = 5 +#learning_algorithm = dqn +#architecture = vanilla +#h1_size = 130 +#h2_size = 50 +#training_frequency = 2 +#n_samples = 1 +#stddev_var_mu = 0.01 +#stddev_var_logsigma = 0.01 +#mean_log_sigma = 0.000001 +#sigma_prior = 1.5 +#alpha =0.85 +#alpha_divergence =False +#sigma_eps = 0.01 +#delta = 1.0 +#beta = 0.95 +#is_threshold = 5.0 +#train_iters_per_episode = 1 + +###### Evaluation parameters ###### + +[eval] +rewardvenuerecommended=0 +penaliseallturns = True +wrongvenuepenalty = 0 +notmentionedvaluepenalty = 0 +successmeasure = objective +successreward = 20 + diff --git a/policy/DQNPolicy.py b/policy/DQNPolicy.py index 30aa6ca2a8c288180529851f5158fd61d4e16ca8..2f3931bf4f44af9cb8c59dd41ef3c9c22d8d66ff 100644 --- a/policy/DQNPolicy.py +++ b/policy/DQNPolicy.py @@ -768,7 +768,7 @@ class DQNPolicy(Policy.Policy): curiosity_loss = self.curiosityFunctions.training(s2_batch, s_batch, a_batch_one_hot) # self.curiositypred_loss.append(curiosity_loss) # for plotting - predicted_q_value, currentLoss = self.dqn.train(s_batch, a_batch_one_hot, reshaped_yi) + predicted_q_value, _, currentLoss = self.dqn.train(s_batch, a_batch_one_hot, reshaped_yi) if self.episodecount % 1 == 0: # Update target networks diff --git a/pydial.py b/pydial.py index aefc9930a1d4cd4b146a2a3e93ded6464ab0dab7..606eb2382bddbcaa4e8645204de98c0102c77f50 100644 --- a/pydial.py +++ b/pydial.py @@ -931,6 +931,9 @@ def test_command(configfile, iteration, seed=None, testerrorrate=None, trainerro policyname = '-'.join(ps[:-1] + ['seed{}'.format(orig_seed)] + [ps[-1]]) else: policyname = "%s-%02d.%d" % (configId, gtrainerrorrate, i) + if not 'seed' in policyname: + ps= policyname.split('-') + policyname = '-'.join(ps[:-1] + ['seed{}'.format(orig_seed)] + [ps[-1]]) poldirpath = path(policy_dir) if poldirpath.isdir(): policyfiles = poldirpath.files()