Skip to content
Snippets Groups Projects
Select Git revision
  • 88f2d26317e4331013d8779c37bb2e00f418333e
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

train.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    train.py 1.50 KiB
    import os
    import torch
    import logging
    import json
    import sys
    root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
    sys.path.append(root_dir)
    
    from convlab2.policy.rlmodule import MultiDiscretePolicy
    from convlab2.policy.vector.vector_camrest import CamrestVector
    from convlab2.policy.mle.train import MLE_Trainer_Abstract
    from convlab2.policy.mle.multiwoz.loader import ActPolicyDataLoaderCamrest
    from convlab2.util.train_util import init_logging_handler
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class MLE_Trainer(MLE_Trainer_Abstract):
        def __init__(self, manager, cfg):
            self._init_data(manager, cfg)
            voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
            vector = CamrestVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(vector.state_dim, cfg['h_dim'], vector.da_dim).to(device=DEVICE)
            self.policy.eval()
            self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=cfg['lr'])
            
    if __name__ == '__main__':
        manager = ActPolicyDataLoaderCamrest()
        with open('config.json', 'r') as f:
            cfg = json.load(f)
        init_logging_handler(cfg['log_dir'])
        agent = MLE_Trainer(manager, cfg)
        
        logging.debug('start training')
        
        best = float('inf')
        for e in range(cfg['epoch']):
            agent.imitating(e)
            best = agent.imit_test(e, best)