diff --git a/convlab/nlg/scgpt/multiwoz/__init__.py b/convlab/nlg/scgpt/multiwoz/__init__.py deleted file mode 100644 index 88c7ca2e9735ded913e007644fc8b46fd78535f6..0000000000000000000000000000000000000000 --- a/convlab/nlg/scgpt/multiwoz/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Sat Apr 4 21:43:42 2020 - -@author: truthless -""" - -from convlab.nlg.scgpt.multiwoz.scgpt import SCGPT \ No newline at end of file diff --git a/convlab/nlg/scgpt/multiwoz/preprocess.py b/convlab/nlg/scgpt/multiwoz/preprocess.py deleted file mode 100644 index 3f5cf70f664895cd7f1c7d67f6c31903d157809a..0000000000000000000000000000000000000000 --- a/convlab/nlg/scgpt/multiwoz/preprocess.py +++ /dev/null @@ -1,129 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Mon Sep 14 11:38:53 2020 -@author: truthless -""" - -import os -import json -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser -from convlab.nlg.scgpt.utils import dict2dict, dict2seq -import zipfile - -def read_zipped_json(filepath, filename): - print("zip file path = ", filepath) - archive = zipfile.ZipFile(filepath, 'r') - return json.load(archive.open(filename)) - -def init_domain(): - return {'Attraction':False, - 'Hospital':False, - 'Hotel':False, - 'Police':False, - 'Restaurant':False, - 'Taxi':False, - 'Train':False} - -def write_file(name, data, role='usr'): - with open(f'{name}.txt', 'w', encoding='utf-8') as f: - for ID in data: - sess = data[ID] - sess_domains = init_domain() - for turn in sess: - if role == 'usr': - if not turn['usr_da']: - continue - turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train')) - da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and') - domains = set([key.split('-')[0] for key in turn['usr_da'].keys()]) - elif role == 'sys': - if not turn['sys_da']: - continue - turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train')) - da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and') - domains = set([key.split('-')[0] for key in turn['sys_da'].keys()]) - else: - raise NameError('Invalid Role: Select usr/sys.') - for domain in domains: - if domain not in ['general', 'Booking'] and not sess_domains[domain]: - da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1) - sess_domains[domain] = True - - if role == 'usr': - da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and') - elif role == 'sys': - da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and') - f.write(f'{da_seq} & {da_uttr}\n') - - -if __name__ == '__main__': - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument('--role', type=str, default='usr') - args = parser.parse_args() - - cur_dir = os.path.dirname(os.path.abspath(__file__)) - data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( - cur_dir)))), 'data/multiwoz/') - - keys = ['train', 'val', 'test'] - data = {} - for key in keys: - data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json') - print('load {}, size {}'.format(key, len(data_key))) - data = dict(data, **data_key) - - with open(os.path.join(data_dir, 'valListFile'), 'r') as f: - val_list = f.read().splitlines() - with open(os.path.join(data_dir, 'testListFile'), 'r') as f: - test_list = f.read().splitlines() - - results = {} - results_val = {} - results_test = {} - - for title, sess in data.items(): - logs = sess['log'] - turns = [] - turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''} - current_domain = None - for i, diag in enumerate(logs): - text = diag['text'] - da = diag['dialog_act'] - span = diag['span_info'] - if current_domain: - da = eval(str(da).replace('Booking', current_domain)) - span = eval(str(span).replace('Booking', current_domain)) - if i % 2 == 0: - turn['usr'] = text - turn['usr_da'] = da - turn['usr_span'] = span - turns.append(turn) - else: - turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''} - turn['sys'] = text - turn['sys_da'] = da - turn['sys_span'] = span - for key in da: - domain = key.split('-')[0] - if domain not in ['general', 'Booking']: - current_domain = domain - else: - if args.role == 'sys': - turns.append(turn) - title = title - if title in val_list: - current = results_val - elif title in test_list: - current = results_test - else: - current = results - current[title] = turns - - results = eval(str(results).replace(" n't", " not")) - results_val = eval(str(results_val).replace(" n't", " not")) - results_test = eval(str(results_test).replace(" n't", " not")) - - if not os.path.exists(os.path.join(cur_dir,'data')): - os.makedirs(os.path.join(cur_dir, 'data')) - write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role) - write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role) diff --git a/convlab/nlg/scgpt/multiwoz/run.py b/convlab/nlg/scgpt/multiwoz/run.py deleted file mode 100644 index e583fe72fb26cd4262a6c4aae7776aabee49293b..0000000000000000000000000000000000000000 --- a/convlab/nlg/scgpt/multiwoz/run.py +++ /dev/null @@ -1,171 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -import argparse -import logging -from tqdm import trange - -import torch -import torch.nn.functional as F -import numpy as np - -import sys - -from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig - -from transformers import GPT2LMHeadModel, GPT2Tokenizer -from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer -from transformers import XLNetLMHeadModel, XLNetTokenizer -from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer -from transformers import CTRLLMHeadModel, CTRLTokenizer -from transformers import XLMWithLMHeadModel, XLMTokenizer - - -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) -logger = logging.getLogger(__name__) - -MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop - -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ()) - -MODEL_CLASSES = { - 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), - 'ctrl': (CTRLLMHeadModel, CTRLTokenizer), - 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - 'xlnet': (XLNetLMHeadModel, XLNetTokenizer), - 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), - 'xlm': (XLMWithLMHeadModel, XLMTokenizer), -} - -# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia -# in https://github.com/rusiaaman/XLNet-gen#methodology -# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e -PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family -(except for Alexei and Maria) are discovered. -The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the -remainder of the story. 1883 Western Siberia, -a young Grigori Rasputin is asked by his father and a group of men to perform magic. -Rasputin has a vision and denounces one of the men as a horse thief. Although his -father initially slaps him for making such an accusation, Rasputin watches as the -man is chased outside and beaten. Twenty years later, Rasputin sees a vision of -the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, -with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" - - -def set_seed(args): - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if args.n_gpu > 0: - torch.cuda.manual_seed_all(args.seed) - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_type", default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) - parser.add_argument("--model_name_or_path", default=None, type=str, required=True, - help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) - parser.add_argument("--prompt", type=str, default="") - parser.add_argument("--padding_text", type=str, default="") - parser.add_argument("--length", type=int, default=40) - parser.add_argument("--num_samples", type=int, default=1) - parser.add_argument("--temperature", type=float, default=1.0, - help="temperature of 0 implies greedy sampling") - parser.add_argument("--repetition_penalty", type=float, default=1.0, - help="primarily useful for CTRL model; in that case, use 1.2") - parser.add_argument("--top_k", type=int, default=50) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--no_cuda", action='store_true', - help="Avoid using CUDA when available") - parser.add_argument('--seed', type=int, default=42, - help="random seed for initialization") - parser.add_argument('--stop_token', type=str, default=None, - help="Token at which text generation is stopped") - parser.add_argument("--batch_size", default=1, type=int) - parser.add_argument('--input_file', type=str, default=None, - help="file") - parser.add_argument('--output_file', type=str, default=None, - help="file") - - args = parser.parse_args() - - args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") - args.n_gpu = torch.cuda.device_count() - - set_seed(args) - - args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, pad_token='<PAD>', padding_side='left') - model = model_class.from_pretrained(args.model_name_or_path) - model.to(args.device) - model.eval() - - if args.length < 0 and model.config.max_position_embeddings > 0: - args.length = model.config.max_position_embeddings - elif 0 < model.config.max_position_embeddings < args.length: - args.length = model.config.max_position_embeddings # No generation bigger than model size - elif args.length < 0: - args.length = MAX_LENGTH # avoid infinite loop - - logger.info(args) - if args.model_type in ["ctrl"]: - if args.temperature > 0.7: - logger.info('CTRL typically works better with lower temperatures (and lower top_k).') - - fin = open(args.input_file) - inputs = [i.strip() for i in fin] - output_tests = [] - for idx in range(0, len(inputs), args.batch_size): - logger.info(f"PROGRESS: {int(idx/len(inputs)*100)}%") - - # raw_text = args.prompt if args.prompt else input("Model prompt >>> ") - raw_inputs = [] - for i in range(idx, min(idx+args.batch_size, len(inputs))): - lines = inputs[i] - raw_text = lines.split(' & ')[0] + ' & ' - if args.model_type in ["transfo-xl", "xlnet"]: - # Models with memory likes to have a long prompt for short inputs. - raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text - raw_inputs.append(raw_text) - - encoding_inputs = tokenizer.batch_encode_plus(raw_inputs, pad_to_max_length=True, add_special_tokens=False) - context_tokens = torch.LongTensor(encoding_inputs['input_ids']).to(args.device) - max_length = len(context_tokens[0]) - attention_mask = torch.LongTensor(encoding_inputs['attention_mask']).to(args.device) - position_ids = (attention_mask.cumsum(-1) - 1) - position_ids.masked_fill_(attention_mask==0, 0) - - if args.model_type == "ctrl": - if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()): - logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") - out_ids = model.generate( - input_ids=context_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - num_beams=args.num_samples, - num_return_sequences=args.num_samples, - max_length=args.length, - temperature=args.temperature, - do_sample=True, - top_k=args.top_k, - top_p=args.top_p, - repetition_penalty=args.repetition_penalty - ) - out_ids = out_ids.reshape(len(raw_inputs), args.num_samples, -1)[:, :, max_length:].tolist() - for j, out in enumerate(out_ids): - examples = [inputs[j]] - for o in out: - text = tokenizer.decode(o, clean_up_tokenization_spaces=True) - text = text[: text.find(args.stop_token) if args.stop_token else None] - examples.append(text) - output_tests.append(examples) - # break - # if args.prompt: - # break - import json - json.dump(output_tests, open(args.output_file,'w'), indent=2) - return text - -if __name__ == '__main__': - main() diff --git a/convlab/nlg/scgpt/multiwoz/scgpt.py b/convlab/nlg/scgpt/multiwoz/scgpt.py deleted file mode 100644 index 2c4b10217d43e85198cefae488d7828d98c65f6e..0000000000000000000000000000000000000000 --- a/convlab/nlg/scgpt/multiwoz/scgpt.py +++ /dev/null @@ -1,123 +0,0 @@ -import torch -import numpy as np -import os -import zipfile -from copy import deepcopy - -from transformers import GPT2LMHeadModel, GPT2Tokenizer -from convlab.nlg.scgpt.utils import tuple2seq -from convlab.nlg.scgpt.decode import set_seed, sample_sequence -from convlab.nlg.nlg import NLG -from convlab.util.file_util import cached_path - -MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop - -class SCGPT(NLG): - - def __init__(self, model_file=None, - use_cuda=True, is_user=False): - # If no filename is mentioned then set to default - if not model_file: - if is_user: - model_file = 'https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip' - else: - model_file = 'https://zenodo.org/record/5767426/files/neo_scgpt_system.zip' - - # Load from file/url - model_dir = os.path.dirname(os.path.abspath(__file__)) - if not os.path.isfile(model_file): - model_file = cached_path(model_file) - if not os.path.isdir(model_file): - archive = zipfile.ZipFile(model_file, 'r') - archive.extractall(model_dir) - # Get model directory - model_file = archive.filelist[0].filename.replace('/', '') - self.model_name_or_path = os.path.join(model_dir, model_file) - else: - self.model_name_or_path = model_file - - self.length = 50 - self.num_samples = 5 - self.temperature = 1.0 - self.repetition_penalty = 1.0 - self.top_k = 50 - self.top_p = 0.9 - self.seed = 42 - self.is_user = is_user - self.stop_token = '<|endoftext|>' - - self.device = torch.device("cuda" if torch.cuda.is_available() and use_cuda else "cpu") - set_seed(self.seed, torch.cuda.device_count()) - - model_class, tokenizer_class = GPT2LMHeadModel, GPT2Tokenizer - self.tokenizer = tokenizer_class.from_pretrained(self.model_name_or_path) - self.model = model_class.from_pretrained(self.model_name_or_path) - self.model.to(self.device) - self.model.eval() - - if self.length < 0 and self.model.config.max_position_embeddings > 0: - self.length = self.model.config.max_position_embeddings - elif 0 < self.model.config.max_position_embeddings < self.length: - self.length = self.model.config.max_position_embeddings # No generation bigger than model size - elif self.length < 0: - self.length = self.MAX_LENGTH # avoid infinite loop - - self.init_session() - - def init_session(self): - self.sess_domains = {'Attraction':False, - 'Hospital':False, - 'Hotel':False, - 'Police':False, - 'Restaurant':False, - 'Taxi':False, - 'Train':False,} - self.cur_domain = None - # if not self.is_user: - # self.sess_domains['Booking'] = False - - def generate(self, meta): - - #some actions in testing data is none - if not meta: - return 'No user action' - - meta = deepcopy(meta) - for list_ in meta: - domain = list_[1] - if domain not in ('general', 'Booking'): - self.cur_domain = domain - for i, list_ in enumerate(meta): - list_ = list(list_) - if list_[1] == 'Booking': - if self.cur_domain is not None: - list_[1] = self.cur_domain - meta[i] = list_ - else: - print('`cur_domain` is None, but there is `Booking` in dialog action.') - raw_text = tuple2seq(meta) - domains = set([item[1] for item in meta]) - for domain in domains: - if domain not in ('general', 'Booking') and not self.sess_domains[domain]: - raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1) - self.sess_domains[domain] = True - context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False) - out = sample_sequence( - model=self.model, - context=context_tokens, - num_samples=self.num_samples, - length=self.length, - temperature=self.temperature, - top_k=self.top_k, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - device=self.device, - ) - out = out[:, len(context_tokens):].tolist() - index = np.random.choice([0,1,2,3],p=[0.4,0.3,0.2,0.1]) - o = out[index] - text = self.tokenizer.decode(o, clean_up_tokenization_spaces=True) - text = text.split('& ')[-1] - text = text[: text.find(self.stop_token) if self.stop_token else None] - - return text