From 2e9628573fe7f3020e5f1dc5cb072a6795c0fd88 Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Mon, 14 Jun 2021 09:58:58 +0200
Subject: [PATCH] add tus

---
 convlab2/dst/rule/multiwoz/usr_dst.py | 217 ++++++++++
 convlab2/policy/tus/Da2Goal.py        |  87 ++++
 convlab2/policy/tus/Goal.py           | 405 +++++++++++++++++++
 convlab2/policy/tus/TUS.py            | 293 ++++++++++++++
 convlab2/policy/tus/__init__.py       |   0
 convlab2/policy/tus/analysis.py       | 430 ++++++++++++++++++++
 convlab2/policy/tus/config.json       |  14 +
 convlab2/policy/tus/exp/default.json  |  32 ++
 convlab2/policy/tus/train.py          | 215 ++++++++++
 convlab2/policy/tus/transformer.py    | 177 +++++++++
 convlab2/policy/tus/usermanager.py    | 550 ++++++++++++++++++++++++++
 convlab2/policy/tus/util.py           | 126 ++++++
 12 files changed, 2546 insertions(+)
 create mode 100755 convlab2/dst/rule/multiwoz/usr_dst.py
 create mode 100644 convlab2/policy/tus/Da2Goal.py
 create mode 100644 convlab2/policy/tus/Goal.py
 create mode 100644 convlab2/policy/tus/TUS.py
 create mode 100644 convlab2/policy/tus/__init__.py
 create mode 100644 convlab2/policy/tus/analysis.py
 create mode 100755 convlab2/policy/tus/config.json
 create mode 100644 convlab2/policy/tus/exp/default.json
 create mode 100644 convlab2/policy/tus/train.py
 create mode 100644 convlab2/policy/tus/transformer.py
 create mode 100644 convlab2/policy/tus/usermanager.py
 create mode 100644 convlab2/policy/tus/util.py

diff --git a/convlab2/dst/rule/multiwoz/usr_dst.py b/convlab2/dst/rule/multiwoz/usr_dst.py
new file mode 100755
index 0000000..00199fd
--- /dev/null
+++ b/convlab2/dst/rule/multiwoz/usr_dst.py
@@ -0,0 +1,217 @@
+import json
+import os
+
+from convlab2.util.multiwoz.state import default_state
+from convlab2.dst.rule.multiwoz.dst_util import normalize_value
+from convlab2.dst.dst import DST
+from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
+from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
+from pprint import pprint
+
+SLOT2SEMI = {
+    "arriveby": "arriveBy",
+    "leaveat": "leaveAt",
+    "trainid": "trainID",
+}
+
+
+class UserRuleDST(DST):
+    """Rule based DST which trivially updates new values from NLU result to states.
+
+    Attributes:
+        state(dict):
+            Dialog state. Function ``convlab2.util.multiwoz.state.default_state`` returns a default state.
+        value_dict(dict):
+            It helps check whether ``user_act`` has correct content.
+    """
+
+    def __init__(self):
+        DST.__init__(self)
+        self.state = default_state()
+        path = os.path.dirname(
+            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
+        path = os.path.join(path, 'data/multiwoz/value_dict.json')
+        self.value_dict = json.load(open(path))
+        self.mentioned_domain = []
+
+    def update(self, sys_act=None):
+        """
+        update belief_state, request_state
+        :param sys_act:
+        :return:
+        """
+        # print("dst", user_act)
+        self.update_mentioned_domain(sys_act)
+        for intent, domain, slot, value in sys_act:
+            domain = domain.lower()
+            intent = intent.lower()
+            if domain in ['unk', 'general']:
+                continue
+            # TODO domain: booking
+            if domain == "booking":
+                for domain in self.mentioned_domain:
+                    self.update_inform_request(
+                        intent, domain, slot, value)
+            else:
+                self.update_inform_request(intent, domain, slot, value)
+
+            # elif intent == 'inform' or intent == 'recommend' or intent == 'request':
+            #     # TODO domain: booking
+            #     self.update_inform_request(self, intent, domain, slot, value)
+
+            # elif intent in ["offerbooked", "book", "nobook"]:
+            #     pass
+            #     # elif intent == 'request':
+                #     k = REF_SYS_DA[domain.capitalize()].get(slot, slot)
+                #     if domain not in self.state['request_state']:
+                #         self.state['request_state'][domain] = {}
+                #     if k not in self.state['request_state'][domain]:
+                #         self.state['request_state'][domain][k] = 0
+                # self.state['user_action'] = user_act  # should be added outside DST module
+        return self.state
+
+    def init_session(self):
+        """Initialize ``self.state`` with a default state, which ``convlab2.util.multiwoz.state.default_state`` returns."""
+        self.state = default_state()
+        self.mentioned_domain = []
+
+    def update_mentioned_domain(self, sys_act):
+        if not sys_act:
+            return
+        for intent, domain, slot, value in sys_act:
+            domain = domain.lower()
+            if domain not in self.mentioned_domain and domain not in ['unk', 'general', 'booking']:
+                self.mentioned_domain.append(domain)
+                # print(f"update: mentioned {domain} domain")
+
+    def update_inform_request(self, intent, domain, slot, value):
+        slot = slot.lower()
+        k = SysDa2Goal[domain].get(slot, slot)
+        k = SLOT2SEMI.get(k, k)
+        if k is None:
+            return
+        try:
+            assert domain in self.state['belief_state']
+        except:
+            raise Exception(
+                'Error: domain <{}> not in new belief state'.format(domain))
+        domain_dic = self.state['belief_state'][domain]
+        assert 'semi' in domain_dic
+        assert 'book' in domain_dic
+        if k in domain_dic['semi']:
+            nvalue = normalize_value(self.value_dict, domain, k, value)
+            self.state['belief_state'][domain]['semi'][k] = nvalue
+        elif k in domain_dic['book']:
+            self.state['belief_state'][domain]['book'][k] = value
+        elif k.lower() in domain_dic['book']:
+            self.state['belief_state'][domain]['book'][k.lower()
+                                                       ] = value
+        elif k == 'trainID' and domain == 'train':
+            self.state['belief_state'][domain]['book'][k] = normalize_value(
+                self.value_dict, domain, k, value)
+        else:
+            # print('unknown slot name <{}> of domain <{}>'.format(k, domain))
+            nvalue = normalize_value(self.value_dict, domain, k, value)
+            self.state['belief_state'][domain]['semi'][k] = nvalue
+            with open('unknown_slot.log', 'a+') as f:
+                f.write(
+                    'unknown slot name <{}> of domain <{}>\n'.format(k, domain))
+
+    def update_request(self):
+        pass
+
+    def update_booking(self):
+        pass
+
+
+if __name__ == '__main__':
+    # from convlab2.dst.rule.multiwoz import RuleDST
+
+    dst = UserRuleDST()
+
+    action = [['Inform', 'Restaurant', 'Phone', '01223323737'],
+              ['reqmore', 'general', 'none', 'none'],
+              ["Inform", "Hotel", "Area", "east"], ]
+    state = dst.update(action)
+    pprint(state)
+    dst.init_session()
+
+    # Action is a dict. Its keys are strings(domain-type pairs, both uppercase and lowercase is OK) and its values are list of lists.
+    # The domain may be one of ('Attraction', 'Hospital', 'Booking', 'Hotel', 'Restaurant', 'Taxi', 'Train', 'Police').
+    # The type may be "inform" or "request".
+
+    # For example, the action below has a key "Hotel-Inform", in which "Hotel" is domain and "Inform" is action type.
+    # Each list in the value of "Hotel-Inform" is a slot-value pair. "Area" is slot and "east" is value. "Star" is slot and "4" is value.
+    action = [
+        ["Inform", "Hotel", "Area", "east"],
+        ["Inform", "Hotel", "Stars", "4"]
+    ]
+
+    # method `update` updates the attribute `state` of tracker, and returns it.
+    state = dst.update(action)
+    assert state == dst.state
+    assert state == {'user_action': [],
+                     'system_action': [],
+                     'belief_state': {'police': {'book': {'booked': []}, 'semi': {}},
+                                      'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''},
+                                                'semi': {'name': '',
+                                                         'area': 'east',
+                                                         'parking': '',
+                                                         'pricerange': '',
+                                                         'stars': '4',
+                                                         'internet': '',
+                                                         'type': ''}},
+                                      'attraction': {'book': {'booked': []},
+                                                     'semi': {'type': '', 'name': '', 'area': ''}},
+                                      'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''},
+                                                     'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}},
+                                      'hospital': {'book': {'booked': []}, 'semi': {'department': ''}},
+                                      'taxi': {'book': {'booked': []},
+                                               'semi': {'leaveAt': '',
+                                                        'destination': '',
+                                                        'departure': '',
+                                                        'arriveBy': ''}},
+                                      'train': {'book': {'booked': [], 'people': ''},
+                                                'semi': {'leaveAt': '',
+                                                         'destination': '',
+                                                         'day': '',
+                                                         'arriveBy': '',
+                                                         'departure': ''}}},
+                     'request_state': {},
+                     'terminated': False,
+                     'history': []}
+
+    # Please call `init_session` before a new dialog. This initializes the attribute `state` of tracker with a default state, which `convlab2.util.multiwoz.state.default_state` returns. But You needn't call it before the first dialog, because tracker gets a default state in its constructor.
+    dst.init_session()
+    action = [["Inform", "Train", "Arrive", "19:45"]]
+    state = dst.update(action)
+    assert state == {'user_action': [],
+                     'system_action': [],
+                     'belief_state': {'police': {'book': {'booked': []}, 'semi': {}},
+                                      'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''},
+                                                'semi': {'name': '',
+                                                         'area': '',
+                                                         'parking': '',
+                                                         'pricerange': '',
+                                                         'stars': '',
+                                                         'internet': '',
+                                                         'type': ''}},
+                                      'attraction': {'book': {'booked': []},
+                                                     'semi': {'type': '', 'name': '', 'area': ''}},
+                                      'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''},
+                                                     'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}},
+                                      'hospital': {'book': {'booked': []}, 'semi': {'department': ''}},
+                                      'taxi': {'book': {'booked': []},
+                                               'semi': {'leaveAt': '',
+                                                        'destination': '',
+                                                        'departure': '',
+                                                        'arriveBy': ''}},
+                                      'train': {'book': {'booked': [], 'people': ''},
+                                                'semi': {'leaveAt': '',
+                                                         'destination': '',
+                                                         'day': '',
+                                                         'arriveBy': '19:45',
+                                                         'departure': ''}}},
+                     'request_state': {},
+                     'terminated': False,
+                     'history': []}
diff --git a/convlab2/policy/tus/Da2Goal.py b/convlab2/policy/tus/Da2Goal.py
new file mode 100644
index 0000000..8565297
--- /dev/null
+++ b/convlab2/policy/tus/Da2Goal.py
@@ -0,0 +1,87 @@
+UsrDa2Goal = {
+    'attraction': {
+        'area': 'area', 'name': 'name', 'type': 'type',
+        'addr': 'address', 'fee': 'entrance fee', 'phone': 'phone',
+        'post': 'postcode', 'ref': "ref", 'none': None
+    },
+    'hospital': {
+        'department': 'department', 'addr': 'address', 'phone': 'phone',
+        'post': 'postcode', 'ref': "ref", 'none': None
+    },
+    'hotel': {
+        'area': 'area', 'internet': 'internet', 'name': 'name',
+        'parking': 'parking', 'price': 'pricerange', 'stars': 'stars',
+        'type': 'type', 'addr': 'address', 'phone': 'phone',
+        'post': 'postcode', 'day': 'day', 'people': 'people',
+        'stay': 'stay', 'ref': "ref", 'none': None
+    },
+    'police': {
+        'addr': 'address', 'phone': 'phone', 'post': 'postcode', 'name': 'name', 'ref': "ref", 'none': None
+    },
+    'restaurant': {
+        'area': 'area', 'day': 'day', 'food': 'food', 'type': 'type',
+        'name': 'name', 'people': 'people', 'price': 'pricerange',
+        'time': 'time', 'addr': 'address', 'phone': 'phone',
+        'post': 'postcode', 'ref': "ref", 'none': None
+    },
+    'taxi': {
+        'arrive': 'arriveby', 'depart': 'departure', 'dest': 'destination',
+        'leave': 'leaveat', 'car': 'car type', 'phone': 'phone', 'ref': "ref", 'none': None
+    },
+    'train': {
+        'time': "duration", 'arrive': 'arriveby', 'day': 'day', 'ref': "ref",
+        'depart': 'departure', 'dest': 'destination', 'leave': 'leaveat',
+        'people': 'people', 'duration': 'duration', 'price': 'price', 'choice': "choice",
+        'trainid': 'trainid', 'ticket': 'price', 'id': "trainid", 'none': None
+    }
+}
+
+SysDa2Goal = {
+    'attraction': {
+        'addr': "address", 'area': "area", 'choice': "choice",
+        'fee': "entrance fee", 'name': "name", 'phone': "phone",
+        'post': "postcode", 'price': "pricerange", 'type': "type",
+        'none': None
+    },
+    'booking': {
+        'day': 'day', 'name': 'name', 'people': 'people',
+        'ref': 'ref', 'stay': 'stay', 'time': 'time',
+        'none': None
+    },
+    'hospital': {
+        'department': 'department', 'addr': 'address', 'post': 'postcode',
+        'phone': 'phone', 'none': None
+    },
+    'hotel': {
+        'addr': "address", 'area': "area", 'choice': "choice",
+        'internet': "internet", 'name': "name", 'parking': "parking",
+        'phone': "phone", 'post': "postcode", 'price': "pricerange",
+        'ref': "ref", 'stars': "stars", 'type': "type",
+        'none': None
+    },
+    'restaurant': {
+        'addr': "address", 'area': "area", 'choice': "choice",
+        'name': "name", 'food': "food", 'phone': "phone",
+        'post': "postcode", 'price': "pricerange", 'ref': "ref",
+        'none': None
+    },
+    'taxi': {
+        'arrive': "arriveby", 'car': "car type", 'depart': "departure",
+        'dest': "destination", 'leave': "leaveat", 'phone': "phone",
+        'none': None
+    },
+    'train': {
+        'arrive': "arriveby", 'choice': "choice", 'day': "day",
+        'depart': "departure", 'dest': "destination", 'id': "trainid", 'trainid': "trainid",
+        'leave': "leaveat", 'people': "people", 'ref': "ref",
+        'ticket': "price", 'time': "duration", 'duration': 'duration', 'none': None
+    },
+    'police': {
+        'addr': "address", 'post': "postcode", 'phone': "phone", 'name': 'name', 'none': None
+    }
+}
+ref_slot_data2stand = {
+    'train': {
+        'duration': 'time', 'price': 'ticket', 'trainid': 'id'
+    }
+}
diff --git a/convlab2/policy/tus/Goal.py b/convlab2/policy/tus/Goal.py
new file mode 100644
index 0000000..13ec311
--- /dev/null
+++ b/convlab2/policy/tus/Goal.py
@@ -0,0 +1,405 @@
+import json
+import os
+from random import shuffle
+from convlab2.task.multiwoz.goal_generator import GoalGenerator
+from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
+from convlab2.policy.tus.multiwoz.util import parse_user_goal
+
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK, ""]
+
+ref_slot_data2stand = {
+    'train': {
+        'duration': 'time', 'price': 'ticket', 'trainid': 'id'
+    }
+}
+
+
+class Goal(object):
+    """ User Goal Model Class. """
+
+    def __init__(self, goal_generator=None, goal=None):
+        """
+        create new Goal by random
+        Args:
+            goal_generator (GoalGenerator): Goal Gernerator.
+        """
+        self.max_domain_len = 5
+        self.max_slot_len = 20
+        self.local_id = {}
+        path = os.path.dirname(
+            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
+        path = os.path.join(path, 'data/multiwoz/all_value.json')
+        self.all_values = json.load(open(path))
+        if goal_generator is not None and goal is None:
+            self.domain_goals = goal_generator.get_user_goal()
+            self.domains = list(self.domain_goals['domain_ordering'])
+            del self.domain_goals['domain_ordering']
+        elif goal_generator is None and goal is not None:
+            self.domains = []
+            self.domain_goals = {}
+            for domain in goal:
+                if domain in SysDa2Goal and goal[domain]:  # TODO check order
+                    self.domains.append(domain)
+                    self.domain_goals[domain] = goal[domain]
+        else:
+            print("Warning!!! One of goal_generator or goal should not be None!!!")
+
+        for domain in self.domains:
+            if 'reqt' in self.domain_goals[domain].keys():
+                self.domain_goals[domain]['reqt'] = {
+                    slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']}
+
+            if 'book' in self.domain_goals[domain].keys():
+                self.domain_goals[domain]['booked'] = DEF_VAL_UNK
+        self.init_local_id()
+
+    def init_local_id(self):
+        # local_id = {
+        #     "domain 1": {
+        #         "ID": [1, 0, 0],
+        #         "SLOT": {
+        #             "slot 1": [1, 0, 0],
+        #             "slot 2": [0, 1, 0]}}}
+
+        for domain_id, domain in enumerate(self.domains):
+            self._init_domain_id(domain)
+            self._update_domain_id(domain, domain_id)
+            slot_id = 0
+            for slot_type in ["info", "book", "reqt"]:
+                for slot in self.domain_goals[domain].get(slot_type, {}):
+                    self._init_slot_id(domain, slot)
+                    self._update_slot_id(domain, slot, slot_id)
+                    slot_id += 1
+
+    def insert_local_id(self, new_slot_name):
+        domain, slot = new_slot_name.split('-')
+        if domain not in self.local_id:
+            self._init_domain_id(domain)
+            domain_id = len(self.domains) + 1
+            self._update_domain_id(domain, domain_id)
+            self._init_slot_id(domain, slot)
+            # the first slot for a new domain
+            self._update_slot_id(domain, slot, 0)
+
+        else:
+            slot_id = len(self.local_id[domain]["SLOT"]) + 1
+            self._init_slot_id(domain, slot)
+            self._update_slot_id(domain, slot, slot_id)
+
+    def get_slot_id(self, slot_name):
+        domain, slot = slot_name.split('-')
+        if domain in self.local_id and slot in self.local_id[domain]["SLOT"]:
+            return self.local_id[domain]["ID"], self.local_id[domain]["SLOT"][slot]
+        else:  # a slot not in original user goal
+            self.insert_local_id(slot_name)
+            domain_id, slot_id = self.get_slot_id(slot_name)
+            return domain_id, slot_id
+
+    def task_complete(self):
+        """
+        Check that all requests have been met
+        Returns:
+            (boolean): True to accomplish.
+        """
+        for domain in self.domains:
+            if 'reqt' in self.domain_goals[domain]:
+                reqt_vals = self.domain_goals[domain]['reqt'].values()
+                for val in reqt_vals:
+                    if val in NOT_SURE_VALS:
+                        return False
+
+            if 'booked' in self.domain_goals[domain]:
+                if self.domain_goals[domain]['booked'] in NOT_SURE_VALS:
+                    return False
+        return True
+
+    def next_domain_incomplete(self):
+        # request
+        for domain in self.domains:
+            # reqt
+            if 'reqt' in self.domain_goals[domain]:
+                requests = self.domain_goals[domain]['reqt']
+                unknow_reqts = [
+                    key for (key, val) in requests.items() if val in NOT_SURE_VALS]
+                if len(unknow_reqts) > 0:
+                    return domain, 'reqt', ['name'] if 'name' in unknow_reqts else unknow_reqts
+
+            # book
+            if 'booked' in self.domain_goals[domain]:
+                if self.domain_goals[domain]['booked'] in NOT_SURE_VALS:
+                    return domain, 'book', \
+                        self.domain_goals[domain]['fail_book'] if 'fail_book' in self.domain_goals[domain].keys() else \
+                        self.domain_goals[domain]['book']
+
+        return None, None, None
+
+    def __str__(self):
+        return '-----Goal-----\n' + \
+               json.dumps(self.domain_goals, indent=4) + \
+               '\n-----Goal-----'
+
+    def action_list(self, user_history=None, sys_act=None, all_values=None):
+        goal_slot = parse_user_goal(self)
+        if user_history:
+            priority_action = self._reorder_based_on_user_history(
+                user_history, goal_slot)
+
+        else:
+            # priority_action = [slot for slot in goal_slot]
+            priority_action = self._reorder_random(goal_slot)
+
+        if sys_act:
+            for intent, domain, slot, value in sys_act:
+                slot_name = self.act2slot(
+                    intent, domain, slot, value, all_values)
+                # print("system_mention:", slot_name)
+                if slot_name and slot_name not in priority_action:
+                    priority_action.insert(0, slot_name)
+
+        return priority_action
+
+    def get_booking_domain(self, slot, value, all_values):
+        for domain in self.domains:
+            if slot in all_values["all_value"] and value in all_values["all_value"][slot]:
+                return domain
+        print("NOT FOUND BOOKING DOMAIN")
+        return ""
+
+    def act2slot(self, intent, domain, slot, value, all_values):
+        domain = domain.lower()
+        slot = slot.lower()
+
+        if domain not in UsrDa2Goal:
+            # print(f"Not handle domain {domain}")
+            return ""
+
+        if domain == "booking":
+            slot = SysDa2Goal[domain][slot]
+            domain = self.get_booking_domain(slot, value, all_values)
+            if domain:
+                return f"{domain}-{slot}"
+
+        elif domain in UsrDa2Goal:
+            if slot in SysDa2Goal[domain]:
+                slot = SysDa2Goal[domain][slot]
+            elif slot in UsrDa2Goal[domain]:
+                slot = UsrDa2Goal[domain][slot]
+            elif slot in SysDa2Goal["booking"]:
+                slot = SysDa2Goal["booking"][slot]
+            # else:
+            #     print(
+            #         f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}")
+
+            return f"{domain}-{slot}"
+        return ""
+
+    def _reorder_random(self, goal_slot):
+        new_order = [slot for slot in goal_slot]
+        random_order = new_order[1:]
+        shuffle(random_order)
+
+        return [new_order[0]] + random_order
+
+    def _reorder_based_on_user_history(self, user_history, goal_slot):
+        # user_history = [slot_0, slot_1, ...]
+        new_order = []
+        for slot in user_history:
+            if slot and slot not in new_order:
+                new_order.append(slot)
+
+        for slot in goal_slot:
+            if slot not in new_order:
+                new_order.append(slot)
+        return new_order
+
+    def update_user_goal(self, action=None, state=None):
+        # update request and booked
+        if action:
+            self._update_user_goal_from_action(action)
+        if state:
+            self._update_user_goal_from_state(state)
+            self._check_booked(state)  # this should always check
+
+        if action is None and state is None:
+            print("Warning!!!! Both action and state are None")
+
+    def _check_booked(self, state):
+        for domain in self.domains:
+            if "booked" in self.domain_goals[domain]:
+                if self._check_book_info(state, domain):
+                    self.domain_goals[domain]["booked"] = DEF_VAL_BOOKED
+                else:
+                    self.domain_goals[domain]["booked"] = DEF_VAL_NOBOOK
+
+    def _check_book_info(self, state, domain):
+        # need to check info, reqt for booked?
+        if domain not in state:
+            return False
+
+        for slot_type in ['info', 'book']:
+            for slot in self.domain_goals[domain].get(slot_type, {}):
+                user_value = self.domain_goals[domain][slot_type][slot]
+                if slot in state[domain]["semi"]:
+                    state_value = state[domain]["semi"][slot]
+
+                elif slot in state[domain]["book"]:
+                    state_value = state[domain]["book"][slot]
+                else:
+                    state_value = ""
+                # only check mentioned values (?)
+                if state_value and state_value != user_value:
+                    # print(
+                    #     f"booking info is incorrect, for slot {slot}: "
+                    #     f"goal {user_value} != state {state_value}")
+                    return False
+
+        return True
+
+    def _update_user_goal_from_action(self, action):
+        for intent, domain, slot, value in action:
+            # print("update user goal from action")
+            # print(intent, domain, slot, value)
+            # print("action:", intent)
+            domain = domain.lower()
+            value = value.lower()
+            slot = slot.lower()
+            if slot == "ref":
+                for usr_domain in self.domains:
+                    if "booked" in self.domain_goals[usr_domain]:
+                        self.domain_goals[usr_domain]["booked"] = DEF_VAL_BOOKED
+            else:
+                domain, slot = self._norm_domain_slot(domain, slot, value)
+
+                if self._check_update_request(domain, slot) and value != "?":
+                    self.domain_goals[domain]['reqt'][slot] = value
+                    # print(f"update reqt {slot} = {value} from system action")
+
+    def _norm_domain_slot(self, domain, slot, value):
+        if domain == "booking":
+            # ["book", "booking", "people", 7]
+            if slot in SysDa2Goal[domain]:
+                slot = SysDa2Goal[domain][slot]
+                domain = self._get_booking_domain(slot, value)
+            else:
+                domain = ""
+                for d in SysDa2Goal:
+                    if slot in SysDa2Goal[d]:
+                        domain = d
+                        slot = SysDa2Goal[d][slot]
+            if not domain:
+                return "", ""
+            return domain, slot
+
+        elif domain in self.domains:
+            if slot in SysDa2Goal[domain]:
+                # ["request", "restaurant", "area", "north"]
+                slot = SysDa2Goal[domain][slot]
+            elif slot in UsrDa2Goal[domain]:
+                slot = UsrDa2Goal[domain][slot]
+            elif slot in SysDa2Goal["booking"]:
+                # ["inform", "hotel", "stay", 2]
+                slot = SysDa2Goal["booking"][slot]
+            # else:
+            #     print(
+            #         f"UNSEEN SLOT IN UPDATE GOAL {intent, domain, slot, value}")
+            return domain, slot
+
+        else:
+            # domain = general
+            return "", ""
+
+    def _update_user_goal_from_state(self, state):
+        for domain in state:
+            for slot in state[domain]["semi"]:
+                if self._check_update_request(domain, slot):
+                    self._update_user_goal_from_semi(state, domain, slot)
+            for slot in state[domain]["book"]:
+                if slot == "booked" and state[domain]["book"]["booked"]:
+                    self._update_booked(state, domain)
+
+                elif state[domain]["book"][slot] and self._check_update_request(domain, slot):
+                    self._update_book(state, domain, slot)
+
+    def _update_slot(self, domain, slot, value):
+        self.domain_goals[domain]['reqt'][slot] = value
+
+    def _update_user_goal_from_semi(self, state, domain, slot):
+        if self._check_value(state[domain]["semi"][slot]):
+            self._update_slot(domain, slot, state[domain]["semi"][slot])
+            # print("update reqt {} in semi".format(slot),
+            #       state[domain]["semi"][slot])
+
+    def _update_booked(self, state, domain):
+        # check state and goal is fulfill
+        self.domain_goals[domain]["booked"] = DEF_VAL_BOOKED
+        print("booked")
+        for booked_slot in state[domain]["book"]["booked"][0]:
+            if self._check_update_request(domain, booked_slot):
+                self._update_slot(domain, booked_slot,
+                                  state[domain]["book"]["booked"][0][booked_slot])
+                # print("update reqt {} in booked".format(booked_slot),
+                #       self.domain_goals[domain]['reqt'][booked_slot])
+
+    def _update_book(self, state, domain, slot):
+        if self._check_value(state[domain]["book"][slot]):
+            self._update_slot(domain, slot, state[domain]["book"][slot])
+            # print("update reqt {} in book".format(slot),
+            #       state[domain]["book"][slot])
+
+    def _check_update_request(self, domain, slot):
+        # check whether one slot is a request slot
+        if domain not in self.domain_goals:
+            return False
+        if 'reqt' not in self.domain_goals[domain]:
+            return False
+        if slot not in self.domain_goals[domain]['reqt']:
+            return False
+        return True
+
+    def _check_value(self, value=None):
+        if not value:
+            return False
+        if value in NOT_SURE_VALS:
+            return False
+        return True
+
+    def _get_booking_domain(self, slot, value):
+        """ 
+        find the domain for domain booking, excluding slot "ref"
+        """
+        found = ""
+        if not slot:  # work around
+            return found
+        slot = slot.lower()
+        value = value.lower()
+        for domain in self.all_values["all_value"]:
+            if slot in self.all_values["all_value"][domain]:
+                if value in self.all_values["all_value"][domain][slot]:
+                    if domain in self.domains:
+                        found = domain
+        return found
+
+    def _init_domain_id(self, domain):
+        self.local_id[domain] = {"ID": [0] * self.max_domain_len, "SLOT": {}}
+
+    def _init_slot_id(self, domain, slot):
+        self.local_id[domain]["SLOT"][slot] = [0] * self.max_slot_len
+
+    def _update_domain_id(self, domain, domain_id):
+        if domain_id < self.max_domain_len:
+            self.local_id[domain]["ID"][domain_id] = 1
+        else:
+            print(
+                f"too many doamins: {domain_id} > {self.max_domain_len}")
+
+    def _update_slot_id(self, domain, slot, slot_id):
+        if slot_id < self.max_slot_len:
+            self.local_id[domain]["SLOT"][slot][slot_id] = 1
+        else:
+            print(
+                f"too many slots, {slot_id} > {self.max_slot_len}")
diff --git a/convlab2/policy/tus/TUS.py b/convlab2/policy/tus/TUS.py
new file mode 100644
index 0000000..ac50976
--- /dev/null
+++ b/convlab2/policy/tus/TUS.py
@@ -0,0 +1,293 @@
+import json
+import math
+import os
+from copy import deepcopy
+from random import choice
+
+import numpy as np
+import torch
+from convlab2.policy.tus.multiwoz.Goal import Goal
+from convlab2.policy.tus.multiwoz.transformer import \
+    TransformerActionPrediction
+from convlab2.policy.tus.multiwoz.usermanager import Feature
+from convlab2.policy.tus.multiwoz.util import parse_user_goal
+from convlab2.policy.policy import Policy
+from convlab2.task.multiwoz.goal_generator import GoalGenerator
+from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+DEF_VAL_UNK = '?'  # Unknown
+DEF_VAL_DNC = 'dontcare'  # Do not care
+DEF_VAL_NUL = 'none'  # for none
+DEF_VAL_BOOKED = 'yes'  # for booked
+DEF_VAL_NOBOOK = 'no'  # for booked
+Inform = "Inform"
+Request = "Request"
+NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK, ""]
+
+SLOT2SEMI = {
+    "arriveby": "arriveBy",
+    "leaveat": "leaveAt",
+    "trainid": "trainID",
+}
+
+
+class UserActionPolicy(Policy):
+    def __init__(self, config):
+        self.config = config
+        path = os.path.dirname(
+            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
+        path = os.path.join(path, 'data/multiwoz/all_value.json')
+        self.all_values = json.load(open(path))
+        self.goal_gen = GoalGenerator()
+        Policy.__init__(self)
+        self.feat_handler = Feature(self.config)
+        self.config["num_token"] = config["num_token"]
+        self.user = TransformerActionPrediction(self.config).to(device=DEVICE)
+        self.load(os.path.join(
+            self.config["model_dir"], self.config["model_name"]))
+        self.user.eval()
+        self.use_domain_mask = self.config.get("domain_mask", False)
+        self.max_turn = 40
+        self.mentioned_domain = []
+
+    def predict(self, state):
+        # update goal
+        self.goal.update_user_goal(action=state["system_action"],
+                                   state=state['belief_state'])
+        self.time_step += 2
+
+        self.predict_action_list = self.goal.action_list(
+            sys_act=state["system_action"],
+            all_values=self.all_values)
+
+        feature, mask = self.feat_handler.get_feature(
+            self.predict_action_list,
+            self.goal,
+            state['belief_state'],
+            self.sys_history_state,
+            state["system_action"],
+            self.pre_usr_act)
+        feature = torch.tensor([feature], dtype=torch.float).to(DEVICE)
+        mask = torch.tensor([mask], dtype=torch.bool).to(DEVICE)
+
+        self.sys_history_state = state['belief_state']
+
+        usr_output = self.user.forward(feature, mask)
+        usr_action = self.transform_usr_act(
+            usr_output, self.predict_action_list)
+        domains = [act[1] for act in usr_action]
+        none_slot_acts = self._add_none_slot_act(domains)
+        usr_action = none_slot_acts + usr_action
+
+        self.pre_usr_act = deepcopy(usr_action)
+
+        return usr_action
+
+    def init_session(self):
+        self.mentioned_domain = []
+        self.time_step = 0
+        self.topic = 'NONE'
+        remove_domain = "police"  # remove police domain in inference
+        # if "remove_domain" in self.config:
+        #     remove_domain = self.config["remove_domain"]
+        self.new_goal(remove_domain=remove_domain)
+
+        # print(self.goal)
+        if self.config.get("reorder", False):
+            self.predict_action_list = self.goal.action_list()
+        else:
+            self.predict_action_list = self.action_list
+        self.sys_history_state = None  # to save sys history
+        self.terminated = False
+        self.feat_handler.initFeatureHandeler(self.goal)
+        self.pre_usr_act = None
+
+    def new_goal(self, remove_domain="police", domain_len=None):
+        keep_generate_goal = True
+        while keep_generate_goal:
+            self.goal = Goal(self.goal_gen)
+            if (domain_len and len(self.goal.domains) != domain_len) or \
+                    (remove_domain and remove_domain in self.goal.domains):
+                keep_generate_goal = True
+            else:
+                keep_generate_goal = False
+
+    def load(self, model_path=None):
+        self.user.load_state_dict(torch.load(model_path, map_location=DEVICE))
+
+    def get_goal(self):
+        return self.goal.domain_goals
+
+    def get_reward(self):
+        if self.goal.task_complete():
+            reward = 2 * self.max_turn
+
+        elif self.time_step >= self.max_turn:
+            reward = -1 * self.max_turn
+
+        else:
+            reward = -1.0
+        return reward
+
+    def _add_none_slot_act(self, domains):
+        actions = []
+        for domain in domains:
+            domain = domain.lower()
+            if domain not in self.mentioned_domain and domain != 'general':
+                actions.append([Inform, domain.capitalize(), "None", "None"])
+                self.mentioned_domain.append(domain)
+        return actions
+
+    def transform_usr_act(self, usr_output, action_list):
+        if self.goal.task_complete():
+            self.terminated = True
+            #print("task compltet")
+            return [["bye", "general", "None", "None"]]
+
+        usr_action = []
+        # [score, [action]]: when usr_action is none and task doesn't finish
+        non_zero_action = []
+        for index, slot_name in enumerate(action_list):
+            domain, slot = slot_name.split('-')
+            self._add_user_action(
+                usr_action,
+                output=torch.argmax(usr_output[0, index + 1, :]).item(),
+                domain=domain.capitalize(),
+                slot=SLOT2SEMI.get(slot, slot))
+            non_zero_action = self._max_non_zero_action(
+                action=non_zero_action,
+                score=torch.max(usr_output[0, index + 1, 1:]).item(),
+                output=torch.argmax(usr_output[0, index + 1, 1:]).item() + 1,
+                domain=domain.capitalize(),
+                slot=SLOT2SEMI.get(slot, slot))
+
+        if self.use_domain_mask:
+            domain_mask = self._get_prediction_domain(torch.round(
+                torch.sigmoid(usr_output[0, 0, :])).tolist())
+            usr_action = self._mask_user_action(usr_action, domain_mask)
+
+        if self.time_step > self.max_turn:
+            #print(f"max turn: {self.time_step}")
+            self.terminated = True
+            return [["bye", "general", "None", "None"]]
+
+        # usr_action = None
+
+        if not usr_action and not self.terminated:
+            # print(f"pick max none zero action {non_zero_action}")
+            return non_zero_action[1]
+        if not usr_action:
+            print("!!!STRANGE!!!!!!")
+            print(usr_action, non_zero_action, action_list)
+        return usr_action
+
+    def _mask_user_action(self, usr_action, mask):
+        mask_action = []
+        for intent, domain, slot, value in usr_action:
+            if domain.lower() in mask:
+                mask_action += [[intent, domain, slot, value]]
+        return mask_action
+
+    def _get_prediction_domain(self, domain_output):
+        predict_domain = []
+        if domain_output[0] > 0:
+            predict_domain.append('general')
+        for index, value in enumerate(domain_output[1:]):
+            if value > 0 and index < len(self.goal.domains):
+                predict_domain.append(self.goal.domains[index])
+        return predict_domain
+
+    def _max_non_zero_action(self, action, score, output, domain, slot):
+        if not action:
+            action = [score, []]
+
+        if score > action[0]:
+            temp = []
+            self._add_user_action(temp, output, domain, slot)
+            action = [score, temp]
+        elif score == action[0]:
+            self._add_user_action(action[1], output, domain, slot)
+        return action
+
+    def _add_user_action(self, usr_action, output, domain, slot):
+        goal = self.get_goal()
+        # "?"
+        if output == 1:  # "?"
+            self._add_action(usr_action, Request, domain, slot, DEF_VAL_UNK)
+
+        # "dontcare"
+        elif output == 2:
+            self._add_action(usr_action, Inform, domain, slot, DEF_VAL_DNC)
+
+        # system
+        elif output == 3 and self._slot_type(domain, slot):
+            slot_type = self._slot_type(domain, slot)
+            value = self.sys_history_state[domain.lower()][slot_type].get(
+                slot, "")
+            if value:
+                self._add_action(usr_action, Inform, domain, slot, value)
+
+        elif output == 4 and domain.lower() in goal:  # usr
+            value = None
+            if slot in goal[domain.lower()].get("info", {}):
+                value = goal[domain.lower()]["info"][slot]
+            elif slot in goal[domain.lower()].get("book", {}):
+                value = goal[domain.lower()]["book"][slot]
+            if value:
+                self._add_action(usr_action, Inform, domain, slot, value)
+
+        elif output == 5 and domain.lower() in goal:
+            # TODO random select, now we use the value from user goal
+            # print(f"NOT HANDLE SLOT {slot} IN DOMAIN {domain} (5)!!!")
+            value = None
+            if slot in goal[domain.lower()].get("info", {}):
+                value = goal[domain.lower()]["info"][slot]
+            elif slot in goal[domain.lower()].get("book", {}):
+                value = goal[domain.lower()]["book"][slot]
+            if value:
+                self._add_action(usr_action, Inform, domain, slot, value)
+
+    def _get_action_slot(self, domain, slot):
+        return REF_USR_DA[domain.capitalize()].get(slot, None)
+
+    def _add_action(self, usr_action, intent, domain, slot, value):
+        action_slot = self._get_action_slot(domain, slot)
+        if action_slot:
+            usr_action += [[intent, domain, action_slot, value]]
+
+    def is_terminated(self):
+        # Is there any action to say?
+        return self.terminated
+
+    def _slot_type(self, domain, slot):
+        slot_type = ""
+        if slot in self.sys_history_state[domain.lower()]["book"]:
+            slot_type = "book"
+        elif slot in self.sys_history_state[domain.lower()]["semi"]:
+            slot_type = "semi"
+
+        return slot_type
+
+
+class UserPolicy(Policy):
+    def __init__(self, config):
+        self.config = config
+        self.policy = UserActionPolicy(self.config)
+
+    def predict(self, state):
+        return self.policy.predict(state)
+
+    def init_session(self):
+        self.policy.init_session()
+
+    def is_terminated(self):
+        return self.policy.is_terminated()
+
+    def get_reward(self):
+        return self.policy.get_reward()
+
+    def get_goal(self):
+        if hasattr(self.policy, 'get_goal'):
+            return self.policy.get_goal()
+        return None
diff --git a/convlab2/policy/tus/__init__.py b/convlab2/policy/tus/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/convlab2/policy/tus/analysis.py b/convlab2/policy/tus/analysis.py
new file mode 100644
index 0000000..f8d358a
--- /dev/null
+++ b/convlab2/policy/tus/analysis.py
@@ -0,0 +1,430 @@
+import argparse
+import json
+import os
+import random
+import numpy as np
+import logging
+
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import seaborn as sn
+import torch
+from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
+from convlab2.dialog_agent.env import Environment
+from convlab2.dialog_agent.session import BiSession
+from convlab2.dialog_agent.agent import PipelineAgent
+from convlab2.dst.rule.multiwoz import RuleDST
+from convlab2.dst.rule.multiwoz.usr_dst import UserRuleDST
+from convlab2.policy.tus.multiwoz.TUS import UserPolicy
+from convlab2.policy.tus.multiwoz.transformer import \
+    TransformerActionPrediction
+from convlab2.policy.tus.multiwoz.usermanager import \
+    TUSDataManager
+from convlab2.policy.rule.multiwoz import RulePolicy
+from sklearn import metrics
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+import datetime
+
+
+def check_device():
+    if torch.cuda.is_available():
+        print("using GPU")
+        return torch.device('cuda')
+    else:
+        print("using CPU")
+        return torch.device('cpu')
+
+
+def init_logging(log_dir_path, path_suffix=None):
+    if not os.path.exists(log_dir_path):
+        os.makedirs(log_dir_path)
+    current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
+    if path_suffix:
+        log_file_path = os.path.join(
+            log_dir_path, f"{current_time}_{path_suffix}.log")
+    else:
+        log_file_path = os.path.join(
+            log_dir_path, "{}.log".format(current_time))
+
+    stderr_handler = logging.StreamHandler()
+    file_handler = logging.FileHandler(log_file_path)
+    format_str = "%(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s"
+    logging.basicConfig(level=logging.DEBUG, handlers=[
+                        stderr_handler, file_handler], format=format_str)
+
+
+class Analysis:
+    def __init__(self, config, analysis_dir='user-analysis-result', show_dialog=False, save_dialog=True):
+        if not os.path.exists(analysis_dir):
+            os.makedirs(analysis_dir)
+        self.dialog_dir = os.path.join(analysis_dir, 'dialog')
+        if not os.path.exists(self.dialog_dir):
+            os.makedirs(self.dialog_dir)
+        self.dir = analysis_dir
+        self.config = config
+        self.device = check_device()
+        self.show_dialog = show_dialog
+        self.save_dialog = save_dialog
+        self.max_turn = 40
+
+    def get_sys(self, sys="rule", load_path=None):
+        dst = RuleDST()
+
+        sys = sys.lower()
+        if sys == "rule":
+            policy = RulePolicy()
+        elif sys == "ppo":
+            from convlab2.policy.ppo import PPO
+            if load_path:
+                policy = PPO(False, use_action_mask=True, shrink=False)
+                policy.load(load_path)
+            else:
+                policy = PPO.from_pretrained()
+        elif sys == "vtrace":
+            from convlab2.policy.vtrace_rnn_action_embedding import VTRACE_RNN
+            policy = VTRACE_RNN(
+                is_train=False, seed=0, use_masking=True, shrink=False)
+            policy.load(load_path)
+        else:
+            print(f"Unsupport system type: {sys}")
+
+        return dst, policy
+
+    def get_usr(self, usr="tus", load_path=None):
+        # if using "tus", we read config
+        # for the other user simulators, we read load_path
+        usr = usr.lower()
+        if usr == "rule":
+            dst_usr = None
+            policy_usr = RulePolicy(character='usr')
+        elif usr == "tus":
+            dst_usr = UserRuleDST()
+            policy_usr = UserPolicy(self.config)
+        elif usr == "vhus":
+            from convlab2.policy.vhus.multiwoz import UserPolicyVHUS
+
+            dst_usr = None
+            policy_usr = UserPolicyVHUS(
+                load_from_zip=True, model_file="vhus_simulator_multiwoz.zip")
+        else:
+            print(f"Unsupport user type: {usr}")
+        # TODO VHUS
+
+        return dst_usr, policy_usr
+
+    def interact_test(self,
+                      sys="rule",
+                      usr="tus",
+                      sys_load_path=None,
+                      usr_load_path=None,
+                      num_dialog=400,
+                      domain=None):
+        # TODO need refactor
+        seed = 20190827
+        torch.manual_seed(seed)
+        sys = sys.lower()
+        usr = usr.lower()
+
+        sess = self._set_interactive_test(
+            sys, usr, sys_load_path, usr_load_path)
+
+        task_success = {
+            # 'All_user_sim': [], 'All_evaluator': [], 'total_return': []}
+            'complete': [], 'success': [], 'reward': []}
+
+        turn_slot_num = {i: [] for i in range(self.max_turn)}
+        turn_domain_num = {i: [] for i in range(self.max_turn)}
+        true_max_turn = 0
+
+        for seed in tqdm(range(1000, 1000 + num_dialog)):
+            # logging.info(f"Seed: {seed}")
+
+            random.seed(seed)
+            np.random.seed(seed)
+            sess.init_session()
+            # if domain is not none, the user goal must contain certain domain
+            if domain:
+                domain = domain.lower()
+                print(f"check {domain}")
+                while 1:
+                    if domain in sess.user_agent.policy.get_goal():
+                        break
+                    sess.user_agent.init_session()
+            sys_uttr = []
+            actions = 0
+            total_return = 0.0
+            if self.save_dialog:
+                f = open(os.path.join(self.dialog_dir, str(seed)), 'w')
+            for turn in range(self.max_turn):
+                sys_uttr, usr_uttr, finish, reward = sess.next_turn(sys_uttr)
+                if self.show_dialog:
+                    print(f"USR: {usr_uttr}")
+                    print(f"SYS: {sys_uttr}")
+                if self.save_dialog:
+                    f.write(f"USR: {usr_uttr}\n")
+                    f.write(f"SYS: {sys_uttr}\n")
+                actions += len(usr_uttr)
+                turn_slot_num[turn].append(len(usr_uttr))
+                turn_domain_num[turn].append(self._get_domain_num(usr_uttr))
+                total_return += sess.user_agent.policy.policy.get_reward()
+
+                if finish:
+                    task_succ = sess.evaluator.task_success()
+                    break
+            if turn > true_max_turn:
+                true_max_turn = turn
+            if self.save_dialog:
+                f.close()
+            # logging.info(f"Return: {total_return}")
+            # logging.info(f"Average actions: {actions / (turn+1)}")
+
+            task_success['complete'].append(
+                int(sess.user_agent.policy.policy.goal.task_complete()))
+            task_success['success'].append(task_succ)
+            task_success['reward'].append(total_return)
+        task_summary = {key: [0] for key in task_success}
+        for key in task_success:
+            if task_success[key]:
+                task_summary[key][0] = np.average(task_success[key])
+
+        for key in task_success:
+            logging.info(
+                f'{key} {len(task_success[key])} {task_summary[key][0]}')
+
+        # logging.info("Average action in turn")
+        write = {'turn_slot_num': [], 'turn_domain_num': []}
+        for turn in turn_slot_num:
+            if turn > true_max_turn:
+                break
+            avg = 0
+            if turn_slot_num[turn]:
+                avg = sum(turn_slot_num[turn]) / len(turn_slot_num[turn])
+            write['turn_slot_num'].append(avg)
+            # logging.info(f"turn {turn}: {avg} slots")
+        for turn in turn_domain_num:
+            if turn > true_max_turn:
+                break
+            avg = 0
+            if turn_domain_num[turn]:
+                avg = sum(turn_domain_num[turn]) / len(turn_domain_num[turn])
+            write['turn_domain_num'].append(avg)
+            # logging.info(f"turn {turn}: {avg} domains")
+
+        # write results
+        pd.DataFrame.from_dict(write).to_csv(
+            os.path.join(self.dir, f'{sys}-{usr}-turn-statistics.csv'))
+        pd.DataFrame.from_dict(task_summary).to_csv(
+            os.path.join(self.dir, f'{sys}-{usr}-task-summary.csv'))
+
+    def _get_domain_num(self, action):
+        # act: [Intent, Domain, Slot, Value]
+        return len(set(act[1] for act in action))
+
+    def _set_interactive_test(self, sys, usr, sys_load_path, usr_load_path):
+        dst_sys, policy_sys = self.get_sys(sys, sys_load_path)
+        dst_usr, policy_usr = self.get_usr(usr, usr_load_path)
+
+        usr = PipelineAgent(None, dst_usr, policy_usr, None, 'user')
+        sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys')
+        env = Environment(None, usr, None, dst_sys)
+        evaluator = MultiWozEvaluator()
+        sess = BiSession(sys, usr, None, evaluator)
+
+        return sess
+
+    def direct_test(self, model, test_data):
+
+        model = model.to(self.device)
+        model.zero_grad()
+        model.eval()
+        y_lable, y_pred = [], []
+        y_turn = []
+
+        result = {}  # old way
+
+        with torch.no_grad():
+            for i, data in enumerate(tqdm(test_data, ascii=True, desc="Evaluation"), 0):
+                input_feature = data["input"].to(self.device)
+                mask = data["mask"].to(self.device)
+                label = data["label"].to(self.device)
+                output = model(input_feature, mask)
+                y_l, y_p, y_t, r = self.parse_result(output, label)
+                y_lable += y_l
+                y_pred += y_p
+                y_turn += y_t
+                # old way
+                for r_type in r:
+                    if r_type not in result:
+                        result[r_type] = {"correct": 0, "total": 0}
+                    for n in result[r_type]:
+                        result[r_type][n] += float(r[r_type][n])
+        old_result = {}
+        for r_type in result:
+            temp = result[r_type]['correct'] / result[r_type]['total']
+            old_result[r_type] = [temp]
+
+        pd.DataFrame.from_dict(old_result).to_csv(
+            os.path.join(self.dir, f'old_result.csv'))
+
+        cm = self.model_confusion_matrix(y_lable, y_pred)
+        self.summary(y_lable, y_pred, y_turn, cm)
+
+        return old_result
+
+    def summary(self, y_true, y_pred, y_turn, cm, file_name='scores.csv'):
+        result = {
+            'f1': metrics.f1_score(y_true, y_pred, average='micro'),
+            'precision': metrics.precision_score(y_true, y_pred, average='micro'),
+            'recall': metrics.recall_score(y_true, y_pred, average='micro'),
+            'none-zero-acc': self.none_zero_acc(cm),
+            'turn-acc': sum(y_turn) / len(y_turn)}
+        col = [c for c in result]
+        df_f1 = pd.DataFrame([result[c] for c in col], col)
+        df_f1.to_csv(os.path.join(self.dir, file_name))
+
+    def none_zero_acc(self, cm):
+        # ['Unnamed: 0', 'none', '?', 'dontcare', 'sys', 'usr', 'random']
+        col = cm.columns[1:]
+        num_label = cm.sum(axis=1)
+        correct = 0
+        for col_name in col:
+            correct += cm[col_name][col_name]
+        return correct / sum(num_label[1:])
+
+    def model_confusion_matrix(self, y_true, y_pred, file_name='cm.csv', legend=["none", "?", "dontcare", "sys", "usr", "random"]):
+        cm = metrics.confusion_matrix(y_true, y_pred)
+        df_cm = pd.DataFrame(cm, legend, legend)
+        df_cm.to_csv(os.path.join(self.dir, file_name))
+        return df_cm
+
+    def parse_result(self, prediction, label):
+        _, arg_prediction = torch.max(prediction.data, -1)
+        batch_size, token_num = label.shape
+        y_true, y_pred = [], []
+        y_turn = []
+        result = {
+            "non-zero": {"correct": 0, "total": 0},
+            "total": {"correct": 0, "total": 0},
+            "turn": {"correct": 0, "total": 0}
+        }
+
+        for batch_num in range(batch_size):
+
+            turn_acc = True  # old way
+            turn_success = 1  # new way
+
+            for element in range(token_num):
+                result["total"]["total"] += 1
+                l = label[batch_num][element].item()
+                p = arg_prediction[batch_num][element + 1].item()
+                # old way
+                if l > 0:
+                    result["non-zero"]["total"] += 1
+                if p == l:
+                    if l > 0:
+                        result["non-zero"]["correct"] += 1
+                    result["total"]["correct"] += 1
+                elif p == 0 and l < 0:
+                    result["total"]["correct"] += 1
+
+                else:
+                    if l >= 0:
+                        turn_acc = False
+
+                # new way
+                if l >= 0:
+                    y_true.append(l)
+                    y_pred.append(p)
+                if l >= 0 and l != p:
+                    turn_success = 0
+            y_turn.append(turn_success)
+            # old way
+            result["turn"]["total"] += 1
+            if turn_acc:
+                result["turn"]["correct"] += 1
+
+        return y_true, y_pred, y_turn, result
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--analysis_dir", type=str,
+                        default="user-analysis-result")
+    parser.add_argument("--user_config", type=str,
+                        default="convlab2/policy/tus/multiwoz/exp/default.json")
+    parser.add_argument("--user_mode", type=str, default="")
+    parser.add_argument("--version", type=str, default="")
+    parser.add_argument("--do_direct", action="store_true")
+    parser.add_argument("--do_interact", action="store_true")
+    parser.add_argument("--show_dialog", type=bool, default=False)
+    parser.add_argument("--usr", type=str, default="tus")
+    parser.add_argument("--sys", type=str, default="rule")
+    parser.add_argument("--num_dialog", type=int, default=400)
+    parser.add_argument("--use_mask", action="store_true")
+    parser.add_argument("--sys_config", type=str, default="")
+    parser.add_argument("--sys_model_dir", type=str,
+                        default="convlab2/policy/ppo/save/")
+    parser.add_argument("--domain", type=str, default="",
+                        help="the user goal must contain a specific domain")
+
+    args = parser.parse_args()
+
+    analysis_dir = os.path.join(args.analysis_dir, f"{args.sys}-{args.usr}")
+    if args.version:
+        analysis_dir = os.path.join(analysis_dir, args.version)
+
+    if not os.path.exists(os.path.join(analysis_dir)):
+        os.makedirs(analysis_dir)
+
+    config = json.load(open(args.user_config))
+    init_logging(log_dir_path=os.path.join(analysis_dir, "log"))
+    if args.user_mode:
+        config["model_name"] = config["model_name"] + '-' + args.user_mode
+    with open(config["all_slot"]) as f:
+        action_list = [line.strip() for line in f]
+    config["num_token"] = len(action_list)
+    if args.use_mask:
+        config["domain_mask"] = True
+
+    ana = Analysis(config, analysis_dir=analysis_dir,
+                   show_dialog=args.show_dialog)
+    if args.usr == "tus" and args.do_direct:
+        test_data = DataLoader(
+            TUSDataManager(
+                config, data_dir="data", set_type='test'),
+            batch_size=config["batch_size"],
+            shuffle=True)
+
+        model = TransformerActionPrediction(config)
+        if args.user_mode:
+            model.load_state_dict(torch.load(
+                os.path.join(config["model_dir"], config["model_name"])))
+            print(args.user_mode)
+            old_result = ana.direct_test(model, test_data)
+            print(old_result)
+        else:
+            for user_mode in ["loss", "total", "turn", "non-zero"]:
+                model.load_state_dict(torch.load(
+                    os.path.join(config["model_dir"], config["model_name"] + '-' + user_mode)))
+                print(user_mode)
+                old_result = ana.direct_test(model, test_data)
+                print(old_result)
+
+    if args.do_interact:
+        sys_load_path = None
+        if args.sys_config:
+            _, file_extension = os.path.splitext(args.sys_config)
+            # read from config
+            if file_extension == ".json":
+                sys_config = json.load(open(args.sys_config))
+                file_name = f"{sys_config['current_time']}_best_complete_rate_ppo"
+                sys_load_path = os.path.join(args.sys_model_dir, file_name)
+            # read from file
+            else:
+                sys_load_path = args.sys_config
+        ana.interact_test(sys=args.sys, usr=args.usr,
+                          sys_load_path=sys_load_path,
+                          num_dialog=args.num_dialog,
+                          domain=args.domain)
diff --git a/convlab2/policy/tus/config.json b/convlab2/policy/tus/config.json
new file mode 100755
index 0000000..da62b04
--- /dev/null
+++ b/convlab2/policy/tus/config.json
@@ -0,0 +1,14 @@
+{
+	"batchsz": 32,
+	"epoch": 16,
+	"lr": 0.001,
+	"save_dir": "save",
+	"log_dir": "log",
+	"print_per_batch": 400,
+	"save_per_epoch": 5,
+	"hu_dim": 200,
+	"eu_dim": 150,
+	"max_ulen": 20,
+	"alpha": 0.01,
+	"load": "save/best"
+}
\ No newline at end of file
diff --git a/convlab2/policy/tus/exp/default.json b/convlab2/policy/tus/exp/default.json
new file mode 100644
index 0000000..8ecad85
--- /dev/null
+++ b/convlab2/policy/tus/exp/default.json
@@ -0,0 +1,32 @@
+{
+    "model_dir": "convlab2/policy/tus/multiwoz/default",
+    "model_name": "model",
+    "num_epoch": 50,
+    "batch_size": 128,
+    "learning_rate": 1e-4,
+    "num_token": 65,
+    "debug": false,
+    "gelu": false,
+    "dropout": 0.1,
+    "embed_dim": 78,
+    "out_dim": 6,
+    "hidden": 200,
+    "num_transformer": 2,
+    "weight_factor": [
+        1,
+        1,
+        1,
+        1,
+        1,
+        5
+    ],
+    "window": 3,
+    "nhead": 4,
+    "dim_feedforward": 200,
+    "num_transform_layer": 2,
+    "turn-pos": false,
+    "reorder": true,
+    "conflict": false,
+    "remove_domain": "police",
+    "domain_feat": true
+}
\ No newline at end of file
diff --git a/convlab2/policy/tus/train.py b/convlab2/policy/tus/train.py
new file mode 100644
index 0000000..85af35e
--- /dev/null
+++ b/convlab2/policy/tus/train.py
@@ -0,0 +1,215 @@
+import json
+import os
+import time
+
+import torch
+import torch.optim as optim
+from tqdm import tqdm
+from convlab2.policy.tus.multiwoz.analysis import Analysis
+
+
+def check_device():
+    if torch.cuda.is_available():
+        print("using GPU")
+        return torch.device('cuda')
+    else:
+        print("using CPU")
+        return torch.device('cpu')
+
+
+class Trainer:
+    def __init__(self, model, config):
+        self.model = model
+        self.config = config
+        self.num_epoch = self.config["num_epoch"]
+        self.batch_size = self.config["batch_size"]
+        self.device = check_device()
+        print(self.device)
+        self.optimizer = optim.Adam(
+            model.parameters(), lr=self.config["learning_rate"])
+        # self.exp_time = time.strftime("%d-%b-%Y-%H-%M-%S", time.localtime())
+        # os.mkdir(self.exp_time)
+        self.ana = Analysis(config)
+
+    def training(self, train_data, test_data=None):
+
+        self.model = self.model.to(self.device)
+        if not os.path.exists(self.config["model_dir"]):
+            os.makedirs(self.config["model_dir"])
+
+        save_path = os.path.join(
+            self.config["model_dir"], self.config["model_name"])
+
+        # best = [0, 0, 0]
+        best = {"loss": 100}
+        lowest_loss = 100
+        for epoch in range(self.num_epoch):
+            print("epoch {}".format(epoch))
+            total_loss = self.train_epoch(train_data)
+            print("loss: {}".format(total_loss))
+            if test_data is not None:
+                acc = self.eval(test_data)
+
+            if total_loss < lowest_loss:
+                best["loss"] = total_loss
+                print(f"save model in {save_path}-loss")
+                torch.save(self.model.state_dict(), f"{save_path}-loss")
+
+            for acc_type in acc:
+                if acc_type not in best:
+                    best[acc_type] = 0
+                temp = acc[acc_type]["correct"] / acc[acc_type]["total"]
+                if best[acc_type] < temp:
+                    best[acc_type] = temp
+                    print(f"save model in {save_path}-{acc_type}")
+                    torch.save(self.model.state_dict(),
+                               f"{save_path}-{acc_type}")
+            if epoch < 10 and epoch > 5:
+                print(f"save model in {save_path}-{epoch}")
+                torch.save(self.model.state_dict(),
+                           f"{save_path}-{epoch}")
+            print(f"save latest model in {save_path}")
+            torch.save(self.model.state_dict(), save_path)
+
+    def train_epoch(self, data_loader):
+        self.model.train()
+        total_loss = 0
+        result = {}
+        # result = {"id": {"slot": {"prediction": [],"label": []}}}
+        count = 0
+        for i, data in enumerate(tqdm(data_loader, ascii=True, desc="Training"), 0):
+            input_feature = data["input"].to(self.device)
+            mask = data["mask"].to(self.device)
+            label = data["label"].to(self.device)
+            if self.config.get("domain_traget", True):
+                domain = data["domain"].to(self.device)
+            else:
+                domain = None
+            self.optimizer.zero_grad()
+
+            loss, output = self.model(input_feature, mask, label, domain)
+
+            loss.backward()
+            self.optimizer.step()
+            total_loss += float(loss)
+            count += 1
+
+        return total_loss / count
+
+    def eval(self, test_data):
+        self.model.zero_grad()
+        self.model.eval()
+
+        result = {}
+
+        with torch.no_grad():
+            correct, total, non_zero_correct, non_zero_total = 0, 0, 0, 0
+            for i, data in enumerate(tqdm(test_data, ascii=True, desc="Evaluation"), 0):
+                input_feature = data["input"].to(self.device)
+                mask = data["mask"].to(self.device)
+                label = data["label"].to(self.device)
+                output = self.model(input_feature, mask)
+                r = parse_result(output, label)
+                for r_type in r:
+                    if r_type not in result:
+                        result[r_type] = {"correct": 0, "total": 0}
+                    for n in result[r_type]:
+                        result[r_type][n] += float(r[r_type][n])
+
+        for r_type in result:
+            temp = result[r_type]['correct'] / result[r_type]['total']
+            print(f"{r_type}: {temp}")
+
+        return result
+
+
+def parse_result(prediction, label):
+    # result = {"id": {"slot": {"prediction": [],"label": []}}}
+    # dialog_index = ["dialog-id"_"slot-name", "dialog-id"_"slot-name", ...]
+    # prdiction = [0, 1, 0, ...] # after max
+
+    _, arg_prediction = torch.max(prediction.data, -1)
+    batch_size, token_num = label.shape
+    result = {
+        "non-zero": {"correct": 0, "total": 0},
+        "total": {"correct": 0, "total": 0},
+        "turn": {"correct": 0, "total": 0}
+    }
+
+    for batch_num in range(batch_size):
+        turn_acc = True
+        for element in range(token_num):
+            result["total"]["total"] += 1
+            if label[batch_num][element] > 0:
+                result["non-zero"]["total"] += 1
+
+            if arg_prediction[batch_num][element + 1] == label[batch_num][element]:
+                if label[batch_num][element] > 0:
+                    result["non-zero"]["correct"] += 1
+                result["total"]["correct"] += 1
+
+            elif arg_prediction[batch_num][element + 1] == 0 and label[batch_num][element] < 0:
+                result["total"]["correct"] += 1
+
+            else:
+                if label[batch_num][element] >= 0:
+                    turn_acc = False
+
+        result["turn"]["total"] += 1
+        if turn_acc:
+            result["turn"]["correct"] += 1
+
+    return result
+
+
+def f1(target, result):
+    target_len = 0
+    result_len = 0
+    tp = 0
+    for t, r in zip(target, result):
+        if t:
+            target_len += 1
+        if r:
+            result_len += 1
+        if r == t and t:
+            tp += 1
+    precision = 0
+    recall = 0
+    if result_len:
+        precision = tp / result_len
+    if target_len:
+        recall = tp / target_len
+    f1_score = 2 / (1 / precision + 1 / recall)
+    return f1_score, precision, recall
+
+
+if __name__ == "__main__":
+    import argparse
+    import os
+    from convlab2.policy.tus.multiwoz.transformer import \
+        TransformerActionPrediction
+    from convlab2.policy.tus.multiwoz.usermanager import \
+        TUSDataManager
+    from torch.utils.data import DataLoader
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--user_config", type=str,
+                        default="convlab2/policy/tus/multiwoz/exp/default.json")
+
+    args = parser.parse_args()
+    config = json.load(open(args.user_config))
+
+    batch_size = config["batch_size"]
+
+    # train_data = TUSDataManager(
+    #     config, data_dir="data", set_type='train')
+    train_data = TUSDataManager(config, data_dir="data", set_type='val')
+    embed_dim = train_data.features["input"].shape[-1]
+    assert embed_dim == config["embed_dim"]
+    train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)
+    test_data = DataLoader(
+        TUSDataManager(config, data_dir="data", set_type='test'), batch_size=batch_size, shuffle=True)
+
+    model = TransformerActionPrediction(config)
+    trainer = Trainer(model, config)
+    trainer.training(train_data, test_data)
diff --git a/convlab2/policy/tus/transformer.py b/convlab2/policy/tus/transformer.py
new file mode 100644
index 0000000..38f3174
--- /dev/null
+++ b/convlab2/policy/tus/transformer.py
@@ -0,0 +1,177 @@
+import json
+import math
+import os
+from copy import deepcopy
+from random import choice
+
+import numpy as np
+import torch
+
+from convlab2.policy.policy import Policy
+from torch import nn
+from torch.autograd import Variable
+from torch.nn import (GRU, CrossEntropyLoss, LayerNorm, Linear,
+                      TransformerEncoder, TransformerEncoderLayer)
+
+# TODO masking
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def get_clones(module, N):
+    return nn.ModuleList([deepcopy(module) for i in range(N)])
+
+
+class EncodeLayer(torch.nn.Module):
+    def __init__(self, config):
+        super(EncodeLayer, self).__init__()
+        self.config = config
+        self.num_token = self.config["num_token"]
+        self.embed_dim = self.config["hidden"]
+        transform_layer = TransformerEncoderLayer(
+            d_model=self.embed_dim,
+            nhead=self.config["nhead"],
+            dim_feedforward=self.config["hidden"],
+            activation='gelu')
+
+        self.norm_1 = LayerNorm(self.embed_dim)
+        self.encoder = TransformerEncoder(
+            encoder_layer=transform_layer,
+            num_layers=self.config["num_transform_layer"],
+            norm=self.norm_1)
+        self.norm_2 = LayerNorm(self.embed_dim)
+        self.use_gelu = self.config.get("gelu", False)
+        if self.use_gelu:
+            self.fc_1 = Linear(self.embed_dim, self.embed_dim)
+            self.gelu = nn.GELU()
+            self.dropout = nn.Dropout(self.config["dropout"])
+            self.fc_2 = Linear(self.embed_dim, self.embed_dim)
+        else:
+            self.fc = Linear(self.embed_dim, self.embed_dim)
+        self.init_weights()
+
+    def init_weights(self):
+        initrange = 0.1
+        if self.use_gelu:
+            self.fc_1.weight.data.uniform_(-initrange, initrange)
+            self.fc_2.weight.data.uniform_(-initrange, initrange)
+        else:
+            self.fc.weight.data.uniform_(-initrange, initrange)
+
+    def forward(self, x, mask):
+        x = self.encoder(x, src_key_padding_mask=mask)
+        return x
+
+
+class TransformerActionPrediction(torch.nn.Module):
+    def __init__(self, config):
+        super(TransformerActionPrediction, self).__init__()
+        self.config = config
+        self.num_transformer = self.config["num_transformer"]
+        self.embed_dim = self.config["embed_dim"]
+        self.out_dim = self.config["out_dim"]
+        self.hidden = self.config["hidden"]
+        self.softmax = nn.Softmax(dim=-1)
+        self.num_token = self.config["num_token"]
+        self.embed_linear = Linear(self.embed_dim, self.hidden)
+        self.position = PositionalEncoding(self.hidden, self.config)
+        self.encoder_layers = get_clones(
+            EncodeLayer(self.config), N=self.num_transformer)
+
+        self.norm_1 = LayerNorm(self.hidden)
+        self.decoder = Linear(self.hidden, self.out_dim)
+        self.norm_2 = LayerNorm(self.out_dim)
+
+        weight = [1.0] * self.out_dim
+        for i in range(self.out_dim):
+            weight[i] /= self.config["weight_factor"][i]
+
+        weight = torch.tensor(weight)
+        self.loss = CrossEntropyLoss(weight=weight, ignore_index=-1)
+        self.pick_loss = CrossEntropyLoss()
+        self.similarity = nn.CosineSimilarity(dim=-1)
+        self.domain_loss = nn.BCEWithLogitsLoss()
+        self.init_weights()
+
+    def init_weights(self):
+        initrange = 0.1
+        self.embed_linear.weight.data.uniform_(-initrange, initrange)
+        self.decoder.bias.data.zero_()
+        self.decoder.weight.data.uniform_(-initrange, initrange)
+
+    def forward(self, input_feat, mask, label=None, domain_label=None):
+
+        src = self.embed_linear(input_feat) * math.sqrt(self.hidden)
+        src = self.position(src)
+        src = src.permute(1, 0, 2)
+        for i in range(self.num_transformer):
+            src = self.encoder_layers[i](src, mask)
+
+        out_src = self.norm_1(src)
+        out_src = src + out_src
+
+        out_src = out_src.permute(1, 0, 2)
+
+        out = self.decoder(out_src)
+        out = self.norm_2(out)
+
+        if label is not None:
+            if domain_label is None:
+                loss = self.get_loss(out, label)
+            else:
+                loss = self.get_loss(out, label) + \
+                    self.first_token_loss(out, domain_label)
+            return loss, out
+        return out
+
+    def get_loss(self, prediction, target):
+        # prediction = [batch_size, num_token, out_dim]
+        # target = [batch_size, num_token]
+        # first token is CLS
+        pre = prediction[:, 1: self.num_token + 1, :]
+        pre = torch.reshape(pre, (pre.shape[0]*pre.shape[1], pre.shape[-1]))
+        l = self.loss(pre, target.view(-1))
+
+        return l
+
+    def first_token_loss(self, prediction, target):
+        # prediction = [batch_size, num_token, out_dim]
+        # target = [batch_size, num_token]
+        pre = prediction[:, 0, :]
+        l = self.domain_loss(pre, target)
+        return l
+
+
+class PositionalEncoding(nn.Module):
+
+    def __init__(self, d_model, config, dropout=0.1, max_len=5000):
+        super(PositionalEncoding, self).__init__()
+        self.config = config
+        self.dropout = nn.Dropout(p=dropout)
+        self.turn_pos = self.config.get("turn-pos", True)
+
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        position = position // 65
+        div_term = torch.exp(torch.arange(
+            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0).transpose(0, 1)
+        self.register_buffer('pe', pe)
+
+        pe1 = torch.zeros(max_len, d_model)
+        position1 = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term1 = torch.exp(torch.arange(
+            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe1[:, 0::2] = torch.sin(position1 * div_term1)
+        pe1[:, 1::2] = torch.cos(position1 * div_term1)
+        pe1 = pe1.unsqueeze(0).transpose(0, 1)
+        self.register_buffer('pe1', pe1)
+
+    def forward(self, x):
+        if self.turn_pos:
+            x = x + self.pe1[:x.size(0), :]*0.5 + self.pe[:x.size(0), :]*0.5
+        else:
+            x = x + self.pe1[:x.size(0), :]
+        # x = torch.cat((x, self.pe, self.pe1), -1)
+        return self.dropout(x)
diff --git a/convlab2/policy/tus/usermanager.py b/convlab2/policy/tus/usermanager.py
new file mode 100644
index 0000000..1cf3818
--- /dev/null
+++ b/convlab2/policy/tus/usermanager.py
@@ -0,0 +1,550 @@
+from convlab2.policy.tus.multiwoz import util
+from convlab2.dst.rule.multiwoz.usr_dst import UserRuleDST
+from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
+from collections import Counter
+from convlab2.policy.tus.multiwoz.Goal import Goal
+from torch.utils.data import DataLoader, Dataset, RandomSampler
+from pprint import pprint
+from random import uniform
+from tqdm import tqdm
+import json
+import torch
+import os
+
+NOT_MENTIONED = "not mentioned"
+
+
+def dic2list(da2goal):
+    action_list = []
+    for domain in da2goal:
+        for slot in da2goal[domain]:
+            if da2goal[domain][slot] is None:
+                continue
+            act = f"{domain}-{da2goal[domain][slot]}"
+            if act not in action_list:
+                action_list.append(act)
+    return action_list
+
+
+class TUSDataManager(Dataset):
+    def __init__(self,
+                 config,
+                 data_dir='data',
+                 set_type='train',
+                 max_turns=12):
+
+        self.config = config
+        if set_type == "train":
+            data = json.load(
+                open(os.path.join(data_dir, 'multiwoz/train_modified.json'), 'r'))
+        elif set_type == "val":
+            data = json.load(
+                open(os.path.join(data_dir, 'multiwoz/val.json'), 'r'))
+        elif set_type == "test":
+            data = json.load(
+                open(os.path.join(data_dir, 'multiwoz/test.json'), 'r'))
+        else:
+            print("UNKNOWN DATA TYPE")
+
+        self.all_values = json.load(open("data/multiwoz/all_value.json"))
+
+        self.remove_domain = self.config.get("remove_domain", "")
+        self.feature_handler = Feature(self.config)
+        self.features = self.process(data, max_turns)
+
+    def __getitem__(self, index):
+        return {label: self.features[label][index] if self.features[label] is not None else None
+                for label in self.features}
+
+    def __len__(self):
+        return self.features['input'].size(0)
+
+    def resample(self, size=None):
+        n_dialogues = self.__len__()
+        if not size:
+            size = n_dialogues
+
+        dialogues = torch.randint(low=0, high=n_dialogues, size=(size,))
+        self.features = {
+            label: self.features[label][dialogues] for label in self.features}
+
+    def to(self, device):
+        self.device = device
+        self.features = {label: self.features[label].to(
+            device) for label in self.features}
+
+    def process(self, data, max_turns):
+
+        feature = {"id": [], "input": [],
+                   "label": [], "mask": [], "domain": []}
+        dst = UserRuleDST()
+        for dialog_id in tqdm(data, ascii=True, desc="Processing"):
+            user_goal = Goal(goal=data[dialog_id]["goal"])
+
+            # if one domain is removed, we skip all data related to this domain
+            # remove police at default
+            if (self.remove_domain and self.remove_domain in user_goal.domains) or \
+                    ("police" in user_goal.domains):
+                continue
+
+            turn_num = len(data[dialog_id]["log"])
+            pre_state = {}
+            sys_act = []
+            self.feature_handler.initFeatureHandeler(user_goal)
+            dst.init_session()
+            user_mentioned = util.get_user_history(
+                data[dialog_id]["log"], self.all_values)
+
+            for turn_id in range(0, turn_num, 2):
+                action_list = user_goal.action_list(
+                    user_history=user_mentioned,
+                    sys_act=sys_act,
+                    all_values=self.all_values)
+                if turn_id > 0:
+                    # cur_state = data[dialog_id]["log"][turn_id-1]["metadata"]
+                    sys_act = util.parse_dialogue_act(
+                        data[dialog_id]["log"][turn_id - 1]["dialog_act"])
+                cur_state = dst.update(sys_act)["belief_state"]
+
+                user_goal.update_user_goal(sys_act, cur_state)
+                usr_act = util.parse_dialogue_act(
+                    data[dialog_id]["log"][turn_id]["dialog_act"])
+                input_feature, mask = self.feature_handler.get_feature(
+                    action_list, user_goal, cur_state, pre_state, sys_act)
+                label = self.feature_handler.generate_label(
+                    action_list, user_goal, cur_state, usr_act, mode='sys')
+                domain_label = self.feature_handler.domain_label(
+                    user_goal, usr_act)
+                pre_state = dst.update(usr_act)
+                pre_state = pre_state["belief_state"]
+                feature["id"].append(dialog_id)
+                feature["input"].append(input_feature)
+                feature["mask"].append(mask)
+                feature["label"].append(label)
+                feature["domain"].append(domain_label)
+
+        print("label distribution")
+        label_distribution = Counter()
+        for label in feature["label"]:
+            label_distribution += Counter(label)
+        print(label_distribution)
+        feature["input"] = torch.tensor(feature["input"], dtype=torch.float)
+        feature["label"] = torch.tensor(feature["label"], dtype=torch.long)
+        feature["mask"] = torch.tensor(feature["mask"], dtype=torch.bool)
+        feature["domain"] = torch.tensor(feature["domain"], dtype=torch.float)
+        for feat_type in ["input", "label", "mask", "domain"]:
+            print("{}: {}".format(feat_type, feature[feat_type].shape))
+        return feature
+
+    def _update_act_dict(self, act_dict, act):
+        for intent, domain, slot, value in act:
+            domain = domain.lower()
+            slot = slot.lower()
+            value = value.lower()
+            if domain == "general":
+                continue
+            elif domain == "booking":
+                domain = util.get_booking_domain(
+                    slot, value, self.feature_handler.all_values)
+            if slot in UsrDa2Goal[domain]:
+                slot = UsrDa2Goal[domain][slot]
+            elif slot in SysDa2Goal["booking"]:
+                slot = SysDa2Goal["booking"][slot]
+            else:
+                print(
+                    f"strange action in geathering all actions{intent, domain, slot, value}")
+            if slot == "none":
+                continue
+            act_name = f"{domain}-{slot}"
+            if act_name not in act_dict:
+                act_dict[act_name] = []
+            if intent not in act_dict[act_name]:
+                act_dict[act_name].append(intent)
+
+
+class Feature:
+    def __init__(self, config, max_turn=1):
+        self.config = config
+        self.max_turn = max_turn
+        self.intents = ["inform", "request", "recommend", "select",
+                        "book", "nobook", "offerbook", "offerbooked", "nooffer"]
+        self.general_intent = ["reqmore", "bye", "thank", "welcome", "greet"]
+        self.default_values = ["none", "?", "dontcare"]
+        # self.initFeatureHandeler()
+        path = os.path.dirname(
+            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
+        path = os.path.join(path, 'data/multiwoz/all_value.json')
+        self.all_values = json.load(open(path))
+
+    def initFeatureHandeler(self, goal):
+        self.goal = goal
+        self.domain_list = goal.domains
+        usr = util.parse_user_goal(goal)
+        self.constrains = {}  # slot: fulfill
+        self.requirements = {}  # slot: fulfill
+        self.pre_usr = []
+        self.all_slot = None
+        self.user_feat_hist = []
+        for slot in usr:
+            if usr[slot] != "?":
+                self.constrains[slot] = NOT_MENTIONED
+
+    def get_feature(self, all_slot, user_goal, cur_state, pre_state=None, sys_action=None, usr_action=None):
+        """ 
+        given current dialog information and return the input feature 
+        user_goal: Goal()
+        cur_state: dict, {domain: "semi": {slot: value}, "book": {slot: value, "booked": []}}("metadata" in the data set)
+        sys_action: [[intent, domain, slot, value]]
+        """
+        feature = []
+        usr = util.parse_user_goal(user_goal)
+
+        if sys_action:
+            self.update_constrain(sys_action)
+
+        cur = util.metadata2state(cur_state)
+        pre = {}
+        if pre_state != None:
+            pre = util.metadata2state(pre_state)
+        if not self.pre_usr:
+            self.pre_usr = [0] * len(all_slot)
+
+        if usr_action:
+            self.get_user_action_feat(all_slot, user_goal, usr_action)
+
+        usr_act_feat = {}
+        for index, slot in zip(self.pre_usr, all_slot):
+            usr_act_feat[slot] = util.int2onehot(index, self.config["out_dim"])
+        # usr_act_feat = self.user_action(all_slot, user_goal, usr_action)
+        for slot in all_slot:
+            feat = self.slot_feature(
+                slot, usr, cur, pre, sys_action, usr_act_feat)
+            feature.append(feat)
+        self.user_feat_hist.append(feature)
+
+        feature, mask = self.pad_feature()
+        return feature, mask
+
+    def slot_feature(self, slot, user_goal, current_state, previous_state, sys_action, usr_action):
+        feat = []
+        feat += self.special_token(slot)
+        feat += self.value_representation(
+            slot, current_state.get(slot, NOT_MENTIONED))
+        feat += self.value_representation(
+            slot, user_goal.get(slot, NOT_MENTIONED))
+        feat += self.is_constrain_request(slot, user_goal)
+        feat += self.is_fulfill(slot, user_goal)
+        if self.config.get("conflict", True):
+            feat += self.conflict_check(user_goal, current_state, slot)
+        if self.config.get("domain_feat", False):
+            # feat += self.domain_feat(slot)
+            if slot in ["CLS", "SEP"]:
+                feat += [0] * (self.goal.max_domain_len +
+                               self.goal.max_slot_len)
+            else:
+                domain_feat, slot_feat = self.goal.get_slot_id(slot)
+                feat += domain_feat + slot_feat
+        feat += self.first_mention_detection(
+            previous_state, current_state, slot)
+        feat += self.just_mention(slot, sys_action)
+        feat += self.action_representation(slot, sys_action)
+        # need change from 0 to domain predictor
+        if slot in ["CLS", "SEP"]:
+            feat += [0] * self.config["out_dim"]
+        else:
+            feat += usr_action[slot]
+        return feat
+
+    def pad_feature(self, max_memory=5):
+        feature = []
+        feat_dim = len(self.user_feat_hist[0][0])
+        num_token = len(self.user_feat_hist[0])
+        num_feat = len(self.user_feat_hist)
+        zero_pad = [0] * feat_dim
+
+        for feat_index in range(num_feat - 1, max(num_feat - 1 - max_memory, -1), -1):
+            if feat_index == num_feat - 1:
+                special_token = self.slot_feature(
+                    "CLS", {}, {}, {}, [], [])
+            else:
+                special_token = self.slot_feature(
+                    "SEP", {}, {}, {}, [], [])
+            feature += [special_token]
+            feature += self.user_feat_hist[feat_index]
+
+        # zero padding, 65 is the maximum -> need to modify
+        max_len = max_memory * 65
+        if len(feature) < max_len:
+            padding = [[0] * feat_dim] * (max_len - len(feature))
+            feature += padding
+            mask = [False] * len(feature) + [True] * (max_len - len(feature))
+
+        # if num_feat < max_memory:
+        #     zero_feat = [[0] * feat_dim] * num_token
+        #     for _ in range(max_memory - num_feat):
+        #         feature += zero_feat
+        return feature, mask
+
+    def domain_feat(self, slot):
+        max_domain = 3
+        feat = [0] * max_domain
+        domain = slot.split('-')[0]
+        if domain in self.domain_list:
+            index = self.domain_list.index(domain)
+            if index < max_domain:
+                feat[index] = 1
+        return feat
+
+    def domain_label(self, user_goal, dialog_act):
+        labels = [0] * self.config["out_dim"]
+        goal_domains = user_goal.domains
+        no_domain = True
+
+        for intent, domain, slot, value in dialog_act:
+            domain = domain.lower()
+            if domain in goal_domains:
+                index = goal_domains.index(domain)
+                labels[index + 1] = 1
+                no_domain = False
+        if no_domain:
+            labels[0] = 1
+        return labels
+
+    def generate_label(self, action_list, user_goal, cur_state, dialog_act, mode="usr"):
+        # label = "none", "?", "dontcare", "system", "user", "change"
+
+        labels = [-1] * self.config["num_token"]
+        # if mode == "sys":
+        #     labels = [-1] * len(action_list)
+        # elif mode == "usr":
+        #     labels = [0] * len(action_list)
+        # else:
+        #     raise Exception(f"UNSEEN MODE {mode} in generate_label.")
+
+        usr = util.parse_user_goal(user_goal)
+        cur = util.metadata2state(cur_state)
+        for intent, domain, slot, value in dialog_act:
+            domain = domain.lower()
+            value = value.lower()
+            slot = slot.lower()
+            name = util.act2slot(intent, domain, slot, value, self.all_values)
+
+            if name not in action_list:
+                # print(f"Not handle name {name} in getting label")
+                continue
+            name_id = action_list.index(name)
+            if value == "?":
+                labels[name_id] = 1
+            elif value == "dontcare":
+                labels[name_id] = 2
+            elif name in cur and value == cur[name]:
+                labels[name_id] = 3
+            elif name in usr and value == usr[name]:
+                labels[name_id] = 4
+            elif (name in cur or name in usr) and value not in [cur.get(name), usr.get(name)]:
+                labels[name_id] = 5
+
+        for name in action_list:
+            domain = name.split('-')[0]
+            name_id = action_list.index(name)
+            if labels[name_id] < 0 and domain in self.domain_list:
+                labels[name_id] = 0
+
+        self.pre_usr = labels
+
+        return labels
+
+    def get_user_action_feat(self, action_list, user_goal, usr_act):
+        usr_label = self.generate_label(
+            action_list, user_goal, {}, usr_act, mode="usr")
+        self.pre_usr = usr_label
+
+    def special_token(self, slot):
+        special_token = ["CLS", "SEP"]
+        feat = [0]*len(special_token)
+        if slot in special_token:
+            feat[special_token.index(slot)] = 1
+        return feat
+
+    def is_constrain_request(self, feature_slot, user_goal):
+        if feature_slot in ["CLS", "SEP"]:
+            return [0, 0]
+        # [is_constrain, is_request]
+        value = user_goal.get(feature_slot, NOT_MENTIONED)
+        if value == "?":
+            return [0, 1]
+        elif value == NOT_MENTIONED:
+            return [0, 0]
+        else:
+            return [1, 0]
+
+    def is_fulfill(self, feature_slot, user_goal):
+        if feature_slot in ["CLS", "SEP"]:
+            return [0]
+
+        if feature_slot in user_goal and user_goal.get(feature_slot) == self.constrains.get(feature_slot):
+            return [1]
+        return [0]
+
+    def just_mention(self, feature_slot, sys_action):
+        """
+        the system action just mentioned this slot
+        """
+        if feature_slot in ["CLS", "SEP"]:
+            return [0]
+        if not sys_action:
+            return [0]
+        sys_action_slot = []
+        for intent, domain, slot, value in sys_action:
+            domain = domain.lower()
+            slot = slot.lower()
+            value = value.lower()
+            if domain == "booking":
+                domain = util.get_booking_domain(slot, value, self.all_values)
+            if domain in sys_action:
+                action = f"{domain}-{slot}"
+                sys_action_slot.append(action)
+        if feature_slot in sys_action_slot:
+            return [1]
+        return [0]
+
+    def action_representation(self, feature_slot, action):
+
+        gen_vec = [0] * len(self.general_intent)
+        # ["none", "?", other]
+        intent2act = {intent: [0] * 3 for intent in self.intents}
+
+        if action is None or feature_slot in ["CLS", "SEP"]:
+            return self.concatenate_action_vector(intent2act, gen_vec)
+        for intent, domain, slot, value in action:
+            domain = domain.lower()
+            slot = slot.lower()
+            value = value.lower()
+
+            # general
+            if domain == "general":
+                self.update_general_action(gen_vec, intent)
+            else:
+                if domain == "booking":
+                    domain = util.get_booking_domain(
+                        slot, value, self.all_values)
+                self.update_intent2act(
+                    feature_slot, intent2act,
+                    domain, intent, slot, value)
+
+        return self.concatenate_action_vector(intent2act, gen_vec)
+
+    def update_general_action(self, vec, intent):
+        if intent in self.general_intent:
+            vec[self.general_intent.index(intent)] = 1
+
+    def update_intent2act(self, feature_slot, intent2act, domain, intent, slot, value):
+        feature_domain, feature_slot = feature_slot.split('-')
+        intent = intent.lower()
+        slot = slot.lower()
+        value = value.lower()
+        if slot == "none" and feature_domain == domain:  # None slot
+            intent2act[intent][2] = 1
+        elif feature_domain == domain and slot == feature_slot:
+            if value == "none":
+                intent2act[intent][0] = 1
+            elif value == "?":
+                intent2act[intent][1] = 1
+            else:
+                intent2act[intent][2] = 1
+
+    def concatenate_action_vector(self, intent2act, general):
+        feat = []
+        for intent in intent2act:
+            feat += intent2act[intent]
+        feat += general
+        return feat
+
+    def value_representation(self, slot, value):
+        if slot in ["CLS", "SEP"]:
+            return [0, 0, 0, 0]
+        if value == NOT_MENTIONED:
+            return [1, 0, 0, 0]
+        else:
+            temp_vector = [0] * (len(self.default_values) + 1)
+            if value in self.default_values:
+                temp_vector[self.default_values.index(value)] = 1
+            else:
+                temp_vector[-1] = 1
+
+            return temp_vector
+
+    def conflict_check(self, user_goal, system_state, slot):
+        # conflict = [1] else [0]
+        if slot in ["CLS", "SEP"]:
+            return [0]
+        usr = user_goal.get(slot, NOT_MENTIONED)
+        sys = system_state.get(slot, NOT_MENTIONED)
+        if usr in [NOT_MENTIONED, "none", ""] and sys in [NOT_MENTIONED, "none", ""]:
+            return [0]
+
+        if usr != sys or (usr == "?" and sys == "?"):
+            # print(f"{slot} has different value: {usr} and {sys}.")
+            # conflict = uniform(0.2, 1)
+            conflict = 1
+            return [conflict]
+        return [0]
+
+    def first_mention_detection(self, pre_state, cur_state, slot):
+        if slot in ["CLS", "SEP"]:
+            return [0]
+
+        first_mention = [1]
+        not_first_mention = [0]
+        cur = cur_state.get(slot, NOT_MENTIONED)
+        if pre_state is None:
+            if cur not in [NOT_MENTIONED, "none"]:
+                return first_mention
+            else:
+                return not_first_mention
+
+        pre = pre_state.get(slot, NOT_MENTIONED)
+
+        if pre in [NOT_MENTIONED, "none"] and cur not in [NOT_MENTIONED, "none"]:
+            return first_mention
+
+        return not_first_mention  # hasn't been mentioned
+
+    def update_constrain(self, action):
+        """ 
+        update constrain status by system actions
+        action = [[intent, domain, slot, name]]
+        """
+        for intent, domain, slot, value in action:
+            domain = domain.lower()
+            if domain in self.domain_list:
+                slot = SysDa2Goal[domain].get(slot, "none")
+                slot_name = f"{domain}-{slot}"
+            elif domain == "booking":
+                if slot.lower() == "ref":
+                    continue
+                slot = SysDa2Goal[domain].get(slot, "none")
+                domain = util.get_booking_domain(slot, value, self.all_values)
+                if not domain:
+                    continue  # work around
+                slot_name = f"{domain}-{slot}"
+
+            else:
+                continue
+            if value != "?":
+                self.constrains[slot_name] = value
+
+    def update_request(self):
+        # TODO
+        pass
+
+    @staticmethod
+    def concatenate_subvectors(vec_list):
+        vec = []
+        for sub_vec in vec_list:
+            vec += sub_vec
+        return vec
+
+
+if __name__ == "__main__":
+    x = TUSDataManager()
diff --git a/convlab2/policy/tus/util.py b/convlab2/policy/tus/util.py
new file mode 100644
index 0000000..22d60cc
--- /dev/null
+++ b/convlab2/policy/tus/util.py
@@ -0,0 +1,126 @@
+from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal
+
+NOT_MENTIONED = "not mentioned"
+
+
+def int2onehot(index, output_dim=6):
+    one_hot = [0] * output_dim
+    if index >= 0:
+        one_hot[index] = 1
+
+    return one_hot
+
+
+def parse_user_goal(user_goal):
+    """flatten user goal structure"""
+    goal = user_goal.domain_goals
+    user_goal = {}
+    for domain in goal:
+        if domain not in UsrDa2Goal:
+            continue
+        for slot_type in goal[domain]:
+            if slot_type in ["fail_info", "fail_book", "booked"]:
+                continue  # TODO [fail_info] fix in the future
+            if slot_type in ["info", "book", "reqt"]:
+                for slot in goal[domain][slot_type]:
+                    slot_name = f"{domain}-{slot.lower()}"
+                    user_goal[slot_name] = goal[domain][slot_type][slot]
+
+    return user_goal
+
+
+def parse_dialogue_act(dialogue_act):
+    """ transfer action from dict to list """
+    actions = []
+    for act in dialogue_act:
+        domain, intent = act.split('-')
+        for slot, value in dialogue_act[act]:
+            value_dict = {"do nt care": "dontcare"}
+            if value in value_dict:
+                value = value_dict[value]
+            actions.append([intent, domain, slot, value])
+
+    return actions
+
+
+def metadata2state(metadata):
+    """
+    parse metadata in the data set or dst
+    """
+    slot_value = {}
+
+    for domain in metadata:
+        for slot in metadata[domain]["semi"]:
+            slot_name = f"{domain.lower()}-{slot.lower()}"
+            value = metadata[domain]["semi"][slot]
+            if not value or value == NOT_MENTIONED:
+                value = "none"
+            slot_value[slot_name] = value
+
+        for slot in metadata[domain]["book"]:
+            if slot == "booked":
+                continue
+            slot_name = f"{domain.lower()}-{slot.lower()}"
+            value = metadata[domain]["book"][slot]
+            slot_value[slot_name] = value
+
+    return slot_value
+
+
+def get_booking_domain(slot, value, all_values):
+    """ 
+    find the domain for domain booking, excluding slot "ref"
+    """
+    found = ""
+    if not slot:
+        return found
+    slot = slot.lower()
+    value = value.lower()
+    for domain in all_values["all_value"]:
+        if slot in all_values["all_value"][domain] and value in all_values["all_value"][domain][slot]:
+            found = domain
+    return found
+
+
+def act2slot(intent, domain, slot, value, all_values):
+
+    if domain not in UsrDa2Goal:
+        # print(f"Not handle domain {domain}")
+        return ""
+
+    if domain == "booking":
+        slot = SysDa2Goal[domain][slot]
+        domain = get_booking_domain(slot, value, all_values)
+        return f"{domain}-{slot}"
+
+    elif domain in UsrDa2Goal:
+        if slot in SysDa2Goal[domain]:
+            slot = SysDa2Goal[domain][slot]
+        elif slot in UsrDa2Goal[domain]:
+            slot = UsrDa2Goal[domain][slot]
+        elif slot in SysDa2Goal["booking"]:
+            slot = SysDa2Goal["booking"][slot]
+        # else:
+        #     print(
+        #         f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}")
+
+        return f"{domain}-{slot}"
+
+    print("strange!!!")
+    print(intent, domain, slot, value)
+
+    return ""
+
+
+def get_user_history(dialog, all_values):
+    turn_num = len(dialog)
+    mentioned_slot = []
+    for turn_id in range(0, turn_num, 2):
+        usr_act = parse_dialogue_act(
+            dialog[turn_id]["dialog_act"])
+        for intent, domain, slot, value in usr_act:
+            slot_name = act2slot(
+                intent, domain.lower(), slot.lower(), value.lower(), all_values)
+            if slot_name not in mentioned_slot:
+                mentioned_slot.append(slot_name)
+    return mentioned_slot
-- 
GitLab