Skip to content
Snippets Groups Projects
Unverified Commit e0245c1b authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Fix TripPy multiprocessing pickling bug (#129)


* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Fix pickling error of TransformerForDST Trippy Class

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

* Setsumbt info dict

* Fix generate function for SCGPT

* SCGPT default device GPU

Co-authored-by: default avatarCarel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: default avatarMichael Heck <michael.heck@hhu.de>
Co-authored-by: default avatarChristian Geishauser <christian.geishauser@hhu.de>
parent e7f924e9
No related branches found
No related tags found
No related merge requests found
...@@ -62,7 +62,7 @@ def TransformerForDST(parent_name): ...@@ -62,7 +62,7 @@ def TransformerForDST(parent_name):
class TransformerForDST(PARENT_CLASSES[parent_name]): class TransformerForDST(PARENT_CLASSES[parent_name]):
def __init__(self, config): def __init__(self, config):
assert config.model_type in PARENT_CLASSES assert config.model_type in PARENT_CLASSES
assert self.__class__.__bases__[0] in MODEL_CLASSES # assert self.__class__.__bases__[0] in MODEL_CLASSES
super(TransformerForDST, self).__init__(config) super(TransformerForDST, self).__init__(config)
self.model_type = config.model_type self.model_type = config.model_type
self.slot_list = config.dst_slot_list self.slot_list = config.dst_slot_list
...@@ -82,7 +82,7 @@ def TransformerForDST(parent_name): ...@@ -82,7 +82,7 @@ def TransformerForDST(parent_name):
self.refer_index = -1 self.refer_index = -1
# Make sure this module has the same name as in the pretrained checkpoint you want to load! # Make sure this module has the same name as in the pretrained checkpoint you want to load!
self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config)) self.add_module(self.model_type, MODEL_CLASSES[PARENT_CLASSES[self.model_type]](config))
if self.model_type == "electra": if self.model_type == "electra":
self.pooler = ElectraPooler(config) self.pooler = ElectraPooler(config)
......
...@@ -30,10 +30,15 @@ from convlab.dst.trippy.modeling_dst import (TransformerForDST) ...@@ -30,10 +30,15 @@ from convlab.dst.trippy.modeling_dst import (TransformerForDST)
from convlab.dst.trippy.dataset_interfacer import (create_dataset_interfacer) from convlab.dst.trippy.dataset_interfacer import (create_dataset_interfacer)
from convlab.util import relative_import_module_from_unified_datasets from convlab.util import relative_import_module_from_unified_datasets
class BertForDST(TransformerForDST('bert')): pass
class RobertaForDST(TransformerForDST('roberta')): pass
class ElectraForDST(TransformerForDST('electra')): pass
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer), 'bert': (BertConfig, BertForDST, BertTokenizer),
'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer), 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer),
'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer), 'electra': (ElectraConfig, ElectraForDST, ElectraTokenizer),
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment