Select Git revision
x3dom_button.html
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
train.py 1.50 KiB
import os
import torch
import logging
import json
import sys
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
sys.path.append(root_dir)
from convlab2.policy.rlmodule import MultiDiscretePolicy
from convlab2.policy.vector.vector_camrest import CamrestVector
from convlab2.policy.mle.train import MLE_Trainer_Abstract
from convlab2.policy.mle.multiwoz.loader import ActPolicyDataLoaderCamrest
from convlab2.util.train_util import init_logging_handler
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MLE_Trainer(MLE_Trainer_Abstract):
def __init__(self, manager, cfg):
self._init_data(manager, cfg)
voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
vector = CamrestVector(voc_file, voc_opp_file)
self.policy = MultiDiscretePolicy(vector.state_dim, cfg['h_dim'], vector.da_dim).to(device=DEVICE)
self.policy.eval()
self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=cfg['lr'])
if __name__ == '__main__':
manager = ActPolicyDataLoaderCamrest()
with open('config.json', 'r') as f:
cfg = json.load(f)
init_logging_handler(cfg['log_dir'])
agent = MLE_Trainer(manager, cfg)
logging.debug('start training')
best = float('inf')
for e in range(cfg['epoch']):
agent.imitating(e)
best = agent.imit_test(e, best)