Skip to content
Snippets Groups Projects
Commit 99bfdf7f authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

GenTUS and TUS training

parent 9696fd17
Branches
No related tags found
No related merge requests found
Showing
with 2417 additions and 84 deletions
import json
import os
import sys
from argparse import ArgumentParser
from pprint import pprint
import torch
from convlab.nlg.evaluate import fine_SER
from datasets import load_metric
# from convlab.policy.genTUS.pg.stepGenTUSagent import \
# stepGenTUSPG as UserPolicy
from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
from tqdm import tqdm
sys.path.append(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))))
def arg_parser():
parser = ArgumentParser()
parser.add_argument("--model-checkpoint", type=str, help="the model path")
parser.add_argument("--model-weight", type=str,
help="the model weight", default="")
parser.add_argument("--input-file", type=str, help="the testing input file",
default="")
parser.add_argument("--generated-file", type=str, help="the generated results",
default="")
parser.add_argument("--only-action", action="store_true")
parser.add_argument("--dataset", default="multiwoz")
parser.add_argument("--do-semantic", action="store_true",
help="do semantic evaluation")
parser.add_argument("--do-nlg", action="store_true",
help="do nlg generation")
parser.add_argument("--do-golden-nlg", action="store_true",
help="do golden nlg generation")
return parser.parse_args()
class Evaluator:
def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False):
self.dataset = dataset
self.model_checkpoint = model_checkpoint
self.model_weight = model_weight
# if model_weight:
# self.usr_policy = UserPolicy(
# self.model_checkpoint, only_action=only_action)
# self.usr_policy.load(model_weight)
# self.usr = self.usr_policy.usr
# else:
self.usr = UserActionPolicy(
model_checkpoint, only_action=only_action, dataset=self.dataset)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
def generate_results(self, f_eval, golden=False):
in_file = json.load(open(f_eval))
r = {
"input": [],
"golden_acts": [],
"golden_utts": [],
"gen_acts": [],
"gen_utts": []
}
for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
if golden:
usr_act = labels["action"]
usr_utt = self.usr.generate_text_from_give_semantic(
inputs, usr_act)
else:
output = self.usr._parse_output(
self.usr._generate_action(inputs))
usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"]
r["input"].append(inputs)
r["golden_acts"].append(labels["action"])
r["golden_utts"].append(labels["text"])
r["gen_acts"].append(usr_act)
r["gen_utts"].append(usr_utt)
return r
def read_generated_result(self, f_eval):
in_file = json.load(open(f_eval))
r = {
"input": [],
"golden_acts": [],
"golden_utts": [],
"gen_acts": [],
"gen_utts": []
}
for dialog in tqdm(in_file['dialog']):
for x in dialog:
r[x].append(dialog[x])
return r
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
if input_file:
print("Force generation")
gen_r = self.generate_results(input_file, golden)
elif generated_file:
gen_r = self.read_generated_result(generated_file)
else:
print("You must specify the input_file or the generated_file")
nlg_eval = {
"golden": golden,
"metrics": {},
"dialog": []
}
for input, golden_act, golden_utt, gen_act, gen_utt in zip(gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["gen_acts"], gen_r["gen_utts"]):
nlg_eval["dialog"].append({
"input": input,
"golden_acts": golden_act,
"golden_utts": golden_utt,
"gen_acts": gen_act,
"gen_utts": gen_utt
})
if golden:
print("Calculate BLEU")
bleu_metric = load_metric("sacrebleu")
labels = [[utt] for utt in gen_r["golden_utts"]]
bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"],
references=labels,
force=True)
print("bleu_metric", bleu_score)
nlg_eval["metrics"]["bleu"] = bleu_score
else:
print("Calculate SER")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
gen_r["gen_acts"], gen_r["gen_utts"])
print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
"genTUSNLG", missing, total, hallucinate, missing/total))
nlg_eval["metrics"]["SER"] = missing/total
dir_name = self.model_checkpoint
json.dump(nlg_eval,
open(os.path.join(dir_name, "nlg_eval.json"), 'w'),
indent=2)
return os.path.join(dir_name, "nlg_eval.json")
def evaluation(self, input_file=None, generated_file=None):
force_prediction = True
if generated_file:
gen_file = json.load(open(generated_file))
force_prediction = False
if gen_file["golden"]:
force_prediction = True
if force_prediction:
in_file = json.load(open(input_file))
dialog_result = []
gen_acts, golden_acts = [], []
# scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
ans_action = self.usr._remove_illegal_action(labels["action"])
preds = self.usr._generate_action(inputs)
preds = self.usr._parse_output(preds)
usr_action = self.usr._remove_illegal_action(preds["action"])
gen_acts.append(usr_action)
golden_acts.append(ans_action)
d = {"input": inputs,
"golden_acts": ans_action,
"gen_acts": usr_action}
if "text" in preds:
d["golden_utts"] = labels["text"]
d["gen_utts"] = preds["text"]
# print("pred text", preds["text"])
dialog_result.append(d)
else:
gen_acts, golden_acts = [], []
for dialog in gen_file['dialog']:
gen_acts.append(dialog["gen_acts"])
golden_acts.append(dialog["golden_acts"])
dialog_result = gen_file['dialog']
scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
for gen_act, golden_act in zip(gen_acts, golden_acts):
s = f1_measure(preds=gen_act, labels=golden_act)
for metric in scores:
scores[metric].append(s[metric])
result = {}
for metric in scores:
result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}")
result["dialog"] = dialog_result
basename = "semantic_evaluation_result"
json.dump(result, open(os.path.join(
self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
# if self.model_weight:
# json.dump(result, open(os.path.join(
# 'results', f"{basename}.json"), 'w'))
# else:
# json.dump(result, open(os.path.join(
# self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
def f1_measure(preds, labels):
tp = 0
score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0}
for p in preds:
if p in labels:
tp += 1.0
if preds:
score["precision"] = tp/len(preds)
if labels:
score["recall"] = tp/len(labels)
if (score["precision"] + score["recall"]) > 0:
score["f1"] = 2*(score["precision"]*score["recall"]) / \
(score["precision"]+score["recall"])
if tp == len(preds) and tp == len(labels):
score["turn_acc"] = 1
return score
def main():
args = arg_parser()
eval = Evaluator(args.model_checkpoint,
args.dataset,
args.model_weight,
args.only_action)
print("model checkpoint", args.model_checkpoint)
print("generated_file", args.generated_file)
print("input_file", args.input_file)
with torch.no_grad():
if args.do_semantic:
eval.evaluation(args.input_file)
if args.do_nlg:
nlg_result = eval.nlg_evaluation(input_file=args.input_file,
generated_file=args.generated_file,
golden=args.do_golden_nlg)
if args.generated_file:
generated_file = args.generated_file
else:
generated_file = nlg_result
eval.evaluation(args.input_file,
generated_file)
if __name__ == '__main__':
main()
import json
import torch
from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
from convlab.policy.genTUS.token_map import tokenMap
from convlab.policy.tus.unify.Goal import Goal
from transformers import BartTokenizer
class stepGenTUSVector:
def __init__(self, model_checkpoint, max_in_len=400, max_out_len=80, allow_general_intent=True):
self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
self.vocab = len(self.tokenizer)
self.max_in_len = max_in_len
self.max_out_len = max_out_len
self.token_map = tokenMap(tokenizer=self.tokenizer)
self.token_map.default(only_action=True)
self.kg = KnowledgeGraph(self.tokenizer)
self.mentioned_domain = []
self.allow_general_intent = allow_general_intent
self.candidate_num = 5
if self.allow_general_intent:
print("---> allow_general_intent")
def init_session(self, goal: Goal):
self.goal = goal
self.mentioned_domain = []
def encode(self, raw_inputs, max_length, return_tensors="pt", truncation=True):
model_input = self.tokenizer(raw_inputs,
max_length=max_length,
return_tensors=return_tensors,
truncation=truncation,
padding="max_length")
return model_input
def decode(self, generated_so_far, skip_special_tokens=True):
output = self.tokenizer.decode(
generated_so_far, skip_special_tokens=skip_special_tokens)
return output
def state_vectorize(self, action, history, turn):
self.goal.update_user_goal(action=action)
inputs = json.dumps({"system": action,
"goal": self.goal.get_goal_list(),
"history": history,
"turn": str(turn)})
inputs = self.encode(inputs, self.max_in_len)
s_vec, action_mask = inputs["input_ids"][0], inputs["attention_mask"][0]
return s_vec, action_mask
def action_vectorize(self, action, s=None):
# action: [[intent, domain, slot, value], ...]
vec = {"vector": torch.tensor([]), "mask": torch.tensor([])}
if s is not None:
raw_inputs = self.decode(s[0])
self.kg.parse_input(raw_inputs)
self._append(vec, self._get_id("<s>"))
self._append(vec, self.token_map.get_id('start_json'))
self._append(vec, self.token_map.get_id('start_act'))
act_len = len(action)
for i, (intent, domain, slot, value) in enumerate(action):
if value == '?':
value = '<?>'
c_idx = {x: None for x in ["intent", "domain", "slot", "value"]}
if s is not None:
c_idx["intent"] = self._candidate_id(self.kg.candidate(
"intent", allow_general_intent=self.allow_general_intent))
c_idx["domain"] = self._candidate_id(self.kg.candidate(
"domain", intent=intent))
c_idx["slot"] = self._candidate_id(self.kg.candidate(
"slot", intent=intent, domain=domain, is_mentioned=self.is_mentioned(domain)))
c_idx["value"] = self._candidate_id(self.kg.candidate(
"value", intent=intent, domain=domain, slot=slot))
self._append(vec, self._get_id(intent), c_idx["intent"])
self._append(vec, self.token_map.get_id('sep_token'))
self._append(vec, self._get_id(domain), c_idx["domain"])
self._append(vec, self.token_map.get_id('sep_token'))
self._append(vec, self._get_id(slot), c_idx["slot"])
self._append(vec, self.token_map.get_id('sep_token'))
self._append(vec, self._get_id(value), c_idx["value"])
c_idx = [0]*self.candidate_num
c_idx[0] = self.token_map.get_id('end_act')[0]
c_idx[1] = self.token_map.get_id('sep_act')[0]
if i == act_len - 1:
x = self.token_map.get_id('end_act')
else:
x = self.token_map.get_id('sep_act')
self._append(vec, x, c_idx)
self._append(vec, self._get_id("</s>"))
# pad
if len(vec["vector"]) < self.max_out_len:
pad_len = self.max_out_len-len(vec["vector"])
self._append(vec, x=torch.tensor([1]*pad_len))
for vec_type in vec:
vec[vec_type] = vec[vec_type].to(torch.int64)
return vec
def _append(self, vec, x, candidate=None):
if type(x) is list:
x = torch.tensor(x)
mask = self._mask(x, candidate)
vec["vector"] = torch.cat((vec["vector"], x), dim=-1)
vec["mask"] = torch.cat((vec["mask"], mask), dim=0)
def _mask(self, idx, c_idx=None):
mask = torch.zeros(len(idx), self.candidate_num)
mask[:, 0] = idx
if c_idx is not None and len(c_idx) > 1:
mask[0, :] = torch.tensor(c_idx)
return mask
def _candidate_id(self, candidate):
if len(candidate) > self.candidate_num:
print(f"too many candidates. Max = {self.candidate_num}")
c_idx = [0]*self.candidate_num
for i, idx in enumerate([self._get_id(c)[0] for c in candidate[:self.candidate_num]]):
c_idx[i] = idx
return c_idx
def _get_id(self, value):
token_id = self.tokenizer(value, add_special_tokens=False)
return token_id["input_ids"]
def action_devectorize(self, action_id):
return self.decode(action_id)
def update_mentioned_domain(self, semantic_act):
for act in semantic_act:
domain = act[1]
if domain not in self.mentioned_domain:
self.mentioned_domain.append(domain)
def is_mentioned(self, domain):
if domain in self.mentioned_domain:
return True
return False
import json
import os
import torch
from transformers import BartTokenizer
from convlab.policy.genTUS.ppo.vector import stepGenTUSVector
from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel
from convlab.policy.genTUS.token_map import tokenMap
from convlab.policy.genTUS.unify.Goal import Goal
from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
from convlab.policy.policy import Policy
from convlab.task.multiwoz.goal_generator import GoalGenerator
DEBUG = False
class UserActionPolicy(Policy):
def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs):
self.mode = mode
# if mode == "semantic" and only_action:
# # only generate semantic action in prediction
print("model_checkpoint", model_checkpoint)
self.only_action = only_action
if self.only_action:
print("change mode to semantic because only_action=True")
self.mode = "semantic"
self.max_in_len = 500
self.max_out_len = 50 if only_action else 200
max_act_len = kwargs.get("max_act_len", 2)
print("max_act_len", max_act_len)
self.max_action_len = max_act_len
if "max_act_len" in kwargs:
self.max_out_len = 30 * self.max_action_len
print("max_act_len", self.max_out_len)
self.max_turn = max_turn
if mode not in ["semantic", "language"]:
print("Unknown user mode")
self.reward = {"success": self.max_turn*2,
"fail": self.max_turn*-1}
self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
train_whole_model = kwargs.get("whole_model", True)
self.model = stepGenTUSmodel(
model_checkpoint, train_whole_model=train_whole_model)
self.model.eval()
self.model.to(self.device)
self.model.share_memory()
self.turn_level_reward = kwargs.get("turn_level_reward", True)
self.cooperative = kwargs.get("cooperative", True)
dataset = kwargs.get("dataset", "")
self.kg = KnowledgeGraph(
tokenizer=self.tokenizer,
dataset=dataset)
self.goal_gen = GoalGenerator()
self.vector = stepGenTUSVector(
model_checkpoint, self.max_in_len, self.max_out_len)
self.norm_reward = False
self.action_penalty = kwargs.get("action_penalty", False)
self.usr_act_penalize = kwargs.get("usr_act_penalize", 0)
self.goal_list_type = kwargs.get("goal_list_type", "normal")
self.update_mode = kwargs.get("update_mode", "normal")
self.max_history = kwargs.get("max_history", 3)
self.init_session()
def _update_seq(self, sub_seq: list, pos: int):
for x in sub_seq:
self.seq[0, pos] = x
pos += 1
return pos
def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True):
# TODO no duplicate
self.kg.parse_input(raw_inputs)
model_input = self.vector.encode(raw_inputs, self.max_in_len)
# start token
self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos)
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
# get semantic actions
for act_len in range(self.max_action_len):
pos = self._get_semantic_action(
model_input, pos, mode, allow_general_intent)
terminate, token_name = self._stop_semantic(
model_input, pos, act_len)
pos = self._update_seq(self.token_map.get_id(token_name), pos)
if terminate:
break
if self.only_action:
# return semantic action. Don't need to generate text
return self.vector.decode(self.seq[0, :pos])
# TODO remove illegal action here?
# get text output
pos = self._update_seq(self.token_map.get_id("start_text"), pos)
text = self._get_text(model_input, pos)
return text
def generate_text_from_give_semantic(self, raw_inputs, semantic_action):
self.kg.parse_input(raw_inputs)
model_input = self.vector.encode(raw_inputs, self.max_in_len)
self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos)
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
if len(semantic_action) == 0:
pos = self._update_seq(self.token_map.get_id("end_act"), pos)
for act_id, (intent, domain, slot, value) in enumerate(semantic_action):
pos = self._update_seq(self.kg._get_token_id(intent), pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.kg._get_token_id(domain), pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.kg._get_token_id(slot), pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.kg._get_token_id(value), pos)
if act_id == len(semantic_action) - 1:
token_name = "end_act"
else:
token_name = "sep_act"
pos = self._update_seq(self.token_map.get_id(token_name), pos)
pos = self._update_seq(self.token_map.get_id("start_text"), pos)
raw_output = self._get_text(model_input, pos)
return self._parse_output(raw_output)["text"]
def _get_text(self, model_input, pos):
s_pos = pos
for i in range(s_pos, self.max_out_len):
next_token_logits = self.model.get_next_token_logits(
model_input, self.seq[:1, :pos])
next_token = torch.argmax(next_token_logits, dim=-1)
if self._stop_text(next_token):
# text = self.vector.decode(self.seq[0, s_pos:pos])
# text = self._norm_str(text)
# return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
break
pos = self._update_seq([next_token], pos)
text = self.vector.decode(self.seq[0, s_pos:pos])
text = self._norm_str(text)
return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
# TODO return None
def _stop_text(self, next_token):
if next_token == self.token_map.get_id("end_json")[0]:
return True
elif next_token == self.token_map.get_id("end_json_2")[0]:
return True
return False
@staticmethod
def _norm_str(text: str):
text = text.strip('"')
text = text.replace('"', "'")
text = text.replace('\\', "")
return text
def _stop_semantic(self, model_input, pos, act_length=0):
outputs = self.model.get_next_token_logits(
model_input, self.seq[:1, :pos])
tokens = {}
for token_name in ['sep_act', 'end_act']:
tokens[token_name] = {
"token_id": self.token_map.get_id(token_name)}
hash_id = tokens[token_name]["token_id"][0]
tokens[token_name]["score"] = outputs[:, hash_id].item()
if tokens['end_act']["score"] > tokens['sep_act']["score"]:
terminate = True
else:
terminate = False
if act_length >= self.max_action_len - 1:
terminate = True
token_name = "end_act" if terminate else "sep_act"
return terminate, token_name
def _get_semantic_action(self, model_input, pos, mode="max", allow_general_intent=True):
intent = self._get_intent(
model_input, self.seq[:1, :pos], mode, allow_general_intent)
pos = self._update_seq(intent["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
# get domain
domain = self._get_domain(
model_input, self.seq[:1, :pos], intent["token_name"], mode)
pos = self._update_seq(domain["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
# get slot
slot = self._get_slot(
model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode)
pos = self._update_seq(slot["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
# get value
value = self._get_value(
model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], slot["token_name"], mode)
pos = self._update_seq(value["token_id"], pos)
return pos
def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
return self.kg.get_intent(next_token_logits, mode, allow_general_intent)
def _get_domain(self, model_input, generated_so_far, intent, mode="max"):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
return self.kg.get_domain(next_token_logits, intent, mode)
def _get_slot(self, model_input, generated_so_far, intent, domain, mode="max"):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
is_mentioned = self.vector.is_mentioned(domain)
return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned)
def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"):
next_token_logits = self.model.get_next_token_logits(
model_input, generated_so_far)
return self.kg.get_value(next_token_logits, intent, domain, slot, mode)
def _remove_illegal_action(self, action):
# Transform illegal action to legal action
new_action = []
for act in action:
if len(act) == 4:
if "<?>" in act[-1]:
act = [act[0], act[1], act[2], "?"]
if act not in new_action:
new_action.append(act)
else:
print("illegal action:", action)
return new_action
def _parse_output(self, in_str):
in_str = str(in_str)
in_str = in_str.replace('<s>', '').replace(
'<\\s>', '').replace('o"clock', "o'clock")
action = {"action": [], "text": ""}
try:
action = json.loads(in_str)
except:
print("invalid action:", in_str)
print("-"*20)
return action
def predict(self, sys_act, mode="max", allow_general_intent=True):
# raw_sys_act = sys_act
# sys_act = sys_act[:5]
# update goal
# TODO
allow_general_intent = False
self.model.eval()
if not self.add_sys_from_reward:
self.goal.update_user_goal(action=sys_act, char="sys")
self.sys_acts.append(sys_act) # for terminate conversation
# update constraint
self.time_step += 2
history = []
if self.usr_acts:
if self.max_history == 1:
history = self.usr_acts[-1]
else:
history = self.usr_acts[-1*self.max_history:]
inputs = json.dumps({"system": sys_act,
"goal": self.goal.get_goal_list(),
"history": history,
"turn": str(int(self.time_step/2))})
with torch.no_grad():
raw_output = self._generate_action(
raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
output = self._parse_output(raw_output)
self.semantic_action = self._remove_illegal_action(output["action"])
if not self.only_action:
self.utterance = output["text"]
# TODO
if self.is_finish():
self.semantic_action, self.utterance = self._good_bye()
# if self.is_finish():
# print("terminated")
# if self.is_finish():
# good_bye = self._good_bye()
# self.goal.add_usr_da(good_bye)
# return good_bye
self.goal.update_user_goal(action=self.semantic_action, char="usr")
self.vector.update_mentioned_domain(self.semantic_action)
self.usr_acts.append(self.semantic_action)
# if self._usr_terminate(usr_action):
# print("terminated by user")
# self.terminated = True
del inputs
if self.mode == "language":
# print("in", sys_act)
# print("out", self.utterance)
return self.utterance
else:
return self.semantic_action
def init_session(self, goal=None):
self.token_map = tokenMap(tokenizer=self.tokenizer)
self.token_map.default(only_action=self.only_action)
self.time_step = 0
remove_domain = "police" # remove police domain in inference
if not goal:
self._new_goal(remove_domain=remove_domain)
else:
self._read_goal(goal)
self.vector.init_session(goal=self.goal)
self.terminated = False
self.add_sys_from_reward = False
self.sys_acts = []
self.usr_acts = []
self.semantic_action = []
self.utterance = ""
def _read_goal(self, data_goal):
self.goal = Goal(goal=data_goal)
def _new_goal(self, remove_domain="police", domain_len=None):
self.goal = Goal(goal_generator=self.goal_gen)
# keep_generate_goal = True
# # domain_len = 1
# while keep_generate_goal:
# self.goal = Goal(goal_generator=self.goal_gen,
# goal_list_type=self.goal_list_type,
# update_mode=self.update_mode)
# if (domain_len and len(self.goal.domains) != domain_len) or \
# (remove_domain and remove_domain in self.goal.domains):
# keep_generate_goal = True
# else:
# keep_generate_goal = False
def load(self, model_path):
self.model.load_state_dict(torch.load(
model_path, map_location=self.device))
# self.model = BartForConditionalGeneration.from_pretrained(
# model_checkpoint)
def get_goal(self):
if self.goal.raw_goal is not None:
return self.goal.raw_goal
goal = {}
for domain in self.goal.domain_goals:
if domain not in goal:
goal[domain] = {}
for intent in self.goal.domain_goals[domain]:
if intent == "inform":
slot_type = "info"
elif intent == "request":
slot_type = "reqt"
elif intent == "book":
slot_type = "book"
else:
print("unknown slot type")
if slot_type not in goal[domain]:
goal[domain][slot_type] = {}
for slot, value in self.goal.domain_goals[domain][intent].items():
goal[domain][slot_type][slot] = value
return goal
def get_reward(self, sys_response=None):
self.add_sys_from_reward = False if sys_response is None else True
if self.add_sys_from_reward:
self.goal.update_user_goal(action=sys_response, char="sys")
self.goal.add_sys_da(sys_response) # for evaluation
self.sys_acts.append(sys_response) # for terminate conversation
if self.is_finish():
if self.is_success():
reward = self.reward["success"]
self.success = True
else:
reward = self.reward["fail"]
self.success = False
else:
reward = -1
if self.turn_level_reward:
reward += self.turn_reward()
self.success = None
# if self.action_penalty:
# reward += self._system_action_penalty()
if self.norm_reward:
reward = (reward - 20)/60
return reward
def _system_action_penalty(self):
free_action_len = 3
if len(self.sys_acts) < 1:
return 0
# TODO only penalize the slots not in user goal
# else:
# penlaty = 0
# for i in range(len(self.sys_acts[-1])):
# penlaty += -1*i
# return penlaty
if len(self.sys_acts[-1]) > 3:
return -1*(len(self.sys_acts[-1])-free_action_len)
return 0
def turn_reward(self):
r = 0
r += self._new_act_reward()
r += self._reply_reward()
r += self._usr_act_len()
return r
def _usr_act_len(self):
last_act = self.usr_acts[-1]
penalty = 0
if len(last_act) > 2:
penalty = (2-len(last_act))*self.usr_act_penalize
return penalty
def _new_act_reward(self):
last_act = self.usr_acts[-1]
if last_act != self.semantic_action:
print(f"---> why? last {last_act} usr {self.semantic_action}")
new_act = []
for act in last_act:
if len(self.usr_acts) < 2:
break
if act[1].lower() == "general":
new_act.append(0)
elif act in self.usr_acts[-2]:
new_act.append(-1)
elif act not in self.usr_acts[-2]:
new_act.append(1)
return sum(new_act)
def _reply_reward(self):
if self.cooperative:
return self._cooperative_reply_reward()
else:
return self._non_cooperative_reply_reward()
def _non_cooperative_reply_reward(self):
r = []
reqts = []
infos = []
reply_len = 0
max_len = 1
for act in self.sys_acts[-1]:
if act[0] == "request":
reqts.append([act[1], act[2]])
for act in self.usr_acts[-1]:
if act[0] == "inform":
infos.append([act[1], act[2]])
for req in reqts:
if req in infos:
if reply_len < max_len:
r.append(1)
elif reply_len == max_len:
r.append(0)
else:
r.append(-5)
if r:
return sum(r)
return 0
def _cooperative_reply_reward(self):
r = []
reqts = []
infos = []
for act in self.sys_acts[-1]:
if act[0] == "request":
reqts.append([act[1], act[2]])
for act in self.usr_acts[-1]:
if act[0] == "inform":
infos.append([act[1], act[2]])
for req in reqts:
if req in infos:
r.append(1)
else:
r.append(-1)
if r:
return sum(r)
return 0
def _usr_terminate(self):
for act in self.semantic_action:
if act[0] in ['thank', 'bye']:
return True
return False
def is_finish(self):
# stop by model generation?
if self._finish_conversation_rule():
self.terminated = True
return True
elif self._usr_terminate():
self.terminated = True
return True
self.terminated = False
return False
def is_success(self):
task_complete = self.goal.task_complete()
# goal_status = self.goal.all_mentioned()
# should mentioned all slots
if task_complete: # and goal_status["complete"] > 0.6:
return True
return False
def _good_bye(self):
if self.is_success():
return [['thank', 'general', 'none', 'none']], "thank you. bye"
# if self.mode == "semantic":
# return [['thank', 'general', 'none', 'none']]
# else:
# return "bye"
else:
return [["bye", "general", "None", "None"]], "bye"
if self.mode == "semantic":
return [["bye", "general", "None", "None"]]
return "bye"
def _finish_conversation_rule(self):
if self.is_success():
return True
if self.time_step > self.max_turn:
return True
if (len(self.sys_acts) > 4) and (self.sys_acts[-1] == self.sys_acts[-2]) and (self.sys_acts[-2] == self.sys_acts[-3]):
return True
return False
def is_terminated(self):
# Is there any action to say?
self.is_finish()
return self.terminated
class UserPolicy(Policy):
def __init__(self,
model_checkpoint,
mode="semantic",
only_action=True,
sample=False,
action_penalty=False,
**kwargs):
# self.config = config
# if not os.path.exists(self.config["model_dir"]):
# os.mkdir(self.config["model_dir"])
# model_downloader(self.config["model_dir"],
# "https://zenodo.org/record/5779832/files/default.zip")
self.policy = UserActionPolicy(
model_checkpoint,
mode=mode,
only_action=only_action,
action_penalty=action_penalty,
**kwargs)
self.policy.load(os.path.join(
model_checkpoint, "pytorch_model.bin"))
self.sample = sample
def predict(self, sys_act, mode="max"):
if self.sample:
mode = "sample"
else:
mode = "max"
response = self.policy.predict(sys_act, mode)
return response
def init_session(self, goal=None):
self.policy.init_session(goal)
def is_terminated(self):
return self.policy.is_terminated()
def get_reward(self, sys_response=None):
return self.policy.get_reward(sys_response)
def get_goal(self):
if hasattr(self.policy, 'get_goal'):
return self.policy.get_goal()
return None
if __name__ == "__main__":
import os
from convlab.dialog_agent import PipelineAgent
# from convlab.nlu.jointBERT.multiwoz import BERTNLU
from convlab.util.custom_util import set_seed
set_seed(20220220)
# Test semantic level behaviour
model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0'
usr_policy = UserPolicy(
model_checkpoint,
mode="semantic")
usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
usr_nlu = None # BERTNLU()
usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
print(usr.policy.get_goal())
print(usr.response([]))
print(usr.policy.policy.goal.status)
print(usr.response([["request", "attraction", "area", "?"]]))
print(usr.policy.policy.goal.status)
print(usr.response([["request", "attraction", "area", "?"]]))
print(usr.policy.policy.goal.status)
import json
import torch
from torch.nn.functional import softmax, one_hot, cross_entropy
from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
from convlab.policy.genTUS.token_map import tokenMap
from convlab.policy.genTUS.utils import append_tokens
from transformers import (BartConfig, BartForConditionalGeneration,
BartTokenizer)
class stepGenTUSmodel(BartForConditionalGeneration):
def __init__(self, model_checkpoint, train_whole_model=True, **kwargs):
config = BartConfig.from_pretrained(model_checkpoint)
super().__init__(config, **kwargs)
self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
self.vocab = len(self.tokenizer)
self.kg = KnowledgeGraph(self.tokenizer)
self.action_kg = KnowledgeGraph(self.tokenizer)
self.token_map = tokenMap(self.tokenizer)
# only_action doesn't matter. it is only used for get_log_prob
self.token_map.default(only_action=True)
if not train_whole_model:
for param in self.parameters():
param.requires_grad = False
for param in self.model.decoder.layers[-1].fc1.parameters():
param.requires_grad = True
for param in self.model.decoder.layers[-1].fc2.parameters():
param.requires_grad = True
def get_trainable_param(self):
return filter(
lambda p: p.requires_grad, self.parameters())
def get_next_token_logits(self, model_input, generated_so_far):
input_ids = model_input["input_ids"].to(self.device)
attention_mask = model_input["attention_mask"].to(self.device)
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=generated_so_far,
return_dict=True)
return outputs.logits[:, -1, :]
def get_log_prob(self, s, a, action_mask, prob_mask):
output = self.forward(input_ids=s,
attention_mask=action_mask,
decoder_input_ids=a)
prob = self._norm_prob(a[:, 1:].long(),
output.logits[:, :-1, :],
prob_mask[:, 1:, :].long())
return prob
def _norm_prob(self, a, prob, mask):
prob = softmax(prob, -1)
base = self._base(prob, mask).to(self.device) # [b, seq_len]
prob = (prob*one_hot(a, num_classes=self.vocab)).sum(-1)
prob = torch.log(prob / base)
pad_mask = a != 1
prob = prob*pad_mask.float()
return prob.sum(-1)
@staticmethod
def _base(prob, mask):
batch_size, seq_len, dim = prob.shape
base = torch.zeros(batch_size, seq_len)
for b in range(batch_size):
for s in range(seq_len):
temp = [prob[b, s, c] for c in mask[b, s, :] if c > 0]
base[b, s] = torch.sum(torch.tensor(temp))
return base
if __name__ == "__main__":
import os
from convlab.util.custom_util import set_seed
from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
set_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_checkpoint = 'results/genTUS-22-01-31-09-21/'
usr = UserActionPolicy(model_checkpoint=model_checkpoint)
usr.model.load_state_dict(torch.load(
os.path.join(model_checkpoint, "pytorch_model.bin"), map_location=device))
usr.model.eval()
test_file = "convlab/policy/genTUS/data/goal_status_validation_v1.json"
data = json.load(open(test_file))
test_id = 20
inputs = usr.tokenizer(data["dialog"][test_id]["in"],
max_length=400,
return_tensors="pt",
truncation=True)
actions = [data["dialog"][test_id]["out"],
data["dialog"][test_id+100]["out"]]
for action in actions:
action = json.loads(action)
vec = usr.vector.action_vectorize(
action["action"], s=inputs["input_ids"])
print({"action": action["action"]})
print("get_log_prob", usr.model.get_log_prob(
inputs["input_ids"],
torch.unsqueeze(vec["vector"], 0),
inputs["attention_mask"],
torch.unsqueeze(vec["mask"], 0)))
import json
class tokenMap:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.token_name = {}
self.hash_map = {}
self.debug = False
self.default()
def default(self, only_action=False):
self.format_tokens = {
'start_json': '{"action": [', # 49643, 10845, 7862, 646
'start_act': '["', # 49329
'sep_token': '", "', # 1297('",'), 22
'sep_act': '"], ["', # 49177
'end_act': '"]], "', # 42248, 7479, 22
'start_text': 'text": "', # 29015, 7862, 22
'end_json': '}', # 24303
'end_json_2': '"}' # 48805
}
if only_action:
self.format_tokens['end_act'] = '"]]}'
for token_name in self.format_tokens:
self.add_token(
token_name, self.format_tokens[token_name])
def add_token(self, token_name, value):
if token_name in self.token_name and self.debug:
print(f"---> duplicate token: {token_name}({value})!!!!!!!")
token_id = self.tokenizer(str(value), add_special_tokens=False)[
"input_ids"]
self.token_name[token_name] = {"value": value, "token_id": token_id}
# print(token_id)
hash_id = token_id[0]
if hash_id in self.hash_map and self.debug:
print(
f"---> conflict hash number {hash_id}: {self.hash_map[hash_id]['name']} and {token_name}")
self.hash_map[hash_id] = {
"name": token_name, "value": value, "token_id": token_id}
def get_info(self, hash_id):
return self.hash_map[hash_id]
def get_id(self, token_name):
# workaround
# if token_name not in self.token_name[token_name]:
# self.add_token(token_name, token_name)
return self.token_name[token_name]["token_id"]
def get_token_value(self, token_name):
return self.token_name[token_name]["value"]
def token_name_is_in(self, token_name):
if token_name in self.token_name:
return True
return False
def hash_id_is_in(self, hash_id):
if hash_id in self.hash_map:
return True
return False
import json
import os
import sys
from argparse import ArgumentParser
from datetime import datetime
from pprint import pprint
import numpy as np
import torch
import transformers
from datasets import Dataset, load_metric
from tqdm import tqdm
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
BartForConditionalGeneration, BartTokenizer,
DataCollatorForSeq2Seq, Seq2SeqTrainer,
Seq2SeqTrainingArguments)
sys.path.append(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))))
os.environ["WANDB_DISABLED"] = "true"
METRIC = load_metric("sacrebleu")
TOKENIZER = BartTokenizer.from_pretrained("facebook/bart-base")
TOKENIZER.add_tokens(["<?>"])
MAX_IN_LEN = 500
MAX_OUT_LEN = 500
def arg_parser():
parser = ArgumentParser()
# data_name, dial_ids_order, split2ratio
parser.add_argument("--model-type", type=str, default="unify",
help="unify or multiwoz")
parser.add_argument("--data-name", type=str, default="multiwoz21",
help="multiwoz21, sgd, tm1, tm2, tm3, sgd+tm, or all")
parser.add_argument("--dial-ids-order", type=int, default=0)
parser.add_argument("--split2ratio", type=float, default=1)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--model-checkpoint", type=str,
default="facebook/bart-base")
return parser.parse_args()
def gentus_compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = TOKENIZER.batch_decode(
preds, skip_special_tokens=True, max_length=MAX_OUT_LEN)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id)
decoded_labels = TOKENIZER.batch_decode(
labels, skip_special_tokens=True, max_length=MAX_OUT_LEN)
act, text = postprocess_text(decoded_preds, decoded_labels)
result = METRIC.compute(
# predictions=decoded_preds, references=decoded_labels)
predictions=text["preds"], references=text["labels"])
result = {"bleu": result["score"]}
f1_scores = f1_measure(pred_acts=act["preds"], label_acts=act["labels"])
for s in f1_scores:
result[s] = f1_scores[s]
result = {k: round(v, 4) for k, v in result.items()}
return result
def postprocess_text(preds, labels):
act = {"preds": [], "labels": []}
text = {"preds": [], "labels": []}
for pred, label in zip(preds, labels):
model_output = parse_output(pred.strip())
label_output = parse_output(label.strip())
if len(label_output["text"]) < 1:
continue
act["preds"].append(model_output.get("action", []))
text["preds"].append(model_output.get("text", pred.strip()))
act["labels"].append(label_output["action"])
text["labels"].append([label_output["text"]])
return act, text
def parse_output(in_str):
in_str = in_str.replace('<s>', '').replace('<\\s>', '')
try:
output = json.loads(in_str)
except:
# print(f"invalid action {in_str}")
output = {"action": [], "text": ""}
return output
def f1_measure(pred_acts, label_acts):
result = {"precision": [], "recall": [], "f1": []}
for pred, label in zip(pred_acts, label_acts):
r = tp_fn_fp(pred, label)
for m in result:
result[m].append(r[m])
for m in result:
result[m] = sum(result[m])/len(result[m])
return result
def tp_fn_fp(pred, label):
tp, fn, fp = 0.0, 0.0, 0.0
precision, recall, f1 = 0, 0, 0
for p in pred:
if p in label:
tp += 1
else:
fp += 1
for l in label:
if l not in pred:
fn += 1
if (tp+fp) > 0:
precision = tp / (tp+fp)
if (tp+fn) > 0:
recall = tp/(tp+fn)
if (precision + recall) > 0:
f1 = (2*precision*recall)/(precision+recall)
return {"precision": precision, "recall": recall, "f1": f1}
class TrainerHelper:
def __init__(self, tokenizer, max_input_length=500, max_target_length=500):
print("transformers version is: ", transformers.__version__)
self.tokenizer = tokenizer
self.max_input_length = max_input_length
self.max_target_length = max_target_length
self.base_name = "convlab/policy/genTUS"
self.dir_name = ""
def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
# base_name = "convlab/policy/genTUS/unify/data"
if model_type not in ["unify", "multiwoz"]:
print("Unknown model type. Currently only support unify and multiwoz")
self.dir_name = f"{data_name}_{dial_ids_order}_{split2ratio}"
return os.path.join(self.base_name, model_type, 'data', self.dir_name)
def get_model_folder(self, model_type):
folder_name = os.path.join(
self.base_name, model_type, "experiments", self.dir_name)
if not os.path.exists(folder_name):
os.makedirs(folder_name)
return folder_name
def parse_data(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
data_folder = self._get_data_folder(
model_type, data_name, dial_ids_order, split2ratio)
raw_data = {}
for d_type in ["train", "validation", "test"]:
f_name = os.path.join(data_folder, f"{d_type}.json")
raw_data[d_type] = json.load(open(f_name))
tokenized_datasets = {}
for data_type, data in raw_data.items():
tokenized_datasets[data_type] = Dataset.from_dict(
self._preprocess(data["dialog"]))
return tokenized_datasets
def _preprocess(self, examples):
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if isinstance(examples, dict):
examples = [examples]
for example in tqdm(examples):
inputs = self.tokenizer(example["in"],
max_length=self.max_input_length,
truncation=True)
# Setup the tokenizer for targets
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(example["out"],
max_length=self.max_target_length,
truncation=True)
for key in ["input_ids", "attention_mask"]:
model_inputs[key].append(inputs[key])
model_inputs["labels"].append(labels["input_ids"])
return model_inputs
def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"):
tokenizer = TOKENIZER
train_helper = TrainerHelper(
tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length)
data = train_helper.parse_data(model_type=model_type,
data_name=data_name,
dial_ids_order=dial_ids_order,
split2ratio=split2ratio)
model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
model.resize_token_embeddings(len(tokenizer))
fp16 = False
if torch.cuda.is_available():
fp16 = True
model_dir = os.path.join(
train_helper.get_model_folder(model_type),
f"{datetime.now().strftime('%y-%m-%d-%H-%M')}")
args = Seq2SeqTrainingArguments(
model_dir,
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
save_total_limit=2,
num_train_epochs=5,
predict_with_generate=True,
fp16=fp16,
push_to_hub=False,
generation_max_length=max_target_length,
logging_dir=os.path.join(model_dir, 'log')
)
data_collator = DataCollatorForSeq2Seq(
tokenizer, model=model, padding=True)
# customize this trainer
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=data["train"],
eval_dataset=data["test"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=gentus_compute_metrics)
print("start training...")
trainer.train()
print("saving model...")
trainer.save_model()
def main():
args = arg_parser()
print("---> data_name", args.data_name)
train(model_type=args.model_type,
data_name=args.data_name,
dial_ids_order=args.dial_ids_order,
split2ratio=args.split2ratio,
batch_size=args.batch_size,
max_input_length=MAX_IN_LEN,
max_target_length=MAX_OUT_LEN,
model_checkpoint=args.model_checkpoint)
if __name__ == "__main__":
main()
# sgd+tm: 46000
"""
The user goal for unify data format
"""
import json
from convlab.policy.tus.unify.Goal import old_goal2list
from convlab.task.multiwoz.goal_generator import GoalGenerator
from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
from convlab.util.custom_util import slot_mapping
DEF_VAL_UNK = '?' # Unknown
DEF_VAL_DNC = 'dontcare' # Do not care
DEF_VAL_NUL = 'none' # for none
NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, ""]
NOT_MENTIONED = "not mentioned"
FULFILLED = "fulfilled"
REQUESTED = "requested"
CONFLICT = "conflict"
class Goal:
""" User Goal Model Class. """
def __init__(self, goal=None, goal_generator=None):
"""
create new Goal from a dialog or from goal_generator
Args:
goal: can be a list (create from a dialog), an abus goal, or none
"""
self.domains = []
self.domain_goals = {}
self.status = {}
self.invert_slot_mapping = {v: k for k, v in slot_mapping.items()}
self.raw_goal = None
self._init_goal_from_data(goal, goal_generator)
self._init_status()
def __str__(self):
return '-----Goal-----\n' + \
json.dumps(self.domain_goals, indent=4) + \
'\n-----Goal-----'
def _init_goal_from_data(self, goal=None, goal_generator=None):
if not goal and goal_generator:
goal = ABUS_Goal(goal_generator)
self.raw_goal = goal.domain_goals
goal = old_goal2list(goal.domain_goals)
elif isinstance(goal, dict):
self.raw_goal = goal
goal = old_goal2list(goal)
elif isinstance(goal, ABUS_Goal):
self.raw_goal = goal.domain_goals
goal = old_goal2list(goal.domain_goals)
else:
print("unknow goal")
# be careful of this order
for domain, intent, slot, value in goal:
if domain == "none":
continue
if domain not in self.domains:
self.domains.append(domain)
self.domain_goals[domain] = {}
if intent not in self.domain_goals[domain]:
self.domain_goals[domain][intent] = {}
if not value:
if intent == "request":
self.domain_goals[domain][intent][slot] = DEF_VAL_UNK
else:
print(
f"unknown no value intent {domain}, {intent}, {slot}")
else:
self.domain_goals[domain][intent][slot] = value
def _init_status(self):
for domain, domain_goal in self.domain_goals.items():
if domain not in self.status:
self.status[domain] = {}
for slot_type, sub_domain_goal in domain_goal.items():
if slot_type not in self.status[domain]:
self.status[domain][slot_type] = {}
for slot in sub_domain_goal:
if slot not in self.status[domain][slot_type]:
self.status[domain][slot_type][slot] = {}
self.status[domain][slot_type][slot] = {
"value": str(sub_domain_goal[slot]),
"status": NOT_MENTIONED}
def get_goal_list(self, data_goal=None):
goal_list = []
if data_goal:
# make sure the order!!!
for domain, intent, slot, _ in data_goal:
status = self._get_status(domain, intent, slot)
value = self.domain_goals[domain][intent][slot]
goal_list.append([intent, domain, slot, value, status])
return goal_list
else:
for domain, domain_goal in self.domain_goals.items():
for intent, sub_goal in domain_goal.items():
for slot, value in sub_goal.items():
status = self._get_status(domain, intent, slot)
goal_list.append([intent, domain, slot, value, status])
return goal_list
def _get_status(self, domain, intent, slot):
if domain not in self.status:
return NOT_MENTIONED
if intent not in self.status[domain]:
return NOT_MENTIONED
if slot not in self.status[domain][intent]:
return NOT_MENTIONED
return self.status[domain][intent][slot]["status"]
def task_complete(self):
"""
Check that all requests have been met
Returns:
(boolean): True to accomplish.
"""
for domain, domain_goal in self.status.items():
if domain not in self.domain_goals:
continue
for slot_type, sub_domain_goal in domain_goal.items():
if slot_type not in self.domain_goals[domain]:
continue
for slot, status in sub_domain_goal.items():
if slot not in self.domain_goals[domain][slot_type]:
continue
# for strict success, turn this on
if status["status"] in [NOT_MENTIONED, CONFLICT]:
if status["status"] == CONFLICT and slot in ["arrive by", "leave at"]:
continue
return False
if "?" in status["value"]:
return False
return True
# TODO change to update()?
def update_user_goal(self, action, char="usr"):
# update request and booked
if char == "usr":
self._user_action_update(action)
elif char == "sys":
self._system_action_update(action)
else:
print("!!!UNKNOWN CHAR!!!")
def _user_action_update(self, action):
# no need to update user goal
for intent, domain, slot, _ in action:
goal_intent = self._check_slot_and_intent(domain, slot)
if not goal_intent:
continue
# fulfilled by user
if is_inform(intent):
self._set_status(goal_intent, domain, slot, FULFILLED)
# requested by user
if is_request(intent):
self._set_status(goal_intent, domain, slot, REQUESTED)
def _system_action_update(self, action):
for intent, domain, slot, value in action:
goal_intent = self._check_slot_and_intent(domain, slot)
if not goal_intent:
continue
# fulfill request by system
if is_inform(intent) and is_request(goal_intent):
self._set_status(goal_intent, domain, slot, FULFILLED)
self._set_goal(goal_intent, domain, slot, value)
if is_inform(intent) and is_inform(goal_intent):
# fulfill infrom by system
if value == self.domain_goals[domain][goal_intent][slot]:
self._set_status(goal_intent, domain, slot, FULFILLED)
# conflict system inform
else:
self._set_status(goal_intent, domain, slot, CONFLICT)
# requested by system
if is_request(intent) and is_inform(goal_intent):
self._set_status(goal_intent, domain, slot, REQUESTED)
def _set_status(self, intent, domain, slot, status):
self.status[domain][intent][slot]["status"] = status
def _set_goal(self, intent, domain, slot, value):
# old_value = self.domain_goals[domain][intent][slot]
self.domain_goals[domain][intent][slot] = value
self.status[domain][intent][slot]["value"] = value
# print(
# f"updating user goal {intent}-{domain}-{slot} {old_value}-> {value}")
def _check_slot_and_intent(self, domain, slot):
not_found = ""
if domain not in self.domain_goals:
return not_found
for intent in self.domain_goals[domain]:
if slot in self.domain_goals[domain][intent]:
return intent
return not_found
def is_inform(intent):
if "inform" in intent:
return True
return False
def is_request(intent):
if "request" in intent:
return True
return False
def transform_data_act(data_action):
action_list = []
for _, dialog_act in data_action.items():
for act in dialog_act:
value = act.get("value", "")
if not value:
if "request" in act["intent"]:
value = "?"
else:
value = "none"
action_list.append(
[act["intent"], act["domain"], act["slot"], value])
return action_list
import json
import os
import sys
from argparse import ArgumentParser
from tqdm import tqdm
from convlab.policy.genTUS.unify.Goal import Goal, transform_data_act
from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset
sys.path.append(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))))
def arg_parser():
parser = ArgumentParser()
parser.add_argument("--dataset", type=str, default="multiwoz21",
help="the dataset, such as multiwoz21, sgd, tm1, tm2, and tm3.")
parser.add_argument("--dial-ids-order", type=int, default=0)
parser.add_argument("--split2ratio", type=float, default=1)
parser.add_argument("--random-order", action="store_true")
parser.add_argument("--no-status", action="store_true")
parser.add_argument("--add-history", action="store_true")
parser.add_argument("--remove-domain", type=str, default="")
return parser.parse_args()
class DataBuilder:
def __init__(self, dataset='multiwoz21'):
self.dataset = dataset
def setup_data(self,
raw_data,
random_order=False,
no_status=False,
add_history=False,
remove_domain=None):
examples = {data_split: {"dialog": []} for data_split in raw_data}
for data_split, dialogs in raw_data.items():
for dialog in tqdm(dialogs, ascii=True):
example = self._one_dialog(dialog=dialog,
add_history=add_history,
random_order=random_order,
no_status=no_status)
examples[data_split]["dialog"] += example
return examples
def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False):
example = []
history = []
data_goal = self.norm_domain_goal(create_goal(dialog))
if not data_goal:
return example
user_goal = Goal(goal=data_goal)
for turn_id in range(0, len(dialog["turns"]), 2):
sys_act = self._get_sys_act(dialog, turn_id)
user_goal.update_user_goal(action=sys_act, char="sys")
usr_goal_str = self._user_goal_str(user_goal, data_goal, random_order, no_status)
usr_act = self.norm_domain(transform_data_act(
dialog["turns"][turn_id]["dialogue_acts"]))
user_goal.update_user_goal(action=usr_act, char="usr")
# change value "?" to "<?>"
usr_act = self._modify_act(usr_act)
in_str = self._dump_in_str(sys_act, usr_goal_str, history, turn_id, add_history)
out_str = self._dump_out_str(usr_act, dialog["turns"][turn_id]["utterance"])
history.append(usr_act)
if usr_act:
example.append({"in": in_str, "out": out_str})
return example
def _get_sys_act(self, dialog, turn_id):
sys_act = []
if turn_id > 0:
sys_act = self.norm_domain(transform_data_act(
dialog["turns"][turn_id - 1]["dialogue_acts"]))
return sys_act
def _user_goal_str(self, user_goal, data_goal, random_order, no_status):
if random_order:
usr_goal_str = user_goal.get_goal_list()
else:
usr_goal_str = user_goal.get_goal_list(data_goal=data_goal)
if no_status:
usr_goal_str = self._remove_status(usr_goal_str)
return usr_goal_str
def _dump_in_str(self, sys_act, usr_goal_str, history, turn_id, add_history):
in_str = {}
in_str["system"] = self._modify_act(sys_act)
in_str["goal"] = usr_goal_str
if add_history:
h = []
if history:
h = history[-3:]
in_str["history"] = h
in_str["turn"] = str(int(turn_id/2))
return json.dumps(in_str)
def _dump_out_str(self, usr_act, text):
out_str = {"action": usr_act, "text": text}
return json.dumps(out_str)
@staticmethod
def _norm_intent(intent):
if intent in ["inform_intent", "negate_intent", "affirm_intent", "request_alts"]:
return f"_{intent}"
return intent
def norm_domain(self, x):
if not x:
return x
norm_result = []
# print(x)
for intent, domain, slot, value in x:
if "_" in domain:
domain = domain.split('_')[0]
if not domain:
domain = "none"
if not slot:
slot = "none"
if not value:
if intent == "request":
value = "<?>"
else:
value = "none"
norm_result.append([self._norm_intent(intent), domain, slot, value])
return norm_result
def norm_domain_goal(self, x):
if not x:
return x
norm_result = []
# take care of the order!
for domain, intent, slot, value in x:
if "_" in domain:
domain = domain.split('_')[0]
if not domain:
domain = "none"
if not slot:
slot = "none"
if not value:
if intent == "request":
value = "<?>"
else:
value = "none"
norm_result.append([domain, self._norm_intent(intent), slot, value])
return norm_result
@staticmethod
def _remove_status(goal_list):
new_list = [[goal[0], goal[1], goal[2], goal[3]]
for goal in goal_list]
return new_list
@staticmethod
def _modify_act(act):
new_act = []
for i, d, s, value in act:
if value == "?":
new_act.append([i, d, s, "<?>"])
else:
new_act.append([i, d, s, value])
return new_act
if __name__ == "__main__":
args = arg_parser()
base_name = "convlab/policy/genTUS/unify/data"
dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}"
folder_name = os.path.join(base_name, dir_name)
remove_domain = args.remove_domain
if not os.path.exists(folder_name):
os.makedirs(folder_name)
dataset = load_experiment_dataset(
data_name=args.dataset,
dial_ids_order=args.dial_ids_order,
split2ratio=args.split2ratio)
data_builder = DataBuilder(dataset=args.dataset)
data = data_builder.setup_data(
raw_data=dataset,
random_order=args.random_order,
no_status=args.no_status,
add_history=args.add_history,
remove_domain=remove_domain)
for data_type in data:
if remove_domain:
file_name = os.path.join(
folder_name,
f"no{remove_domain}_{data_type}.json")
else:
file_name = os.path.join(
folder_name,
f"{data_type}.json")
json.dump(data[data_type], open(file_name, 'w'), indent=2)
import json
from random import choices
from convlab.policy.genTUS.token_map import tokenMap
from transformers import BartTokenizer
DEBUG = False
DATASET = "unify"
class KnowledgeGraph:
def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="multiwoz21"):
print("dataset", dataset)
self.debug = DEBUG
self.tokenizer = tokenizer
if "multiwoz" in dataset:
self.domain_intent = ["inform", "request"]
self.general_intent = ["thank", "bye"]
# use sgd dataset intents as default
else:
self.domain_intent = ["_inform_intent",
"_negate_intent",
"_affirm_intent",
"inform",
"request",
"affirm",
"negate",
"select",
"_request_alts"]
self.general_intent = ["thank_you", "goodbye"]
self.general_domain = "none"
self.kg_map = {"intent": tokenMap(tokenizer=self.tokenizer)}
for intent in self.domain_intent + self.general_intent:
self.kg_map["intent"].add_token(intent, intent)
self.init()
def init(self):
for map_type in ["domain", "slot", "value"]:
self.kg_map[map_type] = tokenMap(tokenizer=self.tokenizer)
self.add_token("<?>", "value")
def parse_input(self, in_str):
self.init()
inputs = json.loads(in_str)
self.sys_act = inputs["system"]
self.user_goal = {}
self._add_none_domain()
for intent, domain, slot, value, _ in inputs["goal"]:
self._update_user_goal(intent, domain, slot, value, source="goal")
for intent, domain, slot, value in self.sys_act:
self._update_user_goal(intent, domain, slot, value, source="sys")
def _add_none_domain(self):
self.user_goal["none"] = {"none": "none"}
# add slot
self.add_token("none", "domain")
self.add_token("none", "slot")
self.add_token("none", "value")
def _update_user_goal(self, intent, domain, slot, value, source="goal"):
if value == "?":
value = "<?>"
if intent == "request" and source == "sys":
value = "dontcare" # user can "dontcare" system request
if source == "sys" and intent != "request":
return
if domain not in self.user_goal:
self.user_goal[domain] = {}
self.user_goal[domain]["none"] = ["none"]
self.add_token(domain, "domain")
self.add_token("none", "slot")
self.add_token("none", "value")
if slot not in self.user_goal[domain]:
self.user_goal[domain][slot] = []
self.add_token(domain, "slot")
if value not in self.user_goal[domain][slot]:
value = json.dumps(str(value))[1:-1]
self.user_goal[domain][slot].append(value)
value = value.replace('"', "'")
self.add_token(value, "value")
def add_token(self, token_name, map_type):
if map_type == "value":
token_name = token_name.replace('"', "'")
if not self.kg_map[map_type].token_name_is_in(token_name):
self.kg_map[map_type].add_token(token_name, token_name)
def _get_max_score(self, outputs, candidate_list, map_type):
score = {}
if not candidate_list:
print(f"ERROR: empty candidate list for {map_type}")
score[1] = {"token_id": self._get_token_id(
"none"), "token_name": "none"}
for x in candidate_list:
hash_id = self._get_token_id(x)[0]
s = outputs[:, hash_id].item()
score[s] = {"token_id": self._get_token_id(x),
"token_name": x}
return score
def _select(self, score, mode="max"):
probs = [s for s in score]
if mode == "max":
s = max(probs)
elif mode == "sample":
s = choices(probs, weights=probs, k=1)
s = s[0]
else:
print("unknown select mode")
return s
def _get_max_domain_token(self, outputs, candidates, map_type, mode="max"):
score = self._get_max_score(outputs, candidates, map_type)
s = self._select(score, mode)
token_id = score[s]["token_id"]
token_name = score[s]["token_name"]
return {"token_id": token_id, "token_name": token_name}
def candidate(self, candidate_type, **kwargs):
if "intent" in kwargs:
intent = kwargs["intent"]
if candidate_type == "intent":
allow_general_intent = kwargs.get("allow_general_intent", True)
if allow_general_intent:
return self.domain_intent + self.general_intent
else:
return self.domain_intent
elif candidate_type == "domain":
if intent in self.general_intent:
return [self.general_domain]
else:
return [d for d in self.user_goal]
elif candidate_type == "slot":
if intent in self.general_intent:
return ["none"]
else:
return self._filter_slot(intent, kwargs["domain"], kwargs["is_mentioned"])
else:
if intent in self.general_intent:
return ["none"]
elif intent.lower() == "request":
return ["<?>"]
else:
return self._filter_value(intent, kwargs["domain"], kwargs["slot"])
def get_intent(self, outputs, mode="max", allow_general_intent=True):
# return intent, token_id_list
# TODO request?
canidate_list = self.candidate(
"intent", allow_general_intent=allow_general_intent)
score = self._get_max_score(outputs, canidate_list, "intent")
s = self._select(score, mode)
return score[s]
def get_domain(self, outputs, intent, mode="max"):
if intent in self.general_intent:
token_name = self.general_domain
token_id = self.tokenizer(token_name, add_special_tokens=False)
token_map = {"token_id": token_id['input_ids'],
"token_name": token_name}
elif intent in self.domain_intent:
# [d for d in self.user_goal]
domain_list = self.candidate("domain", intent=intent)
token_map = self._get_max_domain_token(
outputs=outputs, candidates=domain_list, map_type="domain", mode=mode)
else:
if self.debug:
print("unknown intent", intent)
return token_map
def get_slot(self, outputs, intent, domain, mode="max", is_mentioned=False):
if intent in self.general_intent:
token_name = "none"
token_id = self.tokenizer(token_name, add_special_tokens=False)
token_map = {"token_id": token_id['input_ids'],
"token_name": token_name}
elif intent in self.domain_intent:
slot_list = self.candidate(
candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned)
token_map = self._get_max_domain_token(
outputs=outputs, candidates=slot_list, map_type="slot", mode=mode)
return token_map
def get_value(self, outputs, intent, domain, slot, mode="max"):
if intent in self.general_intent or slot.lower() == "none":
token_name = "none"
token_id = self.tokenizer(token_name, add_special_tokens=False)
token_map = {"token_id": token_id['input_ids'],
"token_name": token_name}
elif intent.lower() == "request":
token_name = "<?>"
token_id = self.tokenizer(token_name, add_special_tokens=False)
token_map = {"token_id": token_id['input_ids'],
"token_name": token_name}
elif intent in self.domain_intent:
# TODO should not none ?
# value_list = [v for v in self.user_goal[domain][slot]]
value_list = self.candidate(
candidate_type="value", intent=intent, domain=domain, slot=slot)
token_map = self._get_max_domain_token(
outputs=outputs, candidates=value_list, map_type="value", mode=mode)
return token_map
def _filter_slot(self, intent, domain, is_mentioned=True):
slot_list = []
for slot in self.user_goal[domain]:
value_list = self._filter_value(intent, domain, slot)
if len(value_list) > 0:
slot_list.append(slot)
if not is_mentioned and intent.lower() != "request":
slot_list.append("none")
return slot_list
def _filter_value(self, intent, domain, slot):
value_list = [v for v in self.user_goal[domain][slot]]
if "none" in value_list:
value_list.remove("none")
if intent.lower() != "request":
if "?" in value_list:
value_list.remove("?")
if "<?>" in value_list:
value_list.remove("<?>")
# print(f"{intent}-{domain}-{slot}= {value_list}")
return value_list
def _get_token_id(self, token):
return self.tokenizer(token, add_special_tokens=False)["input_ids"]
import torch
def append_tokens(tokens, new_token, device):
return torch.cat((tokens, torch.tensor([new_token]).to(device)), dim=1)
{
"model": {
"load_path": "convlab/policy/ppo/pretrained_models/supervised",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 1000,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
"process_num": 1,
"num_eval_dialogues": 500,
"sys_semantic_to_usr": false
},
"vectorizer_sys": {
"uncertainty_vector_mul": {
"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
"ini_params": {
"use_masking": true,
"manually_add_entity_names": true,
"seed": 0
}
}
},
"nlu_sys": {},
"dst_sys": {
"RuleDST": {
"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
"ini_params": {}
}
},
"sys_nlg": {},
"nlu_usr": {},
"dst_usr": {},
"policy_usr": {
"RulePolicy": {
"class_path": "convlab.policy.genTUS.stepGenTUS.UserPolicy",
"ini_params": {
"model_checkpoint": "convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0",
"character": "usr"
}
}
},
"usr_nlg": {}
}
\ No newline at end of file
import time
import json
from convlab.policy.tus.unify.util import split_slot_name
from convlab.policy.tus.unify.util import split_slot_name, slot_name_map
from convlab.util.custom_util import slot_mapping
from random import sample, shuffle
from pprint import pprint
DEF_VAL_UNK = '?' # Unknown
DEF_VAL_DNC = 'dontcare' # Do not care
......@@ -27,6 +30,35 @@ def isTimeFormat(input):
return False
def old_goal2list(goal: dict, reorder=False) -> list:
goal_list = []
for domain in goal:
for slot_type in ['info', 'book', 'reqt']:
if slot_type not in goal[domain]:
continue
temp = []
for slot in goal[domain][slot_type]:
s = slot
if slot in slot_name_map:
s = slot_name_map[slot]
elif slot in slot_name_map[domain]:
s = slot_name_map[domain][slot]
# domain, intent, slot, value
if slot_type in ['info', 'book']:
i = "inform"
v = goal[domain][slot_type][slot]
else:
i = "request"
v = DEF_VAL_UNK
s = slot_mapping.get(s, s)
temp.append([domain, i, s, v])
shuffle(temp)
goal_list = goal_list + temp
# shuffle_goal = goal_list[:1] + sample(goal_list[1:], len(goal_list)-1)
# return shuffle_goal
return goal_list
class Goal(object):
""" User Goal Model Class. """
......@@ -101,6 +133,7 @@ class Goal(object):
if self.domain_goals[domain]["reqt"][slot] == DEF_VAL_UNK:
# print(f"not fulfilled request{domain}-{slot}")
return False
return True
def init_local_id(self):
......@@ -169,6 +202,10 @@ class Goal(object):
def _update_status(self, action: list, char: str):
for intent, domain, slot, value in action:
if slot == "arrive by":
slot = "arriveBy"
elif slot == "leave at":
slot = "leaveAt"
if domain not in self.status:
self.status[domain] = {}
# update info
......@@ -180,6 +217,10 @@ class Goal(object):
def _update_goal(self, action: list, char: str):
# update requt slots in goal
for intent, domain, slot, value in action:
if slot == "arrive by":
slot = "arriveBy"
elif slot == "leave at":
slot = "leaveAt"
if "info" not in intent:
continue
if self._check_update_request(domain, slot) and value != "?":
......
......@@ -5,19 +5,18 @@ from copy import deepcopy
import torch
from convlab.policy.policy import Policy
from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import (
act_dict_to_flat_tuple, unified_format)
from convlab.policy.tus.multiwoz.transformer import TransformerActionPrediction
from convlab.policy.tus.unify.Goal import Goal
from convlab.policy.tus.unify.usermanager import BinaryFeature
from convlab.policy.tus.unify.util import (create_goal, int2onehot,
metadata2state, parse_dialogue_act,
parse_user_goal, split_slot_name)
from convlab.policy.tus.unify.util import create_goal, split_slot_name
from convlab.util import (load_dataset,
relative_import_module_from_unified_datasets)
from convlab.util.custom_util import model_downloader
from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA
from pprint import pprint
from convlab.task.multiwoz.goal_generator import GoalGenerator
from convlab.policy.tus.unify.Goal import old_goal2list
from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets(
'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value'])
......@@ -51,7 +50,7 @@ class UserActionPolicy(Policy):
self.user = TransformerActionPrediction(self.config).to(device=DEVICE)
if pretrain:
model_path = os.path.join(
self.config["model_dir"], self.config["model_name"])
self.config["model_dir"], "model-non-zero") # self.config["model_name"])
print(f"loading model from {model_path}...")
self.load(model_path)
self.user.eval()
......@@ -61,6 +60,8 @@ class UserActionPolicy(Policy):
self.reward = {"success": 40,
"fail": -20}
self.sys_acts = []
self.goal_gen = GoalGenerator()
self.raw_goal = None
def _no_offer(self, system_in):
for intent, domain, slot, value in system_in:
......@@ -127,13 +128,15 @@ class UserActionPolicy(Policy):
self.topic = 'NONE'
remove_domain = "police" # remove police domain in inference
# if not goal:
# self.new_goal(remove_domain=remove_domain)
# else:
# self.read_goal(goal)
if not goal:
data = load_dataset(self.dataset, 0)
goal = Goal(create_goal(data["test"][0]))
if type(goal) == ABUS_Goal:
self.raw_goal = goal.domain_goals
goal_list = old_goal2list(goal.domain_goals)
goal = Goal(goal_list)
else:
goal = ABUS_Goal(self.goal_gen)
self.raw_gaol = goal.domain_goals
goal_list = old_goal2list(goal.domain_goals)
goal = Goal(goal_list)
self.read_goal(goal)
self.feat_handler.initFeatureHandeler(self.goal)
......@@ -155,15 +158,15 @@ class UserActionPolicy(Policy):
else:
self.goal = Goal(goal=data_goal)
def new_goal(self, remove_domain="police", domain_len=None):
keep_generate_goal = True
while keep_generate_goal:
self.goal = Goal(goal_generator=self.goal_gen)
if (domain_len and len(self.goal.domains) != domain_len) or \
(remove_domain and remove_domain in self.goal.domains):
keep_generate_goal = True
else:
keep_generate_goal = False
# def new_goal(self, remove_domain="police", domain_len=None):
# keep_generate_goal = True
# while keep_generate_goal:
# self.goal = Goal(goal_generator=self.goal_gen)
# if (domain_len and len(self.goal.domains) != domain_len) or \
# (remove_domain and remove_domain in self.goal.domains):
# keep_generate_goal = True
# else:
# keep_generate_goal = False
def load(self, model_path=None):
self.user.load_state_dict(torch.load(model_path, map_location=DEVICE))
......@@ -171,9 +174,14 @@ class UserActionPolicy(Policy):
def load_state_dict(self, model=None):
self.user.load_state_dict(model)
def get_goal(self):
def _get_goal(self):
# internal usage
return self.goal.domain_goals
def get_goal(self):
# for outside usage, e.g. evaluator
return self.raw_goal
def get_reward(self):
if self.goal.task_complete():
# reward = 2 * self.max_turn
......@@ -330,7 +338,7 @@ class UserActionPolicy(Policy):
return predict_domain
def _add_user_action(self, output, domain, slot):
goal = self.get_goal()
goal = self._get_goal()
is_action = False
act = [[]]
value = None
......@@ -403,16 +411,32 @@ class UserPolicy(Policy):
self.config = json.load(open(config))
else:
self.config = config
self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}'
self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz'
if not os.path.exists(self.config["model_dir"]):
# os.mkdir(self.config["model_dir"])
model_downloader(os.path.dirname(self.config["model_dir"]),
"https://zenodo.org/record/5779832/files/default.zip")
self.slot2dbattr = {
'open hours': 'openhours',
'price range': 'pricerange',
'arrive by': 'arriveBy',
'leave at': 'leaveAt',
'train id': 'trainID'
}
self.dbattr2slot = {}
for k, v in self.slot2dbattr.items():
self.dbattr2slot[v] = k
self.policy = UserActionPolicy(self.config)
def predict(self, state):
return self.policy.predict(state)
raw_act = self.policy.predict(state)
act = []
for intent, domain, slot, value in raw_act:
if slot in self.dbattr2slot:
slot = self.dbattr2slot[slot]
act.append([intent, domain, slot, value])
return act
def init_session(self, goal=None):
self.policy.init_session(goal)
......@@ -424,14 +448,8 @@ class UserPolicy(Policy):
return self.policy.get_reward()
def get_goal(self):
slot2dbattr = {
'open hours': 'openhours',
'price range': 'pricerange',
'arrive by': 'arriveBy',
'leave at': 'leaveAt',
'train id': 'trainID'
}
if hasattr(self.policy, 'get_goal'):
return self.policy.get_goal()
# workaround: convert goal to old format
multiwoz_goal = {}
goal = self.policy.get_goal()
......@@ -449,8 +467,8 @@ class UserPolicy(Policy):
multiwoz_goal[domain]["book"] = {}
norm_slot = slot.split(' ')[-1]
multiwoz_goal[domain]["book"][norm_slot] = value
elif slot in slot2dbattr:
norm_slot = slot2dbattr[slot]
elif slot in self.slot2dbattr:
norm_slot = self.slot2dbattr[slot]
multiwoz_goal[domain][slot_type][norm_slot] = value
else:
multiwoz_goal[domain][slot_type][slot] = value
......
......@@ -97,7 +97,7 @@ class TUSDataManager(Dataset):
action_list, user_goal, cur_state, usr_act)
domain_label = self.feature_handler.domain_label(
user_goal, usr_act)
pre_state = user_goal.update(action=usr_act, char="user")
# pre_state = user_goal.update(action=usr_act, char="user") # trick?
feature["id"].append(dialog["dialogue_id"])
feature["input"].append(input_feature)
feature["mask"].append(mask)
......
from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
from convlab.util import load_dataset
import json
NOT_MENTIONED = "not mentioned"
def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1):
ratio = {'train': split2ratio, 'validation': split2ratio}
if data_name == "all" or data_name == "sgd+tm" or data_name == "tm":
print("merge all datasets...")
if data_name == "all":
all_dataset = ["multiwoz21", "sgd", "tm1", "tm2", "tm3"]
if data_name == "sgd+tm":
all_dataset = ["sgd", "tm1", "tm2", "tm3"]
if data_name == "tm":
all_dataset = ["tm1", "tm2", "tm3"]
datasets = {}
for name in all_dataset:
datasets[name] = load_dataset(
name,
dial_ids_order=dial_ids_order,
split2ratio=ratio)
raw_data = merge_dataset(datasets, all_dataset[0])
else:
print(f"load single dataset {data_name}/{split2ratio}")
raw_data = load_dataset(data_name,
dial_ids_order=dial_ids_order,
split2ratio=ratio)
return raw_data
def merge_dataset(datasets, data_name):
data_split = [x for x in datasets[data_name]]
raw_data = {}
for data_type in data_split:
raw_data[data_type] = []
for dataname, dataset in datasets.items():
print(f"merge {dataname}...")
raw_data[data_type] += dataset[data_type]
return raw_data
def int2onehot(index, output_dim=6, remove_zero=False):
one_hot = [0] * output_dim
if remove_zero:
......@@ -89,50 +129,6 @@ def get_booking_domain(slot, value, all_values, domain_list):
return found
def act2slot(intent, domain, slot, value, all_values):
if domain not in UsrDa2Goal:
# print(f"Not handle domain {domain}")
return ""
if domain == "booking":
slot = SysDa2Goal[domain][slot]
domain = get_booking_domain(slot, value, all_values)
return f"{domain}-{slot}"
elif domain in UsrDa2Goal:
if slot in SysDa2Goal[domain]:
slot = SysDa2Goal[domain][slot]
elif slot in UsrDa2Goal[domain]:
slot = UsrDa2Goal[domain][slot]
elif slot in SysDa2Goal["booking"]:
slot = SysDa2Goal["booking"][slot]
# else:
# print(
# f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}")
return f"{domain}-{slot}"
print("strange!!!")
print(intent, domain, slot, value)
return ""
def get_user_history(dialog, all_values):
turn_num = len(dialog)
mentioned_slot = []
for turn_id in range(0, turn_num, 2):
usr_act = parse_dialogue_act(
dialog[turn_id]["dialog_act"])
for intent, domain, slot, value in usr_act:
slot_name = act2slot(
intent, domain.lower(), slot.lower(), value.lower(), all_values)
if slot_name not in mentioned_slot:
mentioned_slot.append(slot_name)
return mentioned_slot
def update_config_file(file_name, attribute, value):
with open(file_name, 'r') as config_file:
config = json.load(config_file)
......@@ -147,7 +143,7 @@ def update_config_file(file_name, attribute, value):
def create_goal(dialog) -> list:
# a list of {'intent': ..., 'domain': ..., 'slot': ..., 'value': ...}
dicts = []
for i, turn in enumerate(dialog['turns']):
for turn in dialog['turns']:
# print(turn['speaker'])
# assert (i % 2 == 0) == (turn['speaker'] == 'user')
# if i % 2 == 0:
......@@ -205,6 +201,45 @@ def split_slot_name(slot_name):
return tokens[0], '-'.join(tokens[1:])
# copy from data.unified_datasets.multiwoz21
slot_name_map = {
'addr': "address",
'post': "postcode",
'pricerange': "price range",
'arrive': "arrive by",
'arriveby': "arrive by",
'leave': "leave at",
'leaveat': "leave at",
'depart': "departure",
'dest': "destination",
'fee': "entrance fee",
'open': 'open hours',
'car': "type",
'car type': "type",
'ticket': 'price',
'trainid': 'train id',
'id': 'train id',
'people': 'book people',
'stay': 'book stay',
'none': '',
'attraction': {
'price': 'entrance fee'
},
'hospital': {},
'hotel': {
'day': 'book day', 'price': "price range"
},
'restaurant': {
'day': 'book day', 'time': 'book time', 'price': "price range"
},
'taxi': {},
'train': {
'day': 'day', 'time': "duration"
},
'police': {},
'booking': {}
}
if __name__ == "__main__":
print(split_slot_name("restaurant-search-location"))
print(split_slot_name("sports-day.match"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment