Skip to content
Snippets Groups Projects
Commit 027f3425 authored by Christian's avatar Christian
Browse files

first version that can convert multiwoz data, trains supervised model and...

first version that can convert multiwoz data, trains supervised model and evaluates model with simulated user, f1-score is 0.52 and success rate is 73%
parent 5ec83abb
No related branches found
No related tags found
No related merge requests found
...@@ -69,7 +69,8 @@ convlab2.egg-info ...@@ -69,7 +69,8 @@ convlab2.egg-info
# configs # configs
*experiment*
*pretrained_models*
.ipynb_checkpoints .ipynb_checkpoints
## dst files ## dst files
......
...@@ -196,14 +196,8 @@ class PipelineAgent(Agent): ...@@ -196,14 +196,8 @@ class PipelineAgent(Agent):
for intent, domain, slot, value in self.output_action: for intent, domain, slot, value in self.output_action:
if domain.lower() not in ['general', 'booking']: if domain.lower() not in ['general', 'booking']:
self.cur_domain = domain self.cur_domain = domain
dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}' if intent == "book":
if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']: self.dst.state['belief_state'][domain.lower()]['book']['booked'] = [{slot.lower(): value}]
if self.cur_domain:
self.dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}]
elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref':
self.dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}]
elif dial_act == 'taxi-inform-car':
self.dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}]
else: else:
self.dst.state['user_action'] = self.output_action self.dst.state['user_action'] = self.output_action
# user dst is also updated by itself # user dst is also updated by itself
......
...@@ -111,21 +111,18 @@ class MultiWozEvaluator(Evaluator): ...@@ -111,21 +111,18 @@ class MultiWozEvaluator(Evaluator):
value = str(value) value = str(value)
self.sys_da_array.append(da + '-' + value) self.sys_da_array.append(da + '-' + value)
if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']: # new booking actions make life easier
if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \ if intent.lower() == "book":
len(self.dbs[self.cur_domain]) > int(value): # taxi has no DB queries
self.booked[self.cur_domain] = self.dbs[self.cur_domain][int( if domain.lower() == "taxi":
value)].copy()
self.booked[self.cur_domain]['Ref'] = value
self.booked_states[self.cur_domain] = belief_state[self.cur_domain]
elif da == 'train-offerbooked-ref' or da == 'train-inform-ref':
if not self.booked['train'] and re.match(r'^\d{8}$', value) and len(self.dbs['train']) > int(value):
self.booked['train'] = self.dbs['train'][int(value)].copy()
self.booked['train']['Ref'] = value
self.booked_states[self.cur_domain] = belief_state[self.cur_domain]
elif da == 'taxi-inform-car':
if not self.booked['taxi']: if not self.booked['taxi']:
self.booked['taxi'] = 'booked' self.booked['taxi'] = 'booked'
else:
if not self.booked[domain] and re.match(r'^\d{8}$', value) and \
len(self.dbs[domain]) > int(value):
self.booked[domain] = self.dbs[domain][int(value)].copy()
self.booked[domain]['Ref'] = value
self.booked_states[domain] = belief_state[domain]
def add_usr_da(self, da_turn): def add_usr_da(self, da_turn):
"""add usr_da into array """add usr_da into array
......
...@@ -9,6 +9,7 @@ import json ...@@ -9,6 +9,7 @@ import json
import logging import logging
import os import os
import random import random
from convlab2.policy.vector.vector_multiwoz import MultiWozVector
import numpy as np import numpy as np
import torch import torch
...@@ -167,7 +168,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v ...@@ -167,7 +168,7 @@ def evaluate(args, dataset_name, model_name, load_path, calculate_reward=True, v
if model_name == "PPO": if model_name == "PPO":
from convlab2.policy.ppo import PPO from convlab2.policy.ppo import PPO
if load_path: if load_path:
policy_sys = PPO(False) policy_sys = PPO(False, vectorizer=MultiWozVector())
policy_sys.load(load_path) policy_sys.load(load_path)
else: else:
policy_sys = PPO.from_pretrained() policy_sys = PPO.from_pretrained()
......
{"args": {"seed": 0, "eval_freq": 1}, "config": {"batchsz": 32, "epoch": 24, "lr_supervised": 0.0001, "save_dir": "save", "log_dir": "log", "print_per_batch": 400, "save_per_epoch": 1, "h_dim": 100, "load": "save/best", "pos_weight": 5, "hidden_size": 256, "weight_decay": 1e-05, "lambda": 1, "tau": 0.005, "policy_freq": 2, "entropy_weight": 0.001}}
\ No newline at end of file
{ {
"model": { "model": {
"load_path": "convlab2/policy/mle/multiwoz/experiment_2021-12-15-11-12-07/save/supervised", "load_path": "convlab2/policy/mle/multiwoz/experiment_2022-02-14-18-40-51/save/supervised",
"use_pretrained_initialisation": true, "use_pretrained_initialisation": false,
"pretrained_load_path": "", "pretrained_load_path": "",
"batchsz": 1000, "batchsz": 1000,
"seed": 0, "seed": 0,
......
...@@ -486,8 +486,7 @@ class Agenda(object): ...@@ -486,8 +486,7 @@ class Agenda(object):
continue continue
slot_vals = sys_action[diaact] slot_vals = sys_action[diaact]
#TODO: use string "book" instead of "booking" if 'book' in diaact:
if 'booking' in diaact:
if self.update_booking(diaact, slot_vals, goal): if self.update_booking(diaact, slot_vals, goal):
return return
elif 'general' in diaact: elif 'general' in diaact:
...@@ -503,8 +502,7 @@ class Agenda(object): ...@@ -503,8 +502,7 @@ class Agenda(object):
if slot == 'name': if slot == 'name':
self._remove_item(diaact.split( self._remove_item(diaact.split(
'-')[0]+'-inform', 'choice') '-')[0]+'-inform', 'choice')
# TODO: use string "book" instead of "booking" if 'book' in diaact and self.cur_domain:
if 'booking' in diaact and self.cur_domain:
g_book = self._get_goal_infos(self.cur_domain, goal)[-2] g_book = self._get_goal_infos(self.cur_domain, goal)[-2]
if len(g_book) == 0: if len(g_book) == 0:
self._push_item(self.cur_domain + self._push_item(self.cur_domain +
...@@ -535,15 +533,12 @@ class Agenda(object): ...@@ -535,15 +533,12 @@ class Agenda(object):
:param goal: Goal :param goal: Goal
:return: True:user want to close the session. False:session is continue :return: True:user want to close the session. False:session is continue
""" """
#TODO: Use domain of diaact.split instead of current domain domain, intent = diaact.split('-')
_, intent = diaact.split('-')
domain = self.cur_domain
self.domains['update_booking'] = domain self.domains['update_booking'] = domain
isover = False isover = False
if domain not in goal.domains: if domain not in goal.domains:
isover = False isover = False
#TODO: Remove inform
elif intent in ['book', 'inform']: elif intent in ['book', 'inform']:
isover = self._handle_inform(domain, intent, slot_vals, goal) isover = self._handle_inform(domain, intent, slot_vals, goal)
...@@ -686,8 +681,7 @@ class Agenda(object): ...@@ -686,8 +681,7 @@ class Agenda(object):
self._push_item(domain + '-inform', slot, g_book[slot]) self._push_item(domain + '-inform', slot, g_book[slot])
info_right = False info_right = False
#TODO: Only use "book" if intent in ['book'] and info_right:
if intent in ['book', 'offerbooked'] and info_right:
# booked ok # booked ok
if 'booked' in goal.domain_goals[domain]: if 'booked' in goal.domain_goals[domain]:
goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED goal.domain_goals[domain]['booked'] = DEF_VAL_BOOKED
......
...@@ -76,6 +76,14 @@ def lexicalize_da(meta, entities, state, requestable, cur_domain=None): ...@@ -76,6 +76,14 @@ def lexicalize_da(meta, entities, state, requestable, cur_domain=None):
else: else:
pair[1] = 'none' pair[1] = 'none'
else: else:
if intent.lower() == "book":
for pair in v:
if len(entities[domain]) > 0:
slot = REF_SYS_DA[domain].get('Ref', 'Ref')
if slot in entities[domain][0]:
pair[1] = entities[domain][0][slot]
continue
if domain.lower() in ['booking']: if domain.lower() in ['booking']:
if cur_domain and cur_domain in entities: if cur_domain and cur_domain in entities:
domain = cur_domain domain = cur_domain
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment