Skip to content
Snippets Groups Projects
Commit 78e53b16 authored by Christian's avatar Christian
Browse files

first version that trains an MLE (supervised) policy with the unified data...

first version that trains an MLE (supervised) policy with the unified data format. There are some TODOs left in vector_base, but the training runs and performance of 0.53 f1-score is achieved
parent 88f2d263
No related branches found
No related tags found
No related merge requests found
Showing
with 114 additions and 782 deletions
...@@ -86,3 +86,4 @@ test.py ...@@ -86,3 +86,4 @@ test.py
*.egg-info *.egg-info
pre-trained-models/ pre-trained-models/
venv
\ No newline at end of file
...@@ -10,7 +10,7 @@ import copy ...@@ -10,7 +10,7 @@ import copy
from convlab2.policy.policy import Policy from convlab2.policy.policy import Policy
from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
from convlab2.util.train_util import init_logging_handler from convlab2.util.train_util import init_logging_handler
from convlab2.policy.vector.vector_multiwoz import MultiWozVector from convlab2.policy.vector.vector_binary import VectorBinary
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
from convlab2.util.file_util import cached_path from convlab2.util.file_util import cached_path
import zipfile import zipfile
...@@ -42,9 +42,7 @@ class DQN(Policy): ...@@ -42,9 +42,7 @@ class DQN(Policy):
# construct multiwoz vector # construct multiwoz vector
if dataset == 'Multiwoz': if dataset == 'Multiwoz':
voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt') self.vector = VectorBinary()
voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
self.vector = MultiWozVector(voc_file, voc_opp_file, composite_actions=True, vocab_size=cfg['vocab_size'])
#replay memory #replay memory
self.memory = MemoryReplay(cfg['memory_size']) self.memory = MemoryReplay(cfg['memory_size'])
......
...@@ -9,7 +9,7 @@ import json ...@@ -9,7 +9,7 @@ import json
import logging import logging
import os import os
import random import random
from convlab2.policy.vector.vector_multiwoz import MultiWozVector from convlab2.policy.vector.vector_binary import VectorBinary
import numpy as np import numpy as np
import torch import torch
...@@ -168,7 +168,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v ...@@ -168,7 +168,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v
if model_name == "PPO": if model_name == "PPO":
from convlab2.policy.ppo import PPO from convlab2.policy.ppo import PPO
if load_path: if load_path:
policy_sys = PPO(False, vectorizer=MultiWozVector()) policy_sys = PPO(False, vectorizer=VectorBinary())
policy_sys.load(load_path) policy_sys.load(load_path)
else: else:
policy_sys = PPO.from_pretrained() policy_sys = PPO.from_pretrained()
...@@ -183,7 +183,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v ...@@ -183,7 +183,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v
else: else:
policy_sys = PG.from_pretrained() policy_sys = PG.from_pretrained()
elif model_name == "MLE": elif model_name == "MLE":
from convlab2.policy.mle.multiwoz import MLE from convlab2.policy.mle import MLE
if load_path: if load_path:
policy_sys = MLE() policy_sys = MLE()
policy_sys.load(load_path) policy_sys.load(load_path)
......
...@@ -8,7 +8,7 @@ import json ...@@ -8,7 +8,7 @@ import json
from convlab2.policy.policy import Policy from convlab2.policy.policy import Policy
from convlab2.policy.rlmodule import MultiDiscretePolicy, Value from convlab2.policy.rlmodule import MultiDiscretePolicy, Value
from convlab2.util.train_util import init_logging_handler from convlab2.util.train_util import init_logging_handler
from convlab2.policy.vector.vector_multiwoz import MultiWozVector from convlab2.policy.vector.vector_binary import VectorBinary
from convlab2.util.file_util import cached_path from convlab2.util.file_util import cached_path
import zipfile import zipfile
import sys import sys
...@@ -38,9 +38,7 @@ class GDPL(Policy): ...@@ -38,9 +38,7 @@ class GDPL(Policy):
# construct policy and value network # construct policy and value network
if dataset == 'Multiwoz': if dataset == 'Multiwoz':
voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt') self.vector = VectorBinary()
voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
self.vector = MultiWozVector(voc_file, voc_opp_file)
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
self.value = Value(self.vector.state_dim, cfg['hv_dim']).to(device=DEVICE) self.value = Value(self.vector.state_dim, cfg['hv_dim']).to(device=DEVICE)
......
# Imitation on camrest
Vanilla MLE Policy employs a multi-class classification via Imitation Learning with a set of compositional actions where a compositional action consists of a set of dialog act items.
## Train
```
python train.py
```
You can modify *config.json* to change the setting.
## Data
data/camrest/[train/val/test].json
## Performance
|Dialog act accuracy|
|-|
|0.7459|
from convlab2.policy.mle.camrest.mle import MLE
\ No newline at end of file
{
"batchsz": 32,
"epoch": 16,
"lr": 0.01,
"save_dir": "save",
"log_dir": "log",
"print_per_batch": 10,
"save_per_epoch": 5,
"h_dim": 10,
"load": "save/best"
}
\ No newline at end of file
import os
import json
import pickle
import zipfile
from convlab2.util.camrest.state import default_state
from convlab2.util.dataloader.module_dataloader import ActPolicyDataloader
from convlab2.policy.vector.vector_camrest import CamrestVector
class ActPolicyDataLoaderCamrest(ActPolicyDataloader):
def __init__(self):
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
self.vector = CamrestVector(voc_file, voc_opp_file)
processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data')
if os.path.exists(processed_dir):
print('Load processed data file')
self._load_data(processed_dir)
else:
print('Start preprocessing the dataset')
self._build_data(root_dir, processed_dir)
def _build_data(self, root_dir, processed_dir): # TODO
raw_data = {}
for part in ['train', 'val', 'test']:
archive = zipfile.ZipFile(os.path.join(root_dir, 'data/camrest/{}.json.zip'.format(part)), 'r')
with archive.open('{}.json'.format(part), 'r') as f:
raw_data[part] = json.load(f)
self.data = {}
for part in ['train', 'val', 'test']:
self.data[part] = []
for key in raw_data[part]:
sess = key['dial']
state = default_state()
action = {}
for i, turn in enumerate(sess):
state['user_action'] = turn['usr']['dialog_act']
if i + 1 == len(sess):
state['terminated'] = True
for da in turn['usr']['slu']:
if da['slots'][0][0] != 'slot':
state['belief_state'][da['slots'][0][0]] = da['slots'][0][1]
action = turn['sys']['dialog_act']
self.data[part].append([self.vector.state_vectorize(state),
self.vector.action_vectorize(action)])
state['system_action'] = turn['sys']['dialog_act']
os.makedirs(processed_dir)
for part in ['train', 'val', 'test']:
with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'wb') as f:
pickle.dump(self.data[part], f)
# -*- coding: utf-8 -*-
import torch
import os
import json
from convlab2.policy.mle.mle import MLEAbstract
from convlab2.policy.rlmodule import MultiDiscretePolicy
from convlab2.policy.vector.vector_camrest import CamrestVector
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_camrest.zip")
class MLE(MLEAbstract):
def __init__(self,
archive_file=DEFAULT_ARCHIVE_FILE,
model_file='https://convlab.blob.core.windows.net/convlab-2/mle_policy_camrest.zip'):
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
self.vector = CamrestVector(voc_file, voc_opp_file)
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
self.load(archive_file, model_file, cfg['load'])
import os
import torch
import logging
import json
import sys
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
sys.path.append(root_dir)
from convlab2.policy.rlmodule import MultiDiscretePolicy
from convlab2.policy.vector.vector_camrest import CamrestVector
from convlab2.policy.mle.train import MLE_Trainer_Abstract
from convlab2.policy.mle.multiwoz.loader import ActPolicyDataLoaderCamrest
from convlab2.util.train_util import init_logging_handler
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MLE_Trainer(MLE_Trainer_Abstract):
def __init__(self, manager, cfg):
self._init_data(manager, cfg)
voc_file = os.path.join(root_dir, 'data/camrest/sys_da_voc.txt')
voc_opp_file = os.path.join(root_dir, 'data/camrest/usr_da_voc.txt')
vector = CamrestVector(voc_file, voc_opp_file)
self.policy = MultiDiscretePolicy(vector.state_dim, cfg['h_dim'], vector.da_dim).to(device=DEVICE)
self.policy.eval()
self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=cfg['lr'])
if __name__ == '__main__':
manager = ActPolicyDataLoaderCamrest()
with open('config.json', 'r') as f:
cfg = json.load(f)
init_logging_handler(cfg['log_dir'])
agent = MLE_Trainer(manager, cfg)
logging.debug('start training')
best = float('inf')
for e in range(cfg['epoch']):
agent.imitating(e)
best = agent.imit_test(e, best)
# Imitation on CrossWOZ
Vanilla MLE Policy employs a multi-class classification via Imitation Learning with a set of compositional actions where a compositional action consists of a set of dialog act items.
## Train
```
python train.py
```
You can modify *config.json* to change the setting.
## Data
data/crosswoz/[train/val/test].json
from convlab2.policy.mle.crosswoz.mle import MLE
\ No newline at end of file
{
"batchsz": 32,
"epoch": 20,
"lr": 0.001,
"save_dir": "save",
"log_dir": "log",
"print_per_batch": 400,
"save_per_epoch": 5,
"h_dim": 100,
"load": "save/best"
}
\ No newline at end of file
from convlab2.policy.mle.crosswoz.mle import MLE
from convlab2.dst.rule.crosswoz.dst import RuleDST
from convlab2.util.crosswoz.state import default_state
from convlab2.policy.rule.crosswoz.rule_simulator import Simulator
from convlab2.dialog_agent import PipelineAgent, BiSession
from convlab2.util.crosswoz.lexicalize import delexicalize_da
from convlab2.nlu.jointBERT.crosswoz.nlu import BERTNLU
from convlab2.nlg.template.crosswoz.nlg import TemplateNLG
from convlab2.nlg.sclstm.crosswoz.sc_lstm import SCLSTM
import os
import zipfile
import json
from copy import deepcopy
import random
import numpy as np
from pprint import pprint
import torch
def read_zipped_json(filepath, filename):
archive = zipfile.ZipFile(filepath, 'r')
return json.load(archive.open(filename))
def calculateF1(predict_golden):
TP, FP, FN = 0, 0, 0
for item in predict_golden:
predicts = item['predict']
labels = item['golden']
for quad in predicts:
if quad in labels:
TP += 1
else:
FP += 1
for quad in labels:
if quad not in predicts:
FN += 1
print(TP, FP, FN)
precision = 1.0 * TP / (TP + FP) if (TP + FP) else 0.
recall = 1.0 * TP / (TP + FN) if (TP + FN) else 0.
F1 = 2.0 * precision * recall / (precision + recall) if (precision + recall) else 0.
return precision, recall, F1
def evaluate_corpus_f1(policy, data, goal_type=None):
dst = RuleDST()
da_predict_golden = []
delex_da_predict_golden = []
for task_id, sess in data.items():
if goal_type and sess['type']!=goal_type:
continue
dst.init_session()
for i, turn in enumerate(sess['messages']):
if turn['role'] == 'usr':
dst.update(usr_da=turn['dialog_act'])
if i + 2 == len(sess):
dst.state['terminated'] = True
else:
for domain, svs in turn['sys_state'].items():
for slot, value in svs.items():
if slot != 'selectedResults':
dst.state['belief_state'][domain][slot] = value
golden_da = turn['dialog_act']
predict_da = policy.predict(deepcopy(dst.state))
# print(golden_da)
# print(predict_da)
# print()
# if 'Select' in [x[0] for x in sess['messages'][i - 1]['dialog_act']]:
da_predict_golden.append({
'predict': predict_da,
'golden': golden_da
})
delex_da_predict_golden.append({
'predict': delexicalize_da(predict_da),
'golden': delexicalize_da(golden_da)
})
# print(delex_da_predict_golden[-1])
dst.state['system_action'] = golden_da
# break
print('origin precision/recall/f1:', calculateF1(da_predict_golden))
print('delex precision/recall/f1:', calculateF1(delex_da_predict_golden))
def end2end_evaluate_simulation(policy):
nlu = BERTNLU()
nlg_usr = TemplateNLG(is_user=True, mode='auto_manual')
nlg_sys = TemplateNLG(is_user=False, mode='auto_manual')
# nlg_usr = SCLSTM(is_user=True, use_cuda=False)
# nlg_sys = SCLSTM(is_user=False, use_cuda=False)
usr_policy = Simulator()
usr_agent = PipelineAgent(nlu, None, usr_policy, nlg_usr, name='user')
sys_policy = policy
sys_dst = RuleDST()
sys_agent = PipelineAgent(nlu, sys_dst, sys_policy, nlg_sys, name='sys')
sess = BiSession(sys_agent=sys_agent, user_agent=usr_agent)
task_finish = {'All': list(), '单领域': list(), '独立多领域': list(), '独立多领域+交通': list(), '不独立多领域': list(),
'不独立多领域+交通': list()}
simulate_sess_num = 100
repeat = 10
random_seed = 2019
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random_seeds = [random.randint(1, 2**32-1) for _ in range(simulate_sess_num * repeat * 10000)]
while True:
sys_response = ''
random_seed = random_seeds[0]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random_seeds.pop(0)
sess.init_session()
# print(usr_policy.goal_type)
if len(task_finish[usr_policy.goal_type]) == simulate_sess_num*repeat:
continue
for i in range(15):
sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
# print('user:', user_response)
# print('sys:', sys_response)
# print(session_over, reward)
# print()
if session_over is True:
task_finish['All'].append(1)
task_finish[usr_policy.goal_type].append(1)
break
else:
task_finish['All'].append(0)
task_finish[usr_policy.goal_type].append(0)
print([len(x) for x in task_finish.values()])
# print(min([len(x) for x in task_finish.values()]))
if len(task_finish['All']) % 100 == 0:
for k, v in task_finish.items():
print(k)
all_samples = []
for i in range(repeat):
samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
all_samples += samples
print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
if min([len(x) for x in task_finish.values()]) == simulate_sess_num*repeat:
break
# pprint(usr_policy.original_goal)
# pprint(task_finish)
print('task_finish')
for k, v in task_finish.items():
print(k)
all_samples = []
for i in range(repeat):
samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
all_samples += samples
print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
def da_evaluate_simulation(policy):
usr_policy = Simulator()
usr_agent = PipelineAgent(None, None, usr_policy, None, name='user')
sys_policy = policy
sys_dst = RuleDST()
sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, name='sys')
sess = BiSession(sys_agent=sys_agent, user_agent=usr_agent)
task_finish = {'All': list(), '单领域': list(), '独立多领域': list(), '独立多领域+交通': list(), '不独立多领域': list(),
'不独立多领域+交通': list()}
simulate_sess_num = 100
repeat = 10
random_seed = 2019
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random_seeds = [random.randint(1, 2**32-1) for _ in range(simulate_sess_num * repeat * 10000)]
while True:
sys_response = []
random_seed = random_seeds[0]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random_seeds.pop(0)
sess.init_session()
# print(usr_policy.goal_type)
if len(task_finish[usr_policy.goal_type]) == simulate_sess_num*repeat:
continue
for i in range(15):
sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
# print('user:', user_response)
# print('sys:', sys_response)
# print(session_over, reward)
# print()
if session_over is True:
# pprint(sys_agent.tracker.state)
task_finish['All'].append(1)
task_finish[usr_policy.goal_type].append(1)
break
else:
task_finish['All'].append(0)
task_finish[usr_policy.goal_type].append(0)
print([len(x) for x in task_finish.values()])
# print(min([len(x) for x in task_finish.values()]))
if len(task_finish['All']) % 100 == 0:
for k, v in task_finish.items():
print(k)
all_samples = []
for i in range(repeat):
samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
all_samples += samples
print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
if min([len(x) for x in task_finish.values()]) == simulate_sess_num*repeat:
break
# pprint(usr_policy.original_goal)
# pprint(task_finish)
print('task_finish')
for k, v in task_finish.items():
print(k)
all_samples = []
for i in range(repeat):
samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
all_samples += samples
print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
if __name__ == '__main__':
random_seed = 2019
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
test_data = os.path.abspath(os.path.join(os.path.abspath(__file__),'../../../../../data/crosswoz/test.json.zip'))
test_data = read_zipped_json(test_data, 'test.json')
policy = MLE()
for goal_type in ['单领域','独立多领域','独立多领域+交通','不独立多领域','不独立多领域+交通',None]:
print(goal_type)
evaluate_corpus_f1(policy, test_data, goal_type=goal_type)
da_evaluate_simulation(policy)
end2end_evaluate_simulation(policy)
import os
import json
import pickle
import zipfile
import torch
import torch.utils.data as data
from convlab2.util.crosswoz.state import default_state
from convlab2.dst.rule.crosswoz.dst import RuleDST
from convlab2.policy.vector.vector_crosswoz import CrossWozVector
from copy import deepcopy
class PolicyDataLoaderCrossWoz():
def __init__(self):
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
voc_file = os.path.join(root_dir, 'data/crosswoz/sys_da_voc.json')
voc_opp_file = os.path.join(root_dir, 'data/crosswoz/usr_da_voc.json')
self.vector = CrossWozVector(sys_da_voc_json=voc_file, usr_da_voc_json=voc_opp_file)
processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data')
if os.path.exists(processed_dir):
print('Load processed data file')
self._load_data(processed_dir)
else:
print('Start preprocessing the dataset')
self._build_data(root_dir, processed_dir)
def _build_data(self, root_dir, processed_dir):
raw_data = {}
for part in ['train', 'val', 'test']:
archive = zipfile.ZipFile(os.path.join(root_dir, 'data/crosswoz/{}.json.zip'.format(part)), 'r')
with archive.open('{}.json'.format(part), 'r') as f:
raw_data[part] = json.load(f)
self.data = {}
# for cur domain update
dst = RuleDST()
for part in ['train', 'val', 'test']:
self.data[part] = []
for key in raw_data[part]:
sess = raw_data[part][key]['messages']
dst.init_session()
for i, turn in enumerate(sess):
if turn['role'] == 'usr':
dst.state['user_action'] = turn['dialog_act']
dst.update(usr_da=turn['dialog_act'])
if i + 2 == len(sess):
dst.state['terminated'] = True
else:
for domain, svs in turn['sys_state'].items():
for slot, value in svs.items():
if slot != 'selectedResults':
dst.state['belief_state'][domain][slot] = value
action = turn['dialog_act']
self.data[part].append([self.vector.state_vectorize(deepcopy(dst.state)),
self.vector.action_vectorize(action)])
dst.state['system_action'] = turn['dialog_act']
os.makedirs(processed_dir)
for part in ['train', 'val', 'test']:
with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'wb') as f:
pickle.dump(self.data[part], f)
def _load_data(self, processed_dir):
self.data = {}
for part in ['train', 'val', 'test']:
with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'rb') as f:
self.data[part] = pickle.load(f)
def create_dataset(self, part, batchsz):
print('Start creating {} dataset'.format(part))
s = []
a = []
for item in self.data[part]:
s.append(torch.Tensor(item[0]))
a.append(torch.Tensor(item[1]))
s = torch.stack(s)
a = torch.stack(a)
dataset = Dataset(s, a)
dataloader = data.DataLoader(dataset, batchsz, True)
print('Finish creating {} dataset'.format(part))
return dataloader
class Dataset(data.Dataset):
def __init__(self, s_s, a_s):
self.s_s = s_s
self.a_s = a_s
self.num_total = len(s_s)
def __getitem__(self, index):
s = self.s_s[index]
a = self.a_s[index]
return s, a
def __len__(self):
return self.num_total
# -*- coding: utf-8 -*-
import torch
import os
import json
import zipfile
from convlab2.util.file_util import cached_path
from convlab2.policy.mle.mle import MLEAbstract
from convlab2.policy.rlmodule import MultiDiscretePolicy
from convlab2.policy.vector.vector_crosswoz import CrossWozVector
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_crosswoz.zip")
class MLE(MLEAbstract):
def __init__(self,
archive_file=DEFAULT_ARCHIVE_FILE,
model_file='https://convlab.blob.core.windows.net/convlab-2/mle_policy_crosswoz.zip'):
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
voc_file = os.path.join(root_dir, 'data/crosswoz/sys_da_voc.json')
voc_opp_file = os.path.join(root_dir, 'data/crosswoz/usr_da_voc.json')
self.vector = CrossWozVector(sys_da_voc_json=voc_file, usr_da_voc_json=voc_opp_file)
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.sys_da_dim).to(device=DEVICE)
if not os.path.isfile(archive_file):
if not model_file:
raise Exception("No model for MLE Policy is specified!")
archive_file = cached_path(model_file)
model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'save')
if not os.path.exists(model_dir):
os.mkdir(model_dir)
if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')):
archive = zipfile.ZipFile(archive_file, 'r')
archive.extractall(model_dir)
self.load_from_pretrained(archive_file, model_file, cfg['load'])
import os
import torch
import logging
import torch.nn as nn
import json
import pickle
import sys
import random
import numpy as np
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
sys.path.append(root_dir)
from convlab2.policy.rlmodule import MultiDiscretePolicy
from convlab2.policy.vector.vector_crosswoz import CrossWozVector
from convlab2.policy.mle.crosswoz.loader import PolicyDataLoaderCrossWoz
from convlab2.util.train_util import to_device, init_logging_handler
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MLE_Trainer():
def __init__(self, manager, cfg):
self.data_train = manager.create_dataset('train', cfg['batchsz'])
self.data_valid = manager.create_dataset('val', cfg['batchsz'])
self.data_test = manager.create_dataset('test', cfg['batchsz'])
self.save_dir = cfg['save_dir']
self.print_per_batch = cfg['print_per_batch']
self.save_per_epoch = cfg['save_per_epoch']
voc_file = os.path.join(root_dir, 'data/crosswoz/sys_da_voc.json')
voc_opp_file = os.path.join(root_dir, 'data/crosswoz/usr_da_voc.json')
vector = CrossWozVector(voc_file, voc_opp_file)
self.policy = MultiDiscretePolicy(vector.state_dim, cfg['h_dim'], vector.sys_da_dim).to(device=DEVICE)
self.policy.eval()
self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=cfg['lr'])
self.multi_entropy_loss = nn.MultiLabelSoftMarginLoss()
def policy_loop(self, data):
s, target_a = to_device(data)
a_weights = self.policy(s)
loss_a = self.multi_entropy_loss(a_weights, target_a)
return loss_a
def imitating(self, epoch):
"""
pretrain the policy by simple imitation learning (behavioral cloning)
"""
self.policy.train()
a_loss = 0.
for i, data in enumerate(self.data_train):
self.policy_optim.zero_grad()
loss_a = self.policy_loop(data)
a_loss += loss_a.item()
loss_a.backward()
self.policy_optim.step()
if (i + 1) % self.print_per_batch == 0:
a_loss /= self.print_per_batch
logging.debug('<<dialog policy>> epoch {}, iter {}, loss_a:{}'.format(epoch, i, a_loss))
a_loss = 0.
if (epoch + 1) % self.save_per_epoch == 0:
self.save(self.save_dir, epoch)
self.policy.eval()
def imit_test(self, epoch, best):
"""
provide an unbiased evaluation of the policy fit on the training dataset
"""
a_loss = 0.
for i, data in enumerate(self.data_valid):
loss_a = self.policy_loop(data)
a_loss += loss_a.item()
a_loss /= len(self.data_valid)
logging.debug('<<dialog policy>> validation, epoch {}, loss_a:{}'.format(epoch, a_loss))
if a_loss < best:
logging.info('<<dialog policy>> best model saved')
best = a_loss
self.save(self.save_dir, 'best')
a_loss = 0.
for i, data in enumerate(self.data_test):
loss_a = self.policy_loop(data)
a_loss += loss_a.item()
a_loss /= len(self.data_test)
logging.debug('<<dialog policy>> test, epoch {}, loss_a:{}'.format(epoch, a_loss))
return best
def test(self):
def f1(a, target):
TP, FP, FN = 0, 0, 0
real = target.nonzero().tolist()
predict = a.nonzero().tolist()
# print(real)
# print(predict)
# print()
for item in real:
if item in predict:
TP += 1
else:
FN += 1
for item in predict:
if item not in real:
FP += 1
return TP, FP, FN
a_TP, a_FP, a_FN = 0, 0, 0
for i, data in enumerate(self.data_test):
s, target_a = to_device(data)
a_weights = self.policy(s)
a = a_weights.ge(0)
# TODO: fix batch F1
TP, FP, FN = f1(a, target_a)
a_TP += TP
a_FP += FP
a_FN += FN
prec = a_TP / (a_TP + a_FP)
rec = a_TP / (a_TP + a_FN)
F1 = 2 * prec * rec / (prec + rec)
print(a_TP, a_FP, a_FN, F1)
def save(self, directory, epoch):
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(self.policy.state_dict(), directory + '/' + str(epoch) + '_mle.pol.mdl')
logging.info('<<dialog policy>> epoch {}: saved network to mdl'.format(epoch))
def load(self, filename='save/best'):
policy_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_mle.pol.mdl')
if os.path.exists(policy_mdl):
self.policy.load_state_dict(torch.load(policy_mdl))
if __name__ == '__main__':
random_seed = 2019
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
manager = PolicyDataLoaderCrossWoz()
with open('config.json', 'r') as f:
cfg = json.load(f)
init_logging_handler(cfg['log_dir'])
agent = MLE_Trainer(manager, cfg)
agent.load()
logging.debug('start training')
best = float('inf')
for e in range(cfg['epoch']):
agent.imitating(e)
best = agent.imit_test(e, best)
# agent.test() # 5731 1483 1880 0.7731534569983137
...@@ -3,63 +3,87 @@ import pickle ...@@ -3,63 +3,87 @@ import pickle
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from convlab2.policy.vector.vector_binary import VectorBinary
from convlab2.util import load_policy_data, load_dataset
from convlab2.util.custom_util import flatten_acts
from convlab2.util.multiwoz.state import default_state from convlab2.util.multiwoz.state import default_state
from convlab2.policy.vector.dataset import ActDataset from convlab2.policy.vector.dataset import ActDataset
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
from convlab2.util.dataloader.module_dataloader import ActPolicyDataloader
class ActMLEPolicyDataLoader: class PolicyDataVectorizer:
def __init__(self): def __init__(self, dataset_name='multiwoz21', vector=None):
self.vector = None self.dataset_name = dataset_name
if vector is None:
self.vector = VectorBinary(dataset_name)
else:
self.vector = vector
self.process_data()
def _build_data(self, root_dir, processed_dir): def process_data(self):
processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
f'processed_data/{self.dataset_name}_{type(self.vector).__name__}')
if os.path.exists(processed_dir):
print('Load processed data file')
self._load_data(processed_dir)
else:
print('Start preprocessing the dataset, this can take a while..')
self._build_data(processed_dir)
def _build_data(self, processed_dir):
self.data = {} self.data = {}
print("Initialise DataLoader")
data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader())
raw_data_all = data_loader.load_data(data_key='all', role='sys')
for part in ['train', 'val', 'test']:
self.data[part] = []
raw_data = raw_data_all[part]
for belief_state, context_dialog_act, terminated, dialog_act, goal in \ os.makedirs(processed_dir, exist_ok=True)
zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], dataset = load_dataset(self.dataset_name)
raw_data['dialog_act'], raw_data['goal']): data_split = load_policy_data(dataset, context_window_size=2)
for split in data_split:
self.data[split] = []
raw_data = data_split[split]
for data_point in raw_data:
state = default_state() state = default_state()
state['belief_state'] = belief_state
state['user_action'] = context_dialog_act[-1] state['belief_state'] = data_point['context'][-1]['state']
state['system_action'] = context_dialog_act[-2] if len(context_dialog_act) > 1 else {} state['user_action'] = flatten_acts(data_point['context'][-1]['dialogue_acts'])
state['terminated'] = terminated last_system_act = data_point['context'][-2]['dialogue_acts'] \
action = dialog_act if len(data_point['context']) > 1 else {}
self.data[part].append([self.vector.state_vectorize(state), state['system_action'] = flatten_acts(last_system_act)
self.vector.action_vectorize(action)]) state['terminated'] = data_point['terminated']
state['booked'] = data_point['booked']
os.makedirs(processed_dir) dialogue_act = flatten_acts(data_point['dialogue_acts'])
for part in ['train', 'val', 'test']:
with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'wb') as f: vectorized_state, mask = self.vector.state_vectorize(state)
pickle.dump(self.data[part], f) vectorized_action = self.vector.action_vectorize(dialogue_act)
self.data[split].append({"state": vectorized_state, "action": vectorized_action, "mask": mask})
with open(os.path.join(processed_dir, '{}.pkl'.format(split)), 'wb') as f:
pickle.dump(self.data[split], f)
print("Data processing done.")
def _load_data(self, processed_dir): def _load_data(self, processed_dir):
self.data = {} self.data = {}
for part in ['train', 'val', 'test']: for part in ['train', 'validation', 'test']:
with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'rb') as f: with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'rb') as f:
self.data[part] = pickle.load(f) self.data[part] = pickle.load(f)
def create_dataset(self, part, batchsz): def create_dataset(self, part, batchsz):
print('Start creating {} dataset'.format(part)) states = []
s = [] actions = []
a = [] masks = []
m = []
for item in self.data[part]: for item in self.data[part]:
s.append(torch.Tensor(item[0][0])) states.append(torch.Tensor(item['state']))
a.append(torch.Tensor(item[1])) actions.append(torch.Tensor(item['action']))
m.append(torch.zeros(len(item[1]))) masks.append(torch.Tensor(item['mask']))
s = torch.stack(s) s = torch.stack(states)
a = torch.stack(a) a = torch.stack(actions)
m = torch.stack(m) m = torch.stack(masks)
dataset = ActDataset(s, a, m) dataset = ActDataset(s, a, m)
dataloader = data.DataLoader(dataset, batchsz, True) dataloader = data.DataLoader(dataset, batchsz, True)
print('Finish creating {} dataset'.format(part))
return dataloader return dataloader
if __name__ == '__main__':
data_loader = PolicyDataVectorizer()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import zipfile
import logging
import torch import torch
import os import os
import zipfile import json
from convlab2.policy.policy import Policy from convlab2.policy.policy import Policy
from convlab2.util.file_util import cached_path from convlab2.util.file_util import cached_path
import logging from convlab2.policy.rlmodule import MultiDiscretePolicy
from convlab2.policy.vector.vector_binary import VectorBinary
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_multiwoz.zip")
class MLEAbstract(Policy): class MLEAbstract(Policy):
def __init__(self, archive_file, model_file): def __init__(self, vector, policy):
self.vector = None self.vector = vector
self.policy = None self.policy = policy
def predict(self, state): def predict(self, state):
""" """
...@@ -67,3 +74,34 @@ class MLEAbstract(Policy): ...@@ -67,3 +74,34 @@ class MLEAbstract(Policy):
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break break
class MLE(MLEAbstract):
def __init__(self):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
self.vector = VectorBinary()
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
@classmethod
def from_pretrained(cls,
archive_file=DEFAULT_ARCHIVE_FILE,
model_file='https://convlab.blob.core.windows.net/convlab-2/mle_policy_multiwoz.zip'):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
model = cls()
model.load_from_pretrained(archive_file, model_file, cfg['load'])
return model
class MLEPolicy(MLE):
def __init__(self,
archive_file=DEFAULT_ARCHIVE_FILE,
model_file='https://convlab.blob.core.windows.net/convlab-2/mle_policy_multiwoz.zip'):
super().__init__()
if model_file:
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
self.load_from_pretrained(archive_file, model_file, cfg['load'])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment