Skip to content
Snippets Groups Projects
Commit cf362bc0 authored by Christian's avatar Christian
Browse files

ddpt uses argmax during evaluation and removed domain mask

parent e8b68c20
No related branches found
No related tags found
No related merge requests found
......@@ -83,7 +83,7 @@ convlab2/dst/trade/multiwoz_config/
deploy/bert_multiwoz_all.zip
deploy/templates/dialog_eg.html
test.py
*convlab2/policy/vector/action_dicts
*.egg-info
pre-trained-models/
venv
\ No newline at end of file
......@@ -65,10 +65,10 @@ class VectorBinary(VectorBase):
return state_vec, mask
def get_mask(self, domain_active_dict, number_entities_dict):
domain_mask = self.compute_domain_mask(domain_active_dict)
#domain_mask = self.compute_domain_mask(domain_active_dict)
entity_mask = self.compute_entity_mask(number_entities_dict)
general_mask = self.compute_general_mask()
mask = domain_mask + entity_mask + general_mask
mask = entity_mask + general_mask
return mask
def vectorize_booked(self, state):
......
......@@ -69,10 +69,10 @@ class VectorNodes(VectorBase):
return np.zeros(1), mask
def get_mask(self, domain_active_dict, number_entities_dict):
domain_mask = self.compute_domain_mask(domain_active_dict)
#domain_mask = self.compute_domain_mask(domain_active_dict)
entity_mask = self.compute_entity_mask(number_entities_dict)
general_mask = self.compute_general_mask()
mask = domain_mask + entity_mask + general_mask
mask = entity_mask + general_mask
return mask
def get_db_features(self):
......
......@@ -17,7 +17,7 @@
"uncertainty_vector_mul": {
"class_path": "convlab2.policy.vector.vector_nodes.VectorNodes",
"ini_params": {
"use_masking": false,
"use_masking": true,
"manually_add_entity_names": false,
"seed": 0
}
......
......@@ -80,7 +80,7 @@ class EncoderDecoder(nn.Module):
value_list = torch.Tensor([node['value'] for node in kg_list[0]]).unsqueeze(1).to(DEVICE)
return description_idx_list, value_list
def select_action(self, kg_list, mask=None):
def select_action(self, kg_list, mask=None, eval=False):
'''
:param kg_list: A single knowledge graph consisting of a list of nodes
:return: multi-action
......@@ -159,6 +159,7 @@ class EncoderDecoder(nn.Module):
action_logits = action_logits - action_mask * sys.maxsize
action_distribution = self.softmax(action_logits).squeeze(-1)
if not eval or t % 3 != 0:
dist = Categorical(action_distribution)
rand_state = torch.random.get_rng_state()
action = dist.sample().tolist()[-1]
......@@ -166,6 +167,12 @@ class EncoderDecoder(nn.Module):
semantic_action = self.action_embedder.small_action_dict_reversed[action[-1]]
action_list.append(semantic_action)
action_list_num.append(action[-1])
else:
action = action_distribution[-1, -1, :]
action = torch.argmax(action).item()
semantic_action = self.action_embedder.small_action_dict_reversed[action]
action_list.append(semantic_action)
action_list_num.append(action)
#prepare for next step
next_input = self.action_embedder.action_projector(self.action_embedder.action_embeddings[action]).view(1, 1, -1) + \
......
......@@ -107,7 +107,7 @@ class VTRACE(nn.Module, Policy):
s, action_mask = self.vector.state_vectorize(state)
kg_states = [self.vector.kg_info]
a = self.policy.select_action(kg_states, mask=action_mask).detach().cpu()
a = self.policy.select_action(kg_states, mask=action_mask, eval=not self.is_train).detach().cpu()
self.info_dict = self.policy.info_dict
descr_list = self.info_dict["description_idx_list"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment