From 027f3425635096d0bd57684e0392ac6b11db24b9 Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Tue, 15 Feb 2022 16:10:03 +0100 Subject: [PATCH] first version that can convert multiwoz data, trains supervised model and evaluates model with simulated user, f1-score is 0.52 and success rate is 73% --- .gitignore | 3 +- convlab2/dialog_agent/agent.py | 10 ++----- convlab2/evaluator/multiwoz_eval.py | 27 ++++++++---------- convlab2/policy/evaluate.py | 3 +- ...d568-7cb7-4ba9-96d4-9baa7357316e.fritz.box | Bin 40 -> 0 bytes .../configs/config_saved.json | 1 - .../policy/ppo/semantic_level_config.json | 4 +-- .../rule/multiwoz/policy_agenda_multiwoz.py | 14 +++------ convlab2/util/multiwoz/lexicalize.py | 8 ++++++ 9 files changed, 32 insertions(+), 38 deletions(-) delete mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box delete mode 100644 convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json diff --git a/.gitignore b/.gitignore index a2820f1e..4e467920 100644 --- a/.gitignore +++ b/.gitignore @@ -69,7 +69,8 @@ convlab2.egg-info # configs - +*experiment* +*pretrained_models* .ipynb_checkpoints ## dst files diff --git a/convlab2/dialog_agent/agent.py b/convlab2/dialog_agent/agent.py index 1afbc936..2feed5ad 100755 --- a/convlab2/dialog_agent/agent.py +++ b/convlab2/dialog_agent/agent.py @@ -196,14 +196,8 @@ class PipelineAgent(Agent): for intent, domain, slot, value in self.output_action: if domain.lower() not in ['general', 'booking']: self.cur_domain = domain - dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}' - if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']: - if self.cur_domain: - self.dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref': - self.dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}] - elif dial_act == 'taxi-inform-car': - self.dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}] + if intent == "book": + self.dst.state['belief_state'][domain.lower()]['book']['booked'] = [{slot.lower(): value}] else: self.dst.state['user_action'] = self.output_action # user dst is also updated by itself diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py index 202e248f..0ec4aeed 100755 --- a/convlab2/evaluator/multiwoz_eval.py +++ b/convlab2/evaluator/multiwoz_eval.py @@ -111,21 +111,18 @@ class MultiWozEvaluator(Evaluator): value = str(value) self.sys_da_array.append(da + '-' + value) - if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']: - if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \ - len(self.dbs[self.cur_domain]) > int(value): - self.booked[self.cur_domain] = self.dbs[self.cur_domain][int( - value)].copy() - self.booked[self.cur_domain]['Ref'] = value - self.booked_states[self.cur_domain] = belief_state[self.cur_domain] - elif da == 'train-offerbooked-ref' or da == 'train-inform-ref': - if not self.booked['train'] and re.match(r'^\d{8}$', value) and len(self.dbs['train']) > int(value): - self.booked['train'] = self.dbs['train'][int(value)].copy() - self.booked['train']['Ref'] = value - self.booked_states[self.cur_domain] = belief_state[self.cur_domain] - elif da == 'taxi-inform-car': - if not self.booked['taxi']: - self.booked['taxi'] = 'booked' + # new booking actions make life easier + if intent.lower() == "book": + # taxi has no DB queries + if domain.lower() == "taxi": + if not self.booked['taxi']: + self.booked['taxi'] = 'booked' + else: + if not self.booked[domain] and re.match(r'^\d{8}$', value) and \ + len(self.dbs[domain]) > int(value): + self.booked[domain] = self.dbs[domain][int(value)].copy() + self.booked[domain]['Ref'] = value + self.booked_states[domain] = belief_state[domain] def add_usr_da(self, da_turn): """add usr_da into array diff --git a/convlab2/policy/evaluate.py b/convlab2/policy/evaluate.py index 471e6c8e..da5d184f 100755 --- a/convlab2/policy/evaluate.py +++ b/convlab2/policy/evaluate.py @@ -9,6 +9,7 @@ import json import logging import os import random +from convlab2.policy.vector.vector_multiwoz import MultiWozVector import numpy as np import torch @@ -167,7 +168,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v if model_name == "PPO": from convlab2.policy.ppo import PPO if load_path: - policy_sys = PPO(False) + policy_sys = PPO(False, vectorizer=MultiWozVector()) policy_sys.load(load_path) else: policy_sys = PPO.from_pretrained() diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/TB_summary/events.out.tfevents.1644860451.f1acd568-7cb7-4ba9-96d4-9baa7357316e.fritz.box deleted file mode 100644 index a44089b2ed8b2a679ca9b81225aaa574230d0570..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 40 rcmb1OfPlsI-b$Pk(%y3{ZMxwo#hX-=n3<>NT9%quVrBHADKZfN)u#;6 diff --git a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json b/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json deleted file mode 100644 index 377cab63..00000000 --- a/convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/configs/config_saved.json +++ /dev/null @@ -1 +0,0 @@ -{"args": {"seed": 0, "eval_freq": 1}, "config": {"batchsz": 32, "epoch": 24, "lr_supervised": 0.0001, "save_dir": "save", "log_dir": "log", "print_per_batch": 400, "save_per_epoch": 1, "h_dim": 100, "load": "save/best", "pos_weight": 5, "hidden_size": 256, "weight_decay": 1e-05, "lambda": 1, "tau": 0.005, "policy_freq": 2, "entropy_weight": 0.001}} \ No newline at end of file diff --git a/convlab2/policy/ppo/semantic_level_config.json b/convlab2/policy/ppo/semantic_level_config.json index 095a21b5..24cfeb57 100644 --- a/convlab2/policy/ppo/semantic_level_config.json +++ b/convlab2/policy/ppo/semantic_level_config.json @@ -1,7 +1,7 @@ { "model": { - "load_path": "convlab2/policy/mle/multiwoz/experiment_2021-12-15-11-12-07/save/supervised", - "use_pretrained_initialisation": true, + "load_path": "convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/save/supervised", + "use_pretrained_initialisation": false, "pretrained_load_path": "", "batchsz": 1000, "seed": 0, diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index 13994c67..b2778a31 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -486,8 +486,7 @@ class Agenda(object): continue slot_vals = sys_action[diaact] - #TODO: use string "book" instead of "booking" - if 'booking' in diaact: + if 'book' in diaact: if self.update_booking(diaact, slot_vals, goal): return elif 'general' in diaact: @@ -503,8 +502,7 @@ class Agenda(object): if slot == 'name': self._remove_item(diaact.split( '-')[0]+'-inform', 'choice') - # TODO: use string "book" instead of "booking" - if 'booking' in diaact and self.cur_domain: + if 'book' in diaact and self.cur_domain: g_book = self._get_goal_infos(self.cur_domain, goal)[-2] if len(g_book) == 0: self._push_item(self.cur_domain + @@ -535,15 +533,12 @@ class Agenda(object): :param goal: Goal :return: True:user want to close the session. False:session is continue """ - #TODO: Use domain of diaact.split instead of current domain - _, intent = diaact.split('-') - domain = self.cur_domain + domain, intent = diaact.split('-') self.domains['update_booking'] = domain isover = False if domain not in goal.domains: isover = False - #TODO: Remove inform elif intent in ['book', 'inform']: isover = self._handle_inform(domain, intent, slot_vals, goal) @@ -686,8 +681,7 @@ class Agenda(object): self._push_item(domain + '-inform', slot, g_book[slot]) info_right = False - #TODO: Only use "book" - if intent in ['book', 'offerbooked'] and info_right: + if intent in ['book'] and info_right: # booked ok if 'booked' in goal.domain_goals[domain]: goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED diff --git a/convlab2/util/multiwoz/lexicalize.py b/convlab2/util/multiwoz/lexicalize.py index 427a54ab..3f798e46 100755 --- a/convlab2/util/multiwoz/lexicalize.py +++ b/convlab2/util/multiwoz/lexicalize.py @@ -76,6 +76,14 @@ def lexicalize_da(meta, entities, state, requestable, cur_domain=None): else: pair[1] = 'none' else: + if intent.lower() == "book": + for pair in v: + if len(entities[domain]) > 0: + slot = REF_SYS_DA[domain].get('Ref', 'Ref') + if slot in entities[domain][0]: + pair[1] = entities[domain][0][slot] + continue + if domain.lower() in ['booking']: if cur_domain and cur_domain in entities: domain = cur_domain -- GitLab