From 2422980acd1abee6b4ab2d146718c6d869ea770f Mon Sep 17 00:00:00 2001 From: liangrz <liangrz15@mails.tsinghua.edu.cn> Date: Wed, 10 Jun 2020 19:13:06 +0800 Subject: [PATCH] fix the bug of nan gradient --- convlab2/policy/gdpl/gdpl.py | 3 +++ convlab2/policy/mle/loader.py | 2 +- convlab2/policy/pg/pg.py | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/convlab2/policy/gdpl/gdpl.py b/convlab2/policy/gdpl/gdpl.py index c40be3b3..7d213be5 100755 --- a/convlab2/policy/gdpl/gdpl.py +++ b/convlab2/policy/gdpl/gdpl.py @@ -176,6 +176,9 @@ class GDPL(Policy): # backprop surrogate.backward() + + for p in self.policy.parameters(): + p.grad[p.grad != p.grad] = 0.0 # gradient clipping, for stability torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10) # self.lock.acquire() # retain lock to update weights diff --git a/convlab2/policy/mle/loader.py b/convlab2/policy/mle/loader.py index b9b62efd..02119d2a 100755 --- a/convlab2/policy/mle/loader.py +++ b/convlab2/policy/mle/loader.py @@ -17,7 +17,7 @@ class ActMLEPolicyDataLoader(): data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader()) for part in ['train', 'val', 'test']: self.data[part] = [] - raw_data = data_loader.load_data(data_key=part, role='system')[part] + raw_data = data_loader.load_data(data_key=part, role='sys')[part] for belief_state, context_dialog_act, terminated, dialog_act in \ zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']): diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py index b7c6efc1..82df69a8 100755 --- a/convlab2/policy/pg/pg.py +++ b/convlab2/policy/pg/pg.py @@ -126,6 +126,9 @@ class PG(Policy): # backprop surrogate.backward() + + for p in self.policy.parameters(): + p.grad[p.grad != p.grad] = 0.0 # gradient clipping, for stability torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10) # self.lock.acquire() # retain lock to update weights -- GitLab