############################################################################### # PyDial: Multi-domain Statistical Spoken Dialogue System Software ############################################################################### # # Copyright 2015 - 2019 # Cambridge University Engineering Department Dialogue Systems Group # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ############################################################################### ''' Simulate.py - semantic level user simulator system. ==================================================== Copyright CUED Dialogue Systems Group 2015 - 2017 **Basic Execution**: >>> python Simulate.py [-h] -C CONFIG [-n -r -l -t -g -s] Optional arguments/flags [default values]:: -n Number of dialogs [1] -r semantic error rate [0] -s set random seed -g generate text prompts -h help **Relevant Config variables** [Default values]:: [simulate] maxturns = 30 continuewhensuccessful = False forcenullpositive = False confscorer = additive .. seealso:: CUED Imports/Dependencies: import :mod:`utils.ContextLogger` |.| import :mod:`utils.Settings` |.| import :mod:`usersimulator.SimulatedUsersManager` |.| import :mod:`ontology.FlatOntology` |.| import :mod:`Agent` |.| ************************ ''' import os import argparse import Agent from usersimulator import SimulatedUsersManager from utils import Settings from utils import ContextLogger from ontology import Ontology logger = ContextLogger.getLogger('') __author__ = "cued_dialogue_systems_group" __version__ = Settings.__version__ class SimulationSystem(object): ''' Semantic level simulated dialog system ''' def __init__(self, error_rate): ''' :param error_rate: error rate of the simulated environment :type error_rate: float ''' # Dialogue Agent Factory: #----------------------------------------- self.agent_factory = Agent.AgentFactory(hub_id='simulate') # NOTE - using agent factory here rather than just an agent - since for simulate I can easily envisage wanting to # have multiple agents and looking at combining their policies etc... This is not being used now though; will just use # a single agent in here at present. # Simulated User. #----------------------------------------- self.simulator = SimulatedUsersManager.SimulatedUsersManager(error_rate) self.traceDialog = 2 self.sim_level = 'dial_act' self.text_sampling = 'dict' if Settings.config.has_option("GENERAL", "tracedialog"): self.traceDialog = Settings.config.getint("GENERAL", "tracedialog") if Settings.config.has_option("usermodel", "simlevel"): self.sim_level = Settings.config.get("usermodel", "simlevel") if Settings.config.has_option("usermodel", "textsampling"): self.text_sampling = Settings.config.get("usermodel", "textsampling") if self.sim_level == 'text': #Load the text generator if self.text_sampling == 'dict': sampling_dict = os.path.join(Settings.root, 'usersimulator/textgenerator/textgen_dict.pkl') else: sampling_dict = None import usersimulator.textgenerator.textgen_toolkit.SCTranslate as SCT self.SCT = SCT.SCTranslate(sampling_dict=sampling_dict) elif self.sim_level == 'sys2text': pass #load here florians model def run_dialogs(self, numDialogs): ''' run a loop over the run() method for the given number of dialogues. :param numDialogs: number of dialogues to loop over. :type numDialogs: int :return: None ''' for i in range(numDialogs): logger.info('Dialogue %d' % (i+1)) self.run(session_id='simulate_dialog'+str(i), sim_level=self.sim_level) self.agent_factory.power_down_factory() # Important! -uses FORCE_SAVE on policy- which will finalise learning and save policy. def run(self, session_id, agent_id='Smith', sim_level='dial_act'): ''' Runs one episode through the simulator :param session_id: session id :type session_id: int :param agent_id: agent id, default = 'Smith' :type agent_id: string :return: None ''' # RESET THE USER SIMULATOR: self.simulator.restart() for domain in self.simulator.simUserManagers: if self.simulator.simUserManagers[domain] and self.sim_level != 'sys2text': goal = self.simulator.simUserManagers[domain].um.goal logger.dial('User will execute the following goal: {}' .format(str(goal.request_type) + str(goal.constraints) + str([req for req in goal.requests]))) user_act = '' endingDialogue = False # SYSTEM STARTS THE CALL: sys_act = self.agent_factory.agents[agent_id].start_call(session_id, domainSimulatedUsers=self.simulator.simUserManagers, maxNumTurnsScaling=self.simulator.number_domains_this_dialog) prompt_str = sys_act.prompt if prompt_str is not None: # if we are generating text, versus remaining only at semantic level. if self.traceDialog > 1: print(' Prompt >', prompt_str) logger.info('| Prompt > '+ prompt_str) # LOOP OVER TURNS: while not endingDialogue: # USER ACT: #------------------------------------------------------------------------------------------------------------- sys_act = self.agent_factory.agents[agent_id].retrieve_last_sys_act() if sim_level == 'sys2text': text_user_act, user_actsDomain, _ = self.simulator.act_on(sys_act) #user_actsDomain = 'CamRestaurants' hyps = [(text_user_act, 1.0)] else: user_act, user_actsDomain, hyps = self.simulator.act_on(sys_act) if sim_level == 'text': #todo: convert dialact to text #text_user_act = raw_input('Translate user act: {} > '.format(user_act)) text_user_act = self.SCT.translateUserAct(str(user_act),1)[2] try: text_user_act = text_user_act[0] except: logger.error('Wrong user act: ' + user_act, text_user_act) hyps = [(text_user_act, 1.0)] #actually also output user_actsDomain (the TRUE DOMAIN) here too - which can be used to avoid doing topic tracking if self.traceDialog>1: print(' User >', user_act) if self.sim_level != 'sys2text': logger.dial('| User > ' + user_act.to_string()) else: logger.dial('| User > ' + text_user_act) # SYSTEM ACT: #------------------------------------------------------------------------------------------------------------- sys_act = self.agent_factory.agents[agent_id].continue_call(asr_info = hyps, domainString=user_actsDomain, domainSimulatedUsers=self.simulator.simUserManagers) prompt_str = sys_act.prompt if prompt_str is not None: # if we are generating text, versus remaining only at semantic level. if self.traceDialog>1: print(' Prompt >', prompt_str) logger.info('| Prompt > ' + prompt_str) if self.sim_level != 'sys2text': if 'bye' == user_act.act or 'bye' == sys_act.act: endingDialogue = True else: if 'bye' in text_user_act or 'bye' == sys_act.act: endingDialogue = True # Process ends. for domain in self.simulator.simUserManagers: if self.simulator.simUserManagers[domain]: if self.sim_level != 'sys2text': goal = self.simulator.simUserManagers[domain].um.goal logger.dial('User goal at the end of the dialogue: {}' .format(str(goal.request_type) + str(goal.constraints) + str([req for req in goal.requests]))) self.agent_factory.agents[agent_id].end_call(domainSimulatedUsers=self.simulator.simUserManagers) return if __name__ == '__main__': parser = argparse.ArgumentParser(description='Simulate') parser.add_argument('-C', '-c', '--config', help='set config file', required=True, type=argparse.FileType('r')) parser.add_argument('-n', '--number', help='set the number of dialogues', type=int) parser.add_argument('-r', '--error', help='set error rate', type=int) parser.set_defaults(use_color=True) parser.add_argument('--nocolor', dest='use_color',action='store_false', help='no color in logging. best to\ turn off if dumping to file. Will be overriden by [logging] config setting of "usecolor=".') parser.add_argument('-s', '--seed', help='set random seed', type=int) args = parser.parse_args() if args.error is None: args.error = 0 # default simulated error rate if args.number is None: args.number = 1 # default number of dialogs seed = Settings.init(config_file=args.config.name,seed=args.seed) ContextLogger.createLoggingHandlers(config=Settings.config, use_color=args.use_color) logger.info("Random Seed is {}".format(seed)) Ontology.init_global_ontology() simulator = SimulationSystem(error_rate=float(args.error)/100) simulator.run_dialogs(args.number) #END OF FILE