diff --git a/convlab/policy/genTUS/evaluate.py b/convlab/policy/genTUS/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..87de854970d2701900ba180d2bf15736071e0c1a
--- /dev/null
+++ b/convlab/policy/genTUS/evaluate.py
@@ -0,0 +1,257 @@
+import json
+import os
+import sys
+from argparse import ArgumentParser
+from pprint import pprint
+
+import torch
+from convlab.nlg.evaluate import fine_SER
+from datasets import load_metric
+
+# from convlab.policy.genTUS.pg.stepGenTUSagent import \
+#     stepGenTUSPG as UserPolicy
+from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
+from tqdm import tqdm
+
+sys.path.append(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.abspath(__file__)))))
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    parser.add_argument("--model-checkpoint", type=str, help="the model path")
+    parser.add_argument("--model-weight", type=str,
+                        help="the model weight", default="")
+    parser.add_argument("--input-file", type=str, help="the testing input file",
+                        default="")
+    parser.add_argument("--generated-file", type=str, help="the generated results",
+                        default="")
+    parser.add_argument("--only-action", action="store_true")
+    parser.add_argument("--dataset", default="multiwoz")
+    parser.add_argument("--do-semantic", action="store_true",
+                        help="do semantic evaluation")
+    parser.add_argument("--do-nlg", action="store_true",
+                        help="do nlg generation")
+    parser.add_argument("--do-golden-nlg", action="store_true",
+                        help="do golden nlg generation")
+    return parser.parse_args()
+
+
+class Evaluator:
+    def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False):
+        self.dataset = dataset
+        self.model_checkpoint = model_checkpoint
+        self.model_weight = model_weight
+        # if model_weight:
+        #     self.usr_policy = UserPolicy(
+        #         self.model_checkpoint, only_action=only_action)
+        #     self.usr_policy.load(model_weight)
+        #     self.usr = self.usr_policy.usr
+        # else:
+        self.usr = UserActionPolicy(
+            model_checkpoint, only_action=only_action, dataset=self.dataset)
+        self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
+
+    def generate_results(self, f_eval, golden=False):
+        in_file = json.load(open(f_eval))
+        r = {
+            "input": [],
+            "golden_acts": [],
+            "golden_utts": [],
+            "gen_acts": [],
+            "gen_utts": []
+        }
+        for dialog in tqdm(in_file['dialog']):
+            inputs = dialog["in"]
+            labels = self.usr._parse_output(dialog["out"])
+            if golden:
+                usr_act = labels["action"]
+                usr_utt = self.usr.generate_text_from_give_semantic(
+                    inputs, usr_act)
+
+            else:
+                output = self.usr._parse_output(
+                    self.usr._generate_action(inputs))
+                usr_act = self.usr._remove_illegal_action(output["action"])
+                usr_utt = output["text"]
+            r["input"].append(inputs)
+            r["golden_acts"].append(labels["action"])
+            r["golden_utts"].append(labels["text"])
+            r["gen_acts"].append(usr_act)
+            r["gen_utts"].append(usr_utt)
+
+        return r
+
+    def read_generated_result(self, f_eval):
+        in_file = json.load(open(f_eval))
+        r = {
+            "input": [],
+            "golden_acts": [],
+            "golden_utts": [],
+            "gen_acts": [],
+            "gen_utts": []
+        }
+        for dialog in tqdm(in_file['dialog']):
+            for x in dialog:
+                r[x].append(dialog[x])
+
+        return r
+
+    def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
+        if input_file:
+            print("Force generation")
+            gen_r = self.generate_results(input_file, golden)
+
+        elif generated_file:
+            gen_r = self.read_generated_result(generated_file)
+        else:
+            print("You must specify the input_file or the generated_file")
+
+        nlg_eval = {
+            "golden": golden,
+            "metrics": {},
+            "dialog": []
+        }
+        for input, golden_act, golden_utt, gen_act, gen_utt in zip(gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["gen_acts"], gen_r["gen_utts"]):
+            nlg_eval["dialog"].append({
+                "input": input,
+                "golden_acts": golden_act,
+                "golden_utts": golden_utt,
+                "gen_acts": gen_act,
+                "gen_utts": gen_utt
+            })
+
+        if golden:
+            print("Calculate BLEU")
+            bleu_metric = load_metric("sacrebleu")
+            labels = [[utt] for utt in gen_r["golden_utts"]]
+
+            bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"],
+                                             references=labels,
+                                             force=True)
+            print("bleu_metric", bleu_score)
+            nlg_eval["metrics"]["bleu"] = bleu_score
+
+        else:
+            print("Calculate SER")
+            missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
+                gen_r["gen_acts"], gen_r["gen_utts"])
+
+            print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
+                "genTUSNLG", missing, total, hallucinate, missing/total))
+            nlg_eval["metrics"]["SER"] = missing/total
+
+        dir_name = self.model_checkpoint
+        json.dump(nlg_eval,
+                  open(os.path.join(dir_name, "nlg_eval.json"), 'w'),
+                  indent=2)
+        return os.path.join(dir_name, "nlg_eval.json")
+
+    def evaluation(self, input_file=None, generated_file=None):
+        force_prediction = True
+        if generated_file:
+            gen_file = json.load(open(generated_file))
+            force_prediction = False
+            if gen_file["golden"]:
+                force_prediction = True
+
+        if force_prediction:
+            in_file = json.load(open(input_file))
+            dialog_result = []
+            gen_acts, golden_acts = [], []
+            # scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
+            for dialog in tqdm(in_file['dialog']):
+                inputs = dialog["in"]
+                labels = self.usr._parse_output(dialog["out"])
+                ans_action = self.usr._remove_illegal_action(labels["action"])
+                preds = self.usr._generate_action(inputs)
+                preds = self.usr._parse_output(preds)
+                usr_action = self.usr._remove_illegal_action(preds["action"])
+
+                gen_acts.append(usr_action)
+                golden_acts.append(ans_action)
+
+                d = {"input": inputs,
+                     "golden_acts": ans_action,
+                     "gen_acts": usr_action}
+                if "text" in preds:
+                    d["golden_utts"] = labels["text"]
+                    d["gen_utts"] = preds["text"]
+                    # print("pred text", preds["text"])
+
+                dialog_result.append(d)
+        else:
+            gen_acts, golden_acts = [], []
+            for dialog in gen_file['dialog']:
+                gen_acts.append(dialog["gen_acts"])
+                golden_acts.append(dialog["golden_acts"])
+            dialog_result = gen_file['dialog']
+
+        scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
+
+        for gen_act, golden_act in zip(gen_acts, golden_acts):
+            s = f1_measure(preds=gen_act, labels=golden_act)
+            for metric in scores:
+                scores[metric].append(s[metric])
+
+        result = {}
+        for metric in scores:
+            result[metric] = sum(scores[metric])/len(scores[metric])
+            print(f"{metric}: {result[metric]}")
+
+        result["dialog"] = dialog_result
+        basename = "semantic_evaluation_result"
+        json.dump(result, open(os.path.join(
+            self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
+        # if self.model_weight:
+        #     json.dump(result, open(os.path.join(
+        #         'results', f"{basename}.json"), 'w'))
+        # else:
+        #     json.dump(result, open(os.path.join(
+        #         self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
+
+
+def f1_measure(preds, labels):
+    tp = 0
+    score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0}
+    for p in preds:
+        if p in labels:
+            tp += 1.0
+    if preds:
+        score["precision"] = tp/len(preds)
+    if labels:
+        score["recall"] = tp/len(labels)
+    if (score["precision"] + score["recall"]) > 0:
+        score["f1"] = 2*(score["precision"]*score["recall"]) / \
+            (score["precision"]+score["recall"])
+    if tp == len(preds) and tp == len(labels):
+        score["turn_acc"] = 1
+    return score
+
+
+def main():
+    args = arg_parser()
+    eval = Evaluator(args.model_checkpoint,
+                     args.dataset,
+                     args.model_weight,
+                     args.only_action)
+    print("model checkpoint", args.model_checkpoint)
+    print("generated_file", args.generated_file)
+    print("input_file", args.input_file)
+    with torch.no_grad():
+        if args.do_semantic:
+            eval.evaluation(args.input_file)
+        if args.do_nlg:
+            nlg_result = eval.nlg_evaluation(input_file=args.input_file,
+                                             generated_file=args.generated_file,
+                                             golden=args.do_golden_nlg)
+            if args.generated_file:
+                generated_file = args.generated_file
+            else:
+                generated_file = nlg_result
+            eval.evaluation(args.input_file,
+                            generated_file)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/convlab/policy/genTUS/ppo/vector.py b/convlab/policy/genTUS/ppo/vector.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c502a46f87582008ff49219f8a14844378b9ed2
--- /dev/null
+++ b/convlab/policy/genTUS/ppo/vector.py
@@ -0,0 +1,148 @@
+import json
+
+import torch
+from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
+from convlab.policy.genTUS.token_map import tokenMap
+from convlab.policy.tus.unify.Goal import Goal
+from transformers import BartTokenizer
+
+
+class stepGenTUSVector:
+    def __init__(self, model_checkpoint, max_in_len=400, max_out_len=80, allow_general_intent=True):
+        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
+        self.vocab = len(self.tokenizer)
+        self.max_in_len = max_in_len
+        self.max_out_len = max_out_len
+        self.token_map = tokenMap(tokenizer=self.tokenizer)
+        self.token_map.default(only_action=True)
+        self.kg = KnowledgeGraph(self.tokenizer)
+        self.mentioned_domain = []
+        self.allow_general_intent = allow_general_intent
+        self.candidate_num = 5
+        if self.allow_general_intent:
+            print("---> allow_general_intent")
+
+    def init_session(self, goal: Goal):
+        self.goal = goal
+        self.mentioned_domain = []
+
+    def encode(self, raw_inputs, max_length, return_tensors="pt", truncation=True):
+        model_input = self.tokenizer(raw_inputs,
+                                     max_length=max_length,
+                                     return_tensors=return_tensors,
+                                     truncation=truncation,
+                                     padding="max_length")
+        return model_input
+
+    def decode(self, generated_so_far, skip_special_tokens=True):
+        output = self.tokenizer.decode(
+            generated_so_far, skip_special_tokens=skip_special_tokens)
+        return output
+
+    def state_vectorize(self, action, history, turn):
+        self.goal.update_user_goal(action=action)
+        inputs = json.dumps({"system": action,
+                             "goal": self.goal.get_goal_list(),
+                             "history": history,
+                             "turn": str(turn)})
+        inputs = self.encode(inputs, self.max_in_len)
+        s_vec, action_mask = inputs["input_ids"][0], inputs["attention_mask"][0]
+
+        return s_vec, action_mask
+
+    def action_vectorize(self, action, s=None):
+        # action:  [[intent, domain, slot, value], ...]
+        vec = {"vector": torch.tensor([]), "mask": torch.tensor([])}
+        if s is not None:
+            raw_inputs = self.decode(s[0])
+            self.kg.parse_input(raw_inputs)
+
+        self._append(vec, self._get_id("<s>"))
+        self._append(vec, self.token_map.get_id('start_json'))
+        self._append(vec, self.token_map.get_id('start_act'))
+
+        act_len = len(action)
+        for i, (intent, domain, slot, value) in enumerate(action):
+            if value == '?':
+                value = '<?>'
+            c_idx = {x: None for x in ["intent", "domain", "slot", "value"]}
+
+            if s is not None:
+                c_idx["intent"] = self._candidate_id(self.kg.candidate(
+                    "intent", allow_general_intent=self.allow_general_intent))
+                c_idx["domain"] = self._candidate_id(self.kg.candidate(
+                    "domain", intent=intent))
+                c_idx["slot"] = self._candidate_id(self.kg.candidate(
+                    "slot", intent=intent, domain=domain, is_mentioned=self.is_mentioned(domain)))
+                c_idx["value"] = self._candidate_id(self.kg.candidate(
+                    "value", intent=intent, domain=domain, slot=slot))
+
+            self._append(vec, self._get_id(intent), c_idx["intent"])
+            self._append(vec, self.token_map.get_id('sep_token'))
+            self._append(vec, self._get_id(domain), c_idx["domain"])
+            self._append(vec, self.token_map.get_id('sep_token'))
+            self._append(vec, self._get_id(slot), c_idx["slot"])
+            self._append(vec, self.token_map.get_id('sep_token'))
+            self._append(vec, self._get_id(value), c_idx["value"])
+
+            c_idx = [0]*self.candidate_num
+            c_idx[0] = self.token_map.get_id('end_act')[0]
+            c_idx[1] = self.token_map.get_id('sep_act')[0]
+            if i == act_len - 1:
+                x = self.token_map.get_id('end_act')
+            else:
+                x = self.token_map.get_id('sep_act')
+
+            self._append(vec, x, c_idx)
+
+        self._append(vec, self._get_id("</s>"))
+
+        # pad
+        if len(vec["vector"]) < self.max_out_len:
+            pad_len = self.max_out_len-len(vec["vector"])
+            self._append(vec, x=torch.tensor([1]*pad_len))
+        for vec_type in vec:
+            vec[vec_type] = vec[vec_type].to(torch.int64)
+
+        return vec
+
+    def _append(self, vec, x, candidate=None):
+        if type(x) is list:
+            x = torch.tensor(x)
+        mask = self._mask(x, candidate)
+        vec["vector"] = torch.cat((vec["vector"], x), dim=-1)
+        vec["mask"] = torch.cat((vec["mask"], mask), dim=0)
+
+    def _mask(self, idx, c_idx=None):
+        mask = torch.zeros(len(idx), self.candidate_num)
+        mask[:, 0] = idx
+        if c_idx is not None and len(c_idx) > 1:
+            mask[0, :] = torch.tensor(c_idx)
+
+        return mask
+
+    def _candidate_id(self, candidate):
+        if len(candidate) > self.candidate_num:
+            print(f"too many candidates. Max = {self.candidate_num}")
+        c_idx = [0]*self.candidate_num
+        for i, idx in enumerate([self._get_id(c)[0] for c in candidate[:self.candidate_num]]):
+            c_idx[i] = idx
+        return c_idx
+
+    def _get_id(self, value):
+        token_id = self.tokenizer(value, add_special_tokens=False)
+        return token_id["input_ids"]
+
+    def action_devectorize(self, action_id):
+        return self.decode(action_id)
+
+    def update_mentioned_domain(self, semantic_act):
+        for act in semantic_act:
+            domain = act[1]
+            if domain not in self.mentioned_domain:
+                self.mentioned_domain.append(domain)
+
+    def is_mentioned(self, domain):
+        if domain in self.mentioned_domain:
+            return True
+        return False
diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16c0ebeeeaaba50f1739aea6c4db40eb81f8d29
--- /dev/null
+++ b/convlab/policy/genTUS/stepGenTUS.py
@@ -0,0 +1,653 @@
+import json
+import os
+
+import torch
+from transformers import BartTokenizer
+
+from convlab.policy.genTUS.ppo.vector import stepGenTUSVector
+from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel
+from convlab.policy.genTUS.token_map import tokenMap
+from convlab.policy.genTUS.unify.Goal import Goal
+from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
+from convlab.policy.policy import Policy
+from convlab.task.multiwoz.goal_generator import GoalGenerator
+
+DEBUG = False
+
+
+class UserActionPolicy(Policy):
+    def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs):
+        self.mode = mode
+        # if mode == "semantic" and only_action:
+        #     # only generate semantic action in prediction
+        print("model_checkpoint", model_checkpoint)
+        self.only_action = only_action
+        if self.only_action:
+            print("change mode to semantic because only_action=True")
+            self.mode = "semantic"
+        self.max_in_len = 500
+        self.max_out_len = 50 if only_action else 200
+        max_act_len = kwargs.get("max_act_len", 2)
+        print("max_act_len", max_act_len)
+        self.max_action_len = max_act_len
+        if "max_act_len" in kwargs:
+            self.max_out_len = 30 * self.max_action_len
+            print("max_act_len", self.max_out_len)
+        self.max_turn = max_turn
+        if mode not in ["semantic", "language"]:
+            print("Unknown user mode")
+
+        self.reward = {"success":  self.max_turn*2,
+                       "fail": self.max_turn*-1}
+        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
+        self.device = "cuda" if torch.cuda.is_available() else "cpu"
+        train_whole_model = kwargs.get("whole_model", True)
+        self.model = stepGenTUSmodel(
+            model_checkpoint, train_whole_model=train_whole_model)
+        self.model.eval()
+        self.model.to(self.device)
+        self.model.share_memory()
+
+        self.turn_level_reward = kwargs.get("turn_level_reward", True)
+        self.cooperative = kwargs.get("cooperative", True)
+
+        dataset = kwargs.get("dataset", "")
+        self.kg = KnowledgeGraph(
+            tokenizer=self.tokenizer,
+            dataset=dataset)
+
+        self.goal_gen = GoalGenerator()
+
+        self.vector = stepGenTUSVector(
+            model_checkpoint, self.max_in_len, self.max_out_len)
+        self.norm_reward = False
+
+        self.action_penalty = kwargs.get("action_penalty", False)
+        self.usr_act_penalize = kwargs.get("usr_act_penalize", 0)
+        self.goal_list_type = kwargs.get("goal_list_type", "normal")
+        self.update_mode = kwargs.get("update_mode", "normal")
+        self.max_history = kwargs.get("max_history", 3)
+        self.init_session()
+
+    def _update_seq(self, sub_seq: list, pos: int):
+        for x in sub_seq:
+            self.seq[0, pos] = x
+            pos += 1
+
+        return pos
+
+    def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True):
+        # TODO no duplicate
+        self.kg.parse_input(raw_inputs)
+        model_input = self.vector.encode(raw_inputs, self.max_in_len)
+        # start token
+        self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
+        pos = self._update_seq([0], 0)
+        pos = self._update_seq(self.token_map.get_id('start_json'), pos)
+        pos = self._update_seq(self.token_map.get_id('start_act'), pos)
+
+        # get semantic actions
+        for act_len in range(self.max_action_len):
+            pos = self._get_semantic_action(
+                model_input, pos, mode, allow_general_intent)
+
+            terminate, token_name = self._stop_semantic(
+                model_input, pos, act_len)
+            pos = self._update_seq(self.token_map.get_id(token_name), pos)
+
+            if terminate:
+                break
+
+        if self.only_action:
+            # return semantic action. Don't need to generate text
+            return self.vector.decode(self.seq[0, :pos])
+
+        # TODO remove illegal action here?
+
+        # get text output
+        pos = self._update_seq(self.token_map.get_id("start_text"), pos)
+
+        text = self._get_text(model_input, pos)
+
+        return text
+
+    def generate_text_from_give_semantic(self, raw_inputs, semantic_action):
+        self.kg.parse_input(raw_inputs)
+        model_input = self.vector.encode(raw_inputs, self.max_in_len)
+        self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
+        pos = self._update_seq([0], 0)
+        pos = self._update_seq(self.token_map.get_id('start_json'), pos)
+        pos = self._update_seq(self.token_map.get_id('start_act'), pos)
+
+        if len(semantic_action) == 0:
+            pos = self._update_seq(self.token_map.get_id("end_act"), pos)
+
+        for act_id, (intent, domain, slot, value) in enumerate(semantic_action):
+            pos = self._update_seq(self.kg._get_token_id(intent), pos)
+            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+            pos = self._update_seq(self.kg._get_token_id(domain), pos)
+            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+            pos = self._update_seq(self.kg._get_token_id(slot), pos)
+            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+            pos = self._update_seq(self.kg._get_token_id(value), pos)
+
+            if act_id == len(semantic_action) - 1:
+                token_name = "end_act"
+            else:
+                token_name = "sep_act"
+            pos = self._update_seq(self.token_map.get_id(token_name), pos)
+        pos = self._update_seq(self.token_map.get_id("start_text"), pos)
+
+        raw_output = self._get_text(model_input, pos)
+        return self._parse_output(raw_output)["text"]
+
+    def _get_text(self, model_input, pos):
+        s_pos = pos
+        for i in range(s_pos, self.max_out_len):
+            next_token_logits = self.model.get_next_token_logits(
+                model_input, self.seq[:1, :pos])
+            next_token = torch.argmax(next_token_logits, dim=-1)
+
+            if self._stop_text(next_token):
+                # text = self.vector.decode(self.seq[0, s_pos:pos])
+                # text = self._norm_str(text)
+                # return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
+                break
+
+            pos = self._update_seq([next_token], pos)
+        text = self.vector.decode(self.seq[0, s_pos:pos])
+        text = self._norm_str(text)
+        return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
+        # TODO return None
+
+    def _stop_text(self, next_token):
+        if next_token == self.token_map.get_id("end_json")[0]:
+            return True
+        elif next_token == self.token_map.get_id("end_json_2")[0]:
+            return True
+
+        return False
+
+    @staticmethod
+    def _norm_str(text: str):
+        text = text.strip('"')
+        text = text.replace('"', "'")
+        text = text.replace('\\', "")
+        return text
+
+    def _stop_semantic(self, model_input, pos, act_length=0):
+
+        outputs = self.model.get_next_token_logits(
+            model_input, self.seq[:1, :pos])
+        tokens = {}
+        for token_name in ['sep_act', 'end_act']:
+            tokens[token_name] = {
+                "token_id": self.token_map.get_id(token_name)}
+            hash_id = tokens[token_name]["token_id"][0]
+            tokens[token_name]["score"] = outputs[:, hash_id].item()
+
+        if tokens['end_act']["score"] > tokens['sep_act']["score"]:
+            terminate = True
+        else:
+            terminate = False
+
+        if act_length >= self.max_action_len - 1:
+            terminate = True
+
+        token_name = "end_act" if terminate else "sep_act"
+
+        return terminate, token_name
+
+    def _get_semantic_action(self, model_input, pos, mode="max", allow_general_intent=True):
+
+        intent = self._get_intent(
+            model_input, self.seq[:1, :pos], mode, allow_general_intent)
+        pos = self._update_seq(intent["token_id"], pos)
+        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+
+        # get domain
+        domain = self._get_domain(
+            model_input, self.seq[:1, :pos], intent["token_name"], mode)
+        pos = self._update_seq(domain["token_id"], pos)
+        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+
+        # get slot
+        slot = self._get_slot(
+            model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode)
+        pos = self._update_seq(slot["token_id"], pos)
+        pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+
+        # get value
+
+        value = self._get_value(
+            model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], slot["token_name"], mode)
+        pos = self._update_seq(value["token_id"], pos)
+
+        return pos
+
+    def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+
+        return self.kg.get_intent(next_token_logits, mode, allow_general_intent)
+
+    def _get_domain(self, model_input, generated_so_far, intent, mode="max"):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+
+        return self.kg.get_domain(next_token_logits, intent, mode)
+
+    def _get_slot(self, model_input, generated_so_far, intent, domain, mode="max"):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+        is_mentioned = self.vector.is_mentioned(domain)
+        return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned)
+
+    def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+
+        return self.kg.get_value(next_token_logits, intent, domain, slot, mode)
+
+    def _remove_illegal_action(self, action):
+        # Transform illegal action to legal action
+        new_action = []
+        for act in action:
+            if len(act) == 4:
+                if "<?>" in act[-1]:
+                    act = [act[0], act[1], act[2], "?"]
+                if act not in new_action:
+                    new_action.append(act)
+            else:
+                print("illegal action:", action)
+        return new_action
+
+    def _parse_output(self, in_str):
+        in_str = str(in_str)
+        in_str = in_str.replace('<s>', '').replace(
+            '<\\s>', '').replace('o"clock', "o'clock")
+        action = {"action": [], "text": ""}
+        try:
+            action = json.loads(in_str)
+        except:
+            print("invalid action:", in_str)
+            print("-"*20)
+        return action
+
+    def predict(self, sys_act, mode="max", allow_general_intent=True):
+        # raw_sys_act = sys_act
+        # sys_act = sys_act[:5]
+        # update goal
+        # TODO
+        allow_general_intent = False
+        self.model.eval()
+
+        if not self.add_sys_from_reward:
+            self.goal.update_user_goal(action=sys_act, char="sys")
+            self.sys_acts.append(sys_act)  # for terminate conversation
+
+        # update constraint
+        self.time_step += 2
+
+        history = []
+        if self.usr_acts:
+            if self.max_history == 1:
+                history = self.usr_acts[-1]
+            else:
+                history = self.usr_acts[-1*self.max_history:]
+        inputs = json.dumps({"system": sys_act,
+                             "goal": self.goal.get_goal_list(),
+                             "history": history,
+                             "turn": str(int(self.time_step/2))})
+        with torch.no_grad():
+            raw_output = self._generate_action(
+                raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
+        output = self._parse_output(raw_output)
+        self.semantic_action = self._remove_illegal_action(output["action"])
+        if not self.only_action:
+            self.utterance = output["text"]
+
+        # TODO
+        if self.is_finish():
+            self.semantic_action, self.utterance = self._good_bye()
+
+        # if self.is_finish():
+        #     print("terminated")
+
+        # if self.is_finish():
+        #     good_bye = self._good_bye()
+        #     self.goal.add_usr_da(good_bye)
+        #     return good_bye
+
+        self.goal.update_user_goal(action=self.semantic_action, char="usr")
+        self.vector.update_mentioned_domain(self.semantic_action)
+        self.usr_acts.append(self.semantic_action)
+
+        # if self._usr_terminate(usr_action):
+        #     print("terminated by user")
+        #     self.terminated = True
+
+        del inputs
+
+        if self.mode == "language":
+            # print("in", sys_act)
+            # print("out", self.utterance)
+            return self.utterance
+        else:
+            return self.semantic_action
+
+    def init_session(self, goal=None):
+        self.token_map = tokenMap(tokenizer=self.tokenizer)
+        self.token_map.default(only_action=self.only_action)
+        self.time_step = 0
+        remove_domain = "police"  # remove police domain in inference
+
+        if not goal:
+            self._new_goal(remove_domain=remove_domain)
+        else:
+            self._read_goal(goal)
+
+        self.vector.init_session(goal=self.goal)
+
+        self.terminated = False
+        self.add_sys_from_reward = False
+        self.sys_acts = []
+        self.usr_acts = []
+        self.semantic_action = []
+        self.utterance = ""
+
+    def _read_goal(self, data_goal):
+        self.goal = Goal(goal=data_goal)
+
+    def _new_goal(self, remove_domain="police", domain_len=None):
+        self.goal = Goal(goal_generator=self.goal_gen)
+        # keep_generate_goal = True
+        # # domain_len = 1
+        # while keep_generate_goal:
+        #     self.goal = Goal(goal_generator=self.goal_gen,
+        #                      goal_list_type=self.goal_list_type,
+        #                      update_mode=self.update_mode)
+        #     if (domain_len and len(self.goal.domains) != domain_len) or \
+        #             (remove_domain and remove_domain in self.goal.domains):
+        #         keep_generate_goal = True
+        #     else:
+        #         keep_generate_goal = False
+
+    def load(self, model_path):
+        self.model.load_state_dict(torch.load(
+            model_path, map_location=self.device))
+        # self.model = BartForConditionalGeneration.from_pretrained(
+        #     model_checkpoint)
+
+    def get_goal(self):
+        if self.goal.raw_goal is not None:
+            return self.goal.raw_goal
+        goal = {}
+        for domain in self.goal.domain_goals:
+            if domain not in goal:
+                goal[domain] = {}
+            for intent in self.goal.domain_goals[domain]:
+                if intent == "inform":
+                    slot_type = "info"
+                elif intent == "request":
+                    slot_type = "reqt"
+                elif intent == "book":
+                    slot_type = "book"
+                else:
+                    print("unknown slot type")
+                if slot_type not in goal[domain]:
+                    goal[domain][slot_type] = {}
+                for slot, value in self.goal.domain_goals[domain][intent].items():
+                    goal[domain][slot_type][slot] = value
+        return goal
+
+    def get_reward(self, sys_response=None):
+        self.add_sys_from_reward = False if sys_response is None else True
+
+        if self.add_sys_from_reward:
+            self.goal.update_user_goal(action=sys_response, char="sys")
+            self.goal.add_sys_da(sys_response)  # for evaluation
+            self.sys_acts.append(sys_response)  # for terminate conversation
+
+        if self.is_finish():
+            if self.is_success():
+                reward = self.reward["success"]
+                self.success = True
+            else:
+                reward = self.reward["fail"]
+                self.success = False
+
+        else:
+            reward = -1
+            if self.turn_level_reward:
+                reward += self.turn_reward()
+
+            self.success = None
+            # if self.action_penalty:
+            #     reward += self._system_action_penalty()
+
+        if self.norm_reward:
+            reward = (reward - 20)/60
+        return reward
+
+    def _system_action_penalty(self):
+        free_action_len = 3
+        if len(self.sys_acts) < 1:
+            return 0
+        # TODO only penalize the slots not in user goal
+        # else:
+        #     penlaty = 0
+        #     for i in range(len(self.sys_acts[-1])):
+        #         penlaty += -1*i
+        #     return penlaty
+        if len(self.sys_acts[-1]) > 3:
+            return -1*(len(self.sys_acts[-1])-free_action_len)
+        return 0
+
+    def turn_reward(self):
+        r = 0
+        r += self._new_act_reward()
+        r += self._reply_reward()
+        r += self._usr_act_len()
+        return r
+
+    def _usr_act_len(self):
+        last_act = self.usr_acts[-1]
+        penalty = 0
+        if len(last_act) > 2:
+            penalty = (2-len(last_act))*self.usr_act_penalize
+        return penalty
+
+    def _new_act_reward(self):
+        last_act = self.usr_acts[-1]
+        if last_act != self.semantic_action:
+            print(f"---> why? last {last_act} usr {self.semantic_action}")
+        new_act = []
+        for act in last_act:
+            if len(self.usr_acts) < 2:
+                break
+            if act[1].lower() == "general":
+                new_act.append(0)
+            elif act in self.usr_acts[-2]:
+                new_act.append(-1)
+            elif act not in self.usr_acts[-2]:
+                new_act.append(1)
+
+        return sum(new_act)
+
+    def _reply_reward(self):
+        if self.cooperative:
+            return self._cooperative_reply_reward()
+        else:
+            return self._non_cooperative_reply_reward()
+
+    def _non_cooperative_reply_reward(self):
+        r = []
+        reqts = []
+        infos = []
+        reply_len = 0
+        max_len = 1
+        for act in self.sys_acts[-1]:
+            if act[0] == "request":
+                reqts.append([act[1], act[2]])
+        for act in self.usr_acts[-1]:
+            if act[0] == "inform":
+                infos.append([act[1], act[2]])
+        for req in reqts:
+            if req in infos:
+                if reply_len < max_len:
+                    r.append(1)
+                elif reply_len == max_len:
+                    r.append(0)
+                else:
+                    r.append(-5)
+
+        if r:
+            return sum(r)
+        return 0
+
+    def _cooperative_reply_reward(self):
+        r = []
+        reqts = []
+        infos = []
+        for act in self.sys_acts[-1]:
+            if act[0] == "request":
+                reqts.append([act[1], act[2]])
+        for act in self.usr_acts[-1]:
+            if act[0] == "inform":
+                infos.append([act[1], act[2]])
+        for req in reqts:
+            if req in infos:
+                r.append(1)
+            else:
+                r.append(-1)
+        if r:
+            return sum(r)
+        return 0
+
+    def _usr_terminate(self):
+        for act in self.semantic_action:
+            if act[0] in ['thank', 'bye']:
+                return True
+        return False
+
+    def is_finish(self):
+        # stop by model generation?
+        if self._finish_conversation_rule():
+            self.terminated = True
+            return True
+        elif self._usr_terminate():
+            self.terminated = True
+            return True
+        self.terminated = False
+        return False
+
+    def is_success(self):
+        task_complete = self.goal.task_complete()
+        # goal_status = self.goal.all_mentioned()
+        # should mentioned all slots
+        if task_complete:  # and goal_status["complete"] > 0.6:
+            return True
+        return False
+
+    def _good_bye(self):
+        if self.is_success():
+            return [['thank', 'general', 'none', 'none']], "thank you. bye"
+            # if self.mode == "semantic":
+            #     return [['thank', 'general', 'none', 'none']]
+            # else:
+            #     return "bye"
+        else:
+            return [["bye", "general", "None", "None"]], "bye"
+            if self.mode == "semantic":
+                return [["bye", "general", "None", "None"]]
+            return "bye"
+
+    def _finish_conversation_rule(self):
+        if self.is_success():
+            return True
+
+        if self.time_step > self.max_turn:
+            return True
+
+        if (len(self.sys_acts) > 4) and (self.sys_acts[-1] == self.sys_acts[-2]) and (self.sys_acts[-2] == self.sys_acts[-3]):
+            return True
+        return False
+
+    def is_terminated(self):
+        # Is there any action to say?
+        self.is_finish()
+        return self.terminated
+
+
+class UserPolicy(Policy):
+    def __init__(self,
+                 model_checkpoint,
+                 mode="semantic",
+                 only_action=True,
+                 sample=False,
+                 action_penalty=False,
+                 **kwargs):
+        # self.config = config
+        # if not os.path.exists(self.config["model_dir"]):
+        #     os.mkdir(self.config["model_dir"])
+        #     model_downloader(self.config["model_dir"],
+        #                      "https://zenodo.org/record/5779832/files/default.zip")
+
+        self.policy = UserActionPolicy(
+            model_checkpoint,
+            mode=mode,
+            only_action=only_action,
+            action_penalty=action_penalty,
+            **kwargs)
+        self.policy.load(os.path.join(
+            model_checkpoint, "pytorch_model.bin"))
+        self.sample = sample
+
+    def predict(self, sys_act, mode="max"):
+        if self.sample:
+            mode = "sample"
+        else:
+            mode = "max"
+        response = self.policy.predict(sys_act, mode)
+        return response
+
+    def init_session(self, goal=None):
+        self.policy.init_session(goal)
+
+    def is_terminated(self):
+        return self.policy.is_terminated()
+
+    def get_reward(self, sys_response=None):
+        return self.policy.get_reward(sys_response)
+
+    def get_goal(self):
+        if hasattr(self.policy, 'get_goal'):
+            return self.policy.get_goal()
+        return None
+
+
+if __name__ == "__main__":
+    import os
+
+    from convlab.dialog_agent import PipelineAgent
+    # from convlab.nlu.jointBERT.multiwoz import BERTNLU
+    from convlab.util.custom_util import set_seed
+
+    set_seed(20220220)
+    # Test semantic level behaviour
+    model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0'
+    usr_policy = UserPolicy(
+        model_checkpoint,
+        mode="semantic")
+    usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
+    usr_nlu = None  # BERTNLU()
+    usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
+    print(usr.policy.get_goal())
+
+    print(usr.response([]))
+    print(usr.policy.policy.goal.status)
+    print(usr.response([["request", "attraction", "area", "?"]]))
+    print(usr.policy.policy.goal.status)
+    print(usr.response([["request", "attraction", "area", "?"]]))
+    print(usr.policy.policy.goal.status)
diff --git a/convlab/policy/genTUS/stepGenTUSmodel.py b/convlab/policy/genTUS/stepGenTUSmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2eaf7bc5064262808120aac4a9cbe2eb007d863
--- /dev/null
+++ b/convlab/policy/genTUS/stepGenTUSmodel.py
@@ -0,0 +1,114 @@
+
+import json
+
+import torch
+from torch.nn.functional import softmax, one_hot, cross_entropy
+
+from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
+from convlab.policy.genTUS.token_map import tokenMap
+from convlab.policy.genTUS.utils import append_tokens
+from transformers import (BartConfig, BartForConditionalGeneration,
+                          BartTokenizer)
+
+
+class stepGenTUSmodel(BartForConditionalGeneration):
+    def __init__(self, model_checkpoint, train_whole_model=True, **kwargs):
+        config = BartConfig.from_pretrained(model_checkpoint)
+        super().__init__(config, **kwargs)
+
+        self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
+        self.vocab = len(self.tokenizer)
+        self.kg = KnowledgeGraph(self.tokenizer)
+        self.action_kg = KnowledgeGraph(self.tokenizer)
+        self.token_map = tokenMap(self.tokenizer)
+        # only_action doesn't matter. it is only used for get_log_prob
+        self.token_map.default(only_action=True)
+
+        if not train_whole_model:
+            for param in self.parameters():
+                param.requires_grad = False
+
+            for param in self.model.decoder.layers[-1].fc1.parameters():
+                param.requires_grad = True
+            for param in self.model.decoder.layers[-1].fc2.parameters():
+                param.requires_grad = True
+
+    def get_trainable_param(self):
+
+        return filter(
+            lambda p: p.requires_grad, self.parameters())
+
+    def get_next_token_logits(self, model_input, generated_so_far):
+        input_ids = model_input["input_ids"].to(self.device)
+        attention_mask = model_input["attention_mask"].to(self.device)
+        outputs = self.forward(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=generated_so_far,
+            return_dict=True)
+        return outputs.logits[:, -1, :]
+
+    def get_log_prob(self, s, a, action_mask, prob_mask):
+        output = self.forward(input_ids=s,
+                              attention_mask=action_mask,
+                              decoder_input_ids=a)
+        prob = self._norm_prob(a[:, 1:].long(),
+                               output.logits[:, :-1, :],
+                               prob_mask[:, 1:, :].long())
+        return prob
+
+    def _norm_prob(self, a, prob, mask):
+        prob = softmax(prob, -1)
+        base = self._base(prob, mask).to(self.device)  # [b, seq_len]
+        prob = (prob*one_hot(a, num_classes=self.vocab)).sum(-1)
+        prob = torch.log(prob / base)
+        pad_mask = a != 1
+        prob = prob*pad_mask.float()
+        return prob.sum(-1)
+
+    @staticmethod
+    def _base(prob, mask):
+        batch_size, seq_len, dim = prob.shape
+        base = torch.zeros(batch_size, seq_len)
+        for b in range(batch_size):
+            for s in range(seq_len):
+                temp = [prob[b, s, c] for c in mask[b, s, :] if c > 0]
+                base[b, s] = torch.sum(torch.tensor(temp))
+        return base
+
+
+if __name__ == "__main__":
+    import os
+    from convlab.util.custom_util import set_seed
+    from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
+    set_seed(0)
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+
+    model_checkpoint = 'results/genTUS-22-01-31-09-21/'
+    usr = UserActionPolicy(model_checkpoint=model_checkpoint)
+    usr.model.load_state_dict(torch.load(
+        os.path.join(model_checkpoint, "pytorch_model.bin"), map_location=device))
+    usr.model.eval()
+
+    test_file = "convlab/policy/genTUS/data/goal_status_validation_v1.json"
+    data = json.load(open(test_file))
+    test_id = 20
+    inputs = usr.tokenizer(data["dialog"][test_id]["in"],
+                           max_length=400,
+                           return_tensors="pt",
+                           truncation=True)
+
+    actions = [data["dialog"][test_id]["out"],
+               data["dialog"][test_id+100]["out"]]
+
+    for action in actions:
+        action = json.loads(action)
+        vec = usr.vector.action_vectorize(
+            action["action"], s=inputs["input_ids"])
+
+        print({"action": action["action"]})
+        print("get_log_prob", usr.model.get_log_prob(
+            inputs["input_ids"],
+            torch.unsqueeze(vec["vector"], 0),
+            inputs["attention_mask"],
+            torch.unsqueeze(vec["mask"], 0)))
diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..7825c2880928c40f68284b0c3199932cd1cfc477
--- /dev/null
+++ b/convlab/policy/genTUS/token_map.py
@@ -0,0 +1,64 @@
+import json
+
+
+class tokenMap:
+    def __init__(self, tokenizer):
+        self.tokenizer = tokenizer
+        self.token_name = {}
+        self.hash_map = {}
+        self.debug = False
+        self.default()
+
+    def default(self, only_action=False):
+        self.format_tokens = {
+            'start_json': '{"action": [',   # 49643, 10845, 7862, 646
+            'start_act': '["',              # 49329
+            'sep_token': '", "',            # 1297('",'), 22
+            'sep_act': '"], ["',               # 49177
+            'end_act': '"]], "',            # 42248, 7479, 22
+            'start_text': 'text": "',       # 29015, 7862, 22
+            'end_json': '}',                 # 24303
+            'end_json_2': '"}'                 # 48805
+        }
+        if only_action:
+            self.format_tokens['end_act'] = '"]]}'
+        for token_name in self.format_tokens:
+            self.add_token(
+                token_name, self.format_tokens[token_name])
+
+    def add_token(self, token_name, value):
+        if token_name in self.token_name and self.debug:
+            print(f"---> duplicate token: {token_name}({value})!!!!!!!")
+
+        token_id = self.tokenizer(str(value), add_special_tokens=False)[
+            "input_ids"]
+        self.token_name[token_name] = {"value": value, "token_id": token_id}
+        # print(token_id)
+        hash_id = token_id[0]
+        if hash_id in self.hash_map and self.debug:
+            print(
+                f"---> conflict hash number {hash_id}: {self.hash_map[hash_id]['name']} and {token_name}")
+        self.hash_map[hash_id] = {
+            "name": token_name, "value": value, "token_id": token_id}
+
+    def get_info(self, hash_id):
+        return self.hash_map[hash_id]
+
+    def get_id(self, token_name):
+        # workaround
+        # if token_name not in self.token_name[token_name]:
+        #     self.add_token(token_name, token_name)
+        return self.token_name[token_name]["token_id"]
+
+    def get_token_value(self, token_name):
+        return self.token_name[token_name]["value"]
+
+    def token_name_is_in(self, token_name):
+        if token_name in self.token_name:
+            return True
+        return False
+
+    def hash_id_is_in(self, hash_id):
+        if hash_id in self.hash_map:
+            return True
+        return False
diff --git a/convlab/policy/genTUS/train_model.py b/convlab/policy/genTUS/train_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2162417461d692514e3b27742dbfd477491fc24e
--- /dev/null
+++ b/convlab/policy/genTUS/train_model.py
@@ -0,0 +1,258 @@
+import json
+import os
+import sys
+from argparse import ArgumentParser
+from datetime import datetime
+from pprint import pprint
+import numpy as np
+import torch
+import transformers
+from datasets import Dataset, load_metric
+from tqdm import tqdm
+from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
+                          BartForConditionalGeneration, BartTokenizer,
+                          DataCollatorForSeq2Seq, Seq2SeqTrainer,
+                          Seq2SeqTrainingArguments)
+
+sys.path.append(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.abspath(__file__)))))
+
+os.environ["WANDB_DISABLED"] = "true"
+
+METRIC = load_metric("sacrebleu")
+TOKENIZER = BartTokenizer.from_pretrained("facebook/bart-base")
+TOKENIZER.add_tokens(["<?>"])
+MAX_IN_LEN = 500
+MAX_OUT_LEN = 500
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    # data_name, dial_ids_order, split2ratio
+    parser.add_argument("--model-type", type=str, default="unify",
+                        help="unify or multiwoz")
+    parser.add_argument("--data-name", type=str, default="multiwoz21",
+                        help="multiwoz21, sgd, tm1, tm2, tm3, sgd+tm, or all")
+    parser.add_argument("--dial-ids-order", type=int, default=0)
+    parser.add_argument("--split2ratio", type=float, default=1)
+    parser.add_argument("--batch-size", type=int, default=16)
+    parser.add_argument("--model-checkpoint", type=str,
+                        default="facebook/bart-base")
+    return parser.parse_args()
+
+
+def gentus_compute_metrics(eval_preds):
+    preds, labels = eval_preds
+    if isinstance(preds, tuple):
+        preds = preds[0]
+    decoded_preds = TOKENIZER.batch_decode(
+        preds, skip_special_tokens=True, max_length=MAX_OUT_LEN)
+
+    # Replace -100 in the labels as we can't decode them.
+    labels = np.where(labels != -100, labels, TOKENIZER.pad_token_id)
+    decoded_labels = TOKENIZER.batch_decode(
+        labels, skip_special_tokens=True, max_length=MAX_OUT_LEN)
+
+    act, text = postprocess_text(decoded_preds, decoded_labels)
+
+    result = METRIC.compute(
+        # predictions=decoded_preds, references=decoded_labels)
+        predictions=text["preds"], references=text["labels"])
+    result = {"bleu": result["score"]}
+    f1_scores = f1_measure(pred_acts=act["preds"], label_acts=act["labels"])
+    for s in f1_scores:
+        result[s] = f1_scores[s]
+
+    result = {k: round(v, 4) for k, v in result.items()}
+    return result
+
+
+def postprocess_text(preds, labels):
+    act = {"preds": [], "labels": []}
+    text = {"preds": [], "labels": []}
+
+    for pred, label in zip(preds, labels):
+        model_output = parse_output(pred.strip())
+        label_output = parse_output(label.strip())
+        if len(label_output["text"]) < 1:
+            continue
+        act["preds"].append(model_output.get("action", []))
+        text["preds"].append(model_output.get("text", pred.strip()))
+        act["labels"].append(label_output["action"])
+        text["labels"].append([label_output["text"]])
+
+    return act, text
+
+
+def parse_output(in_str):
+    in_str = in_str.replace('<s>', '').replace('<\\s>', '')
+    try:
+        output = json.loads(in_str)
+    except:
+        # print(f"invalid action {in_str}")
+        output = {"action": [], "text": ""}
+    return output
+
+
+def f1_measure(pred_acts, label_acts):
+    result = {"precision": [], "recall": [], "f1": []}
+    for pred, label in zip(pred_acts, label_acts):
+        r = tp_fn_fp(pred, label)
+        for m in result:
+            result[m].append(r[m])
+    for m in result:
+        result[m] = sum(result[m])/len(result[m])
+
+    return result
+
+
+def tp_fn_fp(pred, label):
+    tp, fn, fp = 0.0, 0.0, 0.0
+    precision, recall, f1 = 0, 0, 0
+    for p in pred:
+        if p in label:
+            tp += 1
+        else:
+            fp += 1
+    for l in label:
+        if l not in pred:
+            fn += 1
+    if (tp+fp) > 0:
+        precision = tp / (tp+fp)
+    if (tp+fn) > 0:
+        recall = tp/(tp+fn)
+    if (precision + recall) > 0:
+        f1 = (2*precision*recall)/(precision+recall)
+
+    return {"precision": precision, "recall": recall, "f1": f1}
+
+
+class TrainerHelper:
+    def __init__(self, tokenizer, max_input_length=500, max_target_length=500):
+        print("transformers version is: ", transformers.__version__)
+        self.tokenizer = tokenizer
+        self.max_input_length = max_input_length
+        self.max_target_length = max_target_length
+        self.base_name = "convlab/policy/genTUS"
+        self.dir_name = ""
+
+    def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
+        # base_name = "convlab/policy/genTUS/unify/data"
+        if model_type not in ["unify", "multiwoz"]:
+            print("Unknown model type. Currently only support unify and multiwoz")
+        self.dir_name = f"{data_name}_{dial_ids_order}_{split2ratio}"
+        return os.path.join(self.base_name, model_type, 'data', self.dir_name)
+
+    def get_model_folder(self, model_type):
+        folder_name = os.path.join(
+            self.base_name, model_type, "experiments", self.dir_name)
+        if not os.path.exists(folder_name):
+            os.makedirs(folder_name)
+        return folder_name
+
+    def parse_data(self, model_type, data_name, dial_ids_order=0, split2ratio=1):
+        data_folder = self._get_data_folder(
+            model_type, data_name, dial_ids_order, split2ratio)
+
+        raw_data = {}
+        for d_type in ["train", "validation", "test"]:
+            f_name = os.path.join(data_folder, f"{d_type}.json")
+            raw_data[d_type] = json.load(open(f_name))
+
+        tokenized_datasets = {}
+        for data_type, data in raw_data.items():
+            tokenized_datasets[data_type] = Dataset.from_dict(
+                self._preprocess(data["dialog"]))
+
+        return tokenized_datasets
+
+    def _preprocess(self, examples):
+        model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+        if isinstance(examples, dict):
+            examples = [examples]
+        for example in tqdm(examples):
+            inputs = self.tokenizer(example["in"],
+                                    max_length=self.max_input_length,
+                                    truncation=True)
+
+            # Setup the tokenizer for targets
+            with self.tokenizer.as_target_tokenizer():
+                labels = self.tokenizer(example["out"],
+                                        max_length=self.max_target_length,
+                                        truncation=True)
+            for key in ["input_ids", "attention_mask"]:
+                model_inputs[key].append(inputs[key])
+            model_inputs["labels"].append(labels["input_ids"])
+
+        return model_inputs
+
+
+def train(model_type, data_name, dial_ids_order, split2ratio, batch_size=16, max_input_length=500, max_target_length=500, model_checkpoint="facebook/bart-base"):
+    tokenizer = TOKENIZER
+
+    train_helper = TrainerHelper(
+        tokenizer=tokenizer, max_input_length=max_input_length, max_target_length=max_target_length)
+    data = train_helper.parse_data(model_type=model_type,
+                                   data_name=data_name,
+                                   dial_ids_order=dial_ids_order,
+                                   split2ratio=split2ratio)
+
+    model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
+    model.resize_token_embeddings(len(tokenizer))
+    fp16 = False
+    if torch.cuda.is_available():
+        fp16 = True
+
+    model_dir = os.path.join(
+        train_helper.get_model_folder(model_type),
+        f"{datetime.now().strftime('%y-%m-%d-%H-%M')}")
+
+    args = Seq2SeqTrainingArguments(
+        model_dir,
+        evaluation_strategy="epoch",
+        learning_rate=2e-5,
+        per_device_train_batch_size=batch_size,
+        per_device_eval_batch_size=batch_size,
+        weight_decay=0.01,
+        save_total_limit=2,
+        num_train_epochs=5,
+        predict_with_generate=True,
+        fp16=fp16,
+        push_to_hub=False,
+        generation_max_length=max_target_length,
+        logging_dir=os.path.join(model_dir, 'log')
+    )
+    data_collator = DataCollatorForSeq2Seq(
+        tokenizer, model=model, padding=True)
+
+    # customize this trainer
+    trainer = Seq2SeqTrainer(
+        model=model,
+        args=args,
+        train_dataset=data["train"],
+        eval_dataset=data["test"],
+        data_collator=data_collator,
+        tokenizer=tokenizer,
+        compute_metrics=gentus_compute_metrics)
+    print("start training...")
+    trainer.train()
+    print("saving model...")
+    trainer.save_model()
+
+
+def main():
+    args = arg_parser()
+    print("---> data_name", args.data_name)
+    train(model_type=args.model_type,
+          data_name=args.data_name,
+          dial_ids_order=args.dial_ids_order,
+          split2ratio=args.split2ratio,
+          batch_size=args.batch_size,
+          max_input_length=MAX_IN_LEN,
+          max_target_length=MAX_OUT_LEN,
+          model_checkpoint=args.model_checkpoint)
+
+
+if __name__ == "__main__":
+    main()
+    # sgd+tm: 46000
diff --git a/convlab/policy/genTUS/unify/Goal.py b/convlab/policy/genTUS/unify/Goal.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f19338fc2f7ebc521bda7e1a8f877a6417a57e
--- /dev/null
+++ b/convlab/policy/genTUS/unify/Goal.py
@@ -0,0 +1,233 @@
+"""
+The user goal for unify data format
+"""
+import json
+from convlab.policy.tus.unify.Goal import old_goal2list
+from convlab.task.multiwoz.goal_generator import GoalGenerator
+from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
+from convlab.util.custom_util import slot_mapping
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, ""]
+
+NOT_MENTIONED = "not mentioned"
+FULFILLED = "fulfilled"
+REQUESTED = "requested"
+CONFLICT = "conflict"
+
+
+class Goal:
+    """ User Goal Model Class. """
+
+    def __init__(self, goal=None, goal_generator=None):
+        """
+        create new Goal from a dialog or from goal_generator
+        Args:
+            goal: can be a list (create from a dialog), an abus goal, or none
+        """
+        self.domains = []
+        self.domain_goals = {}
+        self.status = {}
+        self.invert_slot_mapping = {v: k for k, v in slot_mapping.items()}
+        self.raw_goal = None
+
+        self._init_goal_from_data(goal, goal_generator)
+        self._init_status()
+
+    def __str__(self):
+        return '-----Goal-----\n' + \
+               json.dumps(self.domain_goals, indent=4) + \
+               '\n-----Goal-----'
+
+    def _init_goal_from_data(self, goal=None, goal_generator=None):
+        if not goal and goal_generator:
+            goal = ABUS_Goal(goal_generator)
+            self.raw_goal = goal.domain_goals
+            goal = old_goal2list(goal.domain_goals)
+
+        elif isinstance(goal, dict):
+            self.raw_goal = goal
+            goal = old_goal2list(goal)
+
+        elif isinstance(goal, ABUS_Goal):
+            self.raw_goal = goal.domain_goals
+            goal = old_goal2list(goal.domain_goals)
+
+        else:
+            print("unknow goal")
+
+        # be careful of this order
+        for domain, intent, slot, value in goal:
+            if domain == "none":
+                continue
+            if domain not in self.domains:
+                self.domains.append(domain)
+                self.domain_goals[domain] = {}
+            if intent not in self.domain_goals[domain]:
+                self.domain_goals[domain][intent] = {}
+
+            if not value:
+                if intent == "request":
+                    self.domain_goals[domain][intent][slot] = DEF_VAL_UNK
+                else:
+                    print(
+                        f"unknown no value intent {domain}, {intent}, {slot}")
+            else:
+                self.domain_goals[domain][intent][slot] = value
+
+    def _init_status(self):
+        for domain, domain_goal in self.domain_goals.items():
+            if domain not in self.status:
+                self.status[domain] = {}
+            for slot_type, sub_domain_goal in domain_goal.items():
+                if slot_type not in self.status[domain]:
+                    self.status[domain][slot_type] = {}
+                for slot in sub_domain_goal:
+                    if slot not in self.status[domain][slot_type]:
+                        self.status[domain][slot_type][slot] = {}
+                    self.status[domain][slot_type][slot] = {
+                        "value": str(sub_domain_goal[slot]),
+                        "status": NOT_MENTIONED}
+
+    def get_goal_list(self, data_goal=None):
+        goal_list = []
+        if data_goal:
+            # make sure the order!!!
+            for domain, intent, slot, _ in data_goal:
+                status = self._get_status(domain, intent, slot)
+                value = self.domain_goals[domain][intent][slot]
+                goal_list.append([intent, domain, slot, value, status])
+            return goal_list
+        else:
+            for domain, domain_goal in self.domain_goals.items():
+                for intent, sub_goal in domain_goal.items():
+                    for slot, value in sub_goal.items():
+                        status = self._get_status(domain, intent, slot)
+                        goal_list.append([intent, domain, slot, value, status])
+
+        return goal_list
+
+    def _get_status(self, domain, intent, slot):
+        if domain not in self.status:
+            return NOT_MENTIONED
+        if intent not in self.status[domain]:
+            return NOT_MENTIONED
+        if slot not in self.status[domain][intent]:
+            return NOT_MENTIONED
+        return self.status[domain][intent][slot]["status"]
+
+    def task_complete(self):
+        """
+        Check that all requests have been met
+        Returns:
+            (boolean): True to accomplish.
+        """
+        for domain, domain_goal in self.status.items():
+            if domain not in self.domain_goals:
+                continue
+            for slot_type, sub_domain_goal in domain_goal.items():
+                if slot_type not in self.domain_goals[domain]:
+                    continue
+                for slot, status in sub_domain_goal.items():
+                    if slot not in self.domain_goals[domain][slot_type]:
+                        continue
+                    # for strict success, turn this on
+                    if status["status"] in [NOT_MENTIONED, CONFLICT]:
+                        if status["status"] == CONFLICT and slot in ["arrive by", "leave at"]:
+                            continue
+                        return False
+                    if "?" in status["value"]:
+                        return False
+
+        return True
+
+    # TODO change to update()?
+    def update_user_goal(self, action, char="usr"):
+        # update request and booked
+        if char == "usr":
+            self._user_action_update(action)
+        elif char == "sys":
+            self._system_action_update(action)
+        else:
+            print("!!!UNKNOWN CHAR!!!")
+
+    def _user_action_update(self, action):
+        # no need to update user goal
+        for intent, domain, slot, _ in action:
+            goal_intent = self._check_slot_and_intent(domain, slot)
+            if not goal_intent:
+                continue
+            # fulfilled by user
+            if is_inform(intent):
+                self._set_status(goal_intent, domain, slot, FULFILLED)
+            # requested by user
+            if is_request(intent):
+                self._set_status(goal_intent, domain, slot, REQUESTED)
+
+    def _system_action_update(self, action):
+        for intent, domain, slot, value in action:
+            goal_intent = self._check_slot_and_intent(domain, slot)
+            if not goal_intent:
+                continue
+            # fulfill request by system
+            if is_inform(intent) and is_request(goal_intent):
+                self._set_status(goal_intent, domain, slot, FULFILLED)
+                self._set_goal(goal_intent, domain, slot, value)
+
+            if is_inform(intent) and is_inform(goal_intent):
+                # fulfill infrom by system
+                if value == self.domain_goals[domain][goal_intent][slot]:
+                    self._set_status(goal_intent, domain, slot, FULFILLED)
+                # conflict system inform
+                else:
+                    self._set_status(goal_intent, domain, slot, CONFLICT)
+            # requested by system
+            if is_request(intent) and is_inform(goal_intent):
+                self._set_status(goal_intent, domain, slot, REQUESTED)
+
+    def _set_status(self, intent, domain, slot, status):
+        self.status[domain][intent][slot]["status"] = status
+
+    def _set_goal(self, intent, domain, slot, value):
+        # old_value = self.domain_goals[domain][intent][slot]
+        self.domain_goals[domain][intent][slot] = value
+        self.status[domain][intent][slot]["value"] = value
+        # print(
+        #     f"updating user goal {intent}-{domain}-{slot} {old_value}-> {value}")
+
+    def _check_slot_and_intent(self, domain, slot):
+        not_found = ""
+        if domain not in self.domain_goals:
+            return not_found
+        for intent in self.domain_goals[domain]:
+            if slot in self.domain_goals[domain][intent]:
+                return intent
+        return not_found
+
+
+def is_inform(intent):
+    if "inform" in intent:
+        return True
+    return False
+
+
+def is_request(intent):
+    if "request" in intent:
+        return True
+    return False
+
+
+def transform_data_act(data_action):
+    action_list = []
+    for _, dialog_act in data_action.items():
+        for act in dialog_act:
+            value = act.get("value", "")
+            if not value:
+                if "request" in act["intent"]:
+                    value = "?"
+                else:
+                    value = "none"
+            action_list.append(
+                [act["intent"], act["domain"], act["slot"], value])
+    return action_list
diff --git a/convlab/policy/genTUS/unify/build_data.py b/convlab/policy/genTUS/unify/build_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..50873a1d4b6ffaa8ee49c84a4b088ae56ad13554
--- /dev/null
+++ b/convlab/policy/genTUS/unify/build_data.py
@@ -0,0 +1,211 @@
+import json
+import os
+import sys
+from argparse import ArgumentParser
+
+from tqdm import tqdm
+
+from convlab.policy.genTUS.unify.Goal import Goal, transform_data_act
+from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset
+
+
+sys.path.append(os.path.dirname(os.path.dirname(
+    os.path.dirname(os.path.abspath(__file__)))))
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    parser.add_argument("--dataset", type=str, default="multiwoz21",
+                        help="the dataset, such as multiwoz21, sgd, tm1, tm2, and tm3.")
+    parser.add_argument("--dial-ids-order", type=int, default=0)
+    parser.add_argument("--split2ratio", type=float, default=1)
+    parser.add_argument("--random-order", action="store_true")
+    parser.add_argument("--no-status", action="store_true")
+    parser.add_argument("--add-history",  action="store_true")
+    parser.add_argument("--remove-domain", type=str, default="")
+
+    return parser.parse_args()
+
+class DataBuilder:
+    def __init__(self, dataset='multiwoz21'):
+        self.dataset = dataset
+
+    def setup_data(self,
+                   raw_data,
+                   random_order=False,
+                   no_status=False,
+                   add_history=False,
+                   remove_domain=None):
+        examples = {data_split: {"dialog": []} for data_split in raw_data}
+
+        for data_split, dialogs in raw_data.items():
+            for dialog in tqdm(dialogs, ascii=True):
+                example = self._one_dialog(dialog=dialog,
+                                           add_history=add_history,
+                                           random_order=random_order,
+                                           no_status=no_status)
+                examples[data_split]["dialog"] += example
+
+        return examples
+
+    def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False):
+        example = []
+        history = []
+
+        data_goal = self.norm_domain_goal(create_goal(dialog))
+        if not data_goal:
+            return example
+        user_goal = Goal(goal=data_goal)
+
+        for turn_id in range(0, len(dialog["turns"]), 2):
+            sys_act = self._get_sys_act(dialog, turn_id)
+
+            user_goal.update_user_goal(action=sys_act, char="sys")
+            usr_goal_str = self._user_goal_str(user_goal, data_goal, random_order, no_status)
+
+            usr_act = self.norm_domain(transform_data_act(
+                dialog["turns"][turn_id]["dialogue_acts"]))
+            user_goal.update_user_goal(action=usr_act, char="usr")
+
+            # change value "?" to "<?>"
+            usr_act = self._modify_act(usr_act)
+
+            in_str = self._dump_in_str(sys_act, usr_goal_str, history, turn_id, add_history)
+            out_str = self._dump_out_str(usr_act, dialog["turns"][turn_id]["utterance"])
+
+            history.append(usr_act)
+            if usr_act:
+                example.append({"in": in_str, "out": out_str})
+
+        return example
+
+    def _get_sys_act(self, dialog, turn_id):
+        sys_act = []
+        if turn_id > 0:
+            sys_act = self.norm_domain(transform_data_act(
+                dialog["turns"][turn_id - 1]["dialogue_acts"]))
+        return sys_act
+
+    def _user_goal_str(self, user_goal, data_goal, random_order, no_status):
+        if random_order:
+            usr_goal_str = user_goal.get_goal_list()
+        else:
+            usr_goal_str = user_goal.get_goal_list(data_goal=data_goal)
+
+        if no_status:
+            usr_goal_str = self._remove_status(usr_goal_str)
+        return usr_goal_str
+
+    def _dump_in_str(self, sys_act, usr_goal_str, history, turn_id, add_history):
+        in_str = {}
+        in_str["system"] = self._modify_act(sys_act)
+        in_str["goal"] = usr_goal_str
+        if add_history:
+            h = []
+            if history:
+                h = history[-3:]
+            in_str["history"] = h
+            in_str["turn"] = str(int(turn_id/2))
+
+        return json.dumps(in_str)
+
+    def _dump_out_str(self, usr_act, text):
+        out_str = {"action": usr_act, "text": text}
+        return json.dumps(out_str)
+
+    @staticmethod
+    def _norm_intent(intent):
+        if intent in ["inform_intent", "negate_intent", "affirm_intent", "request_alts"]:
+            return f"_{intent}"
+        return intent
+
+    def norm_domain(self, x):
+        if not x:
+            return x
+        norm_result = []
+        # print(x)
+        for intent, domain, slot, value in x:
+            if "_" in domain:
+                domain = domain.split('_')[0]
+            if not domain:
+                domain = "none"
+            if not slot:
+                slot = "none"
+            if not value:
+                if intent == "request":
+                    value = "<?>"
+                else:
+                    value = "none"
+            norm_result.append([self._norm_intent(intent), domain, slot, value])
+        return norm_result
+
+    def norm_domain_goal(self, x):
+        if not x:
+            return x
+        norm_result = []
+        # take care of the order!
+        for domain, intent, slot, value in x:
+            if "_" in domain:
+                domain = domain.split('_')[0]
+            if not domain:
+                domain = "none"
+            if not slot:
+                slot = "none"
+            if not value:
+                if intent == "request":
+                    value = "<?>"
+                else:
+                    value = "none"
+            norm_result.append([domain, self._norm_intent(intent), slot, value])
+        return norm_result
+
+    @staticmethod
+    def _remove_status(goal_list):
+        new_list = [[goal[0], goal[1], goal[2], goal[3]]
+                    for goal in goal_list]
+        return new_list
+
+    @staticmethod
+    def _modify_act(act):
+        new_act = []
+        for i, d, s, value in act:
+            if value == "?":
+                new_act.append([i, d, s, "<?>"])
+            else:
+                new_act.append([i, d, s, value])
+        return new_act
+
+
+if __name__ == "__main__":
+    args = arg_parser()
+
+    base_name = "convlab/policy/genTUS/unify/data"
+    dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}"
+    folder_name = os.path.join(base_name, dir_name)
+    remove_domain = args.remove_domain
+
+    if not os.path.exists(folder_name):
+        os.makedirs(folder_name)
+
+    dataset = load_experiment_dataset(
+        data_name=args.dataset,
+        dial_ids_order=args.dial_ids_order,
+        split2ratio=args.split2ratio)
+    data_builder = DataBuilder(dataset=args.dataset)
+    data = data_builder.setup_data(
+        raw_data=dataset,
+        random_order=args.random_order,
+        no_status=args.no_status,
+        add_history=args.add_history,
+        remove_domain=remove_domain)
+
+    for data_type in data:
+        if remove_domain:
+            file_name = os.path.join(
+                folder_name,
+                f"no{remove_domain}_{data_type}.json")
+        else:
+            file_name = os.path.join(
+                folder_name,
+                f"{data_type}.json")
+        json.dump(data[data_type], open(file_name, 'w'), indent=2)
diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..68af13e481fe4799dfc2a6f3763b526611eabd9c
--- /dev/null
+++ b/convlab/policy/genTUS/unify/knowledge_graph.py
@@ -0,0 +1,252 @@
+import json
+from random import choices
+
+from convlab.policy.genTUS.token_map import tokenMap
+
+from transformers import BartTokenizer
+
+DEBUG = False
+DATASET = "unify"
+
+
+class KnowledgeGraph:
+    def __init__(self, tokenizer: BartTokenizer, ontology_file=None, dataset="multiwoz21"):
+        print("dataset", dataset)
+        self.debug = DEBUG
+        self.tokenizer = tokenizer
+
+        if "multiwoz" in dataset:
+            self.domain_intent = ["inform", "request"]
+            self.general_intent = ["thank", "bye"]
+        # use sgd dataset intents as default
+        else:
+            self.domain_intent = ["_inform_intent",
+                                  "_negate_intent",
+                                  "_affirm_intent",
+                                  "inform",
+                                  "request",
+                                  "affirm",
+                                  "negate",
+                                  "select",
+                                  "_request_alts"]
+            self.general_intent = ["thank_you", "goodbye"]
+
+        self.general_domain = "none"
+        self.kg_map = {"intent": tokenMap(tokenizer=self.tokenizer)}
+
+        for intent in self.domain_intent + self.general_intent:
+            self.kg_map["intent"].add_token(intent, intent)
+
+        self.init()
+
+    def init(self):
+        for map_type in ["domain", "slot", "value"]:
+            self.kg_map[map_type] = tokenMap(tokenizer=self.tokenizer)
+        self.add_token("<?>", "value")
+
+    def parse_input(self, in_str):
+        self.init()
+        inputs = json.loads(in_str)
+        self.sys_act = inputs["system"]
+        self.user_goal = {}
+        self._add_none_domain()
+        for intent, domain, slot, value, _ in inputs["goal"]:
+            self._update_user_goal(intent, domain, slot, value, source="goal")
+
+        for intent, domain, slot, value in self.sys_act:
+            self._update_user_goal(intent, domain, slot, value, source="sys")
+
+    def _add_none_domain(self):
+        self.user_goal["none"] = {"none": "none"}
+        # add slot
+        self.add_token("none", "domain")
+        self.add_token("none", "slot")
+        self.add_token("none", "value")
+
+    def _update_user_goal(self, intent, domain, slot, value, source="goal"):
+
+        if value == "?":
+            value = "<?>"
+
+        if intent == "request" and source == "sys":
+            value = "dontcare"  # user can "dontcare" system request
+
+        if source == "sys" and intent != "request":
+            return
+
+        if domain not in self.user_goal:
+            self.user_goal[domain] = {}
+            self.user_goal[domain]["none"] = ["none"]
+            self.add_token(domain, "domain")
+            self.add_token("none", "slot")
+            self.add_token("none", "value")
+
+        if slot not in self.user_goal[domain]:
+            self.user_goal[domain][slot] = []
+            self.add_token(domain, "slot")
+
+        if value not in self.user_goal[domain][slot]:
+            value = json.dumps(str(value))[1:-1]
+            self.user_goal[domain][slot].append(value)
+            value = value.replace('"', "'")
+            self.add_token(value, "value")
+
+    def add_token(self, token_name, map_type):
+        if map_type == "value":
+            token_name = token_name.replace('"', "'")
+        if not self.kg_map[map_type].token_name_is_in(token_name):
+            self.kg_map[map_type].add_token(token_name, token_name)
+
+    def _get_max_score(self, outputs, candidate_list, map_type):
+        score = {}
+        if not candidate_list:
+            print(f"ERROR: empty candidate list for {map_type}")
+            score[1] = {"token_id": self._get_token_id(
+                "none"), "token_name": "none"}
+
+        for x in candidate_list:
+            hash_id = self._get_token_id(x)[0]
+            s = outputs[:, hash_id].item()
+            score[s] = {"token_id": self._get_token_id(x),
+                        "token_name": x}
+        return score
+
+    def _select(self, score, mode="max"):
+        probs = [s for s in score]
+        if mode == "max":
+            s = max(probs)
+        elif mode == "sample":
+            s = choices(probs, weights=probs, k=1)
+            s = s[0]
+
+        else:
+            print("unknown select mode")
+
+        return s
+
+    def _get_max_domain_token(self, outputs, candidates, map_type, mode="max"):
+        score = self._get_max_score(outputs, candidates, map_type)
+        s = self._select(score, mode)
+        token_id = score[s]["token_id"]
+        token_name = score[s]["token_name"]
+
+        return {"token_id": token_id, "token_name": token_name}
+
+    def candidate(self, candidate_type, **kwargs):
+        if "intent" in kwargs:
+            intent = kwargs["intent"]
+        if candidate_type == "intent":
+            allow_general_intent = kwargs.get("allow_general_intent", True)
+            if allow_general_intent:
+                return self.domain_intent + self.general_intent
+            else:
+                return self.domain_intent
+        elif candidate_type == "domain":
+            if intent in self.general_intent:
+                return [self.general_domain]
+            else:
+                return [d for d in self.user_goal]
+        elif candidate_type == "slot":
+            if intent in self.general_intent:
+                return ["none"]
+            else:
+                return self._filter_slot(intent, kwargs["domain"], kwargs["is_mentioned"])
+        else:
+            if intent in self.general_intent:
+                return ["none"]
+            elif intent.lower() == "request":
+                return ["<?>"]
+            else:
+                return self._filter_value(intent, kwargs["domain"], kwargs["slot"])
+
+    def get_intent(self, outputs, mode="max", allow_general_intent=True):
+        # return intent, token_id_list
+        # TODO request?
+        canidate_list = self.candidate(
+            "intent", allow_general_intent=allow_general_intent)
+        score = self._get_max_score(outputs, canidate_list, "intent")
+        s = self._select(score, mode)
+
+        return score[s]
+
+    def get_domain(self, outputs, intent, mode="max"):
+        if intent in self.general_intent:
+            token_name = self.general_domain
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent in self.domain_intent:
+            # [d for d in self.user_goal]
+            domain_list = self.candidate("domain", intent=intent)
+            token_map = self._get_max_domain_token(
+                outputs=outputs, candidates=domain_list, map_type="domain", mode=mode)
+        else:
+            if self.debug:
+                print("unknown intent", intent)
+
+        return token_map
+
+    def get_slot(self, outputs, intent, domain, mode="max", is_mentioned=False):
+        if intent in self.general_intent:
+            token_name = "none"
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent in self.domain_intent:
+            slot_list = self.candidate(
+                candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned)
+            token_map = self._get_max_domain_token(
+                outputs=outputs, candidates=slot_list, map_type="slot", mode=mode)
+
+        return token_map
+
+    def get_value(self, outputs, intent, domain, slot, mode="max"):
+        if intent in self.general_intent or slot.lower() == "none":
+            token_name = "none"
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent.lower() == "request":
+            token_name = "<?>"
+            token_id = self.tokenizer(token_name, add_special_tokens=False)
+            token_map = {"token_id": token_id['input_ids'],
+                         "token_name": token_name}
+
+        elif intent in self.domain_intent:
+            # TODO should not none ?
+            # value_list = [v for v in self.user_goal[domain][slot]]
+            value_list = self.candidate(
+                candidate_type="value", intent=intent, domain=domain, slot=slot)
+
+            token_map = self._get_max_domain_token(
+                outputs=outputs, candidates=value_list, map_type="value", mode=mode)
+
+        return token_map
+
+    def _filter_slot(self, intent, domain, is_mentioned=True):
+        slot_list = []
+        for slot in self.user_goal[domain]:
+            value_list = self._filter_value(intent, domain, slot)
+            if len(value_list) > 0:
+                slot_list.append(slot)
+        if not is_mentioned and intent.lower() != "request":
+            slot_list.append("none")
+        return slot_list
+
+    def _filter_value(self, intent, domain, slot):
+        value_list = [v for v in self.user_goal[domain][slot]]
+        if "none" in value_list:
+            value_list.remove("none")
+        if intent.lower() != "request":
+            if "?" in value_list:
+                value_list.remove("?")
+            if "<?>" in value_list:
+                value_list.remove("<?>")
+        # print(f"{intent}-{domain}-{slot}= {value_list}")
+        return value_list
+
+    def _get_token_id(self, token):
+        return self.tokenizer(token, add_special_tokens=False)["input_ids"]
diff --git a/convlab/policy/genTUS/utils.py b/convlab/policy/genTUS/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c822dd35f985790d53a2834e8b6fe437864f24
--- /dev/null
+++ b/convlab/policy/genTUS/utils.py
@@ -0,0 +1,5 @@
+import torch
+
+
+def append_tokens(tokens, new_token, device):
+    return torch.cat((tokens, torch.tensor([new_token]).to(device)), dim=1)
diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
new file mode 100644
index 0000000000000000000000000000000000000000..5bf65c9f97f951a367a0abf461e9aa9172d64021
--- /dev/null
+++ b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
@@ -0,0 +1,44 @@
+{
+	"model": {
+		"load_path": "convlab/policy/ppo/pretrained_models/supervised",
+		"pretrained_load_path": "",
+		"use_pretrained_initialisation": false,
+		"batchsz": 1000,
+		"seed": 0,
+		"epoch": 50,
+		"eval_frequency": 5,
+		"process_num": 1,
+		"num_eval_dialogues": 500,
+		"sys_semantic_to_usr": false
+	},
+	"vectorizer_sys": {
+		"uncertainty_vector_mul": {
+			"class_path": "convlab.policy.vector.vector_binary.VectorBinary",
+			"ini_params": {
+				"use_masking": true,
+				"manually_add_entity_names": true,
+				"seed": 0
+			}
+		}
+	},
+	"nlu_sys": {},
+	"dst_sys": {
+		"RuleDST": {
+			"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
+			"ini_params": {}
+		}
+	},
+	"sys_nlg": {},
+	"nlu_usr": {},
+	"dst_usr": {},
+	"policy_usr": {
+		"RulePolicy": {
+			"class_path": "convlab.policy.genTUS.stepGenTUS.UserPolicy",
+			"ini_params": {
+				"model_checkpoint": "convlab/policy/genTUS/unify/experiments/multiwoz21_0_1.0",
+				"character": "usr"
+			}
+		}
+	},
+	"usr_nlg": {}
+}
\ No newline at end of file
diff --git a/convlab/policy/tus/unify/Goal.py b/convlab/policy/tus/unify/Goal.py
index 82a6e7a9799b4537495cf35e57f6f52711e78af8..610469bb4246cf6c16ef4271c89827cab6421437 100644
--- a/convlab/policy/tus/unify/Goal.py
+++ b/convlab/policy/tus/unify/Goal.py
@@ -1,6 +1,9 @@
 import time
 import json
-from convlab.policy.tus.unify.util import split_slot_name
+from convlab.policy.tus.unify.util import split_slot_name, slot_name_map
+from convlab.util.custom_util import slot_mapping
+
+from random import sample, shuffle
 from pprint import pprint
 DEF_VAL_UNK = '?'  # Unknown
 DEF_VAL_DNC = 'dontcare'  # Do not care
@@ -27,6 +30,35 @@ def isTimeFormat(input):
         return False
 
 
+def old_goal2list(goal: dict, reorder=False) -> list:
+    goal_list = []
+    for domain in goal:
+        for slot_type in ['info', 'book', 'reqt']:
+            if slot_type not in goal[domain]:
+                continue
+            temp = []
+            for slot in goal[domain][slot_type]:
+                s = slot
+                if slot in slot_name_map:
+                    s = slot_name_map[slot]
+                elif slot in slot_name_map[domain]:
+                    s = slot_name_map[domain][slot]
+                # domain, intent, slot, value
+                if slot_type in ['info', 'book']:
+                    i = "inform"
+                    v = goal[domain][slot_type][slot]
+                else:
+                    i = "request"
+                    v = DEF_VAL_UNK
+                s = slot_mapping.get(s, s)
+                temp.append([domain, i, s, v])
+            shuffle(temp)
+            goal_list = goal_list + temp
+    # shuffle_goal = goal_list[:1] + sample(goal_list[1:], len(goal_list)-1)
+    # return shuffle_goal
+    return goal_list
+
+
 class Goal(object):
     """ User Goal Model Class. """
 
@@ -101,6 +133,7 @@ class Goal(object):
                     if self.domain_goals[domain]["reqt"][slot] == DEF_VAL_UNK:
                         # print(f"not fulfilled request{domain}-{slot}")
                         return False
+
         return True
 
     def init_local_id(self):
@@ -169,6 +202,10 @@ class Goal(object):
 
     def _update_status(self, action: list, char: str):
         for intent, domain, slot, value in action:
+            if slot == "arrive by":
+                slot = "arriveBy"
+            elif slot == "leave at":
+                slot = "leaveAt"
             if domain not in self.status:
                 self.status[domain] = {}
             # update info
@@ -180,6 +217,10 @@ class Goal(object):
     def _update_goal(self, action: list, char: str):
         # update requt slots in goal
         for intent, domain, slot, value in action:
+            if slot == "arrive by":
+                slot = "arriveBy"
+            elif slot == "leave at":
+                slot = "leaveAt"
             if "info" not in intent:
                 continue
             if self._check_update_request(domain, slot) and value != "?":
diff --git a/convlab/policy/tus/unify/TUS.py b/convlab/policy/tus/unify/TUS.py
index 09c50672fb889ffd3e965ec0e90d2dee30b2c360..c380df3914d5d636c95b0e76081b361c26d79ff3 100644
--- a/convlab/policy/tus/unify/TUS.py
+++ b/convlab/policy/tus/unify/TUS.py
@@ -5,19 +5,18 @@ from copy import deepcopy
 
 import torch
 from convlab.policy.policy import Policy
-from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import (
-    act_dict_to_flat_tuple, unified_format)
 from convlab.policy.tus.multiwoz.transformer import TransformerActionPrediction
 from convlab.policy.tus.unify.Goal import Goal
 from convlab.policy.tus.unify.usermanager import BinaryFeature
-from convlab.policy.tus.unify.util import (create_goal, int2onehot,
-                                           metadata2state, parse_dialogue_act,
-                                           parse_user_goal, split_slot_name)
+from convlab.policy.tus.unify.util import create_goal, split_slot_name
 from convlab.util import (load_dataset,
                           relative_import_module_from_unified_datasets)
 from convlab.util.custom_util import model_downloader
-from convlab.util.multiwoz.multiwoz_slot_trans import REF_USR_DA
-from pprint import pprint
+from convlab.task.multiwoz.goal_generator import GoalGenerator
+from convlab.policy.tus.unify.Goal import old_goal2list
+from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import Goal as ABUS_Goal
+
+
 reverse_da, normalize_domain_slot_value = relative_import_module_from_unified_datasets(
     'multiwoz21', 'preprocess.py', ['reverse_da', 'normalize_domain_slot_value'])
 
@@ -51,7 +50,7 @@ class UserActionPolicy(Policy):
         self.user = TransformerActionPrediction(self.config).to(device=DEVICE)
         if pretrain:
             model_path = os.path.join(
-                self.config["model_dir"], self.config["model_name"])
+                self.config["model_dir"], "model-non-zero")  # self.config["model_name"])
             print(f"loading model from {model_path}...")
             self.load(model_path)
         self.user.eval()
@@ -61,6 +60,8 @@ class UserActionPolicy(Policy):
         self.reward = {"success": 40,
                        "fail": -20}
         self.sys_acts = []
+        self.goal_gen = GoalGenerator()
+        self.raw_goal = None
 
     def _no_offer(self, system_in):
         for intent, domain, slot, value in system_in:
@@ -127,13 +128,15 @@ class UserActionPolicy(Policy):
         self.topic = 'NONE'
         remove_domain = "police"  # remove police domain in inference
 
-        # if not goal:
-        #     self.new_goal(remove_domain=remove_domain)
-        # else:
-        #     self.read_goal(goal)
-        if not goal:
-            data = load_dataset(self.dataset, 0)
-            goal = Goal(create_goal(data["test"][0]))
+        if type(goal) == ABUS_Goal:
+            self.raw_goal = goal.domain_goals
+            goal_list = old_goal2list(goal.domain_goals)
+            goal = Goal(goal_list)
+        else:
+            goal = ABUS_Goal(self.goal_gen)
+            self.raw_gaol = goal.domain_goals
+            goal_list = old_goal2list(goal.domain_goals)
+            goal = Goal(goal_list)
 
         self.read_goal(goal)
         self.feat_handler.initFeatureHandeler(self.goal)
@@ -155,15 +158,15 @@ class UserActionPolicy(Policy):
         else:
             self.goal = Goal(goal=data_goal)
 
-    def new_goal(self, remove_domain="police", domain_len=None):
-        keep_generate_goal = True
-        while keep_generate_goal:
-            self.goal = Goal(goal_generator=self.goal_gen)
-            if (domain_len and len(self.goal.domains) != domain_len) or \
-                    (remove_domain and remove_domain in self.goal.domains):
-                keep_generate_goal = True
-            else:
-                keep_generate_goal = False
+    # def new_goal(self, remove_domain="police", domain_len=None):
+    #     keep_generate_goal = True
+    #     while keep_generate_goal:
+    #         self.goal = Goal(goal_generator=self.goal_gen)
+    #         if (domain_len and len(self.goal.domains) != domain_len) or \
+    #                 (remove_domain and remove_domain in self.goal.domains):
+    #             keep_generate_goal = True
+    #         else:
+    #             keep_generate_goal = False
 
     def load(self, model_path=None):
         self.user.load_state_dict(torch.load(model_path, map_location=DEVICE))
@@ -171,9 +174,14 @@ class UserActionPolicy(Policy):
     def load_state_dict(self, model=None):
         self.user.load_state_dict(model)
 
-    def get_goal(self):
+    def _get_goal(self):
+        # internal usage
         return self.goal.domain_goals
 
+    def get_goal(self):
+        # for outside usage, e.g. evaluator
+        return self.raw_goal
+
     def get_reward(self):
         if self.goal.task_complete():
             # reward = 2 * self.max_turn
@@ -330,7 +338,7 @@ class UserActionPolicy(Policy):
         return predict_domain
 
     def _add_user_action(self, output, domain, slot):
-        goal = self.get_goal()
+        goal = self._get_goal()
         is_action = False
         act = [[]]
         value = None
@@ -403,16 +411,32 @@ class UserPolicy(Policy):
             self.config = json.load(open(config))
         else:
             self.config = config
-        self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}'
+        self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz'
         if not os.path.exists(self.config["model_dir"]):
             # os.mkdir(self.config["model_dir"])
             model_downloader(os.path.dirname(self.config["model_dir"]),
                              "https://zenodo.org/record/5779832/files/default.zip")
+        self.slot2dbattr = {
+            'open hours': 'openhours',
+            'price range': 'pricerange',
+            'arrive by': 'arriveBy',
+            'leave at': 'leaveAt',
+            'train id': 'trainID'
+        }
+        self.dbattr2slot = {}
+        for k, v in self.slot2dbattr.items():
+            self.dbattr2slot[v] = k
 
         self.policy = UserActionPolicy(self.config)
 
     def predict(self, state):
-        return self.policy.predict(state)
+        raw_act = self.policy.predict(state)
+        act = []
+        for intent, domain, slot, value in raw_act:
+            if slot in self.dbattr2slot:
+                slot = self.dbattr2slot[slot]
+            act.append([intent, domain, slot, value])
+        return act
 
     def init_session(self, goal=None):
         self.policy.init_session(goal)
@@ -424,14 +448,8 @@ class UserPolicy(Policy):
         return self.policy.get_reward()
 
     def get_goal(self):
-        slot2dbattr = {
-            'open hours': 'openhours',
-            'price range': 'pricerange',
-            'arrive by': 'arriveBy',
-            'leave at': 'leaveAt',
-            'train id': 'trainID'
-        }
         if hasattr(self.policy, 'get_goal'):
+            return self.policy.get_goal()
             # workaround: convert goal to old format
             multiwoz_goal = {}
             goal = self.policy.get_goal()
@@ -449,8 +467,8 @@ class UserPolicy(Policy):
                                 multiwoz_goal[domain]["book"] = {}
                             norm_slot = slot.split(' ')[-1]
                             multiwoz_goal[domain]["book"][norm_slot] = value
-                        elif slot in slot2dbattr:
-                            norm_slot = slot2dbattr[slot]
+                        elif slot in self.slot2dbattr:
+                            norm_slot = self.slot2dbattr[slot]
                             multiwoz_goal[domain][slot_type][norm_slot] = value
                         else:
                             multiwoz_goal[domain][slot_type][slot] = value
diff --git a/convlab/policy/tus/unify/usermanager.py b/convlab/policy/tus/unify/usermanager.py
index 640da7b9ee01414f04700e6aaa2976f9c916914f..3192d3d578ca3698424b16f47933d6863208f77c 100644
--- a/convlab/policy/tus/unify/usermanager.py
+++ b/convlab/policy/tus/unify/usermanager.py
@@ -97,7 +97,7 @@ class TUSDataManager(Dataset):
                     action_list, user_goal, cur_state, usr_act)
                 domain_label = self.feature_handler.domain_label(
                     user_goal, usr_act)
-                pre_state = user_goal.update(action=usr_act, char="user")
+                # pre_state = user_goal.update(action=usr_act, char="user") # trick?
                 feature["id"].append(dialog["dialogue_id"])
                 feature["input"].append(input_feature)
                 feature["mask"].append(mask)
diff --git a/convlab/policy/tus/unify/util.py b/convlab/policy/tus/unify/util.py
index 1978f489534f74839b11dee333e232cbc601f270..d65f72a06e181e66bfe0d7ac0c60f0c03a56ad43 100644
--- a/convlab/policy/tus/unify/util.py
+++ b/convlab/policy/tus/unify/util.py
@@ -1,9 +1,49 @@
 from convlab.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
+from convlab.util import load_dataset
+
 import json
 
 NOT_MENTIONED = "not mentioned"
 
 
+def load_experiment_dataset(data_name="multiwoz21", dial_ids_order=0, split2ratio=1):
+    ratio = {'train': split2ratio, 'validation': split2ratio}
+    if data_name == "all" or data_name == "sgd+tm" or data_name == "tm":
+        print("merge all datasets...")
+        if data_name == "all":
+            all_dataset = ["multiwoz21", "sgd", "tm1", "tm2", "tm3"]
+        if data_name == "sgd+tm":
+            all_dataset = ["sgd", "tm1", "tm2", "tm3"]
+        if data_name == "tm":
+            all_dataset = ["tm1", "tm2", "tm3"]
+
+        datasets = {}
+        for name in all_dataset:
+            datasets[name] = load_dataset(
+                name,
+                dial_ids_order=dial_ids_order,
+                split2ratio=ratio)
+        raw_data = merge_dataset(datasets, all_dataset[0])
+
+    else:
+        print(f"load single dataset {data_name}/{split2ratio}")
+        raw_data = load_dataset(data_name,
+                                dial_ids_order=dial_ids_order,
+                                split2ratio=ratio)
+    return raw_data
+
+
+def merge_dataset(datasets, data_name):
+    data_split = [x for x in datasets[data_name]]
+    raw_data = {}
+    for data_type in data_split:
+        raw_data[data_type] = []
+        for dataname, dataset in datasets.items():
+            print(f"merge {dataname}...")
+            raw_data[data_type] += dataset[data_type]
+    return raw_data
+
+
 def int2onehot(index, output_dim=6, remove_zero=False):
     one_hot = [0] * output_dim
     if remove_zero:
@@ -89,50 +129,6 @@ def get_booking_domain(slot, value, all_values, domain_list):
     return found
 
 
-def act2slot(intent, domain, slot, value, all_values):
-
-    if domain not in UsrDa2Goal:
-        # print(f"Not handle domain {domain}")
-        return ""
-
-    if domain == "booking":
-        slot = SysDa2Goal[domain][slot]
-        domain = get_booking_domain(slot, value, all_values)
-        return f"{domain}-{slot}"
-
-    elif domain in UsrDa2Goal:
-        if slot in SysDa2Goal[domain]:
-            slot = SysDa2Goal[domain][slot]
-        elif slot in UsrDa2Goal[domain]:
-            slot = UsrDa2Goal[domain][slot]
-        elif slot in SysDa2Goal["booking"]:
-            slot = SysDa2Goal["booking"][slot]
-        # else:
-        #     print(
-        #         f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}")
-
-        return f"{domain}-{slot}"
-
-    print("strange!!!")
-    print(intent, domain, slot, value)
-
-    return ""
-
-
-def get_user_history(dialog, all_values):
-    turn_num = len(dialog)
-    mentioned_slot = []
-    for turn_id in range(0, turn_num, 2):
-        usr_act = parse_dialogue_act(
-            dialog[turn_id]["dialog_act"])
-        for intent, domain, slot, value in usr_act:
-            slot_name = act2slot(
-                intent, domain.lower(), slot.lower(), value.lower(), all_values)
-            if slot_name not in mentioned_slot:
-                mentioned_slot.append(slot_name)
-    return mentioned_slot
-
-
 def update_config_file(file_name, attribute, value):
     with open(file_name, 'r') as config_file:
         config = json.load(config_file)
@@ -147,7 +143,7 @@ def update_config_file(file_name, attribute, value):
 def create_goal(dialog) -> list:
     # a list of {'intent': ..., 'domain': ..., 'slot': ..., 'value': ...}
     dicts = []
-    for i, turn in enumerate(dialog['turns']):
+    for turn in dialog['turns']:
         # print(turn['speaker'])
         # assert (i % 2 == 0) == (turn['speaker'] == 'user')
         # if i % 2 == 0:
@@ -205,6 +201,45 @@ def split_slot_name(slot_name):
         return tokens[0], '-'.join(tokens[1:])
 
 
+# copy from data.unified_datasets.multiwoz21
+slot_name_map = {
+    'addr': "address",
+    'post': "postcode",
+    'pricerange': "price range",
+    'arrive': "arrive by",
+    'arriveby': "arrive by",
+    'leave': "leave at",
+    'leaveat': "leave at",
+    'depart': "departure",
+    'dest': "destination",
+    'fee': "entrance fee",
+    'open': 'open hours',
+    'car': "type",
+    'car type': "type",
+    'ticket': 'price',
+    'trainid': 'train id',
+    'id': 'train id',
+    'people': 'book people',
+    'stay': 'book stay',
+    'none': '',
+    'attraction': {
+        'price': 'entrance fee'
+    },
+    'hospital': {},
+    'hotel': {
+        'day': 'book day', 'price': "price range"
+    },
+    'restaurant': {
+        'day': 'book day', 'time': 'book time', 'price': "price range"
+    },
+    'taxi': {},
+    'train': {
+        'day': 'day', 'time': "duration"
+    },
+    'police': {},
+    'booking': {}
+}
+
 if __name__ == "__main__":
     print(split_slot_name("restaurant-search-location"))
     print(split_slot_name("sports-day.match"))