From 5ec83abba8c33f1881615f589300a0b7a72e683f Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Tue, 15 Feb 2022 09:06:39 +0100
Subject: [PATCH] working MLE version with remapped data set, getting F1-score
 of 0.52

---
 convlab2/policy/mle/loader.py                 |   4 +-
 ...d568-7cb7-4ba9-96d4-9baa7357316e.fritz.box | Bin 0 -> 40 bytes
 .../configs/config_saved.json                 |   1 +
 convlab2/policy/mle/multiwoz/loader.py        |   4 +-
 convlab2/policy/mle/multiwoz/mle.py           |   1 +
 .../rule/multiwoz/policy_agenda_multiwoz.py   |   5 +
 convlab2/policy/vector/vector_base.py         |   2 +-
 .../util/dataloader/dataset_dataloader.py     |  97 +++++++++---------
 data/multiwoz/remap_actions.py                |   2 +-
 9 files changed, 62 insertions(+), 54 deletions(-)
 create mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box
 create mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json

diff --git a/convlab2/policy/mle/loader.py b/convlab2/policy/mle/loader.py
index 8bea164b..349783d8 100755
--- a/convlab2/policy/mle/loader.py
+++ b/convlab2/policy/mle/loader.py
@@ -16,10 +16,12 @@ class ActMLEPolicyDataLoader:
         
     def _build_data(self, root_dir, processed_dir):
         self.data = {}
+        print("Initialise DataLoader")
         data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader())
+        raw_data_all = data_loader.load_data(data_key='all', role='sys')
         for part in ['train', 'val', 'test']:
             self.data[part] = []
-            raw_data = data_loader.load_data(data_key=part, role='sys')[part]
+            raw_data = raw_data_all[part]
             
             for belief_state, context_dialog_act, terminated, dialog_act, goal in \
                 zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'],
diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box
new file mode 100644
index 0000000000000000000000000000000000000000..a44089b2ed8b2a679ca9b81225aaa574230d0570
GIT binary patch
literal 40
rcmb1OfPlsI-b$Pk(%y3{ZMxwo#hX-=n3<>NT9%quVrBHADKZfN)u#;6

literal 0
HcmV?d00001

diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json
new file mode 100644
index 00000000..377cab63
--- /dev/null
+++ b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json
@@ -0,0 +1 @@
+{"args": {"seed": 0, "eval_freq": 1}, "config": {"batchsz": 32, "epoch": 24, "lr_supervised": 0.0001, "save_dir": "save", "log_dir": "log", "print_per_batch": 400, "save_per_epoch": 1, "h_dim": 100, "load": "save/best", "pos_weight": 5, "hidden_size": 256, "weight_decay": 1e-05, "lambda": 1, "tau": 0.005, "policy_freq": 2, "entropy_weight": 0.001}}
\ No newline at end of file
diff --git a/convlab2/policy/mle/multiwoz/loader.py b/convlab2/policy/mle/multiwoz/loader.py
index bc946761..13ea8f35 100755
--- a/convlab2/policy/mle/multiwoz/loader.py
+++ b/convlab2/policy/mle/multiwoz/loader.py
@@ -7,13 +7,11 @@ class ActMLEPolicyDataLoaderMultiWoz(ActMLEPolicyDataLoader):
 
     def __init__(self, vectoriser=None):
         root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
-        voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
-        voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
         if vectoriser:
             self.vector = vectoriser
         else:
             print("We use vanilla Vectoriser")
-            self.vector = MultiWozVector(voc_file, voc_opp_file)
+            self.vector = MultiWozVector()
 
         processed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'processed_data')
 
diff --git a/convlab2/policy/mle/multiwoz/mle.py b/convlab2/policy/mle/multiwoz/mle.py
index f0377524..b614b55e 100755
--- a/convlab2/policy/mle/multiwoz/mle.py
+++ b/convlab2/policy/mle/multiwoz/mle.py
@@ -11,6 +11,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
 DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "mle_policy_multiwoz.zip")
 
+
 class MLE(MLEAbstract):
     
     def __init__(self):
diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
index f5c5aa58..13994c67 100755
--- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
+++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
@@ -486,6 +486,7 @@ class Agenda(object):
                 continue
 
             slot_vals = sys_action[diaact]
+            #TODO: use string "book" instead of "booking"
             if 'booking' in diaact:
                 if self.update_booking(diaact, slot_vals, goal):
                     return
@@ -502,6 +503,7 @@ class Agenda(object):
                     if slot == 'name':
                         self._remove_item(diaact.split(
                             '-')[0]+'-inform', 'choice')
+            # TODO: use string "book" instead of "booking"
             if 'booking' in diaact and self.cur_domain:
                 g_book = self._get_goal_infos(self.cur_domain, goal)[-2]
                 if len(g_book) == 0:
@@ -533,6 +535,7 @@ class Agenda(object):
         :param goal:        Goal
         :return:            True:user want to close the session. False:session is continue
         """
+        #TODO: Use domain of diaact.split instead of current domain
         _, intent = diaact.split('-')
         domain = self.cur_domain
         self.domains['update_booking'] = domain
@@ -540,6 +543,7 @@ class Agenda(object):
         if domain not in goal.domains:
             isover = False
 
+        #TODO: Remove inform
         elif intent in ['book', 'inform']:
             isover = self._handle_inform(domain, intent, slot_vals, goal)
 
@@ -682,6 +686,7 @@ class Agenda(object):
                 self._push_item(domain + '-inform', slot, g_book[slot])
                 info_right = False
 
+        #TODO: Only use "book"
         if intent in ['book', 'offerbooked'] and info_right:
             # booked ok
             if 'booked' in goal.domain_goals[domain]:
diff --git a/convlab2/policy/vector/vector_base.py b/convlab2/policy/vector/vector_base.py
index 48e0529e..d040e6e5 100644
--- a/convlab2/policy/vector/vector_base.py
+++ b/convlab2/policy/vector/vector_base.py
@@ -57,7 +57,7 @@ class MultiWozVectorBase(Vector):
 
         if not voc_file or not voc_opp_file:
             voc_file = os.path.join(
-                root_dir, 'data/multiwoz/sys_da_voc.txt')
+                root_dir, 'data/multiwoz/sys_da_voc_remapped.txt')
             voc_opp_file = os.path.join(
                 root_dir, 'data/multiwoz/usr_da_voc.txt')
 
diff --git a/convlab2/util/dataloader/dataset_dataloader.py b/convlab2/util/dataloader/dataset_dataloader.py
index 12c06b14..5387320f 100755
--- a/convlab2/util/dataloader/dataset_dataloader.py
+++ b/convlab2/util/dataloader/dataset_dataloader.py
@@ -3,6 +3,7 @@ Dataloader base class. Every dataset should inherit this class and implement its
 """
 from abc import ABC, abstractmethod
 import os
+from zipfile import ZipFile
 import json
 import sys
 import zipfile
@@ -69,57 +70,57 @@ class MultiWOZDataloader(DatasetDataloader):
                                        'terminated', 'goal']))
 
         self.data = {'train': {}, 'val': {}, 'test': {}, 'role': role, 'human_val': {}}
-        if data_key == 'all':
-            data_key_list = ['train', 'val', 'test']
-        else:
-            data_key_list = [data_key]
-        for data_key in data_key_list:
-            data = read_zipped_json(os.path.join(data_dir, '{}.json.zip'.format(data_key)), '{}.json'.format(data_key))
-            print('loaded {}, size {}'.format(data_key, len(data)))
+
+        archive = ZipFile(os.path.join(data_dir, 'data.zip'))
+        archive.extractall()
+        data = json.load(open(os.path.join(data_dir, 'data/data.json')))
+
+        for k in ['train', 'test', 'val']:
             for x in info_list:
-                self.data[data_key][x] = []
-            for sess_id, sess in data.items():
-                cur_context = []
-                cur_context_dialog_act = []
-                entity_booked_dict = dict((domain, False) for domain in belief_domains)
-                for i, turn in enumerate(sess['log']):
-                    text = turn['text']
-                    da = da2tuples(turn['dialog_act'])
-                    if role == 'sys' and i % 2 == 0:
-                        cur_context.append(text)
-                        cur_context_dialog_act.append(da)
-                        continue
-                    elif role == 'usr' and i % 2 == 1:
-                        cur_context.append(text)
-                        cur_context_dialog_act.append(da)
-                        continue
-                    if utterance:
-                        self.data[data_key]['utterance'].append(text)
-                    if dialog_act:
-                        self.data[data_key]['dialog_act'].append(da)
-                    if context:
-                        self.data[data_key]['context'].append(cur_context[-context_window_size:])
-                    if context_dialog_act:
-                        self.data[data_key]['context_dialog_act'].append(cur_context_dialog_act[-context_window_size:])
-                    if belief_state:
-                        entity_booked_dict, fixed_bs = self.fix_entity_booked_info(entity_booked_dict, turn['metadata'])
-                        self.data[data_key]['belief_state'].append(fixed_bs)
-                    if last_opponent_utterance:
-                        self.data[data_key]['last_opponent_utterance'].append(
-                            cur_context[-1] if len(cur_context) >= 1 else '')
-                    if last_self_utterance:
-                        self.data[data_key]['last_self_utterance'].append(
-                            cur_context[-2] if len(cur_context) >= 2 else '')
-                    if session_id:
-                        self.data[data_key]['session_id'].append(sess_id)
-                    if span_info:
-                        self.data[data_key]['span_info'].append(turn['span_info'])
-                    if terminated:
-                        self.data[data_key]['terminated'].append(i + 2 >= len(sess['log']))
-                    if goal:
-                        self.data[data_key]['goal'].append(sess['goal'])
+                self.data[k][x] = []
+        for sess_id, sess in data.items():
+            data_key = sess['split']
+            cur_context = []
+            cur_context_dialog_act = []
+            entity_booked_dict = dict((domain, False) for domain in belief_domains)
+            for i, turn in enumerate(sess['log']):
+                text = turn['text']
+                da = da2tuples(turn.get('dialog_act', {}))
+                if role == 'sys' and i % 2 == 0:
+                    cur_context.append(text)
+                    cur_context_dialog_act.append(da)
+                    continue
+                elif role == 'usr' and i % 2 == 1:
                     cur_context.append(text)
                     cur_context_dialog_act.append(da)
+                    continue
+                if utterance:
+                    self.data[data_key]['utterance'].append(text)
+                if dialog_act:
+                    self.data[data_key]['dialog_act'].append(da)
+                if context:
+                    self.data[data_key]['context'].append(cur_context[-context_window_size:])
+                if context_dialog_act:
+                    self.data[data_key]['context_dialog_act'].append(cur_context_dialog_act[-context_window_size:])
+                if belief_state:
+                    entity_booked_dict, fixed_bs = self.fix_entity_booked_info(entity_booked_dict, turn['metadata'])
+                    self.data[data_key]['belief_state'].append(fixed_bs)
+                if last_opponent_utterance:
+                    self.data[data_key]['last_opponent_utterance'].append(
+                        cur_context[-1] if len(cur_context) >= 1 else '')
+                if last_self_utterance:
+                    self.data[data_key]['last_self_utterance'].append(
+                        cur_context[-2] if len(cur_context) >= 2 else '')
+                if session_id:
+                    self.data[data_key]['session_id'].append(sess_id)
+                if span_info:
+                    self.data[data_key]['span_info'].append(turn['span_info'])
+                if terminated:
+                    self.data[data_key]['terminated'].append(i + 2 >= len(sess['log']))
+                if goal:
+                    self.data[data_key]['goal'].append(sess['goal'])
+                cur_context.append(text)
+                cur_context_dialog_act.append(da)
         if ontology:
             ontology_path = os.path.join(data_dir, 'ontology.json')
             self.data['ontology'] = json.load(open(ontology_path))
diff --git a/data/multiwoz/remap_actions.py b/data/multiwoz/remap_actions.py
index ab7e48ad..097ebce6 100644
--- a/data/multiwoz/remap_actions.py
+++ b/data/multiwoz/remap_actions.py
@@ -217,7 +217,7 @@ def preprocess():
 
     for ori_dialog_id, ori_dialog in tqdm(original_data.items()):
         if ori_dialog_id in val_list:
-            split = 'validation'
+            split = 'val'
         elif ori_dialog_id in test_list:
             split = 'test'
         else:
-- 
GitLab