diff --git a/convlab/dialog_agent/agent.py b/convlab/dialog_agent/agent.py index 9a9dbb31e0599ca7e73b384bb6028d479b245c25..025b41e88cef8eec0e0ca556fbb807dee84ef72d 100755 --- a/convlab/dialog_agent/agent.py +++ b/convlab/dialog_agent/agent.py @@ -358,9 +358,8 @@ class DialogueAgent(Agent): else: state = self.input_action - fundamental_info['state'] = state - state = deepcopy(state) # get rid of reference problem + fundamental_info['state'] = state self.sys_state_history.append(state) # get action diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index f300914eb9382de570fbc08fae55e048f0f07ffa..62770bb44b83155917420dddbf7dc53a9221d1b0 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -172,13 +172,13 @@ class MultiWozEvaluator(Evaluator): da_turn = self._convert_action(da_turn) for intent, domain, slot, value in da_turn: - dom_int = '-'.join([domain, intent]) - domain = dom_int.split('-')[0].lower() + dom_int = '_'.join([domain, intent]) + domain = dom_int.split('_')[0].lower() if domain in belief_domains and domain != self.cur_domain: self.cur_domain = domain - da = (dom_int + '-' + slot).lower() + da = (dom_int + '_' + slot).lower() value = str(value) - self.sys_da_array.append(da + '-' + value) + self.sys_da_array.append(da + '_' + value) # new booking actions make life easier if intent.lower() == "book": @@ -208,13 +208,13 @@ class MultiWozEvaluator(Evaluator): """ da_turn = self._convert_action(da_turn) for intent, domain, slot, value in da_turn: - dom_int = '-'.join([domain, intent]) - domain = dom_int.split('-')[0].lower() + dom_int = '_'.join([domain, intent]) + domain = dom_int.split('_')[0].lower() if domain in belief_domains and domain != self.cur_domain: self.cur_domain = domain - da = (dom_int + '-' + slot).lower() + da = (dom_int + '_' + slot).lower() value = str(value) - self.usr_da_array.append(da + '-' + value) + self.usr_da_array.append(da + '_' + value) def _book_rate_goal(self, goal, booked_entity, domains=None): """ @@ -326,7 +326,7 @@ class MultiWozEvaluator(Evaluator): reqt_not_inform = set() bad_inform = set() for da in sys_history: - domain, intent, slot, value = da.split('-', 3) + domain, intent, slot, value = da.split('_', 3) if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and \ domain in domains and slot in mapping[domain] and value.strip() not in NUL_VALUE: key = mapping[domain][slot] @@ -391,7 +391,7 @@ class MultiWozEvaluator(Evaluator): if domain in self.goal and 'book' in self.goal[domain]: goal[domain]['book'] = self.goal[domain]['book'] for da in self.usr_da_array: - d, i, s, v = da.split('-', 3) + d, i, s, v = da.split('_', 3) if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: goal[d]['info'][mapping[d][s]] = v score = self._book_rate_goal(goal, self.booked) @@ -409,7 +409,7 @@ class MultiWozEvaluator(Evaluator): if domain in self.goal and 'book' in self.goal[domain]: goal[domain]['book'] = self.goal[domain]['book'] for da in self.usr_da_array: - d, i, s, v = da.split('-', 3) + d, i, s, v = da.split('_', 3) if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: goal[d]['info'][mapping[d][s]] = v score = self._book_goal_constraints(goal, self.booked_states) @@ -441,7 +441,7 @@ class MultiWozEvaluator(Evaluator): else: goal = self._init_dict() for da in self.usr_da_array: - d, i, s, v = da.split('-', 3) + d, i, s, v = da.split('_', 3) if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: goal[d]['info'][mapping[d][s]] = v elif i == 'request': @@ -502,7 +502,7 @@ class MultiWozEvaluator(Evaluator): if 'book' in self.goal[domain]: goal[domain]['book'] = self.goal[domain]['book'] for da in self.usr_da_array: - d, i, s, v = da.split('-', 3) + d, i, s, v = da.split('_', 3) if d != domain: continue if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: @@ -529,7 +529,7 @@ class MultiWozEvaluator(Evaluator): if 'book' in self.goal[domain]: goal[domain]['book'] = self.goal[domain]['book'] for da in self.usr_da_array: - d, i, s, v = da.split('-', 3) + d, i, s, v = da.split('_', 3) if d != domain: continue if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]: @@ -664,7 +664,11 @@ class MultiWozEvaluator(Evaluator): def evaluate_dialog(self, goal, user_acts, system_acts, system_states): - self.add_goal(goal.domain_goals) + if isinstance(goal, dict): + self.add_goal(goal) + else: + self.add_goal(goal.domain_goals) + for sys_act, sys_state, user_act in zip(system_acts, system_states, user_acts): self.add_sys_da(sys_act, sys_state) self.add_usr_da(user_act)