From e0245c1bf4171d155de256d8efa8a9961e3ad2e8 Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com>
Date: Thu, 26 Jan 2023 16:57:36 +0100
Subject: [PATCH] Fix TripPy multiprocessing pickling bug (#129)

* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Fix pickling error of TransformerForDST Trippy Class

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

* Setsumbt info dict

* Fix generate function for SCGPT

* SCGPT default device GPU

Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: Michael Heck <michael.heck@hhu.de>
Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de>
---
 convlab/dst/trippy/modeling_dst.py |  4 ++--
 convlab/dst/trippy/tracker.py      | 11 ++++++++---
 2 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/convlab/dst/trippy/modeling_dst.py b/convlab/dst/trippy/modeling_dst.py
index 2828d17e..3bd875b6 100644
--- a/convlab/dst/trippy/modeling_dst.py
+++ b/convlab/dst/trippy/modeling_dst.py
@@ -62,7 +62,7 @@ def TransformerForDST(parent_name):
     class TransformerForDST(PARENT_CLASSES[parent_name]):
         def __init__(self, config):
             assert config.model_type in PARENT_CLASSES
-            assert self.__class__.__bases__[0] in MODEL_CLASSES
+            # assert self.__class__.__bases__[0] in MODEL_CLASSES
             super(TransformerForDST, self).__init__(config)
             self.model_type = config.model_type
             self.slot_list = config.dst_slot_list
@@ -82,7 +82,7 @@ def TransformerForDST(parent_name):
                 self.refer_index = -1
 
             # Make sure this module has the same name as in the pretrained checkpoint you want to load!
-            self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config))
+            self.add_module(self.model_type, MODEL_CLASSES[PARENT_CLASSES[self.model_type]](config))
             if self.model_type == "electra":
                 self.pooler = ElectraPooler(config)
             
diff --git a/convlab/dst/trippy/tracker.py b/convlab/dst/trippy/tracker.py
index b0470266..8ceaeddd 100644
--- a/convlab/dst/trippy/tracker.py
+++ b/convlab/dst/trippy/tracker.py
@@ -30,10 +30,15 @@ from convlab.dst.trippy.modeling_dst import (TransformerForDST)
 from convlab.dst.trippy.dataset_interfacer import (create_dataset_interfacer)
 from convlab.util import relative_import_module_from_unified_datasets
 
+
+class BertForDST(TransformerForDST('bert')): pass
+class RobertaForDST(TransformerForDST('roberta')): pass
+class ElectraForDST(TransformerForDST('electra')): pass
+
 MODEL_CLASSES = {
-    'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer),
-    'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer),
-    'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer),
+    'bert': (BertConfig, BertForDST, BertTokenizer),
+    'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
+    'electra': (ElectraConfig, ElectraForDST, ElectraTokenizer),
 }
 
 
-- 
GitLab