From e0245c1bf4171d155de256d8efa8a9961e3ad2e8 Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com> Date: Thu, 26 Jan 2023 16:57:36 +0100 Subject: [PATCH] Fix TripPy multiprocessing pickling bug (#129) * Seperate test and train domains * Add progress bars in ontology embedder * Update custom_util.py * Fix custom_util things I broke * Github master * Save dialogue ids in prediction file * Fix bug in ontology enxtraction * Return dialogue ids in predictions file and fix bugs * Add setsumbt starting config loader * Add script to extract golden labels from dataset to match model predictions * Add more setsumbt configs * Add option to use local files only in transformers package * Update starting configurations for setsumbt * Github master * Update README.md * Update README.md * Update convlab/dialog_agent/agent.py * Revert custom_util.py * Update custom_util.py * Commit unverified chnages :(:(:(:( * Fix SetSUMBT bug resulting from new torch feature * Fix pickling error of TransformerForDST Trippy Class * Setsumbt bug fixes * Policy config refactor * Policy config refactor * small bug fix in memory with new config path * Setsumbt info dict * Fix generate function for SCGPT * SCGPT default device GPU Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de> Co-authored-by: Michael Heck <michael.heck@hhu.de> Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de> --- convlab/dst/trippy/modeling_dst.py | 4 ++-- convlab/dst/trippy/tracker.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/convlab/dst/trippy/modeling_dst.py b/convlab/dst/trippy/modeling_dst.py index 2828d17e..3bd875b6 100644 --- a/convlab/dst/trippy/modeling_dst.py +++ b/convlab/dst/trippy/modeling_dst.py @@ -62,7 +62,7 @@ def TransformerForDST(parent_name): class TransformerForDST(PARENT_CLASSES[parent_name]): def __init__(self, config): assert config.model_type in PARENT_CLASSES - assert self.__class__.__bases__[0] in MODEL_CLASSES + # assert self.__class__.__bases__[0] in MODEL_CLASSES super(TransformerForDST, self).__init__(config) self.model_type = config.model_type self.slot_list = config.dst_slot_list @@ -82,7 +82,7 @@ def TransformerForDST(parent_name): self.refer_index = -1 # Make sure this module has the same name as in the pretrained checkpoint you want to load! - self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config)) + self.add_module(self.model_type, MODEL_CLASSES[PARENT_CLASSES[self.model_type]](config)) if self.model_type == "electra": self.pooler = ElectraPooler(config) diff --git a/convlab/dst/trippy/tracker.py b/convlab/dst/trippy/tracker.py index b0470266..8ceaeddd 100644 --- a/convlab/dst/trippy/tracker.py +++ b/convlab/dst/trippy/tracker.py @@ -30,10 +30,15 @@ from convlab.dst.trippy.modeling_dst import (TransformerForDST) from convlab.dst.trippy.dataset_interfacer import (create_dataset_interfacer) from convlab.util import relative_import_module_from_unified_datasets + +class BertForDST(TransformerForDST('bert')): pass +class RobertaForDST(TransformerForDST('roberta')): pass +class ElectraForDST(TransformerForDST('electra')): pass + MODEL_CLASSES = { - 'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer), - 'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer), - 'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer), + 'bert': (BertConfig, BertForDST, BertTokenizer), + 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), + 'electra': (ElectraConfig, ElectraForDST, ElectraTokenizer), } -- GitLab