Skip to content
Snippets Groups Projects
Commit a6dea527 authored by Christian's avatar Christian
Browse files

consitency with _ instead of - in evaluator

parent 681e7b80
No related branches found
No related tags found
No related merge requests found
...@@ -358,9 +358,8 @@ class DialogueAgent(Agent): ...@@ -358,9 +358,8 @@ class DialogueAgent(Agent):
else: else:
state = self.input_action state = self.input_action
fundamental_info['state'] = state
state = deepcopy(state) # get rid of reference problem state = deepcopy(state) # get rid of reference problem
fundamental_info['state'] = state
self.sys_state_history.append(state) self.sys_state_history.append(state)
# get action # get action
......
...@@ -172,13 +172,13 @@ class MultiWozEvaluator(Evaluator): ...@@ -172,13 +172,13 @@ class MultiWozEvaluator(Evaluator):
da_turn = self._convert_action(da_turn) da_turn = self._convert_action(da_turn)
for intent, domain, slot, value in da_turn: for intent, domain, slot, value in da_turn:
dom_int = '-'.join([domain, intent]) dom_int = '_'.join([domain, intent])
domain = dom_int.split('-')[0].lower() domain = dom_int.split('_')[0].lower()
if domain in belief_domains and domain != self.cur_domain: if domain in belief_domains and domain != self.cur_domain:
self.cur_domain = domain self.cur_domain = domain
da = (dom_int + '-' + slot).lower() da = (dom_int + '_' + slot).lower()
value = str(value) value = str(value)
self.sys_da_array.append(da + '-' + value) self.sys_da_array.append(da + '_' + value)
# new booking actions make life easier # new booking actions make life easier
if intent.lower() == "book": if intent.lower() == "book":
...@@ -208,13 +208,13 @@ class MultiWozEvaluator(Evaluator): ...@@ -208,13 +208,13 @@ class MultiWozEvaluator(Evaluator):
""" """
da_turn = self._convert_action(da_turn) da_turn = self._convert_action(da_turn)
for intent, domain, slot, value in da_turn: for intent, domain, slot, value in da_turn:
dom_int = '-'.join([domain, intent]) dom_int = '_'.join([domain, intent])
domain = dom_int.split('-')[0].lower() domain = dom_int.split('_')[0].lower()
if domain in belief_domains and domain != self.cur_domain: if domain in belief_domains and domain != self.cur_domain:
self.cur_domain = domain self.cur_domain = domain
da = (dom_int + '-' + slot).lower() da = (dom_int + '_' + slot).lower()
value = str(value) value = str(value)
self.usr_da_array.append(da + '-' + value) self.usr_da_array.append(da + '_' + value)
def _book_rate_goal(self, goal, booked_entity, domains=None): def _book_rate_goal(self, goal, booked_entity, domains=None):
""" """
...@@ -326,7 +326,7 @@ class MultiWozEvaluator(Evaluator): ...@@ -326,7 +326,7 @@ class MultiWozEvaluator(Evaluator):
reqt_not_inform = set() reqt_not_inform = set()
bad_inform = set() bad_inform = set()
for da in sys_history: for da in sys_history:
domain, intent, slot, value = da.split('-', 3) domain, intent, slot, value = da.split('_', 3)
if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and \ if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and \
domain in domains and slot in mapping[domain] and value.strip() not in NUL_VALUE: domain in domains and slot in mapping[domain] and value.strip() not in NUL_VALUE:
key = mapping[domain][slot] key = mapping[domain][slot]
...@@ -391,7 +391,7 @@ class MultiWozEvaluator(Evaluator): ...@@ -391,7 +391,7 @@ class MultiWozEvaluator(Evaluator):
if domain in self.goal and 'book' in self.goal[domain]: if domain in self.goal and 'book' in self.goal[domain]:
goal[domain]['book'] = self.goal[domain]['book'] goal[domain]['book'] = self.goal[domain]['book']
for da in self.usr_da_array: for da in self.usr_da_array:
d, i, s, v = da.split('-', 3) d, i, s, v = da.split('_', 3)
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]:
goal[d]['info'][mapping[d][s]] = v goal[d]['info'][mapping[d][s]] = v
score = self._book_rate_goal(goal, self.booked) score = self._book_rate_goal(goal, self.booked)
...@@ -409,7 +409,7 @@ class MultiWozEvaluator(Evaluator): ...@@ -409,7 +409,7 @@ class MultiWozEvaluator(Evaluator):
if domain in self.goal and 'book' in self.goal[domain]: if domain in self.goal and 'book' in self.goal[domain]:
goal[domain]['book'] = self.goal[domain]['book'] goal[domain]['book'] = self.goal[domain]['book']
for da in self.usr_da_array: for da in self.usr_da_array:
d, i, s, v = da.split('-', 3) d, i, s, v = da.split('_', 3)
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]:
goal[d]['info'][mapping[d][s]] = v goal[d]['info'][mapping[d][s]] = v
score = self._book_goal_constraints(goal, self.booked_states) score = self._book_goal_constraints(goal, self.booked_states)
...@@ -441,7 +441,7 @@ class MultiWozEvaluator(Evaluator): ...@@ -441,7 +441,7 @@ class MultiWozEvaluator(Evaluator):
else: else:
goal = self._init_dict() goal = self._init_dict()
for da in self.usr_da_array: for da in self.usr_da_array:
d, i, s, v = da.split('-', 3) d, i, s, v = da.split('_', 3)
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]:
goal[d]['info'][mapping[d][s]] = v goal[d]['info'][mapping[d][s]] = v
elif i == 'request': elif i == 'request':
...@@ -502,7 +502,7 @@ class MultiWozEvaluator(Evaluator): ...@@ -502,7 +502,7 @@ class MultiWozEvaluator(Evaluator):
if 'book' in self.goal[domain]: if 'book' in self.goal[domain]:
goal[domain]['book'] = self.goal[domain]['book'] goal[domain]['book'] = self.goal[domain]['book']
for da in self.usr_da_array: for da in self.usr_da_array:
d, i, s, v = da.split('-', 3) d, i, s, v = da.split('_', 3)
if d != domain: if d != domain:
continue continue
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]:
...@@ -529,7 +529,7 @@ class MultiWozEvaluator(Evaluator): ...@@ -529,7 +529,7 @@ class MultiWozEvaluator(Evaluator):
if 'book' in self.goal[domain]: if 'book' in self.goal[domain]:
goal[domain]['book'] = self.goal[domain]['book'] goal[domain]['book'] = self.goal[domain]['book']
for da in self.usr_da_array: for da in self.usr_da_array:
d, i, s, v = da.split('-', 3) d, i, s, v = da.split('_', 3)
if d != domain: if d != domain:
continue continue
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]:
...@@ -664,7 +664,11 @@ class MultiWozEvaluator(Evaluator): ...@@ -664,7 +664,11 @@ class MultiWozEvaluator(Evaluator):
def evaluate_dialog(self, goal, user_acts, system_acts, system_states): def evaluate_dialog(self, goal, user_acts, system_acts, system_states):
if isinstance(goal, dict):
self.add_goal(goal)
else:
self.add_goal(goal.domain_goals) self.add_goal(goal.domain_goals)
for sys_act, sys_state, user_act in zip(system_acts, system_states, user_acts): for sys_act, sys_state, user_act in zip(system_acts, system_states, user_acts):
self.add_sys_da(sys_act, sys_state) self.add_sys_da(sys_act, sys_state)
self.add_usr_da(user_act) self.add_usr_da(user_act)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment