Skip to content
Snippets Groups Projects
Select Git revision
  • 1c614c730f081f9769a4e3b2b3a9989b1ed0f1c2
  • master default protected
2 results

webserver.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    session.py 6.29 KiB
    """Dialog controller classes."""
    from abc import ABC, abstractmethod
    import random
    from convlab2.dialog_agent.agent import Agent
    
    
    class Session(ABC):
        """Base dialog session controller, which manages the agents to conduct a complete dialog session.
        """
    
        @abstractmethod
        def next_agent(self):
            """Decide the next agent to generate a response.
    
            In this base class, this function returns the index randomly.
    
            Returns:
                next_agent (Agent): The index of the next agent.
            """
            pass
    
        @abstractmethod
        def next_response(self, observation):
            """Generated the next response.
            
            Args:
                observation (str or dict): The agent observation of next agent.
            Returns:
                response (str or dict): The agent's response.
            """
            pass
    
        @abstractmethod
        def init_session(self):
            """Init the agent variables for a new session."""
            pass
    
    
    class BiSession(Session):
        """The dialog controller which aggregates several agents to conduct a complete dialog session.
    
        Attributes:
            sys_agent (Agent):
                system dialog agent.
    
            user_agent (Agent):
                user dialog agent.
    
            kb_query (KBquery):
                knowledge base query tool.
    
            dialog_history (list):
                The dialog history, formatted as [[user_uttr1, sys_uttr1], [user_uttr2, sys_uttr2], ...]
        """
    
        def __init__(self, sys_agent: Agent, user_agent: Agent, kb_query=None, evaluator=None):
            """
            Args:
                sys_agent (Agent):
                    An instance of system agent.
    
                user_agent (Agent):
                    An instance of user agent.
    
                kb_query (KBquery):
                    An instance of database query tool.
    
                evaluator (Evaluator):
                    An instance of evaluator.
            """
            self.sys_agent = sys_agent
            self.user_agent = user_agent
            self.kb_query = kb_query
            self.evaluator = evaluator
    
            self.dialog_history = []
            self.__turn_indicator = 0
    
            self.init_session()
    
        def next_agent(self):
            """The user and system agent response in turn."""
            if self.__turn_indicator % 2 == 0:
                next_agent = self.user_agent
            else:
                next_agent = self.sys_agent
            self.__turn_indicator += 1
            return next_agent
    
        def next_response(self, observation):
            next_agent = self.next_agent()
            response = next_agent.response(observation)
            return response
    
        def next_turn(self, last_observation):
            """Conduct a new turn of dialog, which consists of the system response and user response.
    
            The variable type of responses can be either 1) str or 2) dialog act, depends on the dialog mode settings of the
            two agents which are supposed to be the same.
            
            Args:
                last_observation:
                    Last agent response.
            Returns:
                sys_response:
                    The response of system.
    
                user_response:
                    The response of user simulator.
    
                session_over (boolean):
                    True if session ends, else session continues.
    
                reward (float):
                    The reward given by the user.
            """
            user_response = self.next_response(last_observation)
            if self.evaluator:
                self.evaluator.add_sys_da(self.user_agent.get_in_da())
                self.evaluator.add_usr_da(self.user_agent.get_out_da())
            session_over = self.user_agent.is_terminated()
            if hasattr(self.sys_agent, 'dst'):
                self.sys_agent.dst.state['terminated'] = session_over
            # if session_over and self.evaluator:
                # prec, rec, f1 = self.evaluator.inform_F1()
                # print('inform prec. {} rec. {} F1 {}'.format(prec, rec, f1))
                # print('book rate {}'.format(self.evaluator.book_rate()))
                # print('task success {}'.format(self.evaluator.task_success()))
            reward = self.user_agent.get_reward()
            sys_response = self.next_response(user_response)
            self.dialog_history.append([self.user_agent.name, user_response])
            self.dialog_history.append([self.sys_agent.name, sys_response])
    
            return sys_response, user_response, session_over, reward
    
        def train_policy(self):
            """
            Train the parameters of system agent policy.
            """
            self.sys_agent.policy.train()
    
        def init_session(self, **kwargs):
            self.sys_agent.init_session()
            self.user_agent.init_session(**kwargs)
            if self.evaluator:
                self.evaluator.add_goal(self.user_agent.policy.get_goal())
            self.dialog_history = []
            self.__turn_indicator = 0
    
    
    class DealornotSession(Session):
        """A special session for Deal or Not dataset, which is a object dividing negotiation task."""
    
        def __init__(self, alice, bob):
            self.alice = alice
            self.bob = bob
            self.__turn_indicator = 0
            self.init_session()
            self.current_agent = None
            self.dialog_history = []
    
        def next_agent(self):
            """Alice and Bob agents response in turn."""
            if self.__turn_indicator % 2 == 0:
                next_agent = self.alice
            else:
                next_agent = self.bob
            self.__turn_indicator += 1
            return next_agent
    
        def next_response(self, observation):
            agent = self.next_agent()
            self.current_agent = agent
            model_response = self.current_agent.response(observation)
            self.dialog_history.append(model_response)
            return model_response
    
        def is_terminated(self):
            if self.current_agent.is_terminated():
                return True
    
        def get_rewards(self, ctxs):
            """Return the rewards of alice and bob.
    
            Returns:
                reward_1 (float):
                    Reward of Alice.
    
                reward_2 (float):
                    Reward of Bob.
            """
            choices = []
            for agent in [self.alice, self.bob]:
                choice = agent.choose()
                choices.append(choice)
            agree, rewards = self.alice.domain.score_choices(choices, ctxs)
            return agree, rewards
    
        def init_session(self):
            self.__turn_indicator = random.choice([0, 1])
            self.alice.init_session()
            self.bob.init_session()
            self.current_agent = None
            self.dialog_history = []