diff --git a/convlab/dialcrowd_server/AgentFactory.py b/convlab/dialcrowd_server/AgentFactory.py deleted file mode 100644 index 5632f639512b21cd21ed216f18f75810736cfee7..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/AgentFactory.py +++ /dev/null @@ -1,381 +0,0 @@ -''' -AgentFactory.py - Session management between agents and dialogue server. -========================================================================== - -@author: Songbo and Neo - -''' -from convlab2.dialcrowd_server.Goal import _process_goal -import copy -import json -import numpy as np -import torch -from convlab2.dialcrowd_server.SubjectiveFeedbackManager import SubjectiveFeedbackManager -from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator -from convlab2.util import ContextLogger -from configparser import ConfigParser -import logging -import time -import os -import shutil -logger = ContextLogger.getLogger('') - - -class AgentFactory(object): - - def __init__(self, configPath, savePath, saveFlag=True, task_file=None): - self.init_agents() - self.session2agent = {} - self.historical_sessions = [] - self.savepath = savePath - self.saveFlag = saveFlag - self.number_agents_total = 0 - assert task_file is not None, print("YOU NEED TO PASS A TASK FILE FOR OBJECTIVE SUCCESS.") - - self.filepath = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) #get parent directory - self.filepath = os.path.dirname(self.filepath) #get grandparent directory - self.filepath = os.path.join(self.filepath, 'human_trial', time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) - os.makedirs(self.filepath) - os.makedirs(os.path.join(self.filepath, 'dialogues')) - - shutil.copy(task_file, self.filepath) - - with open(task_file, 'r') as f: - self.tasks = [] - for line in f: - line = line.strip() - self.tasks.append(json.loads(line)) - - # These messages will control sessions for dialogueCrowd. Be careful when you change them, particularly for the fisrt two. - self.ending_message = "Thanks for your participation. You can now click the Blue Finish Button." - self.query_taskID_message = "Please now enter the 5 digit task number" - self.willkommen_message = "Welcome to the dialogue system. How can I help you?" - self.query_feedback_message = "Got it, thanks. Have you found all the information you were looking for and were all necessary entities booked? Please enter 1 for yes, and 0 for no." - self.ask_rate_again_message = "Please try again. Have you found all the information you were looking for and were all necessary entities booked? Please enter 1 for yes, and 0 for no." - - configparser = ConfigParser() - configparser.read(configPath) - agentPath = (configparser.get("AGENT", "agentPath")) - agentClass = (configparser.get("AGENT", "agentClass")) - self.maxTurn = int(configparser.get("AGENT", "maxTurn")) - self.maxNumberAgent = int(configparser.get("AGENT", "maxNumberAgent")) - - mod = __import__(agentPath, fromlist=[agentClass]) - klass = getattr(mod, agentClass) - self.template_agent_class = klass - self.template_agent_instances = klass() - self.policy = self.template_agent_instances.policy - self.nlu = copy.deepcopy(self.template_agent_instances.nlu) - self.nlg = copy.deepcopy(self.template_agent_instances.nlg) - self.template_agent_instances.policy = None - self.template_agent_instances.nlu = None - self.template_agent_instances.nlg = None - - self.subjectiveFeedbackEnabled = ( - configparser.getboolean("SUBJECTIVE", "enabled")) - self.subjectiveFeedbackManager = None - self.terminateFlag = False - - # TODO - # subjectiveFeedbackManager should be independent with subjectiveFeedbackEnabled - # subjectiveFeedbackManager is used for saving every information - # subjectiveFeedbackEnabled is used for updating the policy through interacting with real users - if self.subjectiveFeedbackEnabled: - self.subjectiveFeedbackManager = SubjectiveFeedbackManager( - configPath, - self.policy, - agent_name=self.template_agent_instances.agent_name) - - def init_agents(self): - - self.agents = {} - - def start_call(self, session_id, user_id=None, task_id=None): - ''' - Locates an agent to take this call and uses that agents start_call method. - - :param session_id: session_id - :type session_id: string - - :return: start_call() function of agent id (String) - ''' - - agent_id = None - - print(session_id) - - # 1. make sure session_id is not in use by any agent - if session_id in list(self.session2agent.keys()): - agent_id = self.session2agent[session_id] - - # 2. check if there is an inactive agent - if agent_id is None: - for a_id in list(self.agents.keys()): - if self.agents[a_id].session_id is None: - agent_id = a_id - break - - # 3. otherwise create a new agent for this call - if agent_id is None: - agent_id = self.new_agent() - else: - logger.info('Agent {} has been reactivated.'.format(agent_id)) - - # 4. record that this session is with this agent, and that it existed: - self.session2agent[session_id] = agent_id - self.historical_sessions.append(session_id) - - # 5. start the call with this agent: - self.agents[agent_id].session_id = session_id - self.agents[agent_id].init_session() - self.agents[agent_id].agent_saves['session_id'] = session_id - self.agents[agent_id].agent_saves['agent_id'] = agent_id - self.agents[agent_id].agent_saves['task_id'] = task_id - self.agents[agent_id].agent_saves['user_id'] = user_id - return agent_id - - def continue_call(self, agent_id, user_id, userUtterance, task_id): - ''' - wrapper for continue_call for the specific Agent() instance identified by agent_id - - :param agent_id: agent id - :type agent_id: string - - :param userUtterance: user input to dialogue agent - :type userUtterance: str - - :return: string -- the system's response - ''' - - # If user say "bye", end the dialgue. A user must say "bye" to end the conversation. - if(str.lower(userUtterance).__contains__("bye")): - self.agents[agent_id].ENDING_DIALOG = True - self.agents[agent_id].agent_saves['user_id'] = user_id - self.agents[agent_id].agent_saves['task_id'] = task_id - self.agents[agent_id].agent_saves['timestamp'] = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - self.end_call(agent_id) - self.terminateFlag = True - return self.ending_message - - # return self.query_feedback_message - - # # This captures the user subjective feedback. "1" is for achieving the goal. "0" is for not achieving the goal. - # if self.agents[agent_id].ENDING_DIALOG: - # if str.lower(userUtterance) in ["1", "0"]: - # self.agents[agent_id].USER_RATED = True - # self.agents[agent_id].USER_GOAL_ACHIEVED = ( - # str.lower(userUtterance) == "1") - # self.end_call(agent_id) - # return self.ending_message - # else: - # return self.ask_rate_again_message - - # Get system responses. - with torch.no_grad(): - prompt_str = self.agents[agent_id].response(userUtterance) - if(self.agents[agent_id].turn >= self.maxTurn): - self.agents[agent_id].ENDING_DIALOG = True - - return prompt_str - - def end_call(self, agent_id=None, session_id=None): - ''' - Can pass session_id or agent_id as we use this in cases - 1) normally ending a dialogue, (via agent_id) - 2) cleaning a hung up call (via session_id) - - :param agent_id: agent id - :type agent_id: string - - :param session_id: session_id - :type session_id: string - - :return: None - ''' - - # 1. find the agent if only given session_id - if agent_id is None: # implicitly assume session_id is given then - agent_id = self.retrieve_agent(session_id) - if not agent_id: - return - logger.info('Ending agents %s call' % agent_id) - - # 2. remove session from active list - session_id = self.agents[agent_id].session_id - print("SESSION IDDDDDD: ", session_id) - - del self.session2agent[session_id] - print("SESSION2AGENT: ", self.session2agent) - - # 3. Train the policy according to the subject feedback from the real user. - # if self.subjectiveFeedbackEnabled: - # training_state = self.agents[agent_id].sys_state_history - # training_action = self.agents[agent_id].sys_action_history - # training_utterance = self.agents[agent_id].sys_utterance_history - # training_reward = None #self.agents[agent_id].retrieve_reward() - # training_subjectiveFeedback = self.agents[agent_id].USER_GOAL_ACHIEVED - # system_outputs = self.agents[agent_id].sys_output_history - # try: - # prob_history = self.agents[agent_id].action_prob_history - # except: - # prob_history = [] - - # task_id = self.agents[agent_id].taskID - # self.subjectiveFeedbackManager.add_state_action_lists( - # training_utterance, training_state, training_action, training_subjectiveFeedback, training_reward, - # task_id, system_outputs, prob_history) - - if self.saveFlag: - user_id = self.agents[agent_id].agent_saves['user_id'] - suffix = str(user_id) + "-" + str(session_id).split("\t")[0] + '.pkl' - save_path = os.path.join(self.filepath, "dialogues", suffix) - - try: - task_id = self.agents[agent_id].agent_saves['task_id'] - dialogue = self.agents[agent_id].agent_saves["dialogue_info_fundamental"] - objective_performance = self.get_objective_performance(task_id, dialogue) - self.agents[agent_id].agent_saves['performance'] = objective_performance - print("OBJECTIVE PERFORMANCE:", objective_performance) - except Exception as e: - print(f"Could not calculate objective performance: {e}") - - torch.save(self.agents[agent_id].agent_saves, save_path) - # 4. bye bye, agent : ( - self.kill_agent(agent_id) - - def agent2session(self, agent_id): - ''' - Gets str describing session_id agent is currently on - - :param agent_id: agent id - :type agent_id: string - - :return: string -- the session id - ''' - return self.agents[agent_id].session_id - - def retrieve_agent(self, session_id): - ''' - Returns str describing agent_id. - - :param session_id: session_id - :type session_id: string - - :return: string -- the agent id - ''' - if session_id not in list(self.session2agent.keys()): - logger.error( - 'Attempted to get an agent for unknown session %s' % session_id) - return "" - return self.session2agent[session_id] - - def new_agent(self): - ''' - Creates a new agent to handle some concurrency. - Here deepcopy is used to create clean copy rather than referencing, - leaving it in a clean state to commence a new call. - - :return: string -- the agent id - ''' - - agent_id = 'Smith' + str(self.number_agents_total) - self.number_agents_total += 1 - - # This method has efficiency issue. Loading BERT NLU takes too long which will raise errors for socket. - # Could also just do a deepcopy of everything here and not setting policy to None in the init - self.agents[agent_id] = copy.deepcopy(self.template_agent_instances) - self.agents[agent_id].policy = self.policy - self.agents[agent_id].nlu = self.nlu - self.agents[agent_id].nlg = self.nlg - - self.agents[agent_id].dst.init_session() - - if self.subjectiveFeedbackEnabled: - self.agents[agent_id].policy = self.subjectiveFeedbackManager.getUpdatedPolicy( - ) - #logger.info('Agent {} has been created.'.format(agent_id)) - - if len(self.agents) >= self.maxNumberAgent: - self.kill_inactive_agent() - - logging.info( - f"Created new agent {agent_id}. We now have {len(self.agents)} agents in total.") - - return agent_id - - def kill_agent(self, agent_id): - ''' - - :param agent_id: agent id - :type agent_id: string - - :return: None - ''' - - del self.agents[agent_id] - torch.cuda.empty_cache() - - def power_down_factory(self): - ''' - Finalise agents, print the evaluation summary and save the policy we close dialogue server. - - :return: None - ''' - - for agent_id in list(self.agents.keys()): - logger.info('Summary of agent: %s' % agent_id) - logger.info('Factory handled these sessions: %s' % - self.historical_sessions) - - def kill_inactive_agent(self): - ''' - Kill inactive agent in the agents list if there are too many agents running. - ''' - con = 0 - for a_id in list(self.agents.keys()): - if self.agents[a_id].is_inactive(): - - session_id = self.agents[a_id].session_id - - del self.session2agent[session_id] - self.kill_agent(a_id) - con += 1 - logger.info('%s of agents are killed.' % con) - - def get_objective_performance(self, task_id, dialogue): - - goal = self.task_id_to_goal(task_id) - evaluator = MultiWozEvaluator() - - user_acts = [] - system_acts = [] - belief_states = [] - - for turn in dialogue: - system_acts.append(turn['output_action']) - user_acts.append(turn['state']['user_action']) - belief_states.append(turn['state']['belief_state']) - performance_dict = evaluator.evaluate_dialog(goal, user_acts, system_acts, belief_states) - - return performance_dict - - def task_id_to_goal(self, task_id): - goal = None - for task in self.tasks: - if task_id == task['taskID']: - goal = task - break - return _process_goal(goal) - - -class NumpyEncoder(json.JSONEncoder): - """ Special json encoder for numpy types """ - def default(self, obj): - if isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - return json.JSONEncoder.default(self, obj) -# END OF FILE diff --git a/convlab/dialcrowd_server/DialcrowdTaskGen.py b/convlab/dialcrowd_server/DialcrowdTaskGen.py deleted file mode 100644 index a53f9f4e54c9af0777e4dffa52bdfe7efafa0a37..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/DialcrowdTaskGen.py +++ /dev/null @@ -1,267 +0,0 @@ -''' -DialcrowdTaskGen.py - Generating tasks for Dialcrowd. -========================================================================== - -Only work for MultiWoz domain. - -@author: Songbo - -''' - -import json -import argparse -from uuid import uuid1 as uuid -from convlab2.task.multiwoz.goal_generator import GoalGenerator -from convlab2.policy.tus.multiwoz.Goal import Goal - - -def timeline_task(goal, task_id, goal_generator): - goal["taskID"] = task_id - # TODO need to modify the Dialcrowd server backend - goal["taskType"] = "convlab2" - goal["taskMessage"] = goal_generator.build_message(goal) - return goal - - -def multiwoz_task(goal, task_id): - task_info = {} - task_info["taskID"] = task_id - task_info["tasks"] = [] - for domain in goal: - task = { - 'reqs': goal[domain].get('reqt', []), - } - reqs = goal[domain].get('reqt', []) - - for slot_type in ['info', 'book']: - task[slot_type] = slot_str(goal, domain, slot_type) - - task_info["tasks"].append({ - "Dom": domain.capitalize(), - "Cons": ", ".join(task['info']), - "Book": ", ".join(task['book']), - "Reqs": ", ".join(task['reqs'])}) - - return task_info - - -def normal_task(goal, task_id): - task_info = {} - task_info["taskID"] = task_id - task_info["tasks"] = [] - for domain in goal["domain_ordering"]: - task = { - 'reqs': goal[domain].get('reqt', []), - } - reqs = goal[domain].get('reqt', []) - - for slot_type in ['info', 'book']: - task[slot_type] = slot_str(goal, domain, slot_type) - - task_info["tasks"].append({ - "Dom": domain.capitalize(), - "Cons": ", ".join(task['info']), - "Book": ", ".join(task['book']), - "Reqs": ", ".join(task['reqs'])}) - - return task_info - - -def slot_str(goal, domain, slot_type): - slots = [] - for slot in goal[domain].get(slot_type, []): - value = info_str(goal, domain, slot, slot_type) - slots.append(f"{slot}={value}") - return slots - - -def check_slot_length(task_info, max_slot_len=5): - for task in task_info["tasks"]: - if len(task["Cons"].split(", ")) > max_slot_len: - print(f'SLOT LENGTH ERROR: {len(task["Cons"].split(", "))}') - return False - return True - - -def check_domain_number(task_info, min_domain_len=0, max_domain_len=3): - if len(task_info["tasks"]) > min_domain_len and len(task_info["tasks"]) <= max_domain_len: - return True - print(f'DOMAIN NUMBER ERROR: {len(task_info["tasks"])}') - return False - - -def info_str(goal, domain, slot, slot_type): - if slot_type == 'info': - fail_type = 'fail_info' - elif slot_type == 'book': - fail_type = 'book_again' - # print(goal, domain, slot, slot_type) - value = goal[domain][slot_type][slot] - if fail_type not in goal[domain]: - return value - else: - fail_info = goal[domain][fail_type].get(slot, "") - if fail_info and fail_info != value: - return f"{fail_info} (if unavailable use: {value})" - else: - return value - - -def write_task(task_size, task_type, out_file, test_file=None): - goal_generator = GoalGenerator(boldify=True) - if test_file: - test_list = [task_id for task_id in test_file] - con = 0 - output = [] - while len(output) < task_size: - # try: - if test_file: - goal = Goal(goal=test_file[test_list[len(output)]]["goal"]) - task_info = multiwoz_task(goal.domain_goals, len(output) + 10000) - - else: - goal = goal_generator.get_user_goal() - if 'police' in goal['domain_ordering']: - no_police = list(goal['domain_ordering']) - no_police.remove('police') - goal['domain_ordering'] = tuple(no_police) - del goal['police'] - - if task_type == "convlab2": - task_info = timeline_task( - goal, len(output) + 10000, goal_generator) - elif task_type == "normal": - # task_info = normal_task(goal, len(output) + 10000) - task_info = normal_task(goal, str(uuid())) - else: - print("unseen task type. No goal is created.") - - if check_domain_number(task_info) and check_slot_length(task_info): - output.append(json.dumps(task_info)) - - # except Exception as e: - # print(goal) - # con += 1 - - print("{} exceptions in total." .format(con)) - f = open(out_file, "w") - f.write("\n".join(output)) - f.close() - - -def write_task_single_domain(task_size, task_type, out_file, test_file=None): - goal_generator = GoalGenerator(boldify=True) - if test_file: - test_list = [task_id for task_id in test_file] - con = 0 - output = [] - goals_per_domain = dict() - # we want task_size / 5 many goals for each of the 5 domains we have - goals_needed = task_size / 5 - while len(output) < task_size: - # try: - if test_file: - goal = Goal(goal=test_file[test_list[len(output)]]["goal"]) - task_info = multiwoz_task(goal.domain_goals, len(output) + 10000) - - else: - goal = goal_generator.get_user_goal() - if 'police' in goal['domain_ordering']: - no_police = list(goal['domain_ordering']) - no_police.remove('police') - goal['domain_ordering'] = tuple(no_police) - del goal['police'] - if 'hospital' in goal['domain_ordering']: - no_police = list(goal['domain_ordering']) - no_police.remove('hospital') - goal['domain_ordering'] = tuple(no_police) - del goal['hospital'] - - # make sure we only get single domain goals - num_goals = len(goal['domain_ordering']) - - if num_goals == 0: - continue - - while num_goals > 1: - domain_removed = list(goal['domain_ordering'])[-1] - goal['domain_ordering'] = tuple(list(goal['domain_ordering'])[:-1]) - del goal[domain_removed] - num_goals = len(goal['domain_ordering']) - - domain = goal['domain_ordering'] - - type_request = False - for key in goal: - if key == "domain_ordering": - continue - if 'reqt' in goal[key]: - if 'type' in goal[key]['reqt']: - type_request = True - break - if type_request: - continue - - # if we have enough domains, continue to search - if goals_per_domain.get(domain, 0) >= goals_needed: - continue - - if domain not in goals_per_domain: - goals_per_domain[domain] = 1 - else: - goals_per_domain[domain] += 1 - - if task_type == "convlab2": - task_info = timeline_task( - goal, len(output) + 10000, goal_generator) - elif task_type == "normal": - task_info = normal_task(goal, str(uuid())) - else: - print("unseen task type. No goal is created.") - - if check_domain_number(task_info) and check_slot_length(task_info): - output.append(json.dumps(task_info)) - - # except Exception as e: - # print(goal) - # con += 1 - - print("{} exceptions in total." .format(con)) - f = open(out_file, "w") - f.write("\n".join(output)) - f.close() - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--num-task', default=400, type=int, - help="How many tasks would you like in you task list.") - parser.add_argument('--task-type', default="normal", type=str, - help="the format of task type, 'normal' or 'timeline'.") - parser.add_argument('--out-file', default="task.out", type=str, - help="the output file name") - parser.add_argument('--test-file', default="", type=str, - help="the multiwoz test file") - parser.add_argument('--single-domains', action='store_true', - help="if it should generate uniform distribution over single domain goals") - - args = parser.parse_args() - test_file = None - num_task = args.num_task - if args.test_file: - test_file = json.load(open(args.test_file)) - num_task = len(test_file) - print(f"use test file, length={num_task}") - if not args.single_domains: - write_task(num_task, args.task_type, args.out_file, test_file) - else: - write_task_single_domain(num_task, args.task_type, args.out_file, test_file) - - -if __name__ == '__main__': - - # How many tasks would you like in you task list. - main() - - -# END OF FILE diff --git a/convlab/dialcrowd_server/DialogueServer.py b/convlab/dialcrowd_server/DialogueServer.py deleted file mode 100644 index 240e32cb2ab058f7fa69ef0f4cf6c06b682d4f39..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/DialogueServer.py +++ /dev/null @@ -1,279 +0,0 @@ -''' -DialogueServer.py - Web interface for DialCrowd. -========================================================================== - -@author: Songbo - -''' - -from configparser import ConfigParser -from convlab2.dialcrowd_server.AgentFactory import AgentFactory -from datetime import datetime -import json -import http.server -import os -import sys -import shutil -from argparse import ArgumentParser -import logging - -from convlab2.util.custom_util import init_logging - -project_path = (os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))) -sys.path.append(project_path) - - -# ================================================================================================ -# SERVER BEHAVIOUR -# ================================================================================================ - - -def make_request_handler_class(dialServer): - """ - """ - - class RequestHandler(http.server.BaseHTTPRequestHandler): - ''' - Process HTTP Requests - :return: - ''' - - def do_POST(self): - ''' - Handle only POST requests. Please note that GET requests ARE NOT SUPPORTED! :) - ''' - self.error_free = True # boolean which can become False if we encounter a problem - - agent_id = None - self.currentSession = None - self.currentUser = None - prompt_str = '' - reply = {} - - # ------Get the "Request" link. - logging.info('-' * 30) - - request = self.path[1:] if self.path.find( - '?') < 0 else self.path[1:self.path.find('?')] - - logging.info(f'Request: ' + str(request)) - logging.info(f'POST full path: {self.path}') - - if not 'Content-Length' in self.headers: - data_string = self.path[self.path.find('?') + 1:] - else: - data_string = self.rfile.read( - int(self.headers['Content-Length'])) - - # contains e.g: {"session": "voip-5595158237"} - logging.info("Request Data:" + str(data_string)) - - recognition_fail = True # default until we confirm we have received data - try: - data = json.loads(data_string) # ValueError - self.currentSession = data["sessionID"] # KeyError - self.currentUser = data["userID"] # KeyError - except Exception as e: - logging.info(f"Not a valid JSON object (or object lacking info) received. {e}") - else: - recognition_fail = False - - if request == 'init': - try: - user_id = data.get("userID", None) - task_id = data.get("taskID", None) - agent_id = dialServer.agent_factory.start_call( - session_id=self.currentSession, user_id=user_id, task_id=task_id) - reply = dialServer.prompt( - dialServer.agent_factory.willkommen_message, session_id=self.currentSession) - except Exception as e: - self.error_free = False - logging.info(f"COULD NOT INIT SESSION, EXCEPTION: {e}") - else: - logging.info(f"A new call has started. Session: {self.currentSession}") - - elif request == 'next': - - # Next step in the conversation flow - # map session_id to agent_id - - try: - agent_id = dialServer.agent_factory.retrieve_agent( - session_id=self.currentSession) - except Exception as e: # Throws a ExceptionRaisedByLogger - self.error_free = False - logging.info(f"NEXT: tried to retrieve agent but: {e}") - else: - logging.info(f"Continuing session: {self.currentSession} with agent_id {agent_id}") - if self.error_free: - try: - userUtterance = data["text"] # KeyError - user_id = data["userID"] - task_id = data.get("taskID", None) - logging.info(f"Received user utterance {userUtterance}") - prompt_str = dialServer.agent_factory.continue_call( - agent_id, user_id, userUtterance, task_id) - - if(prompt_str == dial_server.agent_factory.ending_message): - reply = dialServer.prompt( - prompt_str, session_id=self.currentSession, isfinal=True) - else: - reply = dialServer.prompt( - prompt_str, session_id=self.currentSession, isfinal=False) - except Exception as e: - logging.info(f"NEXT: tried to continue call but {e}") - else: - reply = None - - elif request == 'end': - - # Request to stop the session. - - logging.info("Received request to Clean Session ID from the VoiceBroker...:" + self.currentSession) - - self.error_free = False - - try: - agent_id = dialServer.agent_factory.end_call( - session_id=self.currentSession) - except Exception as e: # an ExceptionRaisedByLogger - logging.info(f"END: tried to end call but: {e}") - - # ------ Completed turn -------------- - - # POST THE REPLY BACK TO THE SPEECH SYSTEM - logging.info("Sending prompt: " + prompt_str + " to tts.") - self.send_response(200) # 200=OK W3C HTTP Standard codes - self.send_header('Content-type', 'text/json') - self.end_headers() - logging.info(reply) - try: - self.wfile.write(reply.encode('utf-8')) - except Exception as e: - logging.info(f"wanted wo wfile.write but {e}") - reply = dialServer.prompt(f"Error: I am sorry, we are working on this.", session_id=self.currentSession, isfinal=False) - self.wfile.write(reply.encode('utf-8')) - logging.info(dialServer.agent_factory.session2agent) - return RequestHandler - - -# ================================================================================================ -# DIALOGUE SERVER -# ================================================================================================ - -class DialogueServer(object): - - ''' - This class implements an HTTP Server - ''' - - def __init__(self, configPath): - """ HTTP Server - """ - - configparser = ConfigParser() - configparser.read(configPath) - host = (configparser.get("GENERAL", "host")) - task_file = (configparser.get("GENERAL", "task_file")) - port = int(configparser.get("GENERAL", "port")) - agentPath = (configparser.get("AGENT", "agentPath")) - agentClass = (configparser.get("AGENT", "agentClass")) - dialogueSave = (configparser.get("AGENT", "dialogueSave")) - saveFlag = True - - if configparser.get("AGENT", "saveFlag") == "True": - saveFlag = True - - mod = __import__(agentPath, fromlist=[agentClass]) - klass = getattr(mod, agentClass) - self.host = host - self.port = port - self.agent_factory = AgentFactory(configPath, dialogueSave, saveFlag, task_file) - - shutil.copy(configPath, self.agent_factory.filepath) - shutil.copy(agentPath.replace(".", "/") + ".py", self.agent_factory.filepath) - - logging.info("Server init") - - def run(self): - """Listen to request in host dialhost and port dialport""" - - RequestHandlerClass = make_request_handler_class(self) - - server = http.server.HTTPServer( - (self.host, self.port), RequestHandlerClass) - logging.info(f'Server starting {self.host}:{self.port} (level=info)') - - try: - while 1: - server.serve_forever() - except KeyboardInterrupt: - pass - finally: - logging.info(f'Server stopping {self.host}:{self.port}') - server.server_close() - - self.agent_factory.power_down_factory() - - def prompt(self, prompt, session_id, isfinal=False): - ''' - Create a prompt, for the moment the arguments are - - :param prompt: the text to be prompt - :param isfinal: if it is the final sentence before the end of dialogue - :return: reply in json - ''' - - reply = {} - reply["sessionID"] = session_id - reply["version"] = "0.1" - reply["terminal"] = isfinal - reply['sys'] = self._clean_text(prompt) - reply["timeStamp"] = datetime.now().isoformat() - - logging.info(reply) - return json.dumps(reply, ensure_ascii=False) - - def _clean_text(self, RAW_TEXT): - """ - """ - # The replace() is because of how words with ' come out of the Template SemO. - JUNK_CHARS = ['(', ')', '{', '}', '<', '>', '"', "'"] - return ''.join(c for c in RAW_TEXT.replace("' ", "") if c not in JUNK_CHARS) - - -def save_log_to_file(): - import time - - dir_name = os.path.join(os.path.dirname( - os.path.abspath(__file__)), 'dialogueServer_LOG') - if not os.path.exists(dir_name): - os.makedirs(dir_name) - current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - output_file = open(os.path.join( - dir_name, f"stdout_{current_time}.txt"), 'w') - sys.stdout = output_file - sys.stderr = output_file - - -# ================================================================================================ -# MAIN FUNCTION -# ================================================================================================ -if __name__ == "__main__": - - logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ - init_logging(os.path.dirname(os.path.abspath(__file__)), 'info') - - parser = ArgumentParser() - parser.add_argument("--config", type=str, default="./convlab2/dialcrowd_server/dialogueServer.cfg", - help="path of server config file to load") - args = parser.parse_args() - - # save_log_to_file() - logging.info(f"Config-file being used: {args.config}") - - dial_server = DialogueServer(args.config) - dial_server.run() - -# END OF FILE diff --git a/convlab/dialcrowd_server/Goal.py b/convlab/dialcrowd_server/Goal.py deleted file mode 100644 index 87ae7ee2ee86e2efe9c922cb79edcd40a2a7cef2..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/Goal.py +++ /dev/null @@ -1,340 +0,0 @@ -import json -import os - -from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator -from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal -from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA - -# import reflect table -REF_SYS_DA_M = {} -for dom, ref_slots in REF_SYS_DA.items(): - dom = dom.lower() - REF_SYS_DA_M[dom] = {} - for slot_a, slot_b in ref_slots.items(): - if slot_a == 'Ref': - slot_b = 'ref' - REF_SYS_DA_M[dom][slot_a.lower()] = slot_b - REF_SYS_DA_M[dom]['none'] = 'none' -REF_SYS_DA_M['taxi']['phone'] = 'phone' -REF_SYS_DA_M['taxi']['car'] = 'car type' - -# Goal slot mapping table -mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'price': 'pricerange'}, - 'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name', - 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'}, - 'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'type': 'type'}, - 'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination', - 'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'}, - 'taxi': {'car': 'car type', 'phone': 'phone'}, - 'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'}, - 'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}} - -DEF_VAL_UNK = '?' # Unknown -DEF_VAL_DNC = 'dontcare' # Do not care -DEF_VAL_NUL = 'none' # for none -DEF_VAL_BOOKED = 'yes' # for booked -DEF_VAL_NOBOOK = 'no' # for booked -NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK, ""] - -ref_slot_data2stand = { - 'train': { - 'duration': 'time', 'price': 'ticket', 'trainid': 'id' - } -} - - -class Goal(object): - """ User Goal Model Class. """ - - def __init__(self, goal): - self.domain_goals = _process_goal(goal) - self.domains = [d for d in self.domain_goals] - - path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - path = os.path.join(path, 'data/multiwoz/all_value.json') - self.all_values = json.load(open(path)) - - self.init_info_record() - self.actions = None - self.evaluator = MultiWozEvaluator() - self.evaluator.add_goal(self.domain_goals) - self.cur_domain = None - - def init_info_record(self): - self.info = {} - for domain in self.domains: - if 'info' in self.domain_goals[domain].keys(): - self.info[domain] = {} - for slot in self.domain_goals[domain]['info']: - self.info[domain][slot] = DEF_VAL_NUL - - def add_sys_da(self, sys_act, belief_state): - self.evaluator.add_sys_da(sys_act, belief_state) - self.update_user_goal(sys_act, belief_state) - - def add_usr_da(self, usr_act): - self.evaluator.add_usr_da(usr_act) - - usr_domain = [d for i, d, s, v in usr_act][0] if usr_act else self.cur_domain - usr_domain = usr_domain if usr_domain else 'general' - self.cur_domain = usr_domain if usr_domain.lower() not in ['general', 'booking'] else self.cur_domain - - def task_complete(self): - """ - Check that all requests have been met - Returns: - (boolean): True to accomplish. - """ - if self.evaluator.success == 1: - return True - for domain in self.domains: - if 'reqt' in self.domain_goals[domain]: - reqt_vals = self.domain_goals[domain]['reqt'].values() - for val in reqt_vals: - if val in NOT_SURE_VALS: - return False - if 'booked' in self.domain_goals[domain]: - if self.domain_goals[domain]['booked'] in NOT_SURE_VALS: - return False - return True - - def __str__(self): - return '-----Goal-----\n' + \ - json.dumps(self.domain_goals, indent=4) + \ - '\n-----Goal-----' - - def get_booking_domain(self, slot, value, all_values): - for domain in self.domains: - if slot in all_values["all_value"] and value in all_values["all_value"][slot]: - return domain - print("NOT FOUND BOOKING DOMAIN") - return "" - - def update_user_goal(self, action=None, state=None): - # update request and booked - if action: - self._update_user_goal_from_action(action) - if state: - self._update_user_goal_from_state(state) - self._check_booked(state) # this should always check - - if action is None and state is None: - print("Warning!!!! Both action and state are None") - - def _check_booked(self, state): - for domain in self.domains: - if "booked" in self.domain_goals[domain]: - if self._check_book_info(state, domain): - self.domain_goals[domain]["booked"] = DEF_VAL_BOOKED - else: - self.domain_goals[domain]["booked"] = DEF_VAL_NOBOOK - - def _check_book_info(self, state, domain): - # need to check info, reqt for booked? - if domain not in state: - return False - - for slot_type in ['info', 'book']: - for slot in self.domain_goals[domain].get(slot_type, {}): - user_value = self.domain_goals[domain][slot_type][slot] - if slot in state[domain]["semi"]: - state_value = state[domain]["semi"][slot] - - elif slot in state[domain]["book"]: - state_value = state[domain]["book"][slot] - else: - state_value = "" - # only check mentioned values (?) - if state_value and state_value != user_value: - # print( - # f"booking info is incorrect, for slot {slot}: " - # f"goal {user_value} != state {state_value}") - return False - - return True - - def _update_user_goal_from_action(self, action): - for intent, domain, slot, value in action: - # print("update user goal from action") - # print(intent, domain, slot, value) - # print("action:", intent) - domain = domain.lower() - value = value.lower() - slot = slot.lower() - if slot == "ref": # TODO ref!!!! not bug free!!!! - for usr_domain in self.domains: - if "booked" in self.domain_goals[usr_domain]: - self.domain_goals[usr_domain]["booked"] = DEF_VAL_BOOKED - else: - domain, slot = self._norm_domain_slot(domain, slot, value) - - if self._check_update_request(domain, slot) and value != "?": - self.domain_goals[domain]['reqt'][slot] = value - # print(f"update reqt {slot} = {value} from system action") - - if intent.lower() == 'inform': - if domain.lower() in self.domain_goals: - if 'reqt' in self.domain_goals[domain.lower()]: - if REF_SYS_DA_M.get(domain, {}).get(slot, slot) in self.domain_goals[domain]['reqt']: - if value in NOT_SURE_VALS: - value = '\"' + value + '\"' - self.domain_goals[domain]['reqt'][REF_SYS_DA_M.get(domain, {}).get(slot, slot)] = value - - if domain not in ['general', 'booking']: - self.cur_domain = domain - - if domain and intent and slot: - dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}' - else: - dial_act = '' - - if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']: - if self.cur_domain in self.domain_goals and 'booked' in self.domain_goals[self.cur_domain.lower()]: - self.domain_goals[self.cur_domain.lower()]['booked'] = DEF_VAL_BOOKED - elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref': - if 'train' in self.domain_goals and 'booked' in self.domain_goals['train']: - self.domain_goals['train']['booked'] = DEF_VAL_BOOKED - elif dial_act == 'taxi-inform-car': - if 'taxi' in self.domain_goals and 'booked' in self.domain_goals['taxi']: - self.domain_goals['taxi']['booked'] = DEF_VAL_BOOKED - if intent.lower() in ['book', 'offerbooked'] and self.cur_domain.lower() in self.domain_goals: - if 'booked' in self.domain_goals[self.cur_domain.lower()]: - self.domain_goals[self.cur_domain.lower()]['booked'] = DEF_VAL_BOOKED - - def _norm_domain_slot(self, domain, slot, value): - if domain == "booking": - # ["book", "booking", "people", 7] - if slot in SysDa2Goal[domain]: - slot = SysDa2Goal[domain][slot] - domain = self._get_booking_domain(slot, value) - else: - domain = "" - for d in SysDa2Goal: - if slot in SysDa2Goal[d]: - domain = d - slot = SysDa2Goal[d][slot] - if not domain: # TODO make sure what happened! - return "", "" - return domain, slot - - elif domain in self.domains: - if slot in SysDa2Goal[domain]: - # ["request", "restaurant", "area", "north"] - slot = SysDa2Goal[domain][slot] - elif slot in UsrDa2Goal[domain]: - slot = UsrDa2Goal[domain][slot] - elif slot in SysDa2Goal["booking"]: - # ["inform", "hotel", "stay", 2] - slot = SysDa2Goal["booking"][slot] - # else: - # print( - # f"UNSEEN SLOT IN UPDATE GOAL {intent, domain, slot, value}") - return domain, slot - - else: - # domain = general - return "", "" - - def _update_user_goal_from_state(self, state): - for domain in state: - for slot in state[domain]["semi"]: - if self._check_update_request(domain, slot): - self._update_user_goal_from_semi(state, domain, slot) - for slot in state[domain]["book"]: - if slot == "booked" and state[domain]["book"]["booked"]: - self._update_booked(state, domain) - - elif state[domain]["book"][slot] and self._check_update_request(domain, slot): - self._update_book(state, domain, slot) - - def _update_slot(self, domain, slot, value): - self.domain_goals[domain]['reqt'][slot] = value - - def _update_user_goal_from_semi(self, state, domain, slot): - if self._check_value(state[domain]["semi"][slot]): - self._update_slot(domain, slot, state[domain]["semi"][slot]) - # print("update reqt {} in semi".format(slot), - # state[domain]["semi"][slot]) - - def _update_booked(self, state, domain): - # check state and goal is fulfill - self.domain_goals[domain]["booked"] = DEF_VAL_BOOKED - print("booked") - for booked_slot in state[domain]["book"]["booked"][0]: - if self._check_update_request(domain, booked_slot): - self._update_slot(domain, booked_slot, - state[domain]["book"]["booked"][0][booked_slot]) - # print("update reqt {} in booked".format(booked_slot), - # self.domain_goals[domain]['reqt'][booked_slot]) - - def _update_book(self, state, domain, slot): - if self._check_value(state[domain]["book"][slot]): - self._update_slot(domain, slot, state[domain]["book"][slot]) - # print("update reqt {} in book".format(slot), - # state[domain]["book"][slot]) - - def _check_update_request(self, domain, slot): - # check whether one slot is a request slot - if domain not in self.domain_goals: - return False - if 'reqt' not in self.domain_goals[domain]: - return False - if slot not in self.domain_goals[domain]['reqt']: - return False - return True - - def _check_value(self, value=None): - if not value: - return False - if value in NOT_SURE_VALS: - return False - return True - - def _get_booking_domain(self, slot, value): - """ - find the domain for domain booking, excluding slot "ref" - """ - found = "" - if not slot: # work around - return found - slot = slot.lower() - value = value.lower() - for domain in self.all_values["all_value"]: - if slot in self.all_values["all_value"][domain]: - if value in self.all_values["all_value"][domain][slot]: - if domain in self.domains: - found = domain - return found - - -def _process_goal(tasks): - goal = {} - for task in tasks['tasks']: - goal[task['Dom'].lower()] = {} - if task['Book']: - goal[task['Dom'].lower()]['booked'] = DEF_VAL_UNK - goal[task['Dom'].lower()]['book'] = {} - for con in task['Book'].split(', '): - slot, val = con.split('=', 1) - slot = mapping[task['Dom'].lower()].get(slot, slot) - goal[task['Dom'].lower()]['book'][slot] = val - if task['Cons']: - goal[task['Dom'].lower()]['info'] = {} - goal[task['Dom'].lower()]['fail_info'] = {} - for con in task['Cons'].split(', '): - slot, val = con.split('=', 1) - slot = mapping[task['Dom'].lower()].get(slot, slot) - if " (otherwise " in val: - value = val.split(" (if unavailable use: ") - goal[task['Dom'].lower()]['fail_info'][slot] = value[0] - goal[task['Dom'].lower()]['info'][slot] = value[1][:-1] - else: - goal[task['Dom'].lower()]['info'][slot] = val - - if task['Reqs']: - goal[task['Dom'].lower()]['reqt'] = {mapping[task['Dom'].lower()].get(s, s): DEF_VAL_UNK for s in - task['Reqs'].split(', ')} - - return goal \ No newline at end of file diff --git a/convlab/dialcrowd_server/SubjectiveFeedbackManager.py b/convlab/dialcrowd_server/SubjectiveFeedbackManager.py deleted file mode 100644 index 110762b46b90f65fa66c217e6c63daff9c02cc77..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/SubjectiveFeedbackManager.py +++ /dev/null @@ -1,224 +0,0 @@ -''' -SubjectiveFeedbackManager.py - Update policy according to subjective feedback -========================================================================== - -@author: Songbo, Chris - -''' - -from configparser import ConfigParser -import torch -import pickle -import os -import time -import logging - -# from convlab2.util.train_util import save_to_bucket - - -class SubjectiveFeedbackManager(object): - - def __init__(self, configPath, policy, agent_name=""): - self.sys_dialogue_utterance = [] - self.sys_dialogue_state_vec = [] - self.sys_action_mask_vec = [] - self.sys_dialogue_act_vec = [] - self.sys_dialogue_reward_vec = [] - self.sys_dialogue_mask_vec = [] - self.agent_name = agent_name - configparser = ConfigParser() - configparser.read(configPath) - self.turn_reward = int(configparser.get("SUBJECTIVE", "turnReward")) - self.subject_reward = int( - configparser.get("SUBJECTIVE", "subjectReward")) - self.updatePerSession = int( - configparser.get("SUBJECTIVE", "updatePerSession")) - self.memory = Memory() - - # All policy update is done by this instances. - self.policy = policy - self.add_counter = 0 - self.add_counter_total = 0 - - self.trainingEpoch = int( - configparser.get("SUBJECTIVE", "trainingEpoch")) - - current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - self.save_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - f"policy/AMT/AMT_REAL_{agent_name}_{current_time}") - os.makedirs(self.save_dir, exist_ok=True) - - def init_list(self): - self.sys_dialogue_utterance = [] - self.sys_dialogue_state_vec = [] - self.sys_dialogue_act_vec = [] - self.sys_dialogue_reward_vec = [] - self.sys_dialogue_mask_vec = [] - self.sys_action_mask_vec = [] - self.add_counter = 0 - - def get_reward_vector(self, state_list, act_list, isGoalAchieved): - assert len(state_list) == len(act_list) - reward_vector = [] - for i in range(1, len(state_list)): - reward_vector.append(self.turn_reward) - if isGoalAchieved: - reward_vector.append(2 * self.subject_reward) - else: - reward_vector.append(-self.subject_reward) - return reward_vector - - def get_mask_vector(self, state_list, act_list): - assert len(state_list) == len(act_list) - mask_vector = [] - for i in range(1, len(state_list)): - mask_vector.append(1) - mask_vector.append(0) - return mask_vector - - def add_state_action_lists(self, utterance_list, state_list, act_list, isGoalAchieved, reward_list, task_id, - system_outputs, prob_history): - assert len(state_list) == len(act_list) and isGoalAchieved in [True, False] - - self.sys_dialogue_utterance.extend(utterance_list) - - state_list_vec = [] - action_mask_vec = [] - for s in state_list: - s_vec, mask_ = self.policy.vector.state_vectorize( - s, output_mask=True) - mask_ = mask_ + [0, 0] - state_list_vec.append(s_vec) - action_mask_vec.append(mask_) - - self.sys_dialogue_state_vec.extend(state_list_vec) - self.sys_action_mask_vec.extend(action_mask_vec) - - reward_list_new = self.get_reward_vector( - state_list, act_list, isGoalAchieved) - - try: - action_list_vec = list( - map(self.policy.vector.action_vectorize, act_list)) - self.sys_dialogue_act_vec.extend(action_list_vec) - except: - # we assume the acts are already action_vectorized - self.sys_dialogue_act_vec.extend(act_list) - - # TODO: Change the reward here!! - self.sys_dialogue_reward_vec.extend(reward_list_new) - - self.sys_dialogue_mask_vec.extend( - self.get_mask_vector(state_list, act_list)) - self.add_counter += 1 - self.add_counter_total += 1 - - logging.info( - f"Added dialog, we now have {self.add_counter} dialogs in total.") - - try: - if hasattr(self.policy, "last_action"): - if len(state_list_vec) == 0 or len(act_list) == 0 or len(reward_list) == 0: - pass - else: - self.policy.update_memory( - utterance_list, state_list_vec, act_list, reward_list_new) - self.memory.add_experience(utterance_list, state_list, state_list_vec, act_list, reward_list, - isGoalAchieved, task_id, system_outputs, prob_history) - else: - self.policy.update_memory( - utterance_list, state_list_vec, action_list_vec, reward_list_new) - self.memory.add_experience(utterance_list, state_list, state_list_vec, action_list_vec, reward_list, - isGoalAchieved, task_id, system_outputs, prob_history) - except: - pass - print("Session Added to FeedbackManager {}".format(self.add_counter)) - # if(self.add_counter % self.updatePerSession == 0 and self.add_counter > 400): - if (self.add_counter % self.updatePerSession == 0 and self.add_counter > 0): - - logging.info("Manager updating policy.") - try: - self.updatePolicy() - logging.info("Successfully updated policy.") - except Exception as e: - logging.info("Couldnt update policy. Exception: ", e) - - logging.info("Saving AMT memory") - self.memory.save(self.save_dir) - try: - self.save_into_bucket() - except: - print("SubjectiveFeedbackManager: Could not save into bucket") - - def updatePolicy(self): - try: - train_state_list = torch.Tensor(self.sys_dialogue_state_vec) - train_act_list = torch.Tensor(self.sys_dialogue_act_vec) - train_reward_list = torch.Tensor(self.sys_dialogue_reward_vec) - train_mask_list = torch.Tensor(self.sys_dialogue_mask_vec) - train_action_mask_list = torch.Tensor(self.sys_action_mask_vec) - batchsz = train_state_list.size()[0] - except: - train_state_list = (self.sys_dialogue_state_vec) - train_act_list = (self.sys_dialogue_act_vec) - train_reward_list = (self.sys_dialogue_reward_vec) - train_mask_list = (self.sys_dialogue_mask_vec) - train_action_mask_list = (self.sys_action_mask_vec) - batchsz = 32 - - for i in range(self.trainingEpoch): - # print(train_state_list) - # print(train_action_mask_list) - # print(train_act_list) - # print(train_reward_list) - self.policy.update(i, batchsz, train_state_list, train_act_list, train_reward_list, train_mask_list, - train_action_mask_list) - if self.policy.is_train: - self.policy.save(self.save_dir) - - if self.add_counter_total % 200 == 0: - self.policy.save( - self.save_dir, addition=f"_{self.add_counter}") - - # Empty the current batch. This is needed for on-policy algorithms like PPO. - self.init_list() - - def getUpdatedPolicy(self): - return self.policy - - def save_into_bucket(self): - print(f"Saving into bucket from {self.save_dir}") - bucket_save_name = self.save_dir.split('/')[-1] - for file in os.listdir(self.save_dir): - save_to_bucket('geishauser', f'AMT_Experiments/{bucket_save_name}/{file}', - os.path.join(self.save_dir, file)) - - -class Memory: - - def __init__(self): - self.utterances = [] - self.raw_states = [] - self.states = [] - self.actions = [] - self.rewards = [] - self.feedback = [] - self.task_id = [] - self.sys_outputs = [] - self.action_probs = [] - - def add_experience(self, utterances, raw_states, states, actions, rewards, feedback, task_id, system_outputs, - prob_history): - self.utterances.append(utterances) - self.raw_states.append(raw_states) - self.states.append(states) - self.actions.append(actions) - self.rewards.append(rewards) - self.feedback.append(feedback) - self.task_id.append(task_id) - self.sys_outputs.append(system_outputs) - self.action_probs.append(prob_history) - - def save(self, directory): - with open(directory + '/' + 'AMT_memory.pkl', 'wb') as output: - pickle.dump(self, output, pickle.HIGHEST_PROTOCOL) diff --git a/convlab/dialcrowd_server/__init__.py b/convlab/dialcrowd_server/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/convlab/dialcrowd_server/agents/base_agent.py b/convlab/dialcrowd_server/agents/base_agent.py deleted file mode 100644 index a3d53ef4638ebe6dfc081f9878b1ebf9f6465a94..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/agents/base_agent.py +++ /dev/null @@ -1,19 +0,0 @@ -''' - -Build up an pipeline agent with nlu, dst, policy and nlg. - -@author: Chris Geishauser -''' - -from convlab2.dialog_agent.agent import DialogueAgent - - -class BaseAgent(DialogueAgent): - - def __init__(self, config, policy_sys): - - nlu = config['nlu_sys_activated'] - dst = config['dst_sys_activated'] - nlg = config['sys_nlg_activated'] - - super().__init__(nlu, dst, policy_sys, nlg) diff --git a/convlab/dialcrowd_server/agents/ddpt_agent.cfg b/convlab/dialcrowd_server/agents/ddpt_agent.cfg deleted file mode 100644 index 6bf179e4c049322c41d346a80b1f7fcfa727f4fc..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/agents/ddpt_agent.cfg +++ /dev/null @@ -1,19 +0,0 @@ -[GENERAL] -host = 0.0.0.0 -port = 5001 -task_file = task.out - -[AGENT] -agentPath = convlab2.dialcrowd_server.agents.ddpt_agent -agentClass = Agent -maxTurn = 40 -maxNumberAgent = 20 -dialogueSave = "" -saveFlag = True - -[SUBJECTIVE] -enabled = False -turnReward = -1 -subjectReward = 40 -updatePerSession = 10 -trainingEpoch = 10 \ No newline at end of file diff --git a/convlab/dialcrowd_server/agents/ddpt_agent.py b/convlab/dialcrowd_server/agents/ddpt_agent.py deleted file mode 100644 index 5796a5cf782f59269a9ebbce8b12d4780a655f89..0000000000000000000000000000000000000000 --- a/convlab/dialcrowd_server/agents/ddpt_agent.py +++ /dev/null @@ -1,25 +0,0 @@ -''' - -Build up an pipeline agent with nlu, dst, policy and nlg. - -@author: Chris Geishauser -''' - -from convlab2.dialcrowd_server.agents.base_agent import BaseAgent -from convlab2.policy.vtrace_DPT import VTRACE -from convlab2.util.custom_util import get_config - - -class DDPTAgent(BaseAgent): - - def __init__(self): - - config_path = "" - conf = get_config(config_path, []) - - policy = VTRACE(vectorizer=conf['vectorizer_sys_activated']) - policy.load(conf['model']['load_path']) - - super().__init__(conf, policy) - - self.agent_name = "DDPT"