Skip to content
Snippets Groups Projects
Commit f5ce0b8d authored by Hsien-Chin Lin's avatar Hsien-Chin Lin Committed by Carel van Niekerk
Browse files

Us readme

parent 16304fcd
No related branches found
No related tags found
No related merge requests found
**GenTUS** is a data-driven user simulator with transformers, which can generate semantic actions and utterence. It is able to trasfer to a new ontology in a zero-shot fashion.
## Introduction
We propose a generative transform-based user simulator (GenTUS) in this work. GenTUS consists of an encoder-decoder structure, which can optimise both the user policy and natural language generation jointly. GenTUS generates semantic actions and natural language utterances, preserving interpretability and enhancing language variation.
The code of TUS is in `convlab/policy/genTUS`.
## Usage
### Train GenTUS from scratch
You need to generate the input files by `build_data.py`, then train the model by `train_model.py`.
```
python3 convlab/policy/genTUS/unify/build_data.py --dataset $dataset --add-history --dial-ids-order $dial_ids_order --split2ratio $split2ratio
python3 convlab/policy/genTUS/train_model.py --data-name $dataset --dial-ids-order $dial_ids_order --split2ratio $split2ratio --batch-size 8
```
`dataset` can be `multiwoz21`, `sgd`, `tm`, `sgd+tm`, or `all`.
`dial_ids_order` can be 0, 1 or 2
`split2ratio` can be 0.01, 0.1 or 1
The `build_data.py` will generate three files, `train.json`, `validation.json`, and `test.json`, under the folder `convlab/policy/genTUS/unify/data/${dataset}_${dial_ids_order}_${split2ration}`.
We trained GenTUS on A100 or RTX6000.
### Evaluate TUS
```
python3 convlab/policy/genTUS/evaluate.py --model-checkpoint $model_checkpoint --input-file $in_file --dataset $dataset --do-nlg
```
The `in_file` is the file generated by `build_data.py`.
### Train a dialogue policy with GenTUS
You can use it as a normal user simulator by `PipelineAgent`. For example,
```python
from convlab.dialog_agent import PipelineAgent
from convlab.util.custom_util import set_seed
model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21-exp'
usr_policy = UserPolicy(model_checkpoint, mode="semantic")
simulator = PipelineAgent(None, None, usr_policy, None, 'user')
```
then you can train your system with this simulator.
You can also change the `mode` to `"language"`, then GenTUS will response in natural language instead of semantic actions.
<!---citation--->
## Citing
```
@inproceedings{lin-etal-2022-gentus,
title = "{G}en{TUS}: Simulating User Behaviour and Language in Task-oriented Dialogues with Generative Transformers",
author = "Lin, Hsien-chin and
Geishauser, Christian and
Feng, Shutong and
Lubis, Nurul and
van Niekerk, Carel and
Heck, Michael and
Gasic, Milica",
booktitle = "Proceedings of the 23rd Annual Meeting of the Special Interest Group on Discourse and Dialogue",
month = sep,
year = "2022",
address = "Edinburgh, UK",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.sigdial-1.28",
pages = "270--282",
abstract = "User simulators (USs) are commonly used to train task-oriented dialogue systems via reinforcement learning. The interactions often take place on semantic level for efficiency, but there is still a gap from semantic actions to natural language, which causes a mismatch between training and deployment environment. Incorporating a natural language generation (NLG) module with USs during training can partly deal with this problem. However, since the policy and NLG of USs are optimised separately, these simulated user utterances may not be natural enough in a given context. In this work, we propose a generative transformer-based user simulator (GenTUS). GenTUS consists of an encoder-decoder structure, which means it can optimise both the user policy and natural language generation jointly. GenTUS generates both semantic actions and natural language utterances, preserving interpretability and enhancing language variation. In addition, by representing the inputs and outputs as word sequences and by using a large pre-trained language model we can achieve generalisability in feature representation. We evaluate GenTUS with automatic metrics and human evaluation. Our results show that GenTUS generates more natural language and is able to transfer to an unseen ontology in a zero-shot fashion. In addition, its behaviour can be further shaped with reinforcement learning opening the door to training specialised user simulators.",
}
```
## License
Apache License 2.0
......@@ -11,6 +11,8 @@ from convlab.policy.genTUS.unify.Goal import Goal
from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
from convlab.policy.policy import Policy
from convlab.task.multiwoz.goal_generator import GoalGenerator
from convlab.util.custom_util import model_downloader
DEBUG = False
......@@ -589,10 +591,10 @@ class UserPolicy(Policy):
action_penalty=False,
**kwargs):
# self.config = config
# if not os.path.exists(self.config["model_dir"]):
# os.mkdir(self.config["model_dir"])
# model_downloader(self.config["model_dir"],
# "https://zenodo.org/record/5779832/files/default.zip")
if not os.path.exists(os.path.dirname(model_checkpoint)):
os.mkdir(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
self.policy = UserActionPolicy(
model_checkpoint,
......@@ -636,11 +638,11 @@ if __name__ == "__main__":
set_seed(20220220)
# Test semantic level behaviour
model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0'
model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21-exp'
usr_policy = UserPolicy(
model_checkpoint,
mode="semantic")
usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
# usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
usr_nlu = None # BERTNLU()
usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
print(usr.policy.get_goal())
......
......@@ -54,8 +54,8 @@ class Goal:
self.raw_goal = goal.domain_goals
goal = old_goal2list(goal.domain_goals)
else:
print("unknow goal")
# else:
# print("unknow goal")
# be careful of this order
for domain, intent, slot, value in goal:
......
**TUS** is a domain-independent user simulator with transformers for task-oriented dialogue systems. It is based on the [ConvLab-2](https://github.com/thu-coai/ConvLab-2) framework. Therefore, you should follow their instruction to install the package.
**TUS** is a domain-independent user simulator with transformers for task-oriented dialogue systems.
## Introduction
Our model is a domain-independent user simulator, which means it is not based on any domain-dependent freatures and the output representation is also domain-independent. Therefore, it can easily adapt to a new domain, without additional feature engineering and model retraining.
Our model is a domain-independent user simulator, which means its input and output representations are domain agnostic. Therefore, it can easily adapt to a new domain, without additional feature engineering and model retraining.
The code of TUS is in `convlab/policy/tus` and a rule-based DST of user is also created in `convlab/dst/rule/multiwoz/dst.py` based on the rule-based DST in `convlab/dst/rule/multiwoz/dst.py`.
The code of TUS is in `convlab/policy/tus`.
## How to run the model
### Train the user simulator
`python3 convlab/policy/tus/multiwoz/train.py --user_config convlab/policy/tus/multiwoz/exp/default.json`
## Usage
### Train TUS from scratch
One default configuration is placed in `convlab/policy/tus/multiwoz/exp/default.json`. They can be modified based on your requirements. For example, the output directory can be specified in the configuration (`model_dir`).
```
python3 convlab/policy/tus/unify/train.py --dataset $dataset --dial-ids-order $dial_ids_order --split2ratio $split2ratio --user-config $config
```
`dataset` can be `multiwoz21`, `sgd`, `tm`, `sgd+tm`, or `all`.
`dial_ids_order` can be 0, 1 or 2
`split2ratio` can be 0.01, 0.1 or 1
Default configurations are placed in `convlab/policy/tus/unify/exp`. They can be modified based on your requirements.
For example, you can train TUS for multiwoz21 by
`python3 convlab/policy/tus/unify/train.py --dataset multiwoz21 --dial-ids-order 0 --split2ratio 1 --user-config "convlab/policy/tus/unify/exp/multiwoz.json"`
### Evaluate TUS
### Train a dialogue policy with TUS
You can use it as a normal user simulator by `PipelineAgent`. For example,
```python
import json
from convlab.dialog_agent.agent import PipelineAgent
from convlab.dst.rule.multiwoz.usr_dst import UserRuleDST
from convlab.policy.tus.multiwoz.TUS import UserPolicy
from convlab.policy.tus.unify.TUS import UserPolicy
user_config_file = "convlab/policy/tus/multiwoz/exp/default.json"
dst_usr = UserRuleDST()
user_config_file = "convlab/policy/tus/unify/exp/multiwoz.json"
user_config = json.load(open(user_config_file))
policy_usr = UserPolicy(user_config)
simulator = PipelineAgent(None, dst_usr, policy_usr, None, 'user')
simulator = PipelineAgent(None, None, policy_usr, None, 'user')
```
then you can train your system with this simulator.
There is an example config, which trains a PPO policy with TUS in semantic level, in `convlab/policy/ppo/tus_semantic_level_config.json`.
You can train a PPO policy as following,
```
config="convlab/policy/ppo/tus_semantic_level_config.json"
python3 convlab/policy/ppo/train.py --path $config
```
notice: You should name your pretrained policy as `convlab/policy/ppo/pretrained_models/mle` or modify the `load_path` of `model` in the config `convlab/policy/ppo/tus_semantic_level_config.json`.
<!---citation--->
......
......@@ -132,6 +132,8 @@ class UserActionPolicy(Policy):
self.raw_goal = goal.domain_goals
goal_list = old_goal2list(goal.domain_goals)
goal = Goal(goal_list)
elif type(goal) == Goal:
self.raw_goal = goal.domain_goals
else:
goal = ABUS_Goal(self.goal_gen)
self.raw_goal = goal.domain_goals
......@@ -416,7 +418,7 @@ class UserPolicy(Policy):
if not os.path.exists(self.config["model_dir"]):
# os.mkdir(self.config["model_dir"])
model_downloader(os.path.dirname(self.config["model_dir"]),
"https://zenodo.org/record/5779832/files/default.zip")
"https://zenodo.org/record/7369429/files/multiwoz_0.zip")
self.slot2dbattr = {
'open hours': 'openhours',
'price range': 'pricerange',
......@@ -451,30 +453,4 @@ class UserPolicy(Policy):
def get_goal(self):
if hasattr(self.policy, 'get_goal'):
return self.policy.get_goal()
# workaround: convert goal to old format
multiwoz_goal = {}
goal = self.policy.get_goal()
for domain in goal:
multiwoz_goal[domain] = {}
for slot_type in ["info", "reqt"]:
if slot_type not in goal[domain]:
continue
if slot_type not in multiwoz_goal[domain]:
multiwoz_goal[domain][slot_type] = {}
for slot in goal[domain][slot_type]:
value = goal[domain][slot_type][slot].lower()
if "book" in slot:
if "book" not in multiwoz_goal[domain]:
multiwoz_goal[domain]["book"] = {}
norm_slot = slot.split(' ')[-1]
multiwoz_goal[domain]["book"][norm_slot] = value
elif slot in self.slot2dbattr:
norm_slot = self.slot2dbattr[slot]
multiwoz_goal[domain][slot_type][norm_slot] = value
else:
multiwoz_goal[domain][slot_type][slot] = value
for domain in multiwoz_goal:
if "book" in multiwoz_goal[domain]:
multiwoz_goal[domain]["booked"] = '?'
return multiwoz_goal
return None
import argparse
import datetime
import json
import logging
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
from convlab.dialog_agent.agent import PipelineAgent
from convlab.dialog_agent.env import Environment
from convlab.dialog_agent.session import BiSession
from convlab.dst.rule.multiwoz import RuleDST
from convlab.dst.rule.multiwoz.usr_dst import UserRuleDST
from convlab.evaluator.multiwoz_eval import MultiWozEvaluator
from convlab.policy.rule.multiwoz import RulePolicy
from convlab.policy.tus.multiwoz import util
from convlab.policy.tus.multiwoz.transformer import TransformerActionPrediction
from convlab.policy.tus.unify.TUS import UserPolicy
from convlab.policy.tus.unify.usermanager import TUSDataManager
from convlab.policy.tus.unify.util import (create_goal, int2onehot,
metadata2state, parse_dialogue_act,
parse_user_goal, split_slot_name)
from convlab.util import load_dataset, load_ontology
from sklearn import metrics
from torch.utils.data import DataLoader
from tqdm import tqdm
def get_f1(target, result):
target_len = 0
result_len = 0
tp = 0
for t, r in zip(target, result):
if t:
target_len += 1
if r:
result_len += 1
if r == t and t:
tp += 1
precision = 0
recall = 0
if result_len:
precision = tp / result_len
if target_len:
recall = tp / target_len
if precision and recall:
f1_score = 2 / (1 / precision + 1 / recall)
else:
f1_score = "NAN"
return f1_score, precision, recall
from convlab.policy.rule.multiwoz import RulePolicy
from convlab.policy.tus.unify.Goal import Goal
from convlab.policy.tus.unify.TUS import UserPolicy
from convlab.policy.tus.unify.usermanager import TUSDataManager
from convlab.policy.tus.unify.util import create_goal, parse_dialogue_act
from convlab.util import load_dataset
def check_device():
......@@ -63,24 +24,6 @@ def check_device():
return torch.device('cpu')
def init_logging(log_dir_path, path_suffix=None):
if not os.path.exists(log_dir_path):
os.makedirs(log_dir_path)
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
if path_suffix:
log_file_path = os.path.join(
log_dir_path, f"{current_time}_{path_suffix}.log")
else:
log_file_path = os.path.join(
log_dir_path, "{}.log".format(current_time))
stderr_handler = logging.StreamHandler()
file_handler = logging.FileHandler(log_file_path)
format_str = "%(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s"
logging.basicConfig(level=logging.DEBUG, handlers=[
stderr_handler, file_handler], format=format_str)
class Analysis:
def __init__(self, config, analysis_dir='user-analysis-result', show_dialog=False, save_dialog=True):
if not os.path.exists(analysis_dir):
......@@ -95,175 +38,18 @@ class Analysis:
self.save_dialog = save_dialog
self.max_turn = 40
def get_sys(self, sys="rule", load_path=None):
dst = RuleDST()
sys = sys.lower()
if sys == "rule":
policy = RulePolicy()
elif sys == "ppo":
from convlab.policy.ppo import PPO
if load_path:
policy = PPO(False, use_action_mask=True, shrink=False)
policy.load(load_path)
else:
policy = PPO.from_pretrained()
elif sys == "vtrace":
from convlab.policy.vtrace_rnn_action_embedding import VTRACE_RNN
policy = VTRACE_RNN(
is_train=False, seed=0, use_masking=True, shrink=False)
policy.load(load_path)
else:
print(f"Unsupport system type: {sys}")
return dst, policy
def get_usr(self, usr="tus", load_path=None):
# if using "tus", we read config
# for the other user simulators, we read load_path
usr = usr.lower()
if usr == "rule":
dst_usr = None
policy_usr = RulePolicy(character='usr')
elif usr == "tus":
if usr == "tus":
policy_usr = UserPolicy(self.config)
elif usr == "ppo-tus":
from convlab.policy.ppo.ppo_usr import PPO_USR
dst_usr = UserRuleDST()
policy_usr = PPO_USR(pre_trained_config=self.config)
policy_usr.load(load_path)
elif usr == "vhus":
from convlab.policy.vhus.multiwoz import UserPolicyVHUS
dst_usr = None
policy_usr = UserPolicyVHUS(
load_from_zip=True, model_file="vhus_simulator_multiwoz.zip")
else:
print(f"Unsupport user type: {usr}")
# TODO VHUS
return policy_usr
def interact_test(self,
sys="rule",
usr="tus",
sys_load_path=None,
usr_load_path=None,
num_dialog=400,
domain=None):
# TODO need refactor
seed = 20190827
torch.manual_seed(seed)
sys = sys.lower()
usr = usr.lower()
sess = self._set_interactive_test(
sys, usr, sys_load_path, usr_load_path)
task_success = {
# 'All_user_sim': [], 'All_evaluator': [], 'total_return': []}
'complete': [], 'success': [], 'reward': []}
turn_slot_num = {i: [] for i in range(self.max_turn)}
turn_domain_num = {i: [] for i in range(self.max_turn)}
true_max_turn = 0
for seed in tqdm(range(1000, 1000 + num_dialog)):
# logging.info(f"Seed: {seed}")
random.seed(seed)
np.random.seed(seed)
sess.init_session()
# if domain is not none, the user goal must contain certain domain
if domain:
domain = domain.lower()
print(f"check {domain}")
while 1:
if domain in sess.user_agent.policy.get_goal():
break
sess.user_agent.init_session()
sys_uttr = []
actions = 0
total_return = 0.0
if self.save_dialog:
f = open(os.path.join(self.dialog_dir, str(seed)), 'w')
for turn in range(self.max_turn):
sys_uttr, usr_uttr, finish, reward = sess.next_turn(sys_uttr)
if self.show_dialog:
print(f"USR: {usr_uttr}")
print(f"SYS: {sys_uttr}")
if self.save_dialog:
f.write(f"USR: {usr_uttr}\n")
f.write(f"SYS: {sys_uttr}\n")
actions += len(usr_uttr)
turn_slot_num[turn].append(len(usr_uttr))
turn_domain_num[turn].append(self._get_domain_num(usr_uttr))
total_return += sess.user_agent.policy.policy.get_reward()
if finish:
task_succ = sess.evaluator.task_success()
break
if turn > true_max_turn:
true_max_turn = turn
if self.save_dialog:
f.close()
# logging.info(f"Return: {total_return}")
# logging.info(f"Average actions: {actions / (turn+1)}")
task_success['complete'].append(
int(sess.user_agent.policy.policy.goal.task_complete()))
task_success['success'].append(task_succ)
task_success['reward'].append(total_return)
task_summary = {key: [0] for key in task_success}
for key in task_success:
if task_success[key]:
task_summary[key][0] = np.average(task_success[key])
for key in task_success:
logging.info(
f'{key} {len(task_success[key])} {task_summary[key][0]}')
# logging.info("Average action in turn")
write = {'turn_slot_num': [], 'turn_domain_num': []}
for turn in turn_slot_num:
if turn > true_max_turn:
break
avg = 0
if turn_slot_num[turn]:
avg = sum(turn_slot_num[turn]) / len(turn_slot_num[turn])
write['turn_slot_num'].append(avg)
# logging.info(f"turn {turn}: {avg} slots")
for turn in turn_domain_num:
if turn > true_max_turn:
break
avg = 0
if turn_domain_num[turn]:
avg = sum(turn_domain_num[turn]) / len(turn_domain_num[turn])
write['turn_domain_num'].append(avg)
# logging.info(f"turn {turn}: {avg} domains")
# write results
pd.DataFrame.from_dict(write).to_csv(
os.path.join(self.dir, f'{sys}-{usr}-turn-statistics.csv'))
pd.DataFrame.from_dict(task_summary).to_csv(
os.path.join(self.dir, f'{sys}-{usr}-task-summary.csv'))
def _get_domain_num(self, action):
# act: [Intent, Domain, Slot, Value]
return len(set(act[1] for act in action))
def _set_interactive_test(self, sys, usr, sys_load_path, usr_load_path):
dst_sys, policy_sys = self.get_sys(sys, sys_load_path)
dst_usr, policy_usr = self.get_usr(usr, usr_load_path)
usr = PipelineAgent(None, dst_usr, policy_usr, None, 'user')
sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')
env = Environment(None, usr, None, dst_sys)
evaluator = MultiWozEvaluator()
sess = BiSession(sys, usr, None, evaluator)
return sess
def data_interact_test(self, test_data, usr="tus", user_mode=None, load_path=None):
if user_mode:
# origin_model_name = "-".join(self.config["model_name"].split('-')[:-1])
......@@ -276,7 +62,7 @@ class Analysis:
for dialog in tqdm(test_data):
if self.show_dialog:
print(f"dialog_id: {dialog['dialog_id']}")
goal = create_goal(dialog)
goal = Goal(create_goal(dialog))
sys_act = []
policy_usr.init_session(goal=goal)
......@@ -288,21 +74,13 @@ class Analysis:
start = 1
for turn_id in range(start, turn_num, 2):
if turn_id > 0:
# cur_state = data[dialog_id]["log"][turn_id-1]["metadata"]
sys_act = parse_dialogue_act(
dialog["turns"][turn_id - 1]["dialogue_acts"])
usr_act = policy_usr.predict(sys_act)
golden_usr = parse_dialogue_act(
dialog["turns"][turn_id]["dialogue_acts"])
# sys_act = parse_dialogue_act(
# dialog["turns"][turn_id + 1]["dialogue_acts"])
result.append(usr_act)
label.append(golden_usr)
# if self.show_dialog:
# print(f"---> turn {turn_id} ")
# print(f"pre: {usr_act}")
# print(f"ans: {golden_usr}")
# print(f"sys: {sys_act}")
for domain in [None]:
......@@ -329,6 +107,7 @@ class Analysis:
self.config["model_dir"], f'{user_mode}_data_scores.csv'))
def _extract_domain_related_actions(self, actions, select_domain):
#
domain_related_acts = []
for act in actions:
domain = act[1].lower()
......@@ -337,6 +116,7 @@ class Analysis:
return domain_related_acts
def _data_f1(self, result, label, domain=None):
#
statistic = {}
for stat_type in ["precision", "recall", "turn_acc"]:
statistic[stat_type] = {"success": 0, "count": 0}
......@@ -365,6 +145,7 @@ class Analysis:
@staticmethod
def _skip(label, result, domain=None):
#
ignore = False
if domain:
if not label and not result:
......@@ -379,6 +160,7 @@ class Analysis:
return ignore
def _check(self, r, l):
#
# TODO domain check
# [['Inform', 'Attraction', 'Addr', 'dontcare']] [['thank', 'general', 'none', 'none']]
# skip this one
......@@ -412,6 +194,7 @@ class Analysis:
@staticmethod
def _is_in(a, acts):
#
is_none_slot = False
intent, domain, slot, value = a
if slot.lower() == "none" or domain.lower() == "general":
......@@ -425,145 +208,16 @@ class Analysis:
return is_none_slot, True
return is_none_slot, False
def direct_test(self, model, test_data, user_mode=None):
model = model.to(self.device)
model.zero_grad()
model.eval()
y_lable, y_pred = [], []
y_turn = []
result = {} # old way
with torch.no_grad():
for i, data in enumerate(tqdm(test_data, ascii=True, desc="Evaluation"), 0):
input_feature = data["input"].to(self.device)
mask = data["mask"].to(self.device)
label = data["label"].to(self.device)
output = model(input_feature, mask)
y_l, y_p, y_t, r = self.parse_result(output, label)
y_lable += y_l
y_pred += y_p
y_turn += y_t
# old way
for r_type in r:
if r_type not in result:
result[r_type] = {"correct": 0, "total": 0}
for n in result[r_type]:
result[r_type][n] += float(r[r_type][n])
old_result = {}
for r_type in result:
temp = result[r_type]['correct'] / result[r_type]['total']
old_result[r_type] = [temp]
pd.DataFrame.from_dict(old_result).to_csv(
os.path.join(self.dir, f'{user_mode}_old_result.csv'))
cm = self.model_confusion_matrix(y_lable, y_pred)
self.summary(y_lable, y_pred, y_turn, cm,
file_name=f'{user_mode}_scores.csv')
return old_result
def summary(self, y_true, y_pred, y_turn, cm, file_name='scores.csv'):
f1, pre, rec = get_f1(y_true, y_pred)
result = {
'f1': f1, # metrics.f1_score(y_true, y_pred, average='micro'),
# metrics.precision_score(y_true, y_pred, average='micro'),
'precision': pre,
# metrics.recall_score(y_true, y_pred, average='micro'),
'recall': rec,
'none-zero-acc': self.none_zero_acc(cm),
'turn-acc': sum(y_turn) / len(y_turn)}
col = [c for c in result]
df_f1 = pd.DataFrame([result[c] for c in col], col)
df_f1.to_csv(os.path.join(self.dir, file_name))
print("summary")
print(df_f1)
def none_zero_acc(self, cm):
# ['Unnamed: 0', 'none', '?', 'dontcare', 'sys', 'usr', 'random']
col = cm.columns[1:]
num_label = cm.sum(axis=1)
correct = 0
for col_name in col:
correct += cm[col_name][col_name]
return correct / sum(num_label[1:])
def model_confusion_matrix(self, y_true, y_pred, file_name='cm.csv', legend=["none", "?", "dontcare", "sys", "usr", "random"]):
cm = metrics.confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cm, legend, legend)
df_cm.to_csv(os.path.join(self.dir, file_name))
return df_cm
def parse_result(self, prediction, label):
_, arg_prediction = torch.max(prediction.data, -1)
batch_size, token_num = label.shape
y_true, y_pred = [], []
y_turn = []
result = {
"non-zero": {"correct": 0, "total": 0},
"total": {"correct": 0, "total": 0},
"turn": {"correct": 0, "total": 0}
}
for batch_num in range(batch_size):
turn_acc = True # old way
turn_success = 1 # new way
for element in range(token_num):
result["total"]["total"] += 1
l = label[batch_num][element].item()
p = arg_prediction[batch_num][element + 1].item()
# old way
if l > 0:
result["non-zero"]["total"] += 1
if p == l:
if l > 0:
result["non-zero"]["correct"] += 1
result["total"]["correct"] += 1
elif p == 0 and l < 0:
result["total"]["correct"] += 1
else:
if l >= 0:
turn_acc = False
# new way
if l >= 0:
y_true.append(l)
y_pred.append(p)
if l >= 0 and l != p:
turn_success = 0
y_turn.append(turn_success)
# old way
result["turn"]["total"] += 1
if turn_acc:
result["turn"]["correct"] += 1
return y_true, y_pred, y_turn, result
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--analysis_dir", type=str,
default="user-analysis-result")
parser.add_argument("--user_config", type=str,
default="convlab/policy/tus/multiwoz/exp/default.json")
default="convlab/policy/tus/multiwoz/exp/multiwoz.json")
parser.add_argument("--user_mode", type=str, default="")
parser.add_argument("--version", type=str, default="")
parser.add_argument("--do_direct", action="store_true")
parser.add_argument("--do_interact", action="store_true")
parser.add_argument("--do_data", action="store_true")
parser.add_argument("--show_dialog", action="store_true")
parser.add_argument("--usr", type=str, default="tus")
parser.add_argument("--sys", type=str, default="rule")
parser.add_argument("--num_dialog", type=int, default=400)
parser.add_argument("--use_mask", action="store_true")
parser.add_argument("--sys_config", type=str, default="")
parser.add_argument("--sys_model_dir", type=str,
default="convlab/policy/ppo/save/")
parser.add_argument("--domain", type=str, default="",
help="the user goal must contain a specific domain")
parser.add_argument("--load_path", type=str, default="",
......@@ -574,29 +228,23 @@ if __name__ == '__main__':
args = parser.parse_args()
analysis_dir = os.path.join(args.analysis_dir, f"{args.sys}-{args.usr}")
if args.version:
analysis_dir = os.path.join(analysis_dir, args.version)
analysis_dir = os.path.join(f"{args.analysis_dir}-{args.usr}")
if not os.path.exists(os.path.join(analysis_dir)):
os.makedirs(analysis_dir)
config = json.load(open(args.user_config))
init_logging(log_dir_path=os.path.join(analysis_dir, "log"))
if args.user_mode:
config["model_name"] = config["model_name"] + '-' + args.user_mode
config["model_dir"] = f'{config["model_dir"]}_{args.dial_ids_order}'
# config["model_dir"] = f'{config["model_dir"]}_{args.dial_ids_order}'
# with open(config["all_slot"]) as f:
# action_list = [line.strip() for line in f]
# config["num_token"] = len(action_list)
if args.use_mask:
config["domain_mask"] = True
ana = Analysis(config, analysis_dir=analysis_dir,
show_dialog=args.show_dialog)
ana = Analysis(config, analysis_dir=analysis_dir)
if (args.usr == "tus" or args.usr == "ppo-tus") and args.do_data:
if args.usr == "tus" and args.do_data:
test_data = load_dataset(args.dataset,
dial_ids_order=args.dial_ids_order)["test"]
if args.user_mode:
......@@ -610,47 +258,3 @@ if __name__ == '__main__':
usr=args.usr,
user_mode=user_mode,
load_path=args.load_path)
if args.usr == "tus" and args.do_direct:
print("direct test")
raw_data = load_dataset(args.dataset,
dial_ids_order=args.dial_ids_order)
test_data = DataLoader(
TUSDataManager(
config, raw_data["test"]),
batch_size=config["batch_size"],
shuffle=True)
model = TransformerActionPrediction(config)
if args.user_mode:
model.load_state_dict(torch.load(
os.path.join(config["model_dir"], config["model_name"])))
print(args.user_mode)
old_result = ana.direct_test(
model, test_data, user_mode=args.user_mode)
print(old_result)
else:
for user_mode in ["loss", "total", "turn", "non-zero"]:
model.load_state_dict(torch.load(
os.path.join(config["model_dir"], config["model_name"] + '-' + user_mode)))
print(user_mode)
old_result = ana.direct_test(
model, test_data, user_mode=user_mode)
print(old_result)
if args.do_interact:
sys_load_path = None
if args.sys_config:
_, file_extension = os.path.splitext(args.sys_config)
# read from config
if file_extension == ".json":
sys_config = json.load(open(args.sys_config))
file_name = f"{sys_config['current_time']}_best_complete_rate_ppo"
sys_load_path = os.path.join(args.sys_model_dir, file_name)
# read from file
else:
sys_load_path = args.sys_config
ana.interact_test(sys=args.sys, usr=args.usr,
sys_load_path=sys_load_path,
num_dialog=args.num_dialog,
domain=args.domain)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment