Skip to content
Snippets Groups Projects
Commit 5902089a authored by Carel van Niekerk's avatar Carel van Niekerk :desktop:
Browse files

Refactoring

parent d8df0559
Branches
Tags
No related merge requests found
Showing
with 281 additions and 175 deletions
from convlab.dst.setsumbt.dataset.unified_format import *
\ No newline at end of file
...@@ -13,32 +13,38 @@ ...@@ -13,32 +13,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Convlab3 Unified Format Dialogue Dataset""" """Convlab3 Unified Format Dialogue Datasets"""
import torch import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers.tokenization_utils import PreTrainedTokenizer
from copy import deepcopy from copy import deepcopy
from convlab.dst.setsumbt.unified_format_data.dataset.utils import (load_dataset, get_ontology_slots, from convlab.dst.setsumbt.dataset.utils import (load_dataset, get_ontology_slots, ontology_add_values,
ontology_add_values, get_values_from_data, ontology_add_requestable_slots, get_values_from_data, ontology_add_requestable_slots,
get_requestable_slots, load_dst_data, extract_dialogues) get_requestable_slots, load_dst_data, extract_dialogues,
combine_value_sets)
# Convert dialogue examples to model input features and labels # Convert dialogue examples to model input features and labels
def convert_examples_to_features(data: list, ontology: dict, tokenizer, max_turns: int=12, max_seq_len: int=64) -> dict: def convert_examples_to_features(data: list,
''' ontology: dict,
tokenizer: PreTrainedTokenizer,
max_turns: int = 12,
max_seq_len: int = 64) -> dict:
"""
Convert dialogue examples to model input features and labels Convert dialogue examples to model input features and labels
Args: Parameters:
data (list): List of all extracted dialogues data (list): List of all extracted dialogues
ontology (dict): Ontology dictionary containing slots, slot descriptions and ontology (dict): Ontology dictionary containing slots, slot descriptions and
possible value sets including requests possible value sets including requests
tokenizer (transformers tokenizer): Tokenizer for the encoder model used tokenizer (PreTrainedTokenizer): Tokenizer for the encoder model used
max_turns (int): Maximum numbers of turns in a dialogue max_turns (int): Maximum numbers of turns in a dialogue
max_seq_len (int): Maximum number of tokens in a dialogue turn max_seq_len (int): Maximum number of tokens in a dialogue turn
Returns: Returns:
features (dict): All inputs and labels required to train the model features (dict): All inputs and labels required to train the model
''' """
features = dict() features = dict()
ontology = deepcopy(ontology) ontology = deepcopy(ontology)
...@@ -58,20 +64,23 @@ def convert_examples_to_features(data: list, ontology: dict, tokenizer, max_turn ...@@ -58,20 +64,23 @@ def convert_examples_to_features(data: list, ontology: dict, tokenizer, max_turn
dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True, dial_feats.append(tokenizer.encode_plus(usr, sys, add_special_tokens=True,
max_length=max_seq_len, padding='max_length', max_length=max_seq_len, padding='max_length',
truncation='longest_first')) truncation='longest_first'))
# Trucate # Truncate
if len(dial_feats) >= max_turns: if len(dial_feats) >= max_turns:
break break
input_feats.append(dial_feats) input_feats.append(dial_feats)
del dial_feats del dial_feats
# Perform turn level padding # Perform turn level padding
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]: if 'token_type_ids' in input_feats[0][0]:
token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] token_type_ids = [[turn['token_type_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
else: else:
token_type_ids = None token_type_ids = None
if 'attention_mask' in input_feats[0][0]: if 'attention_mask' in input_feats[0][0]:
attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) for dial in input_feats] attention_mask = [[turn['attention_mask'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats]
else: else:
attention_mask = None attention_mask = None
del input_feats del input_feats
...@@ -126,7 +135,8 @@ def convert_examples_to_features(data: list, ontology: dict, tokenizer, max_turn ...@@ -126,7 +135,8 @@ def convert_examples_to_features(data: list, ontology: dict, tokenizer, max_turn
labs = [] labs = []
for turn in dial: for turn in dial:
domain, slot = domslot.split('-', 1) domain, slot = domslot.split('-', 1)
acts = [act['intent'] for act in turn['dialogue_acts'] if act['domain'] == domain and act['slot'] == slot] acts = [act['intent'] for act in turn['dialogue_acts']
if act['domain'] == domain and act['slot'] == slot]
if acts: if acts:
act_ = acts[0] act_ = acts[0]
if act_ == 'request': if act_ == 'request':
...@@ -188,33 +198,54 @@ def convert_examples_to_features(data: list, ontology: dict, tokenizer, max_turn ...@@ -188,33 +198,54 @@ def convert_examples_to_features(data: list, ontology: dict, tokenizer, max_turn
# Unified Dataset object # Unified Dataset object
class UnifiedDataset(Dataset): class UnifiedFormatDataset(Dataset):
"""
def __init__(self, dataset_name: str, set_type: str, tokenizer, max_turns: int=12, max_seq_len:int =64, Class for preprocessing, and storing data easily from the Convlab3 unified format.
train_ratio: float=1.0): Attributes:
''' dataset_dict (dict): Dictionary containing all the data in dataset
Build Unified Dataset object ontology (dict): Set of all domain-slot-value triplets in the ontology of the model
Args: features (dict): Set of numeric features containing all inputs and labels formatted for the SetSUMBT model
"""
def __init__(self,
dataset_name: str,
set_type: str,
tokenizer: PreTrainedTokenizer,
max_turns: int = 12,
max_seq_len: int = 64,
train_ratio: float = 1.0,
seed: int = 0):
"""
Parameters:
dataset_name (str): Name of the dataset to load dataset_name (str): Name of the dataset to load
set_type (str): Subset of the dataset to load (train, validation or test) set_type (str): Subset of the dataset to load (train, validation or test)
tokenizer (transformers tokenizer): Tokenizer for the encoder model used tokenizer (transformers tokenizer): Tokenizer for the encoder model used
max_turns (int): Maximum numbers of turns in a dialogue max_turns (int): Maximum numbers of turns in a dialogue
max_seq_len (int): Maximum number of tokens in a dialogue turn max_seq_len (int): Maximum number of tokens in a dialogue turn
train_ratio (float): Fraction of training data to use during training train_ratio (float): Fraction of training data to use during training
''' seed (int): Seed governing random order of ids for subsampling
self.dataset_dict = load_dataset(dataset_name) """
self.ontology = get_ontology_slots(dataset_name) if '+' in dataset_name:
self.ontology = ontology_add_values(self.ontology, get_values_from_data(self.dataset_dict)) dataset_args = [{"dataset_name": name} for name in dataset_name.split('+')]
self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dict)) else:
dataset_args = [{"dataset_name": dataset_name}]
data = load_dst_data(self.dataset_dict, data_split=set_type, speaker='all', dialogue_acts=True, split_to_turn=False)
data = data[set_type]
if train_ratio != 1.0: if train_ratio != 1.0:
train_ratio = int(len(data) * train_ratio) for dataset_args_ in dataset_args:
data = data[:train_ratio] dataset_args_['dial_ids_order'] = seed
dataset_args_['split2ratio'] = {'train': train_ratio, 'validation': train_ratio}
data = extract_dialogues(data) self.dataset_dicts = [load_dataset(**dataset_args_) for dataset_args_ in dataset_args]
self.ontology = get_ontology_slots(dataset_name)
values = [get_values_from_data(dataset) for dataset in self.dataset_dicts]
self.ontology = ontology_add_values(self.ontology, combine_value_sets(values))
self.ontology = ontology_add_requestable_slots(self.ontology, get_requestable_slots(self.dataset_dicts))
data = [load_dst_data(dataset_dict, data_split=set_type, speaker='all',
dialogue_acts=True, split_to_turn=False)
for dataset_dict in self.dataset_dicts]
data_list = [data_[set_type] for data_ in data]
data = []
for data_ in data_list:
data += extract_dialogues(data_)
self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len) self.features = convert_examples_to_features(data, self.ontology, tokenizer, max_turns, max_seq_len)
def __getitem__(self, index): def __getitem__(self, index):
...@@ -272,7 +303,7 @@ def get_dataloader(dataset_name: str, set_type: str, batch_size: int, tokenizer, ...@@ -272,7 +303,7 @@ def get_dataloader(dataset_name: str, set_type: str, batch_size: int, tokenizer,
Returns: Returns:
loader (torch dataloader): Dataloader to train and evaluate the setsumbt model loader (torch dataloader): Dataloader to train and evaluate the setsumbt model
''' '''
data = UnifiedDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio) data = UnifiedFormatDataset(dataset_name, set_type, tokenizer, max_turns, max_seq_len, train_ratio=train_ratio)
data.to(device) data.to(device)
if resampled_size: if resampled_size:
......
...@@ -16,19 +16,7 @@ ...@@ -16,19 +16,7 @@
"""Convlab3 Unified dataset data processing utilities""" """Convlab3 Unified dataset data processing utilities"""
from convlab.util import load_dataset, load_ontology, load_dst_data, load_nlu_data from convlab.util import load_dataset, load_ontology, load_dst_data, load_nlu_data
from convlab.dst.setsumbt.dataset.value_maps import *
# MultiWOZ specific label map to avoid duplication and typos in values
VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
# Generic value sets for quantity and time slots
QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
TIME = ['%02i:%02i' % t for l in TIME for t in l]
# Load slots, descriptions and categorical slot values from dataset ontology # Load slots, descriptions and categorical slot values from dataset ontology
...@@ -41,15 +29,23 @@ def get_ontology_slots(dataset_name: str) -> dict: ...@@ -41,15 +29,23 @@ def get_ontology_slots(dataset_name: str) -> dict:
Returns: Returns:
ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
''' '''
dataset_names = dataset_name.split('+') if '+' in dataset_name else [dataset_name]
ontology_slots = dict()
for dataset_name in dataset_names:
ontology = load_ontology(dataset_name) ontology = load_ontology(dataset_name)
ontology_slots = {domain: {} for domain in ontology['domains'] if domain not in ['booking', 'general']} domains = [domain for domain in ontology['domains'] if domain not in ['booking', 'general']]
for domain in ontology_slots: for domain in domains:
domain_name = DOMAINS_MAP.get(domain, domain.lower())
if domain_name not in ontology_slots:
ontology_slots[domain_name] = dict()
for slot, slot_info in ontology['domains'][domain]['slots'].items(): for slot, slot_info in ontology['domains'][domain]['slots'].items():
ontology_slots[domain][slot] = {'description': slot_info['description']} if slot not in ontology_slots[domain_name]:
ontology_slots[domain_name][slot] = {'description': slot_info['description'],
'possible_values': list()}
if slot_info['is_categorical']: if slot_info['is_categorical']:
ontology_slots[domain][slot]['possible_values'] = slot_info['possible_values'] ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values']
else:
ontology_slots[domain][slot]['possible_values'] = [] ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values']))
return ontology_slots return ontology_slots
...@@ -69,17 +65,33 @@ def get_values_from_data(dataset: dict) -> dict: ...@@ -69,17 +65,33 @@ def get_values_from_data(dataset: dict) -> dict:
for set_type, dataset in data.items(): for set_type, dataset in data.items():
for turn in dataset: for turn in dataset:
for domain, substate in turn['state'].items(): for domain, substate in turn['state'].items():
domain_name = DOMAINS_MAP.get(domain, domain.lower())
if domain not in value_sets: if domain not in value_sets:
value_sets[domain] = {} value_sets[domain_name] = {}
for slot, value in substate.items(): for slot, value in substate.items():
if slot not in value_sets[domain]: if slot not in value_sets[domain_name]:
value_sets[domain][slot] = [] value_sets[domain_name][slot] = []
if value and value not in value_sets[domain][slot]: if value and value not in value_sets[domain_name][slot]:
value_sets[domain][slot].append(value) value_sets[domain_name][slot].append(value)
return clean_values(value_sets) return clean_values(value_sets)
def combine_value_sets(value_sets):
value_set = value_sets[0]
for _value_set in value_sets[1:]:
for domain, domain_info in _value_set.items():
for slot, possible_values in domain_info.items():
if domain not in value_set:
value_set[domain] = dict()
if slot not in value_set[domain]:
value_set[domain][slot] = list()
value_set[domain][slot] += _value_set[domain][slot]
value_set[domain][slot] = list(set(value_set[domain][slot]))
return value_set
# Clean the possible values for the ontology # Clean the possible values for the ontology
def clean_values(value_sets: dict, value_map: dict=VALUE_MAP) -> dict: def clean_values(value_sets: dict, value_map: dict=VALUE_MAP) -> dict:
''' '''
...@@ -147,7 +159,7 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict: ...@@ -147,7 +159,7 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict:
# Get set of requestable slots from the dataset action labels # Get set of requestable slots from the dataset action labels
def get_requestable_slots(dataset: dict) -> dict: def get_requestable_slots(datasets: list) -> dict:
''' '''
Function to get set of requestable slots from the dataset action labels. Function to get set of requestable slots from the dataset action labels.
Args: Args:
...@@ -156,9 +168,10 @@ def get_requestable_slots(dataset: dict) -> dict: ...@@ -156,9 +168,10 @@ def get_requestable_slots(dataset: dict) -> dict:
Returns: Returns:
slots (dict): Dictionary containing requestable domain-slot pairs slots (dict): Dictionary containing requestable domain-slot pairs
''' '''
data = load_nlu_data(dataset, data_split='all', speaker='user') datasets = [load_nlu_data(dataset, data_split='all', speaker='user') for dataset in datasets]
slots = {} slots = {}
for data in datasets:
for set_type, subset in data.items(): for set_type, subset in data.items():
for turn in subset: for turn in subset:
requests = [act for act in turn['dialogue_acts']['categorical'] if act['intent'] == 'request'] requests = [act for act in turn['dialogue_acts']['categorical'] if act['intent'] == 'request']
...@@ -166,9 +179,10 @@ def get_requestable_slots(dataset: dict) -> dict: ...@@ -166,9 +179,10 @@ def get_requestable_slots(dataset: dict) -> dict:
requests += [act for act in turn['dialogue_acts']['binary'] if act['intent'] == 'request'] requests += [act for act in turn['dialogue_acts']['binary'] if act['intent'] == 'request']
requests = [(act['domain'], act['slot']) for act in requests] requests = [(act['domain'], act['slot']) for act in requests]
for domain, slot in requests: for domain, slot in requests:
if domain not in slots: domain_name = DOMAINS_MAP.get(domain, domain.lower())
slots[domain] = [] if domain_name not in slots:
slots[domain].append(slot) slots[domain_name] = []
slots[domain_name].append(slot)
slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()} slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()}
...@@ -245,8 +259,14 @@ def clean_states(turns: list) -> list: ...@@ -245,8 +259,14 @@ def clean_states(turns: list) -> list:
clean_turns = [] clean_turns = []
for turn in turns: for turn in turns:
clean_state = {} clean_state = {}
clean_acts = []
for act in turn['dialogue_acts']:
domain = act['domain']
act['domain'] = DOMAINS_MAP.get(domain, domain.lower())
clean_acts.append(act)
for domain, subset in turn['state'].items(): for domain, subset in turn['state'].items():
clean_state[domain] = {} domain_name = DOMAINS_MAP.get(domain, domain.lower())
clean_state[domain_name] = {}
for slot, value in subset.items(): for slot, value in subset.items():
# Remove pipe separated values # Remove pipe separated values
value = value.split('|', 1)[0] value = value.split('|', 1)[0]
...@@ -311,8 +331,9 @@ def clean_states(turns: list) -> list: ...@@ -311,8 +331,9 @@ def clean_states(turns: list) -> list:
elif True in [v in value.lower() for v in ['yes', 'no']]: elif True in [v in value.lower() for v in ['yes', 'no']]:
value = [v for v in ['yes', 'no'] if v in value][0] value = [v for v in ['yes', 'no'] if v in value][0]
clean_state[domain][slot] = value clean_state[domain_name][slot] = value
turn['state'] = clean_state turn['state'] = clean_state
turn['dialogue_acts'] = clean_acts
clean_turns.append(turn) clean_turns.append(turn)
return clean_turns return clean_turns
...@@ -333,11 +354,13 @@ def get_active_domains(turns: list) -> list: ...@@ -333,11 +354,13 @@ def get_active_domains(turns: list) -> list:
if turn_id == 0: if turn_id == 0:
domains = [d for d, substate in turns[turn_id]['state'].items() for s, v in substate.items() if v != 'none'] domains = [d for d, substate in turns[turn_id]['state'].items() for s, v in substate.items() if v != 'none']
domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']] domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']]
domains = [DOMAINS_MAP.get(domain, domain.lower()) for domain in domains]
turns[turn_id]['active_domains'] = list(set(domains)) turns[turn_id]['active_domains'] = list(set(domains))
else: else:
# Use changes in domains to identify active domains # Use changes in domains to identify active domains
domains = [] domains = []
for domain, substate in turns[turn_id]['state'].items(): for domain, substate in turns[turn_id]['state'].items():
domain_name = DOMAINS_MAP.get(domain, domain.lower())
for slot, value in substate.items(): for slot, value in substate.items():
if value != turns[turn_id - 1]['state'][domain][slot]: if value != turns[turn_id - 1]['state'][domain][slot]:
val = value val = value
...@@ -346,7 +369,7 @@ def get_active_domains(turns: list) -> list: ...@@ -346,7 +369,7 @@ def get_active_domains(turns: list) -> list:
if value == 'none': if value == 'none':
val = 'none' val = 'none'
if val != 'none': if val != 'none':
domains.append(domain) domains.append(domain_name)
# Add all domains activated by a user action # Add all domains activated by a user action
domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']] domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']]
turns[turn_id]['active_domains'] = list(set(domains)) turns[turn_id]['active_domains'] = list(set(domains))
......
# MultiWOZ specific label map to avoid duplication and typos in values
VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
'worth house': 'the worth house', 'alesbray lodge guest house': 'aylesbray lodge',
'huntingdon hotel': 'huntingdon marriott hotel', 'huntingd': 'huntingdon marriott hotel',
'jamaicanchinese': 'chinese', 'barbequemodern european': 'modern european',
'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
# Domain map for SGD Data
DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buses_1': 'bus', 'Buses_2': 'bus',
'Buses_3': 'bus', 'Calendar_1': 'calendar', 'Events_1': 'events', 'Events_2': 'events',
'Events_3': 'events', 'Flights_1': 'flights', 'Flights_2': 'flights', 'Flights_3': 'flights',
'Flights_4': 'flights', 'Homes_1': 'homes', 'Homes_2': 'homes', 'Hotels_1': 'hotel',
'Hotels_2': 'hotel', 'Hotels_3': 'hotel', 'Hotels_4': 'hotel', 'Media_1': 'media',
'Media_2': 'media', 'Media_3': 'media', 'Messaging_1': 'messaging', 'Movies_1': 'movies',
'Movies_2': 'movies', 'Movies_3': 'movies', 'Music_1': 'music', 'Music_2': 'music', 'Music_3': 'music',
'Payment_1': 'payment', 'RentalCars_1': 'rentalcars', 'RentalCars_2': 'rentalcars',
'RentalCars_3': 'rentalcars', 'Restaurants_1': 'restaurant', 'Restaurants_2': 'restaurant',
'RideSharing_1': 'ridesharing', 'RideSharing_2': 'ridesharing', 'Services_1': 'services',
'Services_2': 'services', 'Services_3': 'services', 'Services_4': 'services', 'Trains_1': 'train',
'Travel_1': 'travel', 'Weather_1': 'weather', 'movie_ticket': 'movies',
'restaurant_reservation': 'restaurant', 'coffee_ordering': 'coffee', 'pizza_ordering': 'takeout',
'auto_repair': 'car_repairs', 'flights': 'flights', 'food-ordering': 'takeout', 'hotels': 'hotel',
'movies': 'movies', 'music': 'music', 'restaurant-search': 'restaurant', 'sports': 'sports',
'movie': 'movies'}
INVERSE_DOMAINS_MAP = {item: key for key, item in DOMAINS_MAP.items()}
SLOTS_MAP = {"account_balance": "balance", "transfer_amount": "amount", "from_location": "departure",
"from_station": "departure", "origin": "departure", "origin_station_name": "departure",
"from_city": "departure", "to_location": "destination", "to_station": "destination",
"destination_station_name": "destination", "to_city": "destination", "leaving_date": "departure_date",
"leaving_time": "departure_time", "fare": "price", "fare_type": "price"}
# Generic value sets for quantity and time slots
QUANTITIES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10 or more']
TIME = [[(i, j) for i in range(24)] for j in range(0, 60, 5)]
TIME = ['%02i:%02i' % t for l in TIME for t in l]
\ No newline at end of file
...@@ -29,9 +29,9 @@ from tensorboardX import SummaryWriter ...@@ -29,9 +29,9 @@ from tensorboardX import SummaryWriter
from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
from convlab.dst.setsumbt.unified_format_data import unified_format from convlab.dst.setsumbt.dataset import unified_format
from convlab.dst.setsumbt.modeling import training from convlab.dst.setsumbt.modeling import training
from convlab.dst.setsumbt.unified_format_data.dataset import ontology as embeddings from convlab.dst.setsumbt.dataset import ontology as embeddings
from convlab.dst.setsumbt.utils import get_args, update_args from convlab.dst.setsumbt.utils import get_args, update_args
# from convlab.dst.setsumbt.modeling import ensemble_utils # from convlab.dst.setsumbt.modeling import ensemble_utils
......
from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT from convlab.dst.setsumbt.modeling.bert_nbt import BertSetSUMBT
from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT from convlab.dst.setsumbt.modeling.roberta_nbt import RobertaSetSUMBT
from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT, DropoutEnsembleSetSUMBT from convlab.dst.setsumbt.modeling.ensemble_nbt import EnsembleSetSUMBT, DropoutEnsembleSetSUMBT
from convlab.dst.setsumbt.modeling.temperature_scheduler import LinearTemperatureScheduler
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -20,7 +20,7 @@ import transformers ...@@ -20,7 +20,7 @@ import transformers
from torch.autograd import Variable from torch.autograd import Variable
from transformers import BertModel, BertPreTrainedModel from transformers import BertModel, BertPreTrainedModel
from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward from convlab.dst.setsumbt.modeling.functional import initialize_setsumbt_model, nbt_forward
class BertSetSUMBT(BertPreTrainedModel): class BertSetSUMBT(BertPreTrainedModel):
...@@ -35,7 +35,7 @@ class BertSetSUMBT(BertPreTrainedModel): ...@@ -35,7 +35,7 @@ class BertSetSUMBT(BertPreTrainedModel):
for p in self.bert.parameters(): for p in self.bert.parameters():
p.requires_grad = False p.requires_grad = False
_initialise(self, config) initialize_setsumbt_model(self, config)
# Add new slot candidates to the model # Add new slot candidates to the model
def add_slot_candidates(self, slot_candidates): def add_slot_candidates(self, slot_candidates):
...@@ -103,9 +103,9 @@ class BertSetSUMBT(BertPreTrainedModel): ...@@ -103,9 +103,9 @@ class BertSetSUMBT(BertPreTrainedModel):
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
if get_turn_pooled_representation: if get_turn_pooled_representation:
return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, return nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
dialogue_size, turn_size, hidden_state, inform_labels, request_labels, domain_labels, dialogue_size, hidden_state, inform_labels, request_labels, domain_labels,
goodbye_labels, calculate_inform_mutual_info) + (bert_output.pooler_output,) goodbye_labels, calculate_inform_mutual_info) + (bert_output.pooler_output,)
return _nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, dialogue_size, return nbt_forward(self, turn_embeddings, bert_output.pooler_output, attention_mask, batch_size, dialogue_size,
turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
calculate_inform_mutual_info) calculate_inform_mutual_info)
...@@ -29,8 +29,8 @@ from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatch ...@@ -29,8 +29,8 @@ from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatch
from convlab.dst.setsumbt.loss.endd_loss import rkl_dirichlet_mediator_loss, logits_to_mutual_info from convlab.dst.setsumbt.loss.endd_loss import rkl_dirichlet_mediator_loss, logits_to_mutual_info
# Default belief tracker model intialisation function # Default belief tracker model initialization function
def _initialise(self, config): def initialize_setsumbt_model(self, config):
# Slot Utterance matching attention # Slot Utterance matching attention
self.slot_attention = MultiheadAttention(config.hidden_size, config.slot_attention_heads) self.slot_attention = MultiheadAttention(config.hidden_size, config.slot_attention_heads)
...@@ -143,12 +143,12 @@ def _initialise(self, config): ...@@ -143,12 +143,12 @@ def _initialise(self, config):
# Default belief tracker forward pass. # Default belief tracker forward pass.
def _nbt_forward(self, turn_embeddings, def nbt_forward(self,
turn_embeddings,
turn_pooled_representation, turn_pooled_representation,
attention_mask, attention_mask,
batch_size, batch_size,
dialogue_size, dialogue_size,
turn_size,
hidden_state, hidden_state,
inform_labels, inform_labels,
request_labels, request_labels,
...@@ -196,7 +196,7 @@ def _nbt_forward(self, turn_embeddings, ...@@ -196,7 +196,7 @@ def _nbt_forward(self, turn_embeddings,
turn_embeddings = turn_embeddings.transpose(0, 1) turn_embeddings = turn_embeddings.transpose(0, 1)
# Compute key padding mask # Compute key padding mask
key_padding_mask = (attention_mask[:, :, 0] == 0.0) key_padding_mask = (attention_mask[:, :, 0] == 0.0)
key_padding_mask[key_padding_mask[:, 0] == True, :] = False key_padding_mask[key_padding_mask[:, 0], :] = False
# Multi head attention of slot over tokens # Multi head attention of slot over tokens
hidden, _ = self.slot_attention(query=slot_embeddings, hidden, _ = self.slot_attention(query=slot_embeddings,
key=turn_embeddings, key=turn_embeddings,
......
...@@ -20,7 +20,7 @@ import transformers ...@@ -20,7 +20,7 @@ import transformers
from torch.autograd import Variable from torch.autograd import Variable
from transformers import RobertaModel, RobertaPreTrainedModel from transformers import RobertaModel, RobertaPreTrainedModel
from convlab.dst.setsumbt.modeling.functional import _initialise, _nbt_forward from convlab.dst.setsumbt.modeling.functional import initialize_setsumbt_model, nbt_forward
class RobertaSetSUMBT(RobertaPreTrainedModel): class RobertaSetSUMBT(RobertaPreTrainedModel):
...@@ -35,7 +35,7 @@ class RobertaSetSUMBT(RobertaPreTrainedModel): ...@@ -35,7 +35,7 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
for p in self.roberta.parameters(): for p in self.roberta.parameters():
p.requires_grad = False p.requires_grad = False
_initialise(self, config) initialize_setsumbt_model(self, config)
# Add new slot candidates to the model # Add new slot candidates to the model
...@@ -106,9 +106,9 @@ class RobertaSetSUMBT(RobertaPreTrainedModel): ...@@ -106,9 +106,9 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1) turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
if get_turn_pooled_representation: if get_turn_pooled_representation:
return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size, return nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
calculate_inform_mutual_info) + (roberta_output.pooler_output,) calculate_inform_mutual_info) + (roberta_output.pooler_output,)
return _nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size, return nbt_forward(self, turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size, dialogue_size,
turn_size, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels, hidden_state, inform_labels, request_labels, domain_labels, goodbye_labels,
calculate_inform_mutual_info) calculate_inform_mutual_info)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2021 DSML Group, Heinrich Heine University, Düsseldorf # Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de) # Authors: Carel van Niekerk (niekerk@hhu.de)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,50 +13,67 @@ ...@@ -13,50 +13,67 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Temperature Scheduler Class""" """Linear Temperature Scheduler Class"""
import torch
# Temp scheduler class for ensemble distillation
class TemperatureScheduler:
def __init__(self, total_steps, base_temp=2.5, cycle_len=0.1): # Temp scheduler class for ensemble distillation
self.state = {} class LinearTemperatureScheduler:
"""
Temperature scheduler object used for distribution temperature scheduling in distillation
Attributes:
state (dict): Internal state of scheduler
"""
def __init__(self,
total_steps: int,
base_temp: float = 2.5,
cycle_len: float = 0.1):
"""
Init function for LinearTemperatureScheduler
Args:
total_steps (int): Total number of training steps
base_temp (float): Starting temperature
cycle_len (float): Fraction of total steps used for scheduling cycle
"""
self.state = dict()
self.state['total_steps'] = total_steps self.state['total_steps'] = total_steps
self.state['current_step'] = 0 self.state['current_step'] = 0
self.state['base_temp'] = base_temp self.state['base_temp'] = base_temp
self.state['current_temp'] = base_temp self.state['current_temp'] = base_temp
self.state['cycles'] = [int(total_steps * cycle_len / 2), int(total_steps * cycle_len)] self.state['cycles'] = [int(total_steps * cycle_len / 2), int(total_steps * cycle_len)]
self.state['rate'] = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0])
def step(self): def step(self):
"""
Update temperature based on the schedule
"""
self.state['current_step'] += 1 self.state['current_step'] += 1
assert self.state['current_step'] <= self.state['total_steps'] assert self.state['current_step'] <= self.state['total_steps']
if self.state['current_step'] > self.state['cycles'][0]: if self.state['current_step'] > self.state['cycles'][0]:
if self.state['current_step'] < self.state['cycles'][1]: if self.state['current_step'] < self.state['cycles'][1]:
rate = (self.state['base_temp'] - 1.0) / (self.state['cycles'][1] - self.state['cycles'][0]) self.state['current_temp'] -= self.state['rate']
self.state['current_temp'] -= rate
else: else:
self.state['current_temp'] = 1.0 self.state['current_temp'] = 1.0
def temp(self): def temp(self):
"""
Get current temperature
Returns:
temp (float): Current temperature for distribution scaling
"""
return float(self.state['current_temp']) return float(self.state['current_temp'])
def state_dict(self): def state_dict(self):
"""
Return scheduler state
Returns:
state (dict): Dictionary format state of the scheduler
"""
return self.state return self.state
def load_state_dict(self, sd): def load_state_dict(self, state_dict: dict):
self.state = sd """
Load scheduler state from dictionary
Parameters:
# if __name__ == "__main__": state_dict (dict): Dictionary format state of the scheduler
# temp_scheduler = TemperatureScheduler(100) """
# print(temp_scheduler.state_dict()) self.state = state_dict
# temp = []
# for i in range(100):
# temp.append(temp_scheduler.temp())
# temp_scheduler.step()
# temp_scheduler.load_state_dict(temp_scheduler.state_dict())
# print(temp_scheduler.state_dict())
# print(temp)
...@@ -33,7 +33,7 @@ except: ...@@ -33,7 +33,7 @@ except:
print('Apex not used') print('Apex not used')
from convlab.dst.setsumbt.utils import clear_checkpoints from convlab.dst.setsumbt.utils import clear_checkpoints
from convlab.dst.setsumbt.modeling.temperature_scheduler import TemperatureScheduler from convlab.dst.setsumbt.modeling import LinearTemperatureScheduler
# Load logger and tensorboard summary writer # Load logger and tensorboard summary writer
......
from convlab.dst.setsumbt.unified_format_data.dataset import unified_format, ontology
from convlab.dst.setsumbt.unified_format_data.Tracker import SetSUMBTTracker
\ No newline at end of file
...@@ -21,22 +21,25 @@ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser ...@@ -21,22 +21,25 @@ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from datetime import datetime from datetime import datetime
def get_args(MODELS): def get_args(base_models: dict):
# Get arguments # Get arguments
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
# Optional # Optional
parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='') parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='')
parser.add_argument('--logging_path', help='Path for log file', default='') parser.add_argument('--logging_path', help='Path for log file', default='')
parser.add_argument('--seed', help='Seed value for reproducability', default=0, type=int) parser.add_argument('--seed', help='Seed value for reproducibility', default=0, type=int)
# DATASET (Optional) # DATASET (Optional)
parser.add_argument('--dataset', help='Dataset Name: multiwoz21/simr', default='multiwoz21') parser.add_argument('--dataset', help='Dataset Name (See Convlab 3 unified format for possible datasets',
parser.add_argument('--dataset_train_ratio', help='Fraction of training set to use in training', default=1.0, type=float) default='multiwoz21')
parser.add_argument('--dataset_train_ratio', help='Fraction of training set to use in training', default=1.0,
type=float)
parser.add_argument('--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int) parser.add_argument('--max_dialogue_len', help='Maximum number of turns per dialogue', default=12, type=int)
parser.add_argument('--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int) parser.add_argument('--max_turn_len', help='Maximum number of tokens per turn', default=64, type=int)
parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int) parser.add_argument('--max_slot_len', help='Maximum number of tokens per slot description', default=12, type=int)
parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12, type=int) parser.add_argument('--max_candidate_len', help='Maximum number of tokens per value candidate', default=12,
type=int)
parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.') parser.add_argument('--force_processing', action='store_true', help='Force preprocessing of data.')
parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int) parser.add_argument('--data_sampling_size', help='Resampled dataset size', default=-1, type=int)
parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings', parser.add_argument('--no_descriptions', help='Do not use slot descriptions rather than slot names for embeddings',
...@@ -64,26 +67,28 @@ def get_args(MODELS): ...@@ -64,26 +67,28 @@ def get_args(MODELS):
parser.add_argument('--distance_measure', default='cosine', parser.add_argument('--distance_measure', default='cosine',
help='Similarity measure for candidate scoring: cosine/euclidean') help='Similarity measure for candidate scoring: cosine/euclidean')
# parser.add_argument('--ensemble_size', help='Number of models in ensemble', default=-1, type=int) # parser.add_argument('--ensemble_size', help='Number of models in ensemble', default=-1, type=int)
parser.add_argument('--no_set_similarity', action='store_true', parser.add_argument('--no_set_similarity', action='store_true', help='Set True to not use set similarity')
help='Set True to not use set similarity (Model tracks latent belief state as sequence and performs semantic similarity of sets)') parser.add_argument('--set_pooling',
parser.add_argument('--set_pooling', help='Set pooling method for set similarity model using single embedding distances', help='Set pooling method for set similarity model using single embedding distances',
default='cnn') default='cnn')
parser.add_argument('--candidate_pooling', help='Pooling approach for non set based candidate representations: cls/mean', parser.add_argument('--candidate_pooling',
help='Pooling approach for non set based candidate representations: cls/mean',
default='mean') default='mean')
parser.add_argument('--no_action_prediction', help='Model does not predicts user actions and active domain', parser.add_argument('--no_action_prediction', help='Model does not predicts user actions and active domain',
action='store_true') action='store_true')
# Loss # Loss
parser.add_argument('--loss_function', help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/distillation/distribution_distillation', parser.add_argument('--loss_function',
help='Loss Function for training: crossentropy/bayesianmatching/labelsmoothing/...',
default='labelsmoothing') default='labelsmoothing')
parser.add_argument('--kl_scaling_factor', help='Scaling factor for KL divergence in bayesian matching loss', parser.add_argument('--kl_scaling_factor', help='Scaling factor for KL divergence in bayesian matching loss',
type=float) type=float)
parser.add_argument('--prior_constant', help='Constant parameter for prior in bayesian matching loss', parser.add_argument('--prior_constant', help='Constant parameter for prior in bayesian matching loss',
type=float) type=float)
parser.add_argument('--ensemble_smoothing', help='Ensemble distribution smoothing constant', type=float) parser.add_argument('--ensemble_smoothing', help='Ensemble distribution smoothing constant', type=float)
parser.add_argument('--annealing_base_temp', help='Ensemble Distribution destillation temp annealing base temp', parser.add_argument('--annealing_base_temp', help='Ensemble Distribution distillation temp annealing base temp',
type=float) type=float)
parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution destillation temp annealing cycle length', parser.add_argument('--annealing_cycle_len', help='Ensemble Distribution distillation temp annealing cycle length',
type=float) type=float)
parser.add_argument('--label_smoothing', help='Label smoothing coefficient.', type=float) parser.add_argument('--label_smoothing', help='Label smoothing coefficient.', type=float)
parser.add_argument('--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0', parser.add_argument('--user_goal_loss_weight', help='Weight of the user goal prediction loss. 0.0<weight<=1.0',
...@@ -102,7 +107,7 @@ def get_args(MODELS): ...@@ -102,7 +107,7 @@ def get_args(MODELS):
help='Number of batches accumulated for one update step') help='Number of batches accumulated for one update step')
parser.add_argument('--num_train_epochs', help='Number of training epochs', default=50, type=int) parser.add_argument('--num_train_epochs', help='Number of training epochs', default=50, type=int)
parser.add_argument('--patience', help='Number of training steps without improving model before stopping.', parser.add_argument('--patience', help='Number of training steps without improving model before stopping.',
default=25, type=int) default=20, type=int)
parser.add_argument('--weight_decay', help='Weight decay rate', default=0.01, type=float) parser.add_argument('--weight_decay', help='Weight decay rate', default=0.01, type=float)
parser.add_argument('--learning_rate', help='Initial Learning Rate', default=5e-5, type=float) parser.add_argument('--learning_rate', help='Initial Learning Rate', default=5e-5, type=float)
parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', default=0.2, type=float) parser.add_argument('--warmup_proportion', help='Warmup proportion for linear scheduler', default=0.2, type=float)
...@@ -134,7 +139,7 @@ def get_args(MODELS): ...@@ -134,7 +139,7 @@ def get_args(MODELS):
# RUN_NBT ACTIONS # RUN_NBT ACTIONS
parser.add_argument('--do_train', help='Perform training', action='store_true') parser.add_argument('--do_train', help='Perform training', action='store_true')
parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true') parser.add_argument('--do_eval', help='Perform model evaluation during training', action='store_true')
parser.add_argument('--do_test', help='Evaulate model on test data', action='store_true') parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true')
args = parser.parse_args() args = parser.parse_args()
# Simplify args # Simplify args
...@@ -147,8 +152,8 @@ def get_args(MODELS): ...@@ -147,8 +152,8 @@ def get_args(MODELS):
args.output_dir = os.path.dirname(os.path.abspath(__file__)) args.output_dir = os.path.dirname(os.path.abspath(__file__))
args.output_dir = os.path.join(args.output_dir, 'models') args.output_dir = os.path.join(args.output_dir, 'models')
name = 'SetSUMBT' name = 'SetSUMBT' if args.set_similarity else 'SUMBT'
name += '-Acts' if args.predict_actions else '' name += '+ActPrediction' if args.predict_actions else ''
name += '-' + args.dataset name += '-' + args.dataset
name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else '' name += '-' + str(round(args.dataset_train_ratio*100)) + '%' if args.dataset_train_ratio != 1.0 else ''
name += '-' + args.model_type name += '-' + args.model_type
...@@ -166,9 +171,6 @@ def get_args(MODELS): ...@@ -166,9 +171,6 @@ def get_args(MODELS):
args.kl_scaling_factor = 0.001 args.kl_scaling_factor = 0.001
if not args.prior_constant: if not args.prior_constant:
args.prior_constant = 1.0 args.prior_constant = 1.0
if args.loss_function == 'inhibitedce':
if not args.inhibiting_factor:
args.inhibiting_factor = 1.0
if args.loss_function == 'labelsmoothing': if args.loss_function == 'labelsmoothing':
if not args.label_smoothing: if not args.label_smoothing:
args.label_smoothing = 0.05 args.label_smoothing = 0.05
...@@ -191,10 +193,8 @@ def get_args(MODELS): ...@@ -191,10 +193,8 @@ def get_args(MODELS):
if not args.active_domain_loss_weight: if not args.active_domain_loss_weight:
args.active_domain_loss_weight = 0.2 args.active_domain_loss_weight = 0.2
args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join( args.tensorboard_path = args.tensorboard_path if args.tensorboard_path else os.path.join(args.output_dir, 'tb_logs')
args.output_dir, 'tb_logs') args.logging_path = args.logging_path if args.logging_path else os.path.join(args.output_dir, 'run.log')
args.logging_path = args.logging_path if args.logging_path else os.path.join(
args.output_dir, 'run.log')
# Default model_name's # Default model_name's
if not args.model_name_or_path: if not args.model_name_or_path:
...@@ -208,28 +208,22 @@ def get_args(MODELS): ...@@ -208,28 +208,22 @@ def get_args(MODELS):
if not args.candidate_embedding_model_name: if not args.candidate_embedding_model_name:
args.candidate_embedding_model_name = args.model_name_or_path args.candidate_embedding_model_name = args.model_name_or_path
if args.model_type in MODELS: if args.model_type in base_models:
configClass = MODELS[args.model_type][-2] config_class = base_models[args.model_type][-2]
else: else:
raise NameError('NotImplemented') raise NameError('NotImplemented')
config = build_config(configClass, args) config = build_config(config_class, args)
return args, config return args, config
def build_config(configClass, args): def build_config(config_class, args):
if args.model_type == 'fasttext': config = config_class.from_pretrained(args.model_name_or_path)
config = configClass.from_pretrained('bert-base-uncased') if not os.path.exists(args.model_name_or_path):
config.model_type == 'fasttext'
config.fasttext_path = args.model_name_or_path
config.vocab_size = None
elif not os.path.exists(args.model_name_or_path):
config = configClass.from_pretrained(args.model_name_or_path)
config.tokenizer_name = args.model_name_or_path config.tokenizer_name = args.model_name_or_path
elif 'tod-bert' in args.model_name_or_path.lower(): try:
config = configClass.from_pretrained(args.model_name_or_path) config.tokenizer_name = config.tokenizer_name
except AttributeError:
config.tokenizer_name = args.model_name_or_path config.tokenizer_name = args.model_name_or_path
else:
config = configClass.from_pretrained(args.model_name_or_path)
if args.candidate_embedding_model_name: if args.candidate_embedding_model_name:
config.candidate_embedding_model_name = args.candidate_embedding_model_name config.candidate_embedding_model_name = args.candidate_embedding_model_name
config.max_dialogue_len = args.max_dialogue_len config.max_dialogue_len = args.max_dialogue_len
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment