diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 0bc3351aae8bcb769983b58a5e6601607e3f5f98..0b1150fbfa6b865c819c83d721f4d8eea2ebdaf3 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -111,14 +111,14 @@ class VectorBase(Vector): if turn['speaker'] == 'system': for act in delex_acts: - act = tuple(act) + act = tuple([a.lower() for a in act]) if act not in system_dict: system_dict[act] = 1 else: system_dict[act] += 1 else: for act in delex_acts: - act = tuple(act) + act = tuple([a.lower() for a in act]) if act not in user_dict: user_dict[act] = 1 else: diff --git a/convlab/policy/vtrace_DPT/supervised/train_supervised.py b/convlab/policy/vtrace_DPT/supervised/train_supervised.py index 1807a671da7e2938173a18277cd21980ee577a11..ccc407086e3399656d6ec7840e5c920779e2d058 100644 --- a/convlab/policy/vtrace_DPT/supervised/train_supervised.py +++ b/convlab/policy/vtrace_DPT/supervised/train_supervised.py @@ -182,7 +182,7 @@ if __name__ == '__main__': args = arg_parser() root_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - with open(os.path.join(root_directory, 'config.json'), 'r') as f: + with open(os.path.join(root_directory, 'configs/multiwoz21_dpt.json'), 'r') as f: cfg = json.load(f) cfg['dataset_name'] = args.dataset_name diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt new file mode 100644 index 0000000000000000000000000000000000000000..2f7e5e9cdf949b082b9ae752c09d2eafeb322bfb Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_multiwoz21.pt differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt new file mode 100644 index 0000000000000000000000000000000000000000..5c6193163b6a1fc788b959eff2d0153b3a4bf7cb Binary files /dev/null and b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_multiwoz21.pt differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.txt b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.txt new file mode 100644 index 0000000000000000000000000000000000000000..b368cd9fb7e978b36ed0fb261aafbe1baeed4ab4 --- /dev/null +++ b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.txt @@ -0,0 +1,92 @@ +"attraction" +"general" +"hospital" +"hotel" +"police" +"restaurant" +"taxi" +"train" +"eos" +"inform" +"nooffer" +"recommend" +"request" +"select" +"bye" +"greet" +"reqmore" +"welcome" +"book" +"offerbook" +"nobook" +["address", "1"] +["address", "2"] +["address", "3"] +["area", "1"] +["area", "2"] +["area", "3"] +["choice", "1"] +["choice", "2"] +["choice", "3"] +["entrance fee", "1"] +["entrance fee", "2"] +["name", "1"] +["name", "2"] +["name", "3"] +["name", "4"] +["phone", "1"] +["postcode", "1"] +["type", "1"] +["type", "2"] +["type", "3"] +["type", "4"] +["type", "5"] +["none", "none"] +["area", "?"] +["entrance fee", "?"] +["name", "?"] +["type", "?"] +["department", "1"] +["department", "?"] +["book day", "1"] +["book people", "1"] +["book stay", "1"] +["internet", "1"] +["parking", "1"] +["price range", "1"] +["price range", "2"] +["ref", "1"] +["stars", "1"] +["stars", "2"] +["book day", "?"] +["book people", "?"] +["book stay", "?"] +["internet", "?"] +["parking", "?"] +["price range", "?"] +["stars", "?"] +["book time", "1"] +["food", "1"] +["food", "2"] +["food", "3"] +["food", "4"] +["postcode", "2"] +["book time", "?"] +["food", "?"] +["arrive by", "1"] +["departure", "1"] +["destination", "1"] +["leave at", "1"] +["arrive by", "?"] +["departure", "?"] +["destination", "?"] +["leave at", "?"] +["arrive by", "2"] +["day", "1"] +["duration", "1"] +["leave at", "2"] +["leave at", "3"] +["price", "1"] +["train id", "1"] +["day", "?"] +"pad"