Skip to content
Snippets Groups Projects
Commit d7872038 authored by Christian's avatar Christian
Browse files

Merge branch 'github_master' of...

Merge branch 'github_master' of gitlab.cs.uni-duesseldorf.de:dsml/convlab/ConvLab3 into github_master
parents ecfa6933 16304fcd
No related branches found
No related tags found
No related merge requests found
Showing
with 300 additions and 45 deletions
...@@ -7,7 +7,6 @@ def evaluate(predict_result): ...@@ -7,7 +7,6 @@ def evaluate(predict_result):
metrics = {'TP':0, 'FP':0, 'FN':0} metrics = {'TP':0, 'FP':0, 'FN':0}
acc = [] acc = []
for sample in predict_result: for sample in predict_result:
pred_state = sample['predictions']['state'] pred_state = sample['predictions']['state']
gold_state = sample['state'] gold_state = sample['state']
......
{
"model_type": "SetSUMBT",
"dataset": "multiwoz21+sgd+tm1+tm2+tm3",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 8,
"test_batch_size": 8,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "multiwoz21",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 16,
"test_batch_size": 16,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "sgd+tm1+tm2+tm3",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 12,
"test_batch_size": 12,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "sgd",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 6,
"test_batch_size": 3,
"run_nbt": true
}
\ No newline at end of file
{
"model_type": "SetSUMBT",
"dataset": "tm1+tm2+tm3",
"no_action_prediction": true,
"model_name_or_path": "/gpfs/project/niekerk/models/transformers/roberta-base",
"transformers_local_files_only": true,
"train_batch_size": 3,
"dev_batch_size": 8,
"test_batch_size": 8,
"run_nbt": true
}
\ No newline at end of file
...@@ -27,7 +27,7 @@ from convlab.util import load_dataset ...@@ -27,7 +27,7 @@ from convlab.util import load_dataset
from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values, from convlab.dst.setsumbt.dataset.utils import (get_ontology_slots, ontology_add_values,
get_values_from_data, ontology_add_requestable_slots, get_values_from_data, ontology_add_requestable_slots,
get_requestable_slots, load_dst_data, extract_dialogues, get_requestable_slots, load_dst_data, extract_dialogues,
combine_value_sets) combine_value_sets, IdTensor)
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
...@@ -77,6 +77,11 @@ def convert_examples_to_features(data: list, ...@@ -77,6 +77,11 @@ def convert_examples_to_features(data: list,
del dial_feats del dial_feats
# Perform turn level padding # Perform turn level padding
dial_ids = list()
for dial in data:
_ids = [turn['dialogue_id'] for turn in dial][:max_turns]
_ids += [''] * (max_turns - len(_ids))
dial_ids.append(_ids)
input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial)) input_ids = [[turn['input_ids'] for turn in dial] + [[0] * max_seq_len] * (max_turns - len(dial))
for dial in input_feats] for dial in input_feats]
if 'token_type_ids' in input_feats[0][0]: if 'token_type_ids' in input_feats[0][0]:
...@@ -92,6 +97,7 @@ def convert_examples_to_features(data: list, ...@@ -92,6 +97,7 @@ def convert_examples_to_features(data: list,
del input_feats del input_feats
# Create torch data tensors # Create torch data tensors
features['dialogue_ids'] = IdTensor(dial_ids)
features['input_ids'] = torch.tensor(input_ids) features['input_ids'] = torch.tensor(input_ids)
features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None features['token_type_ids'] = torch.tensor(token_type_ids) if token_type_ids else None
features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None features['attention_mask'] = torch.tensor(attention_mask) if attention_mask else None
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
# limitations under the License. # limitations under the License.
"""Convlab3 Unified dataset data processing utilities""" """Convlab3 Unified dataset data processing utilities"""
import numpy
import pdb
from convlab.util import load_ontology, load_dst_data, load_nlu_data from convlab.util import load_ontology, load_dst_data, load_nlu_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
...@@ -66,7 +69,9 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict: ...@@ -66,7 +69,9 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
data = load_dst_data(dataset, data_split='all', speaker='user') data = load_dst_data(dataset, data_split='all', speaker='user')
# Remove test data from the data when building training/validation ontology # Remove test data from the data when building training/validation ontology
if data_split in ['train', 'validation']: if data_split == 'train':
data = {key: itm for key, itm in data.items() if key == 'train'}
elif data_split == 'validation':
data = {key: itm for key, itm in data.items() if key in ['train', 'validation']} data = {key: itm for key, itm in data.items() if key in ['train', 'validation']}
value_sets = {} value_sets = {}
...@@ -74,13 +79,14 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict: ...@@ -74,13 +79,14 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
for turn in dataset: for turn in dataset:
for domain, substate in turn['state'].items(): for domain, substate in turn['state'].items():
domain_name = DOMAINS_MAP.get(domain, domain.lower()) domain_name = DOMAINS_MAP.get(domain, domain.lower())
if domain not in value_sets: if domain_name not in value_sets:
value_sets[domain_name] = {} value_sets[domain_name] = {}
for slot, value in substate.items(): for slot, value in substate.items():
if slot not in value_sets[domain_name]: if slot not in value_sets[domain_name]:
value_sets[domain_name][slot] = [] value_sets[domain_name][slot] = []
if value and value not in value_sets[domain_name][slot]: if value and value not in value_sets[domain_name][slot]:
value_sets[domain_name][slot].append(value) value_sets[domain_name][slot].append(value)
# pdb.set_trace()
return clean_values(value_sets) return clean_values(value_sets)
...@@ -163,6 +169,9 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict, data_split: str ...@@ -163,6 +169,9 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict, data_split: str
if data_split in ['train', 'validation']: if data_split in ['train', 'validation']:
if domain not in value_sets: if domain not in value_sets:
continue continue
possible_values = [v for slot, vals in value_sets[domain].items() for v in vals]
if len(possible_values) == 0:
continue
ontology[domain] = {} ontology[domain] = {}
for slot in sorted(ontology_slots[domain]): for slot in sorted(ontology_slots[domain]):
if not ontology_slots[domain][slot]['possible_values']: if not ontology_slots[domain][slot]['possible_values']:
...@@ -228,12 +237,13 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict ...@@ -228,12 +237,13 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict
return ontology_slots return ontology_slots
def extract_turns(dialogue: list, dataset_name: str) -> list: def extract_turns(dialogue: list, dataset_name: str, dialogue_id: str) -> list:
""" """
Extract the required information from the data provided by unified loader Extract the required information from the data provided by unified loader
Args: Args:
dialogue (list): List of turns within a dialogue dialogue (list): List of turns within a dialogue
dataset_name (str): Name of the dataset to which the dialogue belongs dataset_name (str): Name of the dataset to which the dialogue belongs
dialogue_str (str): ID of the dialogue
Returns: Returns:
turns (list): List of turns within a dialogue turns (list): List of turns within a dialogue
...@@ -261,6 +271,7 @@ def extract_turns(dialogue: list, dataset_name: str) -> list: ...@@ -261,6 +271,7 @@ def extract_turns(dialogue: list, dataset_name: str) -> list:
turn_info['state'] = turn['state'] turn_info['state'] = turn['state']
turn_info['dataset_name'] = dataset_name turn_info['dataset_name'] = dataset_name
turn_info['dialogue_id'] = dialogue_id
if 'system_utterance' in turn_info and 'user_utterance' in turn_info: if 'system_utterance' in turn_info and 'user_utterance' in turn_info:
turns.append(turn_info) turns.append(turn_info)
...@@ -399,6 +410,17 @@ def get_active_domains(turns: list) -> list: ...@@ -399,6 +410,17 @@ def get_active_domains(turns: list) -> list:
return turns return turns
class IdTensor:
def __init__(self, values):
self.values = numpy.array(values)
def __getitem__(self, index: int):
return self.values[index].tolist()
def to(self, device):
return self
def extract_dialogues(data: list, dataset_name: str) -> list: def extract_dialogues(data: list, dataset_name: str) -> list:
""" """
Extract all dialogues from dataset Extract all dialogues from dataset
...@@ -411,7 +433,8 @@ def extract_dialogues(data: list, dataset_name: str) -> list: ...@@ -411,7 +433,8 @@ def extract_dialogues(data: list, dataset_name: str) -> list:
""" """
dialogues = [] dialogues = []
for dial in data: for dial in data:
turns = extract_turns(dial['turns'], dataset_name) dial_id = dial['dialogue_id']
turns = extract_turns(dial['turns'], dataset_name, dial_id)
turns = clean_states(turns) turns = clean_states(turns)
turns = get_active_domains(turns) turns = get_active_domains(turns)
dialogues.append(turns) dialogues.append(turns)
......
...@@ -65,13 +65,15 @@ def main(args=None, config=None): ...@@ -65,13 +65,15 @@ def main(args=None, config=None):
paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else [] paths = os.listdir(args.output_dir) if os.path.exists(args.output_dir) else []
if 'pytorch_model.bin' in paths and 'config.json' in paths: if 'pytorch_model.bin' in paths and 'config.json' in paths:
args.model_name_or_path = args.output_dir args.model_name_or_path = args.output_dir
config = ConfigClass.from_pretrained(args.model_name_or_path) config = ConfigClass.from_pretrained(args.model_name_or_path,
local_files_only=args.transformers_local_files_only)
else: else:
paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p] paths = [os.path.join(args.output_dir, p) for p in paths if 'checkpoint-' in p]
if paths: if paths:
paths = paths[0] paths = paths[0]
args.model_name_or_path = paths args.model_name_or_path = paths
config = ConfigClass.from_pretrained(args.model_name_or_path) config = ConfigClass.from_pretrained(args.model_name_or_path,
local_files_only=args.transformers_local_files_only)
args = update_args(args, config) args = update_args(args, config)
...@@ -102,12 +104,15 @@ def main(args=None, config=None): ...@@ -102,12 +104,15 @@ def main(args=None, config=None):
# Initialise Model # Initialise Model
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config) model = SetSumbtModel.from_pretrained(args.model_name_or_path, config=config,
local_files_only=args.transformers_local_files_only)
model = model.to(device) model = model.to(device)
# Create Tokenizer and embedding model for Data Loaders and ontology # Create Tokenizer and embedding model for Data Loaders and ontology
encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name) encoder = CandidateEncoderModel.from_pretrained(config.candidate_embedding_model_name,
tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config) local_files_only=args.transformers_local_files_only)
tokenizer = Tokenizer.from_pretrained(config.tokenizer_name, config=config,
local_files_only=args.transformers_local_files_only)
# Set up model training/evaluation # Set up model training/evaluation
training.set_logger(logger, tb_writer) training.set_logger(logger, tb_writer)
......
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import os
from tqdm import tqdm
from convlab.util import load_dataset
from convlab.util import load_dst_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
def extract_data(dataset_names: str) -> list:
dataset_dicts = [load_dataset(dataset_name=name) for name in dataset_names.split('+')]
data = []
for dataset_dict in dataset_dicts:
dataset = load_dst_data(dataset_dict, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False)
for dial in dataset['test']:
data.append(dial)
return data
def clean_state(state):
clean_state = dict()
for domain, subset in state.items():
clean_state[domain] = {}
for slot, value in subset.items():
# Remove pipe separated values
value = value.split('|')
# Map values using value_map
for old, new in VALUE_MAP.items():
value = [val.replace(old, new) for val in value]
value = '|'.join(value)
# Map dontcare to "do not care" and empty to 'none'
value = value.replace('dontcare', 'do not care')
value = value if value else 'none'
# Map quantity values to the integer quantity value
if 'people' in slot or 'duration' in slot or 'stay' in slot:
try:
if value not in ['do not care', 'none']:
value = int(value)
value = str(value) if value < 10 else QUANTITIES[-1]
except:
value = value
# Map time values to the most appropriate value in the standard time set
elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
try:
if value not in ['do not care', 'none']:
# Strip after/before from time value
value = value.replace('after ', '').replace('before ', '')
# Extract hours and minutes from different possible formats
if ':' not in value and len(value) == 4:
h, m = value[:2], value[2:]
elif len(value) == 1:
h = int(value)
m = 0
elif 'pm' in value:
h = int(value.replace('pm', '')) + 12
m = 0
elif 'am' in value:
h = int(value.replace('pm', ''))
m = 0
elif ':' in value:
h, m = value.split(':')
elif ';' in value:
h, m = value.split(';')
# Map to closest 5 minutes
if int(m) % 5 != 0:
m = round(int(m) / 5) * 5
h = int(h)
if m == 60:
m = 0
h += 1
if h >= 24:
h -= 24
# Set in standard 24 hour format
h, m = int(h), int(m)
value = '%02i:%02i' % (h, m)
except:
value = value
# Map boolean slots to yes/no value
elif 'parking' in slot or 'internet' in slot:
if value not in ['do not care', 'none']:
if value == 'free':
value = 'yes'
elif True in [v in value.lower() for v in ['yes', 'no']]:
value = [v for v in ['yes', 'no'] if v in value][0]
value = value if value != 'none' else ''
clean_state[domain][slot] = value
return clean_state
def extract_states(data):
states_data = {}
for dial in data:
states = []
for turn in dial['turns']:
if 'state' in turn:
state = clean_state(turn['state'])
states.append(state)
states_data[dial['dialogue_id']] = states
return states_data
def get_golden_state(prediction, data):
state = data[prediction['dial_idx']][prediction['utt_idx']]
pred = prediction['predictions']['state']
pred = {domain: {slot: pred.get(DOMAINS_MAP.get(domain, domain.lower()), dict()).get(slot, '')
for slot in state[domain]} for domain in state}
prediction['state'] = state
prediction['predictions']['state'] = pred
return prediction
if __name__ == "__main__":
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset_name', type=str, help='Name of dataset', default="multiwoz21")
parser.add_argument('--model_path', type=str, help='Path to model dir')
args = parser.parse_args()
data = extract_data(args.dataset_name)
data = extract_states(data)
reader = open(os.path.join(args.model_path, "predictions", "test.json"), 'r')
predictions = json.load(reader)
reader.close()
predictions = [get_golden_state(pred, data) for pred in tqdm(predictions)]
writer = open(os.path.join(args.model_path, "predictions", f"test_{args.dataset_name}.json"), 'w')
json.dump(predictions, writer)
writer.close()
...@@ -590,8 +590,7 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train ...@@ -590,8 +590,7 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
for sample in eval_output_batch: for sample in eval_output_batch:
dom, slt = slot.split('-', 1) dom, slt = slot.split('-', 1)
lab = state_labels[sample['dial_idx']][sample['utt_idx']].item() lab = state_labels[sample['dial_idx']][sample['utt_idx']].item()
if lab != -1: lab = ontology[dom][slt]['possible_values'][lab] if lab != -1 else 'NOT_IN_ONTOLOGY'
lab = ontology[dom][slt]['possible_values'][lab]
pred = prediction[sample['dial_idx']][sample['utt_idx']].item() pred = prediction[sample['dial_idx']][sample['utt_idx']].item()
pred = ontology[dom][slt]['possible_values'][pred] pred = ontology[dom][slt]['possible_values'][pred]
...@@ -615,7 +614,9 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train ...@@ -615,7 +614,9 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
num_inform_slots += (state_labels != -1).float().reshape(-1) num_inform_slots += (state_labels != -1).float().reshape(-1)
if return_eval_output: if return_eval_output:
evaluation_output += deepcopy(eval_output_batch) for sample in eval_output_batch:
sample['dial_idx'] = batch['dialogue_ids'][sample['utt_idx']][sample['dial_idx']]
evaluation_output.append(deepcopy(sample))
eval_output_batch = [] eval_output_batch = []
if model.config.predict_actions: if model.config.predict_actions:
...@@ -708,16 +709,6 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train ...@@ -708,16 +709,6 @@ def evaluate(args, model, device, dataloader, return_eval_output=False, is_train
req_f1, dom_f1, gen_f1 = None, None, None req_f1, dom_f1, gen_f1 = None, None, None
if return_eval_output: if return_eval_output:
dial_idx = 0
for sample in evaluation_output:
if dial_idx == 0 and sample['dial_idx'] == 0 and sample['utt_idx'] == 0:
dial_idx = 0
elif dial_idx == 0 and sample['dial_idx'] != 0 and sample['utt_idx'] == 0:
dial_idx += 1
elif sample['utt_idx'] == 0:
dial_idx += 1
sample['dial_idx'] = dial_idx
return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), evaluation_output
if is_train: if is_train:
return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats return jg_acc, sl_acc, req_f1, dom_f1, gen_f1, tr_loss / len(dataloader), stats
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""SetSUMBT utils""" """SetSUMBT utils"""
import os import os
import json
import shutil import shutil
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from datetime import datetime from datetime import datetime
...@@ -27,6 +28,9 @@ def get_args(base_models: dict): ...@@ -27,6 +28,9 @@ def get_args(base_models: dict):
# Get arguments # Get arguments
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
# Config file usage
parser.add_argument('--starting_config_name', default=None, type=str)
# Optional # Optional
parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='') parser.add_argument('--tensorboard_path', help='Path to tensorboard', default='')
parser.add_argument('--logging_path', help='Path for log file', default='') parser.add_argument('--logging_path', help='Path for log file', default='')
...@@ -54,6 +58,8 @@ def get_args(base_models: dict): ...@@ -54,6 +58,8 @@ def get_args(base_models: dict):
parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None) parser.add_argument('--model_name_or_path', help='Name or path of the pretrained model.', default=None)
parser.add_argument('--candidate_embedding_model_name', default=None, parser.add_argument('--candidate_embedding_model_name', default=None,
help='Name of the pretrained candidate embedding model.') help='Name of the pretrained candidate embedding model.')
parser.add_argument('--transformers_local_files_only', help='Use local files only for huggingface transformers',
action='store_true')
# Architecture # Architecture
parser.add_argument('--freeze_encoder', help='No training performed on the turn encoder Bert Model', parser.add_argument('--freeze_encoder', help='No training performed on the turn encoder Bert Model',
...@@ -143,6 +149,12 @@ def get_args(base_models: dict): ...@@ -143,6 +149,12 @@ def get_args(base_models: dict):
parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true') parser.add_argument('--do_test', help='Evaluate model on test data', action='store_true')
args = parser.parse_args() args = parser.parse_args()
if args.starting_config_name:
args = get_starting_config(args)
if args.do_train:
args.do_eval = True
# Simplify args # Simplify args
args.set_similarity = not args.no_set_similarity args.set_similarity = not args.no_set_similarity
args.use_descriptions = not args.no_descriptions args.use_descriptions = not args.no_descriptions
...@@ -217,6 +229,31 @@ def get_args(base_models: dict): ...@@ -217,6 +229,31 @@ def get_args(base_models: dict):
return args, config return args, config
def get_starting_config(args):
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, 'configs', f"{args.starting_config_name}.json")
reader = open(path, 'r')
config = json.load(reader)
reader.close()
if "model_type" in config:
if config["model_type"].lower() == 'setsumbt':
config["model_type"] = 'roberta'
config["no_set_similarity"] = False
config["no_descriptions"] = False
elif config["model_type"].lower() == 'sumbt':
config["model_type"] = 'bert'
config["no_set_similarity"] = True
config["no_descriptions"] = False
variables = vars(args).keys()
for key, value in config.items():
if key in variables:
setattr(args, key, value)
return args
def get_git_info(): def get_git_info():
repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) repo = Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True)
branch_name = repo.active_branch.name branch_name = repo.active_branch.name
......
...@@ -26,7 +26,7 @@ class UserActionPolicy(Policy): ...@@ -26,7 +26,7 @@ class UserActionPolicy(Policy):
print("change mode to semantic because only_action=True") print("change mode to semantic because only_action=True")
self.mode = "semantic" self.mode = "semantic"
self.max_in_len = 500 self.max_in_len = 500
self.max_out_len = 50 if only_action else 200 self.max_out_len = 100 if only_action else 200
max_act_len = kwargs.get("max_act_len", 2) max_act_len = kwargs.get("max_act_len", 2)
print("max_act_len", max_act_len) print("max_act_len", max_act_len)
self.max_action_len = max_act_len self.max_action_len = max_act_len
......
{ {
"model": { "model": {
"load_path": "convlab/policy/ppo/pretrained_models/supervised", "load_path": "convlab/policy/ppo/pretrained_models/mle",
"pretrained_load_path": "", "pretrained_load_path": "",
"use_pretrained_initialisation": false, "use_pretrained_initialisation": false,
"batchsz": 1000, "batchsz": 500,
"seed": 0, "seed": 0,
"epoch": 50, "epoch": 50,
"eval_frequency": 5, "eval_frequency": 5,
......
{ {
"model": { "model": {
"load_path": "", "load_path": "convlab/policy/ppo/pretrained_models/mle",
"use_pretrained_initialisation": false, "use_pretrained_initialisation": false,
"pretrained_load_path": "", "pretrained_load_path": "",
"batchsz": 1000, "batchsz": 500,
"seed": 0, "seed": 0,
"epoch": 10, "epoch": 10,
"eval_frequency": 5, "eval_frequency": 5,
"process_num": 4, "process_num": 4,
"sys_semantic_to_usr": false, "sys_semantic_to_usr": false,
"num_eval_dialogues": 500 "num_eval_dialogues": 200
}, },
"vectorizer_sys": { "vectorizer_sys": {
"uncertainty_vector_mul": { "uncertainty_vector_mul": {
......
{ {
"model": { "model": {
"load_path": "convlab/policy/mle/experiments/experiment_2022-05-23-14-08-43/save/supervised", "load_path": "convlab/policy/ppo/pretrained_models/mle",
"use_pretrained_initialisation": false, "use_pretrained_initialisation": false,
"pretrained_load_path": "", "pretrained_load_path": "",
"batchsz": 1000, "batchsz": 1000,
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
"TUSPolicy": { "TUSPolicy": {
"class_path": "convlab.policy.tus.unify.TUS.UserPolicy", "class_path": "convlab.policy.tus.unify.TUS.UserPolicy",
"ini_params": { "ini_params": {
"config": "convlab/policy/tus/unify/exp/all.json" "config": "convlab/policy/tus/unify/exp/multiwoz.json"
} }
} }
}, },
......
...@@ -134,7 +134,7 @@ class UserActionPolicy(Policy): ...@@ -134,7 +134,7 @@ class UserActionPolicy(Policy):
goal = Goal(goal_list) goal = Goal(goal_list)
else: else:
goal = ABUS_Goal(self.goal_gen) goal = ABUS_Goal(self.goal_gen)
self.raw_gaol = goal.domain_goals self.raw_goal = goal.domain_goals
goal_list = old_goal2list(goal.domain_goals) goal_list = old_goal2list(goal.domain_goals)
goal = Goal(goal_list) goal = Goal(goal_list)
...@@ -411,7 +411,8 @@ class UserPolicy(Policy): ...@@ -411,7 +411,8 @@ class UserPolicy(Policy):
self.config = json.load(open(config)) self.config = json.load(open(config))
else: else:
self.config = config self.config = config
self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz' self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}'
print("model_dir", self.config['model_dir'])
if not os.path.exists(self.config["model_dir"]): if not os.path.exists(self.config["model_dir"]):
# os.mkdir(self.config["model_dir"]) # os.mkdir(self.config["model_dir"])
model_downloader(os.path.dirname(self.config["model_dir"]), model_downloader(os.path.dirname(self.config["model_dir"]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment