From 5ec83abba8c33f1881615f589300a0b7a72e683f Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Tue, 15 Feb 2022 09:06:39 +0100 Subject: [PATCH] working MLE version with remapped data set, getting F1-score of 0.52 --- convlab2/policy/mle/loader.py | 4 +- ...d568-7cb7-4ba9-96d4-9baa7357316e.fritz.box | Bin 0 -> 40 bytes .../configs/config_saved.json | 1 + convlab2/policy/mle/multiwoz/loader.py | 4 +- convlab2/policy/mle/multiwoz/mle.py | 1 + .../rule/multiwoz/policy_agenda_multiwoz.py | 5 + convlab2/policy/vector/vector_base.py | 2 +- .../util/dataloader/dataset_dataloader.py | 97 +++++++++--------- data/multiwoz/remap_actions.py | 2 +- 9 files changed, 62 insertions(+), 54 deletions(-) create mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box create mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json diff --git a/convlab2/policy/mle/loader.py b/convlab2/policy/mle/loader.py index 8bea164b..349783d8 100755 --- a/convlab2/policy/mle/loader.py +++ b/convlab2/policy/mle/loader.py @@ -16,10 +16,12 @@ class ActMLEPolicyDataLoader: def _build_data(self, root_dir, processed_dir): self.data = {} + print("Initialise DataLoader") data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader()) + raw_data_all = data_loader.load_data(data_key='all', role='sys') for part in ['train', 'val', 'test']: self.data[part] = [] - raw_data = data_loader.load_data(data_key=part, role='sys')[part] + raw_data = raw_data_all[part] for belief_state, context_dialog_act, terminated, dialog_act, goal in \ zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box new file mode 100644 index 0000000000000000000000000000000000000000..a44089b2ed8b2a679ca9b81225aaa574230d0570 GIT binary patch literal 40 rcmb1OfPlsI-b$Pk(%y3{ZMxwo#hX-=n3<>NT9%quVrBHADKZfN)u#;6 literal 0 HcmV?d00001 diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json new file mode 100644 index 00000000..377cab63 --- /dev/null +++ b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json @@ -0,0 +1 @@ +{"args": {"seed": 0, "eval_freq": 1}, "config": {"batchsz": 32, "epoch": 24, "lr_supervised": 0.0001, "save_dir": "save", "log_dir": "log", "print_per_batch": 400, "save_per_epoch": 1, "h_dim": 100, "load": "save/best", "pos_weight": 5, "hidden_size": 256, "weight_decay": 1e-05, "lambda": 1, "tau": 0.005, "policy_freq": 2, "entropy_weight": 0.001}} \ No newline at end of file diff --git a/convlab2/policy/mle/multiwoz/loader.py b/convlab2/policy/mle/multiwoz/loader.py index bc946761..13ea8f35 100755 --- a/convlab2/policy/mle/multiwoz/loader.py +++ b/convlab2/policy/mle/multiwoz/loader.py @@ -7,13 +7,11 @@ class ActMLEPolicyDataLoaderMultiWoz(ActMLEPolicyDataLoader): def __init__(self, vectoriser=None): root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) - voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt') - voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt') if vectoriser: self.vector = vectoriser else: print("We use vanilla Vectoriser") - self.vector = MultiWozVector(voc_file, voc_opp_file) + self.vector = MultiWozVector() processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data') diff --git a/convlab2/policy/mle/multiwoz/mle.py b/convlab2/policy/mle/multiwoz/mle.py index f0377524..b614b55e 100755 --- a/convlab2/policy/mle/multiwoz/mle.py +++ b/convlab2/policy/mle/multiwoz/mle.py @@ -11,6 +11,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_multiwoz.zip") + class MLE(MLEAbstract): def __init__(self): diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index f5c5aa58..13994c67 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -486,6 +486,7 @@ class Agenda(object): continue slot_vals = sys_action[diaact] + #TODO: use string "book" instead of "booking" if 'booking' in diaact: if self.update_booking(diaact, slot_vals, goal): return @@ -502,6 +503,7 @@ class Agenda(object): if slot == 'name': self._remove_item(diaact.split( '-')[0]+'-inform', 'choice') + # TODO: use string "book" instead of "booking" if 'booking' in diaact and self.cur_domain: g_book = self._get_goal_infos(self.cur_domain, goal)[-2] if len(g_book) == 0: @@ -533,6 +535,7 @@ class Agenda(object): :param goal: Goal :return: True:user want to close the session. False:session is continue """ + #TODO: Use domain of diaact.split instead of current domain _, intent = diaact.split('-') domain = self.cur_domain self.domains['update_booking'] = domain @@ -540,6 +543,7 @@ class Agenda(object): if domain not in goal.domains: isover = False + #TODO: Remove inform elif intent in ['book', 'inform']: isover = self._handle_inform(domain, intent, slot_vals, goal) @@ -682,6 +686,7 @@ class Agenda(object): self._push_item(domain + '-inform', slot, g_book[slot]) info_right = False + #TODO: Only use "book" if intent in ['book', 'offerbooked'] and info_right: # booked ok if 'booked' in goal.domain_goals[domain]: diff --git a/convlab2/policy/vector/vector_base.py b/convlab2/policy/vector/vector_base.py index 48e0529e..d040e6e5 100644 --- a/convlab2/policy/vector/vector_base.py +++ b/convlab2/policy/vector/vector_base.py @@ -57,7 +57,7 @@ class MultiWozVectorBase(Vector): if not voc_file or not voc_opp_file: voc_file = os.path.join( - root_dir, 'data/multiwoz/sys_da_voc.txt') + root_dir, 'data/multiwoz/sys_da_voc_remapped.txt') voc_opp_file = os.path.join( root_dir, 'data/multiwoz/usr_da_voc.txt') diff --git a/convlab2/util/dataloader/dataset_dataloader.py b/convlab2/util/dataloader/dataset_dataloader.py index 12c06b14..5387320f 100755 --- a/convlab2/util/dataloader/dataset_dataloader.py +++ b/convlab2/util/dataloader/dataset_dataloader.py @@ -3,6 +3,7 @@ Dataloader base class. Every dataset should inherit this class and implement its """ from abc import ABC, abstractmethod import os +from zipfile import ZipFile import json import sys import zipfile @@ -69,57 +70,57 @@ class MultiWOZDataloader(DatasetDataloader): 'terminated', 'goal'])) self.data = {'train': {}, 'val': {}, 'test': {}, 'role': role, 'human_val': {}} - if data_key == 'all': - data_key_list = ['train', 'val', 'test'] - else: - data_key_list = [data_key] - for data_key in data_key_list: - data = read_zipped_json(os.path.join(data_dir, '{}.json.zip'.format(data_key)), '{}.json'.format(data_key)) - print('loaded {}, size {}'.format(data_key, len(data))) + + archive = ZipFile(os.path.join(data_dir, 'data.zip')) + archive.extractall() + data = json.load(open(os.path.join(data_dir, 'data/data.json'))) + + for k in ['train', 'test', 'val']: for x in info_list: - self.data[data_key][x] = [] - for sess_id, sess in data.items(): - cur_context = [] - cur_context_dialog_act = [] - entity_booked_dict = dict((domain, False) for domain in belief_domains) - for i, turn in enumerate(sess['log']): - text = turn['text'] - da = da2tuples(turn['dialog_act']) - if role == 'sys' and i % 2 == 0: - cur_context.append(text) - cur_context_dialog_act.append(da) - continue - elif role == 'usr' and i % 2 == 1: - cur_context.append(text) - cur_context_dialog_act.append(da) - continue - if utterance: - self.data[data_key]['utterance'].append(text) - if dialog_act: - self.data[data_key]['dialog_act'].append(da) - if context: - self.data[data_key]['context'].append(cur_context[-context_window_size:]) - if context_dialog_act: - self.data[data_key]['context_dialog_act'].append(cur_context_dialog_act[-context_window_size:]) - if belief_state: - entity_booked_dict, fixed_bs = self.fix_entity_booked_info(entity_booked_dict, turn['metadata']) - self.data[data_key]['belief_state'].append(fixed_bs) - if last_opponent_utterance: - self.data[data_key]['last_opponent_utterance'].append( - cur_context[-1] if len(cur_context) >= 1 else '') - if last_self_utterance: - self.data[data_key]['last_self_utterance'].append( - cur_context[-2] if len(cur_context) >= 2 else '') - if session_id: - self.data[data_key]['session_id'].append(sess_id) - if span_info: - self.data[data_key]['span_info'].append(turn['span_info']) - if terminated: - self.data[data_key]['terminated'].append(i + 2 >= len(sess['log'])) - if goal: - self.data[data_key]['goal'].append(sess['goal']) + self.data[k][x] = [] + for sess_id, sess in data.items(): + data_key = sess['split'] + cur_context = [] + cur_context_dialog_act = [] + entity_booked_dict = dict((domain, False) for domain in belief_domains) + for i, turn in enumerate(sess['log']): + text = turn['text'] + da = da2tuples(turn.get('dialog_act', {})) + if role == 'sys' and i % 2 == 0: + cur_context.append(text) + cur_context_dialog_act.append(da) + continue + elif role == 'usr' and i % 2 == 1: cur_context.append(text) cur_context_dialog_act.append(da) + continue + if utterance: + self.data[data_key]['utterance'].append(text) + if dialog_act: + self.data[data_key]['dialog_act'].append(da) + if context: + self.data[data_key]['context'].append(cur_context[-context_window_size:]) + if context_dialog_act: + self.data[data_key]['context_dialog_act'].append(cur_context_dialog_act[-context_window_size:]) + if belief_state: + entity_booked_dict, fixed_bs = self.fix_entity_booked_info(entity_booked_dict, turn['metadata']) + self.data[data_key]['belief_state'].append(fixed_bs) + if last_opponent_utterance: + self.data[data_key]['last_opponent_utterance'].append( + cur_context[-1] if len(cur_context) >= 1 else '') + if last_self_utterance: + self.data[data_key]['last_self_utterance'].append( + cur_context[-2] if len(cur_context) >= 2 else '') + if session_id: + self.data[data_key]['session_id'].append(sess_id) + if span_info: + self.data[data_key]['span_info'].append(turn['span_info']) + if terminated: + self.data[data_key]['terminated'].append(i + 2 >= len(sess['log'])) + if goal: + self.data[data_key]['goal'].append(sess['goal']) + cur_context.append(text) + cur_context_dialog_act.append(da) if ontology: ontology_path = os.path.join(data_dir, 'ontology.json') self.data['ontology'] = json.load(open(ontology_path)) diff --git a/data/multiwoz/remap_actions.py b/data/multiwoz/remap_actions.py index ab7e48ad..097ebce6 100644 --- a/data/multiwoz/remap_actions.py +++ b/data/multiwoz/remap_actions.py @@ -217,7 +217,7 @@ def preprocess(): for ori_dialog_id, ori_dialog in tqdm(original_data.items()): if ori_dialog_id in val_list: - split = 'validation' + split = 'val' elif ori_dialog_id in test_list: split = 'test' else: -- GitLab