From 2422980acd1abee6b4ab2d146718c6d869ea770f Mon Sep 17 00:00:00 2001
From: liangrz <liangrz15@mails.tsinghua.edu.cn>
Date: Wed, 10 Jun 2020 19:13:06 +0800
Subject: [PATCH] fix the bug of nan gradient

---
 convlab2/policy/gdpl/gdpl.py  | 3 +++
 convlab2/policy/mle/loader.py | 2 +-
 convlab2/policy/pg/pg.py      | 3 +++
 3 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/convlab2/policy/gdpl/gdpl.py b/convlab2/policy/gdpl/gdpl.py
index c40be3b3..7d213be5 100755
--- a/convlab2/policy/gdpl/gdpl.py
+++ b/convlab2/policy/gdpl/gdpl.py
@@ -176,6 +176,9 @@ class GDPL(Policy):
 
                 # backprop
                 surrogate.backward()
+
+                for p in self.policy.parameters():
+                    p.grad[p.grad != p.grad] = 0.0
                 # gradient clipping, for stability
                 torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
                 # self.lock.acquire() # retain lock to update weights
diff --git a/convlab2/policy/mle/loader.py b/convlab2/policy/mle/loader.py
index b9b62efd..02119d2a 100755
--- a/convlab2/policy/mle/loader.py
+++ b/convlab2/policy/mle/loader.py
@@ -17,7 +17,7 @@ class ActMLEPolicyDataLoader():
         data_loader = ActPolicyDataloader(dataset_dataloader=MultiWOZDataloader())
         for part in ['train', 'val', 'test']:
             self.data[part] = []
-            raw_data = data_loader.load_data(data_key=part, role='system')[part]
+            raw_data = data_loader.load_data(data_key=part, role='sys')[part]
             
             for belief_state, context_dialog_act, terminated, dialog_act in \
                 zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']):
diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py
index b7c6efc1..82df69a8 100755
--- a/convlab2/policy/pg/pg.py
+++ b/convlab2/policy/pg/pg.py
@@ -126,6 +126,9 @@ class PG(Policy):
 
                 # backprop
                 surrogate.backward()
+                
+                for p in self.policy.parameters():
+                    p.grad[p.grad != p.grad] = 0.0
                 # gradient clipping, for stability
                 torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
                 # self.lock.acquire() # retain lock to update weights
-- 
GitLab