Skip to content
Snippets Groups Projects
Commit 2422980a authored by liangrz's avatar liangrz Committed by zhuqi
Browse files

fix the bug of nan gradient

parent c2c24b42
Branches
No related tags found
No related merge requests found
...@@ -176,6 +176,9 @@ class GDPL(Policy): ...@@ -176,6 +176,9 @@ class GDPL(Policy):
# backprop # backprop
surrogate.backward() surrogate.backward()
for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability # gradient clipping, for stability
torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10) torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
# self.lock.acquire() # retain lock to update weights # self.lock.acquire() # retain lock to update weights
......
...@@ -17,7 +17,7 @@ class ActMLEPolicyDataLoader(): ...@@ -17,7 +17,7 @@ class ActMLEPolicyDataLoader():
data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader()) data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader())
for part in ['train', 'val', 'test']: for part in ['train', 'val', 'test']:
self.data[part] = [] self.data[part] = []
raw_data = data_loader.load_data(data_key=part, role='system')[part] raw_data = data_loader.load_data(data_key=part, role='sys')[part]
for belief_state, context_dialog_act, terminated, dialog_act in \ for belief_state, context_dialog_act, terminated, dialog_act in \
zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']): zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']):
......
...@@ -126,6 +126,9 @@ class PG(Policy): ...@@ -126,6 +126,9 @@ class PG(Policy):
# backprop # backprop
surrogate.backward() surrogate.backward()
for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability # gradient clipping, for stability
torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10) torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
# self.lock.acquire() # retain lock to update weights # self.lock.acquire() # retain lock to update weights
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment