Skip to content
Snippets Groups Projects
Select Git revision
  • e8cbafc5d6d62d7f81d22d5b57a13f555914fcb8
  • master default protected
2 results

pydial.py

Blame
  • user avatar
    Carel van Niekerk authored
    fb16cd9c
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    pydial.py 52.84 KiB
    #! /usr/bin/env python
    
    ###############################################################################
    # PyDial: Multi-domain Statistical Spoken Dialogue System Software
    ###############################################################################
    #
    # Copyright 2015 - 2019
    # Cambridge University Engineering Department Dialogue Systems Group
    #
    # 
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #
    ###############################################################################
    
    
    import os
    from scriptine import run, path, log, command
    import re
    import numpy as np
    
    # Uncomment for mac os users
    import matplotlib
    # matplotlib.use('TkAgg')
    # matplotlib.use('Agg')
    
    import matplotlib.pyplot as plt
    
    #Uncomment for 4k screens
    # matplotlib.rcParams.update({'font.size': 22})
    
    # PyDial modules
    import Simulate
    import Texthub
    from utils import Settings
    from utils import ContextLogger
    from ontology import Ontology
    import utils.ContextLogger as clog
    import pprint
    pp = pprint.PrettyPrinter(indent=4)
    
    # Remove tensorflow deprecation warnings
    #import tensorflow.python.util.deprecation as deprecation
    #deprecation._PRINT_DEPRECATION_WARNINGS = False
    
    logger = None
    tracedialog = 2
    policy_dir = ""
    conf_dir = ""
    log_dir = ""
    logfile = ""
    
    gnumtrainbatches = 0
    gtraindialogsperbatch = 0
    gnumbatchtestdialogs = 0
    gnumtestdialogs = 0
    gtrainerrorrate = 0
    gtesterrorrate = 0
    gtrainsourceiteration = 0
    gtesteverybatch = False
    
    gpscale = 1
    
    gplotnum = 1
    
    gbatchnum = 0
    
    isSingleDomain = False
    taskID = ""
    domain = ""
    domains = []
    policytype = "hdc"
    
    policytypes = {}
    
    
    def help_command():
        """ Provide an overview of pydial functionality
        """
        print("\n pydial - command line interface to PyDial")
        print('""""""""""""""""""""""""""""""""""""""""""""')
        print(' o Runs simulator to train and test policies')
        print(' o Plots learning rates and performance vs error rate')
        print(' o Runs texthub in multi-domain chat mode\n')
        print('Basic usage:')
        print('  a) Make pydial.py executable and add a symbolic link to it (eg pydial) from your')
        print('     local bin directory.  Create a directory called ID and cd into it.\n')
        print("  b) create a config file and add an exec_config section eg:\n")
        print('     [exec_config]')
        print('     domain = CamRestaurants     # specific train/test domain')
        print('     policytype = gp             # type of policy to train/test')
        print('     configdir = cfgdir          # folder to store configs')
        print('     logfiledir = logdir         # folder to store logfiles')
        print('     numtrainbatches = 2         # num training batches (iterations)')
        print('     traindialogsperbatch = 10   # num dialogs per batch')
        print('     numbatchtestdialogs =  20   # num dialogs to eval each batch')
        print('     trainsourceiteration = 0    # index of initial source policy to update')
        print('     testiteration = 1           # policy iteration to test')
        print('     numtestdialogs =  10        # num dialogs per test')
        print('     trainerrorrate = 0          # train error rate in %')
        print('     testerrorrate  = 0          # test error rate in %')
        print('     testeverybatch = True       # enable batch testing\n')
        print('     by convention the config file name for training and testing should be of the')
        print('     form ID-policytype-domain.cfg where ID is a user-defined id.')
        print('     (There is more detail on naming conventions below.)')
        print('     Also unless the current directory is the same as the PyDial root')
        print('     make sure that [GENERAL]root points to root of the PyDial source tree.\n')
        print('  c) to train a policy as specified in the config file, type')
        print('       > pydial train config')
        print('     if trainsourceiteration=0 this creates a new policy in n batches where')
        print('     n=numtrainbatches, otherwise an existing policy is trained further.\n')
        print('  d) to test a policy as specified in the config file, type')
        print('       > pydial test config\n')
        print('     texthub.py can be invoked to interact with a policy from the keyboard by:')
        print('       > pydial chat config')
        print('     Note that train and test must refer to a specific domain as per [exec_config] domain')
        print('     whereas chat mode can specify multiple domains via the [GENERAL]domains variable.\n')
        print('  e) for convenience, many config parameters can be overridden on the command line, eg')
        print('       > pydial train config --trainerrorrate=20')
        print('       > pydial test config --iteration=4 --trainerrorrate=20 --testerrorrate=50')
        print('     to train a policy at 20% error rate and test the 4th iteration at 50% error rate.')
        print('     A range of test error rates can be specified as a triple (stErr,enErr,stepSize), eg')
        print("       > pydial test config --iteration=4 --trainerrorrate=20 --testerrorrate='(0,50,10)'")
        print('     to test a policy at 0%, 10%, 20%, 30%, 40%, and 50% error rates.\n')
        print('  f) logfiles for each train/test run are stored in logfiledir.')
        print('     The plot command scans one or more logfiles and extract info to plot eg')
        print('       > pydial plot logdir/*train*')
        print('     Setting the option --printtab, also tabulates the performance data.\n')
        print('  All policy information is stored in the policydir specified in the corresponding ')
        print('  config file section with name [policy_domain]. Since pydial overrides some config')
        print('  params, the actual configs used for each run are recorded in configdir.\n')
        print('  Derived file naming convention:')
        print("     Policyname: ID-poltype-domain-TrainErrRate               eg S0-gp-CamRestaurants-20")
        print("     Policy: ID-poltype-domain-TrainErrRate.Iteration         eg S0-gp-CamRestaurants-20.3")
        print("     Policyfile: ID-poltype-domain-TrainErrRate.Iteration.ext eg S0-gp-CamRestaurants-20.3.dct")
        print("     TrainLogfiles: PolicyName.IterationRange.train.log       eg S0-gp-CamRestaurants-20.1-3.train.log")
        print("     EvalLogfiles:  Policy.eval.ErrorRange.eval.log           eg S0-gp-CamRestaurants-20.3.eval.00-50.log\n")
        print("To get further help:")
        print("  pydial             list of available commands")
        print("  pydial help        this overview")
        print("  pydial cmd --help  help for a specific command\n")
    
    
    def conventionCheck(name):
        global taskID, domain, policytype
        try:
            if name.find('-') < 0:
                raise Exception('no separators')
            (taskID,p,d)=name.split('-')
            if p != policytype:
                raise Exception('policytype != config param')
            if d != domain:
                raise Exception('domain name != config param')
        except Exception as x:
            pass#log.warn("Non-standard config name [%s] (preferred format ID-policytype-domain.cfg)", x.args[0])
    
    
    def getConfigId(configFileName):
        i = configFileName.rfind('.')
        if i < 0 or configFileName[i+1:] != 'cfg':
            print(("Config file %s does not have required .cfg extension" % configFileName))
            exit(0)
    
        cfg = path(configFileName)
        if not cfg.isfile():
            print(("Config file %s does not exist" % configFileName))
            exit(0)
        id = configFileName[:i]
        j = id.rfind('/')
        if j >= 0: id = id[j+1:]
        return id
    
    
    def getOptionalConfigVar(configvarname, default='', section='exec_config'):
        value = default
        if Settings.config.has_option(section, configvarname):
            value = Settings.config.get(section, configvarname)
        return value
    
    
    def getRequiredDirectory(directoryname, section='exec_config'):
        assert Settings.config.has_option(section, directoryname),\
            "Value {} in section {} is missing.".format(directoryname, section)
        dir = Settings.config.get(section, directoryname)
        if dir[-1] != '/': dir = dir+'/'
        return dir
    
    
    def getOptionalConfigInt(configvarname, default='0',section='exec_config'):
        value = default
        if Settings.config.has_option(section, configvarname):
            try:
                value = Settings.config.getint(section, configvarname)
            except ValueError:
                value = Settings.config.get(section, configvarname)
    
    
    
        return value
    
    
    def getOptionalConfigBool(configvarname, default='False', section='exec_config'):
        value = default
        if Settings.config.has_option(section, configvarname):
            value = Settings.config.getboolean(section, configvarname)
        return value
    
    
    def initialise(configId, config_file, seed, mode, trainerrorrate=None, trainsourceiteration=None,
                   numtrainbatches=None, traindialogsperbatch=None, numtestdialogs=None,
                   testerrorrate=None, testenderrorrate=None, iteration=None, traindomains=None, testdomains=None,
                   dbprefix=None):
        global logger, logfile, traceDialog, isSingleDomain
        global policy_dir, conf_dir, log_dir
        global gnumtrainbatches, gtraindialogsperbatch, gnumbatchtestdialogs, gnumtestdialogs
        global gtrainerrorrate, gtesterrorrate, gtrainsourceiteration
        global taskID, domain, domains, policytype, gtesteverybatch, gpscale
        global gdeleteprevpolicy, isSingleModel
        global policytypes
    
        if seed is not None:
            seed = int(seed)
        seed = Settings.init(config_file, seed)
        taskID = 'ID'
    
        isSingleDomain = getOptionalConfigBool("singledomain", isSingleDomain, "GENERAL")
        isSingleModel = getOptionalConfigBool("singlemodel", False, "policycommittee")
        traceDialog    = getOptionalConfigInt("tracedialog", tracedialog, "GENERAL")
        domain         = getOptionalConfigVar("domains", '', "GENERAL")
        if len(domain.split(',')) > 1 and isSingleDomain:
            logger.error('It cannot be singledomain and have several domains defined, Check config file.')
        if isSingleDomain:
            if Settings.config.has_section('policy_' + domain):
                policytype = getOptionalConfigVar('policytype', policytype, 'policy_' + domain)
            else:
                policytype = getOptionalConfigVar('policytype', policytype, 'policy')
            conventionCheck(configId)
        else:
            domains = getOptionalConfigVar("domains", "", "GENERAL").split(',')
            policytypes = {}
            for domain in domains:
                if Settings.config.has_section('policy_' + domain):
                    policytypes[domain] = getOptionalConfigVar('policytype', policytype, 'policy_' + domain)
                else:
                    policytypes[domain] = getOptionalConfigVar('policytype', policytype, 'policy')
    
        # if gp, make sure to save required scale before potentially overriding
        if isSingleDomain:
            if policytype == "gp":
                if Settings.config.has_section("gpsarsa_" + domain):
                    try:
                        gpscale = Settings.config.getint("gpsarsa_" + domain, "scale")
                    except ValueError:
                        gpscale = Settings.config.get("gpsarsa_" + domain, "scale")
                else:
                    try:
                        gpscale = Settings.config.getint("gpsarsa", "scale")
                    except ValueError:
                        gpscale = Settings.config.get("gpsarsa", "scale")
        else:
            gpscales = {}
            for domain in domains:
                if policytypes[domain] == "gp":
                    if Settings.config.has_section("gpsarsa_" + domain):
                        try:
                            gpscales[domain] = Settings.config.getint("gpsarsa_"+ domain, "scale")
                        except ValueError:
                            gpscales[domain] = Settings.config.get("gpsarsa_"+ domain, "scale")
    
                    else:
                        try:
                            gpscales[domain] = Settings.config.getint("gpsarsa", "scale")
                        except ValueError:
                            gpscales[domain] = Settings.config.get("gpsarsa", "scale")
    
        # if deep-rl model, make sure to set the correct n_in
        if isSingleDomain:
            if Settings.config.has_section("dqnpolicy"):
                if domain == 'CamRestaurants':
                    Settings.config.set("dqnpolicy", 'n_in', '268')
                elif domain == 'SFRestaurants':
                    Settings.config.set("dqnpolicy", 'n_in', '636')
                elif domain == 'Laptops11':
                    Settings.config.set("dqnpolicy", 'n_in', '257')
                    # TODO: set rest of environments and multidomain
    
        # Get required folders and create if necessary
        log_dir    = getRequiredDirectory("logfiledir")
        conf_dir   = getRequiredDirectory("configdir")
        if isSingleDomain:
            if policytype != 'hdc':
                if Settings.config.has_section("policy_"+domain):
                    policy_dir = getRequiredDirectory("policydir", "policy_"+domain)
                else:
                    policy_dir = getRequiredDirectory("policydir", "policy")
                pd = path(policy_dir)
                if not pd.isdir():
                    print("Policy dir %s does not exist, creating it" % policy_dir)
                    pd.mkdir()
        else:
            for domain in domains:
                if policytypes[domain] != 'hdc':
                    if Settings.config.has_section("policy_" + domain):
                        policy_dir = getRequiredDirectory("policydir", "policy_" + domain)
                    else:
                        policy_dir = getRequiredDirectory("policydir", "policy")
                    pd = path(policy_dir)
                    if not pd.isdir():
                        print("Policy dir %s does not exist, creating it" % policy_dir)
                        pd.mkdir()
    
        cd = path(conf_dir)
        if not cd.isdir():
            print("Config dir %s does not exist, creating it" % conf_dir)
            cd.mkdir()
        ld = path(log_dir)
        if not ld.isdir():
            print("Log dir %s does not exist, creating it" % log_dir)
            ld.mkdir()
    
    
        # optional config settings
        if numtrainbatches:
            gnumtrainbatches = int(numtrainbatches)
        else:
            gnumtrainbatches = getOptionalConfigInt("numtrainbatches", 1)
        if traindialogsperbatch:
            gtraindialogsperbatch = int(traindialogsperbatch)
        else:
            gtraindialogsperbatch = getOptionalConfigInt("traindialogsperbatch", 100)
        if trainerrorrate:
            gtrainerrorrate = int(trainerrorrate)
        else:
            gtrainerrorrate = getOptionalConfigInt("trainerrorrate", 0)
        if testerrorrate:
            gtesterrorrate = int(testerrorrate)
        else:
            gtesterrorrate = getOptionalConfigInt("testerrorrate",0)
        if trainsourceiteration:
            gtrainsourceiteration = int(trainsourceiteration)
        else:
            gtrainsourceiteration = getOptionalConfigInt("trainsourceiteration",0)
        if numtestdialogs:
            gnumtestdialogs = int(numtestdialogs)
        else:
            gnumtestdialogs = getOptionalConfigInt("numtestdialogs", 50)
    
        gnumbatchtestdialogs = getOptionalConfigInt("numbatchtestdialogs", 20)
        gtesteverybatch = getOptionalConfigBool("testeverybatch",True)
        gdeleteprevpolicy = getOptionalConfigBool("deleteprevpolicy", False)
        if seed is not None and not 'seed' in configId:
            if seed >= 100 and seed < 200:
                seed_string = 'seed{}-'.format(seed - 100)
            else:
                seed_string = 'seed{}-'.format(seed)
        else:
            seed_string = ''
        if mode == "train":
            if gnumtrainbatches>1:
                enditeration = gtrainsourceiteration+gnumtrainbatches
                logfile = "%s-%s%02d.%d-%d.train.log" % (configId, seed_string,gtrainerrorrate,gtrainsourceiteration+1,enditeration)
            else:
                logfile = "%s-%s%02d.%d.train.log" % (configId, seed_string, gtrainerrorrate, gtrainsourceiteration + 1)
        elif mode == "eval":
            if testenderrorrate:
                logfile = "%s-%s%02d.%d.eval.%02d-%02d.log" % (configId, seed_string,gtrainerrorrate,iteration,
                                                             gtesterrorrate,testenderrorrate)
            else:
                if type(iteration) == str:
                    logfile = "{}_vs_{}-{}.eval.log".format(configId, iteration, seed_string[:-1])
                else:
                    logfile = "%s-%s%02d.%d.eval.%02d.log" % (configId, seed_string, gtrainerrorrate, iteration, gtesterrorrate)
        elif mode == "chat":
            logfile = "%s-%s%02d.%d.chat.log" % (configId, seed_string, gtrainerrorrate, gtrainsourceiteration)
        else:
            print("Unknown initialisation mode:",mode)
            exit(0)
        print('*** logfile: {} ***'.format(logfile))
        Settings.config.set("logging", "file", log_dir + logfile)
        if traindomains:
            Settings.config.set("GENERAL", "traindomains", traindomains)
        if testdomains:
            Settings.config.set("GENERAL", "testdomains", testdomains)
        if dbprefix:
            Settings.config.set("exec_config", "dbprefix", dbprefix)
        if not Ontology.global_ontology:
            ContextLogger.createLoggingHandlers(config=Settings.config)
            logger = ContextLogger.getLogger('')
            Ontology.init_global_ontology()
        else:
            ContextLogger.resetLoggingHandlers()
            ContextLogger.createLoggingHandlers(config=Settings.config)
            logger = ContextLogger.getLogger('')
    
        Settings.random.seed(int(seed))
        if Settings.root == '':
            Settings.root = os.getcwd()
        logger.info("Seed = %d", seed)
        logger.info("Root = %s", Settings.root)
    
    
    def setupPolicy(domain, configId, trainerr, source_iteration, target_iteration, seed=None):
        if Settings.config.has_section("policy_" + domain):
            policy_section = "policy_" + domain
        else:
            policy_section = "policy"
        if not str(source_iteration).isdigit():
            inpolicyfile = source_iteration
            outpolicyfile = source_iteration
        elif seed is not None:
            inpolicyfile = "%s-seed%s-%02d.%d" % (configId, seed, trainerr, source_iteration)
            outpolicyfile = "%s-seed%s-%02d.%d" % (configId, seed, trainerr, target_iteration)
        else:
            inpolicyfile = "%s-%02d.%d" % (configId, trainerr, source_iteration)
            outpolicyfile = "%s-%02d.%d" % (configId, trainerr, target_iteration)
        if isSingleDomain:
            Settings.config.set(policy_section, "inpolicyfile", policy_dir + inpolicyfile)
            Settings.config.set(policy_section, "outpolicyfile", policy_dir + outpolicyfile)
        else:
            multi_policy_dir = policy_dir + domain
            pd = path(multi_policy_dir)
            if not pd.isdir():
                print("Policy dir %s does not exist, creating it" % multi_policy_dir)
                pd.mkdir()
            Settings.config.set(policy_section, "inpolicyfile", multi_policy_dir + inpolicyfile)
            Settings.config.set(policy_section, "outpolicyfile", multi_policy_dir + outpolicyfile)
        return (inpolicyfile, outpolicyfile)
    
    
    def trainBatch(domain, configId, trainerr, ndialogs, source_iteration, seed=None):
        if isSingleDomain:
            (inpolicy, outpolicy) = setupPolicy(domain, configId, trainerr, source_iteration, source_iteration + 1, seed=seed)
            mess = "*** Training Iteration %s->%s: iter=%d, error-rate=%d, num-dialogs=%d ***" % (
                inpolicy, outpolicy, source_iteration, trainerr, ndialogs)
            if tracedialog > 0: print(mess)
            logger.results(mess)
            # make sure that learning is switched on
            if Settings.config.has_section("policy_" + domain):
                Settings.config.set("policy_" + domain, "learning", 'True')
            else:
                Settings.config.set("policy", "learning", 'True')
            # if gp, make sure to reset scale to config setting
            if policytype == "gp":
                if Settings.config.has_section("gpsarsa_" + domain):
                    Settings.config.set("gpsarsa_" + domain, "scale", str(gpscale))
                else:
                    Settings.config.set("gpsarsa", "scale", str(gpscale))
            # Define the config file for this iteration
            confsavefile = conf_dir + outpolicy + ".train.cfg"
        else:
            mess = "*** Training Iteration: iter=%d, error-rate=%d, num-dialogs=%d ***" % (
                source_iteration, trainerr, ndialogs)
            if tracedialog > 0: print(mess)
            logger.results(mess)
            for dom in domain:
                setupPolicy(dom, configId, trainerr, source_iteration, source_iteration + 1, seed=seed)
                # make sure that learning is switched on
                if Settings.config.has_section("policy_" + dom):
                    Settings.config.set("policy_" + dom, "learning", 'True')
                else:
                    Settings.config.set("policy", "learning", 'True')
                # if gp, make sure to reset scale to config setting
                if policytype == "gp":
                    Settings.config.set("gpsarsa_" + dom, "scale", str(gpscale))
            # Define the config file for this iteration
            multipolicy = "%s-%02d.%d" % (configId, trainerr, source_iteration + 1)
            confsavefile = conf_dir + multipolicy + ".train.cfg"
    
        # Save the config file for this iteration
        cf = open(confsavefile, 'w')
        Settings.config.write(cf)
        error = float(trainerr) / 100.0
        # run the system
        simulator = Simulate.SimulationSystem(error_rate=error)
        simulator.run_dialogs(ndialogs)
        if gdeleteprevpolicy:
            if isSingleDomain:
                if inpolicy[-1] != '0':
                    if Settings.config.has_section("policy_" + domain):
                        for f in os.listdir(Settings.config.get('policy_{}'.format(domain), 'policydir')):
                            if re.search(inpolicy, f):
                                os.remove(os.path.join(Settings.config.get('policy_{}'.format(domain), 'policydir'), f))
                    else:
                        for f in os.listdir(Settings.config.get('policy', 'policydir')):
                            if re.search(inpolicy, f):
                                os.remove(os.path.join(Settings.config.get('policy', 'policydir'), f))
    
    
    def setEvalConfig(domain, configId, evalerr, ndialogs, iteration, seed=None):
        (_, policy) = setupPolicy(domain, configId, gtrainerrorrate, iteration, iteration, seed=seed)
        if isSingleDomain:
            mess = "*** Evaluating %s: error-rate=%d num-dialogs=%d ***" % (policy, evalerr, ndialogs)
        else:
            mess = "*** Evaluating %s: error-rate=%d num-dialogs=%d ***" % (policy.replace('Multidomain', domain),
                                                                            evalerr, ndialogs)
        if tracedialog > 0: print(mess)
        logger.results(mess)
        # make sure that learning is switched off
        if Settings.config.has_section("policy_" + domain):
            Settings.config.set("policy_" + domain, "learning", 'False')
        else:
            Settings.config.set("policy", "learning", 'False')
        # if gp, make sure to reset scale to 1 for evaluation
        if policytype == "gp":
            if Settings.config.has_section("gpsarsa_" + domain):
                Settings.config.set("gpsarsa_" + domain, "scale", "1")
            else:
                Settings.config.set("gpsarsa", "scale", "1")
        # Save a copy of config file
        confsavefile = conf_dir + "%s.eval.%02d.cfg" % (policy, evalerr)
        cf = open(confsavefile, 'w')
        Settings.config.write(cf)
    
    
    def evalPolicy(domain, configId, evalerr, ndialogs, iteration, seed=None):
        if isSingleDomain:
            setEvalConfig(domain, configId, evalerr, ndialogs, iteration, seed=seed)
        else:
            for dom in domains:
                setEvalConfig(dom, configId, evalerr, ndialogs, iteration, seed=seed)
    
        error = float(evalerr) / 100.0
        # finally run the system
        simulator = Simulate.SimulationSystem(error_rate=error)
        simulator.run_dialogs(ndialogs)
    
    
    def getIntParam(line, key):
        m = re.search(" %s *= *(\d+)" % (key), line) #what is this parenthesisi placement here and below???
        if m is None:
            print("Cant find int %s in %s" % (key, line))
            exit(0)
        return int(m.group(1))
    
    
    def getFloatRange(line,key):
        m = re.search(" %s *= *(\-?\d+\.\d+) *\+- *(\d+\.\d+)" % (key), line)
        if m==None:
            print("Cant find float %s in %s" % (key, line))
            exit(0)
        return (float(m.group(1)),float(m.group(2)))
    
    
    def getDomainFromLog(l):
        return l.split()[-1].split(',')
    
    
    def extractEvalData(lines):
        evalData = {}
        training = False
        domain_list = []
        #domain_list = []#['SFRestaurants','SFHotels','Laptops11']
        #for dom in domain_list:
        #    evalData[dom] = {}
        cur_domain = None
        for l in lines:
            if l.find('List of domains:') >= 0:
                # get the list of domains from the log by reading the lines where the ontologies are loaded
                doms = getDomainFromLog(l)
                for domain in doms:
                    if domain not in domain_list:
                        domain_list.append(domain)
                        evalData[domain] = {}
            if l.find('*** Training Iteration') >= 0:
                iteration = getIntParam(l, 'iter')+1
                if iteration in list(evalData.keys()):
                    print("Duplicate iteration %d" % iteration)
                    exit(0)
                for domain in domain_list:
                    evalData[domain][iteration] = {}
                    evalData[domain][iteration]['erate'] = getIntParam(l, 'error-rate')
                    evalData[domain][iteration]['ndialogs'] = getIntParam(l, 'num-dialogs')
                training = True
            elif l.find('*** Evaluating')>=0 and not training:
                l = l.replace('CR', 'CamRestaurants') 
                erate = getIntParam(l, 'error-rate')
                ll = l[l.find('*** Evaluating') + len('*** Evaluating')+1:]
                (ll,x) = ll.split(':')
                for domain in domain_list:
                    if domain in ll:
                        evalData[domain][erate] = {}
                        evalData[domain][erate]['policy'] = ll
                        evalData[domain][erate]['ndialogs'] = getIntParam(l, 'num-dialogs')
            elif l.find('Results for domain:') >= 0:
                cur_domain = l.split('Results for domain:')[1].split('--')[0].strip()
            elif l.find('Average reward') >= 0:
                if training:
                    evalData[cur_domain][iteration]['reward'] = getFloatRange(l, 'Average reward')
                else:
                    evalData[cur_domain][erate]['reward'] = getFloatRange(l, 'Average reward')
            elif l.find('Average success') >= 0:
                if training:
                    evalData[cur_domain][iteration]['success'] = getFloatRange(l, 'Average success')
                else:
                    evalData[cur_domain][erate]['success'] = getFloatRange(l, 'Average success')
    
            elif l.find('Average turns') >= 0:
                if training:
                    evalData[cur_domain][iteration]['turns'] = getFloatRange(l, 'Average turns')
                else:
                    evalData[cur_domain][erate]['turns'] = getFloatRange(l, 'Average turns')
        return evalData
    
    
    def plotTrain(dname, rtab, stab, block=True, saveplot=False):
        font = {
                'weight': 'bold',
                'size': 20}
        plt.rc('font', **font)
    
        global gplotnum
        policylist = sorted(rtab.keys())
        ncurves = len(policylist)
        plt.figure(gplotnum)
    
        gplotnum += 1
        for policy in policylist:
            tab = rtab[policy]
            plt.subplot(211)
            # plt.xlim((800, 4200))
            if len(tab['x']) < 2:
                plt.axhline(y=tab['y'][0], linestyle='--')
            else:
                plt.errorbar(tab['x'],tab['y'], yerr=tab['var'], label=policy)
                # plt.errorbar(tab['x'], tab['y'], label=policy)
            tab = stab[policy]
            plt.subplot(212)
            # plt.xlim((800, 4200))
            if len(tab['x']) < 2:
                plt.axhline(y=tab['y'][0], linestyle='--')
            else:
                plt.errorbar(tab['x'],tab['y'],yerr=tab['var'],label=policy)
                # plt.errorbar(tab['x'], tab['y'], label=policy)
        plt.subplot(211)
        plt.grid()
        plt.legend(loc='lower right', fontsize=14)  # loc='lower right', best,
        plt.title("Performance vs Num Train Dialogues")
        plt.ylabel('Reward')
        plt.subplot(212)
        plt.grid()
        plt.legend(loc='lower right', fontsize=14)
        plt.xlabel('Num Dialogues')
        plt.ylabel('Success')
        if saveplot:
            if not os.path.exists('_plots'):
                os.mkdir('_plots')
            plt.savefig('_plots/' + dname + '.png', bbox_inches='tight')
            print('plot saved as', dname)
        else:
            plt.show(block=block)
    
    
    def plotTest(dname, rtab, stab, block=True, saveplot=False):
        global gplotnum
        policylist = sorted(rtab.keys())
        ncurves = len(policylist)
        plt.figure(gplotnum)
        gplotnum += 1
        for policy in policylist:
            tab = rtab[policy]
            plt.subplot(211)
            plt.errorbar(tab['x'], tab['y'], yerr=tab['var'], label=policy)
            tab = stab[policy]
            plt.subplot(212)
            plt.errorbar(tab['x'], tab['y'], yerr=tab['var'], label=policy)
        plt.subplot(211)
        plt.grid()
        plt.legend(loc='lower left', fontsize=12-ncurves)
        plt.title(dname+" Performance vs Error Rate")
        plt.ylabel('Reward')
        plt.subplot(212)
        plt.grid()
        plt.legend(loc='lower left', fontsize=12-ncurves)
        plt.xlabel('Error Rate')
        plt.ylabel('Success')
        # plt.show(block=block)
        plt.show()
    
    
    def printTable(title, tab):
        firstrow = True
        policylist = sorted(tab.keys())
        for policy in policylist:
            xvals = tab[policy]['x']
            if firstrow:
                s = "%-20s" % title
                for i in range(0, len(xvals)): s += "%13d" % xvals[i]
                print(s)
                firstrow = False
            s = "%-18s :" % policy
            for i in range(0,len(xvals)):
                s+= "%6.1f +-%4.1f" % (tab[policy]['y'][i],tab[policy]['var'][i])
            print(s)
        print("")
    
    
    def tabulateTrain(dataSet):
        #pp.pprint(dataSet)
        rtab = {}
        stab = {}
        ttab = {}
        oldx = []
        for policy in list(dataSet.keys()):
            yvals = []
            xvals = []
            dialogsum = 0
            for iteration in list(dataSet[policy].keys()):
                d = dataSet[policy][iteration]
                (yr, yrv) = d['reward']
                (ys, ysv) = d['success']
                (yt, ytv) = d['turns']
                ndialogs = d['ndialogs']
                dialogsum += ndialogs
                yvals.append((yr, yrv, ys, ysv, yt, ytv))
                xvals.append(dialogsum)
            yvals = [yy for (xx, yy) in sorted(zip(xvals, yvals))]
            x = [xx for (xx, yy) in sorted(zip(xvals, yvals))]
            #if oldx != [] and oldx != x:
            #    print "Policy %s has inconsistent batch sizes" % policy
            oldx = x
            yrew = [yr for (yr, yrv, ys, ysv, yt, ytv) in yvals]
            yrerr = [yrv for (yr, yrv, ys, ysv, yt, ytv) in yvals]
            ysucc = [ys for (yr, yrv, ys, ysv, yt, ytv) in yvals]
            yserr = [ysv for (yr, yrv, ys, ysv, yt, ytv) in yvals]
            yturn = [yt for (yr, yrv, ys, ysv, yt, ytv) in yvals]
            yterr = [ytv for (yr, yrv, ys, ysv, yt, ytv) in yvals]
            if not (policy in list(rtab.keys())): rtab[policy] = {}
            rtab[policy]['y'] = yrew
            rtab[policy]['var'] = yrerr
            rtab[policy]['x'] = x
            if not (policy in list(stab.keys())): stab[policy] = {}
            stab[policy]['y'] = ysucc
            stab[policy]['var'] = yserr
            stab[policy]['x'] = x
            if not (policy in list(ttab.keys())): ttab[policy] = {}
            ttab[policy]['y'] = yturn
            ttab[policy]['var'] = yterr
            ttab[policy]['x'] = x
        # average results over seeds
        averaged_result_list = []
        for result in [rtab, stab, ttab]:
            averaged_result = {}
            n_seeds = {}
            for policy_key in result:
                if "seed" in policy_key:
                    seed_n = policy_key[policy_key.find("seed"):]
                    seed_n = seed_n.split('-')[0]
                    general_policy_key = policy_key.replace(seed_n + '-', '')
                else:
                    general_policy_key = policy_key
                if not general_policy_key in averaged_result:
                    averaged_result[general_policy_key] = {}
                    n_seeds[general_policy_key] = 1
                else:
                    n_seeds[general_policy_key] += 1
                for key in result[policy_key]:
                    if not key in averaged_result[general_policy_key]:
                        averaged_result[general_policy_key][key] = np.array(result[policy_key][key])
                    else:
                        averaged_result[general_policy_key][key] += np.array(result[policy_key][key])
            for policy_key in averaged_result:
                for key in averaged_result[policy_key]:
                    averaged_result[policy_key][key] = averaged_result[policy_key][key]/n_seeds[policy_key]
            averaged_result_list.append(averaged_result)
    
        return averaged_result_list
    
    
    def tabulateTest(dataSet):
        #pp.pprint(dataSet)
        rtab = {}
        stab = {}
        ttab = {}
        oldx = []
        for policy in list(dataSet.keys()):
            yvals = []
            xvals = []
            for erate in list(dataSet[policy].keys()):
                d = dataSet[policy][erate]
                (yr,yrv) = d['reward']
                (ys,ysv) = d['success']
                (yt,ytv) = d['turns']
                yvals.append((yr, yrv, ys, ysv, yt, ytv))
                xvals.append(erate)
            yvals = [yy for (xx, yy) in sorted(zip(xvals, yvals))]
            x = [xx for (xx,yy) in sorted(zip(xvals, yvals))]
            if oldx != [] and oldx != x:
                print("Policy %s has inconsistent range of error rates" % policy)
                exit(0)
            oldx = x
            yrew = [yr for (yr,yrv,ys,ysv,yt,ytv) in yvals]
            yrerr = [yrv for (yr,yrv,ys,ysv,yt,ytv) in yvals]
            ysucc = [ys for (yr,yrv,ys,ysv,yt,ytv) in yvals]
            yserr = [ysv for (yr,yrv,ys,ysv,yt,ytv) in yvals]
            yturn = [yt for (yr,yrv,ys,ysv,yt,ytv) in yvals]
            yterr = [ytv for (yr,yrv,ys,ysv,yt,ytv) in yvals]
            if not (policy in list(rtab.keys())): rtab[policy]={}
            rtab[policy]['y'] = yrew
            rtab[policy]['var'] = yrerr
            rtab[policy]['x'] = x
            if not (policy in list(stab.keys())): stab[policy] = {}
            stab[policy]['y'] = ysucc
            stab[policy]['var'] = yserr
            stab[policy]['x'] = x
            if not (policy in list(ttab.keys())): ttab[policy] = {}
            ttab[policy]['y'] = yturn
            ttab[policy]['var'] = yterr
            ttab[policy]['x'] = x
        return rtab, stab, ttab
    
    
    def train_command(configfile, seed=None, trainerrorrate=None,trainsourceiteration=None,
                      numtrainbatches=None, traindialogsperbatch=None, traindomains=None, dbprefix=None):
        """ Train a policy according to the supplied configfile.
            Results are stored in the directories specified in the [exec_config] section of the config file.
            Optional parameters over-ride the corresponding config parameters of the same name.
        """
    
        try:
            if seed and seed.startswith('('):
                seeds = seed.replace('(', '').replace(')', '').split(',')
                if len(seeds) == 2 and int(seeds[0]) < int(seeds[1]):
                    seeds = [str(x) for x in range(int(seeds[0]), 1+int(seeds[1]))]
                for seed in seeds:
                    print('*** Seed {} ***'.format(seed))
                    train_command(configfile, seed=seed, trainerrorrate=trainerrorrate,
                                  trainsourceiteration=trainsourceiteration,
                                  numtrainbatches=numtrainbatches, traindialogsperbatch=traindialogsperbatch,
                                  traindomains=traindomains, dbprefix=dbprefix)
    
            else:
                configId = getConfigId(configfile)
                if seed:
                    seed = int(seed)
                initialise(configId,configfile,seed,"train",trainerrorrate=trainerrorrate,
                           trainsourceiteration=trainsourceiteration,numtrainbatches=numtrainbatches,
                           traindialogsperbatch=traindialogsperbatch,traindomains=traindomains,dbprefix=dbprefix)
                for i in range(gtrainsourceiteration, gtrainsourceiteration+gnumtrainbatches):
                    Settings.global_numiter = i + 1
                    if isSingleDomain:
                        logger.results('List of domains: {}'.format(domain))
                        trainBatch(domain, configId, gtrainerrorrate, gtraindialogsperbatch, i, seed=seed)
                    else:
                        logger.results('List of domains: {}'.format(','.join(domains)))
                        trainBatch(domains, configId, gtrainerrorrate, gtraindialogsperbatch, i, seed=seed)
                    if gtesteverybatch and gnumbatchtestdialogs>0 and i+1 < gtrainsourceiteration+gnumtrainbatches:
                        if isSingleDomain:
                            evalPolicy(domain, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)
                        else:
                            evalPolicy(domains, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)
                if gnumbatchtestdialogs > 0:
                    if isSingleDomain:
                        logger.results('List of domains: {}'.format(domain))
                        evalPolicy(domain, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)
                    else:
                        logger.results('List of domains: {}'.format(','.join(domains)))
                        evalPolicy(domains, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)
    
                logger.results("*** Training complete. log: %s - final policy is %s-%02d-%02d" % (logfile, configId, gtrainerrorrate, i+1))
        except clog.ExceptionRaisedByLogger:
            print("Command Aborted - see Log file for error:", logfile)
            exit(0)
        except KeyboardInterrupt:
            print("\nCommand Aborted from Keyboard")
    
    
    def test_command(configfile, iteration, seed=None, testerrorrate=None, trainerrorrate=None,
                     numtestdialogs=None, testdomains=None, dbprefix=None, inputpolicy=None):
        """ Test a specific policy iteration trained at a specific error rate according to the supplied configfile.
            Results are embedded in the logfile specified in the config file.
            Optional parameters over-ride the corresponding config parameters of the same name.
            The testerrorrate can also be specified as a triple (e1,e2,stepsize).  This will repeat the test
            over a range of error rates from e1 to e2.  NB the tuple must be quoted on the command line.
        """
        try:
            if seed and seed.startswith('('):
                seeds = seed.replace('(','').replace(')','').split(',')
                if len(seeds) == 2 and int(seeds[0]) < int(seeds[1]):
                    seeds = [str(x) for x in range(int(seeds[0]), 1 + int(seeds[1]))]
                for seed in seeds:
                    print('*** Seed {} ***'.format(seed))
                    test_command(configfile, iteration, seed=seed, testerrorrate=testerrorrate,
                                  trainerrorrate=trainerrorrate, numtestdialogs=numtestdialogs, testdomains=testdomains,
                                  dbprefix=dbprefix, inputpolicy=inputpolicy)
    
            else:
                errStepping = False
                enErr = None
                if testerrorrate and testerrorrate[0] == '(':
                    if testerrorrate[-1] != ')':
                        print("Missing closing parenthesis in error range %s" % testerrorrate)
                        exit(0)
                    errRange = eval(testerrorrate)
                    if len(errRange) != 3:
                        print("Ill-formed error range %s" % testerrorrate)
                        exit(0)
                    (stErr, enErr, stepErr) = errRange
                    if enErr < stErr or stepErr <= 0:
                        print("Ill-formed test error range [%d,%d,%d]" % testerrorrate)
                        exit(0)
                    errStepping = True
                    testerrorrate = stErr
                if iteration.isdigit():
                    i = int(iteration)
                else:
                    i = iteration
                #if i < 1:
                #    print 'iteration must be > 0'
                #    exit(0)
                configId = getConfigId(configfile)
                orig_seed = '0'
                if seed:
                    orig_seed = seed
                    seed = int(seed) + 100  # To have a different seed during training and testing
                initialise(configId, configfile, seed, "eval", iteration=i, testerrorrate=testerrorrate,
                           testenderrorrate=enErr, trainerrorrate=trainerrorrate,
                           numtestdialogs=numtestdialogs,testdomains=testdomains, dbprefix=dbprefix)
                if type(i) == str:
                    policyname = i
                    if not 'seed' in policyname:
                        ps= policyname.split('-')
                        policyname = '-'.join(ps[:-1] + ['seed{}'.format(orig_seed)] + [ps[-1]])
                else:
                    policyname = "%s-%02d.%d" % (configId, gtrainerrorrate, i)
                poldirpath = path(policy_dir)
                if poldirpath.isdir():
                    policyfiles = poldirpath.files()
                    policynamelist = [p.namebase for p in policyfiles]
                    if isSingleDomain:
                        logger.results('List of domains: {}'.format(domain))
                        if policyname in policynamelist:
                            if errStepping:
                                while stErr <= enErr:
                                    evalPolicy(domain, configId, stErr, gnumtestdialogs, i, seed=seed)
                                    stErr += stepErr
                            else:
                                evalPolicy(domain, configId, gtesterrorrate, gnumtestdialogs, i, seed=seed)
                            logger.results("*** Testing complete. logfile: %s - policy %s evaluated" % (logfile, policyname))
                        else:
                            print("Cannot find policy iteration %s in %s" % (policyname, policy_dir))
                    else:
                        allPolicyFiles = True
                        logger.results('List of domains: {}'.format(','.join(domains)))
                        for dom in domains:
                            multi_policyname = dom+policyname
                            if isSingleModel:
                                multi_policyname = 'singlemodel'+policyname
                            if not multi_policyname in policynamelist:
                                print("Cannot find policy iteration %s in %s" % (multi_policyname, policy_dir))
                                allPolicyFiles = False
                        if allPolicyFiles:
                            if errStepping:
                                while stErr <= enErr:
                                    evalPolicy(domain, configId, stErr, gnumtestdialogs, i)
                                    stErr += stepErr
                            else:
                                evalPolicy(domain, configId, gtesterrorrate, gnumtestdialogs, i)
                            logger.results("*** Testing complete. logfile: %s - policy %s evaluated" % (logfile, policyname))
                else:
                    print("Policy folder %s does not exist" % policy_dir)
        except clog.ExceptionRaisedByLogger:
            print("Command Aborted - see Log file for error:", logfile)
            exit(0)
        except KeyboardInterrupt:
            print("\nCommand Aborted from Keyboard")
    
    
    def plotTrainLogs(logfilelist, printtab, noplot, saveplot, datasetname, block=True):
        """
            Extract data from given log files and display.
        """
        try:
            resultset = {}
            ncurves = 0
            domains = None
    
            if len(logfilelist) < 1:
                print("No log files specified")
                exit(0)
            for fname in logfilelist:
                fn = open(fname, "r")
                if fn:
                    logName = path(fname).namebase
                    if 'epsil0.' in logName:
                        logName = logName.replace('epsil0.', 'epsil0')
                    i = logName.find('.')
                    if i < 0:
                        print("No index info in train log file name")
                        exit(0)
                    curveName = logName[:i]
                    if datasetname == '':
                        i = curveName.find('-')
                        if i >= 0:
                            datasetname=curveName[:i]
                    lines = fn.read().splitlines()
                    results = extractEvalData(lines)
                    npoints = len(results[list(results.keys())[0]])
                    if npoints == 0:
                        print("Log file %s has no plotable data" % fname)
                    else:
                        if len(resultset) == 0:
                            # the list of domains needs to be read from the logfile
                            domains = list(results.keys())
                            for domain in domains:
                                resultset[domain] = {}
                        else:
                            domains_1 = list(resultset.keys()).sort()
                            domains_2 = list(results.keys()).sort()
                            assert domains_1 == domains_2, 'The logfiles have different domains'
                        ncurves += 1
                        for domain in domains:
                            if curveName in list(resultset[domain].keys()):
                                curve = resultset[domain][curveName]
                                for iteration in list(results[domain].keys()):
                                    curve[iteration] = results[domain][iteration]
                            else:
                                resultset[domain][curveName] = results[domain]
                else:
                    print(("Cannot find logfile %s" % fname))
            if ncurves > 0:
                average_results = [[], [], []]
                for domain in domains:
                    (rtab, stab, ttab) = tabulateTrain(resultset[domain])
                    average_results[0].append(rtab)
                    average_results[1].append(stab)
                    average_results[2].append(ttab)
                    if printtab:
                        print("\n%s-%s: Performance vs Num Dialogs\n" % (datasetname, domain))
                        printTable('Reward', rtab)
                        printTable('Success', stab)
                        printTable('Turns', ttab)
                        '''for key in rtab:
                            print key
                            if len(stab[key]['y']) == 1:
                                print '1K', stab[key]['y'][0], rtab[key]['y'][0]
                            else:
                                print '1K', stab[key]['y'][5], rtab[key]['y'][5]
                                print '4K', stab[key]['y'][-1], rtab[key]['y'][-1]'''
                        #print rtab
                        #print stab
    
                    if not noplot:
                        plotTrain(datasetname+'-'+domain,rtab,stab,block=block,saveplot=saveplot)
                # Print average for all domains
                if len(domains) > 1:
                    av_rtab, av_stab, av_ttab = getAverageResults(average_results)
                    plotTrain(datasetname+'-mean', av_rtab, av_stab, block=block,saveplot=saveplot)
            else:
                print("No plotable train data found")
        except clog.ExceptionRaisedByLogger:
            print("Command Aborted - see Log file for error:")
    
    
    def getAverageResults(average_result_list):
        averaged_results = []
        for tab_list in average_result_list:
            n_domains = len(tab_list)
            tab_av_results = {}
            for domain_rtab in tab_list:
                for policy_key in domain_rtab:
                    if not policy_key in tab_av_results:
                        tab_av_results[policy_key] = {}
                    for key in domain_rtab[policy_key]:
                        if not key in tab_av_results[policy_key]:
                            if key == 'var':
                                tab_av_results[policy_key][key] = np.sqrt(np.array(domain_rtab[policy_key][key]))
                            else:
                                tab_av_results[policy_key][key] = np.array(domain_rtab[policy_key][key])
                        else:
                            if key == 'var':
                                tab_av_results[policy_key][key] += np.sqrt(np.array(domain_rtab[policy_key][key]))
                            else:
                                tab_av_results[policy_key][key] += np.array(domain_rtab[policy_key][key])
            #normalise
            for policy_key in tab_av_results:
                for key in tab_av_results[policy_key]:
                    tab_av_results[policy_key][key] /= n_domains
                    if key == 'var':
                        tab_av_results[policy_key][key] = np.square(tab_av_results[policy_key][key])
            averaged_results.append(tab_av_results)
        return averaged_results
    
    
    def plotTestLogs(logfilelist,printtab,noplot,datasetname,block=True):
        """
            Extract data from given eval log files and display performance
            as a function of error rate
        """
        try:
            resultset = {}
            domains = None
            for fname in logfilelist:
                fn = open(fname,"r")
                if fn:
                    lines = fn.read().splitlines()
                    results = extractEvalData(lines)
                    if results:
                        domains = list(results.keys())
                        for domain in domains:
                            if not domain in list(resultset.keys()): 
                                resultset[domain] = {}
                            akey = list(results[domain].keys())[0]
                            aresult = results[domain][akey]
                            if 'policy' in list(aresult.keys()):
                                policyname = results[domain][akey]['policy']
                                if datasetname == '':
                                    i = policyname.find('-')
                                    if i >= 0:
                                        datasetname=policyname[:i]
                                if not policyname in resultset[domain]: resultset[domain][policyname]={}
                                for erate in list(results[domain].keys()):
                                    resultset[domain][policyname][erate] = results[domain][erate]
                            else:
                                print('Format error in log file',fname)
                                exit(0)
                else:
                    print("Cannot find logfile %s" % fname)
                    exit(0)
            for domain in domains:
                if len(list(resultset[domain].keys()))>0:
                    (rtab,stab,ttab) = tabulateTest(resultset[domain])
                    if printtab:
                        print("\n%s-%s: Performance vs Error Rate\n" % (datasetname, domain))
                        printTable('Reward', rtab)
                        printTable('Success', stab)
                        printTable('Turns', ttab)
                    if not noplot:
                        plotTest('%s-%s'%(datasetname, domain),rtab,stab,block=block)
                else:
                    print("No data found")
        except clog.ExceptionRaisedByLogger:
            print("Command Aborted - see Log file for error:")
    
    
    @command.fetch_all('args')
    def plot_command(args="", printtab=False, noplot=False, saveplot=False, datasetname=''):
        """ Call plot with a list of log files and it will print train and test curves.
            For train log files it plots performance vs num dialogs.
            For test log files it plots performance vs error rate.
            Set the printtab option to print a table of results.
            A name can be given to plot via dataset name.
        """
        trainlogs = []
        testlogs = []
        for fname in args:
            if fname.find('train') >= 0:
                trainlogs.append(fname)
            elif fname.find('eval') >= 0:
                testlogs.append(fname)
        block = True
        # if testlogs: block = False
        if noplot: printtab = True    # otherwise no point!
        if trainlogs:
            plotTrainLogs(trainlogs, printtab, noplot, saveplot, datasetname, block)
        if testlogs:
            plotTestLogs(testlogs, printtab, noplot, saveplot, datasetname)
    
    
    def chat_command(configfile, seed=None, trainerrorrate=None, trainsourceiteration=None):
            """ Run the texthub according to the supplied configfile.
            """
            try:
                configId = getConfigId(configfile)
                initialise(configId, configfile, seed, "chat", trainerrorrate=trainerrorrate,
                           trainsourceiteration=trainsourceiteration)
                for dom in domains:
                    if policytypes[dom] != 'hdc':
                        setupPolicy(dom, configId, gtrainerrorrate,
                                    gtrainsourceiteration, gtrainsourceiteration)
                        # make sure that learning is switched off
                        if Settings.config.has_section("policy_" + dom):
                            Settings.config.set("policy_" + dom, "learning", 'False')
                        else:
                            Settings.config.set("policy", "learning", 'False')
                        # if gp, make sure to reset scale to 1 for evaluation
                        if policytypes[dom] == "gp":
                            if Settings.config.has_section("gpsarsa_" + dom):
                                Settings.config.set("gpsarsa_" + dom, "scale", "1")
                            else:
                                Settings.config.set("gpsarsa", "scale", "1")
                mess = "*** Chatting with policies %s: ***" % str(domains)
                if tracedialog > 0: print(mess)
                logger.dial(mess)
    
                # create text hub and run it
                hub = Texthub.ConsoleHub()
                hub.run()
                logger.dial("*** Chat complete")
                # Save a copy of config file
                confsavefile = conf_dir + configId + ".chat.cfg"
                cf = open(confsavefile, 'w')
                Settings.config.write(cf)
            except clog.ExceptionRaisedByLogger:
                print("Command Aborted - see Log file for error:", logfile)
                exit(0)
            except KeyboardInterrupt:
                print("\nCommand Aborted from Keyboard")
    
    
    # class addInfo(object):
    #     def __init__(self, gbatchnum):
    #         self.batch_number = gbatchnum
    
    run()