Skip to content
Snippets Groups Projects
Commit 5ec83abb authored by Christian's avatar Christian
Browse files

working MLE version with remapped data set, getting F1-score of 0.52

parent 0f4ecebf
No related branches found
No related tags found
No related merge requests found
...@@ -16,10 +16,12 @@ class ActMLEPolicyDataLoader: ...@@ -16,10 +16,12 @@ class ActMLEPolicyDataLoader:
def _build_data(self, root_dir, processed_dir): def _build_data(self, root_dir, processed_dir):
self.data = {} self.data = {}
print("Initialise DataLoader")
data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader()) data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader())
raw_data_all = data_loader.load_data(data_key='all', role='sys')
for part in ['train', 'val', 'test']: for part in ['train', 'val', 'test']:
self.data[part] = [] self.data[part] = []
raw_data = data_loader.load_data(data_key=part, role='sys')[part] raw_data = raw_data_all[part]
for belief_state, context_dialog_act, terminated, dialog_act, goal in \ for belief_state, context_dialog_act, terminated, dialog_act, goal in \
zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'],
......
{"args": {"seed": 0, "eval_freq": 1}, "config": {"batchsz": 32, "epoch": 24, "lr_supervised": 0.0001, "save_dir": "save", "log_dir": "log", "print_per_batch": 400, "save_per_epoch": 1, "h_dim": 100, "load": "save/best", "pos_weight": 5, "hidden_size": 256, "weight_decay": 1e-05, "lambda": 1, "tau": 0.005, "policy_freq": 2, "entropy_weight": 0.001}}
\ No newline at end of file
...@@ -7,13 +7,11 @@ class ActMLEPolicyDataLoaderMultiWoz(ActMLEPolicyDataLoader): ...@@ -7,13 +7,11 @@ class ActMLEPolicyDataLoaderMultiWoz(ActMLEPolicyDataLoader):
def __init__(self, vectoriser=None): def __init__(self, vectoriser=None):
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
if vectoriser: if vectoriser:
self.vector = vectoriser self.vector = vectoriser
else: else:
print("We use vanilla Vectoriser") print("We use vanilla Vectoriser")
self.vector = MultiWozVector(voc_file, voc_opp_file) self.vector = MultiWozVector()
processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data') processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data')
......
...@@ -11,6 +11,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ...@@ -11,6 +11,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_multiwoz.zip") DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_multiwoz.zip")
class MLE(MLEAbstract): class MLE(MLEAbstract):
def __init__(self): def __init__(self):
......
...@@ -486,6 +486,7 @@ class Agenda(object): ...@@ -486,6 +486,7 @@ class Agenda(object):
continue continue
slot_vals = sys_action[diaact] slot_vals = sys_action[diaact]
#TODO: use string "book" instead of "booking"
if 'booking' in diaact: if 'booking' in diaact:
if self.update_booking(diaact, slot_vals, goal): if self.update_booking(diaact, slot_vals, goal):
return return
...@@ -502,6 +503,7 @@ class Agenda(object): ...@@ -502,6 +503,7 @@ class Agenda(object):
if slot == 'name': if slot == 'name':
self._remove_item(diaact.split( self._remove_item(diaact.split(
'-')[0]+'-inform', 'choice') '-')[0]+'-inform', 'choice')
# TODO: use string "book" instead of "booking"
if 'booking' in diaact and self.cur_domain: if 'booking' in diaact and self.cur_domain:
g_book = self._get_goal_infos(self.cur_domain, goal)[-2] g_book = self._get_goal_infos(self.cur_domain, goal)[-2]
if len(g_book) == 0: if len(g_book) == 0:
...@@ -533,6 +535,7 @@ class Agenda(object): ...@@ -533,6 +535,7 @@ class Agenda(object):
:param goal: Goal :param goal: Goal
:return: True:user want to close the session. False:session is continue :return: True:user want to close the session. False:session is continue
""" """
#TODO: Use domain of diaact.split instead of current domain
_, intent = diaact.split('-') _, intent = diaact.split('-')
domain = self.cur_domain domain = self.cur_domain
self.domains['update_booking'] = domain self.domains['update_booking'] = domain
...@@ -540,6 +543,7 @@ class Agenda(object): ...@@ -540,6 +543,7 @@ class Agenda(object):
if domain not in goal.domains: if domain not in goal.domains:
isover = False isover = False
#TODO: Remove inform
elif intent in ['book', 'inform']: elif intent in ['book', 'inform']:
isover = self._handle_inform(domain, intent, slot_vals, goal) isover = self._handle_inform(domain, intent, slot_vals, goal)
...@@ -682,6 +686,7 @@ class Agenda(object): ...@@ -682,6 +686,7 @@ class Agenda(object):
self._push_item(domain + '-inform', slot, g_book[slot]) self._push_item(domain + '-inform', slot, g_book[slot])
info_right = False info_right = False
#TODO: Only use "book"
if intent in ['book', 'offerbooked'] and info_right: if intent in ['book', 'offerbooked'] and info_right:
# booked ok # booked ok
if 'booked' in goal.domain_goals[domain]: if 'booked' in goal.domain_goals[domain]:
......
...@@ -57,7 +57,7 @@ class MultiWozVectorBase(Vector): ...@@ -57,7 +57,7 @@ class MultiWozVectorBase(Vector):
if not voc_file or not voc_opp_file: if not voc_file or not voc_opp_file:
voc_file = os.path.join( voc_file = os.path.join(
root_dir, 'data/multiwoz/sys_da_voc.txt') root_dir, 'data/multiwoz/sys_da_voc_remapped.txt')
voc_opp_file = os.path.join( voc_opp_file = os.path.join(
root_dir, 'data/multiwoz/usr_da_voc.txt') root_dir, 'data/multiwoz/usr_da_voc.txt')
......
...@@ -3,6 +3,7 @@ Dataloader base class. Every dataset should inherit this class and implement its ...@@ -3,6 +3,7 @@ Dataloader base class. Every dataset should inherit this class and implement its
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import os import os
from zipfile import ZipFile
import json import json
import sys import sys
import zipfile import zipfile
...@@ -69,22 +70,22 @@ class MultiWOZDataloader(DatasetDataloader): ...@@ -69,22 +70,22 @@ class MultiWOZDataloader(DatasetDataloader):
'terminated', 'goal'])) 'terminated', 'goal']))
self.data = {'train': {}, 'val': {}, 'test': {}, 'role': role, 'human_val': {}} self.data = {'train': {}, 'val': {}, 'test': {}, 'role': role, 'human_val': {}}
if data_key == 'all':
data_key_list = ['train', 'val', 'test'] archive = ZipFile(os.path.join(data_dir, 'data.zip'))
else: archive.extractall()
data_key_list = [data_key] data = json.load(open(os.path.join(data_dir, 'data/data.json')))
for data_key in data_key_list:
data = read_zipped_json(os.path.join(data_dir, '{}.json.zip'.format(data_key)), '{}.json'.format(data_key)) for k in ['train', 'test', 'val']:
print('loaded {}, size {}'.format(data_key, len(data)))
for x in info_list: for x in info_list:
self.data[data_key][x] = [] self.data[k][x] = []
for sess_id, sess in data.items(): for sess_id, sess in data.items():
data_key = sess['split']
cur_context = [] cur_context = []
cur_context_dialog_act = [] cur_context_dialog_act = []
entity_booked_dict = dict((domain, False) for domain in belief_domains) entity_booked_dict = dict((domain, False) for domain in belief_domains)
for i, turn in enumerate(sess['log']): for i, turn in enumerate(sess['log']):
text = turn['text'] text = turn['text']
da = da2tuples(turn['dialog_act']) da = da2tuples(turn.get('dialog_act', {}))
if role == 'sys' and i % 2 == 0: if role == 'sys' and i % 2 == 0:
cur_context.append(text) cur_context.append(text)
cur_context_dialog_act.append(da) cur_context_dialog_act.append(da)
......
...@@ -217,7 +217,7 @@ def preprocess(): ...@@ -217,7 +217,7 @@ def preprocess():
for ori_dialog_id, ori_dialog in tqdm(original_data.items()): for ori_dialog_id, ori_dialog in tqdm(original_data.items()):
if ori_dialog_id in val_list: if ori_dialog_id in val_list:
split = 'validation' split = 'val'
elif ori_dialog_id in test_list: elif ori_dialog_id in test_list:
split = 'test' split = 'test'
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment