Skip to content
Snippets Groups Projects
Select Git revision
  • fcfb467dda51871fb9aabd05eb7abe26031ead74
  • master default protected
2 results

x3dom_button.html

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    train.py 1.50 KiB
    import os
    import torch
    import logging
    import json
    import sys
    root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
    sys.path.append(root_dir)
    
    from convlab2.policy.rlmodule import MultiDiscretePolicy
    from convlab2.policy.vector.vector_camrest import CamrestVector
    from convlab2.policy.mle.train import MLE_Trainer_Abstract
    from convlab2.policy.mle.multiwoz.loader import ActPolicyDataLoaderCamrest
    from convlab2.util.train_util import init_logging_handler
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class MLE_Trainer(MLE_Trainer_Abstract):
        def __init__(self, manager, cfg):
            self._init_data(manager, cfg)
            voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
            voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
            vector = CamrestVector(voc_file, voc_opp_file)
            self.policy = MultiDiscretePolicy(vector.state_dim, cfg['h_dim'], vector.da_dim).to(device=DEVICE)
            self.policy.eval()
            self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=cfg['lr'])
            
    if __name__ == '__main__':
        manager = ActPolicyDataLoaderCamrest()
        with open('config.json', 'r') as f:
            cfg = json.load(f)
        init_logging_handler(cfg['log_dir'])
        agent = MLE_Trainer(manager, cfg)
        
        logging.debug('start training')
        
        best = float('inf')
        for e in range(cfg['epoch']):
            agent.imitating(e)
            best = agent.imit_test(e, best)