Skip to content
Snippets Groups Projects
Select Git revision
  • e6178cc979e166c4f9584c9f7b73c7518a722ed6
  • develop default
  • release protected
  • v0.x
  • v2.2.0
  • v2.1.0
6 results

index.ts

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    agent.py 6.72 KiB
    """Dialog agent interface and classes."""
    from abc import ABC, abstractmethod
    from convlab2.nlu import NLU
    from convlab2.dst import DST
    from convlab2.policy import Policy
    from convlab2.nlg import NLG
    from copy import deepcopy
    
    
    class Agent(ABC):
        """Interface for dialog agent classes."""
        @abstractmethod
        def __init__(self, name: str):
            self.name = name
    
        @abstractmethod
        def response(self, observation):
            """Generate agent response given user input.
    
            The data type of input and response can be either str or list of tuples, condition on the form of agent.
    
            Example:
                If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and
                type(response) == list of tuples.
            Args:
                observation (str or list of tuples):
                    The input to the agent.
            Returns:
                response (str or list of tuples):
                    The response generated by the agent.
            """
            pass
    
        @abstractmethod
        def init_session(self, **kwargs):
            """Reset the class variables to prepare for a new session."""
            pass
    
    
    class PipelineAgent(Agent):
        """Pipeline dialog agent base class, including NLU, DST, Policy and NLG.
    
        The combination modes of pipeline agent modules are flexible. The only thing you have to make sure is that
        the API of agents are matched.
    
        Example:
            If agent A is (nlu, tracker, policy), then the agent B should be like (tracker, policy, nlg) to ensure API
            matching.
        The valid module combinations are as follows:
               =====   =====    ======  ===     ==      ===
                NLU     DST     Policy  NLG     In      Out
               =====   =====    ======  ===     ==      ===
                \+      \+        \+    \+      nl      nl
                 o      \+        \+    \+      da      nl
                 o      \+        \+     o      da      da
                \+      \+        \+     o      nl      da
                 o       o        \+     o      da      da
               =====   =====    ======  ===     ==      ===
        """
    
        def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str):
            """The constructor of PipelineAgent class.
    
            Here are some special combination cases:
    
                1. If you use word-level DST (such as Neural Belief Tracker), you should set the nlu_model paramater \
                 to None. The agent will combine the modules automitically.
    
                2. If you want to aggregate DST and Policy as a single module, set tracker to None.
    
            Args:
                nlu (NLU):
                    The natural langauge understanding module of agent.
    
                dst (DST):
                    The dialog state tracker of agent.
    
                policy (Policy):
                    The dialog policy module of agent.
    
                nlg (NLG):
                    The natural langauge generator module of agent.
            """
            super(PipelineAgent, self).__init__(name=name)
            assert self.name in ['user', 'sys']
            self.opponent_name = 'user' if self.name is 'sys' else 'sys'
            self.nlu = nlu
            self.dst = dst
            self.policy = policy
            self.nlg = nlg
            self.init_session()
            self.history = []
    
        def state_replace(self, agent_state):
            """
            this interface is reserved to replace all interal states of agent
            the code snippet example below is for the scenario when the agent state only depends on self.history and self.dst.state
            """
            self.history = deepcopy(agent_state['history'])
            self.dst.state = deepcopy(agent_state['dst_state'])
    
        def state_return(self):
            """
            this interface is reserved to return all interal states of agent
            the code snippet example below is for the scenario when the agent state only depends on self.history and self.dst.state
            """
            agent_state = {}
            agent_state['history'] = deepcopy(self.history)
            agent_state['dst_state'] = deepcopy(self.dst.state)
    
            return agent_state
    
    
        def response(self, observation):
            """Generate agent response using the agent modules."""
            # Note: If you modify the logic of this function, please ensure that it is consistent with deploy.server.ServerCtrl._turn()
            if self.dst is not None:
                self.dst.state['history'].append([self.opponent_name, observation]) # [['sys', sys_utt], ['user', user_utt],...]
            self.history.append([self.opponent_name, observation])
            # get dialog act
            if self.nlu is not None:
                self.input_action = self.nlu.predict(observation, context=[x[1] for x in self.history[:-1]])
            else:
                self.input_action = observation
            self.input_action = deepcopy(self.input_action) # get rid of reference problem
            # get state
            if self.dst is not None:
                if self.name is 'sys':
                    self.dst.state['user_action'] = self.input_action
                else:
                    self.dst.state['system_action'] = self.input_action
                state = self.dst.update(self.input_action)
            else:
                state = self.input_action
            state = deepcopy(state) # get rid of reference problem
            # get action
            self.output_action = deepcopy(self.policy.predict(state)) # get rid of reference problem
            # get model response
            if self.nlg is not None:
                model_response = self.nlg.generate(self.output_action)
            else:
                model_response = self.output_action
            # print(model_response)
            if self.dst is not None:
                self.dst.state['history'].append([self.name, model_response])
                if self.name is 'sys':
                    self.dst.state['system_action'] = self.output_action
                else:
                    self.dst.state['user_action'] = self.output_action
            self.history.append([self.name, model_response])
            return model_response
    
        def is_terminated(self):
            if hasattr(self.policy, 'is_terminated'):
                return self.policy.is_terminated()
            return None
    
        def get_reward(self):
            if hasattr(self.policy, 'get_reward'):
                return self.policy.get_reward()
            return None
    
        def init_session(self, **kwargs):
            """Init the attributes of DST and Policy module."""
            if self.nlu is not None:
                self.nlu.init_session()
            if self.dst is not None:
                self.dst.init_session()
                if self.name == 'sys':
                    self.dst.state['history'].append([self.name, 'null'])
            if self.policy is not None:
                self.policy.init_session(**kwargs)
            if self.nlg is not None:
                self.nlg.init_session()
            self.history = []
    
        def get_in_da(self):
            return self.input_action
    
        def get_out_da(self):
            return self.output_action