Skip to content
Snippets Groups Projects
Select Git revision
  • main
1 result

DialogueServer.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    DialogueServer.py 9.24 KiB
    '''
    DialogueServer.py - Web interface for DialCrowd.
    ==========================================================================
    
    @author: Songbo
    
    '''
    
    from configparser import ConfigParser
    from convlab2.dialcrowd_server.AgentFactory import AgentFactory
    from datetime import datetime
    import json
    import http.server
    import os
    import sys
    from argparse import ArgumentParser
    
    project_path = (os.path.dirname(os.path.dirname(
        os.path.dirname(os.path.abspath(__file__)))))
    sys.path.append(project_path)
    
    
    # ================================================================================================
    # SERVER BEHAVIOUR
    # ================================================================================================
    
    
    def make_request_handler_class(dialServer):
        """
        """
    
        class RequestHandler(http.server.BaseHTTPRequestHandler):
            '''
                Process HTTP Requests
                :return:
            '''
    
            def do_POST(self):
                '''
                Handle only POST requests. Please note that GET requests ARE NOT SUPPORTED! :)
                '''
                self.error_free = True  # boolean which can become False if we encounter a problem
    
                agent_id = None
                self.currentSession = None
                self.currentUser = None
                prompt_str = ''
                reply = {}
    
                # ------Get the "Request" link.
                print('-' * 30)
    
                request = self.path[1:] if self.path.find(
                    '?') < 0 else self.path[1:self.path.find('?')]
    
                print('Request: ' + str(request))
                print('POST full path: %s ' % self.path)
    
                if not 'Content-Length' in self.headers:
                    data_string = self.path[self.path.find('?') + 1:]
                else:
                    data_string = self.rfile.read(
                        int(self.headers['Content-Length']))
    
                # contains e.g:  {"session": "voip-5595158237"}
                print("Request Data:" + str(data_string))
    
                recognition_fail = True  # default until we confirm we have received data
                try:
                    data = json.loads(data_string)  # ValueError
                    self.currentSession = data["sessionID"]  # KeyError
                    self.currentUser = data["userID"] # KeyError
                except Exception as e:
                    print(
                        "Not a valid JSON object (or object lacking info) received. %s" % e)
                else:
                    recognition_fail = False
    
                if request == 'init':
                    try:
                        agent_id = dialServer.agent_factory.start_call(
                            session_id=self.currentSession)
                        reply = dialServer.prompt(
                            dialServer.agent_factory.willkommen_message, session_id=self.currentSession)
                    except:
                        self.error_free = False
                        print(
                            "Tried to start a new call with a session id: {} already in use".format(self.currentSession))
                    else:
                        print("A new call has started. Session: %s " %
                              self.currentSession)
    
                elif request == 'next':
    
                    # Next step in the conversation flow
                    # map session_id to agent_id
    
                    try:
                        agent_id = dialServer.agent_factory.retrieve_agent(
                            session_id=self.currentSession)
                    except:  # Throws a ExceptionRaisedByLogger
                        self.error_free = False
                        print(
                            "Tried to get an agent for the non-existent session id: {}".format(self.currentSession))
                    else:
                        print("Continuing session: %s with agent_id %s " %
                              (self.currentSession, agent_id))
                    if self.error_free:
    
                        userUtterance = data["text"]  # KeyError
                        user_id = data["userID"]
                        print("Received user utterance {}".format(userUtterance))
                        prompt_str = dialServer.agent_factory.continue_call(
                            agent_id, user_id, userUtterance)
    
                        if(prompt_str == dial_server.agent_factory.ending_message):
                            reply = dialServer.prompt(
                                prompt_str, session_id=self.currentSession, isfinal=True)
                        else:
                            reply = dialServer.prompt(
                                prompt_str, session_id=self.currentSession, isfinal=False)
                    else:
                        reply = None
    
                elif request == 'end':
    
                    # Request to stop the session.
    
                    print(
                        "Received request to Clean Session ID from the VoiceBroker...:" + self.currentSession)
    
                    self.error_free = False
    
                    try:
                        agent_id = dialServer.agent_factory.end_call(
                            session_id=self.currentSession)
                    except:  # an ExceptionRaisedByLogger
                        print(
                            "Tried to get an agent for the non-existent session id: {}".format(self.currentSession))
    
                # ------ Completed turn --------------
    
                # POST THE REPLY BACK TO THE SPEECH SYSTEM
                print("Sending prompt: " + prompt_str + " to tts.")
                self.send_response(200)  # 200=OK W3C HTTP Standard codes
                self.send_header('Content-type', 'text/json')
                self.end_headers()
                print(reply)
                self.wfile.write(reply.encode('utf-8'))
                print(dialServer.agent_factory.session2agent)
        return RequestHandler
    
    
    # ================================================================================================
    # DIALOGUE SERVER
    # ================================================================================================
    
    class DialogueServer(object):
    
        '''
         This class implements an HTTP Server
        '''
    
        def __init__(self, configPath):
            """    HTTP Server
            """
    
            configparser = ConfigParser()
            configparser.read(configPath)
            host = (configparser.get("GENERAL", "host"))
            port = int(configparser.get("GENERAL", "port"))
            agentPath = (configparser.get("AGENT", "agentPath"))
            agentClass = (configparser.get("AGENT", "agentClass"))
            dialogueSave = (configparser.get("AGENT", "dialogueSave"))
            saveFlag = False
    
            if configparser.get("AGENT", "saveFlag") == "True":
                saveFlag = True
    
            mod = __import__(agentPath, fromlist=[agentClass])
            klass = getattr(mod, agentClass)
            self.host = host
            self.port = port
            self.agent_factory = AgentFactory(configPath, dialogueSave, saveFlag)
    
            print("Server init")
    
        def run(self):
            """Listen to request in host dialhost and port dialport"""
    
            RequestHandlerClass = make_request_handler_class(self)
    
            server = http.server.HTTPServer(
                (self.host, self.port), RequestHandlerClass)
            print('Server starting %s:%s (level=%s)' %
                  (self.host, self.port, 'info'))
    
            try:
                while 1:
                    server.serve_forever()
            except KeyboardInterrupt:
                pass
            finally:
                print('Server stopping %s:%s' % (self.host, self.port))
                server.server_close()
    
                self.agent_factory.power_down_factory()
    
        def prompt(self, prompt, session_id, isfinal=False):
            '''
            Create a prompt, for the moment the arguments are
    
            :param prompt: the text to be prompt
            :param isfinal:  if it is the final sentence before the end of dialogue
            :return: reply in json
            '''
    
            reply = {}
            reply["sessionID"] = session_id
            reply["version"] = "0.1"
            reply["terminal"] = isfinal
            reply['sys'] = self._clean_text(prompt)
            reply["timeStamp"] = datetime.now().isoformat()
    
            print(reply)
            return json.dumps(reply, ensure_ascii=False)
    
        def _clean_text(self, RAW_TEXT):
            """
            """
            # The replace() is because of how words with ' come out of the Template SemO.
            JUNK_CHARS = ['(', ')', '{', '}', '<', '>', '"', "'"]
            return ''.join(c for c in RAW_TEXT.replace("' ", "") if c not in JUNK_CHARS)
    
    
    def save_log_to_file():
        import time
    
        dir_name = os.path.join(os.path.dirname(
            os.path.abspath(__file__)), 'dialogueServer_LOG')
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)
        current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
        output_file = open(os.path.join(
            dir_name, f"stdout_{current_time}.txt"), 'w')
        sys.stdout = output_file
        sys.stderr = output_file
    
    
    # ================================================================================================
    # MAIN FUNCTION
    # ================================================================================================
    if __name__ == "__main__":
    
        parser = ArgumentParser()
        parser.add_argument("--config", type=str, default="./convlab2/dialcrowd_server/dialogueServer.cfg",
                            help="path of server config file to load")
        args = parser.parse_args()
    
        # save_log_to_file()
        print(f"Config-file being used: {args.config} \n")
    
        dial_server = DialogueServer(args.config)
        dial_server.run()
    
    # END OF FILE