diff --git a/convlab/policy/genTUS/evaluate.py b/convlab/policy/genTUS/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..87de854970d2701900ba180d2bf15736071e0c1a --- /dev/null +++ b/convlab/policy/genTUS/evaluate.py @@ -0,0 +1,257 @@ +import json +import os +import sys +from argparse import ArgumentParser +from pprint import pprint + +import torch +from convlab.nlg.evaluate import fine_SER +from datasets import load_metric + +# from convlab.policy.genTUS.pg.stepGenTUSagent import \ +# stepGenTUSPG as UserPolicy +from convlab.policy.genTUS.stepGenTUS import UserActionPolicy +from tqdm import tqdm + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model-checkpoint", type=str, help="the model path") + parser.add_argument("--model-weight", type=str, + help="the model weight", default="") + parser.add_argument("--input-file", type=str, help="the testing input file", + default="") + parser.add_argument("--generated-file", type=str, help="the generated results", + default="") + parser.add_argument("--only-action", action="store_true") + parser.add_argument("--dataset", default="multiwoz") + parser.add_argument("--do-semantic", action="store_true", + help="do semantic evaluation") + parser.add_argument("--do-nlg", action="store_true", + help="do nlg generation") + parser.add_argument("--do-golden-nlg", action="store_true", + help="do golden nlg generation") + return parser.parse_args() + + +class Evaluator: + def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False): + self.dataset = dataset + self.model_checkpoint = model_checkpoint + self.model_weight = model_weight + # if model_weight: + # self.usr_policy = UserPolicy( + # self.model_checkpoint, only_action=only_action) + # self.usr_policy.load(model_weight) + # self.usr = self.usr_policy.usr + # else: + self.usr = UserActionPolicy( + model_checkpoint, only_action=only_action, dataset=self.dataset) + self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) + + def generate_results(self, f_eval, golden=False): + in_file = json.load(open(f_eval)) + r = { + "input": [], + "golden_acts": [], + "golden_utts": [], + "gen_acts": [], + "gen_utts": [] + } + for dialog in tqdm(in_file['dialog']): + inputs = dialog["in"] + labels = self.usr._parse_output(dialog["out"]) + if golden: + usr_act = labels["action"] + usr_utt = self.usr.generate_text_from_give_semantic( + inputs, usr_act) + + else: + output = self.usr._parse_output( + self.usr._generate_action(inputs)) + usr_act = self.usr._remove_illegal_action(output["action"]) + usr_utt = output["text"] + r["input"].append(inputs) + r["golden_acts"].append(labels["action"]) + r["golden_utts"].append(labels["text"]) + r["gen_acts"].append(usr_act) + r["gen_utts"].append(usr_utt) + + return r + + def read_generated_result(self, f_eval): + in_file = json.load(open(f_eval)) + r = { + "input": [], + "golden_acts": [], + "golden_utts": [], + "gen_acts": [], + "gen_utts": [] + } + for dialog in tqdm(in_file['dialog']): + for x in dialog: + r[x].append(dialog[x]) + + return r + + def nlg_evaluation(self, input_file=None, generated_file=None, golden=False): + if input_file: + print("Force generation") + gen_r = self.generate_results(input_file, golden) + + elif generated_file: + gen_r = self.read_generated_result(generated_file) + else: + print("You must specify the input_file or the generated_file") + + nlg_eval = { + "golden": golden, + "metrics": {}, + "dialog": [] + } + for input, golden_act, golden_utt, gen_act, gen_utt in zip(gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["gen_acts"], gen_r["gen_utts"]): + nlg_eval["dialog"].append({ + "input": input, + "golden_acts": golden_act, + "golden_utts": golden_utt, + "gen_acts": gen_act, + "gen_utts": gen_utt + }) + + if golden: + print("Calculate BLEU") + bleu_metric = load_metric("sacrebleu") + labels = [[utt] for utt in gen_r["golden_utts"]] + + bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"], + references=labels, + force=True) + print("bleu_metric", bleu_score) + nlg_eval["metrics"]["bleu"] = bleu_score + + else: + print("Calculate SER") + missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( + gen_r["gen_acts"], gen_r["gen_utts"]) + + print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format( + "genTUSNLG", missing, total, hallucinate, missing/total)) + nlg_eval["metrics"]["SER"] = missing/total + + dir_name = self.model_checkpoint + json.dump(nlg_eval, + open(os.path.join(dir_name, "nlg_eval.json"), 'w'), + indent=2) + return os.path.join(dir_name, "nlg_eval.json") + + def evaluation(self, input_file=None, generated_file=None): + force_prediction = True + if generated_file: + gen_file = json.load(open(generated_file)) + force_prediction = False + if gen_file["golden"]: + force_prediction = True + + if force_prediction: + in_file = json.load(open(input_file)) + dialog_result = [] + gen_acts, golden_acts = [], [] + # scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []} + for dialog in tqdm(in_file['dialog']): + inputs = dialog["in"] + labels = self.usr._parse_output(dialog["out"]) + ans_action = self.usr._remove_illegal_action(labels["action"]) + preds = self.usr._generate_action(inputs) + preds = self.usr._parse_output(preds) + usr_action = self.usr._remove_illegal_action(preds["action"]) + + gen_acts.append(usr_action) + golden_acts.append(ans_action) + + d = {"input": inputs, + "golden_acts": ans_action, + "gen_acts": usr_action} + if "text" in preds: + d["golden_utts"] = labels["text"] + d["gen_utts"] = preds["text"] + # print("pred text", preds["text"]) + + dialog_result.append(d) + else: + gen_acts, golden_acts = [], [] + for dialog in gen_file['dialog']: + gen_acts.append(dialog["gen_acts"]) + golden_acts.append(dialog["golden_acts"]) + dialog_result = gen_file['dialog'] + + scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []} + + for gen_act, golden_act in zip(gen_acts, golden_acts): + s = f1_measure(preds=gen_act, labels=golden_act) + for metric in scores: + scores[metric].append(s[metric]) + + result = {} + for metric in scores: + result[metric] = sum(scores[metric])/len(scores[metric]) + print(f"{metric}: {result[metric]}") + + result["dialog"] = dialog_result + basename = "semantic_evaluation_result" + json.dump(result, open(os.path.join( + self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) + # if self.model_weight: + # json.dump(result, open(os.path.join( + # 'results', f"{basename}.json"), 'w')) + # else: + # json.dump(result, open(os.path.join( + # self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) + + +def f1_measure(preds, labels): + tp = 0 + score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0} + for p in preds: + if p in labels: + tp += 1.0 + if preds: + score["precision"] = tp/len(preds) + if labels: + score["recall"] = tp/len(labels) + if (score["precision"] + score["recall"]) > 0: + score["f1"] = 2*(score["precision"]*score["recall"]) / \ + (score["precision"]+score["recall"]) + if tp == len(preds) and tp == len(labels): + score["turn_acc"] = 1 + return score + + +def main(): + args = arg_parser() + eval = Evaluator(args.model_checkpoint, + args.dataset, + args.model_weight, + args.only_action) + print("model checkpoint", args.model_checkpoint) + print("generated_file", args.generated_file) + print("input_file", args.input_file) + with torch.no_grad(): + if args.do_semantic: + eval.evaluation(args.input_file) + if args.do_nlg: + nlg_result = eval.nlg_evaluation(input_file=args.input_file, + generated_file=args.generated_file, + golden=args.do_golden_nlg) + if args.generated_file: + generated_file = args.generated_file + else: + generated_file = nlg_result + eval.evaluation(args.input_file, + generated_file) + + +if __name__ == '__main__': + main() diff --git a/convlab/policy/genTUS/ppo/vector.py b/convlab/policy/genTUS/ppo/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..4c502a46f87582008ff49219f8a14844378b9ed2 --- /dev/null +++ b/convlab/policy/genTUS/ppo/vector.py @@ -0,0 +1,148 @@ +import json + +import torch +from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph +from convlab.policy.genTUS.token_map import tokenMap +from convlab.policy.tus.unify.Goal import Goal +from transformers import BartTokenizer + + +class stepGenTUSVector: + def __init__(self, model_checkpoint, max_in_len=400, max_out_len=80, allow_general_intent=True): + self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint) + self.vocab = len(self.tokenizer) + self.max_in_len = max_in_len + self.max_out_len = max_out_len + self.token_map = tokenMap(tokenizer=self.tokenizer) + self.token_map.default(only_action=True) + self.kg = KnowledgeGraph(self.tokenizer) + self.mentioned_domain = [] + self.allow_general_intent = allow_general_intent + self.candidate_num = 5 + if self.allow_general_intent: + print("---> allow_general_intent") + + def init_session(self, goal: Goal): + self.goal = goal + self.mentioned_domain = [] + + def encode(self, raw_inputs, max_length, return_tensors="pt", truncation=True): + model_input = self.tokenizer(raw_inputs, + max_length=max_length, + return_tensors=return_tensors, + truncation=truncation, + padding="max_length") + return model_input + + def decode(self, generated_so_far, skip_special_tokens=True): + output = self.tokenizer.decode( + generated_so_far, skip_special_tokens=skip_special_tokens) + return output + + def state_vectorize(self, action, history, turn): + self.goal.update_user_goal(action=action) + inputs = json.dumps({"system": action, + "goal": self.goal.get_goal_list(), + "history": history, + "turn": str(turn)}) + inputs = self.encode(inputs, self.max_in_len) + s_vec, action_mask = inputs["input_ids"][0], inputs["attention_mask"][0] + + return s_vec, action_mask + + def action_vectorize(self, action, s=None): + # action: [[intent, domain, slot, value], ...] + vec = {"vector": torch.tensor([]), "mask": torch.tensor([])} + if s is not None: + raw_inputs = self.decode(s[0]) + self.kg.parse_input(raw_inputs) + + self._append(vec, self._get_id("<s>")) + self._append(vec, self.token_map.get_id('start_json')) + self._append(vec, self.token_map.get_id('start_act')) + + act_len = len(action) + for i, (intent, domain, slot, value) in enumerate(action): + if value == '?': + value = '<?>' + c_idx = {x: None for x in ["intent", "domain", "slot", "value"]} + + if s is not None: + c_idx["intent"] = self._candidate_id(self.kg.candidate( + "intent", allow_general_intent=self.allow_general_intent)) + c_idx["domain"] = self._candidate_id(self.kg.candidate( + "domain", intent=intent)) + c_idx["slot"] = self._candidate_id(self.kg.candidate( + "slot", intent=intent, domain=domain, is_mentioned=self.is_mentioned(domain))) + c_idx["value"] = self._candidate_id(self.kg.candidate( + "value", intent=intent, domain=domain, slot=slot)) + + self._append(vec, self._get_id(intent), c_idx["intent"]) + self._append(vec, self.token_map.get_id('sep_token')) + self._append(vec, self._get_id(domain), c_idx["domain"]) + self._append(vec, self.token_map.get_id('sep_token')) + self._append(vec, self._get_id(slot), c_idx["slot"]) + self._append(vec, self.token_map.get_id('sep_token')) + self._append(vec, self._get_id(value), c_idx["value"]) + + c_idx = [0]*self.candidate_num + c_idx[0] = self.token_map.get_id('end_act')[0] + c_idx[1] = self.token_map.get_id('sep_act')[0] + if i == act_len - 1: + x = self.token_map.get_id('end_act') + else: + x = self.token_map.get_id('sep_act') + + self._append(vec, x, c_idx) + + self._append(vec, self._get_id("</s>")) + + # pad + if len(vec["vector"]) < self.max_out_len: + pad_len = self.max_out_len-len(vec["vector"]) + self._append(vec, x=torch.tensor([1]*pad_len)) + for vec_type in vec: + vec[vec_type] = vec[vec_type].to(torch.int64) + + return vec + + def _append(self, vec, x, candidate=None): + if type(x) is list: + x = torch.tensor(x) + mask = self._mask(x, candidate) + vec["vector"] = torch.cat((vec["vector"], x), dim=-1) + vec["mask"] = torch.cat((vec["mask"], mask), dim=0) + + def _mask(self, idx, c_idx=None): + mask = torch.zeros(len(idx), self.candidate_num) + mask[:, 0] = idx + if c_idx is not None and len(c_idx) > 1: + mask[0, :] = torch.tensor(c_idx) + + return mask + + def _candidate_id(self, candidate): + if len(candidate) > self.candidate_num: + print(f"too many candidates. Max = {self.candidate_num}") + c_idx = [0]*self.candidate_num + for i, idx in enumerate([self._get_id(c)[0] for c in candidate[:self.candidate_num]]): + c_idx[i] = idx + return c_idx + + def _get_id(self, value): + token_id = self.tokenizer(value, add_special_tokens=False) + return token_id["input_ids"] + + def action_devectorize(self, action_id): + return self.decode(action_id) + + def update_mentioned_domain(self, semantic_act): + for act in semantic_act: + domain = act[1] + if domain not in self.mentioned_domain: + self.mentioned_domain.append(domain) + + def is_mentioned(self, domain): + if domain in self.mentioned_domain: + return True + return False diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py new file mode 100644 index 0000000000000000000000000000000000000000..f16c0ebeeeaaba50f1739aea6c4db40eb81f8d29 --- /dev/null +++ b/convlab/policy/genTUS/stepGenTUS.py @@ -0,0 +1,653 @@ +import json +import os + +import torch +from transformers import BartTokenizer + +from convlab.policy.genTUS.ppo.vector import stepGenTUSVector +from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel +from convlab.policy.genTUS.token_map import tokenMap +from convlab.policy.genTUS.unify.Goal import Goal +from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph +from convlab.policy.policy import Policy +from convlab.task.multiwoz.goal_generator import GoalGenerator + +DEBUG = False + + +class UserActionPolicy(Policy): + def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs): + self.mode = mode + # if mode == "semantic" and only_action: + # # only generate semantic action in prediction + print("model_checkpoint", model_checkpoint) + self.only_action = only_action + if self.only_action: + print("change mode to semantic because only_action=True") + self.mode = "semantic" + self.max_in_len = 500 + self.max_out_len = 50 if only_action else 200 + max_act_len = kwargs.get("max_act_len", 2) + print("max_act_len", max_act_len) + self.max_action_len = max_act_len + if "max_act_len" in kwargs: + self.max_out_len = 30 * self.max_action_len + print("max_act_len", self.max_out_len) + self.max_turn = max_turn + if mode not in ["semantic", "language"]: + print("Unknown user mode") + + self.reward = {"success": self.max_turn*2, + "fail": self.max_turn*-1} + self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + train_whole_model = kwargs.get("whole_model", True) + self.model = stepGenTUSmodel( + model_checkpoint, train_whole_model=train_whole_model) + self.model.eval() + self.model.to(self.device) + self.model.share_memory() + + self.turn_level_reward = kwargs.get("turn_level_reward", True) + self.cooperative = kwargs.get("cooperative", True) + + dataset = kwargs.get("dataset", "") + self.kg = KnowledgeGraph( + tokenizer=self.tokenizer, + dataset=dataset) + + self.goal_gen = GoalGenerator() + + self.vector = stepGenTUSVector( + model_checkpoint, self.max_in_len, self.max_out_len) + self.norm_reward = False + + self.action_penalty = kwargs.get("action_penalty", False) + self.usr_act_penalize = kwargs.get("usr_act_penalize", 0) + self.goal_list_type = kwargs.get("goal_list_type", "normal") + self.update_mode = kwargs.get("update_mode", "normal") + self.max_history = kwargs.get("max_history", 3) + self.init_session() + + def _update_seq(self, sub_seq: list, pos: int): + for x in sub_seq: + self.seq[0, pos] = x + pos += 1 + + return pos + + def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True): + # TODO no duplicate + self.kg.parse_input(raw_inputs) + model_input = self.vector.encode(raw_inputs, self.max_in_len) + # start token + self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() + pos = self._update_seq([0], 0) + pos = self._update_seq(self.token_map.get_id('start_json'), pos) + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + + # get semantic actions + for act_len in range(self.max_action_len): + pos = self._get_semantic_action( + model_input, pos, mode, allow_general_intent) + + terminate, token_name = self._stop_semantic( + model_input, pos, act_len) + pos = self._update_seq(self.token_map.get_id(token_name), pos) + + if terminate: + break + + if self.only_action: + # return semantic action. Don't need to generate text + return self.vector.decode(self.seq[0, :pos]) + + # TODO remove illegal action here? + + # get text output + pos = self._update_seq(self.token_map.get_id("start_text"), pos) + + text = self._get_text(model_input, pos) + + return text + + def generate_text_from_give_semantic(self, raw_inputs, semantic_action): + self.kg.parse_input(raw_inputs) + model_input = self.vector.encode(raw_inputs, self.max_in_len) + self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() + pos = self._update_seq([0], 0) + pos = self._update_seq(self.token_map.get_id('start_json'), pos) + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + + if len(semantic_action) == 0: + pos = self._update_seq(self.token_map.get_id("end_act"), pos) + + for act_id, (intent, domain, slot, value) in enumerate(semantic_action): + pos = self._update_seq(self.kg._get_token_id(intent), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.kg._get_token_id(domain), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.kg._get_token_id(slot), pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.kg._get_token_id(value), pos) + + if act_id == len(semantic_action) - 1: + token_name = "end_act" + else: + token_name = "sep_act" + pos = self._update_seq(self.token_map.get_id(token_name), pos) + pos = self._update_seq(self.token_map.get_id("start_text"), pos) + + raw_output = self._get_text(model_input, pos) + return self._parse_output(raw_output)["text"] + + def _get_text(self, model_input, pos): + s_pos = pos + for i in range(s_pos, self.max_out_len): + next_token_logits = self.model.get_next_token_logits( + model_input, self.seq[:1, :pos]) + next_token = torch.argmax(next_token_logits, dim=-1) + + if self._stop_text(next_token): + # text = self.vector.decode(self.seq[0, s_pos:pos]) + # text = self._norm_str(text) + # return self.vector.decode(self.seq[0, :s_pos]) + text + '"}' + break + + pos = self._update_seq([next_token], pos) + text = self.vector.decode(self.seq[0, s_pos:pos]) + text = self._norm_str(text) + return self.vector.decode(self.seq[0, :s_pos]) + text + '"}' + # TODO return None + + def _stop_text(self, next_token): + if next_token == self.token_map.get_id("end_json")[0]: + return True + elif next_token == self.token_map.get_id("end_json_2")[0]: + return True + + return False + + @staticmethod + def _norm_str(text: str): + text = text.strip('"') + text = text.replace('"', "'") + text = text.replace('\\', "") + return text + + def _stop_semantic(self, model_input, pos, act_length=0): + + outputs = self.model.get_next_token_logits( + model_input, self.seq[:1, :pos]) + tokens = {} + for token_name in ['sep_act', 'end_act']: + tokens[token_name] = { + "token_id": self.token_map.get_id(token_name)} + hash_id = tokens[token_name]["token_id"][0] + tokens[token_name]["score"] = outputs[:, hash_id].item() + + if tokens['end_act']["score"] > tokens['sep_act']["score"]: + terminate = True + else: + terminate = False + + if act_length >= self.max_action_len - 1: + terminate = True + + token_name = "end_act" if terminate else "sep_act" + + return terminate, token_name + + def _get_semantic_action(self, model_input, pos, mode="max", allow_general_intent=True): + + intent = self._get_intent( + model_input, self.seq[:1, :pos], mode, allow_general_intent) + pos = self._update_seq(intent["token_id"], pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + + # get domain + domain = self._get_domain( + model_input, self.seq[:1, :pos], intent["token_name"], mode) + pos = self._update_seq(domain["token_id"], pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + + # get slot + slot = self._get_slot( + model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) + pos = self._update_seq(slot["token_id"], pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + + # get value + + value = self._get_value( + model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], slot["token_name"], mode) + pos = self._update_seq(value["token_id"], pos) + + return pos + + def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + + return self.kg.get_intent(next_token_logits, mode, allow_general_intent) + + def _get_domain(self, model_input, generated_so_far, intent, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + + return self.kg.get_domain(next_token_logits, intent, mode) + + def _get_slot(self, model_input, generated_so_far, intent, domain, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + is_mentioned = self.vector.is_mentioned(domain) + return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned) + + def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + + return self.kg.get_value(next_token_logits, intent, domain, slot, mode) + + def _remove_illegal_action(self, action): + # Transform illegal action to legal action + new_action = [] + for act in action: + if len(act) == 4: + if "<?>" in act[-1]: + act = [act[0], act[1], act[2], "?"] + if act not in new_action: + new_action.append(act) + else: + print("illegal action:", action) + return new_action + + def _parse_output(self, in_str): + in_str = str(in_str) + in_str = in_str.replace('<s>', '').replace( + '<\\s>', '').replace('o"clock', "o'clock") + action = {"action": [], "text": ""} + try: + action = json.loads(in_str) + except: + print("invalid action:", in_str) + print("-"*20) + return action + + def predict(self, sys_act, mode="max", allow_general_intent=True): + # raw_sys_act = sys_act + # sys_act = sys_act[:5] + # update goal + # TODO + allow_general_intent = False + self.model.eval() + + if not self.add_sys_from_reward: + self.goal.update_user_goal(action=sys_act, char="sys") + self.sys_acts.append(sys_act) # for terminate conversation + + # update constraint + self.time_step += 2 + + history = [] + if self.usr_acts: + if self.max_history == 1: + history = self.usr_acts[-1] + else: + history = self.usr_acts[-1*self.max_history:] + inputs = json.dumps({"system": sys_act, + "goal": self.goal.get_goal_list(), + "history": history, + "turn": str(int(self.time_step/2))}) + with torch.no_grad(): + raw_output = self._generate_action( + raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) + output = self._parse_output(raw_output) + self.semantic_action = self._remove_illegal_action(output["action"]) + if not self.only_action: + self.utterance = output["text"] + + # TODO + if self.is_finish(): + self.semantic_action, self.utterance = self._good_bye() + + # if self.is_finish(): + # print("terminated") + + # if self.is_finish(): + # good_bye = self._good_bye() + # self.goal.add_usr_da(good_bye) + # return good_bye + + self.goal.update_user_goal(action=self.semantic_action, char="usr") + self.vector.update_mentioned_domain(self.semantic_action) + self.usr_acts.append(self.semantic_action) + + # if self._usr_terminate(usr_action): + # print("terminated by user") + # self.terminated = True + + del inputs + + if self.mode == "language": + # print("in", sys_act) + # print("out", self.utterance) + return self.utterance + else: + return self.semantic_action + + def init_session(self, goal=None): + self.token_map = tokenMap(tokenizer=self.tokenizer) + self.token_map.default(only_action=self.only_action) + self.time_step = 0 + remove_domain = "police" # remove police domain in inference + + if not goal: + self._new_goal(remove_domain=remove_domain) + else: + self._read_goal(goal) + + self.vector.init_session(goal=self.goal) + + self.terminated = False + self.add_sys_from_reward = False + self.sys_acts = [] + self.usr_acts = [] + self.semantic_action = [] + self.utterance = "" + + def _read_goal(self, data_goal): + self.goal = Goal(goal=data_goal) + + def _new_goal(self, remove_domain="police", domain_len=None): + self.goal = Goal(goal_generator=self.goal_gen) + # keep_generate_goal = True + # # domain_len = 1 + # while keep_generate_goal: + # self.goal = Goal(goal_generator=self.goal_gen, + # goal_list_type=self.goal_list_type, + # update_mode=self.update_mode) + # if (domain_len and len(self.goal.domains) != domain_len) or \ + # (remove_domain and remove_domain in self.goal.domains): + # keep_generate_goal = True + # else: + # keep_generate_goal = False + + def load(self, model_path): + self.model.load_state_dict(torch.load( + model_path, map_location=self.device)) + # self.model = BartForConditionalGeneration.from_pretrained( + # model_checkpoint) + + def get_goal(self): + if self.goal.raw_goal is not None: + return self.goal.raw_goal + goal = {} + for domain in self.goal.domain_goals: + if domain not in goal: + goal[domain] = {} + for intent in self.goal.domain_goals[domain]: + if intent == "inform": + slot_type = "info" + elif intent == "request": + slot_type = "reqt" + elif intent == "book": + slot_type = "book" + else: + print("unknown slot type") + if slot_type not in goal[domain]: + goal[domain][slot_type] = {} + for slot, value in self.goal.domain_goals[domain][intent].items(): + goal[domain][slot_type][slot] = value + return goal + + def get_reward(self, sys_response=None): + self.add_sys_from_reward = False if sys_response is None else True + + if self.add_sys_from_reward: + self.goal.update_user_goal(action=sys_response, char="sys") + self.goal.add_sys_da(sys_response) # for evaluation + self.sys_acts.append(sys_response) # for terminate conversation + + if self.is_finish(): + if self.is_success(): + reward = self.reward["success"] + self.success = True + else: + reward = self.reward["fail"] + self.success = False + + else: + reward = -1 + if self.turn_level_reward: + reward += self.turn_reward() + + self.success = None + # if self.action_penalty: + # reward += self._system_action_penalty() + + if self.norm_reward: + reward = (reward - 20)/60 + return reward + + def _system_action_penalty(self): + free_action_len = 3 + if len(self.sys_acts) < 1: + return 0 + # TODO only penalize the slots not in user goal + # else: + # penlaty = 0 + # for i in range(len(self.sys_acts[-1])): + # penlaty += -1*i + # return penlaty + if len(self.sys_acts[-1]) > 3: + return -1*(len(self.sys_acts[-1])-free_action_len) + return 0 + + def turn_reward(self): + r = 0 + r += self._new_act_reward() + r += self._reply_reward() + r += self._usr_act_len() + return r + + def _usr_act_len(self): + last_act = self.usr_acts[-1] + penalty = 0 + if len(last_act) > 2: + penalty = (2-len(last_act))*self.usr_act_penalize + return penalty + + def _new_act_reward(self): + last_act = self.usr_acts[-1] + if last_act != self.semantic_action: + print(f"---> why? last {last_act} usr {self.semantic_action}") + new_act = [] + for act in last_act: + if len(self.usr_acts) < 2: + break + if act[1].lower() == "general": + new_act.append(0) + elif act in self.usr_acts[-2]: + new_act.append(-1) + elif act not in self.usr_acts[-2]: + new_act.append(1) + + return sum(new_act) + + def _reply_reward(self): + if self.cooperative: + return self._cooperative_reply_reward() + else: + return self._non_cooperative_reply_reward() + + def _non_cooperative_reply_reward(self): + r = [] + reqts = [] + infos = [] + reply_len = 0 + max_len = 1 + for act in self.sys_acts[-1]: + if act[0] == "request": + reqts.append([act[1], act[2]]) + for act in self.usr_acts[-1]: + if act[0] == "inform": + infos.append([act[1], act[2]]) + for req in reqts: + if req in infos: + if reply_len < max_len: + r.append(1) + elif reply_len == max_len: + r.append(0) + else: + r.append(-5) + + if r: + return sum(r) + return 0 + + def _cooperative_reply_reward(self): + r = [] + reqts = [] + infos = [] + for act in self.sys_acts[-1]: + if act[0] == "request": + reqts.append([act[1], act[2]]) + for act in self.usr_acts[-1]: + if act[0] == "inform": + infos.append([act[1], act[2]]) + for req in reqts: + if req in infos: + r.append(1) + else: + r.append(-1) + if r: + return sum(r) + return 0 + + def _usr_terminate(self): + for act in self.semantic_action: + if act[0] in ['thank', 'bye']: + return True + return False + + def is_finish(self): + # stop by model generation? + if self._finish_conversation_rule(): + self.terminated = True + return True + elif self._usr_terminate(): + self.terminated = True + return True + self.terminated = False + return False + + def is_success(self): + task_complete = self.goal.task_complete() + # goal_status = self.goal.all_mentioned() + # should mentioned all slots + if task_complete: # and goal_status["complete"] > 0.6: + return True + return False + + def _good_bye(self): + if self.is_success(): + return [['thank', 'general', 'none', 'none']], "thank you. bye" + # if self.mode == "semantic": + # return [['thank', 'general', 'none', 'none']] + # else: + # return "bye" + else: + return [["bye", "general", "None", "None"]], "bye" + if self.mode == "semantic": + return [["bye", "general", "None", "None"]] + return "bye" + + def _finish_conversation_rule(self): + if self.is_success(): + return True + + if self.time_step > self.max_turn: + return True + + if (len(self.sys_acts) > 4) and (self.sys_acts[-1] == self.sys_acts[-2]) and (self.sys_acts[-2] == self.sys_acts[-3]): + return True + return False + + def is_terminated(self): + # Is there any action to say? + self.is_finish() + return self.terminated + + +class UserPolicy(Policy): + def __init__(self, + model_checkpoint, + mode="semantic", + only_action=True, + sample=False, + action_penalty=False, + **kwargs): + # self.config = config + # if not os.path.exists(self.config["model_dir"]): + # os.mkdir(self.config["model_dir"]) + # model_downloader(self.config["model_dir"], + # "https://zenodo.org/record/5779832/files/default.zip") + + self.policy = UserActionPolicy( + model_checkpoint, + mode=mode, + only_action=only_action, + action_penalty=action_penalty, + **kwargs) + self.policy.load(os.path.join( + model_checkpoint, "pytorch_model.bin")) + self.sample = sample + + def predict(self, sys_act, mode="max"): + if self.sample: + mode = "sample" + else: + mode = "max" + response = self.policy.predict(sys_act, mode) + return response + + def init_session(self, goal=None): + self.policy.init_session(goal) + + def is_terminated(self): + return self.policy.is_terminated() + + def get_reward(self, sys_response=None): + return self.policy.get_reward(sys_response) + + def get_goal(self): + if hasattr(self.policy, 'get_goal'): + return self.policy.get_goal() + return None + + +if __name__ == "__main__": + import os + + from convlab.dialog_agent import PipelineAgent + # from convlab.nlu.jointBERT.multiwoz import BERTNLU + from convlab.util.custom_util import set_seed + + set_seed(20220220) + # Test semantic level behaviour + model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0' + usr_policy = UserPolicy( + model_checkpoint, + mode="semantic") + usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin")) + usr_nlu = None # BERTNLU() + usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') + print(usr.policy.get_goal()) + + print(usr.response([])) + print(usr.policy.policy.goal.status) + print(usr.response([["request", "attraction", "area", "?"]])) + print(usr.policy.policy.goal.status) + print(usr.response([["request", "attraction", "area", "?"]])) + print(usr.policy.policy.goal.status) diff --git a/convlab/policy/genTUS/stepGenTUSmodel.py b/convlab/policy/genTUS/stepGenTUSmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..e2eaf7bc5064262808120aac4a9cbe2eb007d863 --- /dev/null +++ b/convlab/policy/genTUS/stepGenTUSmodel.py @@ -0,0 +1,114 @@ + +import json + +import torch +from torch.nn.functional import softmax, one_hot, cross_entropy + +from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph +from convlab.policy.genTUS.token_map import tokenMap +from convlab.policy.genTUS.utils import append_tokens +from transformers import (BartConfig, BartForConditionalGeneration, + BartTokenizer) + + +class stepGenTUSmodel(BartForConditionalGeneration): + def __init__(self, model_checkpoint, train_whole_model=True, **kwargs): + config = BartConfig.from_pretrained(model_checkpoint) + super().__init__(config, **kwargs) + + self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint) + self.vocab = len(self.tokenizer) + self.kg = KnowledgeGraph(self.tokenizer) + self.action_kg = KnowledgeGraph(self.tokenizer) + self.token_map = tokenMap(self.tokenizer) + # only_action doesn't matter. it is only used for get_log_prob + self.token_map.default(only_action=True) + + if not train_whole_model: + for param in self.parameters(): + param.requires_grad = False + + for param in self.model.decoder.layers[-1].fc1.parameters(): + param.requires_grad = True + for param in self.model.decoder.layers[-1].fc2.parameters(): + param.requires_grad = True + + def get_trainable_param(self): + + return filter( + lambda p: p.requires_grad, self.parameters()) + + def get_next_token_logits(self, model_input, generated_so_far): + input_ids = model_input["input_ids"].to(self.device) + attention_mask = model_input["attention_mask"].to(self.device) + outputs = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=generated_so_far, + return_dict=True) + return outputs.logits[:, -1, :] + + def get_log_prob(self, s, a, action_mask, prob_mask): + output = self.forward(input_ids=s, + attention_mask=action_mask, + decoder_input_ids=a) + prob = self._norm_prob(a[:, 1:].long(), + output.logits[:, :-1, :], + prob_mask[:, 1:, :].long()) + return prob + + def _norm_prob(self, a, prob, mask): + prob = softmax(prob, -1) + base = self._base(prob, mask).to(self.device) # [b, seq_len] + prob = (prob*one_hot(a, num_classes=self.vocab)).sum(-1) + prob = torch.log(prob / base) + pad_mask = a != 1 + prob = prob*pad_mask.float() + return prob.sum(-1) + + @staticmethod + def _base(prob, mask): + batch_size, seq_len, dim = prob.shape + base = torch.zeros(batch_size, seq_len) + for b in range(batch_size): + for s in range(seq_len): + temp = [prob[b, s, c] for c in mask[b, s, :] if c > 0] + base[b, s] = torch.sum(torch.tensor(temp)) + return base + + +if __name__ == "__main__": + import os + from convlab.util.custom_util import set_seed + from convlab.policy.genTUS.stepGenTUS import UserActionPolicy + set_seed(0) + device = "cuda" if torch.cuda.is_available() else "cpu" + + model_checkpoint = 'results/genTUS-22-01-31-09-21/' + usr = UserActionPolicy(model_checkpoint=model_checkpoint) + usr.model.load_state_dict(torch.load( + os.path.join(model_checkpoint, "pytorch_model.bin"), map_location=device)) + usr.model.eval() + + test_file = "convlab/policy/genTUS/data/goal_status_validation_v1.json" + data = json.load(open(test_file)) + test_id = 20 + inputs = usr.tokenizer(data["dialog"][test_id]["in"], + max_length=400, + return_tensors="pt", + truncation=True) + + actions = [data["dialog"][test_id]["out"], + data["dialog"][test_id+100]["out"]] + + for action in actions: + action = json.loads(action) + vec = usr.vector.action_vectorize( + action["action"], s=inputs["input_ids"]) + + print({"action": action["action"]}) + print("get_log_prob", usr.model.get_log_prob( + inputs["input_ids"], + torch.unsqueeze(vec["vector"], 0), + inputs["attention_mask"], + torch.unsqueeze(vec["mask"], 0))) diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py new file mode 100644 index 0000000000000000000000000000000000000000..7825c2880928c40f68284b0c3199932cd1cfc477 --- /dev/null +++ b/convlab/policy/genTUS/token_map.py @@ -0,0 +1,64 @@ +import json + + +class tokenMap: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.token_name = {} + self.hash_map = {} + self.debug = False + self.default() + + def default(self, only_action=False): + self.format_tokens = { + 'start_json': '{"action": [', # 49643, 10845, 7862, 646 + 'start_act': '["', # 49329 + 'sep_token': '", "', # 1297('",'), 22 + 'sep_act': '"], ["', # 49177 + 'end_act': '"]], "', # 42248, 7479, 22 + 'start_text': 'text": "', # 29015, 7862, 22 + 'end_json': '}', # 24303 + 'end_json_2': '"}' # 48805 + } + if only_action: + self.format_tokens['end_act'] = '"]]}' + for token_name in self.format_tokens: + self.add_token( + token_name, self.format_tokens[token_name]) + + def add_token(self, token_name, value): + if token_name in self.token_name and self.debug: + print(f"---> duplicate token: {token_name}({value})!!!!!!!") + + token_id = self.tokenizer(str(value), add_special_tokens=False)[ + "input_ids"] + self.token_name[token_name] = {"value": value, "token_id": token_id} + # print(token_id) + hash_id = token_id[0] + if hash_id in self.hash_map and self.debug: + print( + f"---> conflict hash number {hash_id}: {self.hash_map[hash_id]['name']} and {token_name}") + self.hash_map[hash_id] = { + "name": token_name, "value": value, "token_id": token_id} + + def get_info(self, hash_id): + return self.hash_map[hash_id] + + def get_id(self, token_name): + # workaround + # if token_name not in self.token_name[token_name]: + # self.add_token(token_name, token_name) + return self.token_name[token_name]["token_id"] + + def get_token_value(self, token_name): + return self.token_name[token_name]["value"] + + def token_name_is_in(self, token_name): + if token_name in self.token_name: + return True + return False + + def hash_id_is_in(self, hash_id): + if hash_id in self.hash_map: + return True + return False diff --git a/convlab/policy/genTUS/train_model.py b/convlab/policy/genTUS/train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2162417461d692514e3b27742dbfd477491fc24e --- /dev/null +++ b/convlab/policy/genTUS/train_model.py @@ -0,0 +1,258 @@ +import json +import os +import sys +from argparse import ArgumentParser +from datetime import datetime +from pprint import pprint +import numpy as np +import torch +import transformers +from datasets import Dataset, load_metric +from tqdm import tqdm +from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, + BartForConditionalGeneration, BartTokenizer, + DataCollatorForSeq2Seq, Seq2SeqTrainer, + Seq2SeqTrainingArguments) + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + +os.environ["WANDB_DISABLED"] = "true" + +METRIC = load_metric("sacrebleu") +TOKENIZER = BartTokenizer.from_pretrained("facebook/bart-base") +TOKENIZER.add_tokens(["<?>"]) +MAX_IN_LEN = 500 +MAX_OUT_LEN = 500 + + +def arg_parser(): + parser = ArgumentParser() + # data_name, dial_ids_order, split2ratio + parser.add_argument("--model-type", type=str, default="unify", + help="unify or multiwoz") + parser.add_argument("--data-name", type=str, default="multiwoz21", + help="multiwoz21, sgd, tm1, tm2, tm3, sgd+tm, or all") + parser.add_argument("--dial-ids-order", type=int, default=0) + parser.add_argument("--split2ratio", type=float, default=1) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--model-checkpoint", type=str, + default="facebook/bart-base") + return parser.parse_args() + + +def gentus_compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = TOKENIZER.batch_decode( + preds, skip_special_tokens=True, max_length=MAX_OUT_LEN) + + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id) + decoded_labels = TOKENIZER.batch_decode( + labels, skip_special_tokens=True, max_length=MAX_OUT_LEN) + + act, text = postprocess_text(decoded_preds, decoded_labels) + + result = METRIC.compute( + # predictions=decoded_preds, references=decoded_labels) + predictions=text["preds"], references=text["labels"]) + result = {"bleu": result["score"]} + f1_scores = f1_measure(pred_acts=act["preds"], label_acts=act["labels"]) + for s in f1_scores: + result[s] = f1_scores[s] + + result = {k: round(v, 4) for k, v in result.items()} + return result + + +def postprocess_text(preds, labels): + act = {"preds": [], "labels": []} + text = {"preds": [], "labels": []} + + for pred, label in zip(preds, labels): + model_output = parse_output(pred.strip()) + label_output = parse_output(label.strip()) + if len(label_output["text"]) < 1: + continue + act["preds"].append(model_output.get("action", [])) + text["preds"].append(model_output.get("text", pred.strip())) + act["labels"].append(label_output["action"]) + text["labels"].append([label_output["text"]]) + + return act, text + + +def parse_output(in_str): + in_str = in_str.replace('<s>', '').replace('<\\s>', '') + try: + output = json.loads(in_str) + except: + # print(f"invalid action {in_str}") + output = {"action": [], "text": ""} + return output + + +def f1_measure(pred_acts, label_acts): + result = {"precision": [], "recall": [], "f1": []} + for pred, label in zip(pred_acts, label_acts): + r = tp_fn_fp(pred, label) + for m in result: + result[m].append(r[m]) + for m in result: + result[m] = sum(result[m])/len(result[m]) + + return result + + +def tp_fn_fp(pred, label): + tp, fn, fp = 0.0, 0.0, 0.0 + precision, recall, f1 = 0, 0, 0 + for p in pred: + if p in label: + tp += 1 + else: + fp += 1 + for l in label: + if l not in pred: + fn += 1 + if (tp+fp) > 0: + precision = tp / (tp+fp) + if (tp+fn) > 0: + recall = tp/(tp+fn) + if (precision + recall) > 0: + f1 = (2*precision*recall)/(precision+recall) + + return {"precision": precision, "recall": recall, "f1": f1} + + +class TrainerHelper: + def __init__(self, tokenizer, max_input_length=500, max_target_length=500): + print("transformers version is: ", transformers.__version__) + self.tokenizer = tokenizer + self.max_input_length = max_input_length + self.max_target_length = max_target_length + self.base_name = "convlab/policy/genTUS" + self.dir_name = "" + + def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1): + # base_name = "convlab/policy/genTUS/unify/data" + if model_type not in ["unify", "multiwoz"]: + print("Unknown model type. Currently only support unify and multiwoz") + self.dir_name = f"{data_name}_{dial_ids_order}_{split2ratio}" + return os.path.join(self.base_name, model_type, 'data', self.dir_name) + + def get_model_folder(self, model_type): + folder_name = os.path.join( + self.base_name, model_type, "experiments", self.dir_name) + if not os.path.exists(folder_name): + os.makedirs(folder_name) + return folder_name + + def parse_data(self, model_type, data_name, dial_ids_order=0, split2ratio=1): + data_folder = self._get_data_folder( + model_type, data_name, dial_ids_order, split2ratio) + + raw_data = {} + for d_type in ["train", "validation", "test"]: + f_name = os.path.join(data_folder, f"{d_type}.json") + raw_data[d_type] = json.load(open(f_name)) + + tokenized_datasets = {} + for data_type, data in raw_data.items(): + tokenized_datasets[data_type] = Dataset.from_dict( + self._preprocess(data["dialog"])) + + return tokenized_datasets + + def _preprocess(self, examples): + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if isinstance(examples, dict): + examples = [examples] + for example in tqdm(examples): + inputs = self.tokenizer(example["in"], + max_length=self.max_input_length, + truncation=True) + + # Setup the tokenizer for targets + with self.tokenizer.as_target_tokenizer(): + labels = self.tokenizer(example["out"], + max_length=self.max_target_length, + truncation=True) + for key in ["input_ids", "attention_mask"]: + model_inputs[key].append(inputs[key]) + model_inputs["labels"].append(labels["input_ids"]) + + return model_inputs + + +def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"): + tokenizer = TOKENIZER + + train_helper = TrainerHelper( + tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length) + data = train_helper.parse_data(model_type=model_type, + data_name=data_name, + dial_ids_order=dial_ids_order, + split2ratio=split2ratio) + + model = BartForConditionalGeneration.from_pretrained(model_checkpoint) + model.resize_token_embeddings(len(tokenizer)) + fp16 = False + if torch.cuda.is_available(): + fp16 = True + + model_dir = os.path.join( + train_helper.get_model_folder(model_type), + f"{datetime.now().strftime('%y-%m-%d-%H-%M')}") + + args = Seq2SeqTrainingArguments( + model_dir, + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + weight_decay=0.01, + save_total_limit=2, + num_train_epochs=5, + predict_with_generate=True, + fp16=fp16, + push_to_hub=False, + generation_max_length=max_target_length, + logging_dir=os.path.join(model_dir, 'log') + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, model=model, padding=True) + + # customize this trainer + trainer = Seq2SeqTrainer( + model=model, + args=args, + train_dataset=data["train"], + eval_dataset=data["test"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=gentus_compute_metrics) + print("start training...") + trainer.train() + print("saving model...") + trainer.save_model() + + +def main(): + args = arg_parser() + print("---> data_name", args.data_name) + train(model_type=args.model_type, + data_name=args.data_name, + dial_ids_order=args.dial_ids_order, + split2ratio=args.split2ratio, + batch_size=args.batch_size, + max_input_length=MAX_IN_LEN, + max_target_length=MAX_OUT_LEN, + model_checkpoint=args.model_checkpoint) + + +if __name__ == "__main__": + main() + # sgd+tm: 46000 diff --git a/convlab/policy/genTUS/unify/Goal.py b/convlab/policy/genTUS/unify/Goal.py new file mode 100644 index 0000000000000000000000000000000000000000..00f19338fc2f7ebc521bda7e1a8f877a6417a57e --- /dev/null +++ b/convlab/policy/genTUS/unify/Goal.py @@ -0,0 +1,233 @@ +""" +The user goal for unify data format +""" +import json +from convlab.policy.tus.unify.Goal import old_goal2list +from convlab.task.multiwoz.goal_generator import GoalGenerator +from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal +from convlab.util.custom_util import slot_mapping +DEF_VAL_UNK = '?' # Unknown +DEF_VAL_DNC = 'dontcare' # Do not care +DEF_VAL_NUL = 'none' # for none +NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, ""] + +NOT_MENTIONED = "not mentioned" +FULFILLED = "fulfilled" +REQUESTED = "requested" +CONFLICT = "conflict" + + +class Goal: + """ User Goal Model Class. """ + + def __init__(self, goal=None, goal_generator=None): + """ + create new Goal from a dialog or from goal_generator + Args: + goal: can be a list (create from a dialog), an abus goal, or none + """ + self.domains = [] + self.domain_goals = {} + self.status = {} + self.invert_slot_mapping = {v: k for k, v in slot_mapping.items()} + self.raw_goal = None + + self._init_goal_from_data(goal, goal_generator) + self._init_status() + + def __str__(self): + return '-----Goal-----\n' + \ + json.dumps(self.domain_goals, indent=4) + \ + '\n-----Goal-----' + + def _init_goal_from_data(self, goal=None, goal_generator=None): + if not goal and goal_generator: + goal = ABUS_Goal(goal_generator) + self.raw_goal = goal.domain_goals + goal = old_goal2list(goal.domain_goals) + + elif isinstance(goal, dict): + self.raw_goal = goal + goal = old_goal2list(goal) + + elif isinstance(goal, ABUS_Goal): + self.raw_goal = goal.domain_goals + goal = old_goal2list(goal.domain_goals) + + else: + print("unknow goal") + + # be careful of this order + for domain, intent, slot, value in goal: + if domain == "none": + continue + if domain not in self.domains: + self.domains.append(domain) + self.domain_goals[domain] = {} + if intent not in self.domain_goals[domain]: + self.domain_goals[domain][intent] = {} + + if not value: + if intent == "request": + self.domain_goals[domain][intent][slot] = DEF_VAL_UNK + else: + print( + f"unknown no value intent {domain}, {intent}, {slot}") + else: + self.domain_goals[domain][intent][slot] = value + + def _init_status(self): + for domain, domain_goal in self.domain_goals.items(): + if domain not in self.status: + self.status[domain] = {} + for slot_type, sub_domain_goal in domain_goal.items(): + if slot_type not in self.status[domain]: + self.status[domain][slot_type] = {} + for slot in sub_domain_goal: + if slot not in self.status[domain][slot_type]: + self.status[domain][slot_type][slot] = {} + self.status[domain][slot_type][slot] = { + "value": str(sub_domain_goal[slot]), + "status": NOT_MENTIONED} + + def get_goal_list(self, data_goal=None): + goal_list = [] + if data_goal: + # make sure the order!!! + for domain, intent, slot, _ in data_goal: + status = self._get_status(domain, intent, slot) + value = self.domain_goals[domain][intent][slot] + goal_list.append([intent, domain, slot, value, status]) + return goal_list + else: + for domain, domain_goal in self.domain_goals.items(): + for intent, sub_goal in domain_goal.items(): + for slot, value in sub_goal.items(): + status = self._get_status(domain, intent, slot) + goal_list.append([intent, domain, slot, value, status]) + + return goal_list + + def _get_status(self, domain, intent, slot): + if domain not in self.status: + return NOT_MENTIONED + if intent not in self.status[domain]: + return NOT_MENTIONED + if slot not in self.status[domain][intent]: + return NOT_MENTIONED + return self.status[domain][intent][slot]["status"] + + def task_complete(self): + """ + Check that all requests have been met + Returns: + (boolean): True to accomplish. + """ + for domain, domain_goal in self.status.items(): + if domain not in self.domain_goals: + continue + for slot_type, sub_domain_goal in domain_goal.items(): + if slot_type not in self.domain_goals[domain]: + continue + for slot, status in sub_domain_goal.items(): + if slot not in self.domain_goals[domain][slot_type]: + continue + # for strict success, turn this on + if status["status"] in [NOT_MENTIONED, CONFLICT]: + if status["status"] == CONFLICT and slot in ["arrive by", "leave at"]: + continue + return False + if "?" in status["value"]: + return False + + return True + + # TODO change to update()? + def update_user_goal(self, action, char="usr"): + # update request and booked + if char == "usr": + self._user_action_update(action) + elif char == "sys": + self._system_action_update(action) + else: + print("!!!UNKNOWN CHAR!!!") + + def _user_action_update(self, action): + # no need to update user goal + for intent, domain, slot, _ in action: + goal_intent = self._check_slot_and_intent(domain, slot) + if not goal_intent: + continue + # fulfilled by user + if is_inform(intent): + self._set_status(goal_intent, domain, slot, FULFILLED) + # requested by user + if is_request(intent): + self._set_status(goal_intent, domain, slot, REQUESTED) + + def _system_action_update(self, action): + for intent, domain, slot, value in action: + goal_intent = self._check_slot_and_intent(domain, slot) + if not goal_intent: + continue + # fulfill request by system + if is_inform(intent) and is_request(goal_intent): + self._set_status(goal_intent, domain, slot, FULFILLED) + self._set_goal(goal_intent, domain, slot, value) + + if is_inform(intent) and is_inform(goal_intent): + # fulfill infrom by system + if value == self.domain_goals[domain][goal_intent][slot]: + self._set_status(goal_intent, domain, slot, FULFILLED) + # conflict system inform + else: + self._set_status(goal_intent, domain, slot, CONFLICT) + # requested by system + if is_request(intent) and is_inform(goal_intent): + self._set_status(goal_intent, domain, slot, REQUESTED) + + def _set_status(self, intent, domain, slot, status): + self.status[domain][intent][slot]["status"] = status + + def _set_goal(self, intent, domain, slot, value): + # old_value = self.domain_goals[domain][intent][slot] + self.domain_goals[domain][intent][slot] = value + self.status[domain][intent][slot]["value"] = value + # print( + # f"updating user goal {intent}-{domain}-{slot} {old_value}-> {value}") + + def _check_slot_and_intent(self, domain, slot): + not_found = "" + if domain not in self.domain_goals: + return not_found + for intent in self.domain_goals[domain]: + if slot in self.domain_goals[domain][intent]: + return intent + return not_found + + +def is_inform(intent): + if "inform" in intent: + return True + return False + + +def is_request(intent): + if "request" in intent: + return True + return False + + +def transform_data_act(data_action): + action_list = [] + for _, dialog_act in data_action.items(): + for act in dialog_act: + value = act.get("value", "") + if not value: + if "request" in act["intent"]: + value = "?" + else: + value = "none" + action_list.append( + [act["intent"], act["domain"], act["slot"], value]) + return action_list diff --git a/convlab/policy/genTUS/unify/build_data.py b/convlab/policy/genTUS/unify/build_data.py new file mode 100644 index 0000000000000000000000000000000000000000..50873a1d4b6ffaa8ee49c84a4b088ae56ad13554 --- /dev/null +++ b/convlab/policy/genTUS/unify/build_data.py @@ -0,0 +1,211 @@ +import json +import os +import sys +from argparse import ArgumentParser + +from tqdm import tqdm + +from convlab.policy.genTUS.unify.Goal import Goal, transform_data_act +from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset + + +sys.path.append(os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__))))) + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--dataset", type=str, default="multiwoz21", + help="the dataset, such as multiwoz21, sgd, tm1, tm2, and tm3.") + parser.add_argument("--dial-ids-order", type=int, default=0) + parser.add_argument("--split2ratio", type=float, default=1) + parser.add_argument("--random-order", action="store_true") + parser.add_argument("--no-status", action="store_true") + parser.add_argument("--add-history", action="store_true") + parser.add_argument("--remove-domain", type=str, default="") + + return parser.parse_args() + +class DataBuilder: + def __init__(self, dataset='multiwoz21'): + self.dataset = dataset + + def setup_data(self, + raw_data, + random_order=False, + no_status=False, + add_history=False, + remove_domain=None): + examples = {data_split: {"dialog": []} for data_split in raw_data} + + for data_split, dialogs in raw_data.items(): + for dialog in tqdm(dialogs, ascii=True): + example = self._one_dialog(dialog=dialog, + add_history=add_history, + random_order=random_order, + no_status=no_status) + examples[data_split]["dialog"] += example + + return examples + + def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False): + example = [] + history = [] + + data_goal = self.norm_domain_goal(create_goal(dialog)) + if not data_goal: + return example + user_goal = Goal(goal=data_goal) + + for turn_id in range(0, len(dialog["turns"]), 2): + sys_act = self._get_sys_act(dialog, turn_id) + + user_goal.update_user_goal(action=sys_act, char="sys") + usr_goal_str = self._user_goal_str(user_goal, data_goal, random_order, no_status) + + usr_act = self.norm_domain(transform_data_act( + dialog["turns"][turn_id]["dialogue_acts"])) + user_goal.update_user_goal(action=usr_act, char="usr") + + # change value "?" to "<?>" + usr_act = self._modify_act(usr_act) + + in_str = self._dump_in_str(sys_act, usr_goal_str, history, turn_id, add_history) + out_str = self._dump_out_str(usr_act, dialog["turns"][turn_id]["utterance"]) + + history.append(usr_act) + if usr_act: + example.append({"in": in_str, "out": out_str}) + + return example + + def _get_sys_act(self, dialog, turn_id): + sys_act = [] + if turn_id > 0: + sys_act = self.norm_domain(transform_data_act( + dialog["turns"][turn_id - 1]["dialogue_acts"])) + return sys_act + + def _user_goal_str(self, user_goal, data_goal, random_order, no_status): + if random_order: + usr_goal_str = user_goal.get_goal_list() + else: + usr_goal_str = user_goal.get_goal_list(data_goal=data_goal) + + if no_status: + usr_goal_str = self._remove_status(usr_goal_str) + return usr_goal_str + + def _dump_in_str(self, sys_act, usr_goal_str, history, turn_id, add_history): + in_str = {} + in_str["system"] = self._modify_act(sys_act) + in_str["goal"] = usr_goal_str + if add_history: + h = [] + if history: + h = history[-3:] + in_str["history"] = h + in_str["turn"] = str(int(turn_id/2)) + + return json.dumps(in_str) + + def _dump_out_str(self, usr_act, text): + out_str = {"action": usr_act, "text": text} + return json.dumps(out_str) + + @staticmethod + def _norm_intent(intent): + if intent in ["inform_intent", "negate_intent", "affirm_intent", "request_alts"]: + return f"_{intent}" + return intent + + def norm_domain(self, x): + if not x: + return x + norm_result = [] + # print(x) + for intent, domain, slot, value in x: + if "_" in domain: + domain = domain.split('_')[0] + if not domain: + domain = "none" + if not slot: + slot = "none" + if not value: + if intent == "request": + value = "<?>" + else: + value = "none" + norm_result.append([self._norm_intent(intent), domain, slot, value]) + return norm_result + + def norm_domain_goal(self, x): + if not x: + return x + norm_result = [] + # take care of the order! + for domain, intent, slot, value in x: + if "_" in domain: + domain = domain.split('_')[0] + if not domain: + domain = "none" + if not slot: + slot = "none" + if not value: + if intent == "request": + value = "<?>" + else: + value = "none" + norm_result.append([domain, self._norm_intent(intent), slot, value]) + return norm_result + + @staticmethod + def _remove_status(goal_list): + new_list = [[goal[0], goal[1], goal[2], goal[3]] + for goal in goal_list] + return new_list + + @staticmethod + def _modify_act(act): + new_act = [] + for i, d, s, value in act: + if value == "?": + new_act.append([i, d, s, "<?>"]) + else: + new_act.append([i, d, s, value]) + return new_act + + +if __name__ == "__main__": + args = arg_parser() + + base_name = "convlab/policy/genTUS/unify/data" + dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}" + folder_name = os.path.join(base_name, dir_name) + remove_domain = args.remove_domain + + if not os.path.exists(folder_name): + os.makedirs(folder_name) + + dataset = load_experiment_dataset( + data_name=args.dataset, + dial_ids_order=args.dial_ids_order, + split2ratio=args.split2ratio) + data_builder = DataBuilder(dataset=args.dataset) + data = data_builder.setup_data( + raw_data=dataset, + random_order=args.random_order, + no_status=args.no_status, + add_history=args.add_history, + remove_domain=remove_domain) + + for data_type in data: + if remove_domain: + file_name = os.path.join( + folder_name, + f"no{remove_domain}_{data_type}.json") + else: + file_name = os.path.join( + folder_name, + f"{data_type}.json") + json.dump(data[data_type], open(file_name, 'w'), indent=2) diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..68af13e481fe4799dfc2a6f3763b526611eabd9c --- /dev/null +++ b/convlab/policy/genTUS/unify/knowledge_graph.py @@ -0,0 +1,252 @@ +import json +from random import choices + +from convlab.policy.genTUS.token_map import tokenMap + +from transformers import BartTokenizer + +DEBUG = False +DATASET = "unify" + + +class KnowledgeGraph: + def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="multiwoz21"): + print("dataset", dataset) + self.debug = DEBUG + self.tokenizer = tokenizer + + if "multiwoz" in dataset: + self.domain_intent = ["inform", "request"] + self.general_intent = ["thank", "bye"] + # use sgd dataset intents as default + else: + self.domain_intent = ["_inform_intent", + "_negate_intent", + "_affirm_intent", + "inform", + "request", + "affirm", + "negate", + "select", + "_request_alts"] + self.general_intent = ["thank_you", "goodbye"] + + self.general_domain = "none" + self.kg_map = {"intent": tokenMap(tokenizer=self.tokenizer)} + + for intent in self.domain_intent + self.general_intent: + self.kg_map["intent"].add_token(intent, intent) + + self.init() + + def init(self): + for map_type in ["domain", "slot", "value"]: + self.kg_map[map_type] = tokenMap(tokenizer=self.tokenizer) + self.add_token("<?>", "value") + + def parse_input(self, in_str): + self.init() + inputs = json.loads(in_str) + self.sys_act = inputs["system"] + self.user_goal = {} + self._add_none_domain() + for intent, domain, slot, value, _ in inputs["goal"]: + self._update_user_goal(intent, domain, slot, value, source="goal") + + for intent, domain, slot, value in self.sys_act: + self._update_user_goal(intent, domain, slot, value, source="sys") + + def _add_none_domain(self): + self.user_goal["none"] = {"none": "none"} + # add slot + self.add_token("none", "domain") + self.add_token("none", "slot") + self.add_token("none", "value") + + def _update_user_goal(self, intent, domain, slot, value, source="goal"): + + if value == "?": + value = "<?>" + + if intent == "request" and source == "sys": + value = "dontcare" # user can "dontcare" system request + + if source == "sys" and intent != "request": + return + + if domain not in self.user_goal: + self.user_goal[domain] = {} + self.user_goal[domain]["none"] = ["none"] + self.add_token(domain, "domain") + self.add_token("none", "slot") + self.add_token("none", "value") + + if slot not in self.user_goal[domain]: + self.user_goal[domain][slot] = [] + self.add_token(domain, "slot") + + if value not in self.user_goal[domain][slot]: + value = json.dumps(str(value))[1:-1] + self.user_goal[domain][slot].append(value) + value = value.replace('"', "'") + self.add_token(value, "value") + + def add_token(self, token_name, map_type): + if map_type == "value": + token_name = token_name.replace('"', "'") + if not self.kg_map[map_type].token_name_is_in(token_name): + self.kg_map[map_type].add_token(token_name, token_name) + + def _get_max_score(self, outputs, candidate_list, map_type): + score = {} + if not candidate_list: + print(f"ERROR: empty candidate list for {map_type}") + score[1] = {"token_id": self._get_token_id( + "none"), "token_name": "none"} + + for x in candidate_list: + hash_id = self._get_token_id(x)[0] + s = outputs[:, hash_id].item() + score[s] = {"token_id": self._get_token_id(x), + "token_name": x} + return score + + def _select(self, score, mode="max"): + probs = [s for s in score] + if mode == "max": + s = max(probs) + elif mode == "sample": + s = choices(probs, weights=probs, k=1) + s = s[0] + + else: + print("unknown select mode") + + return s + + def _get_max_domain_token(self, outputs, candidates, map_type, mode="max"): + score = self._get_max_score(outputs, candidates, map_type) + s = self._select(score, mode) + token_id = score[s]["token_id"] + token_name = score[s]["token_name"] + + return {"token_id": token_id, "token_name": token_name} + + def candidate(self, candidate_type, **kwargs): + if "intent" in kwargs: + intent = kwargs["intent"] + if candidate_type == "intent": + allow_general_intent = kwargs.get("allow_general_intent", True) + if allow_general_intent: + return self.domain_intent + self.general_intent + else: + return self.domain_intent + elif candidate_type == "domain": + if intent in self.general_intent: + return [self.general_domain] + else: + return [d for d in self.user_goal] + elif candidate_type == "slot": + if intent in self.general_intent: + return ["none"] + else: + return self._filter_slot(intent, kwargs["domain"], kwargs["is_mentioned"]) + else: + if intent in self.general_intent: + return ["none"] + elif intent.lower() == "request": + return ["<?>"] + else: + return self._filter_value(intent, kwargs["domain"], kwargs["slot"]) + + def get_intent(self, outputs, mode="max", allow_general_intent=True): + # return intent, token_id_list + # TODO request? + canidate_list = self.candidate( + "intent", allow_general_intent=allow_general_intent) + score = self._get_max_score(outputs, canidate_list, "intent") + s = self._select(score, mode) + + return score[s] + + def get_domain(self, outputs, intent, mode="max"): + if intent in self.general_intent: + token_name = self.general_domain + token_id = self.tokenizer(token_name, add_special_tokens=False) + token_map = {"token_id": token_id['input_ids'], + "token_name": token_name} + + elif intent in self.domain_intent: + # [d for d in self.user_goal] + domain_list = self.candidate("domain", intent=intent) + token_map = self._get_max_domain_token( + outputs=outputs, candidates=domain_list, map_type="domain", mode=mode) + else: + if self.debug: + print("unknown intent", intent) + + return token_map + + def get_slot(self, outputs, intent, domain, mode="max", is_mentioned=False): + if intent in self.general_intent: + token_name = "none" + token_id = self.tokenizer(token_name, add_special_tokens=False) + token_map = {"token_id": token_id['input_ids'], + "token_name": token_name} + + elif intent in self.domain_intent: + slot_list = self.candidate( + candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned) + token_map = self._get_max_domain_token( + outputs=outputs, candidates=slot_list, map_type="slot", mode=mode) + + return token_map + + def get_value(self, outputs, intent, domain, slot, mode="max"): + if intent in self.general_intent or slot.lower() == "none": + token_name = "none" + token_id = self.tokenizer(token_name, add_special_tokens=False) + token_map = {"token_id": token_id['input_ids'], + "token_name": token_name} + + elif intent.lower() == "request": + token_name = "<?>" + token_id = self.tokenizer(token_name, add_special_tokens=False) + token_map = {"token_id": token_id['input_ids'], + "token_name": token_name} + + elif intent in self.domain_intent: + # TODO should not none ? + # value_list = [v for v in self.user_goal[domain][slot]] + value_list = self.candidate( + candidate_type="value", intent=intent, domain=domain, slot=slot) + + token_map = self._get_max_domain_token( + outputs=outputs, candidates=value_list, map_type="value", mode=mode) + + return token_map + + def _filter_slot(self, intent, domain, is_mentioned=True): + slot_list = [] + for slot in self.user_goal[domain]: + value_list = self._filter_value(intent, domain, slot) + if len(value_list) > 0: + slot_list.append(slot) + if not is_mentioned and intent.lower() != "request": + slot_list.append("none") + return slot_list + + def _filter_value(self, intent, domain, slot): + value_list = [v for v in self.user_goal[domain][slot]] + if "none" in value_list: + value_list.remove("none") + if intent.lower() != "request": + if "?" in value_list: + value_list.remove("?") + if "<?>" in value_list: + value_list.remove("<?>") + # print(f"{intent}-{domain}-{slot}= {value_list}") + return value_list + + def _get_token_id(self, token): + return self.tokenizer(token, add_special_tokens=False)["input_ids"] diff --git a/convlab/policy/genTUS/utils.py b/convlab/policy/genTUS/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39c822dd35f985790d53a2834e8b6fe437864f24 --- /dev/null +++ b/convlab/policy/genTUS/utils.py @@ -0,0 +1,5 @@ +import torch + + +def append_tokens(tokens, new_token, device): + return torch.cat((tokens, torch.tensor([new_token]).to(device)), dim=1) diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json new file mode 100644 index 0000000000000000000000000000000000000000..5bf65c9f97f951a367a0abf461e9aa9172d64021 --- /dev/null +++ b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json @@ -0,0 +1,44 @@ +{ + "model": { + "load_path": "convlab/policy/ppo/pretrained_models/supervised", + "pretrained_load_path": "", + "use_pretrained_initialisation": false, + "batchsz": 1000, + "seed": 0, + "epoch": 50, + "eval_frequency": 5, + "process_num": 1, + "num_eval_dialogues": 500, + "sys_semantic_to_usr": false + }, + "vectorizer_sys": { + "uncertainty_vector_mul": { + "class_path": "convlab.policy.vector.vector_binary.VectorBinary", + "ini_params": { + "use_masking": true, + "manually_add_entity_names": true, + "seed": 0 + } + } + }, + "nlu_sys": {}, + "dst_sys": { + "RuleDST": { + "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", + "ini_params": {} + } + }, + "sys_nlg": {}, + "nlu_usr": {}, + "dst_usr": {}, + "policy_usr": { + "RulePolicy": { + "class_path": "convlab.policy.genTUS.stepGenTUS.UserPolicy", + "ini_params": { + "model_checkpoint": "convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0", + "character": "usr" + } + } + }, + "usr_nlg": {} +} \ No newline at end of file diff --git a/convlab/policy/tus/unify/Goal.py b/convlab/policy/tus/unify/Goal.py index 82a6e7a9799b4537495cf35e57f6f52711e78af8..610469bb4246cf6c16ef4271c89827cab6421437 100644 --- a/convlab/policy/tus/unify/Goal.py +++ b/convlab/policy/tus/unify/Goal.py @@ -1,6 +1,9 @@ import time import json -from convlab.policy.tus.unify.util import split_slot_name +from convlab.policy.tus.unify.util import split_slot_name, slot_name_map +from convlab.util.custom_util import slot_mapping + +from random import sample, shuffle from pprint import pprint DEF_VAL_UNK = '?' # Unknown DEF_VAL_DNC = 'dontcare' # Do not care @@ -27,6 +30,35 @@ def isTimeFormat(input): return False +def old_goal2list(goal: dict, reorder=False) -> list: + goal_list = [] + for domain in goal: + for slot_type in ['info', 'book', 'reqt']: + if slot_type not in goal[domain]: + continue + temp = [] + for slot in goal[domain][slot_type]: + s = slot + if slot in slot_name_map: + s = slot_name_map[slot] + elif slot in slot_name_map[domain]: + s = slot_name_map[domain][slot] + # domain, intent, slot, value + if slot_type in ['info', 'book']: + i = "inform" + v = goal[domain][slot_type][slot] + else: + i = "request" + v = DEF_VAL_UNK + s = slot_mapping.get(s, s) + temp.append([domain, i, s, v]) + shuffle(temp) + goal_list = goal_list + temp + # shuffle_goal = goal_list[:1] + sample(goal_list[1:], len(goal_list)-1) + # return shuffle_goal + return goal_list + + class Goal(object): """ User Goal Model Class. """ @@ -101,6 +133,7 @@ class Goal(object): if self.domain_goals[domain]["reqt"][slot] == DEF_VAL_UNK: # print(f"not fulfilled request{domain}-{slot}") return False + return True def init_local_id(self): @@ -169,6 +202,10 @@ class Goal(object): def _update_status(self, action: list, char: str): for intent, domain, slot, value in action: + if slot == "arrive by": + slot = "arriveBy" + elif slot == "leave at": + slot = "leaveAt" if domain not in self.status: self.status[domain] = {} # update info @@ -180,6 +217,10 @@ class Goal(object): def _update_goal(self, action: list, char: str): # update requt slots in goal for intent, domain, slot, value in action: + if slot == "arrive by": + slot = "arriveBy" + elif slot == "leave at": + slot = "leaveAt" if "info" not in intent: continue if self._check_update_request(domain, slot) and value != "?": diff --git a/convlab/policy/tus/unify/TUS.py b/convlab/policy/tus/unify/TUS.py index 09c50672fb889ffd3e965ec0e90d2dee30b2c360..c380df3914d5d636c95b0e76081b361c26d79ff3 100644 --- a/convlab/policy/tus/unify/TUS.py +++ b/convlab/policy/tus/unify/TUS.py @@ -5,19 +5,18 @@ from copy import deepcopy import torch from convlab.policy.policy import Policy -from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import ( - act_dict_to_flat_tuple, unified_format) from convlab.policy.tus.multiwoz.transformer import TransformerActionPrediction from convlab.policy.tus.unify.Goal import Goal from convlab.policy.tus.unify.usermanager import BinaryFeature -from convlab.policy.tus.unify.util import (create_goal, int2onehot, - metadata2state, parse_dialogue_act, - parse_user_goal, split_slot_name) +from convlab.policy.tus.unify.util import create_goal, split_slot_name from convlab.util import (load_dataset, relative_import_module_from_unified_datasets) from convlab.util.custom_util import model_downloader -from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA -from pprint import pprint +from convlab.task.multiwoz.goal_generator import GoalGenerator +from convlab.policy.tus.unify.Goal import old_goal2list +from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal + + reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets( 'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value']) @@ -51,7 +50,7 @@ class UserActionPolicy(Policy): self.user = TransformerActionPrediction(self.config).to(device=DEVICE) if pretrain: model_path = os.path.join( - self.config["model_dir"], self.config["model_name"]) + self.config["model_dir"], "model-non-zero") # self.config["model_name"]) print(f"loading model from {model_path}...") self.load(model_path) self.user.eval() @@ -61,6 +60,8 @@ class UserActionPolicy(Policy): self.reward = {"success": 40, "fail": -20} self.sys_acts = [] + self.goal_gen = GoalGenerator() + self.raw_goal = None def _no_offer(self, system_in): for intent, domain, slot, value in system_in: @@ -127,13 +128,15 @@ class UserActionPolicy(Policy): self.topic = 'NONE' remove_domain = "police" # remove police domain in inference - # if not goal: - # self.new_goal(remove_domain=remove_domain) - # else: - # self.read_goal(goal) - if not goal: - data = load_dataset(self.dataset, 0) - goal = Goal(create_goal(data["test"][0])) + if type(goal) == ABUS_Goal: + self.raw_goal = goal.domain_goals + goal_list = old_goal2list(goal.domain_goals) + goal = Goal(goal_list) + else: + goal = ABUS_Goal(self.goal_gen) + self.raw_gaol = goal.domain_goals + goal_list = old_goal2list(goal.domain_goals) + goal = Goal(goal_list) self.read_goal(goal) self.feat_handler.initFeatureHandeler(self.goal) @@ -155,15 +158,15 @@ class UserActionPolicy(Policy): else: self.goal = Goal(goal=data_goal) - def new_goal(self, remove_domain="police", domain_len=None): - keep_generate_goal = True - while keep_generate_goal: - self.goal = Goal(goal_generator=self.goal_gen) - if (domain_len and len(self.goal.domains) != domain_len) or \ - (remove_domain and remove_domain in self.goal.domains): - keep_generate_goal = True - else: - keep_generate_goal = False + # def new_goal(self, remove_domain="police", domain_len=None): + # keep_generate_goal = True + # while keep_generate_goal: + # self.goal = Goal(goal_generator=self.goal_gen) + # if (domain_len and len(self.goal.domains) != domain_len) or \ + # (remove_domain and remove_domain in self.goal.domains): + # keep_generate_goal = True + # else: + # keep_generate_goal = False def load(self, model_path=None): self.user.load_state_dict(torch.load(model_path, map_location=DEVICE)) @@ -171,9 +174,14 @@ class UserActionPolicy(Policy): def load_state_dict(self, model=None): self.user.load_state_dict(model) - def get_goal(self): + def _get_goal(self): + # internal usage return self.goal.domain_goals + def get_goal(self): + # for outside usage, e.g. evaluator + return self.raw_goal + def get_reward(self): if self.goal.task_complete(): # reward = 2 * self.max_turn @@ -330,7 +338,7 @@ class UserActionPolicy(Policy): return predict_domain def _add_user_action(self, output, domain, slot): - goal = self.get_goal() + goal = self._get_goal() is_action = False act = [[]] value = None @@ -403,16 +411,32 @@ class UserPolicy(Policy): self.config = json.load(open(config)) else: self.config = config - self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}' + self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz' if not os.path.exists(self.config["model_dir"]): # os.mkdir(self.config["model_dir"]) model_downloader(os.path.dirname(self.config["model_dir"]), "https://zenodo.org/record/5779832/files/default.zip") + self.slot2dbattr = { + 'open hours': 'openhours', + 'price range': 'pricerange', + 'arrive by': 'arriveBy', + 'leave at': 'leaveAt', + 'train id': 'trainID' + } + self.dbattr2slot = {} + for k, v in self.slot2dbattr.items(): + self.dbattr2slot[v] = k self.policy = UserActionPolicy(self.config) def predict(self, state): - return self.policy.predict(state) + raw_act = self.policy.predict(state) + act = [] + for intent, domain, slot, value in raw_act: + if slot in self.dbattr2slot: + slot = self.dbattr2slot[slot] + act.append([intent, domain, slot, value]) + return act def init_session(self, goal=None): self.policy.init_session(goal) @@ -424,14 +448,8 @@ class UserPolicy(Policy): return self.policy.get_reward() def get_goal(self): - slot2dbattr = { - 'open hours': 'openhours', - 'price range': 'pricerange', - 'arrive by': 'arriveBy', - 'leave at': 'leaveAt', - 'train id': 'trainID' - } if hasattr(self.policy, 'get_goal'): + return self.policy.get_goal() # workaround: convert goal to old format multiwoz_goal = {} goal = self.policy.get_goal() @@ -449,8 +467,8 @@ class UserPolicy(Policy): multiwoz_goal[domain]["book"] = {} norm_slot = slot.split(' ')[-1] multiwoz_goal[domain]["book"][norm_slot] = value - elif slot in slot2dbattr: - norm_slot = slot2dbattr[slot] + elif slot in self.slot2dbattr: + norm_slot = self.slot2dbattr[slot] multiwoz_goal[domain][slot_type][norm_slot] = value else: multiwoz_goal[domain][slot_type][slot] = value diff --git a/convlab/policy/tus/unify/usermanager.py b/convlab/policy/tus/unify/usermanager.py index 640da7b9ee01414f04700e6aaa2976f9c916914f..3192d3d578ca3698424b16f47933d6863208f77c 100644 --- a/convlab/policy/tus/unify/usermanager.py +++ b/convlab/policy/tus/unify/usermanager.py @@ -97,7 +97,7 @@ class TUSDataManager(Dataset): action_list, user_goal, cur_state, usr_act) domain_label = self.feature_handler.domain_label( user_goal, usr_act) - pre_state = user_goal.update(action=usr_act, char="user") + # pre_state = user_goal.update(action=usr_act, char="user") # trick? feature["id"].append(dialog["dialogue_id"]) feature["input"].append(input_feature) feature["mask"].append(mask) diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py index 1978f489534f74839b11dee333e232cbc601f270..d65f72a06e181e66bfe0d7ac0c60f0c03a56ad43 100644 --- a/convlab/policy/tus/unify/util.py +++ b/convlab/policy/tus/unify/util.py @@ -1,9 +1,49 @@ from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal +from convlab.util import load_dataset + import json NOT_MENTIONED = "not mentioned" +def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1): + ratio = {'train': split2ratio, 'validation': split2ratio} + if data_name == "all" or data_name == "sgd+tm" or data_name == "tm": + print("merge all datasets...") + if data_name == "all": + all_dataset = ["multiwoz21", "sgd", "tm1", "tm2", "tm3"] + if data_name == "sgd+tm": + all_dataset = ["sgd", "tm1", "tm2", "tm3"] + if data_name == "tm": + all_dataset = ["tm1", "tm2", "tm3"] + + datasets = {} + for name in all_dataset: + datasets[name] = load_dataset( + name, + dial_ids_order=dial_ids_order, + split2ratio=ratio) + raw_data = merge_dataset(datasets, all_dataset[0]) + + else: + print(f"load single dataset {data_name}/{split2ratio}") + raw_data = load_dataset(data_name, + dial_ids_order=dial_ids_order, + split2ratio=ratio) + return raw_data + + +def merge_dataset(datasets, data_name): + data_split = [x for x in datasets[data_name]] + raw_data = {} + for data_type in data_split: + raw_data[data_type] = [] + for dataname, dataset in datasets.items(): + print(f"merge {dataname}...") + raw_data[data_type] += dataset[data_type] + return raw_data + + def int2onehot(index, output_dim=6, remove_zero=False): one_hot = [0] * output_dim if remove_zero: @@ -89,50 +129,6 @@ def get_booking_domain(slot, value, all_values, domain_list): return found -def act2slot(intent, domain, slot, value, all_values): - - if domain not in UsrDa2Goal: - # print(f"Not handle domain {domain}") - return "" - - if domain == "booking": - slot = SysDa2Goal[domain][slot] - domain = get_booking_domain(slot, value, all_values) - return f"{domain}-{slot}" - - elif domain in UsrDa2Goal: - if slot in SysDa2Goal[domain]: - slot = SysDa2Goal[domain][slot] - elif slot in UsrDa2Goal[domain]: - slot = UsrDa2Goal[domain][slot] - elif slot in SysDa2Goal["booking"]: - slot = SysDa2Goal["booking"][slot] - # else: - # print( - # f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}") - - return f"{domain}-{slot}" - - print("strange!!!") - print(intent, domain, slot, value) - - return "" - - -def get_user_history(dialog, all_values): - turn_num = len(dialog) - mentioned_slot = [] - for turn_id in range(0, turn_num, 2): - usr_act = parse_dialogue_act( - dialog[turn_id]["dialog_act"]) - for intent, domain, slot, value in usr_act: - slot_name = act2slot( - intent, domain.lower(), slot.lower(), value.lower(), all_values) - if slot_name not in mentioned_slot: - mentioned_slot.append(slot_name) - return mentioned_slot - - def update_config_file(file_name, attribute, value): with open(file_name, 'r') as config_file: config = json.load(config_file) @@ -147,7 +143,7 @@ def update_config_file(file_name, attribute, value): def create_goal(dialog) -> list: # a list of {'intent': ..., 'domain': ..., 'slot': ..., 'value': ...} dicts = [] - for i, turn in enumerate(dialog['turns']): + for turn in dialog['turns']: # print(turn['speaker']) # assert (i % 2 == 0) == (turn['speaker'] == 'user') # if i % 2 == 0: @@ -205,6 +201,45 @@ def split_slot_name(slot_name): return tokens[0], '-'.join(tokens[1:]) +# copy from data.unified_datasets.multiwoz21 +slot_name_map = { + 'addr': "address", + 'post': "postcode", + 'pricerange': "price range", + 'arrive': "arrive by", + 'arriveby': "arrive by", + 'leave': "leave at", + 'leaveat': "leave at", + 'depart': "departure", + 'dest': "destination", + 'fee': "entrance fee", + 'open': 'open hours', + 'car': "type", + 'car type': "type", + 'ticket': 'price', + 'trainid': 'train id', + 'id': 'train id', + 'people': 'book people', + 'stay': 'book stay', + 'none': '', + 'attraction': { + 'price': 'entrance fee' + }, + 'hospital': {}, + 'hotel': { + 'day': 'book day', 'price': "price range" + }, + 'restaurant': { + 'day': 'book day', 'time': 'book time', 'price': "price range" + }, + 'taxi': {}, + 'train': { + 'day': 'day', 'time': "duration" + }, + 'police': {}, + 'booking': {} +} + if __name__ == "__main__": print(split_slot_name("restaurant-search-location")) print(split_slot_name("sports-day.match"))