From 2e9628573fe7f3020e5f1dc5cb072a6795c0fd88 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Mon, 14 Jun 2021 09:58:58 +0200 Subject: [PATCH] add tus --- convlab2/dst/rule/multiwoz/usr_dst.py | 217 ++++++++++ convlab2/policy/tus/Da2Goal.py | 87 ++++ convlab2/policy/tus/Goal.py | 405 +++++++++++++++++++ convlab2/policy/tus/TUS.py | 293 ++++++++++++++ convlab2/policy/tus/__init__.py | 0 convlab2/policy/tus/analysis.py | 430 ++++++++++++++++++++ convlab2/policy/tus/config.json | 14 + convlab2/policy/tus/exp/default.json | 32 ++ convlab2/policy/tus/train.py | 215 ++++++++++ convlab2/policy/tus/transformer.py | 177 +++++++++ convlab2/policy/tus/usermanager.py | 550 ++++++++++++++++++++++++++ convlab2/policy/tus/util.py | 126 ++++++ 12 files changed, 2546 insertions(+) create mode 100755 convlab2/dst/rule/multiwoz/usr_dst.py create mode 100644 convlab2/policy/tus/Da2Goal.py create mode 100644 convlab2/policy/tus/Goal.py create mode 100644 convlab2/policy/tus/TUS.py create mode 100644 convlab2/policy/tus/__init__.py create mode 100644 convlab2/policy/tus/analysis.py create mode 100755 convlab2/policy/tus/config.json create mode 100644 convlab2/policy/tus/exp/default.json create mode 100644 convlab2/policy/tus/train.py create mode 100644 convlab2/policy/tus/transformer.py create mode 100644 convlab2/policy/tus/usermanager.py create mode 100644 convlab2/policy/tus/util.py diff --git a/convlab2/dst/rule/multiwoz/usr_dst.py b/convlab2/dst/rule/multiwoz/usr_dst.py new file mode 100755 index 0000000..00199fd --- /dev/null +++ b/convlab2/dst/rule/multiwoz/usr_dst.py @@ -0,0 +1,217 @@ +import json +import os + +from convlab2.util.multiwoz.state import default_state +from convlab2.dst.rule.multiwoz.dst_util import normalize_value +from convlab2.dst.dst import DST +from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA +from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal +from pprint import pprint + +SLOT2SEMI = { + "arriveby": "arriveBy", + "leaveat": "leaveAt", + "trainid": "trainID", +} + + +class UserRuleDST(DST): + """Rule based DST which trivially updates new values from NLU result to states. + + Attributes: + state(dict): + Dialog state. Function ``convlab2.util.multiwoz.state.default_state`` returns a default state. + value_dict(dict): + It helps check whether ``user_act`` has correct content. + """ + + def __init__(self): + DST.__init__(self) + self.state = default_state() + path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + path = os.path.join(path, 'data/multiwoz/value_dict.json') + self.value_dict = json.load(open(path)) + self.mentioned_domain = [] + + def update(self, sys_act=None): + """ + update belief_state, request_state + :param sys_act: + :return: + """ + # print("dst", user_act) + self.update_mentioned_domain(sys_act) + for intent, domain, slot, value in sys_act: + domain = domain.lower() + intent = intent.lower() + if domain in ['unk', 'general']: + continue + # TODO domain: booking + if domain == "booking": + for domain in self.mentioned_domain: + self.update_inform_request( + intent, domain, slot, value) + else: + self.update_inform_request(intent, domain, slot, value) + + # elif intent == 'inform' or intent == 'recommend' or intent == 'request': + # # TODO domain: booking + # self.update_inform_request(self, intent, domain, slot, value) + + # elif intent in ["offerbooked", "book", "nobook"]: + # pass + # # elif intent == 'request': + # k = REF_SYS_DA[domain.capitalize()].get(slot, slot) + # if domain not in self.state['request_state']: + # self.state['request_state'][domain] = {} + # if k not in self.state['request_state'][domain]: + # self.state['request_state'][domain][k] = 0 + # self.state['user_action'] = user_act # should be added outside DST module + return self.state + + def init_session(self): + """Initialize ``self.state`` with a default state, which ``convlab2.util.multiwoz.state.default_state`` returns.""" + self.state = default_state() + self.mentioned_domain = [] + + def update_mentioned_domain(self, sys_act): + if not sys_act: + return + for intent, domain, slot, value in sys_act: + domain = domain.lower() + if domain not in self.mentioned_domain and domain not in ['unk', 'general', 'booking']: + self.mentioned_domain.append(domain) + # print(f"update: mentioned {domain} domain") + + def update_inform_request(self, intent, domain, slot, value): + slot = slot.lower() + k = SysDa2Goal[domain].get(slot, slot) + k = SLOT2SEMI.get(k, k) + if k is None: + return + try: + assert domain in self.state['belief_state'] + except: + raise Exception( + 'Error: domain <{}> not in new belief state'.format(domain)) + domain_dic = self.state['belief_state'][domain] + assert 'semi' in domain_dic + assert 'book' in domain_dic + if k in domain_dic['semi']: + nvalue = normalize_value(self.value_dict, domain, k, value) + self.state['belief_state'][domain]['semi'][k] = nvalue + elif k in domain_dic['book']: + self.state['belief_state'][domain]['book'][k] = value + elif k.lower() in domain_dic['book']: + self.state['belief_state'][domain]['book'][k.lower() + ] = value + elif k == 'trainID' and domain == 'train': + self.state['belief_state'][domain]['book'][k] = normalize_value( + self.value_dict, domain, k, value) + else: + # print('unknown slot name <{}> of domain <{}>'.format(k, domain)) + nvalue = normalize_value(self.value_dict, domain, k, value) + self.state['belief_state'][domain]['semi'][k] = nvalue + with open('unknown_slot.log', 'a+') as f: + f.write( + 'unknown slot name <{}> of domain <{}>\n'.format(k, domain)) + + def update_request(self): + pass + + def update_booking(self): + pass + + +if __name__ == '__main__': + # from convlab2.dst.rule.multiwoz import RuleDST + + dst = UserRuleDST() + + action = [['Inform', 'Restaurant', 'Phone', '01223323737'], + ['reqmore', 'general', 'none', 'none'], + ["Inform", "Hotel", "Area", "east"], ] + state = dst.update(action) + pprint(state) + dst.init_session() + + # Action is a dict. Its keys are strings(domain-type pairs, both uppercase and lowercase is OK) and its values are list of lists. + # The domain may be one of ('Attraction', 'Hospital', 'Booking', 'Hotel', 'Restaurant', 'Taxi', 'Train', 'Police'). + # The type may be "inform" or "request". + + # For example, the action below has a key "Hotel-Inform", in which "Hotel" is domain and "Inform" is action type. + # Each list in the value of "Hotel-Inform" is a slot-value pair. "Area" is slot and "east" is value. "Star" is slot and "4" is value. + action = [ + ["Inform", "Hotel", "Area", "east"], + ["Inform", "Hotel", "Stars", "4"] + ] + + # method `update` updates the attribute `state` of tracker, and returns it. + state = dst.update(action) + assert state == dst.state + assert state == {'user_action': [], + 'system_action': [], + 'belief_state': {'police': {'book': {'booked': []}, 'semi': {}}, + 'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''}, + 'semi': {'name': '', + 'area': 'east', + 'parking': '', + 'pricerange': '', + 'stars': '4', + 'internet': '', + 'type': ''}}, + 'attraction': {'book': {'booked': []}, + 'semi': {'type': '', 'name': '', 'area': ''}}, + 'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''}, + 'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}}, + 'hospital': {'book': {'booked': []}, 'semi': {'department': ''}}, + 'taxi': {'book': {'booked': []}, + 'semi': {'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': ''}}, + 'train': {'book': {'booked': [], 'people': ''}, + 'semi': {'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '', + 'departure': ''}}}, + 'request_state': {}, + 'terminated': False, + 'history': []} + + # Please call `init_session` before a new dialog. This initializes the attribute `state` of tracker with a default state, which `convlab2.util.multiwoz.state.default_state` returns. But You needn't call it before the first dialog, because tracker gets a default state in its constructor. + dst.init_session() + action = [["Inform", "Train", "Arrive", "19:45"]] + state = dst.update(action) + assert state == {'user_action': [], + 'system_action': [], + 'belief_state': {'police': {'book': {'booked': []}, 'semi': {}}, + 'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''}, + 'semi': {'name': '', + 'area': '', + 'parking': '', + 'pricerange': '', + 'stars': '', + 'internet': '', + 'type': ''}}, + 'attraction': {'book': {'booked': []}, + 'semi': {'type': '', 'name': '', 'area': ''}}, + 'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''}, + 'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}}, + 'hospital': {'book': {'booked': []}, 'semi': {'department': ''}}, + 'taxi': {'book': {'booked': []}, + 'semi': {'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': ''}}, + 'train': {'book': {'booked': [], 'people': ''}, + 'semi': {'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '19:45', + 'departure': ''}}}, + 'request_state': {}, + 'terminated': False, + 'history': []} diff --git a/convlab2/policy/tus/Da2Goal.py b/convlab2/policy/tus/Da2Goal.py new file mode 100644 index 0000000..8565297 --- /dev/null +++ b/convlab2/policy/tus/Da2Goal.py @@ -0,0 +1,87 @@ +UsrDa2Goal = { + 'attraction': { + 'area': 'area', 'name': 'name', 'type': 'type', + 'addr': 'address', 'fee': 'entrance fee', 'phone': 'phone', + 'post': 'postcode', 'ref': "ref", 'none': None + }, + 'hospital': { + 'department': 'department', 'addr': 'address', 'phone': 'phone', + 'post': 'postcode', 'ref': "ref", 'none': None + }, + 'hotel': { + 'area': 'area', 'internet': 'internet', 'name': 'name', + 'parking': 'parking', 'price': 'pricerange', 'stars': 'stars', + 'type': 'type', 'addr': 'address', 'phone': 'phone', + 'post': 'postcode', 'day': 'day', 'people': 'people', + 'stay': 'stay', 'ref': "ref", 'none': None + }, + 'police': { + 'addr': 'address', 'phone': 'phone', 'post': 'postcode', 'name': 'name', 'ref': "ref", 'none': None + }, + 'restaurant': { + 'area': 'area', 'day': 'day', 'food': 'food', 'type': 'type', + 'name': 'name', 'people': 'people', 'price': 'pricerange', + 'time': 'time', 'addr': 'address', 'phone': 'phone', + 'post': 'postcode', 'ref': "ref", 'none': None + }, + 'taxi': { + 'arrive': 'arriveby', 'depart': 'departure', 'dest': 'destination', + 'leave': 'leaveat', 'car': 'car type', 'phone': 'phone', 'ref': "ref", 'none': None + }, + 'train': { + 'time': "duration", 'arrive': 'arriveby', 'day': 'day', 'ref': "ref", + 'depart': 'departure', 'dest': 'destination', 'leave': 'leaveat', + 'people': 'people', 'duration': 'duration', 'price': 'price', 'choice': "choice", + 'trainid': 'trainid', 'ticket': 'price', 'id': "trainid", 'none': None + } +} + +SysDa2Goal = { + 'attraction': { + 'addr': "address", 'area': "area", 'choice': "choice", + 'fee': "entrance fee", 'name': "name", 'phone': "phone", + 'post': "postcode", 'price': "pricerange", 'type': "type", + 'none': None + }, + 'booking': { + 'day': 'day', 'name': 'name', 'people': 'people', + 'ref': 'ref', 'stay': 'stay', 'time': 'time', + 'none': None + }, + 'hospital': { + 'department': 'department', 'addr': 'address', 'post': 'postcode', + 'phone': 'phone', 'none': None + }, + 'hotel': { + 'addr': "address", 'area': "area", 'choice': "choice", + 'internet': "internet", 'name': "name", 'parking': "parking", + 'phone': "phone", 'post': "postcode", 'price': "pricerange", + 'ref': "ref", 'stars': "stars", 'type': "type", + 'none': None + }, + 'restaurant': { + 'addr': "address", 'area': "area", 'choice': "choice", + 'name': "name", 'food': "food", 'phone': "phone", + 'post': "postcode", 'price': "pricerange", 'ref': "ref", + 'none': None + }, + 'taxi': { + 'arrive': "arriveby", 'car': "car type", 'depart': "departure", + 'dest': "destination", 'leave': "leaveat", 'phone': "phone", + 'none': None + }, + 'train': { + 'arrive': "arriveby", 'choice': "choice", 'day': "day", + 'depart': "departure", 'dest': "destination", 'id': "trainid", 'trainid': "trainid", + 'leave': "leaveat", 'people': "people", 'ref': "ref", + 'ticket': "price", 'time': "duration", 'duration': 'duration', 'none': None + }, + 'police': { + 'addr': "address", 'post': "postcode", 'phone': "phone", 'name': 'name', 'none': None + } +} +ref_slot_data2stand = { + 'train': { + 'duration': 'time', 'price': 'ticket', 'trainid': 'id' + } +} diff --git a/convlab2/policy/tus/Goal.py b/convlab2/policy/tus/Goal.py new file mode 100644 index 0000000..13ec311 --- /dev/null +++ b/convlab2/policy/tus/Goal.py @@ -0,0 +1,405 @@ +import json +import os +from random import shuffle +from convlab2.task.multiwoz.goal_generator import GoalGenerator +from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal +from convlab2.policy.tus.multiwoz.util import parse_user_goal + +DEF_VAL_UNK = '?' # Unknown +DEF_VAL_DNC = 'dontcare' # Do not care +DEF_VAL_NUL = 'none' # for none +DEF_VAL_BOOKED = 'yes' # for booked +DEF_VAL_NOBOOK = 'no' # for booked +NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK, ""] + +ref_slot_data2stand = { + 'train': { + 'duration': 'time', 'price': 'ticket', 'trainid': 'id' + } +} + + +class Goal(object): + """ User Goal Model Class. """ + + def __init__(self, goal_generator=None, goal=None): + """ + create new Goal by random + Args: + goal_generator (GoalGenerator): Goal Gernerator. + """ + self.max_domain_len = 5 + self.max_slot_len = 20 + self.local_id = {} + path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + path = os.path.join(path, 'data/multiwoz/all_value.json') + self.all_values = json.load(open(path)) + if goal_generator is not None and goal is None: + self.domain_goals = goal_generator.get_user_goal() + self.domains = list(self.domain_goals['domain_ordering']) + del self.domain_goals['domain_ordering'] + elif goal_generator is None and goal is not None: + self.domains = [] + self.domain_goals = {} + for domain in goal: + if domain in SysDa2Goal and goal[domain]: # TODO check order + self.domains.append(domain) + self.domain_goals[domain] = goal[domain] + else: + print("Warning!!! One of goal_generator or goal should not be None!!!") + + for domain in self.domains: + if 'reqt' in self.domain_goals[domain].keys(): + self.domain_goals[domain]['reqt'] = { + slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']} + + if 'book' in self.domain_goals[domain].keys(): + self.domain_goals[domain]['booked'] = DEF_VAL_UNK + self.init_local_id() + + def init_local_id(self): + # local_id = { + # "domain 1": { + # "ID": [1, 0, 0], + # "SLOT": { + # "slot 1": [1, 0, 0], + # "slot 2": [0, 1, 0]}}} + + for domain_id, domain in enumerate(self.domains): + self._init_domain_id(domain) + self._update_domain_id(domain, domain_id) + slot_id = 0 + for slot_type in ["info", "book", "reqt"]: + for slot in self.domain_goals[domain].get(slot_type, {}): + self._init_slot_id(domain, slot) + self._update_slot_id(domain, slot, slot_id) + slot_id += 1 + + def insert_local_id(self, new_slot_name): + domain, slot = new_slot_name.split('-') + if domain not in self.local_id: + self._init_domain_id(domain) + domain_id = len(self.domains) + 1 + self._update_domain_id(domain, domain_id) + self._init_slot_id(domain, slot) + # the first slot for a new domain + self._update_slot_id(domain, slot, 0) + + else: + slot_id = len(self.local_id[domain]["SLOT"]) + 1 + self._init_slot_id(domain, slot) + self._update_slot_id(domain, slot, slot_id) + + def get_slot_id(self, slot_name): + domain, slot = slot_name.split('-') + if domain in self.local_id and slot in self.local_id[domain]["SLOT"]: + return self.local_id[domain]["ID"], self.local_id[domain]["SLOT"][slot] + else: # a slot not in original user goal + self.insert_local_id(slot_name) + domain_id, slot_id = self.get_slot_id(slot_name) + return domain_id, slot_id + + def task_complete(self): + """ + Check that all requests have been met + Returns: + (boolean): True to accomplish. + """ + for domain in self.domains: + if 'reqt' in self.domain_goals[domain]: + reqt_vals = self.domain_goals[domain]['reqt'].values() + for val in reqt_vals: + if val in NOT_SURE_VALS: + return False + + if 'booked' in self.domain_goals[domain]: + if self.domain_goals[domain]['booked'] in NOT_SURE_VALS: + return False + return True + + def next_domain_incomplete(self): + # request + for domain in self.domains: + # reqt + if 'reqt' in self.domain_goals[domain]: + requests = self.domain_goals[domain]['reqt'] + unknow_reqts = [ + key for (key, val) in requests.items() if val in NOT_SURE_VALS] + if len(unknow_reqts) > 0: + return domain, 'reqt', ['name'] if 'name' in unknow_reqts else unknow_reqts + + # book + if 'booked' in self.domain_goals[domain]: + if self.domain_goals[domain]['booked'] in NOT_SURE_VALS: + return domain, 'book', \ + self.domain_goals[domain]['fail_book'] if 'fail_book' in self.domain_goals[domain].keys() else \ + self.domain_goals[domain]['book'] + + return None, None, None + + def __str__(self): + return '-----Goal-----\n' + \ + json.dumps(self.domain_goals, indent=4) + \ + '\n-----Goal-----' + + def action_list(self, user_history=None, sys_act=None, all_values=None): + goal_slot = parse_user_goal(self) + if user_history: + priority_action = self._reorder_based_on_user_history( + user_history, goal_slot) + + else: + # priority_action = [slot for slot in goal_slot] + priority_action = self._reorder_random(goal_slot) + + if sys_act: + for intent, domain, slot, value in sys_act: + slot_name = self.act2slot( + intent, domain, slot, value, all_values) + # print("system_mention:", slot_name) + if slot_name and slot_name not in priority_action: + priority_action.insert(0, slot_name) + + return priority_action + + def get_booking_domain(self, slot, value, all_values): + for domain in self.domains: + if slot in all_values["all_value"] and value in all_values["all_value"][slot]: + return domain + print("NOT FOUND BOOKING DOMAIN") + return "" + + def act2slot(self, intent, domain, slot, value, all_values): + domain = domain.lower() + slot = slot.lower() + + if domain not in UsrDa2Goal: + # print(f"Not handle domain {domain}") + return "" + + if domain == "booking": + slot = SysDa2Goal[domain][slot] + domain = self.get_booking_domain(slot, value, all_values) + if domain: + return f"{domain}-{slot}" + + elif domain in UsrDa2Goal: + if slot in SysDa2Goal[domain]: + slot = SysDa2Goal[domain][slot] + elif slot in UsrDa2Goal[domain]: + slot = UsrDa2Goal[domain][slot] + elif slot in SysDa2Goal["booking"]: + slot = SysDa2Goal["booking"][slot] + # else: + # print( + # f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}") + + return f"{domain}-{slot}" + return "" + + def _reorder_random(self, goal_slot): + new_order = [slot for slot in goal_slot] + random_order = new_order[1:] + shuffle(random_order) + + return [new_order[0]] + random_order + + def _reorder_based_on_user_history(self, user_history, goal_slot): + # user_history = [slot_0, slot_1, ...] + new_order = [] + for slot in user_history: + if slot and slot not in new_order: + new_order.append(slot) + + for slot in goal_slot: + if slot not in new_order: + new_order.append(slot) + return new_order + + def update_user_goal(self, action=None, state=None): + # update request and booked + if action: + self._update_user_goal_from_action(action) + if state: + self._update_user_goal_from_state(state) + self._check_booked(state) # this should always check + + if action is None and state is None: + print("Warning!!!! Both action and state are None") + + def _check_booked(self, state): + for domain in self.domains: + if "booked" in self.domain_goals[domain]: + if self._check_book_info(state, domain): + self.domain_goals[domain]["booked"] = DEF_VAL_BOOKED + else: + self.domain_goals[domain]["booked"] = DEF_VAL_NOBOOK + + def _check_book_info(self, state, domain): + # need to check info, reqt for booked? + if domain not in state: + return False + + for slot_type in ['info', 'book']: + for slot in self.domain_goals[domain].get(slot_type, {}): + user_value = self.domain_goals[domain][slot_type][slot] + if slot in state[domain]["semi"]: + state_value = state[domain]["semi"][slot] + + elif slot in state[domain]["book"]: + state_value = state[domain]["book"][slot] + else: + state_value = "" + # only check mentioned values (?) + if state_value and state_value != user_value: + # print( + # f"booking info is incorrect, for slot {slot}: " + # f"goal {user_value} != state {state_value}") + return False + + return True + + def _update_user_goal_from_action(self, action): + for intent, domain, slot, value in action: + # print("update user goal from action") + # print(intent, domain, slot, value) + # print("action:", intent) + domain = domain.lower() + value = value.lower() + slot = slot.lower() + if slot == "ref": + for usr_domain in self.domains: + if "booked" in self.domain_goals[usr_domain]: + self.domain_goals[usr_domain]["booked"] = DEF_VAL_BOOKED + else: + domain, slot = self._norm_domain_slot(domain, slot, value) + + if self._check_update_request(domain, slot) and value != "?": + self.domain_goals[domain]['reqt'][slot] = value + # print(f"update reqt {slot} = {value} from system action") + + def _norm_domain_slot(self, domain, slot, value): + if domain == "booking": + # ["book", "booking", "people", 7] + if slot in SysDa2Goal[domain]: + slot = SysDa2Goal[domain][slot] + domain = self._get_booking_domain(slot, value) + else: + domain = "" + for d in SysDa2Goal: + if slot in SysDa2Goal[d]: + domain = d + slot = SysDa2Goal[d][slot] + if not domain: + return "", "" + return domain, slot + + elif domain in self.domains: + if slot in SysDa2Goal[domain]: + # ["request", "restaurant", "area", "north"] + slot = SysDa2Goal[domain][slot] + elif slot in UsrDa2Goal[domain]: + slot = UsrDa2Goal[domain][slot] + elif slot in SysDa2Goal["booking"]: + # ["inform", "hotel", "stay", 2] + slot = SysDa2Goal["booking"][slot] + # else: + # print( + # f"UNSEEN SLOT IN UPDATE GOAL {intent, domain, slot, value}") + return domain, slot + + else: + # domain = general + return "", "" + + def _update_user_goal_from_state(self, state): + for domain in state: + for slot in state[domain]["semi"]: + if self._check_update_request(domain, slot): + self._update_user_goal_from_semi(state, domain, slot) + for slot in state[domain]["book"]: + if slot == "booked" and state[domain]["book"]["booked"]: + self._update_booked(state, domain) + + elif state[domain]["book"][slot] and self._check_update_request(domain, slot): + self._update_book(state, domain, slot) + + def _update_slot(self, domain, slot, value): + self.domain_goals[domain]['reqt'][slot] = value + + def _update_user_goal_from_semi(self, state, domain, slot): + if self._check_value(state[domain]["semi"][slot]): + self._update_slot(domain, slot, state[domain]["semi"][slot]) + # print("update reqt {} in semi".format(slot), + # state[domain]["semi"][slot]) + + def _update_booked(self, state, domain): + # check state and goal is fulfill + self.domain_goals[domain]["booked"] = DEF_VAL_BOOKED + print("booked") + for booked_slot in state[domain]["book"]["booked"][0]: + if self._check_update_request(domain, booked_slot): + self._update_slot(domain, booked_slot, + state[domain]["book"]["booked"][0][booked_slot]) + # print("update reqt {} in booked".format(booked_slot), + # self.domain_goals[domain]['reqt'][booked_slot]) + + def _update_book(self, state, domain, slot): + if self._check_value(state[domain]["book"][slot]): + self._update_slot(domain, slot, state[domain]["book"][slot]) + # print("update reqt {} in book".format(slot), + # state[domain]["book"][slot]) + + def _check_update_request(self, domain, slot): + # check whether one slot is a request slot + if domain not in self.domain_goals: + return False + if 'reqt' not in self.domain_goals[domain]: + return False + if slot not in self.domain_goals[domain]['reqt']: + return False + return True + + def _check_value(self, value=None): + if not value: + return False + if value in NOT_SURE_VALS: + return False + return True + + def _get_booking_domain(self, slot, value): + """ + find the domain for domain booking, excluding slot "ref" + """ + found = "" + if not slot: # work around + return found + slot = slot.lower() + value = value.lower() + for domain in self.all_values["all_value"]: + if slot in self.all_values["all_value"][domain]: + if value in self.all_values["all_value"][domain][slot]: + if domain in self.domains: + found = domain + return found + + def _init_domain_id(self, domain): + self.local_id[domain] = {"ID": [0] * self.max_domain_len, "SLOT": {}} + + def _init_slot_id(self, domain, slot): + self.local_id[domain]["SLOT"][slot] = [0] * self.max_slot_len + + def _update_domain_id(self, domain, domain_id): + if domain_id < self.max_domain_len: + self.local_id[domain]["ID"][domain_id] = 1 + else: + print( + f"too many doamins: {domain_id} > {self.max_domain_len}") + + def _update_slot_id(self, domain, slot, slot_id): + if slot_id < self.max_slot_len: + self.local_id[domain]["SLOT"][slot][slot_id] = 1 + else: + print( + f"too many slots, {slot_id} > {self.max_slot_len}") diff --git a/convlab2/policy/tus/TUS.py b/convlab2/policy/tus/TUS.py new file mode 100644 index 0000000..ac50976 --- /dev/null +++ b/convlab2/policy/tus/TUS.py @@ -0,0 +1,293 @@ +import json +import math +import os +from copy import deepcopy +from random import choice + +import numpy as np +import torch +from convlab2.policy.tus.multiwoz.Goal import Goal +from convlab2.policy.tus.multiwoz.transformer import \ + TransformerActionPrediction +from convlab2.policy.tus.multiwoz.usermanager import Feature +from convlab2.policy.tus.multiwoz.util import parse_user_goal +from convlab2.policy.policy import Policy +from convlab2.task.multiwoz.goal_generator import GoalGenerator +from convlab2.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEF_VAL_UNK = '?' # Unknown +DEF_VAL_DNC = 'dontcare' # Do not care +DEF_VAL_NUL = 'none' # for none +DEF_VAL_BOOKED = 'yes' # for booked +DEF_VAL_NOBOOK = 'no' # for booked +Inform = "Inform" +Request = "Request" +NOT_SURE_VALS = [DEF_VAL_UNK, DEF_VAL_DNC, DEF_VAL_NUL, DEF_VAL_NOBOOK, ""] + +SLOT2SEMI = { + "arriveby": "arriveBy", + "leaveat": "leaveAt", + "trainid": "trainID", +} + + +class UserActionPolicy(Policy): + def __init__(self, config): + self.config = config + path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + path = os.path.join(path, 'data/multiwoz/all_value.json') + self.all_values = json.load(open(path)) + self.goal_gen = GoalGenerator() + Policy.__init__(self) + self.feat_handler = Feature(self.config) + self.config["num_token"] = config["num_token"] + self.user = TransformerActionPrediction(self.config).to(device=DEVICE) + self.load(os.path.join( + self.config["model_dir"], self.config["model_name"])) + self.user.eval() + self.use_domain_mask = self.config.get("domain_mask", False) + self.max_turn = 40 + self.mentioned_domain = [] + + def predict(self, state): + # update goal + self.goal.update_user_goal(action=state["system_action"], + state=state['belief_state']) + self.time_step += 2 + + self.predict_action_list = self.goal.action_list( + sys_act=state["system_action"], + all_values=self.all_values) + + feature, mask = self.feat_handler.get_feature( + self.predict_action_list, + self.goal, + state['belief_state'], + self.sys_history_state, + state["system_action"], + self.pre_usr_act) + feature = torch.tensor([feature], dtype=torch.float).to(DEVICE) + mask = torch.tensor([mask], dtype=torch.bool).to(DEVICE) + + self.sys_history_state = state['belief_state'] + + usr_output = self.user.forward(feature, mask) + usr_action = self.transform_usr_act( + usr_output, self.predict_action_list) + domains = [act[1] for act in usr_action] + none_slot_acts = self._add_none_slot_act(domains) + usr_action = none_slot_acts + usr_action + + self.pre_usr_act = deepcopy(usr_action) + + return usr_action + + def init_session(self): + self.mentioned_domain = [] + self.time_step = 0 + self.topic = 'NONE' + remove_domain = "police" # remove police domain in inference + # if "remove_domain" in self.config: + # remove_domain = self.config["remove_domain"] + self.new_goal(remove_domain=remove_domain) + + # print(self.goal) + if self.config.get("reorder", False): + self.predict_action_list = self.goal.action_list() + else: + self.predict_action_list = self.action_list + self.sys_history_state = None # to save sys history + self.terminated = False + self.feat_handler.initFeatureHandeler(self.goal) + self.pre_usr_act = None + + def new_goal(self, remove_domain="police", domain_len=None): + keep_generate_goal = True + while keep_generate_goal: + self.goal = Goal(self.goal_gen) + if (domain_len and len(self.goal.domains) != domain_len) or \ + (remove_domain and remove_domain in self.goal.domains): + keep_generate_goal = True + else: + keep_generate_goal = False + + def load(self, model_path=None): + self.user.load_state_dict(torch.load(model_path, map_location=DEVICE)) + + def get_goal(self): + return self.goal.domain_goals + + def get_reward(self): + if self.goal.task_complete(): + reward = 2 * self.max_turn + + elif self.time_step >= self.max_turn: + reward = -1 * self.max_turn + + else: + reward = -1.0 + return reward + + def _add_none_slot_act(self, domains): + actions = [] + for domain in domains: + domain = domain.lower() + if domain not in self.mentioned_domain and domain != 'general': + actions.append([Inform, domain.capitalize(), "None", "None"]) + self.mentioned_domain.append(domain) + return actions + + def transform_usr_act(self, usr_output, action_list): + if self.goal.task_complete(): + self.terminated = True + #print("task compltet") + return [["bye", "general", "None", "None"]] + + usr_action = [] + # [score, [action]]: when usr_action is none and task doesn't finish + non_zero_action = [] + for index, slot_name in enumerate(action_list): + domain, slot = slot_name.split('-') + self._add_user_action( + usr_action, + output=torch.argmax(usr_output[0, index + 1, :]).item(), + domain=domain.capitalize(), + slot=SLOT2SEMI.get(slot, slot)) + non_zero_action = self._max_non_zero_action( + action=non_zero_action, + score=torch.max(usr_output[0, index + 1, 1:]).item(), + output=torch.argmax(usr_output[0, index + 1, 1:]).item() + 1, + domain=domain.capitalize(), + slot=SLOT2SEMI.get(slot, slot)) + + if self.use_domain_mask: + domain_mask = self._get_prediction_domain(torch.round( + torch.sigmoid(usr_output[0, 0, :])).tolist()) + usr_action = self._mask_user_action(usr_action, domain_mask) + + if self.time_step > self.max_turn: + #print(f"max turn: {self.time_step}") + self.terminated = True + return [["bye", "general", "None", "None"]] + + # usr_action = None + + if not usr_action and not self.terminated: + # print(f"pick max none zero action {non_zero_action}") + return non_zero_action[1] + if not usr_action: + print("!!!STRANGE!!!!!!") + print(usr_action, non_zero_action, action_list) + return usr_action + + def _mask_user_action(self, usr_action, mask): + mask_action = [] + for intent, domain, slot, value in usr_action: + if domain.lower() in mask: + mask_action += [[intent, domain, slot, value]] + return mask_action + + def _get_prediction_domain(self, domain_output): + predict_domain = [] + if domain_output[0] > 0: + predict_domain.append('general') + for index, value in enumerate(domain_output[1:]): + if value > 0 and index < len(self.goal.domains): + predict_domain.append(self.goal.domains[index]) + return predict_domain + + def _max_non_zero_action(self, action, score, output, domain, slot): + if not action: + action = [score, []] + + if score > action[0]: + temp = [] + self._add_user_action(temp, output, domain, slot) + action = [score, temp] + elif score == action[0]: + self._add_user_action(action[1], output, domain, slot) + return action + + def _add_user_action(self, usr_action, output, domain, slot): + goal = self.get_goal() + # "?" + if output == 1: # "?" + self._add_action(usr_action, Request, domain, slot, DEF_VAL_UNK) + + # "dontcare" + elif output == 2: + self._add_action(usr_action, Inform, domain, slot, DEF_VAL_DNC) + + # system + elif output == 3 and self._slot_type(domain, slot): + slot_type = self._slot_type(domain, slot) + value = self.sys_history_state[domain.lower()][slot_type].get( + slot, "") + if value: + self._add_action(usr_action, Inform, domain, slot, value) + + elif output == 4 and domain.lower() in goal: # usr + value = None + if slot in goal[domain.lower()].get("info", {}): + value = goal[domain.lower()]["info"][slot] + elif slot in goal[domain.lower()].get("book", {}): + value = goal[domain.lower()]["book"][slot] + if value: + self._add_action(usr_action, Inform, domain, slot, value) + + elif output == 5 and domain.lower() in goal: + # TODO random select, now we use the value from user goal + # print(f"NOT HANDLE SLOT {slot} IN DOMAIN {domain} (5)!!!") + value = None + if slot in goal[domain.lower()].get("info", {}): + value = goal[domain.lower()]["info"][slot] + elif slot in goal[domain.lower()].get("book", {}): + value = goal[domain.lower()]["book"][slot] + if value: + self._add_action(usr_action, Inform, domain, slot, value) + + def _get_action_slot(self, domain, slot): + return REF_USR_DA[domain.capitalize()].get(slot, None) + + def _add_action(self, usr_action, intent, domain, slot, value): + action_slot = self._get_action_slot(domain, slot) + if action_slot: + usr_action += [[intent, domain, action_slot, value]] + + def is_terminated(self): + # Is there any action to say? + return self.terminated + + def _slot_type(self, domain, slot): + slot_type = "" + if slot in self.sys_history_state[domain.lower()]["book"]: + slot_type = "book" + elif slot in self.sys_history_state[domain.lower()]["semi"]: + slot_type = "semi" + + return slot_type + + +class UserPolicy(Policy): + def __init__(self, config): + self.config = config + self.policy = UserActionPolicy(self.config) + + def predict(self, state): + return self.policy.predict(state) + + def init_session(self): + self.policy.init_session() + + def is_terminated(self): + return self.policy.is_terminated() + + def get_reward(self): + return self.policy.get_reward() + + def get_goal(self): + if hasattr(self.policy, 'get_goal'): + return self.policy.get_goal() + return None diff --git a/convlab2/policy/tus/__init__.py b/convlab2/policy/tus/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/convlab2/policy/tus/analysis.py b/convlab2/policy/tus/analysis.py new file mode 100644 index 0000000..f8d358a --- /dev/null +++ b/convlab2/policy/tus/analysis.py @@ -0,0 +1,430 @@ +import argparse +import json +import os +import random +import numpy as np +import logging + + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sn +import torch +from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator +from convlab2.dialog_agent.env import Environment +from convlab2.dialog_agent.session import BiSession +from convlab2.dialog_agent.agent import PipelineAgent +from convlab2.dst.rule.multiwoz import RuleDST +from convlab2.dst.rule.multiwoz.usr_dst import UserRuleDST +from convlab2.policy.tus.multiwoz.TUS import UserPolicy +from convlab2.policy.tus.multiwoz.transformer import \ + TransformerActionPrediction +from convlab2.policy.tus.multiwoz.usermanager import \ + TUSDataManager +from convlab2.policy.rule.multiwoz import RulePolicy +from sklearn import metrics +from torch.utils.data import DataLoader +from tqdm import tqdm +import datetime + + +def check_device(): + if torch.cuda.is_available(): + print("using GPU") + return torch.device('cuda') + else: + print("using CPU") + return torch.device('cpu') + + +def init_logging(log_dir_path, path_suffix=None): + if not os.path.exists(log_dir_path): + os.makedirs(log_dir_path) + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + if path_suffix: + log_file_path = os.path.join( + log_dir_path, f"{current_time}_{path_suffix}.log") + else: + log_file_path = os.path.join( + log_dir_path, "{}.log".format(current_time)) + + stderr_handler = logging.StreamHandler() + file_handler = logging.FileHandler(log_file_path) + format_str = "%(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s" + logging.basicConfig(level=logging.DEBUG, handlers=[ + stderr_handler, file_handler], format=format_str) + + +class Analysis: + def __init__(self, config, analysis_dir='user-analysis-result', show_dialog=False, save_dialog=True): + if not os.path.exists(analysis_dir): + os.makedirs(analysis_dir) + self.dialog_dir = os.path.join(analysis_dir, 'dialog') + if not os.path.exists(self.dialog_dir): + os.makedirs(self.dialog_dir) + self.dir = analysis_dir + self.config = config + self.device = check_device() + self.show_dialog = show_dialog + self.save_dialog = save_dialog + self.max_turn = 40 + + def get_sys(self, sys="rule", load_path=None): + dst = RuleDST() + + sys = sys.lower() + if sys == "rule": + policy = RulePolicy() + elif sys == "ppo": + from convlab2.policy.ppo import PPO + if load_path: + policy = PPO(False, use_action_mask=True, shrink=False) + policy.load(load_path) + else: + policy = PPO.from_pretrained() + elif sys == "vtrace": + from convlab2.policy.vtrace_rnn_action_embedding import VTRACE_RNN + policy = VTRACE_RNN( + is_train=False, seed=0, use_masking=True, shrink=False) + policy.load(load_path) + else: + print(f"Unsupport system type: {sys}") + + return dst, policy + + def get_usr(self, usr="tus", load_path=None): + # if using "tus", we read config + # for the other user simulators, we read load_path + usr = usr.lower() + if usr == "rule": + dst_usr = None + policy_usr = RulePolicy(character='usr') + elif usr == "tus": + dst_usr = UserRuleDST() + policy_usr = UserPolicy(self.config) + elif usr == "vhus": + from convlab2.policy.vhus.multiwoz import UserPolicyVHUS + + dst_usr = None + policy_usr = UserPolicyVHUS( + load_from_zip=True, model_file="vhus_simulator_multiwoz.zip") + else: + print(f"Unsupport user type: {usr}") + # TODO VHUS + + return dst_usr, policy_usr + + def interact_test(self, + sys="rule", + usr="tus", + sys_load_path=None, + usr_load_path=None, + num_dialog=400, + domain=None): + # TODO need refactor + seed = 20190827 + torch.manual_seed(seed) + sys = sys.lower() + usr = usr.lower() + + sess = self._set_interactive_test( + sys, usr, sys_load_path, usr_load_path) + + task_success = { + # 'All_user_sim': [], 'All_evaluator': [], 'total_return': []} + 'complete': [], 'success': [], 'reward': []} + + turn_slot_num = {i: [] for i in range(self.max_turn)} + turn_domain_num = {i: [] for i in range(self.max_turn)} + true_max_turn = 0 + + for seed in tqdm(range(1000, 1000 + num_dialog)): + # logging.info(f"Seed: {seed}") + + random.seed(seed) + np.random.seed(seed) + sess.init_session() + # if domain is not none, the user goal must contain certain domain + if domain: + domain = domain.lower() + print(f"check {domain}") + while 1: + if domain in sess.user_agent.policy.get_goal(): + break + sess.user_agent.init_session() + sys_uttr = [] + actions = 0 + total_return = 0.0 + if self.save_dialog: + f = open(os.path.join(self.dialog_dir, str(seed)), 'w') + for turn in range(self.max_turn): + sys_uttr, usr_uttr, finish, reward = sess.next_turn(sys_uttr) + if self.show_dialog: + print(f"USR: {usr_uttr}") + print(f"SYS: {sys_uttr}") + if self.save_dialog: + f.write(f"USR: {usr_uttr}\n") + f.write(f"SYS: {sys_uttr}\n") + actions += len(usr_uttr) + turn_slot_num[turn].append(len(usr_uttr)) + turn_domain_num[turn].append(self._get_domain_num(usr_uttr)) + total_return += sess.user_agent.policy.policy.get_reward() + + if finish: + task_succ = sess.evaluator.task_success() + break + if turn > true_max_turn: + true_max_turn = turn + if self.save_dialog: + f.close() + # logging.info(f"Return: {total_return}") + # logging.info(f"Average actions: {actions / (turn+1)}") + + task_success['complete'].append( + int(sess.user_agent.policy.policy.goal.task_complete())) + task_success['success'].append(task_succ) + task_success['reward'].append(total_return) + task_summary = {key: [0] for key in task_success} + for key in task_success: + if task_success[key]: + task_summary[key][0] = np.average(task_success[key]) + + for key in task_success: + logging.info( + f'{key} {len(task_success[key])} {task_summary[key][0]}') + + # logging.info("Average action in turn") + write = {'turn_slot_num': [], 'turn_domain_num': []} + for turn in turn_slot_num: + if turn > true_max_turn: + break + avg = 0 + if turn_slot_num[turn]: + avg = sum(turn_slot_num[turn]) / len(turn_slot_num[turn]) + write['turn_slot_num'].append(avg) + # logging.info(f"turn {turn}: {avg} slots") + for turn in turn_domain_num: + if turn > true_max_turn: + break + avg = 0 + if turn_domain_num[turn]: + avg = sum(turn_domain_num[turn]) / len(turn_domain_num[turn]) + write['turn_domain_num'].append(avg) + # logging.info(f"turn {turn}: {avg} domains") + + # write results + pd.DataFrame.from_dict(write).to_csv( + os.path.join(self.dir, f'{sys}-{usr}-turn-statistics.csv')) + pd.DataFrame.from_dict(task_summary).to_csv( + os.path.join(self.dir, f'{sys}-{usr}-task-summary.csv')) + + def _get_domain_num(self, action): + # act: [Intent, Domain, Slot, Value] + return len(set(act[1] for act in action)) + + def _set_interactive_test(self, sys, usr, sys_load_path, usr_load_path): + dst_sys, policy_sys = self.get_sys(sys, sys_load_path) + dst_usr, policy_usr = self.get_usr(usr, usr_load_path) + + usr = PipelineAgent(None, dst_usr, policy_usr, None, 'user') + sys = PipelineAgent(None, dst_sys, policy_sys, None, 'sys') + env = Environment(None, usr, None, dst_sys) + evaluator = MultiWozEvaluator() + sess = BiSession(sys, usr, None, evaluator) + + return sess + + def direct_test(self, model, test_data): + + model = model.to(self.device) + model.zero_grad() + model.eval() + y_lable, y_pred = [], [] + y_turn = [] + + result = {} # old way + + with torch.no_grad(): + for i, data in enumerate(tqdm(test_data, ascii=True, desc="Evaluation"), 0): + input_feature = data["input"].to(self.device) + mask = data["mask"].to(self.device) + label = data["label"].to(self.device) + output = model(input_feature, mask) + y_l, y_p, y_t, r = self.parse_result(output, label) + y_lable += y_l + y_pred += y_p + y_turn += y_t + # old way + for r_type in r: + if r_type not in result: + result[r_type] = {"correct": 0, "total": 0} + for n in result[r_type]: + result[r_type][n] += float(r[r_type][n]) + old_result = {} + for r_type in result: + temp = result[r_type]['correct'] / result[r_type]['total'] + old_result[r_type] = [temp] + + pd.DataFrame.from_dict(old_result).to_csv( + os.path.join(self.dir, f'old_result.csv')) + + cm = self.model_confusion_matrix(y_lable, y_pred) + self.summary(y_lable, y_pred, y_turn, cm) + + return old_result + + def summary(self, y_true, y_pred, y_turn, cm, file_name='scores.csv'): + result = { + 'f1': metrics.f1_score(y_true, y_pred, average='micro'), + 'precision': metrics.precision_score(y_true, y_pred, average='micro'), + 'recall': metrics.recall_score(y_true, y_pred, average='micro'), + 'none-zero-acc': self.none_zero_acc(cm), + 'turn-acc': sum(y_turn) / len(y_turn)} + col = [c for c in result] + df_f1 = pd.DataFrame([result[c] for c in col], col) + df_f1.to_csv(os.path.join(self.dir, file_name)) + + def none_zero_acc(self, cm): + # ['Unnamed: 0', 'none', '?', 'dontcare', 'sys', 'usr', 'random'] + col = cm.columns[1:] + num_label = cm.sum(axis=1) + correct = 0 + for col_name in col: + correct += cm[col_name][col_name] + return correct / sum(num_label[1:]) + + def model_confusion_matrix(self, y_true, y_pred, file_name='cm.csv', legend=["none", "?", "dontcare", "sys", "usr", "random"]): + cm = metrics.confusion_matrix(y_true, y_pred) + df_cm = pd.DataFrame(cm, legend, legend) + df_cm.to_csv(os.path.join(self.dir, file_name)) + return df_cm + + def parse_result(self, prediction, label): + _, arg_prediction = torch.max(prediction.data, -1) + batch_size, token_num = label.shape + y_true, y_pred = [], [] + y_turn = [] + result = { + "non-zero": {"correct": 0, "total": 0}, + "total": {"correct": 0, "total": 0}, + "turn": {"correct": 0, "total": 0} + } + + for batch_num in range(batch_size): + + turn_acc = True # old way + turn_success = 1 # new way + + for element in range(token_num): + result["total"]["total"] += 1 + l = label[batch_num][element].item() + p = arg_prediction[batch_num][element + 1].item() + # old way + if l > 0: + result["non-zero"]["total"] += 1 + if p == l: + if l > 0: + result["non-zero"]["correct"] += 1 + result["total"]["correct"] += 1 + elif p == 0 and l < 0: + result["total"]["correct"] += 1 + + else: + if l >= 0: + turn_acc = False + + # new way + if l >= 0: + y_true.append(l) + y_pred.append(p) + if l >= 0 and l != p: + turn_success = 0 + y_turn.append(turn_success) + # old way + result["turn"]["total"] += 1 + if turn_acc: + result["turn"]["correct"] += 1 + + return y_true, y_pred, y_turn, result + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--analysis_dir", type=str, + default="user-analysis-result") + parser.add_argument("--user_config", type=str, + default="convlab2/policy/tus/multiwoz/exp/default.json") + parser.add_argument("--user_mode", type=str, default="") + parser.add_argument("--version", type=str, default="") + parser.add_argument("--do_direct", action="store_true") + parser.add_argument("--do_interact", action="store_true") + parser.add_argument("--show_dialog", type=bool, default=False) + parser.add_argument("--usr", type=str, default="tus") + parser.add_argument("--sys", type=str, default="rule") + parser.add_argument("--num_dialog", type=int, default=400) + parser.add_argument("--use_mask", action="store_true") + parser.add_argument("--sys_config", type=str, default="") + parser.add_argument("--sys_model_dir", type=str, + default="convlab2/policy/ppo/save/") + parser.add_argument("--domain", type=str, default="", + help="the user goal must contain a specific domain") + + args = parser.parse_args() + + analysis_dir = os.path.join(args.analysis_dir, f"{args.sys}-{args.usr}") + if args.version: + analysis_dir = os.path.join(analysis_dir, args.version) + + if not os.path.exists(os.path.join(analysis_dir)): + os.makedirs(analysis_dir) + + config = json.load(open(args.user_config)) + init_logging(log_dir_path=os.path.join(analysis_dir, "log")) + if args.user_mode: + config["model_name"] = config["model_name"] + '-' + args.user_mode + with open(config["all_slot"]) as f: + action_list = [line.strip() for line in f] + config["num_token"] = len(action_list) + if args.use_mask: + config["domain_mask"] = True + + ana = Analysis(config, analysis_dir=analysis_dir, + show_dialog=args.show_dialog) + if args.usr == "tus" and args.do_direct: + test_data = DataLoader( + TUSDataManager( + config, data_dir="data", set_type='test'), + batch_size=config["batch_size"], + shuffle=True) + + model = TransformerActionPrediction(config) + if args.user_mode: + model.load_state_dict(torch.load( + os.path.join(config["model_dir"], config["model_name"]))) + print(args.user_mode) + old_result = ana.direct_test(model, test_data) + print(old_result) + else: + for user_mode in ["loss", "total", "turn", "non-zero"]: + model.load_state_dict(torch.load( + os.path.join(config["model_dir"], config["model_name"] + '-' + user_mode))) + print(user_mode) + old_result = ana.direct_test(model, test_data) + print(old_result) + + if args.do_interact: + sys_load_path = None + if args.sys_config: + _, file_extension = os.path.splitext(args.sys_config) + # read from config + if file_extension == ".json": + sys_config = json.load(open(args.sys_config)) + file_name = f"{sys_config['current_time']}_best_complete_rate_ppo" + sys_load_path = os.path.join(args.sys_model_dir, file_name) + # read from file + else: + sys_load_path = args.sys_config + ana.interact_test(sys=args.sys, usr=args.usr, + sys_load_path=sys_load_path, + num_dialog=args.num_dialog, + domain=args.domain) diff --git a/convlab2/policy/tus/config.json b/convlab2/policy/tus/config.json new file mode 100755 index 0000000..da62b04 --- /dev/null +++ b/convlab2/policy/tus/config.json @@ -0,0 +1,14 @@ +{ + "batchsz": 32, + "epoch": 16, + "lr": 0.001, + "save_dir": "save", + "log_dir": "log", + "print_per_batch": 400, + "save_per_epoch": 5, + "hu_dim": 200, + "eu_dim": 150, + "max_ulen": 20, + "alpha": 0.01, + "load": "save/best" +} \ No newline at end of file diff --git a/convlab2/policy/tus/exp/default.json b/convlab2/policy/tus/exp/default.json new file mode 100644 index 0000000..8ecad85 --- /dev/null +++ b/convlab2/policy/tus/exp/default.json @@ -0,0 +1,32 @@ +{ + "model_dir": "convlab2/policy/tus/multiwoz/default", + "model_name": "model", + "num_epoch": 50, + "batch_size": 128, + "learning_rate": 1e-4, + "num_token": 65, + "debug": false, + "gelu": false, + "dropout": 0.1, + "embed_dim": 78, + "out_dim": 6, + "hidden": 200, + "num_transformer": 2, + "weight_factor": [ + 1, + 1, + 1, + 1, + 1, + 5 + ], + "window": 3, + "nhead": 4, + "dim_feedforward": 200, + "num_transform_layer": 2, + "turn-pos": false, + "reorder": true, + "conflict": false, + "remove_domain": "police", + "domain_feat": true +} \ No newline at end of file diff --git a/convlab2/policy/tus/train.py b/convlab2/policy/tus/train.py new file mode 100644 index 0000000..85af35e --- /dev/null +++ b/convlab2/policy/tus/train.py @@ -0,0 +1,215 @@ +import json +import os +import time + +import torch +import torch.optim as optim +from tqdm import tqdm +from convlab2.policy.tus.multiwoz.analysis import Analysis + + +def check_device(): + if torch.cuda.is_available(): + print("using GPU") + return torch.device('cuda') + else: + print("using CPU") + return torch.device('cpu') + + +class Trainer: + def __init__(self, model, config): + self.model = model + self.config = config + self.num_epoch = self.config["num_epoch"] + self.batch_size = self.config["batch_size"] + self.device = check_device() + print(self.device) + self.optimizer = optim.Adam( + model.parameters(), lr=self.config["learning_rate"]) + # self.exp_time = time.strftime("%d-%b-%Y-%H-%M-%S", time.localtime()) + # os.mkdir(self.exp_time) + self.ana = Analysis(config) + + def training(self, train_data, test_data=None): + + self.model = self.model.to(self.device) + if not os.path.exists(self.config["model_dir"]): + os.makedirs(self.config["model_dir"]) + + save_path = os.path.join( + self.config["model_dir"], self.config["model_name"]) + + # best = [0, 0, 0] + best = {"loss": 100} + lowest_loss = 100 + for epoch in range(self.num_epoch): + print("epoch {}".format(epoch)) + total_loss = self.train_epoch(train_data) + print("loss: {}".format(total_loss)) + if test_data is not None: + acc = self.eval(test_data) + + if total_loss < lowest_loss: + best["loss"] = total_loss + print(f"save model in {save_path}-loss") + torch.save(self.model.state_dict(), f"{save_path}-loss") + + for acc_type in acc: + if acc_type not in best: + best[acc_type] = 0 + temp = acc[acc_type]["correct"] / acc[acc_type]["total"] + if best[acc_type] < temp: + best[acc_type] = temp + print(f"save model in {save_path}-{acc_type}") + torch.save(self.model.state_dict(), + f"{save_path}-{acc_type}") + if epoch < 10 and epoch > 5: + print(f"save model in {save_path}-{epoch}") + torch.save(self.model.state_dict(), + f"{save_path}-{epoch}") + print(f"save latest model in {save_path}") + torch.save(self.model.state_dict(), save_path) + + def train_epoch(self, data_loader): + self.model.train() + total_loss = 0 + result = {} + # result = {"id": {"slot": {"prediction": [],"label": []}}} + count = 0 + for i, data in enumerate(tqdm(data_loader, ascii=True, desc="Training"), 0): + input_feature = data["input"].to(self.device) + mask = data["mask"].to(self.device) + label = data["label"].to(self.device) + if self.config.get("domain_traget", True): + domain = data["domain"].to(self.device) + else: + domain = None + self.optimizer.zero_grad() + + loss, output = self.model(input_feature, mask, label, domain) + + loss.backward() + self.optimizer.step() + total_loss += float(loss) + count += 1 + + return total_loss / count + + def eval(self, test_data): + self.model.zero_grad() + self.model.eval() + + result = {} + + with torch.no_grad(): + correct, total, non_zero_correct, non_zero_total = 0, 0, 0, 0 + for i, data in enumerate(tqdm(test_data, ascii=True, desc="Evaluation"), 0): + input_feature = data["input"].to(self.device) + mask = data["mask"].to(self.device) + label = data["label"].to(self.device) + output = self.model(input_feature, mask) + r = parse_result(output, label) + for r_type in r: + if r_type not in result: + result[r_type] = {"correct": 0, "total": 0} + for n in result[r_type]: + result[r_type][n] += float(r[r_type][n]) + + for r_type in result: + temp = result[r_type]['correct'] / result[r_type]['total'] + print(f"{r_type}: {temp}") + + return result + + +def parse_result(prediction, label): + # result = {"id": {"slot": {"prediction": [],"label": []}}} + # dialog_index = ["dialog-id"_"slot-name", "dialog-id"_"slot-name", ...] + # prdiction = [0, 1, 0, ...] # after max + + _, arg_prediction = torch.max(prediction.data, -1) + batch_size, token_num = label.shape + result = { + "non-zero": {"correct": 0, "total": 0}, + "total": {"correct": 0, "total": 0}, + "turn": {"correct": 0, "total": 0} + } + + for batch_num in range(batch_size): + turn_acc = True + for element in range(token_num): + result["total"]["total"] += 1 + if label[batch_num][element] > 0: + result["non-zero"]["total"] += 1 + + if arg_prediction[batch_num][element + 1] == label[batch_num][element]: + if label[batch_num][element] > 0: + result["non-zero"]["correct"] += 1 + result["total"]["correct"] += 1 + + elif arg_prediction[batch_num][element + 1] == 0 and label[batch_num][element] < 0: + result["total"]["correct"] += 1 + + else: + if label[batch_num][element] >= 0: + turn_acc = False + + result["turn"]["total"] += 1 + if turn_acc: + result["turn"]["correct"] += 1 + + return result + + +def f1(target, result): + target_len = 0 + result_len = 0 + tp = 0 + for t, r in zip(target, result): + if t: + target_len += 1 + if r: + result_len += 1 + if r == t and t: + tp += 1 + precision = 0 + recall = 0 + if result_len: + precision = tp / result_len + if target_len: + recall = tp / target_len + f1_score = 2 / (1 / precision + 1 / recall) + return f1_score, precision, recall + + +if __name__ == "__main__": + import argparse + import os + from convlab2.policy.tus.multiwoz.transformer import \ + TransformerActionPrediction + from convlab2.policy.tus.multiwoz.usermanager import \ + TUSDataManager + from torch.utils.data import DataLoader + + parser = argparse.ArgumentParser() + parser.add_argument("--user_config", type=str, + default="convlab2/policy/tus/multiwoz/exp/default.json") + + args = parser.parse_args() + config = json.load(open(args.user_config)) + + batch_size = config["batch_size"] + + # train_data = TUSDataManager( + # config, data_dir="data", set_type='train') + train_data = TUSDataManager(config, data_dir="data", set_type='val') + embed_dim = train_data.features["input"].shape[-1] + assert embed_dim == config["embed_dim"] + train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True) + test_data = DataLoader( + TUSDataManager(config, data_dir="data", set_type='test'), batch_size=batch_size, shuffle=True) + + model = TransformerActionPrediction(config) + trainer = Trainer(model, config) + trainer.training(train_data, test_data) diff --git a/convlab2/policy/tus/transformer.py b/convlab2/policy/tus/transformer.py new file mode 100644 index 0000000..38f3174 --- /dev/null +++ b/convlab2/policy/tus/transformer.py @@ -0,0 +1,177 @@ +import json +import math +import os +from copy import deepcopy +from random import choice + +import numpy as np +import torch + +from convlab2.policy.policy import Policy +from torch import nn +from torch.autograd import Variable +from torch.nn import (GRU, CrossEntropyLoss, LayerNorm, Linear, + TransformerEncoder, TransformerEncoderLayer) + +# TODO masking +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def get_clones(module, N): + return nn.ModuleList([deepcopy(module) for i in range(N)]) + + +class EncodeLayer(torch.nn.Module): + def __init__(self, config): + super(EncodeLayer, self).__init__() + self.config = config + self.num_token = self.config["num_token"] + self.embed_dim = self.config["hidden"] + transform_layer = TransformerEncoderLayer( + d_model=self.embed_dim, + nhead=self.config["nhead"], + dim_feedforward=self.config["hidden"], + activation='gelu') + + self.norm_1 = LayerNorm(self.embed_dim) + self.encoder = TransformerEncoder( + encoder_layer=transform_layer, + num_layers=self.config["num_transform_layer"], + norm=self.norm_1) + self.norm_2 = LayerNorm(self.embed_dim) + self.use_gelu = self.config.get("gelu", False) + if self.use_gelu: + self.fc_1 = Linear(self.embed_dim, self.embed_dim) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(self.config["dropout"]) + self.fc_2 = Linear(self.embed_dim, self.embed_dim) + else: + self.fc = Linear(self.embed_dim, self.embed_dim) + self.init_weights() + + def init_weights(self): + initrange = 0.1 + if self.use_gelu: + self.fc_1.weight.data.uniform_(-initrange, initrange) + self.fc_2.weight.data.uniform_(-initrange, initrange) + else: + self.fc.weight.data.uniform_(-initrange, initrange) + + def forward(self, x, mask): + x = self.encoder(x, src_key_padding_mask=mask) + return x + + +class TransformerActionPrediction(torch.nn.Module): + def __init__(self, config): + super(TransformerActionPrediction, self).__init__() + self.config = config + self.num_transformer = self.config["num_transformer"] + self.embed_dim = self.config["embed_dim"] + self.out_dim = self.config["out_dim"] + self.hidden = self.config["hidden"] + self.softmax = nn.Softmax(dim=-1) + self.num_token = self.config["num_token"] + self.embed_linear = Linear(self.embed_dim, self.hidden) + self.position = PositionalEncoding(self.hidden, self.config) + self.encoder_layers = get_clones( + EncodeLayer(self.config), N=self.num_transformer) + + self.norm_1 = LayerNorm(self.hidden) + self.decoder = Linear(self.hidden, self.out_dim) + self.norm_2 = LayerNorm(self.out_dim) + + weight = [1.0] * self.out_dim + for i in range(self.out_dim): + weight[i] /= self.config["weight_factor"][i] + + weight = torch.tensor(weight) + self.loss = CrossEntropyLoss(weight=weight, ignore_index=-1) + self.pick_loss = CrossEntropyLoss() + self.similarity = nn.CosineSimilarity(dim=-1) + self.domain_loss = nn.BCEWithLogitsLoss() + self.init_weights() + + def init_weights(self): + initrange = 0.1 + self.embed_linear.weight.data.uniform_(-initrange, initrange) + self.decoder.bias.data.zero_() + self.decoder.weight.data.uniform_(-initrange, initrange) + + def forward(self, input_feat, mask, label=None, domain_label=None): + + src = self.embed_linear(input_feat) * math.sqrt(self.hidden) + src = self.position(src) + src = src.permute(1, 0, 2) + for i in range(self.num_transformer): + src = self.encoder_layers[i](src, mask) + + out_src = self.norm_1(src) + out_src = src + out_src + + out_src = out_src.permute(1, 0, 2) + + out = self.decoder(out_src) + out = self.norm_2(out) + + if label is not None: + if domain_label is None: + loss = self.get_loss(out, label) + else: + loss = self.get_loss(out, label) + \ + self.first_token_loss(out, domain_label) + return loss, out + return out + + def get_loss(self, prediction, target): + # prediction = [batch_size, num_token, out_dim] + # target = [batch_size, num_token] + # first token is CLS + pre = prediction[:, 1: self.num_token + 1, :] + pre = torch.reshape(pre, (pre.shape[0]*pre.shape[1], pre.shape[-1])) + l = self.loss(pre, target.view(-1)) + + return l + + def first_token_loss(self, prediction, target): + # prediction = [batch_size, num_token, out_dim] + # target = [batch_size, num_token] + pre = prediction[:, 0, :] + l = self.domain_loss(pre, target) + return l + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, config, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.config = config + self.dropout = nn.Dropout(p=dropout) + self.turn_pos = self.config.get("turn-pos", True) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + position = position // 65 + div_term = torch.exp(torch.arange( + 0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + pe1 = torch.zeros(max_len, d_model) + position1 = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term1 = torch.exp(torch.arange( + 0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe1[:, 0::2] = torch.sin(position1 * div_term1) + pe1[:, 1::2] = torch.cos(position1 * div_term1) + pe1 = pe1.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe1', pe1) + + def forward(self, x): + if self.turn_pos: + x = x + self.pe1[:x.size(0), :]*0.5 + self.pe[:x.size(0), :]*0.5 + else: + x = x + self.pe1[:x.size(0), :] + # x = torch.cat((x, self.pe, self.pe1), -1) + return self.dropout(x) diff --git a/convlab2/policy/tus/usermanager.py b/convlab2/policy/tus/usermanager.py new file mode 100644 index 0000000..1cf3818 --- /dev/null +++ b/convlab2/policy/tus/usermanager.py @@ -0,0 +1,550 @@ +from convlab2.policy.tus.multiwoz import util +from convlab2.dst.rule.multiwoz.usr_dst import UserRuleDST +from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal +from collections import Counter +from convlab2.policy.tus.multiwoz.Goal import Goal +from torch.utils.data import DataLoader, Dataset, RandomSampler +from pprint import pprint +from random import uniform +from tqdm import tqdm +import json +import torch +import os + +NOT_MENTIONED = "not mentioned" + + +def dic2list(da2goal): + action_list = [] + for domain in da2goal: + for slot in da2goal[domain]: + if da2goal[domain][slot] is None: + continue + act = f"{domain}-{da2goal[domain][slot]}" + if act not in action_list: + action_list.append(act) + return action_list + + +class TUSDataManager(Dataset): + def __init__(self, + config, + data_dir='data', + set_type='train', + max_turns=12): + + self.config = config + if set_type == "train": + data = json.load( + open(os.path.join(data_dir, 'multiwoz/train_modified.json'), 'r')) + elif set_type == "val": + data = json.load( + open(os.path.join(data_dir, 'multiwoz/val.json'), 'r')) + elif set_type == "test": + data = json.load( + open(os.path.join(data_dir, 'multiwoz/test.json'), 'r')) + else: + print("UNKNOWN DATA TYPE") + + self.all_values = json.load(open("data/multiwoz/all_value.json")) + + self.remove_domain = self.config.get("remove_domain", "") + self.feature_handler = Feature(self.config) + self.features = self.process(data, max_turns) + + def __getitem__(self, index): + return {label: self.features[label][index] if self.features[label] is not None else None + for label in self.features} + + def __len__(self): + return self.features['input'].size(0) + + def resample(self, size=None): + n_dialogues = self.__len__() + if not size: + size = n_dialogues + + dialogues = torch.randint(low=0, high=n_dialogues, size=(size,)) + self.features = { + label: self.features[label][dialogues] for label in self.features} + + def to(self, device): + self.device = device + self.features = {label: self.features[label].to( + device) for label in self.features} + + def process(self, data, max_turns): + + feature = {"id": [], "input": [], + "label": [], "mask": [], "domain": []} + dst = UserRuleDST() + for dialog_id in tqdm(data, ascii=True, desc="Processing"): + user_goal = Goal(goal=data[dialog_id]["goal"]) + + # if one domain is removed, we skip all data related to this domain + # remove police at default + if (self.remove_domain and self.remove_domain in user_goal.domains) or \ + ("police" in user_goal.domains): + continue + + turn_num = len(data[dialog_id]["log"]) + pre_state = {} + sys_act = [] + self.feature_handler.initFeatureHandeler(user_goal) + dst.init_session() + user_mentioned = util.get_user_history( + data[dialog_id]["log"], self.all_values) + + for turn_id in range(0, turn_num, 2): + action_list = user_goal.action_list( + user_history=user_mentioned, + sys_act=sys_act, + all_values=self.all_values) + if turn_id > 0: + # cur_state = data[dialog_id]["log"][turn_id-1]["metadata"] + sys_act = util.parse_dialogue_act( + data[dialog_id]["log"][turn_id - 1]["dialog_act"]) + cur_state = dst.update(sys_act)["belief_state"] + + user_goal.update_user_goal(sys_act, cur_state) + usr_act = util.parse_dialogue_act( + data[dialog_id]["log"][turn_id]["dialog_act"]) + input_feature, mask = self.feature_handler.get_feature( + action_list, user_goal, cur_state, pre_state, sys_act) + label = self.feature_handler.generate_label( + action_list, user_goal, cur_state, usr_act, mode='sys') + domain_label = self.feature_handler.domain_label( + user_goal, usr_act) + pre_state = dst.update(usr_act) + pre_state = pre_state["belief_state"] + feature["id"].append(dialog_id) + feature["input"].append(input_feature) + feature["mask"].append(mask) + feature["label"].append(label) + feature["domain"].append(domain_label) + + print("label distribution") + label_distribution = Counter() + for label in feature["label"]: + label_distribution += Counter(label) + print(label_distribution) + feature["input"] = torch.tensor(feature["input"], dtype=torch.float) + feature["label"] = torch.tensor(feature["label"], dtype=torch.long) + feature["mask"] = torch.tensor(feature["mask"], dtype=torch.bool) + feature["domain"] = torch.tensor(feature["domain"], dtype=torch.float) + for feat_type in ["input", "label", "mask", "domain"]: + print("{}: {}".format(feat_type, feature[feat_type].shape)) + return feature + + def _update_act_dict(self, act_dict, act): + for intent, domain, slot, value in act: + domain = domain.lower() + slot = slot.lower() + value = value.lower() + if domain == "general": + continue + elif domain == "booking": + domain = util.get_booking_domain( + slot, value, self.feature_handler.all_values) + if slot in UsrDa2Goal[domain]: + slot = UsrDa2Goal[domain][slot] + elif slot in SysDa2Goal["booking"]: + slot = SysDa2Goal["booking"][slot] + else: + print( + f"strange action in geathering all actions{intent, domain, slot, value}") + if slot == "none": + continue + act_name = f"{domain}-{slot}" + if act_name not in act_dict: + act_dict[act_name] = [] + if intent not in act_dict[act_name]: + act_dict[act_name].append(intent) + + +class Feature: + def __init__(self, config, max_turn=1): + self.config = config + self.max_turn = max_turn + self.intents = ["inform", "request", "recommend", "select", + "book", "nobook", "offerbook", "offerbooked", "nooffer"] + self.general_intent = ["reqmore", "bye", "thank", "welcome", "greet"] + self.default_values = ["none", "?", "dontcare"] + # self.initFeatureHandeler() + path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + path = os.path.join(path, 'data/multiwoz/all_value.json') + self.all_values = json.load(open(path)) + + def initFeatureHandeler(self, goal): + self.goal = goal + self.domain_list = goal.domains + usr = util.parse_user_goal(goal) + self.constrains = {} # slot: fulfill + self.requirements = {} # slot: fulfill + self.pre_usr = [] + self.all_slot = None + self.user_feat_hist = [] + for slot in usr: + if usr[slot] != "?": + self.constrains[slot] = NOT_MENTIONED + + def get_feature(self, all_slot, user_goal, cur_state, pre_state=None, sys_action=None, usr_action=None): + """ + given current dialog information and return the input feature + user_goal: Goal() + cur_state: dict, {domain: "semi": {slot: value}, "book": {slot: value, "booked": []}}("metadata" in the data set) + sys_action: [[intent, domain, slot, value]] + """ + feature = [] + usr = util.parse_user_goal(user_goal) + + if sys_action: + self.update_constrain(sys_action) + + cur = util.metadata2state(cur_state) + pre = {} + if pre_state != None: + pre = util.metadata2state(pre_state) + if not self.pre_usr: + self.pre_usr = [0] * len(all_slot) + + if usr_action: + self.get_user_action_feat(all_slot, user_goal, usr_action) + + usr_act_feat = {} + for index, slot in zip(self.pre_usr, all_slot): + usr_act_feat[slot] = util.int2onehot(index, self.config["out_dim"]) + # usr_act_feat = self.user_action(all_slot, user_goal, usr_action) + for slot in all_slot: + feat = self.slot_feature( + slot, usr, cur, pre, sys_action, usr_act_feat) + feature.append(feat) + self.user_feat_hist.append(feature) + + feature, mask = self.pad_feature() + return feature, mask + + def slot_feature(self, slot, user_goal, current_state, previous_state, sys_action, usr_action): + feat = [] + feat += self.special_token(slot) + feat += self.value_representation( + slot, current_state.get(slot, NOT_MENTIONED)) + feat += self.value_representation( + slot, user_goal.get(slot, NOT_MENTIONED)) + feat += self.is_constrain_request(slot, user_goal) + feat += self.is_fulfill(slot, user_goal) + if self.config.get("conflict", True): + feat += self.conflict_check(user_goal, current_state, slot) + if self.config.get("domain_feat", False): + # feat += self.domain_feat(slot) + if slot in ["CLS", "SEP"]: + feat += [0] * (self.goal.max_domain_len + + self.goal.max_slot_len) + else: + domain_feat, slot_feat = self.goal.get_slot_id(slot) + feat += domain_feat + slot_feat + feat += self.first_mention_detection( + previous_state, current_state, slot) + feat += self.just_mention(slot, sys_action) + feat += self.action_representation(slot, sys_action) + # need change from 0 to domain predictor + if slot in ["CLS", "SEP"]: + feat += [0] * self.config["out_dim"] + else: + feat += usr_action[slot] + return feat + + def pad_feature(self, max_memory=5): + feature = [] + feat_dim = len(self.user_feat_hist[0][0]) + num_token = len(self.user_feat_hist[0]) + num_feat = len(self.user_feat_hist) + zero_pad = [0] * feat_dim + + for feat_index in range(num_feat - 1, max(num_feat - 1 - max_memory, -1), -1): + if feat_index == num_feat - 1: + special_token = self.slot_feature( + "CLS", {}, {}, {}, [], []) + else: + special_token = self.slot_feature( + "SEP", {}, {}, {}, [], []) + feature += [special_token] + feature += self.user_feat_hist[feat_index] + + # zero padding, 65 is the maximum -> need to modify + max_len = max_memory * 65 + if len(feature) < max_len: + padding = [[0] * feat_dim] * (max_len - len(feature)) + feature += padding + mask = [False] * len(feature) + [True] * (max_len - len(feature)) + + # if num_feat < max_memory: + # zero_feat = [[0] * feat_dim] * num_token + # for _ in range(max_memory - num_feat): + # feature += zero_feat + return feature, mask + + def domain_feat(self, slot): + max_domain = 3 + feat = [0] * max_domain + domain = slot.split('-')[0] + if domain in self.domain_list: + index = self.domain_list.index(domain) + if index < max_domain: + feat[index] = 1 + return feat + + def domain_label(self, user_goal, dialog_act): + labels = [0] * self.config["out_dim"] + goal_domains = user_goal.domains + no_domain = True + + for intent, domain, slot, value in dialog_act: + domain = domain.lower() + if domain in goal_domains: + index = goal_domains.index(domain) + labels[index + 1] = 1 + no_domain = False + if no_domain: + labels[0] = 1 + return labels + + def generate_label(self, action_list, user_goal, cur_state, dialog_act, mode="usr"): + # label = "none", "?", "dontcare", "system", "user", "change" + + labels = [-1] * self.config["num_token"] + # if mode == "sys": + # labels = [-1] * len(action_list) + # elif mode == "usr": + # labels = [0] * len(action_list) + # else: + # raise Exception(f"UNSEEN MODE {mode} in generate_label.") + + usr = util.parse_user_goal(user_goal) + cur = util.metadata2state(cur_state) + for intent, domain, slot, value in dialog_act: + domain = domain.lower() + value = value.lower() + slot = slot.lower() + name = util.act2slot(intent, domain, slot, value, self.all_values) + + if name not in action_list: + # print(f"Not handle name {name} in getting label") + continue + name_id = action_list.index(name) + if value == "?": + labels[name_id] = 1 + elif value == "dontcare": + labels[name_id] = 2 + elif name in cur and value == cur[name]: + labels[name_id] = 3 + elif name in usr and value == usr[name]: + labels[name_id] = 4 + elif (name in cur or name in usr) and value not in [cur.get(name), usr.get(name)]: + labels[name_id] = 5 + + for name in action_list: + domain = name.split('-')[0] + name_id = action_list.index(name) + if labels[name_id] < 0 and domain in self.domain_list: + labels[name_id] = 0 + + self.pre_usr = labels + + return labels + + def get_user_action_feat(self, action_list, user_goal, usr_act): + usr_label = self.generate_label( + action_list, user_goal, {}, usr_act, mode="usr") + self.pre_usr = usr_label + + def special_token(self, slot): + special_token = ["CLS", "SEP"] + feat = [0]*len(special_token) + if slot in special_token: + feat[special_token.index(slot)] = 1 + return feat + + def is_constrain_request(self, feature_slot, user_goal): + if feature_slot in ["CLS", "SEP"]: + return [0, 0] + # [is_constrain, is_request] + value = user_goal.get(feature_slot, NOT_MENTIONED) + if value == "?": + return [0, 1] + elif value == NOT_MENTIONED: + return [0, 0] + else: + return [1, 0] + + def is_fulfill(self, feature_slot, user_goal): + if feature_slot in ["CLS", "SEP"]: + return [0] + + if feature_slot in user_goal and user_goal.get(feature_slot) == self.constrains.get(feature_slot): + return [1] + return [0] + + def just_mention(self, feature_slot, sys_action): + """ + the system action just mentioned this slot + """ + if feature_slot in ["CLS", "SEP"]: + return [0] + if not sys_action: + return [0] + sys_action_slot = [] + for intent, domain, slot, value in sys_action: + domain = domain.lower() + slot = slot.lower() + value = value.lower() + if domain == "booking": + domain = util.get_booking_domain(slot, value, self.all_values) + if domain in sys_action: + action = f"{domain}-{slot}" + sys_action_slot.append(action) + if feature_slot in sys_action_slot: + return [1] + return [0] + + def action_representation(self, feature_slot, action): + + gen_vec = [0] * len(self.general_intent) + # ["none", "?", other] + intent2act = {intent: [0] * 3 for intent in self.intents} + + if action is None or feature_slot in ["CLS", "SEP"]: + return self.concatenate_action_vector(intent2act, gen_vec) + for intent, domain, slot, value in action: + domain = domain.lower() + slot = slot.lower() + value = value.lower() + + # general + if domain == "general": + self.update_general_action(gen_vec, intent) + else: + if domain == "booking": + domain = util.get_booking_domain( + slot, value, self.all_values) + self.update_intent2act( + feature_slot, intent2act, + domain, intent, slot, value) + + return self.concatenate_action_vector(intent2act, gen_vec) + + def update_general_action(self, vec, intent): + if intent in self.general_intent: + vec[self.general_intent.index(intent)] = 1 + + def update_intent2act(self, feature_slot, intent2act, domain, intent, slot, value): + feature_domain, feature_slot = feature_slot.split('-') + intent = intent.lower() + slot = slot.lower() + value = value.lower() + if slot == "none" and feature_domain == domain: # None slot + intent2act[intent][2] = 1 + elif feature_domain == domain and slot == feature_slot: + if value == "none": + intent2act[intent][0] = 1 + elif value == "?": + intent2act[intent][1] = 1 + else: + intent2act[intent][2] = 1 + + def concatenate_action_vector(self, intent2act, general): + feat = [] + for intent in intent2act: + feat += intent2act[intent] + feat += general + return feat + + def value_representation(self, slot, value): + if slot in ["CLS", "SEP"]: + return [0, 0, 0, 0] + if value == NOT_MENTIONED: + return [1, 0, 0, 0] + else: + temp_vector = [0] * (len(self.default_values) + 1) + if value in self.default_values: + temp_vector[self.default_values.index(value)] = 1 + else: + temp_vector[-1] = 1 + + return temp_vector + + def conflict_check(self, user_goal, system_state, slot): + # conflict = [1] else [0] + if slot in ["CLS", "SEP"]: + return [0] + usr = user_goal.get(slot, NOT_MENTIONED) + sys = system_state.get(slot, NOT_MENTIONED) + if usr in [NOT_MENTIONED, "none", ""] and sys in [NOT_MENTIONED, "none", ""]: + return [0] + + if usr != sys or (usr == "?" and sys == "?"): + # print(f"{slot} has different value: {usr} and {sys}.") + # conflict = uniform(0.2, 1) + conflict = 1 + return [conflict] + return [0] + + def first_mention_detection(self, pre_state, cur_state, slot): + if slot in ["CLS", "SEP"]: + return [0] + + first_mention = [1] + not_first_mention = [0] + cur = cur_state.get(slot, NOT_MENTIONED) + if pre_state is None: + if cur not in [NOT_MENTIONED, "none"]: + return first_mention + else: + return not_first_mention + + pre = pre_state.get(slot, NOT_MENTIONED) + + if pre in [NOT_MENTIONED, "none"] and cur not in [NOT_MENTIONED, "none"]: + return first_mention + + return not_first_mention # hasn't been mentioned + + def update_constrain(self, action): + """ + update constrain status by system actions + action = [[intent, domain, slot, name]] + """ + for intent, domain, slot, value in action: + domain = domain.lower() + if domain in self.domain_list: + slot = SysDa2Goal[domain].get(slot, "none") + slot_name = f"{domain}-{slot}" + elif domain == "booking": + if slot.lower() == "ref": + continue + slot = SysDa2Goal[domain].get(slot, "none") + domain = util.get_booking_domain(slot, value, self.all_values) + if not domain: + continue # work around + slot_name = f"{domain}-{slot}" + + else: + continue + if value != "?": + self.constrains[slot_name] = value + + def update_request(self): + # TODO + pass + + @staticmethod + def concatenate_subvectors(vec_list): + vec = [] + for sub_vec in vec_list: + vec += sub_vec + return vec + + +if __name__ == "__main__": + x = TUSDataManager() diff --git a/convlab2/policy/tus/util.py b/convlab2/policy/tus/util.py new file mode 100644 index 0000000..22d60cc --- /dev/null +++ b/convlab2/policy/tus/util.py @@ -0,0 +1,126 @@ +from convlab2.policy.tus.multiwoz.Da2Goal import SysDa2Goal, UsrDa2Goal + +NOT_MENTIONED = "not mentioned" + + +def int2onehot(index, output_dim=6): + one_hot = [0] * output_dim + if index >= 0: + one_hot[index] = 1 + + return one_hot + + +def parse_user_goal(user_goal): + """flatten user goal structure""" + goal = user_goal.domain_goals + user_goal = {} + for domain in goal: + if domain not in UsrDa2Goal: + continue + for slot_type in goal[domain]: + if slot_type in ["fail_info", "fail_book", "booked"]: + continue # TODO [fail_info] fix in the future + if slot_type in ["info", "book", "reqt"]: + for slot in goal[domain][slot_type]: + slot_name = f"{domain}-{slot.lower()}" + user_goal[slot_name] = goal[domain][slot_type][slot] + + return user_goal + + +def parse_dialogue_act(dialogue_act): + """ transfer action from dict to list """ + actions = [] + for act in dialogue_act: + domain, intent = act.split('-') + for slot, value in dialogue_act[act]: + value_dict = {"do nt care": "dontcare"} + if value in value_dict: + value = value_dict[value] + actions.append([intent, domain, slot, value]) + + return actions + + +def metadata2state(metadata): + """ + parse metadata in the data set or dst + """ + slot_value = {} + + for domain in metadata: + for slot in metadata[domain]["semi"]: + slot_name = f"{domain.lower()}-{slot.lower()}" + value = metadata[domain]["semi"][slot] + if not value or value == NOT_MENTIONED: + value = "none" + slot_value[slot_name] = value + + for slot in metadata[domain]["book"]: + if slot == "booked": + continue + slot_name = f"{domain.lower()}-{slot.lower()}" + value = metadata[domain]["book"][slot] + slot_value[slot_name] = value + + return slot_value + + +def get_booking_domain(slot, value, all_values): + """ + find the domain for domain booking, excluding slot "ref" + """ + found = "" + if not slot: + return found + slot = slot.lower() + value = value.lower() + for domain in all_values["all_value"]: + if slot in all_values["all_value"][domain] and value in all_values["all_value"][domain][slot]: + found = domain + return found + + +def act2slot(intent, domain, slot, value, all_values): + + if domain not in UsrDa2Goal: + # print(f"Not handle domain {domain}") + return "" + + if domain == "booking": + slot = SysDa2Goal[domain][slot] + domain = get_booking_domain(slot, value, all_values) + return f"{domain}-{slot}" + + elif domain in UsrDa2Goal: + if slot in SysDa2Goal[domain]: + slot = SysDa2Goal[domain][slot] + elif slot in UsrDa2Goal[domain]: + slot = UsrDa2Goal[domain][slot] + elif slot in SysDa2Goal["booking"]: + slot = SysDa2Goal["booking"][slot] + # else: + # print( + # f"UNSEEN ACTION IN GENERATE LABEL {intent, domain, slot, value}") + + return f"{domain}-{slot}" + + print("strange!!!") + print(intent, domain, slot, value) + + return "" + + +def get_user_history(dialog, all_values): + turn_num = len(dialog) + mentioned_slot = [] + for turn_id in range(0, turn_num, 2): + usr_act = parse_dialogue_act( + dialog[turn_id]["dialog_act"]) + for intent, domain, slot, value in usr_act: + slot_name = act2slot( + intent, domain.lower(), slot.lower(), value.lower(), all_values) + if slot_name not in mentioned_slot: + mentioned_slot.append(slot_name) + return mentioned_slot -- GitLab