From 027f3425635096d0bd57684e0392ac6b11db24b9 Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Tue, 15 Feb 2022 16:10:03 +0100
Subject: [PATCH] first version that can convert multiwoz data, trains
 supervised model and evaluates model with simulated user, f1-score is 0.52
 and success rate is 73%

---
 .gitignore                                    |   3 +-
 convlab2/dialog_agent/agent.py                |  10 ++-----
 convlab2/evaluator/multiwoz_eval.py           |  27 ++++++++----------
 convlab2/policy/evaluate.py                   |   3 +-
 ...d568-7cb7-4ba9-96d4-9baa7357316e.fritz.box | Bin 40 -> 0 bytes
 .../configs/config_saved.json                 |   1 -
 .../policy/ppo/semantic_level_config.json     |   4 +--
 .../rule/multiwoz/policy_agenda_multiwoz.py   |  14 +++------
 convlab2/util/multiwoz/lexicalize.py          |   8 ++++++
 9 files changed, 32 insertions(+), 38 deletions(-)
 delete mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box
 delete mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json

diff --git a/.gitignore b/.gitignore
index a2820f1e..4e467920 100644
--- a/.gitignore
+++ b/.gitignore
@@ -69,7 +69,8 @@ convlab2.egg-info
 
 # configs
 
-
+*experiment*
+*pretrained_models*
 .ipynb_checkpoints
 
 ## dst files
diff --git a/convlab2/dialog_agent/agent.py b/convlab2/dialog_agent/agent.py
index 1afbc936..2feed5ad 100755
--- a/convlab2/dialog_agent/agent.py
+++ b/convlab2/dialog_agent/agent.py
@@ -196,14 +196,8 @@ class PipelineAgent(Agent):
                     for intent, domain, slot, value in self.output_action:
                         if domain.lower() not in ['general', 'booking']:
                             self.cur_domain = domain
-                        dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}'
-                        if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']:
-                            if self.cur_domain:
-                                self.dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}]
-                        elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref':
-                            self.dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}]
-                        elif dial_act == 'taxi-inform-car':
-                            self.dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}]
+                        if intent == "book":
+                            self.dst.state['belief_state'][domain.lower()]['book']['booked'] = [{slot.lower(): value}]
             else:
                 self.dst.state['user_action'] = self.output_action
                 # user dst is also updated by itself
diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py
index 202e248f..0ec4aeed 100755
--- a/convlab2/evaluator/multiwoz_eval.py
+++ b/convlab2/evaluator/multiwoz_eval.py
@@ -111,21 +111,18 @@ class MultiWozEvaluator(Evaluator):
             value = str(value)
             self.sys_da_array.append(da + '-' + value)
 
-            if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']:
-                if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \
-                        len(self.dbs[self.cur_domain]) > int(value):
-                    self.booked[self.cur_domain] = self.dbs[self.cur_domain][int(
-                        value)].copy()
-                    self.booked[self.cur_domain]['Ref'] = value
-                    self.booked_states[self.cur_domain] = belief_state[self.cur_domain]
-            elif da == 'train-offerbooked-ref' or da == 'train-inform-ref':
-                if not self.booked['train'] and re.match(r'^\d{8}$', value) and len(self.dbs['train']) > int(value):
-                    self.booked['train'] = self.dbs['train'][int(value)].copy()
-                    self.booked['train']['Ref'] = value
-                    self.booked_states[self.cur_domain] = belief_state[self.cur_domain]
-            elif da == 'taxi-inform-car':
-                if not self.booked['taxi']:
-                    self.booked['taxi'] = 'booked'
+            # new booking actions make life easier
+            if intent.lower() == "book":
+                # taxi has no DB queries
+                if domain.lower() == "taxi":
+                    if not self.booked['taxi']:
+                        self.booked['taxi'] = 'booked'
+                else:
+                    if not self.booked[domain] and re.match(r'^\d{8}$', value) and \
+                            len(self.dbs[domain]) > int(value):
+                        self.booked[domain] = self.dbs[domain][int(value)].copy()
+                        self.booked[domain]['Ref'] = value
+                        self.booked_states[domain] = belief_state[domain]
 
     def add_usr_da(self, da_turn):
         """add usr_da into array
diff --git a/convlab2/policy/evaluate.py b/convlab2/policy/evaluate.py
index 471e6c8e..da5d184f 100755
--- a/convlab2/policy/evaluate.py
+++ b/convlab2/policy/evaluate.py
@@ -9,6 +9,7 @@ import json
 import logging
 import os
 import random
+from convlab2.policy.vector.vector_multiwoz import MultiWozVector
 
 import numpy as np
 import torch
@@ -167,7 +168,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v
     if model_name == "PPO":
         from convlab2.policy.ppo import PPO
         if load_path:
-            policy_sys = PPO(False)
+            policy_sys = PPO(False, vectorizer=MultiWozVector())
             policy_sys.load(load_path)
         else:
             policy_sys = PPO.from_pretrained()
diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box
deleted file mode 100644
index a44089b2ed8b2a679ca9b81225aaa574230d0570..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 40
rcmb1OfPlsI-b$Pk(%y3{ZMxwo#hX-=n3<>NT9%quVrBHADKZfN)u#;6

diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json
deleted file mode 100644
index 377cab63..00000000
--- a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json
+++ /dev/null
@@ -1 +0,0 @@
-{"args": {"seed": 0, "eval_freq": 1}, "config": {"batchsz": 32, "epoch": 24, "lr_supervised": 0.0001, "save_dir": "save", "log_dir": "log", "print_per_batch": 400, "save_per_epoch": 1, "h_dim": 100, "load": "save/best", "pos_weight": 5, "hidden_size": 256, "weight_decay": 1e-05, "lambda": 1, "tau": 0.005, "policy_freq": 2, "entropy_weight": 0.001}}
\ No newline at end of file
diff --git a/convlab2/policy/ppo/semantic_level_config.json b/convlab2/policy/ppo/semantic_level_config.json
index 095a21b5..24cfeb57 100644
--- a/convlab2/policy/ppo/semantic_level_config.json
+++ b/convlab2/policy/ppo/semantic_level_config.json
@@ -1,7 +1,7 @@
 {
 	"model": {
-		"load_path": "convlab2/policy/mle/multiwoz/experiment_2021-12-15-11-12-07/save/supervised",
-		"use_pretrained_initialisation": true,
+		"load_path": "convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/save/supervised",
+		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
 		"batchsz": 1000,
 		"seed": 0,
diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
index 13994c67..b2778a31 100755
--- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
+++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
@@ -486,8 +486,7 @@ class Agenda(object):
                 continue
 
             slot_vals = sys_action[diaact]
-            #TODO: use string "book" instead of "booking"
-            if 'booking' in diaact:
+            if 'book' in diaact:
                 if self.update_booking(diaact, slot_vals, goal):
                     return
             elif 'general' in diaact:
@@ -503,8 +502,7 @@ class Agenda(object):
                     if slot == 'name':
                         self._remove_item(diaact.split(
                             '-')[0]+'-inform', 'choice')
-            # TODO: use string "book" instead of "booking"
-            if 'booking' in diaact and self.cur_domain:
+            if 'book' in diaact and self.cur_domain:
                 g_book = self._get_goal_infos(self.cur_domain, goal)[-2]
                 if len(g_book) == 0:
                     self._push_item(self.cur_domain +
@@ -535,15 +533,12 @@ class Agenda(object):
         :param goal:        Goal
         :return:            True:user want to close the session. False:session is continue
         """
-        #TODO: Use domain of diaact.split instead of current domain
-        _, intent = diaact.split('-')
-        domain = self.cur_domain
+        domain, intent = diaact.split('-')
         self.domains['update_booking'] = domain
         isover = False
         if domain not in goal.domains:
             isover = False
 
-        #TODO: Remove inform
         elif intent in ['book', 'inform']:
             isover = self._handle_inform(domain, intent, slot_vals, goal)
 
@@ -686,8 +681,7 @@ class Agenda(object):
                 self._push_item(domain + '-inform', slot, g_book[slot])
                 info_right = False
 
-        #TODO: Only use "book"
-        if intent in ['book', 'offerbooked'] and info_right:
+        if intent in ['book'] and info_right:
             # booked ok
             if 'booked' in goal.domain_goals[domain]:
                 goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED
diff --git a/convlab2/util/multiwoz/lexicalize.py b/convlab2/util/multiwoz/lexicalize.py
index 427a54ab..3f798e46 100755
--- a/convlab2/util/multiwoz/lexicalize.py
+++ b/convlab2/util/multiwoz/lexicalize.py
@@ -76,6 +76,14 @@ def lexicalize_da(meta, entities, state, requestable, cur_domain=None):
                     else:
                         pair[1] = 'none'
         else:
+            if intent.lower() == "book":
+                for pair in v:
+                    if len(entities[domain]) > 0:
+                        slot = REF_SYS_DA[domain].get('Ref', 'Ref')
+                        if slot in entities[domain][0]:
+                            pair[1] = entities[domain][0][slot]
+                continue
+
             if domain.lower() in ['booking']:
                 if cur_domain and cur_domain in entities:
                     domain = cur_domain
-- 
GitLab