Skip to content
Snippets Groups Projects
Commit 2422980a authored by liangrz's avatar liangrz Committed by zhuqi
Browse files

fix the bug of nan gradient

parent c2c24b42
No related branches found
No related tags found
No related merge requests found
......@@ -176,6 +176,9 @@ class GDPL(Policy):
# backprop
surrogate.backward()
for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability
torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
# self.lock.acquire() # retain lock to update weights
......
......@@ -17,7 +17,7 @@ class ActMLEPolicyDataLoader():
data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader())
for part in ['train', 'val', 'test']:
self.data[part] = []
raw_data = data_loader.load_data(data_key=part, role='system')[part]
raw_data = data_loader.load_data(data_key=part, role='sys')[part]
for belief_state, context_dialog_act, terminated, dialog_act in \
zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']):
......
......@@ -126,6 +126,9 @@ class PG(Policy):
# backprop
surrogate.backward()
for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability
torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
# self.lock.acquire() # retain lock to update weights
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment