Skip to content
Snippets Groups Projects
Select Git revision
  • 370a1ff7918ad95ada91c605e539173a87d14ffc
  • main default protected
2 results

AgentFactory.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    AgentFactory.py 12.19 KiB
    '''
    AgentFactory.py - Session management between agents and dialogue server.
    ==========================================================================
    
    @author: Songbo and Neo
    
    '''
    from convlab2.dialog_agent import agent
    from convlab2 import policy
    import copy
    import json
    import numpy as np
    from convlab2.dialcrowd_server.SubjectiveFeedbackManager import SubjectiveFeedbackManager
    from convlab2.util import ContextLogger
    from configparser import ConfigParser
    import logging
    import time
    import os
    logger = ContextLogger.getLogger('')
    
    
    class AgentFactory(object):
    
        def __init__(self, configPath, savePath, saveFlag):
            self.init_agents()
            self.session2agent = {}
            self.historical_sessions = []
            self.savepath = savePath
            self.saveFlag = saveFlag
            self.number_agents_total = 0
    
            # These messages will control sessions for dialogueCrowd. Be careful when you change them, particularly for the fisrt two.
            self.ending_message = "Thanks for your participation. You can now click the Blue Finish Button."
            self.query_taskID_message = "Please now enter the 5 digit task number"
            self.willkommen_message = "Welcome to the dialogue system. How can I help you?"
            self.query_feedback_message = "Got it, thanks. Have you found all the information you were looking for and were all necessary entities booked? Please enter 1 for yes, and 0 for no."
            self.ask_rate_again_message = "Please try again. Have you found all the information you were looking for and were all necessary entities booked? Please enter 1 for yes, and 0 for no."
    
            configparser = ConfigParser()
            configparser.read(configPath)
            agentPath = (configparser.get("AGENT", "agentPath"))
            agentClass = (configparser.get("AGENT", "agentClass"))
            self.maxTurn = int(configparser.get("AGENT", "maxTurn"))
            self.maxNumberAgent = int(configparser.get("AGENT", "maxNumberAgent"))
    
            mod = __import__(agentPath, fromlist=[agentClass])
            klass = getattr(mod, agentClass)
            self.template_agent_class = klass
            self.template_agent_instances = klass()
            self.policy = self.template_agent_instances.policy
            self.nlu = copy.deepcopy(self.template_agent_instances.nlu)
            self.template_agent_instances.policy = None
            self.template_agent_instances.nlu = None
    
            self.subjectiveFeedbackEnabled = (
                configparser.getboolean("SUBJECTIVE", "enabled"))
            self.subjectiveFeedbackManager = None
            self.terminateFlag = False
            
            self.filepath = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) #get parent directory
            self.filepath = os.path.dirname(self.filepath) #get grandparent directory
            self.filepath = os.path.join(self.filepath, 'user_trial', time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
            os.makedirs(self.filepath)
            # TODO
            # subjectiveFeedbackManager should be independent with subjectiveFeedbackEnabled
            # subjectiveFeedbackManager is used for saving every information
            # subjectiveFeedbackEnabled is used for updating the policy through interacting with real users
            if self.subjectiveFeedbackEnabled:
                self.subjectiveFeedbackManager = SubjectiveFeedbackManager(
                    configPath,
                    self.policy,
                    agent_name=self.template_agent_instances.agent_name)
    
        def init_agents(self):
    
            self.agents = {}
    
        def start_call(self, session_id):
            '''
            Locates an agent to take this call and uses that agents start_call method.
    
            :param session_id: session_id
            :type session_id: string
    
            :return: start_call() function of agent id (String)
            '''
    
            agent_id = None
    
            print(session_id)
    
            # 1. make sure session_id is not in use by any agent
            if session_id in list(self.session2agent.keys()):
                agent_id = self.session2agent[session_id]
    
            # 2. check if there is an inactive agent
            if agent_id is None:
                for a_id in list(self.agents.keys()):
                    if self.agents[a_id].session_id is None:
                        agent_id = a_id
                        break
    
            # 3. otherwise create a new agent for this call
            if agent_id is None:
                agent_id = self.new_agent()
            else:
                logger.info('Agent {} has been reactivated.'.format(agent_id))
    
            # 4. record that this session is with this agent, and that it existed:
            self.session2agent[session_id] = agent_id
            self.historical_sessions.append(session_id)
    
            # 5. start the call with this agent:
            self.agents[agent_id].session_id = session_id
            self.agents[agent_id].init_session()
            self.agents[agent_id].agent_saves['session_id'] = session_id
            self.agents[agent_id].agent_saves['agent_id'] = agent_id
            return agent_id
    
        def continue_call(self, agent_id, user_id, userUtterance):
            '''
            wrapper for continue_call for the specific Agent() instance identified by agent_id
    
            :param agent_id: agent id
            :type agent_id: string
    
            :param userUtterance: user input to dialogue agent
            :type userUtterance: str
    
            :return: string -- the system's response
            '''
    
            # If user say "bye", end the dialgue. A user must say "bye" to end the conversation.
            if(str.lower(userUtterance).__contains__("bye")):
                self.agents[agent_id].ENDING_DIALOG = True
                self.agents[agent_id].agent_saves['user_id'] = user_id
                self.agents[agent_id].agent_saves['timestamp'] = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                self.end_call(agent_id)
                self.terminateFlag = True
                return self.ending_message
    
            #     return self.query_feedback_message
    
            # # This captures the user subjective feedback. "1" is for achieving the goal. "0" is for not achieving the goal.
            # if self.agents[agent_id].ENDING_DIALOG:
            #     if str.lower(userUtterance) in ["1", "0"]:
            #         self.agents[agent_id].USER_RATED = True
            #         self.agents[agent_id].USER_GOAL_ACHIEVED = (
            #             str.lower(userUtterance) == "1")
            #         self.end_call(agent_id)
            #         return self.ending_message
            #     else:
            #         return self.ask_rate_again_message
    
            # Get system responses.
            prompt_str = self.agents[agent_id].response(userUtterance)
            if(self.agents[agent_id].turn >= self.maxTurn):
                self.agents[agent_id].ENDING_DIALOG = True
    
            return prompt_str
    
        def end_call(self, agent_id=None, session_id=None):
            '''
            Can pass session_id or agent_id as we use this in cases
                1) normally ending a dialogue, (via agent_id)
                2) cleaning a hung up call      (via session_id)
    
            :param agent_id: agent id
            :type agent_id: string
    
            :param session_id: session_id
            :type session_id: string
    
            :return: None
            '''
    
            # 1. find the agent if only given session_id
            if agent_id is None:  # implicitly assume session_id is given then
                agent_id = self.retrieve_agent(session_id)
                if not agent_id:
                    return
            logger.info('Ending agents %s call' % agent_id)
    
            # 2. remove session from active list
            session_id = self.agents[agent_id].session_id
            print("SESSION IDDDDDD: ", session_id)
    
            del self.session2agent[session_id]
            print("SESSION2AGENT: ", self.session2agent)
    
            # 3. Train the policy according to the subject feedback from the real user.
            # if self.subjectiveFeedbackEnabled:
            #     training_state = self.agents[agent_id].sys_state_history
            #     training_action = self.agents[agent_id].sys_action_history
            #     training_utterance = self.agents[agent_id].sys_utterance_history
            #     training_reward = None #self.agents[agent_id].retrieve_reward()
            #     training_subjectiveFeedback = self.agents[agent_id].USER_GOAL_ACHIEVED
            #     system_outputs = self.agents[agent_id].sys_output_history
            #     try:
            #         prob_history = self.agents[agent_id].action_prob_history
            #     except:
            #         prob_history = []
    
            #     task_id = self.agents[agent_id].taskID
            #     self.subjectiveFeedbackManager.add_state_action_lists(
            #         training_utterance, training_state, training_action, training_subjectiveFeedback, training_reward,
            #         task_id, system_outputs, prob_history)
            
            if self.saveFlag:
                save_file = open(os.path.join(self.filepath, str(session_id).split("\t")[0] + '_save.json'), "w")
                json.dump(self.agents[agent_id].agent_saves, save_file, cls=NumpyEncoder)
                save_file.close()
            # 4. bye bye, agent : (
            self.kill_agent(agent_id)
    
        def agent2session(self, agent_id):
            '''
            Gets str describing session_id agent is currently on
    
            :param agent_id: agent id
            :type agent_id: string
    
            :return: string -- the session id
            '''
            return self.agents[agent_id].session_id
    
        def retrieve_agent(self, session_id):
            '''
            Returns str describing agent_id.
    
            :param session_id: session_id
            :type session_id: string
    
            :return: string -- the agent id
            '''
            if session_id not in list(self.session2agent.keys()):
                logger.error(
                    'Attempted to get an agent for unknown session %s' % session_id)
                return ""
            return self.session2agent[session_id]
    
        def new_agent(self):
            '''
            Creates a new agent to handle some concurrency.
            Here deepcopy is used to create clean copy rather than referencing,
            leaving it in a clean state to commence a new call.
    
            :return: string -- the agent id
            '''
    
            agent_id = 'Smith' + str(self.number_agents_total)
            self.number_agents_total += 1
    
            # This method has efficiency issue. Loading BERT NLU takes too long which will raise errors for socket.
            # Could also just do a deepcopy of everything here and not setting policy to None in the init
            self.agents[agent_id] = copy.deepcopy(self.template_agent_instances)
            self.agents[agent_id].policy = self.policy
            self.agents[agent_id].nlu = self.nlu
    
            self.agents[agent_id].dst.init_session()
    
            if self.subjectiveFeedbackEnabled:
                self.agents[agent_id].policy = self.subjectiveFeedbackManager.getUpdatedPolicy(
                )
            #logger.info('Agent {} has been created.'.format(agent_id))
    
            if len(self.agents) >= self.maxNumberAgent:
                self.kill_inactive_agent()
    
            logging.info(
                f"Created new agent {agent_id}. We now have {len(self.agents)} agents in total.")
    
            return agent_id
    
        def kill_agent(self, agent_id):
            '''
    
            :param agent_id: agent id
            :type agent_id: string
    
            :return: None
            '''
    
            del self.agents[agent_id]
    
        def power_down_factory(self):
            '''
            Finalise agents, print the evaluation summary and save the policy we close dialogue server.
    
            :return: None
            '''
    
            for agent_id in list(self.agents.keys()):
                logger.info('Summary of agent: %s' % agent_id)
            logger.info('Factory handled these sessions: %s' %
                        self.historical_sessions)
    
        def kill_inactive_agent(self):
            '''
            Kill inactive agent in the agents list if there are too many agents running.
            '''
            con = 0
            for a_id in list(self.agents.keys()):
                if self.agents[a_id].is_inactive():
    
                    session_id = self.agents[a_id].session_id
    
                    del self.session2agent[session_id]
                    self.kill_agent(a_id)
                    con += 1
            logger.info('%s of agents are killed.' % con)
    
    class NumpyEncoder(json.JSONEncoder):
        """ Special json encoder for numpy types """
        def default(self, obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            return json.JSONEncoder.default(self, obj)
    # END OF FILE