Skip to content
Snippets Groups Projects
Unverified Commit acfef62a authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Improve agenda policy (#62)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba7 (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* fix self.cur_domain=None when system offer book

* fix agenda for 0 choice
parent 89dc730c
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,7 @@ for dom, ref_slots in REF_SYS_DA.items(): ...@@ -34,7 +34,7 @@ for dom, ref_slots in REF_SYS_DA.items():
if slot_a == 'Ref': if slot_a == 'Ref':
slot_b = 'ref' slot_b = 'ref'
REF_SYS_DA_M[dom][slot_a.lower()] = slot_b REF_SYS_DA_M[dom][slot_a.lower()] = slot_b
REF_SYS_DA_M[dom]['none'] = None REF_SYS_DA_M[dom]['none'] = 'none'
REF_SYS_DA_M['taxi']['phone'] = 'phone' REF_SYS_DA_M['taxi']['phone'] = 'phone'
REF_SYS_DA_M['taxi']['car'] = 'car type' REF_SYS_DA_M['taxi']['car'] = 'car type'
...@@ -90,6 +90,11 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -90,6 +90,11 @@ class UserPolicyAgendaMultiWoz(Policy):
sys_action = {} sys_action = {}
for intent, domain, slot, value in sys_dialog_act: for intent, domain, slot, value in sys_dialog_act:
if slot == 'Choice' and value.strip().lower() in ['0', 'zero']:
nooffer_key = '-'.join([domain, 'NoOffer'])
sys_action.setdefault(nooffer_key, [])
sys_action[nooffer_key].append(['none', 'none'])
else:
k = '-'.join([domain, intent]) k = '-'.join([domain, intent])
sys_action.setdefault(k, []) sys_action.setdefault(k, [])
sys_action[k].append([slot, value]) sys_action[k].append([slot, value])
...@@ -591,7 +596,7 @@ class Agenda(object): ...@@ -591,7 +596,7 @@ class Agenda(object):
continue continue
# For multiple choices, add new intent to select one: # For multiple choices, add new intent to select one:
if slot == 'choice': if slot == 'choice' and value.strip().lower() not in ['0', 'zero']:
self._push_item(domain + '-inform', "choice", "any") self._push_item(domain + '-inform', "choice", "any")
if slot in g_reqt: if slot in g_reqt:
...@@ -882,6 +887,7 @@ if __name__ == '__main__': ...@@ -882,6 +887,7 @@ if __name__ == '__main__':
state['user_action'] = user_act state['user_action'] = user_act
sys_act = sys_policy.predict(state) sys_act = sys_policy.predict(state)
sys_act.append(["Request", "Restaurant", "Price", "?"]) sys_act.append(["Request", "Restaurant", "Price", "?"])
sys_act = [['Inform', 'Hotel', 'Choice', '0']]
print(sys_act) print(sys_act)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment