#! /usr/bin/env python

###############################################################################
# PyDial: Multi-domain Statistical Spoken Dialogue System Software
###############################################################################
#
# Copyright 2015 - 2019
# Cambridge University Engineering Department Dialogue Systems Group
#
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
###############################################################################


import os
from scriptine import run, path, log, command
import re
import numpy as np

# Uncomment for mac os users
import matplotlib
# matplotlib.use('TkAgg')
# matplotlib.use('Agg')

import matplotlib.pyplot as plt

#Uncomment for 4k screens
# matplotlib.rcParams.update({'font.size': 22})

# PyDial modules
import Simulate
import Texthub
from utils import Settings
from utils import ContextLogger
from ontology import Ontology
import utils.ContextLogger as clog
import pprint
pp = pprint.PrettyPrinter(indent=4)

# Remove tensorflow deprecation warnings
#import tensorflow.python.util.deprecation as deprecation
#deprecation._PRINT_DEPRECATION_WARNINGS = False

logger = None
tracedialog = 2
policy_dir = ""
conf_dir = ""
log_dir = ""
logfile = ""

gnumtrainbatches = 0
gtraindialogsperbatch = 0
gnumbatchtestdialogs = 0
gnumtestdialogs = 0
gtrainerrorrate = 0
gtesterrorrate = 0
gtrainsourceiteration = 0
gtesteverybatch = False

gpscale = 1

gplotnum = 1

gbatchnum = 0

isSingleDomain = False
taskID = ""
domain = ""
domains = []
policytype = "hdc"

policytypes = {}


def help_command():
    """ Provide an overview of pydial functionality
    """
    print("\n pydial - command line interface to PyDial")
    print('""""""""""""""""""""""""""""""""""""""""""""')
    print(' o Runs simulator to train and test policies')
    print(' o Plots learning rates and performance vs error rate')
    print(' o Runs texthub in multi-domain chat mode\n')
    print('Basic usage:')
    print('  a) Make pydial.py executable and add a symbolic link to it (eg pydial) from your')
    print('     local bin directory.  Create a directory called ID and cd into it.\n')
    print("  b) create a config file and add an exec_config section eg:\n")
    print('     [exec_config]')
    print('     domain = CamRestaurants     # specific train/test domain')
    print('     policytype = gp             # type of policy to train/test')
    print('     configdir = cfgdir          # folder to store configs')
    print('     logfiledir = logdir         # folder to store logfiles')
    print('     numtrainbatches = 2         # num training batches (iterations)')
    print('     traindialogsperbatch = 10   # num dialogs per batch')
    print('     numbatchtestdialogs =  20   # num dialogs to eval each batch')
    print('     trainsourceiteration = 0    # index of initial source policy to update')
    print('     testiteration = 1           # policy iteration to test')
    print('     numtestdialogs =  10        # num dialogs per test')
    print('     trainerrorrate = 0          # train error rate in %')
    print('     testerrorrate  = 0          # test error rate in %')
    print('     testeverybatch = True       # enable batch testing\n')
    print('     by convention the config file name for training and testing should be of the')
    print('     form ID-policytype-domain.cfg where ID is a user-defined id.')
    print('     (There is more detail on naming conventions below.)')
    print('     Also unless the current directory is the same as the PyDial root')
    print('     make sure that [GENERAL]root points to root of the PyDial source tree.\n')
    print('  c) to train a policy as specified in the config file, type')
    print('       > pydial train config')
    print('     if trainsourceiteration=0 this creates a new policy in n batches where')
    print('     n=numtrainbatches, otherwise an existing policy is trained further.\n')
    print('  d) to test a policy as specified in the config file, type')
    print('       > pydial test config\n')
    print('     texthub.py can be invoked to interact with a policy from the keyboard by:')
    print('       > pydial chat config')
    print('     Note that train and test must refer to a specific domain as per [exec_config] domain')
    print('     whereas chat mode can specify multiple domains via the [GENERAL]domains variable.\n')
    print('  e) for convenience, many config parameters can be overridden on the command line, eg')
    print('       > pydial train config --trainerrorrate=20')
    print('       > pydial test config --iteration=4 --trainerrorrate=20 --testerrorrate=50')
    print('     to train a policy at 20% error rate and test the 4th iteration at 50% error rate.')
    print('     A range of test error rates can be specified as a triple (stErr,enErr,stepSize), eg')
    print("       > pydial test config --iteration=4 --trainerrorrate=20 --testerrorrate='(0,50,10)'")
    print('     to test a policy at 0%, 10%, 20%, 30%, 40%, and 50% error rates.\n')
    print('  f) logfiles for each train/test run are stored in logfiledir.')
    print('     The plot command scans one or more logfiles and extract info to plot eg')
    print('       > pydial plot logdir/*train*')
    print('     Setting the option --printtab, also tabulates the performance data.\n')
    print('  All policy information is stored in the policydir specified in the corresponding ')
    print('  config file section with name [policy_domain]. Since pydial overrides some config')
    print('  params, the actual configs used for each run are recorded in configdir.\n')
    print('  Derived file naming convention:')
    print("     Policyname: ID-poltype-domain-TrainErrRate               eg S0-gp-CamRestaurants-20")
    print("     Policy: ID-poltype-domain-TrainErrRate.Iteration         eg S0-gp-CamRestaurants-20.3")
    print("     Policyfile: ID-poltype-domain-TrainErrRate.Iteration.ext eg S0-gp-CamRestaurants-20.3.dct")
    print("     TrainLogfiles: PolicyName.IterationRange.train.log       eg S0-gp-CamRestaurants-20.1-3.train.log")
    print("     EvalLogfiles:  Policy.eval.ErrorRange.eval.log           eg S0-gp-CamRestaurants-20.3.eval.00-50.log\n")
    print("To get further help:")
    print("  pydial             list of available commands")
    print("  pydial help        this overview")
    print("  pydial cmd --help  help for a specific command\n")


def conventionCheck(name):
    global taskID, domain, policytype
    try:
        if name.find('-') < 0:
            raise Exception('no separators')
        (taskID,p,d)=name.split('-')
        if p != policytype:
            raise Exception('policytype != config param')
        if d != domain:
            raise Exception('domain name != config param')
    except Exception as x:
        pass#log.warn("Non-standard config name [%s] (preferred format ID-policytype-domain.cfg)", x.args[0])


def getConfigId(configFileName):
    i = configFileName.rfind('.')
    if i < 0 or configFileName[i+1:] != 'cfg':
        print(("Config file %s does not have required .cfg extension" % configFileName))
        exit(0)

    cfg = path(configFileName)
    if not cfg.isfile():
        print(("Config file %s does not exist" % configFileName))
        exit(0)
    id = configFileName[:i]
    j = id.rfind('/')
    if j >= 0: id = id[j+1:]
    return id


def getOptionalConfigVar(configvarname, default='', section='exec_config'):
    value = default
    if Settings.config.has_option(section, configvarname):
        value = Settings.config.get(section, configvarname)
    return value


def getRequiredDirectory(directoryname, section='exec_config'):
    assert Settings.config.has_option(section, directoryname),\
        "Value {} in section {} is missing.".format(directoryname, section)
    dir = Settings.config.get(section, directoryname)
    if dir[-1] != '/': dir = dir+'/'
    return dir


def getOptionalConfigInt(configvarname, default='0',section='exec_config'):
    value = default
    if Settings.config.has_option(section, configvarname):
        try:
            value = Settings.config.getint(section, configvarname)
        except ValueError:
            value = Settings.config.get(section, configvarname)



    return value


def getOptionalConfigBool(configvarname, default='False', section='exec_config'):
    value = default
    if Settings.config.has_option(section, configvarname):
        value = Settings.config.getboolean(section, configvarname)
    return value


def initialise(configId, config_file, seed, mode, trainerrorrate=None, trainsourceiteration=None,
               numtrainbatches=None, traindialogsperbatch=None, numtestdialogs=None,
               testerrorrate=None, testenderrorrate=None, iteration=None, traindomains=None, testdomains=None,
               dbprefix=None):
    global logger, logfile, traceDialog, isSingleDomain
    global policy_dir, conf_dir, log_dir
    global gnumtrainbatches, gtraindialogsperbatch, gnumbatchtestdialogs, gnumtestdialogs
    global gtrainerrorrate, gtesterrorrate, gtrainsourceiteration
    global taskID, domain, domains, policytype, gtesteverybatch, gpscale
    global gdeleteprevpolicy, isSingleModel
    global policytypes

    if seed is not None:
        seed = int(seed)
    seed = Settings.init(config_file, seed)
    taskID = 'ID'

    isSingleDomain = getOptionalConfigBool("singledomain", isSingleDomain, "GENERAL")
    isSingleModel = getOptionalConfigBool("singlemodel", False, "policycommittee")
    traceDialog    = getOptionalConfigInt("tracedialog", tracedialog, "GENERAL")
    domain         = getOptionalConfigVar("domains", '', "GENERAL")
    if len(domain.split(',')) > 1 and isSingleDomain:
        logger.error('It cannot be singledomain and have several domains defined, Check config file.')
    if isSingleDomain:
        if Settings.config.has_section('policy_' + domain):
            policytype = getOptionalConfigVar('policytype', policytype, 'policy_' + domain)
        else:
            policytype = getOptionalConfigVar('policytype', policytype, 'policy')
        conventionCheck(configId)
    else:
        domains = getOptionalConfigVar("domains", "", "GENERAL").split(',')
        policytypes = {}
        for domain in domains:
            if Settings.config.has_section('policy_' + domain):
                policytypes[domain] = getOptionalConfigVar('policytype', policytype, 'policy_' + domain)
            else:
                policytypes[domain] = getOptionalConfigVar('policytype', policytype, 'policy')

    # if gp, make sure to save required scale before potentially overriding
    if isSingleDomain:
        if policytype == "gp":
            if Settings.config.has_section("gpsarsa_" + domain):
                try:
                    gpscale = Settings.config.getint("gpsarsa_" + domain, "scale")
                except ValueError:
                    gpscale = Settings.config.get("gpsarsa_" + domain, "scale")
            else:
                try:
                    gpscale = Settings.config.getint("gpsarsa", "scale")
                except ValueError:
                    gpscale = Settings.config.get("gpsarsa", "scale")
    else:
        gpscales = {}
        for domain in domains:
            if policytypes[domain] == "gp":
                if Settings.config.has_section("gpsarsa_" + domain):
                    try:
                        gpscales[domain] = Settings.config.getint("gpsarsa_"+ domain, "scale")
                    except ValueError:
                        gpscales[domain] = Settings.config.get("gpsarsa_"+ domain, "scale")

                else:
                    try:
                        gpscales[domain] = Settings.config.getint("gpsarsa", "scale")
                    except ValueError:
                        gpscales[domain] = Settings.config.get("gpsarsa", "scale")

    # if deep-rl model, make sure to set the correct n_in
    if isSingleDomain:
        if Settings.config.has_section("dqnpolicy"):
            if domain == 'CamRestaurants':
                Settings.config.set("dqnpolicy", 'n_in', '268')
            elif domain == 'SFRestaurants':
                Settings.config.set("dqnpolicy", 'n_in', '636')
            elif domain == 'Laptops11':
                Settings.config.set("dqnpolicy", 'n_in', '257')
                # TODO: set rest of environments and multidomain

    # Get required folders and create if necessary
    log_dir    = getRequiredDirectory("logfiledir")
    conf_dir   = getRequiredDirectory("configdir")
    if isSingleDomain:
        if policytype != 'hdc':
            if Settings.config.has_section("policy_"+domain):
                policy_dir = getRequiredDirectory("policydir", "policy_"+domain)
            else:
                policy_dir = getRequiredDirectory("policydir", "policy")
            pd = path(policy_dir)
            if not pd.isdir():
                print("Policy dir %s does not exist, creating it" % policy_dir)
                pd.mkdir()
    else:
        for domain in domains:
            if policytypes[domain] != 'hdc':
                if Settings.config.has_section("policy_" + domain):
                    policy_dir = getRequiredDirectory("policydir", "policy_" + domain)
                else:
                    policy_dir = getRequiredDirectory("policydir", "policy")
                pd = path(policy_dir)
                if not pd.isdir():
                    print("Policy dir %s does not exist, creating it" % policy_dir)
                    pd.mkdir()

    cd = path(conf_dir)
    if not cd.isdir():
        print("Config dir %s does not exist, creating it" % conf_dir)
        cd.mkdir()
    ld = path(log_dir)
    if not ld.isdir():
        print("Log dir %s does not exist, creating it" % log_dir)
        ld.mkdir()


    # optional config settings
    if numtrainbatches:
        gnumtrainbatches = int(numtrainbatches)
    else:
        gnumtrainbatches = getOptionalConfigInt("numtrainbatches", 1)
    if traindialogsperbatch:
        gtraindialogsperbatch = int(traindialogsperbatch)
    else:
        gtraindialogsperbatch = getOptionalConfigInt("traindialogsperbatch", 100)
    if trainerrorrate:
        gtrainerrorrate = int(trainerrorrate)
    else:
        gtrainerrorrate = getOptionalConfigInt("trainerrorrate", 0)
    if testerrorrate:
        gtesterrorrate = int(testerrorrate)
    else:
        gtesterrorrate = getOptionalConfigInt("testerrorrate",0)
    if trainsourceiteration:
        gtrainsourceiteration = int(trainsourceiteration)
    else:
        gtrainsourceiteration = getOptionalConfigInt("trainsourceiteration",0)
    if numtestdialogs:
        gnumtestdialogs = int(numtestdialogs)
    else:
        gnumtestdialogs = getOptionalConfigInt("numtestdialogs", 50)

    gnumbatchtestdialogs = getOptionalConfigInt("numbatchtestdialogs", 20)
    gtesteverybatch = getOptionalConfigBool("testeverybatch",True)
    gdeleteprevpolicy = getOptionalConfigBool("deleteprevpolicy", False)
    if seed is not None and not 'seed' in configId:
        if seed >= 100 and seed < 200:
            seed_string = 'seed{}-'.format(seed - 100)
        else:
            seed_string = 'seed{}-'.format(seed)
    else:
        seed_string = ''
    if mode == "train":
        if gnumtrainbatches>1:
            enditeration = gtrainsourceiteration+gnumtrainbatches
            logfile = "%s-%s%02d.%d-%d.train.log" % (configId, seed_string,gtrainerrorrate,gtrainsourceiteration+1,enditeration)
        else:
            logfile = "%s-%s%02d.%d.train.log" % (configId, seed_string, gtrainerrorrate, gtrainsourceiteration + 1)
    elif mode == "eval":
        if testenderrorrate:
            logfile = "%s-%s%02d.%d.eval.%02d-%02d.log" % (configId, seed_string,gtrainerrorrate,iteration,
                                                         gtesterrorrate,testenderrorrate)
        else:
            if type(iteration) == str:
                logfile = "{}_vs_{}-{}.eval.log".format(configId, iteration, seed_string[:-1])
            else:
                logfile = "%s-%s%02d.%d.eval.%02d.log" % (configId, seed_string, gtrainerrorrate, iteration, gtesterrorrate)
    elif mode == "chat":
        logfile = "%s-%s%02d.%d.chat.log" % (configId, seed_string, gtrainerrorrate, gtrainsourceiteration)
    else:
        print("Unknown initialisation mode:",mode)
        exit(0)
    print('*** logfile: {} ***'.format(logfile))
    Settings.config.set("logging", "file", log_dir + logfile)
    if traindomains:
        Settings.config.set("GENERAL", "traindomains", traindomains)
    if testdomains:
        Settings.config.set("GENERAL", "testdomains", testdomains)
    if dbprefix:
        Settings.config.set("exec_config", "dbprefix", dbprefix)
    if not Ontology.global_ontology:
        ContextLogger.createLoggingHandlers(config=Settings.config)
        logger = ContextLogger.getLogger('')
        Ontology.init_global_ontology()
    else:
        ContextLogger.resetLoggingHandlers()
        ContextLogger.createLoggingHandlers(config=Settings.config)
        logger = ContextLogger.getLogger('')

    Settings.random.seed(int(seed))
    if Settings.root == '':
        Settings.root = os.getcwd()
    logger.info("Seed = %d", seed)
    logger.info("Root = %s", Settings.root)


def setupPolicy(domain, configId, trainerr, source_iteration, target_iteration, seed=None):
    if Settings.config.has_section("policy_" + domain):
        policy_section = "policy_" + domain
    else:
        policy_section = "policy"
    if not str(source_iteration).isdigit():
        inpolicyfile = source_iteration
        outpolicyfile = source_iteration
    elif seed is not None:
        inpolicyfile = "%s-seed%s-%02d.%d" % (configId, seed, trainerr, source_iteration)
        outpolicyfile = "%s-seed%s-%02d.%d" % (configId, seed, trainerr, target_iteration)
    else:
        inpolicyfile = "%s-%02d.%d" % (configId, trainerr, source_iteration)
        outpolicyfile = "%s-%02d.%d" % (configId, trainerr, target_iteration)
    if isSingleDomain:
        Settings.config.set(policy_section, "inpolicyfile", policy_dir + inpolicyfile)
        Settings.config.set(policy_section, "outpolicyfile", policy_dir + outpolicyfile)
    else:
        multi_policy_dir = policy_dir + domain
        pd = path(multi_policy_dir)
        if not pd.isdir():
            print("Policy dir %s does not exist, creating it" % multi_policy_dir)
            pd.mkdir()
        Settings.config.set(policy_section, "inpolicyfile", multi_policy_dir + inpolicyfile)
        Settings.config.set(policy_section, "outpolicyfile", multi_policy_dir + outpolicyfile)
    return (inpolicyfile, outpolicyfile)


def trainBatch(domain, configId, trainerr, ndialogs, source_iteration, seed=None):
    if isSingleDomain:
        (inpolicy, outpolicy) = setupPolicy(domain, configId, trainerr, source_iteration, source_iteration + 1, seed=seed)
        mess = "*** Training Iteration %s->%s: iter=%d, error-rate=%d, num-dialogs=%d ***" % (
            inpolicy, outpolicy, source_iteration, trainerr, ndialogs)
        if tracedialog > 0: print(mess)
        logger.results(mess)
        # make sure that learning is switched on
        if Settings.config.has_section("policy_" + domain):
            Settings.config.set("policy_" + domain, "learning", 'True')
        else:
            Settings.config.set("policy", "learning", 'True')
        # if gp, make sure to reset scale to config setting
        if policytype == "gp":
            if Settings.config.has_section("gpsarsa_" + domain):
                Settings.config.set("gpsarsa_" + domain, "scale", str(gpscale))
            else:
                Settings.config.set("gpsarsa", "scale", str(gpscale))
        # Define the config file for this iteration
        confsavefile = conf_dir + outpolicy + ".train.cfg"
    else:
        mess = "*** Training Iteration: iter=%d, error-rate=%d, num-dialogs=%d ***" % (
            source_iteration, trainerr, ndialogs)
        if tracedialog > 0: print(mess)
        logger.results(mess)
        for dom in domain:
            setupPolicy(dom, configId, trainerr, source_iteration, source_iteration + 1, seed=seed)
            # make sure that learning is switched on
            if Settings.config.has_section("policy_" + dom):
                Settings.config.set("policy_" + dom, "learning", 'True')
            else:
                Settings.config.set("policy", "learning", 'True')
            # if gp, make sure to reset scale to config setting
            if policytype == "gp":
                Settings.config.set("gpsarsa_" + dom, "scale", str(gpscale))
        # Define the config file for this iteration
        multipolicy = "%s-%02d.%d" % (configId, trainerr, source_iteration + 1)
        confsavefile = conf_dir + multipolicy + ".train.cfg"

    # Save the config file for this iteration
    cf = open(confsavefile, 'w')
    Settings.config.write(cf)
    error = float(trainerr) / 100.0
    # run the system
    simulator = Simulate.SimulationSystem(error_rate=error)
    simulator.run_dialogs(ndialogs)
    if gdeleteprevpolicy:
        if isSingleDomain:
            if inpolicy[-1] != '0':
                if Settings.config.has_section("policy_" + domain):
                    for f in os.listdir(Settings.config.get('policy_{}'.format(domain), 'policydir')):
                        if re.search(inpolicy, f):
                            os.remove(os.path.join(Settings.config.get('policy_{}'.format(domain), 'policydir'), f))
                else:
                    for f in os.listdir(Settings.config.get('policy', 'policydir')):
                        if re.search(inpolicy, f):
                            os.remove(os.path.join(Settings.config.get('policy', 'policydir'), f))


def setEvalConfig(domain, configId, evalerr, ndialogs, iteration, seed=None):
    (_, policy) = setupPolicy(domain, configId, gtrainerrorrate, iteration, iteration, seed=seed)
    if isSingleDomain:
        mess = "*** Evaluating %s: error-rate=%d num-dialogs=%d ***" % (policy, evalerr, ndialogs)
    else:
        mess = "*** Evaluating %s: error-rate=%d num-dialogs=%d ***" % (policy.replace('Multidomain', domain),
                                                                        evalerr, ndialogs)
    if tracedialog > 0: print(mess)
    logger.results(mess)
    # make sure that learning is switched off
    if Settings.config.has_section("policy_" + domain):
        Settings.config.set("policy_" + domain, "learning", 'False')
    else:
        Settings.config.set("policy", "learning", 'False')
    # if gp, make sure to reset scale to 1 for evaluation
    if policytype == "gp":
        if Settings.config.has_section("gpsarsa_" + domain):
            Settings.config.set("gpsarsa_" + domain, "scale", "1")
        else:
            Settings.config.set("gpsarsa", "scale", "1")
    # Save a copy of config file
    confsavefile = conf_dir + "%s.eval.%02d.cfg" % (policy, evalerr)
    cf = open(confsavefile, 'w')
    Settings.config.write(cf)


def evalPolicy(domain, configId, evalerr, ndialogs, iteration, seed=None):
    if isSingleDomain:
        setEvalConfig(domain, configId, evalerr, ndialogs, iteration, seed=seed)
    else:
        for dom in domains:
            setEvalConfig(dom, configId, evalerr, ndialogs, iteration, seed=seed)

    error = float(evalerr) / 100.0
    # finally run the system
    simulator = Simulate.SimulationSystem(error_rate=error)
    simulator.run_dialogs(ndialogs)


def getIntParam(line, key):
    m = re.search(" %s *= *(\d+)" % (key), line) #what is this parenthesisi placement here and below???
    if m is None:
        print("Cant find int %s in %s" % (key, line))
        exit(0)
    return int(m.group(1))


def getFloatRange(line,key):
    m = re.search(" %s *= *(\-?\d+\.\d+) *\+- *(\d+\.\d+)" % (key), line)
    if m==None:
        print("Cant find float %s in %s" % (key, line))
        exit(0)
    return (float(m.group(1)),float(m.group(2)))


def getDomainFromLog(l):
    return l.split()[-1].split(',')


def extractEvalData(lines):
    evalData = {}
    training = False
    domain_list = []
    #domain_list = []#['SFRestaurants','SFHotels','Laptops11']
    #for dom in domain_list:
    #    evalData[dom] = {}
    cur_domain = None
    for l in lines:
        if l.find('List of domains:') >= 0:
            # get the list of domains from the log by reading the lines where the ontologies are loaded
            doms = getDomainFromLog(l)
            for domain in doms:
                if domain not in domain_list:
                    domain_list.append(domain)
                    evalData[domain] = {}
        if l.find('*** Training Iteration') >= 0:
            iteration = getIntParam(l, 'iter')+1
            if iteration in list(evalData.keys()):
                print("Duplicate iteration %d" % iteration)
                exit(0)
            for domain in domain_list:
                evalData[domain][iteration] = {}
                evalData[domain][iteration]['erate'] = getIntParam(l, 'error-rate')
                evalData[domain][iteration]['ndialogs'] = getIntParam(l, 'num-dialogs')
            training = True
        elif l.find('*** Evaluating')>=0 and not training:
            l = l.replace('CR', 'CamRestaurants') 
            erate = getIntParam(l, 'error-rate')
            ll = l[l.find('*** Evaluating') + len('*** Evaluating')+1:]
            (ll,x) = ll.split(':')
            for domain in domain_list:
                if domain in ll:
                    evalData[domain][erate] = {}
                    evalData[domain][erate]['policy'] = ll
                    evalData[domain][erate]['ndialogs'] = getIntParam(l, 'num-dialogs')
        elif l.find('Results for domain:') >= 0:
            cur_domain = l.split('Results for domain:')[1].split('--')[0].strip()
        elif l.find('Average reward') >= 0:
            if training:
                evalData[cur_domain][iteration]['reward'] = getFloatRange(l, 'Average reward')
            else:
                evalData[cur_domain][erate]['reward'] = getFloatRange(l, 'Average reward')
        elif l.find('Average success') >= 0:
            if training:
                evalData[cur_domain][iteration]['success'] = getFloatRange(l, 'Average success')
            else:
                evalData[cur_domain][erate]['success'] = getFloatRange(l, 'Average success')

        elif l.find('Average turns') >= 0:
            if training:
                evalData[cur_domain][iteration]['turns'] = getFloatRange(l, 'Average turns')
            else:
                evalData[cur_domain][erate]['turns'] = getFloatRange(l, 'Average turns')
    return evalData


def plotTrain(dname, rtab, stab, block=True, saveplot=False):
    font = {
            'weight': 'bold',
            'size': 20}
    plt.rc('font', **font)

    global gplotnum
    policylist = sorted(rtab.keys())
    ncurves = len(policylist)
    plt.figure(gplotnum)

    gplotnum += 1
    for policy in policylist:
        tab = rtab[policy]
        plt.subplot(211)
        # plt.xlim((800, 4200))
        if len(tab['x']) < 2:
            plt.axhline(y=tab['y'][0], linestyle='--')
        else:
            plt.errorbar(tab['x'],tab['y'], yerr=tab['var'], label=policy)
            # plt.errorbar(tab['x'], tab['y'], label=policy)
        tab = stab[policy]
        plt.subplot(212)
        # plt.xlim((800, 4200))
        if len(tab['x']) < 2:
            plt.axhline(y=tab['y'][0], linestyle='--')
        else:
            plt.errorbar(tab['x'],tab['y'],yerr=tab['var'],label=policy)
            # plt.errorbar(tab['x'], tab['y'], label=policy)
    plt.subplot(211)
    plt.grid()
    plt.legend(loc='lower right', fontsize=14)  # loc='lower right', best,
    plt.title("Performance vs Num Train Dialogues")
    plt.ylabel('Reward')
    plt.subplot(212)
    plt.grid()
    plt.legend(loc='lower right', fontsize=14)
    plt.xlabel('Num Dialogues')
    plt.ylabel('Success')
    if saveplot:
        if not os.path.exists('_plots'):
            os.mkdir('_plots')
        plt.savefig('_plots/' + dname + '.png', bbox_inches='tight')
        print('plot saved as', dname)
    else:
        plt.show(block=block)


def plotTest(dname, rtab, stab, block=True, saveplot=False):
    global gplotnum
    policylist = sorted(rtab.keys())
    ncurves = len(policylist)
    plt.figure(gplotnum)
    gplotnum += 1
    for policy in policylist:
        tab = rtab[policy]
        plt.subplot(211)
        plt.errorbar(tab['x'], tab['y'], yerr=tab['var'], label=policy)
        tab = stab[policy]
        plt.subplot(212)
        plt.errorbar(tab['x'], tab['y'], yerr=tab['var'], label=policy)
    plt.subplot(211)
    plt.grid()
    plt.legend(loc='lower left', fontsize=12-ncurves)
    plt.title(dname+" Performance vs Error Rate")
    plt.ylabel('Reward')
    plt.subplot(212)
    plt.grid()
    plt.legend(loc='lower left', fontsize=12-ncurves)
    plt.xlabel('Error Rate')
    plt.ylabel('Success')
    # plt.show(block=block)
    plt.show()


def printTable(title, tab):
    firstrow = True
    policylist = sorted(tab.keys())
    for policy in policylist:
        xvals = tab[policy]['x']
        if firstrow:
            s = "%-20s" % title
            for i in range(0, len(xvals)): s += "%13d" % xvals[i]
            print(s)
            firstrow = False
        s = "%-18s :" % policy
        for i in range(0,len(xvals)):
            s+= "%6.1f +-%4.1f" % (tab[policy]['y'][i],tab[policy]['var'][i])
        print(s)
    print("")


def tabulateTrain(dataSet):
    #pp.pprint(dataSet)
    rtab = {}
    stab = {}
    ttab = {}
    oldx = []
    for policy in list(dataSet.keys()):
        yvals = []
        xvals = []
        dialogsum = 0
        for iteration in list(dataSet[policy].keys()):
            d = dataSet[policy][iteration]
            (yr, yrv) = d['reward']
            (ys, ysv) = d['success']
            (yt, ytv) = d['turns']
            ndialogs = d['ndialogs']
            dialogsum += ndialogs
            yvals.append((yr, yrv, ys, ysv, yt, ytv))
            xvals.append(dialogsum)
        yvals = [yy for (xx, yy) in sorted(zip(xvals, yvals))]
        x = [xx for (xx, yy) in sorted(zip(xvals, yvals))]
        #if oldx != [] and oldx != x:
        #    print "Policy %s has inconsistent batch sizes" % policy
        oldx = x
        yrew = [yr for (yr, yrv, ys, ysv, yt, ytv) in yvals]
        yrerr = [yrv for (yr, yrv, ys, ysv, yt, ytv) in yvals]
        ysucc = [ys for (yr, yrv, ys, ysv, yt, ytv) in yvals]
        yserr = [ysv for (yr, yrv, ys, ysv, yt, ytv) in yvals]
        yturn = [yt for (yr, yrv, ys, ysv, yt, ytv) in yvals]
        yterr = [ytv for (yr, yrv, ys, ysv, yt, ytv) in yvals]
        if not (policy in list(rtab.keys())): rtab[policy] = {}
        rtab[policy]['y'] = yrew
        rtab[policy]['var'] = yrerr
        rtab[policy]['x'] = x
        if not (policy in list(stab.keys())): stab[policy] = {}
        stab[policy]['y'] = ysucc
        stab[policy]['var'] = yserr
        stab[policy]['x'] = x
        if not (policy in list(ttab.keys())): ttab[policy] = {}
        ttab[policy]['y'] = yturn
        ttab[policy]['var'] = yterr
        ttab[policy]['x'] = x
    # average results over seeds
    averaged_result_list = []
    for result in [rtab, stab, ttab]:
        averaged_result = {}
        n_seeds = {}
        for policy_key in result:
            if "seed" in policy_key:
                seed_n = policy_key[policy_key.find("seed"):]
                seed_n = seed_n.split('-')[0]
                general_policy_key = policy_key.replace(seed_n + '-', '')
            else:
                general_policy_key = policy_key
            if not general_policy_key in averaged_result:
                averaged_result[general_policy_key] = {}
                n_seeds[general_policy_key] = 1
            else:
                n_seeds[general_policy_key] += 1
            for key in result[policy_key]:
                if not key in averaged_result[general_policy_key]:
                    averaged_result[general_policy_key][key] = np.array(result[policy_key][key])
                else:
                    averaged_result[general_policy_key][key] += np.array(result[policy_key][key])
        for policy_key in averaged_result:
            for key in averaged_result[policy_key]:
                averaged_result[policy_key][key] = averaged_result[policy_key][key]/n_seeds[policy_key]
        averaged_result_list.append(averaged_result)

    return averaged_result_list


def tabulateTest(dataSet):
    #pp.pprint(dataSet)
    rtab = {}
    stab = {}
    ttab = {}
    oldx = []
    for policy in list(dataSet.keys()):
        yvals = []
        xvals = []
        for erate in list(dataSet[policy].keys()):
            d = dataSet[policy][erate]
            (yr,yrv) = d['reward']
            (ys,ysv) = d['success']
            (yt,ytv) = d['turns']
            yvals.append((yr, yrv, ys, ysv, yt, ytv))
            xvals.append(erate)
        yvals = [yy for (xx, yy) in sorted(zip(xvals, yvals))]
        x = [xx for (xx,yy) in sorted(zip(xvals, yvals))]
        if oldx != [] and oldx != x:
            print("Policy %s has inconsistent range of error rates" % policy)
            exit(0)
        oldx = x
        yrew = [yr for (yr,yrv,ys,ysv,yt,ytv) in yvals]
        yrerr = [yrv for (yr,yrv,ys,ysv,yt,ytv) in yvals]
        ysucc = [ys for (yr,yrv,ys,ysv,yt,ytv) in yvals]
        yserr = [ysv for (yr,yrv,ys,ysv,yt,ytv) in yvals]
        yturn = [yt for (yr,yrv,ys,ysv,yt,ytv) in yvals]
        yterr = [ytv for (yr,yrv,ys,ysv,yt,ytv) in yvals]
        if not (policy in list(rtab.keys())): rtab[policy]={}
        rtab[policy]['y'] = yrew
        rtab[policy]['var'] = yrerr
        rtab[policy]['x'] = x
        if not (policy in list(stab.keys())): stab[policy] = {}
        stab[policy]['y'] = ysucc
        stab[policy]['var'] = yserr
        stab[policy]['x'] = x
        if not (policy in list(ttab.keys())): ttab[policy] = {}
        ttab[policy]['y'] = yturn
        ttab[policy]['var'] = yterr
        ttab[policy]['x'] = x
    return rtab, stab, ttab


def train_command(configfile, seed=None, trainerrorrate=None,trainsourceiteration=None,
                  numtrainbatches=None, traindialogsperbatch=None, traindomains=None, dbprefix=None):
    """ Train a policy according to the supplied configfile.
        Results are stored in the directories specified in the [exec_config] section of the config file.
        Optional parameters over-ride the corresponding config parameters of the same name.
    """

    try:
        if seed and seed.startswith('('):
            seeds = seed.replace('(', '').replace(')', '').split(',')
            if len(seeds) == 2 and int(seeds[0]) < int(seeds[1]):
                seeds = [str(x) for x in range(int(seeds[0]), 1+int(seeds[1]))]
            for seed in seeds:
                print('*** Seed {} ***'.format(seed))
                train_command(configfile, seed=seed, trainerrorrate=trainerrorrate,
                              trainsourceiteration=trainsourceiteration,
                              numtrainbatches=numtrainbatches, traindialogsperbatch=traindialogsperbatch,
                              traindomains=traindomains, dbprefix=dbprefix)

        else:
            configId = getConfigId(configfile)
            if seed:
                seed = int(seed)
            initialise(configId,configfile,seed,"train",trainerrorrate=trainerrorrate,
                       trainsourceiteration=trainsourceiteration,numtrainbatches=numtrainbatches,
                       traindialogsperbatch=traindialogsperbatch,traindomains=traindomains,dbprefix=dbprefix)
            for i in range(gtrainsourceiteration, gtrainsourceiteration+gnumtrainbatches):
                Settings.global_numiter = i + 1
                if isSingleDomain:
                    logger.results('List of domains: {}'.format(domain))
                    trainBatch(domain, configId, gtrainerrorrate, gtraindialogsperbatch, i, seed=seed)
                else:
                    logger.results('List of domains: {}'.format(','.join(domains)))
                    trainBatch(domains, configId, gtrainerrorrate, gtraindialogsperbatch, i, seed=seed)
                if gtesteverybatch and gnumbatchtestdialogs>0 and i+1 < gtrainsourceiteration+gnumtrainbatches:
                    if isSingleDomain:
                        evalPolicy(domain, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)
                    else:
                        evalPolicy(domains, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)
            if gnumbatchtestdialogs > 0:
                if isSingleDomain:
                    logger.results('List of domains: {}'.format(domain))
                    evalPolicy(domain, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)
                else:
                    logger.results('List of domains: {}'.format(','.join(domains)))
                    evalPolicy(domains, configId, gtrainerrorrate, gnumbatchtestdialogs, i + 1, seed=seed)

            logger.results("*** Training complete. log: %s - final policy is %s-%02d-%02d" % (logfile, configId, gtrainerrorrate, i+1))
    except clog.ExceptionRaisedByLogger:
        print("Command Aborted - see Log file for error:", logfile)
        exit(0)
    except KeyboardInterrupt:
        print("\nCommand Aborted from Keyboard")


def test_command(configfile, iteration, seed=None, testerrorrate=None, trainerrorrate=None,
                 numtestdialogs=None, testdomains=None, dbprefix=None, inputpolicy=None):
    """ Test a specific policy iteration trained at a specific error rate according to the supplied configfile.
        Results are embedded in the logfile specified in the config file.
        Optional parameters over-ride the corresponding config parameters of the same name.
        The testerrorrate can also be specified as a triple (e1,e2,stepsize).  This will repeat the test
        over a range of error rates from e1 to e2.  NB the tuple must be quoted on the command line.
    """
    try:
        if seed and seed.startswith('('):
            seeds = seed.replace('(','').replace(')','').split(',')
            if len(seeds) == 2 and int(seeds[0]) < int(seeds[1]):
                seeds = [str(x) for x in range(int(seeds[0]), 1 + int(seeds[1]))]
            for seed in seeds:
                print('*** Seed {} ***'.format(seed))
                test_command(configfile, iteration, seed=seed, testerrorrate=testerrorrate,
                              trainerrorrate=trainerrorrate, numtestdialogs=numtestdialogs, testdomains=testdomains,
                              dbprefix=dbprefix, inputpolicy=inputpolicy)

        else:
            errStepping = False
            enErr = None
            if testerrorrate and testerrorrate[0] == '(':
                if testerrorrate[-1] != ')':
                    print("Missing closing parenthesis in error range %s" % testerrorrate)
                    exit(0)
                errRange = eval(testerrorrate)
                if len(errRange) != 3:
                    print("Ill-formed error range %s" % testerrorrate)
                    exit(0)
                (stErr, enErr, stepErr) = errRange
                if enErr < stErr or stepErr <= 0:
                    print("Ill-formed test error range [%d,%d,%d]" % testerrorrate)
                    exit(0)
                errStepping = True
                testerrorrate = stErr
            if iteration.isdigit():
                i = int(iteration)
            else:
                i = iteration
            #if i < 1:
            #    print 'iteration must be > 0'
            #    exit(0)
            configId = getConfigId(configfile)
            orig_seed = '0'
            if seed:
                orig_seed = seed
                seed = int(seed) + 100  # To have a different seed during training and testing
            initialise(configId, configfile, seed, "eval", iteration=i, testerrorrate=testerrorrate,
                       testenderrorrate=enErr, trainerrorrate=trainerrorrate,
                       numtestdialogs=numtestdialogs,testdomains=testdomains, dbprefix=dbprefix)
            if type(i) == str:
                policyname = i
                if not 'seed' in policyname:
                    ps= policyname.split('-')
                    policyname = '-'.join(ps[:-1] + ['seed{}'.format(orig_seed)] + [ps[-1]])
            else:
                policyname = "%s-%02d.%d" % (configId, gtrainerrorrate, i)
            poldirpath = path(policy_dir)
            if poldirpath.isdir():
                policyfiles = poldirpath.files()
                policynamelist = [p.namebase for p in policyfiles]
                if isSingleDomain:
                    logger.results('List of domains: {}'.format(domain))
                    if policyname in policynamelist:
                        if errStepping:
                            while stErr <= enErr:
                                evalPolicy(domain, configId, stErr, gnumtestdialogs, i, seed=seed)
                                stErr += stepErr
                        else:
                            evalPolicy(domain, configId, gtesterrorrate, gnumtestdialogs, i, seed=seed)
                        logger.results("*** Testing complete. logfile: %s - policy %s evaluated" % (logfile, policyname))
                    else:
                        print("Cannot find policy iteration %s in %s" % (policyname, policy_dir))
                else:
                    allPolicyFiles = True
                    logger.results('List of domains: {}'.format(','.join(domains)))
                    for dom in domains:
                        multi_policyname = dom+policyname
                        if isSingleModel:
                            multi_policyname = 'singlemodel'+policyname
                        if not multi_policyname in policynamelist:
                            print("Cannot find policy iteration %s in %s" % (multi_policyname, policy_dir))
                            allPolicyFiles = False
                    if allPolicyFiles:
                        if errStepping:
                            while stErr <= enErr:
                                evalPolicy(domain, configId, stErr, gnumtestdialogs, i)
                                stErr += stepErr
                        else:
                            evalPolicy(domain, configId, gtesterrorrate, gnumtestdialogs, i)
                        logger.results("*** Testing complete. logfile: %s - policy %s evaluated" % (logfile, policyname))
            else:
                print("Policy folder %s does not exist" % policy_dir)
    except clog.ExceptionRaisedByLogger:
        print("Command Aborted - see Log file for error:", logfile)
        exit(0)
    except KeyboardInterrupt:
        print("\nCommand Aborted from Keyboard")


def plotTrainLogs(logfilelist, printtab, noplot, saveplot, datasetname, block=True):
    """
        Extract data from given log files and display.
    """
    try:
        resultset = {}
        ncurves = 0
        domains = None

        if len(logfilelist) < 1:
            print("No log files specified")
            exit(0)
        for fname in logfilelist:
            fn = open(fname, "r")
            if fn:
                logName = path(fname).namebase
                if 'epsil0.' in logName:
                    logName = logName.replace('epsil0.', 'epsil0')
                i = logName.find('.')
                if i < 0:
                    print("No index info in train log file name")
                    exit(0)
                curveName = logName[:i]
                if datasetname == '':
                    i = curveName.find('-')
                    if i >= 0:
                        datasetname=curveName[:i]
                lines = fn.read().splitlines()
                results = extractEvalData(lines)
                npoints = len(results[list(results.keys())[0]])
                if npoints == 0:
                    print("Log file %s has no plotable data" % fname)
                else:
                    if len(resultset) == 0:
                        # the list of domains needs to be read from the logfile
                        domains = list(results.keys())
                        for domain in domains:
                            resultset[domain] = {}
                    else:
                        domains_1 = list(resultset.keys()).sort()
                        domains_2 = list(results.keys()).sort()
                        assert domains_1 == domains_2, 'The logfiles have different domains'
                    ncurves += 1
                    for domain in domains:
                        if curveName in list(resultset[domain].keys()):
                            curve = resultset[domain][curveName]
                            for iteration in list(results[domain].keys()):
                                curve[iteration] = results[domain][iteration]
                        else:
                            resultset[domain][curveName] = results[domain]
            else:
                print(("Cannot find logfile %s" % fname))
        if ncurves > 0:
            average_results = [[], [], []]
            for domain in domains:
                (rtab, stab, ttab) = tabulateTrain(resultset[domain])
                average_results[0].append(rtab)
                average_results[1].append(stab)
                average_results[2].append(ttab)
                if printtab:
                    print("\n%s-%s: Performance vs Num Dialogs\n" % (datasetname, domain))
                    printTable('Reward', rtab)
                    printTable('Success', stab)
                    printTable('Turns', ttab)
                    '''for key in rtab:
                        print key
                        if len(stab[key]['y']) == 1:
                            print '1K', stab[key]['y'][0], rtab[key]['y'][0]
                        else:
                            print '1K', stab[key]['y'][5], rtab[key]['y'][5]
                            print '4K', stab[key]['y'][-1], rtab[key]['y'][-1]'''
                    #print rtab
                    #print stab

                if not noplot:
                    plotTrain(datasetname+'-'+domain,rtab,stab,block=block,saveplot=saveplot)
            # Print average for all domains
            if len(domains) > 1:
                av_rtab, av_stab, av_ttab = getAverageResults(average_results)
                plotTrain(datasetname+'-mean', av_rtab, av_stab, block=block,saveplot=saveplot)
        else:
            print("No plotable train data found")
    except clog.ExceptionRaisedByLogger:
        print("Command Aborted - see Log file for error:")


def getAverageResults(average_result_list):
    averaged_results = []
    for tab_list in average_result_list:
        n_domains = len(tab_list)
        tab_av_results = {}
        for domain_rtab in tab_list:
            for policy_key in domain_rtab:
                if not policy_key in tab_av_results:
                    tab_av_results[policy_key] = {}
                for key in domain_rtab[policy_key]:
                    if not key in tab_av_results[policy_key]:
                        if key == 'var':
                            tab_av_results[policy_key][key] = np.sqrt(np.array(domain_rtab[policy_key][key]))
                        else:
                            tab_av_results[policy_key][key] = np.array(domain_rtab[policy_key][key])
                    else:
                        if key == 'var':
                            tab_av_results[policy_key][key] += np.sqrt(np.array(domain_rtab[policy_key][key]))
                        else:
                            tab_av_results[policy_key][key] += np.array(domain_rtab[policy_key][key])
        #normalise
        for policy_key in tab_av_results:
            for key in tab_av_results[policy_key]:
                tab_av_results[policy_key][key] /= n_domains
                if key == 'var':
                    tab_av_results[policy_key][key] = np.square(tab_av_results[policy_key][key])
        averaged_results.append(tab_av_results)
    return averaged_results


def plotTestLogs(logfilelist,printtab,noplot,datasetname,block=True):
    """
        Extract data from given eval log files and display performance
        as a function of error rate
    """
    try:
        resultset = {}
        domains = None
        for fname in logfilelist:
            fn = open(fname,"r")
            if fn:
                lines = fn.read().splitlines()
                results = extractEvalData(lines)
                if results:
                    domains = list(results.keys())
                    for domain in domains:
                        if not domain in list(resultset.keys()): 
                            resultset[domain] = {}
                        akey = list(results[domain].keys())[0]
                        aresult = results[domain][akey]
                        if 'policy' in list(aresult.keys()):
                            policyname = results[domain][akey]['policy']
                            if datasetname == '':
                                i = policyname.find('-')
                                if i >= 0:
                                    datasetname=policyname[:i]
                            if not policyname in resultset[domain]: resultset[domain][policyname]={}
                            for erate in list(results[domain].keys()):
                                resultset[domain][policyname][erate] = results[domain][erate]
                        else:
                            print('Format error in log file',fname)
                            exit(0)
            else:
                print("Cannot find logfile %s" % fname)
                exit(0)
        for domain in domains:
            if len(list(resultset[domain].keys()))>0:
                (rtab,stab,ttab) = tabulateTest(resultset[domain])
                if printtab:
                    print("\n%s-%s: Performance vs Error Rate\n" % (datasetname, domain))
                    printTable('Reward', rtab)
                    printTable('Success', stab)
                    printTable('Turns', ttab)
                if not noplot:
                    plotTest('%s-%s'%(datasetname, domain),rtab,stab,block=block)
            else:
                print("No data found")
    except clog.ExceptionRaisedByLogger:
        print("Command Aborted - see Log file for error:")


@command.fetch_all('args')
def plot_command(args="", printtab=False, noplot=False, saveplot=False, datasetname=''):
    """ Call plot with a list of log files and it will print train and test curves.
        For train log files it plots performance vs num dialogs.
        For test log files it plots performance vs error rate.
        Set the printtab option to print a table of results.
        A name can be given to plot via dataset name.
    """
    trainlogs = []
    testlogs = []
    for fname in args:
        if fname.find('train') >= 0:
            trainlogs.append(fname)
        elif fname.find('eval') >= 0:
            testlogs.append(fname)
    block = True
    # if testlogs: block = False
    if noplot: printtab = True    # otherwise no point!
    if trainlogs:
        plotTrainLogs(trainlogs, printtab, noplot, saveplot, datasetname, block)
    if testlogs:
        plotTestLogs(testlogs, printtab, noplot, saveplot, datasetname)


def chat_command(configfile, seed=None, trainerrorrate=None, trainsourceiteration=None):
        """ Run the texthub according to the supplied configfile.
        """
        try:
            configId = getConfigId(configfile)
            initialise(configId, configfile, seed, "chat", trainerrorrate=trainerrorrate,
                       trainsourceiteration=trainsourceiteration)
            for dom in domains:
                if policytypes[dom] != 'hdc':
                    setupPolicy(dom, configId, gtrainerrorrate,
                                gtrainsourceiteration, gtrainsourceiteration)
                    # make sure that learning is switched off
                    if Settings.config.has_section("policy_" + dom):
                        Settings.config.set("policy_" + dom, "learning", 'False')
                    else:
                        Settings.config.set("policy", "learning", 'False')
                    # if gp, make sure to reset scale to 1 for evaluation
                    if policytypes[dom] == "gp":
                        if Settings.config.has_section("gpsarsa_" + dom):
                            Settings.config.set("gpsarsa_" + dom, "scale", "1")
                        else:
                            Settings.config.set("gpsarsa", "scale", "1")
            mess = "*** Chatting with policies %s: ***" % str(domains)
            if tracedialog > 0: print(mess)
            logger.dial(mess)

            # create text hub and run it
            hub = Texthub.ConsoleHub()
            hub.run()
            logger.dial("*** Chat complete")
            # Save a copy of config file
            confsavefile = conf_dir + configId + ".chat.cfg"
            cf = open(confsavefile, 'w')
            Settings.config.write(cf)
        except clog.ExceptionRaisedByLogger:
            print("Command Aborted - see Log file for error:", logfile)
            exit(0)
        except KeyboardInterrupt:
            print("\nCommand Aborted from Keyboard")


# class addInfo(object):
#     def __init__(self, gbatchnum):
#         self.batch_number = gbatchnum

run()