Select Git revision
MachineContext.java
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
AgentFactory.py 12.19 KiB
'''
AgentFactory.py - Session management between agents and dialogue server.
==========================================================================
@author: Songbo and Neo
'''
from convlab2.dialog_agent import agent
from convlab2 import policy
import copy
import json
import numpy as np
from convlab2.dialcrowd_server.SubjectiveFeedbackManager import SubjectiveFeedbackManager
from convlab2.util import ContextLogger
from configparser import ConfigParser
import logging
import time
import os
logger = ContextLogger.getLogger('')
class AgentFactory(object):
def __init__(self, configPath, savePath, saveFlag):
self.init_agents()
self.session2agent = {}
self.historical_sessions = []
self.savepath = savePath
self.saveFlag = saveFlag
self.number_agents_total = 0
# These messages will control sessions for dialogueCrowd. Be careful when you change them, particularly for the fisrt two.
self.ending_message = "Thanks for your participation. You can now click the Blue Finish Button."
self.query_taskID_message = "Please now enter the 5 digit task number"
self.willkommen_message = "Welcome to the dialogue system. How can I help you?"
self.query_feedback_message = "Got it, thanks. Have you found all the information you were looking for and were all necessary entities booked? Please enter 1 for yes, and 0 for no."
self.ask_rate_again_message = "Please try again. Have you found all the information you were looking for and were all necessary entities booked? Please enter 1 for yes, and 0 for no."
configparser = ConfigParser()
configparser.read(configPath)
agentPath = (configparser.get("AGENT", "agentPath"))
agentClass = (configparser.get("AGENT", "agentClass"))
self.maxTurn = int(configparser.get("AGENT", "maxTurn"))
self.maxNumberAgent = int(configparser.get("AGENT", "maxNumberAgent"))
mod = __import__(agentPath, fromlist=[agentClass])
klass = getattr(mod, agentClass)
self.template_agent_class = klass
self.template_agent_instances = klass()
self.policy = self.template_agent_instances.policy
self.nlu = copy.deepcopy(self.template_agent_instances.nlu)
self.template_agent_instances.policy = None
self.template_agent_instances.nlu = None
self.subjectiveFeedbackEnabled = (
configparser.getboolean("SUBJECTIVE", "enabled"))
self.subjectiveFeedbackManager = None
self.terminateFlag = False
self.filepath = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) #get parent directory
self.filepath = os.path.dirname(self.filepath) #get grandparent directory
self.filepath = os.path.join(self.filepath, 'user_trial', time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
os.makedirs(self.filepath)
# TODO
# subjectiveFeedbackManager should be independent with subjectiveFeedbackEnabled
# subjectiveFeedbackManager is used for saving every information
# subjectiveFeedbackEnabled is used for updating the policy through interacting with real users
if self.subjectiveFeedbackEnabled:
self.subjectiveFeedbackManager = SubjectiveFeedbackManager(
configPath,
self.policy,
agent_name=self.template_agent_instances.agent_name)
def init_agents(self):
self.agents = {}
def start_call(self, session_id):
'''
Locates an agent to take this call and uses that agents start_call method.
:param session_id: session_id
:type session_id: string
:return: start_call() function of agent id (String)
'''
agent_id = None
print(session_id)
# 1. make sure session_id is not in use by any agent
if session_id in list(self.session2agent.keys()):
agent_id = self.session2agent[session_id]
# 2. check if there is an inactive agent
if agent_id is None:
for a_id in list(self.agents.keys()):
if self.agents[a_id].session_id is None:
agent_id = a_id
break
# 3. otherwise create a new agent for this call
if agent_id is None:
agent_id = self.new_agent()
else:
logger.info('Agent {} has been reactivated.'.format(agent_id))
# 4. record that this session is with this agent, and that it existed:
self.session2agent[session_id] = agent_id
self.historical_sessions.append(session_id)
# 5. start the call with this agent:
self.agents[agent_id].session_id = session_id
self.agents[agent_id].init_session()
self.agents[agent_id].agent_saves['session_id'] = session_id
self.agents[agent_id].agent_saves['agent_id'] = agent_id
return agent_id
def continue_call(self, agent_id, user_id, userUtterance):
'''
wrapper for continue_call for the specific Agent() instance identified by agent_id
:param agent_id: agent id
:type agent_id: string
:param userUtterance: user input to dialogue agent
:type userUtterance: str
:return: string -- the system's response
'''
# If user say "bye", end the dialgue. A user must say "bye" to end the conversation.
if(str.lower(userUtterance).__contains__("bye")):
self.agents[agent_id].ENDING_DIALOG = True
self.agents[agent_id].agent_saves['user_id'] = user_id
self.agents[agent_id].agent_saves['timestamp'] = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
self.end_call(agent_id)
self.terminateFlag = True
return self.ending_message
# return self.query_feedback_message
# # This captures the user subjective feedback. "1" is for achieving the goal. "0" is for not achieving the goal.
# if self.agents[agent_id].ENDING_DIALOG:
# if str.lower(userUtterance) in ["1", "0"]:
# self.agents[agent_id].USER_RATED = True
# self.agents[agent_id].USER_GOAL_ACHIEVED = (
# str.lower(userUtterance) == "1")
# self.end_call(agent_id)
# return self.ending_message
# else:
# return self.ask_rate_again_message
# Get system responses.
prompt_str = self.agents[agent_id].response(userUtterance)
if(self.agents[agent_id].turn >= self.maxTurn):
self.agents[agent_id].ENDING_DIALOG = True
return prompt_str
def end_call(self, agent_id=None, session_id=None):
'''
Can pass session_id or agent_id as we use this in cases
1) normally ending a dialogue, (via agent_id)
2) cleaning a hung up call (via session_id)
:param agent_id: agent id
:type agent_id: string
:param session_id: session_id
:type session_id: string
:return: None
'''
# 1. find the agent if only given session_id
if agent_id is None: # implicitly assume session_id is given then
agent_id = self.retrieve_agent(session_id)
if not agent_id:
return
logger.info('Ending agents %s call' % agent_id)
# 2. remove session from active list
session_id = self.agents[agent_id].session_id
print("SESSION IDDDDDD: ", session_id)
del self.session2agent[session_id]
print("SESSION2AGENT: ", self.session2agent)
# 3. Train the policy according to the subject feedback from the real user.
# if self.subjectiveFeedbackEnabled:
# training_state = self.agents[agent_id].sys_state_history
# training_action = self.agents[agent_id].sys_action_history
# training_utterance = self.agents[agent_id].sys_utterance_history
# training_reward = None #self.agents[agent_id].retrieve_reward()
# training_subjectiveFeedback = self.agents[agent_id].USER_GOAL_ACHIEVED
# system_outputs = self.agents[agent_id].sys_output_history
# try:
# prob_history = self.agents[agent_id].action_prob_history
# except:
# prob_history = []
# task_id = self.agents[agent_id].taskID
# self.subjectiveFeedbackManager.add_state_action_lists(
# training_utterance, training_state, training_action, training_subjectiveFeedback, training_reward,
# task_id, system_outputs, prob_history)
if self.saveFlag:
save_file = open(os.path.join(self.filepath, str(session_id).split("\t")[0] + '_save.json'), "w")
json.dump(self.agents[agent_id].agent_saves, save_file, cls=NumpyEncoder)
save_file.close()
# 4. bye bye, agent : (
self.kill_agent(agent_id)
def agent2session(self, agent_id):
'''
Gets str describing session_id agent is currently on
:param agent_id: agent id
:type agent_id: string
:return: string -- the session id
'''
return self.agents[agent_id].session_id
def retrieve_agent(self, session_id):
'''
Returns str describing agent_id.
:param session_id: session_id
:type session_id: string
:return: string -- the agent id
'''
if session_id not in list(self.session2agent.keys()):
logger.error(
'Attempted to get an agent for unknown session %s' % session_id)
return ""
return self.session2agent[session_id]
def new_agent(self):
'''
Creates a new agent to handle some concurrency.
Here deepcopy is used to create clean copy rather than referencing,
leaving it in a clean state to commence a new call.
:return: string -- the agent id
'''
agent_id = 'Smith' + str(self.number_agents_total)
self.number_agents_total += 1
# This method has efficiency issue. Loading BERT NLU takes too long which will raise errors for socket.
# Could also just do a deepcopy of everything here and not setting policy to None in the init
self.agents[agent_id] = copy.deepcopy(self.template_agent_instances)
self.agents[agent_id].policy = self.policy
self.agents[agent_id].nlu = self.nlu
self.agents[agent_id].dst.init_session()
if self.subjectiveFeedbackEnabled:
self.agents[agent_id].policy = self.subjectiveFeedbackManager.getUpdatedPolicy(
)
#logger.info('Agent {} has been created.'.format(agent_id))
if len(self.agents) >= self.maxNumberAgent:
self.kill_inactive_agent()
logging.info(
f"Created new agent {agent_id}. We now have {len(self.agents)} agents in total.")
return agent_id
def kill_agent(self, agent_id):
'''
:param agent_id: agent id
:type agent_id: string
:return: None
'''
del self.agents[agent_id]
def power_down_factory(self):
'''
Finalise agents, print the evaluation summary and save the policy we close dialogue server.
:return: None
'''
for agent_id in list(self.agents.keys()):
logger.info('Summary of agent: %s' % agent_id)
logger.info('Factory handled these sessions: %s' %
self.historical_sessions)
def kill_inactive_agent(self):
'''
Kill inactive agent in the agents list if there are too many agents running.
'''
con = 0
for a_id in list(self.agents.keys()):
if self.agents[a_id].is_inactive():
session_id = self.agents[a_id].session_id
del self.session2agent[session_id]
self.kill_agent(a_id)
con += 1
logger.info('%s of agents are killed.' % con)
class NumpyEncoder(json.JSONEncoder):
""" Special json encoder for numpy types """
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
# END OF FILE