diff --git a/convlab2/policy/mle/crosswoz/loader.py b/convlab2/policy/mle/crosswoz/loader.py index d3f160cc458a5fa425fb94c65e8614c6bb54e524..c29c11cabdfc95de0ab220114e3d08165962d695 100755 --- a/convlab2/policy/mle/crosswoz/loader.py +++ b/convlab2/policy/mle/crosswoz/loader.py @@ -45,6 +45,7 @@ class PolicyDataLoaderCrossWoz(): dst.init_session() for i, turn in enumerate(sess): if turn['role'] == 'usr': + dst.state['user_action'] = turn['dialog_act'] dst.update(usr_da=turn['dialog_act']) if i + 2 == len(sess): dst.state['terminated'] = True diff --git a/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py b/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py index 768393658697fa9e239f81f49c17258f76a73c43..06e85eb2bf1e4415115442093d5f46ca9fe53d58 100755 --- a/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py +++ b/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py @@ -114,7 +114,7 @@ class RuleBasedMultiwozBot(Policy): # print("Sys action: ", DA) - if DA == {}: + if len([domain_intent for domain_intent, slots in DA.items() if slots or 'nooffer' in domain_intent.lower()]) == 0: DA = {'general-greet': [['none', 'none']]} tuples = [] for domain_intent, svs in DA.items(): @@ -201,6 +201,17 @@ class RuleBasedMultiwozBot(Policy): self.choice = "" elif self.recommend_flag == 1: self.recommend_flag == 0 + + if len(kb_result) == 0: + if (domain + "-NoOffer") not in DA: + DA[domain + "-NoOffer"] = [] + + for slot in state['belief_state'][domain.lower()]['semi']: + if state['belief_state'][domain.lower()]['semi'][slot] != "" and \ + state['belief_state'][domain.lower()]['semi'][slot] not in ["do nt care", "do n't care", + "dontcare"]: + slot_name = REF_USR_DA[domain].get(slot, slot) + DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]]) if (domain + "-Inform") not in DA: DA[domain + "-Inform"] = [] for slot in user_action[user_act]: