diff --git a/convlab/dialog_agent/agent.py b/convlab/dialog_agent/agent.py index 5ad0c296ce34e1b2a2048bcc67bddaab978e0d7e..79f61e2b18f04e702c690a33cdf313656b2615c6 100755 --- a/convlab/dialog_agent/agent.py +++ b/convlab/dialog_agent/agent.py @@ -64,12 +64,7 @@ class PipelineAgent(Agent): ===== ===== ====== === == === """ -<<<<<<< HEAD def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str): -======= - def __init__(self, nlu: NLU, dst: DST, policy: Policy, nlg: NLG, name: str, return_semantic_acts: bool = False, - word_level_policy_nlu: NLU = None): ->>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805 """The constructor of PipelineAgent class. Here are some special combination cases: @@ -186,27 +181,14 @@ class PipelineAgent(Agent): model_response = self.output_action # print(model_response) - if self.policy_nlu and type(self.output_action) != list: - self.semantic_output_action = self.policy_nlu.predict(self.output_action, - context=[x[1] for x in self.history]) - else: - self.semantic_output_action = None - if self.dst is not None: self.dst.state['history'].append([self.name, model_response]) if self.name == 'sys': - self.dst.state['system_action'] = self.semantic_output_action if self.semantic_output_action else self.output_action + self.dst.state['system_action'] = self.output_action if type(self.output_action) == list: for intent, domain, slot, value in self.output_action: if intent.lower() == "book": -<<<<<<< HEAD -======= - self.dst.state['booked'][domain] = [{slot: value}] - elif self.semantic_output_action: - for intent, domain, slot, value in self.semantic_output_action: - if intent.lower() == "book": ->>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805 self.dst.state['booked'][domain] = [{slot: value}] else: self.dst.state['user_action'] = self.output_action @@ -270,9 +252,6 @@ class PipelineAgent(Agent): def get_in_da_eval(self): return self.input_action_eval - def get_out_da_eval(self): - return self.semantic_output_action - def get_in_da(self): return self.input_action diff --git a/convlab/policy/lava/multiwoz/lava.py b/convlab/policy/lava/multiwoz/lava.py index 203c6c8237caaef7c2140c00a3d8e220c8823e0d..76ea072ddb87f64a0144de9ae421707ea057c7e4 100755 --- a/convlab/policy/lava/multiwoz/lava.py +++ b/convlab/policy/lava/multiwoz/lava.py @@ -237,7 +237,7 @@ def delexicaliseReferenceNumber(sent, state): 'train', 'taxi', 'hospital'] # , 'police'] if state['history'][-1][0]=="sys": - print(state["booked"]) + # print(state["booked"]) for domain in domains: if state['booked'][domain]: for slot in state['booked'][domain][0]: @@ -751,8 +751,8 @@ class LAVA(Policy): # mark_not_mentioned(prev_state) #active_domain = self.get_active_domain_convlab(self.prev_active_domain, prev_bstate, bstate) active_domain = self.get_active_domain_unified(self.prev_active_domain, self.prev_state, state) - print("---------") - print("active domain: ", active_domain) + # print("---------") + # print("active domain: ", active_domain) # if active_domain is not None: # print(f"DST on {active_domain}: {bstate[active_domain]}") @@ -761,7 +761,7 @@ class LAVA(Policy): top_results, num_results = None, None for t_id in range(len(context)): usr = context[t_id] - print(usr) + # print(usr) if t_id == 0: #system turns if usr == "null": @@ -1179,7 +1179,7 @@ class LAVA(Policy): return response def populate_template_unified(self, template, top_results, num_results, state, active_domain): - print("template:",template) + # print("template:",template) # print("top_results:",top_results) # active_domain = None if len( # top_results.keys()) == 0 else list(top_results.keys())[0] diff --git a/examples/agent_examples/test_LAVA.py b/examples/agent_examples/test_LAVA.py index 7cc49c61ba98043affed0b954e1c49af03d588c5..63bf075c5a25dcdfc8b89175d42dbf6272b23f0e 100755 --- a/examples/agent_examples/test_LAVA.py +++ b/examples/agent_examples/test_LAVA.py @@ -9,11 +9,7 @@ import random from pprint import pprint from argparse import ArgumentParser from convlab.nlu.jointBERT.unified_datasets import BERTNLU -<<<<<<< HEAD # from convlab.nlu.jointBERT.multiwoz import BERTNLU as BERTNLU_woz -======= -from convlab.nlu.jointBERT.multiwoz import BERTNLU as BERTNLU_woz ->>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805 # from convlab.nlu.milu.multiwoz import MILU # available DST models from convlab.dst.rule.multiwoz import RuleDST @@ -126,11 +122,7 @@ def test_end2end(args, model_dir): user_agent = PipelineAgent( user_nlu, user_dst, user_policy, user_nlg, name='user') sys_agent = PipelineAgent( -<<<<<<< HEAD sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') -======= - sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys', return_semantic_acts=True, word_level_policy_nlu=user_nlu) ->>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805 sys_agent.add_booking_info = False @@ -140,12 +132,8 @@ def test_end2end(args, model_dir): #seed = 2020 set_seed(args.seed) - model_name = '{}_{}_lava_{}'.format(args.US_type, args.dst_type, model_dir) -<<<<<<< HEAD - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=500) -======= - analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=10) ->>>>>>> 4b3d84a0a205661b8bbb445a48452c4013f43805 + model_name = '{}_{}_lava_{}_tmp'.format(args.US_type, args.dst_type, model_dir) + analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name=model_name, total_dialog=100) if __name__ == '__main__': parser = ArgumentParser()