Skip to content
Snippets Groups Projects
Unverified Commit 2d9ba96a authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #28 from ConvLab/remap_book_actions

Map actions for Booking domain to domain-specific actions
parents 0d908dd7 bcb3d852
Branches
No related tags found
No related merge requests found
Showing
with 2215 additions and 63 deletions
......@@ -69,7 +69,8 @@ convlab2.egg-info
# configs
*experiment*
*pretrained_models*
.ipynb_checkpoints
## dst files
......
FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04
ENV DEBIAN_FRONTEND noninteractive
RUN apt-get update
RUN apt-get install -y --no-install-recommends software-properties-common
RUN add-apt-repository ppa:deadsnakes/ppa
FROM python:3.6-slim AS compile-image
RUN apt-get update
RUN apt-get install -y --no-install-recommends python3.7 python3-pip build-essential libssl-dev libffi-dev python3.7-dev
......
......@@ -252,11 +252,11 @@ We welcome contributions from community.
## Team
**ConvLab-2** is maintained and developed by Tsinghua University Conversational AI group (THU-coai) and Microsoft Research (MSR).
**ConvLab-3** is maintained and developed by Tsinghua University Conversational AI group (THU-coai), the [Dialogue Systems and Machine Learning Group](https://www.cs.hhu.de/en/research-groups/dialog-systems-and-machine-learning.html) at Heinrich Heine University, Düsseldorf, Germany and Microsoft Research (MSR).
We would like to thank:
Yan Fang, Zhuoer Feng, Jianfeng Gao, Qihan Guo, Kaili Huang, Minlie Huang, Sungjin Lee, Bing Li, Jinchao Li, Xiang Li, Xiujun Li, Jiexi Liu, Lingxiao Luo, Wenchang Ma, Mehrad Moradshahi, Baolin Peng, Runze Liang, Ryuichi Takanobu, Hongru Wang, Jiaxin Wen, Yaoqin Zhang, Zheng Zhang, Qi Zhu, Xiaoyan Zhu.
Yan Fang, Zhuoer Feng, Jianfeng Gao, Qihan Guo, Kaili Huang, Minlie Huang, Sungjin Lee, Bing Li, Jinchao Li, Xiang Li, Xiujun Li, Jiexi Liu, Lingxiao Luo, Wenchang Ma, Mehrad Moradshahi, Baolin Peng, Runze Liang, Ryuichi Takanobu, Hongru Wang, Jiaxin Wen, Yaoqin Zhang, Zheng Zhang, Qi Zhu, Xiaoyan Zhu, Carel van Niekerk, Christian Geishauser, Hsien-chin Lin, Nurul Lubis, Xiaochen Zhu, Michael Heck, Shutong Feng, Milica Gašić.
## Citing
......
import os
import os
from convlab2.nlu import NLU
from convlab2.dst import DST
from convlab2.policy import Policy
......
"""Dialog agent interface and classes."""
from abc import ABC, abstractmethod
import logging
from convlab2.nlu import NLU
from convlab2.dst import DST
from convlab2.policy import Policy
from convlab2.nlg import NLG
from copy import deepcopy
import time
from pprint import pprint
class Agent(ABC):
"""Interface for dialog agent classes."""
@abstractmethod
def __init__(self, name: str):
self.name = name
......@@ -38,6 +42,7 @@ class Agent(ABC):
class PipelineAgent(Agent):
"""Pipeline dialog agent base class, including NLU, DST, Policy and NLG.
The combination modes of pipeline agent modules are flexible. The only thing you have to make sure is that
......@@ -58,7 +63,7 @@ class PipelineAgent(Agent):
===== ===== ====== === == ===
"""
def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str):
def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str, return_semantic_acts=False):
"""The constructor of PipelineAgent class.
Here are some special combination cases:
......@@ -70,7 +75,7 @@ class PipelineAgent(Agent):
Args:
nlu (NLU):
The natural langauge understanding module of agent.
The natural language understanding module of agent.
dst (DST):
The dialog state tracker of agent.
......@@ -79,17 +84,33 @@ class PipelineAgent(Agent):
The dialog policy module of agent.
nlg (NLG):
The natural langauge generator module of agent.
The natural language generator module of agent.
"""
super(PipelineAgent, self).__init__(name=name)
assert self.name in ['user', 'sys']
self.opponent_name = 'user' if self.name is 'sys' else 'sys'
self.opponent_name = 'user' if self.name == 'sys' else 'sys'
self.nlu = nlu
self.dst = dst
self.policy = policy
self.nlg = nlg
self.return_semantic_acts = return_semantic_acts
self.init_session()
self.agent_saves = []
self.history = []
self.turn = 0
self.cur_domain = None
#logging.info("Pipeline Agent info_dict check")
if hasattr(self.nlu, 'info_dict') == False:
logging.warning('nlu info_dict is not initialized')
if hasattr(self.dst, 'info_dict') == False:
logging.warning('dst info_dict is not initialized')
if hasattr(self.policy, 'info_dict') == False:
logging.warning('policy info_dict is not initialized')
if hasattr(self.nlg, 'info_dict') == False:
logging.warning('nlg info_dict is not initialized')
#logging.info("Done")
def state_replace(self, agent_state):
"""
......@@ -110,46 +131,107 @@ class PipelineAgent(Agent):
return agent_state
def response(self, observation):
"""Generate agent response using the agent modules."""
# Note: If you modify the logic of this function, please ensure that it is consistent with deploy.server.ServerCtrl._turn()
if self.dst is not None:
self.dst.state['history'].append([self.opponent_name, observation]) # [['sys', sys_utt], ['user', user_utt],...]
# [['sys', sys_utt], ['user', user_utt],...]
self.dst.state['history'].append([self.opponent_name, observation])
self.history.append([self.opponent_name, observation])
# get dialog act
if self.name == 'sys':
if self.nlu is not None:
self.input_action = self.nlu.predict(observation, context=[x[1] for x in self.history[:-1]])
self.input_action = self.nlu.predict(
observation, context=[x[1] for x in self.history[:-1]])
else:
self.input_action = observation
self.input_action = deepcopy(self.input_action) # get rid of reference problem
else:
if self.nlu is not None:
self.input_action_eval = self.nlu.predict(
observation, context=[x[1] for x in self.history[:-1]])
self.input_action = self.nlu.predict(
observation, context=[x[1] for x in self.history[:-1]])
else:
self.input_action = observation
self.input_action_eval = observation
# get rid of reference problem
self.input_action = deepcopy(self.input_action)
# get state
if self.dst is not None:
if self.name is 'sys':
if self.name == 'sys':
self.dst.state['user_action'] = self.input_action
else:
self.dst.state['system_action'] = self.input_action
state = self.dst.update(self.input_action)
else:
state = self.input_action
state = deepcopy(state) # get rid of reference problem
# get action
self.output_action = deepcopy(self.policy.predict(state)) # get rid of reference problem
# get rid of reference problem
self.output_action = deepcopy(self.policy.predict(state))
# get model response
if self.nlg is not None:
model_response = self.nlg.generate(self.output_action)
else:
model_response = self.output_action
# print(model_response)
if self.dst is not None:
self.dst.state['history'].append([self.name, model_response])
if self.name is 'sys':
if self.name == 'sys':
self.dst.state['system_action'] = self.output_action
# If system takes booking action add booking info to the 'book-booked' section of the belief state
if type(self.input_action) != list:
self.input_action = self.dst.state['user_action']
if type(self.input_action) == list:
for intent, domain, slot, value in self.input_action:
if domain.lower() not in ['booking', 'general']:
self.cur_domain = domain
if type(self.output_action) == list:
for intent, domain, slot, value in self.output_action:
if domain.lower() not in ['general', 'booking']:
self.cur_domain = domain
if intent == "book":
self.dst.state['belief_state'][domain.lower()]['book']['booked'] = [{slot.lower(): value}]
else:
self.dst.state['user_action'] = self.output_action
# user dst is also updated by itself
state = self.dst.update(self.output_action)
self.history.append([self.name, model_response])
self.turn += 1
if self.return_semantic_acts:
return self.output_action
self.agent_saves.append(self.save_info())
return model_response
def save_info(self):
try:
infos = {}
if hasattr(self.nlu, 'info_dict'):
infos["nlu"] = self.nlu.info_dict
if hasattr(self.dst, 'info_dict'):
infos["dst"] = self.dst.info_dict
if hasattr(self.policy, 'info_dict'):
infos["policy"] = self.policy.info_dict
if hasattr(self.nlg, 'info_dict'):
infos["nlg"] = self.nlg.info_dict
# nlu_info = self.agents[agent_id].nlu.info
# policy_info = self.agents[agent_id].policy.info
# nlg_info = self.agents[agent_id].nlg.info
# infos = {"nlu": nlu_info, "policy": policy_info, "nlg": nlg_info}
# infos = {"nlu": self.turn, "policy": "policy", "nlg": "nlg"}
except:
infos = None
return infos
def is_terminated(self):
if hasattr(self.policy, 'is_terminated'):
return self.policy.is_terminated()
......@@ -162,20 +244,246 @@ class PipelineAgent(Agent):
def init_session(self, **kwargs):
"""Init the attributes of DST and Policy module."""
self.cur_domain = None
if self.nlu is not None:
self.nlu.init_session()
if self.dst is not None:
self.dst.init_session()
if self.name == 'sys':
self.dst.state['history'].append([self.name, 'null'])
self.dst.state['history'].append(
[self.name, 'null']) # TODO: ??
if self.policy is not None:
self.policy.init_session(**kwargs)
if self.nlg is not None:
self.nlg.init_session()
self.history = []
def get_in_da_eval(self):
return self.input_action_eval
def get_in_da(self):
return self.input_action
def get_out_da(self):
return self.output_action
# Agent for Dialogue Server for HHU Dialcrowd. It is an extension of PipelineAgent with minor modification.
class DialogueAgent(Agent):
"""Pipeline dialog agent base class, including NLU, DST, Policy and NLG.
"""
def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str = "sys"):
"""The constructor of DialogueAgent class.
Here are some special combination cases:
1. If you use word-level DST (such as Neural Belief Tracker), you should set the nlu_model parameter \
to None. The agent will combine the modules automatically.
2. If you want to aggregate DST and Policy as a single module, set tracker to None.
Args:
nlu (NLU):
The natural language understanding module of agent.
dst (DST):
The dialog state tracker of agent.
policy (Policy):
The dialog policy module of agent.
nlg (NLG):
The natural language generator module of agent.
"""
super(DialogueAgent, self).__init__(name=name)
assert self.name in ['sys']
self.opponent_name = 'user'
self.nlu = nlu
self.dst = dst
self.policy = policy
self.nlg = nlg
self.module_names = ["nlu", "dst", "policy", "nlg"]
self.init_session()
self.history = []
self.session_id = None
self.ENDING_DIALOG = False
self.USER_RATED = False
self.USER_GOAL_ACHIEVED = None
self.taskID = None
self.feedback = None
self.requested_feedback = False
self.sys_state_history = []
self.sys_action_history = []
self.sys_utterance_history = []
self.sys_output_history = []
self.action_mask_history = []
self.action_prob_history = []
self.turn = 0
self.agent_saves = {"session_id": None, "agent_id": None,
"user_id": None, "timestamp": None, "dialogue_info": [], "dialogue_info_fundamental": []}
self.initTime = int(time.time())
self.lastUpdate = int(time.time())
self.cur_domain = None
logging.info("Dialogue Agent info_dict check")
if not hasattr(self.nlu, 'info_dict'):
logging.warning('nlu info_dict is not initialized')
if not hasattr(self.dst, 'info_dict'):
logging.warning('dst info_dict is not initialized')
if not hasattr(self.policy, 'info_dict'):
logging.warning('policy info_dict is not initialized')
if not hasattr(self.nlg, 'info_dict'):
logging.warning('nlg info_dict is not initialized')
def response(self, observation):
"""Generate agent response using the agent modules."""
self.sys_utterance_history.append(observation)
fundamental_info = {'observation': observation}
if self.dst is not None:
self.dst.state['history'].append(
[self.opponent_name, observation]) # [['sys', sys_utt], ['user', user_utt],...]
self.history.append([self.opponent_name, observation])
# get dialog act
if self.nlu is not None:
self.input_action = self.nlu.predict(
observation, context=[x[1] for x in self.history[:-1]])
else:
self.input_action = observation
# get rid of reference problem
self.input_action = deepcopy(self.input_action)
fundamental_info['input_action'] = self.input_action
# get state
if self.dst is not None:
self.dst.state['user_action'] = self.input_action
state = self.dst.update(self.input_action)
else:
state = self.input_action
fundamental_info['state'] = state
state = deepcopy(state) # get rid of reference problem
self.sys_state_history.append(state)
# get action
# get rid of reference problem
self.output_action = deepcopy(self.policy.predict(state))
if hasattr(self.policy, "last_action"):
self.sys_action_history.append(self.policy.last_action)
else:
self.sys_action_history.append(self.output_action)
fundamental_info['output_action'] = self.output_action
if hasattr(self.policy, "prob"):
self.action_prob_history.append(self.policy.prob)
# get model response
if self.nlg is not None:
model_response = self.nlg.generate(self.output_action)
else:
model_response = self.output_action
self.sys_output_history.append(model_response)
fundamental_info['model_response'] = model_response
if self.dst is not None:
self.dst.state['history'].append([self.name, model_response])
self.dst.state['system_action'] = self.output_action
# If system takes booking action add booking info to the 'book-booked' section of the belief state
if type(self.output_action) == list:
for intent, domain, slot, value in self.output_action:
if domain.lower() not in ['general', 'booking']:
self.cur_domain = domain
dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}'
if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']:
if self.cur_domain:
self.dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}]
elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref':
self.dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}]
elif dial_act == 'taxi-inform-car':
self.dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}]
self.history.append([self.name, model_response])
self.turn += 1
self.lastUpdate = int(time.time())
self.agent_saves['dialogue_info_fundamental'].append(fundamental_info)
self.agent_saves['dialogue_info'].append(self.get_info())
return model_response
def get_info(self):
info_dict = {}
for name in self.module_names:
module = getattr(self, name)
module_info = getattr(module, "info_dict", None)
info_dict[name] = module_info
return info_dict
def is_terminated(self):
if hasattr(self.policy, 'is_terminated'):
return self.policy.is_terminated()
return None
def retrieve_reward(self):
rewards = [1] * len(self.sys_state_history)
for turn in self.feedback:
turn_number = int((int(turn) - 2) / 2)
if turn_number >= len(self.sys_state_history):
continue
# TODO possibly use text here to check whether rating belongs to the right utterance of the system
text = self.feedback[turn]['text']
rating = self.feedback[turn]["isGood"]
rewards[turn_number] = int(rating)
return rewards
def get_reward(self):
if hasattr(self.policy, 'get_reward'):
return self.policy.get_reward()
return None
def init_session(self):
"""Init the attributes of DST and Policy module."""
self.cur_domain = None
if self.nlu is not None:
self.nlu.init_session()
if self.dst is not None:
self.dst.init_session()
self.dst.state['history'].append([self.name, 'null'])
if self.policy is not None:
self.policy.init_session()
if self.nlg is not None:
self.nlg.init_session()
self.history = []
def get_in_da(self):
return self.input_action
def get_out_da(self):
return self.output_action
def print_ending_agent_summary(self):
print("session_id")
print(self.session_id)
print("taskID")
print(self.taskID)
print("USER_GOAL_ACHIEVED")
print(self.USER_GOAL_ACHIEVED)
print("sys_state_history")
print(self.sys_state_history)
print("sys_action_history")
print(self.sys_action_history)
def is_inactive(self):
currentTime = int(time.time())
return currentTime - self.initTime >= 600 and currentTime - self.lastUpdate >= 60
......@@ -5,37 +5,73 @@ Created on Wed Jul 17 14:27:34 2019
@author: truthless
"""
import pdb
class Environment():
def __init__(self, sys_nlg, usr, sys_nlu, sys_dst, evaluator=None):
def __init__(self, sys_nlg, usr, sys_nlu, sys_dst, evaluator=None, use_semantic_acts=False):
self.sys_nlg = sys_nlg
self.usr = usr
self.sys_nlu = sys_nlu
self.sys_dst = sys_dst
self.evaluator = evaluator
self.use_semantic_acts = use_semantic_acts
self.cur_domain = None
def reset(self):
self.usr.init_session()
self.sys_dst.init_session()
self.cur_domain = None
if self.evaluator:
self.evaluator.add_goal(self.usr.policy.get_goal())
s, r, t = self.step([])
return self.sys_dst.state
def step(self, action):
model_response = self.sys_nlg.generate(action) if self.sys_nlg else action
if not self.use_semantic_acts:
model_response = self.sys_nlg.generate(
action) if self.sys_nlg else action
else:
model_response = action
# If system takes booking action add booking info to the 'book-booked' section of the belief state
if type(action) == list:
for intent, domain, slot, value in action:
if domain.lower() not in ['general', 'booking']:
self.cur_domain = domain
dial_act = f'{domain.lower()}-{intent.lower()}-{slot.lower()}'
if dial_act == 'booking-book-ref' and self.cur_domain.lower() in ['hotel', 'restaurant', 'train']:
if self.cur_domain:
self.sys_dst.state['belief_state'][self.cur_domain.lower()]['book']['booked'] = [{slot.lower():value}]
elif dial_act == 'train-offerbooked-ref' or dial_act == 'train-inform-ref':
self.sys_dst.state['belief_state']['train']['book']['booked'] = [{slot.lower():value}]
elif dial_act == 'taxi-inform-car':
self.sys_dst.state['belief_state']['taxi']['book']['booked'] = [{slot.lower():value}]
observation = self.usr.response(model_response)
if self.evaluator:
self.evaluator.add_sys_da(self.usr.get_in_da())
self.evaluator.add_sys_da(self.usr.get_in_da(), self.sys_dst.state['belief_state'])
self.evaluator.add_usr_da(self.usr.get_out_da())
dialog_act = self.sys_nlu.predict(observation) if self.sys_nlu else observation
dialog_act = self.sys_nlu.predict(
observation) if self.sys_nlu else observation
self.sys_dst.state['user_action'] = dialog_act
state = self.sys_dst.update(dialog_act)
dialog_act = self.sys_dst.state['user_action']
if type(dialog_act) == list:
for intent, domain, slot, value in dialog_act:
if domain.lower() not in ['booking', 'general']:
self.cur_domain = domain
state['history'].append(["sys", model_response])
state['history'].append(["usr", observation])
terminated = self.usr.is_terminated()
if self.evaluator:
reward = self.evaluator.get_reward()
reward = self.evaluator.get_reward(terminated)
else:
reward = self.usr.get_reward()
terminated = self.usr.is_terminated()
return state, reward, terminated
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 17 14:27:34 2019
@author: truthless
"""
import pdb
from convlab2.dialog_agent.env import Environment
class UsrEnvironment(Environment):
def __init__(self, sys, usr, evaluator=None, use_semantic_acts=False):
self.sys = sys
self.usr = usr
self.evaluator = evaluator
def reset(self):
self.usr.init_session()
self.sys.init_session()
if self.evaluator:
self.evaluator.add_goal(self.usr.policy.get_goal())
sys_response = self.sys.response([])
usr_response = self.usr.response(sys_response)
s, r, t = self.step(usr_response)
# return self.sys_dst.state
print("-" * 20)
return self.usr.dst.state
def step(self, action):
# if not self.use_semantic_acts:
# model_response = self.sys_nlg.generate(
# action) if self.sys_nlg else action
# else:
# model_response = action
# only semantic level
usr_response = action
sys_response = self.sys.response(usr_response)
print(f"(env_usr) usr: {usr_response}")
print(f"(env_usr) sys: {sys_response}")
if self.evaluator:
if not self.usr.get_in_da():
# print("not sure why")
usr_in_da = sys_response
usr_out_da = action
else:
usr_in_da = self.usr.get_in_da()
usr_out_da = self.usr.get_out_da()
# print(f"usr_in_da {usr_in_da}, usr_out_da {usr_out_da}")
self.evaluator.add_sys_da(usr_in_da)
self.evaluator.add_usr_da(usr_out_da)
# dialog_act = self.sys_nlu.predict(
# observation) if self.sys_nlu else observation
# TODO pipeline agent should update the dst itself <- make sure why
state = self.usr.dst.update(sys_response)
self.usr.dst.state['user_action'] = usr_response
self.usr.dst.state['system_action'] = sys_response
self.usr.dst.state['history'].append(["usr", usr_response])
self.usr.dst.state['history'].append(["sys", sys_response])
terminated = self.usr.is_terminated()
if terminated:
# TODO uncomment this line
# if self.evaluator:
# if self.evaluator.task_success():
# reward = 80/40
# elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain):
# reward = 0
# else:
# reward = -40/40
# else:
reward = self.usr.get_reward()
else:
# reward = -1 + self.usr.policy.get_turn_reward()
# reward = reward / 40
reward = self.usr.policy.get_turn_reward()
return state, reward, terminated
......@@ -5,6 +5,7 @@ from convlab2.dialog_agent.agent import Agent
class Session(ABC):
"""Base dialog session controller, which manages the agents to conduct a complete dialog session.
"""
......@@ -82,14 +83,18 @@ class BiSession(Session):
"""The user and system agent response in turn."""
if self.__turn_indicator % 2 == 0:
next_agent = self.user_agent
agent = "user"
else:
next_agent = self.sys_agent
agent = "sys"
self.__turn_indicator += 1
# print(agent + " " + str(self.__turn_indicator))
return next_agent
def next_response(self, observation):
next_agent = self.next_agent()
response = next_agent.response(observation)
# print(response)
return response
def next_turn(self, last_observation):
......@@ -116,8 +121,9 @@ class BiSession(Session):
"""
user_response = self.next_response(last_observation)
if self.evaluator:
self.evaluator.add_sys_da(self.user_agent.get_in_da())
self.evaluator.add_sys_da(self.user_agent.get_in_da_eval(), self.sys_agent.dst.state['belief_state'])
self.evaluator.add_usr_da(self.user_agent.get_out_da())
session_over = self.user_agent.is_terminated()
if hasattr(self.sys_agent, 'dst'):
self.sys_agent.dst.state['terminated'] = session_over
......@@ -130,7 +136,6 @@ class BiSession(Session):
sys_response = self.next_response(user_response)
self.dialog_history.append([self.user_agent.name, user_response])
self.dialog_history.append([self.sys_agent.name, sys_response])
return sys_response, user_response, session_over, reward
def train_policy(self):
......@@ -193,6 +198,7 @@ class DealornotSession(Session):
for agent in [self.alice, self.bob]:
choice = agent.choose()
choices.append(choice)
agree, rewards = self.alice.domain.score_choices(choices, ctxs)
return agree, rewards
......
"""Dialog State Tracker Interface"""
import copy
from abc import abstractmethod
from convlab2.util.module import Module
......
# -*- coding: utf-8 -*-
# -*- coding: gbk -*-
"""
Evaluate DST models on specified dataset
Usage: python evaluate.py [MultiWOZ|CrossWOZ|MultiWOZ-zh|CrossWOZ-en] [TRADE|mdbt|sumbt] [val|test|human_val]
Evaluate NLU models on specified dataset
Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule]
"""
from convlab2.dst.sumbt.crosswoz_en.sumbt import crosswoz_en_slot_list
from convlab2.dst.sumbt.multiwoz_zh.sumbt import multiwoz_zh_slot_list
import random
import numpy
import torch
......@@ -34,9 +37,6 @@ crosswoz_slot_list = [
"酒店-酒店设施-室内游泳池", "酒店-酒店设施-早餐服务免费", "酒店-酒店设施-公共区域提供wifi", "酒店-酒店设施-室外游泳池"
]
from convlab2.dst.sumbt.multiwoz_zh.sumbt import multiwoz_zh_slot_list
from convlab2.dst.sumbt.crosswoz_en.sumbt import crosswoz_en_slot_list
def format_history(context):
history = []
......@@ -51,6 +51,7 @@ def sentseg(sent):
tmp = " ".join(jieba.cut(sent))
return ' '.join(tmp.split())
def reformat_state(state):
if 'belief_state' in state:
state = state['belief_state']
......@@ -73,11 +74,13 @@ def reformat_state(state):
else:
val = domain_data[slot]
if val is not None and val not in ['', 'not mentioned', '未提及', '未提到', '没有提到']:
new_state.append(domain+'_book' + '-' + slot + '-' + val)
new_state.append(domain+'_book' +
'-' + slot + '-' + val)
# lower
new_state = [item.lower() for item in new_state]
return new_state
def reformat_state_crosswoz(state):
if 'belief_state' in state:
state = state['belief_state']
......@@ -86,11 +89,13 @@ def reformat_state_crosswoz(state):
for domain in state.keys():
domain_data = state[domain]
for slot in domain_data.keys():
if slot == 'selectedResults': continue
if slot == 'selectedResults':
continue
val = domain_data[slot]
if slot == 'Hotel Facilities' and val not in ['', 'none']:
for facility in val.split(','):
new_state.append(domain + '-' + f'Hotel Facilities - {facility}' + 'yes')
new_state.append(
domain + '-' + f'Hotel Facilities - {facility}' + 'yes')
else:
if val is not None and val not in ['', 'none']:
# print(domain, slot, val)
......@@ -98,6 +103,7 @@ def reformat_state_crosswoz(state):
return new_state
def compute_acc(gold, pred, slot_temp):
# TODO: not mentioned in gold
miss_gold = 0
......@@ -116,6 +122,7 @@ def compute_acc(gold, pred, slot_temp):
ACC = ACC / float(ACC_TOTAL)
return ACC
def compute_prf(gold, pred):
TP, FP, FN = 0, 0, 0
if len(gold) != 0:
......@@ -130,7 +137,8 @@ def compute_prf(gold, pred):
FP += 1
precision = TP / float(TP + FP) if (TP + FP) != 0 else 0
recall = TP / float(TP + FN) if (TP + FN) != 0 else 0
F1 = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0
F1 = 2 * precision * recall / \
float(precision + recall) if (precision + recall) != 0 else 0
else:
if len(pred) == 0:
precision, recall, F1, count = 1, 1, 1, 1
......@@ -138,6 +146,7 @@ def compute_prf(gold, pred):
precision, recall, F1, count = 0, 0, 0, 1
return F1, recall, precision, count
def evaluate_metrics(all_prediction, from_which, slot_temp):
total, turn_acc, joint_acc, F1_pred, F1_count = 0, 0, 0, 0, 0
for d, v in all_prediction.items():
......@@ -148,11 +157,13 @@ def evaluate_metrics(all_prediction, from_which, slot_temp):
total += 1
# Compute prediction slot accuracy
temp_acc = compute_acc(set(cv["turn_belief"]), set(cv[from_which]), slot_temp)
temp_acc = compute_acc(
set(cv["turn_belief"]), set(cv[from_which]), slot_temp)
turn_acc += temp_acc
# Compute prediction joint F1 score
temp_f1, temp_r, temp_p, count = compute_prf(set(cv["turn_belief"]), set(cv[from_which]))
temp_f1, temp_r, temp_p, count = compute_prf(
set(cv["turn_belief"]), set(cv[from_which]))
F1_pred += temp_f1
F1_count += count
......@@ -161,6 +172,7 @@ def evaluate_metrics(all_prediction, from_which, slot_temp):
F1_score = F1_pred / float(F1_count) if F1_count != 0 else 0
return joint_acc_score, F1_score, turn_acc_score
if __name__ == '__main__':
seed = 2020
random.seed(seed)
......@@ -175,7 +187,7 @@ if __name__ == '__main__':
print("\t val=[val|test|human_val]")
sys.exit()
## init phase
# init phase
dataset_name = sys.argv[1]
model_name = sys.argv[2]
data_key = sys.argv[3]
......@@ -200,10 +212,11 @@ if __name__ == '__main__':
else:
raise Exception("Available models: TRADE/mdbt/sumbt")
## load data
# load data
from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader(dataset_name.endswith('zh')))
dataloader = AgentDSTDataloader(
dataset_dataloader=MultiWOZDataloader(dataset_name.endswith('zh')))
data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['belief_state']
all_predictions = {}
......@@ -238,8 +251,10 @@ if __name__ == '__main__':
if len(curr_sess) > 0:
all_predictions[session_count] = copy.deepcopy(curr_sess)
slot_list = multiwoz_zh_slot_list if dataset_name.endswith('zh') else multiwoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(all_predictions, "pred_bs_ptr", slot_list)
slot_list = multiwoz_zh_slot_list if dataset_name.endswith(
'zh') else multiwoz_slot_list
joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = evaluate_metrics(
all_predictions, "pred_bs_ptr", slot_list)
evaluation_metrics = {"Joint Acc": joint_acc_score_ptr, "Turn Acc": turn_acc_score_ptr,
"Joint F1": F1_score_ptr}
print(evaluation_metrics)
......@@ -268,7 +283,8 @@ if __name__ == '__main__':
from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader
dataloader = CrossWOZAgentDSTDataloader(dataset_dataloader=CrossWOZDataloader(en))
dataloader = CrossWOZAgentDSTDataloader(
dataset_dataloader=CrossWOZDataloader(en))
data = dataloader.load_data(data_key=data_key)[data_key]
context, golden_truth = data['context'], data['sys_state_init']
all_predictions = {}
......@@ -300,13 +316,15 @@ if __name__ == '__main__':
for domain in y.keys():
domain_data = y[domain]
for slot in domain_data.keys():
if slot == 'selectedResults': continue
if slot == 'selectedResults':
continue
val = domain_data[slot]
if val is not None and val != '':
val = sentseg(val)
domain_data[slot] = val
model.init_session()
model.state['history'] = format_history([item if en else sentseg(item) for item in context[i]])
model.state['history'] = format_history(
[item if en else sentseg(item) for item in context[i]])
pred = model.update(x[-1] if len(x) > 0 else '')
curr_sess[turn_count] = {
'turn_belief': reformat_state_crosswoz(y),
......
# Multi-domain Belief DST
The multidomain belief tracker (MDBT) is a belief tracking model that
fully utilizes semantic similarity between dialogue utterances and the
ontology terms, which is proposed by [Ramadan et al., 2018](https://www.aclweb.org/anthology/P18-2069).
## Package Structure
We adopted the original code to make it a flexible module which can be
easily imported in a pipeline dialog framework. The dataset-independent
implementation for MDBT is in ```convlab2/dst/mdbt```, and that for Multiwoz
dataset is in ```convlab2/dst/mdbt/multiwoz```.
## Run the Code
The framework will automatically download the pretrained models and data
before running. If the auto-downloading fails, you have to download the pre-trained model and data
from [here](https://drive.google.com/open?id=1k6wbabIlYju7kR0Zr4aVXwE_fsGBOtdw),
and put the ```word-vectors, models``` and ```data``` directories under
```convlab2/dst/mdbt/multiwoz/configs```.
git
## Performance
The performance of our pre-trained MDBT model is 13.9%.
You can train the model by your self for better performance.
import copy
import json
import os
import tensorflow as tf
from convlab2.dst.mdbt.mdbt_util import model_definition, \
track_dialogue, generate_batch, process_history
from convlab2.dst.rule.multiwoz import normalize_value
from convlab2.util.multiwoz.state import default_state
from convlab2.dst.dst import DST
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
from os.path import dirname
train_batch_size = 1
batches_per_eval = 10
no_epochs = 600
device = "gpu"
start_batch = 0
class MDBT(DST):
"""
A multi-domain belief tracker, adopted from https://github.com/osmanio2/multi-domain-belief-tracking.
"""
def __init__(self, ontology_vectors, ontology, slots, data_dir):
DST.__init__(self)
# data profile
self.data_dir = data_dir
self.validation_url = os.path.join(self.data_dir, 'data/validate.json')
self.word_vectors_url = os.path.join(self.data_dir, 'word-vectors/paragram_300_sl999.txt')
self.training_url = os.path.join(self.data_dir, 'data/train.json')
self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json')
self.testing_url = os.path.join(self.data_dir, 'data/test.json')
self.model_url = os.path.join(self.data_dir, 'models/model-1')
self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1')
self.results_url = os.path.join(self.data_dir, 'results/log-1.txt')
self.kb_url = os.path.join(self.data_dir, 'data/') # not used
self.train_model_url = os.path.join(self.data_dir, 'train_models/model-1')
self.train_graph_url = os.path.join(self.data_dir, 'train_graph/graph-1')
self.model_variables = model_definition(ontology_vectors, len(ontology), slots, num_hidden=None,
bidir=True, net_type=None, test=True, dev='cpu')
self.state = default_state()
_config = tf.ConfigProto()
_config.gpu_options.allow_growth = True
_config.allow_soft_placement = True
self.sess = tf.Session(config=_config)
self.param_restored = False
self.det_dic = {}
for domain, dic in REF_USR_DA.items():
for key, value in dic.items():
assert '-' not in key
self.det_dic[key.lower()] = key + '-' + domain
self.det_dic[value.lower()] = key + '-' + domain
def parent_dir(path, time=1):
for _ in range(time):
path = os.path.dirname(path)
return path
root_dir = parent_dir(os.path.abspath(__file__), 4)
self.value_dict = json.load(open(os.path.join(root_dir, 'data/multiwoz/value_dict.json')))
def init_session(self):
self.state = default_state()
if not self.param_restored:
self.restore()
def restore(self):
self.__restore_model(self.sess, tf.train.Saver())
def update_batch(self, batch_action):
pass
def update(self, user_act=None):
"""Update the dialog state."""
if type(user_act) is not str:
raise Exception('Expected user_act to be <class \'str\'> type, but get {}.'.format(type(user_act)))
prev_state = copy.deepcopy(self.state)
if not os.path.exists(os.path.join(self.data_dir, "results")):
os.makedirs(os.path.join(self.data_dir, "results"))
global train_batch_size
model_variables = self.model_variables
(user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy,
slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, predictions,
true_predictions, [y, _]) = model_variables
# Note: Comment the following line since the first node is already i
# prev_state['history'] = [['sys', 'null']] if len(prev_state['history']) == 0 else prev_state['history']
assert len(prev_state['history']) > 0
first_turn = prev_state['history'][0]
if first_turn[0] != 'sys':
prev_state['history'] = [['sys', '']] + prev_state['history']
actual_history = []
assert len(prev_state['history']) % 2 == 0
for name, utt in prev_state['history']:
if not utt:
utt = 'null'
if len(actual_history)==0 or len(actual_history[-1])==2:
actual_history.append([utt])
else:
actual_history[-1].append(utt)
# actual_history[-1].append(user_act)
# actual_history = self.normalize_history(actual_history)
# if len(actual_history) == 0:
# actual_history = [['', user_act if len(user_act)>0 else 'fake user act']]
fake_dialogue = {}
turn_no = 0
for _sys, _user in actual_history:
turn = {}
turn['system'] = _sys
fake_user = {}
fake_user['text'] = _user
fake_user['belief_state'] = default_state()['belief_state']
turn['user'] = fake_user
key = str(turn_no)
fake_dialogue[key] = turn
turn_no += 1
context, actual_context = process_history([fake_dialogue], self.word_vectors, self.ontology)
batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \
batch_no_turns = generate_batch(context, 0, 1, len(self.ontology)) # old feature
# run model
[pred, y_pred] = self.sess.run(
[predictions, y],
feed_dict={user: batch_user, sys_res: batch_sys,
labels: batch_labels,
domain_labels: batch_domain_labels,
user_uttr_len: batch_user_uttr_len,
sys_uttr_len: batch_sys_uttr_len,
no_turns: batch_no_turns,
keep_prob: 1.0})
# convert to str output
dialgs, _, _ = track_dialogue(actual_context, self.ontology, pred, y_pred)
assert len(dialgs) >= 1
last_turn = dialgs[0][-1]
predictions = last_turn['prediction']
new_belief_state = copy.deepcopy(prev_state['belief_state'])
# update belief state
for item in predictions:
item = item.lower()
domain, slot, value = item.strip().split('-')
value = value[::-1].split(':', 1)[1][::-1]
if slot == 'price range':
slot = 'pricerange'
if slot not in ['name', 'book']:
if domain not in new_belief_state:
raise Exception('Error: domain <{}> not in belief state'.format(domain))
slot = REF_SYS_DA[domain.capitalize( )].get(slot, slot)
assert 'semi' in new_belief_state[domain]
assert 'book' in new_belief_state[domain]
if 'book' in slot:
assert slot.startswith('book ')
slot = slot.strip().split()[1]
if slot == 'arriveby':
slot = 'arriveBy'
elif slot == 'leaveat':
slot = 'leaveAt'
domain_dic = new_belief_state[domain]
if slot in domain_dic['semi']:
new_belief_state[domain]['semi'][slot] = normalize_value(self.value_dict, domain, slot, value)
elif slot in domain_dic['book']:
new_belief_state[domain]['book'][slot] = value
elif slot.lower() in domain_dic['book']:
new_belief_state[domain]['book'][slot.lower()] = value
else:
with open('mdbt_unknown_slot.log', 'a+') as f:
f.write('unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(slot, value,
domain, item))
new_request_state = copy.deepcopy(prev_state['request_state'])
# update request_state
user_request_slot = self.detect_requestable_slots(user_act)
for domain in user_request_slot:
for key in user_request_slot[domain]:
if domain not in new_request_state:
new_request_state[domain] = {}
if key not in new_request_state[domain]:
new_request_state[domain][key] = user_request_slot[domain][key]
# update state
new_state = copy.deepcopy(dict(prev_state))
new_state['belief_state'] = new_belief_state
new_state['request_state'] = new_request_state
self.state = new_state
return self.state
def normalize_history(self, history):
"""Replace zero-length history."""
for i in range(len(history)):
a, b = history[i]
if len(a) == 0:
history[i][0] = 'sys'
if len(b) == 0:
history[i][1] = 'user'
return history
def detect_requestable_slots(self, observation):
result = {}
observation = observation.lower()
_observation = ' {} '.format(observation)
for value in self.det_dic.keys():
_value = ' {} '.format(value.strip())
if _value in _observation:
key, domain = self.det_dic[value].split('-')
if domain not in result:
result[domain] = {}
result[domain][key] = 0
return result
def __restore_model(self, sess, saver):
saver.restore(sess, self.model_url)
print('Loading trained MDBT model from ', self.model_url)
self.param_restored = True
This diff is collapsed.
from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT as MDBT
import json
import os
import time
import tensorflow as tf
import shutil
import zipfile
from convlab2.dst.mdbt.mdbt import MDBT
from convlab2.dst.mdbt.mdbt_util import load_word_vectors, load_ontology, load_woz_data_new
from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
from convlab2.util.file_util import cached_path
from pprint import pprint
train_batch_size = 1
batches_per_eval = 10
no_epochs = 600
device = "gpu"
start_batch = 0
class MultiWozMDBT(MDBT):
def __init__(self, data_dir='configs', data=None):
"""Constructor of MultiWOzMDBT class.
Args:
data_dir (str): The path of data dir, where the root path is convlab2/dst/mdbt/multiwoz.
"""
if data is None:
loader = AgentDSTDataloader(MultiWOZDataloader())
data = loader.load_data()
self.file_url = 'https://convlab.blob.core.windows.net/convlab-2/mdbt_multiwoz_sys.zip'
local_path = os.path.dirname(os.path.abspath(__file__))
self.data_dir = os.path.join(local_path, data_dir) # abstract data path
self.validation_url = os.path.join(self.data_dir, 'data/validate.json')
self.training_url = os.path.join(self.data_dir, 'data/train.json')
self.testing_url = os.path.join(self.data_dir, 'data/test.json')
self.word_vectors_url = os.path.join(self.data_dir, 'word-vectors/paragram_300_sl999.txt')
self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json')
self.model_url = os.path.join(self.data_dir, 'models/model-1')
self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1')
self.results_url = os.path.join(self.data_dir, 'results/log-1.txt')
self.kb_url = os.path.join(self.data_dir, 'data/') # not used
self.train_model_url = os.path.join(self.data_dir, 'train_models/model-1')
self.train_graph_url = os.path.join(self.data_dir, 'train_graph/graph-1')
self.auto_download()
print('Configuring MDBT model...')
self.word_vectors = load_word_vectors(self.word_vectors_url)
# Load the ontology and extract the feature vectors
self.ontology, self.ontology_vectors, self.slots = load_ontology(self.ontology_url, self.word_vectors)
# Load and process the training data
self.test_dialogues, self.actual_dialogues = load_woz_data_new(data['test'], self.word_vectors,
self.ontology, url=self.testing_url)
self.no_dialogues = len(self.test_dialogues)
super(MultiWozMDBT, self).__init__(self.ontology_vectors, self.ontology, self.slots, self.data_dir)
def auto_download(self):
"""Automatically download the pretrained model and necessary data."""
if os.path.exists(os.path.join(self.data_dir, 'models')) and \
os.path.exists(os.path.join(self.data_dir, 'data')) and \
os.path.exists(os.path.join(self.data_dir, 'word-vectors')):
return
cached_path(self.file_url, self.data_dir)
files = os.listdir(self.data_dir)
target_file = ''
for name in files:
if name.endswith('.json'):
target_file = name[:-5]
try:
assert target_file in files
except Exception as e:
print('allennlp download file error: MDBT Multiwoz data download failed.')
raise e
zip_file_path = os.path.join(self.data_dir, target_file+'.zip')
shutil.copyfile(os.path.join(self.data_dir, target_file), zip_file_path)
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(self.data_dir)
def test_update():
# lower case, tokenized.
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
tracker = MultiWozMDBT()
tracker.init_session()
# original usage in Convlab
# tracker.state['history'] = [
# ["null", "am looking for a place to to stay that has cheap price range it should be in a type of hotel"],
# ["Okay, do you have a specific area you want to stay in?", "no, i just need to make sure it's cheap. oh, and i need parking"],
# ["I found 1 cheap hotel for you that includes parking. Do you like me to book it?", "Yes, please. 6 people 3 nights starting on tuesday."],
# ["I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay?", "how about only 2 nights."],
# ["Booking was successful.\nReference number is : 7GAWK763. Anything else I can do for you?"]
# ]
# current usage in Convlab2
tracker.state['history'] = [
['sys', ''],
['user', 'Could you book a 4 stars hotel for one night, 1 person?'],
['sys', 'If you\'d like something cheap, I recommend the Allenbell']
]
tracker.state['history'].append(['user', 'Friday and Can you book it for me and get a reference number ?'])
user_utt = 'Friday and Can you book it for me and get a reference number ?'
from timeit import default_timer as timer
start = timer()
pprint(tracker.update(user_utt))
end = timer()
print(end - start)
start = timer()
tracker.update(user_utt)
end = timer()
print(end - start)
start = timer()
tracker.update(user_utt)
end = timer()
print(end - start)
if __name__ == '__main__':
test_update()
......@@ -31,6 +31,7 @@ class RuleDST(DST):
:param user_act:
:return:
"""
#print("dst", user_act)
for intent, domain, slot, value in user_act:
domain = domain.lower()
intent = intent.lower()
......@@ -43,7 +44,8 @@ class RuleDST(DST):
try:
assert domain in self.state['belief_state']
except:
raise Exception('Error: domain <{}> not in new belief state'.format(domain))
raise Exception(
'Error: domain <{}> not in new belief state'.format(domain))
domain_dic = self.state['belief_state'][domain]
assert 'semi' in domain_dic
assert 'book' in domain_dic
......@@ -53,13 +55,15 @@ class RuleDST(DST):
elif k in domain_dic['book']:
self.state['belief_state'][domain]['book'][k] = value
elif k.lower() in domain_dic['book']:
self.state['belief_state'][domain]['book'][k.lower()] = value
self.state['belief_state'][domain]['book'][k.lower()
] = value
elif k == 'trainID' and domain == 'train':
self.state['belief_state'][domain]['book'][k] = normalize_value(self.value_dict, domain, k, value)
elif k != 'none':
# raise Exception('unknown slot name <{}> of domain <{}>'.format(k, domain))
with open('unknown_slot.log', 'a+') as f:
f.write('unknown slot name <{}> of domain <{}>\n'.format(k, domain))
f.write(
'unknown slot name <{}> of domain <{}>\n'.format(k, domain))
elif intent == 'request':
k = REF_SYS_DA[domain.capitalize()].get(slot, slot)
if domain not in self.state['request_state']:
......
import json
import os
from convlab2.util.multiwoz.state import default_state
from convlab2.dst.rule.multiwoz.dst_util import normalize_value
from convlab2.dst.rule.multiwoz import RuleDST
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
from pprint import pprint
SLOT2SEMI = {
"arriveby": "arriveBy",
"leaveat": "leaveAt",
"trainid": "trainID",
}
class UserRuleDST(RuleDST):
"""Rule based DST which trivially updates new values from NLU result to states.
Attributes:
state(dict):
Dialog state. Function ``convlab2.util.multiwoz.state.default_state`` returns a default state.
value_dict(dict):
It helps check whether ``user_act`` has correct content.
"""
def __init__(self):
super().__init__()
self.mentioned_domain = []
def update(self, sys_act=None):
"""
update belief_state, request_state
:param sys_act:
:return:
"""
# print("dst", user_act)
self.update_mentioned_domain(sys_act)
for intent, domain, slot, value in sys_act:
domain = domain.lower()
intent = intent.lower()
if domain in ['unk', 'general']:
continue
# TODO domain: booking
if domain == "booking":
for domain in self.mentioned_domain:
self.update_inform_request(
intent, domain, slot, value)
else:
self.update_inform_request(intent, domain, slot, value)
return self.state
def init_session(self):
"""Initialize ``self.state`` with a default state, which ``convlab2.util.multiwoz.state.default_state`` returns."""
self.state = default_state()
self.mentioned_domain = []
def update_mentioned_domain(self, sys_act):
if not sys_act:
return
for intent, domain, slot, value in sys_act:
domain = domain.lower()
if domain not in self.mentioned_domain and domain not in ['unk', 'general', 'booking']:
self.mentioned_domain.append(domain)
# print(f"update: mentioned {domain} domain")
def update_inform_request(self, intent, domain, slot, value):
slot = slot.lower()
k = SysDa2Goal[domain].get(slot, slot)
k = SLOT2SEMI.get(k, k)
if k is None:
return
try:
assert domain in self.state['belief_state']
except:
raise Exception(
'Error: domain <{}> not in new belief state'.format(domain))
domain_dic = self.state['belief_state'][domain]
assert 'semi' in domain_dic
assert 'book' in domain_dic
if k in domain_dic['semi']:
nvalue = normalize_value(self.value_dict, domain, k, value)
self.state['belief_state'][domain]['semi'][k] = nvalue
elif k in domain_dic['book']:
self.state['belief_state'][domain]['book'][k] = value
elif k.lower() in domain_dic['book']:
self.state['belief_state'][domain]['book'][k.lower()
] = value
elif k == 'trainID' and domain == 'train':
self.state['belief_state'][domain]['book'][k] = normalize_value(
self.value_dict, domain, k, value)
else:
# print('unknown slot name <{}> of domain <{}>'.format(k, domain))
nvalue = normalize_value(self.value_dict, domain, k, value)
self.state['belief_state'][domain]['semi'][k] = nvalue
with open('unknown_slot.log', 'a+') as f:
f.write(
'unknown slot name <{}> of domain <{}>\n'.format(k, domain))
def update_request(self):
pass
def update_booking(self):
pass
if __name__ == '__main__':
# from convlab2.dst.rule.multiwoz import RuleDST
dst = UserRuleDST()
action = [['Inform', 'Restaurant', 'Phone', '01223323737'],
['reqmore', 'general', 'none', 'none'],
["Inform", "Hotel", "Area", "east"], ]
state = dst.update(action)
pprint(state)
dst.init_session()
# Action is a dict. Its keys are strings(domain-type pairs, both uppercase and lowercase is OK) and its values are list of lists.
# The domain may be one of ('Attraction', 'Hospital', 'Booking', 'Hotel', 'Restaurant', 'Taxi', 'Train', 'Police').
# The type may be "inform" or "request".
# For example, the action below has a key "Hotel-Inform", in which "Hotel" is domain and "Inform" is action type.
# Each list in the value of "Hotel-Inform" is a slot-value pair. "Area" is slot and "east" is value. "Star" is slot and "4" is value.
action = [
["Inform", "Hotel", "Area", "east"],
["Inform", "Hotel", "Stars", "4"]
]
# method `update` updates the attribute `state` of tracker, and returns it.
state = dst.update(action)
assert state == dst.state
assert state == {'user_action': [],
'system_action': [],
'belief_state': {'police': {'book': {'booked': []}, 'semi': {}},
'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''},
'semi': {'name': '',
'area': 'east',
'parking': '',
'pricerange': '',
'stars': '4',
'internet': '',
'type': ''}},
'attraction': {'book': {'booked': []},
'semi': {'type': '', 'name': '', 'area': ''}},
'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''},
'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}},
'hospital': {'book': {'booked': []}, 'semi': {'department': ''}},
'taxi': {'book': {'booked': []},
'semi': {'leaveAt': '',
'destination': '',
'departure': '',
'arriveBy': ''}},
'train': {'book': {'booked': [], 'people': ''},
'semi': {'leaveAt': '',
'destination': '',
'day': '',
'arriveBy': '',
'departure': ''}}},
'request_state': {},
'terminated': False,
'history': []}
# Please call `init_session` before a new dialog. This initializes the attribute `state` of tracker with a default state, which `convlab2.util.multiwoz.state.default_state` returns. But You needn't call it before the first dialog, because tracker gets a default state in its constructor.
dst.init_session()
action = [["Inform", "Train", "Arrive", "19:45"]]
state = dst.update(action)
assert state == {'user_action': [],
'system_action': [],
'belief_state': {'police': {'book': {'booked': []}, 'semi': {}},
'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''},
'semi': {'name': '',
'area': '',
'parking': '',
'pricerange': '',
'stars': '',
'internet': '',
'type': ''}},
'attraction': {'book': {'booked': []},
'semi': {'type': '', 'name': '', 'area': ''}},
'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''},
'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}},
'hospital': {'book': {'booked': []}, 'semi': {'department': ''}},
'taxi': {'book': {'booked': []},
'semi': {'leaveAt': '',
'destination': '',
'departure': '',
'arriveBy': ''}},
'train': {'book': {'booked': [], 'people': ''},
'semi': {'leaveAt': '',
'destination': '',
'day': '',
'arriveBy': '19:45',
'departure': ''}}},
'request_state': {},
'terminated': False,
'history': []}
# Our paper
[Uncertainty Measures in Neural Belief Tracking and the Effects on Dialogue Policy Performance](https://todo.pdf)
## Structure
![SetSUMBT Architecture](https://gitlab.cs.uni-duesseldorf.de/dsml/convlab-2/-/raw/develop/convlab2/dst/setsumbt/setSUMBT.png?inline=false)
## Usages
### Data preprocessing
We conduct experiments on the following datasets:
* MultiWOZ 2.1 [Download](https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip) to get `MULTIWOZ2.1.zip`
### Train
**Train baseline single instance SetSUMBT**
```
python run.py --run_nbt \
--use_descriptions --set_similarity \
--do_train --do_eval \
--seed 20211202
```
**Train ensemble SetSUMBT**
```
SEED=20211202
MODEL_PATH="models/SetSUMBT-CE-roberta-gru-cosine-labelsmoothing-Seed$SEED-$(date +'%d-%m-%Y')"
./configure_ensemble.sh $SEED $MODEL_PATH
./train_ensemble.sh $SEED $MODEL_PATH
```
**Distill Ensemble SetSUMBT**
```
SEED=20211202
MODEL_PATH="models/SetSUMBT-CE-roberta-gru-cosine-labelsmoothing-Seed$SEED-$(date +'%d-%m-%Y')"
./distill_end.sh $SEED $MODEL_PATH
```
**Distribution Distill Ensemble SetSUMBT**
```
SEED=20211202
MODEL_PATH="models/SetSUMBT-CE-roberta-gru-cosine-labelsmoothing-Seed$SEED-$(date +'%d-%m-%Y')"
./distill_end2.sh $SEED $MODEL_PATH
```
### Evaluation
```
SEED=20211202
MODEL_PATH="models/SetSUMBT-CE-roberta-gru-cosine-labelsmoothing-Seed$SEED-$(date +'%d-%m-%Y')"
python run.py --run_calibration \
--seed $SEED \
--output_dir $MODEL_PATH
```
### Convert training setup to convlab model
```
SEED=20211202
MODEL_PATH="models/SetSUMBT-CE-roberta-gru-cosine-labelsmoothing-Seed$SEED-$(date +'%d-%m-%Y')"
OUT_PATH="models/labelsmoothing"
./configure_model.sh $MODEL_PATH data $OUT_PATH
```
### Training PPO policy using SetSUMBT tracker and uncertainty
To train a PPO policy switch to the directory:
```
cd ../../policy/ppo
```
In this directory run the relevant train script, for example to train the policy using END-SetSUMBT using no uncertainty metrics run:
```
./train_setsumbt_end_baseline.sh
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment