diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py index 2c7712bc9d7df74d9ae35a36bd8bb9edd4886c60..a6da8381b715c826d61600aa22defd9721c891b0 100644 --- a/convlab/policy/vector/vector_nodes.py +++ b/convlab/policy/vector/vector_nodes.py @@ -70,6 +70,10 @@ class VectorNodes(VectorBase): if self.filter_state: self.kg_info = self.filter_inactive_domains(domain_active_dict) + # make sure kg is not empty + if len(self.kg_info) == 0: + self.add_user_greet() + if self.use_mask: mask = self.get_mask(domain_active_dict, number_entities_dict) for i in range(self.da_dim): @@ -170,3 +174,14 @@ class VectorNodes(VectorBase): return kg_filtered + def add_user_greet(self): + + feature_type = 'user act' + da = ("general", "greet", "none", "none") + if da in self.opp2vec: + domain = da[0] + description = "user-" + "_".join(da) + value = 1.0 + self.add_graph_node(domain, feature_type, description.lower(), value) + +