Skip to content
Snippets Groups Projects
Select Git revision
  • 6a800c444d6aaff27aaa24cc8a090ad1333beb27
  • main default protected
  • gentus-public
3 results

ppo_agent.py

Blame
  • user avatar
    Hsien-Chin Lin authored
    ffd2f250
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ppo_agent.py 923 B
    '''
    mel_agent.py - An example dialogue agent class
    ==========================================================================
    
    Build up an pipeline agent with nlu, dst, policy and nlg.
    
    @author: Songbo
    '''
    
    
    from convlab2.dialog_agent.agent import DialogueAgent
    from convlab2.nlu.jointBERT.multiwoz import BERTNLU
    from convlab2.dst.rule.multiwoz import RuleDST
    from convlab2.policy.ppo_rnn import PPO_RNN
    from convlab2.nlg.template.multiwoz import TemplateNLG
    
    
    class Agent(DialogueAgent):
    
        def __init__(self):
            nlu = BERTNLU()
            dst = RuleDST()
    
            model_path = "convlab2/policy/ppo_rnn/save/2020-08-05-14-34-04_best_complete_rate"
    
            policy = PPO_RNN(is_train=True)
            #TODO: Extend the max_size of the memory!!!
            policy.load(model_path)
    
            nlg = TemplateNLG(is_user=False)
            super().__init__(nlu, dst, policy, nlg)
    
            self.agent_name = "awac_noisy_gamma_shrink"