Skip to content
Snippets Groups Projects
Commit 3ac4a347 authored by zz-jacob's avatar zz-jacob
Browse files

remove old codes

parent 245360f0
Branches
Tags
No related merge requests found
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 4 21:43:42 2020
@author: truthless
"""
from convlab.nlg.scgpt.multiwoz.scgpt import SCGPT
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 14 11:38:53 2020
@author: truthless
"""
import os
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from convlab.nlg.scgpt.utils import dict2dict, dict2seq
import zipfile
def read_zipped_json(filepath, filename):
print("zip file path = ", filepath)
archive = zipfile.ZipFile(filepath, 'r')
return json.load(archive.open(filename))
def init_domain():
return {'Attraction':False,
'Hospital':False,
'Hotel':False,
'Police':False,
'Restaurant':False,
'Taxi':False,
'Train':False}
def write_file(name, data, role='usr'):
with open(f'{name}.txt', 'w', encoding='utf-8') as f:
for ID in data:
sess = data[ID]
sess_domains = init_domain()
for turn in sess:
if role == 'usr':
if not turn['usr_da']:
continue
turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
elif role == 'sys':
if not turn['sys_da']:
continue
turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
else:
raise NameError('Invalid Role: Select usr/sys.')
for domain in domains:
if domain not in ['general', 'Booking'] and not sess_domains[domain]:
da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
sess_domains[domain] = True
if role == 'usr':
da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
elif role == 'sys':
da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
f.write(f'{da_seq} & {da_uttr}\n')
if __name__ == '__main__':
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--role', type=str, default='usr')
args = parser.parse_args()
cur_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
cur_dir)))), 'data/multiwoz/')
keys = ['train', 'val', 'test']
data = {}
for key in keys:
data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json')
print('load {}, size {}'.format(key, len(data_key)))
data = dict(data, **data_key)
with open(os.path.join(data_dir, 'valListFile'), 'r') as f:
val_list = f.read().splitlines()
with open(os.path.join(data_dir, 'testListFile'), 'r') as f:
test_list = f.read().splitlines()
results = {}
results_val = {}
results_test = {}
for title, sess in data.items():
logs = sess['log']
turns = []
turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
current_domain = None
for i, diag in enumerate(logs):
text = diag['text']
da = diag['dialog_act']
span = diag['span_info']
if current_domain:
da = eval(str(da).replace('Booking', current_domain))
span = eval(str(span).replace('Booking', current_domain))
if i % 2 == 0:
turn['usr'] = text
turn['usr_da'] = da
turn['usr_span'] = span
turns.append(turn)
else:
turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
turn['sys'] = text
turn['sys_da'] = da
turn['sys_span'] = span
for key in da:
domain = key.split('-')[0]
if domain not in ['general', 'Booking']:
current_domain = domain
else:
if args.role == 'sys':
turns.append(turn)
title = title
if title in val_list:
current = results_val
elif title in test_list:
current = results_test
else:
current = results
current[title] = turns
results = eval(str(results).replace(" n't", " not"))
results_val = eval(str(results_val).replace(" n't", " not"))
results_test = eval(str(results_test).replace(" n't", " not"))
if not os.path.exists(os.path.join(cur_dir,'data')):
os.makedirs(os.path.join(cur_dir, 'data'))
write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role)
write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role)
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
import sys
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--length", type=int, default=40)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0,
help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument('--input_file', type=str, default=None,
help="file")
parser.add_argument('--output_file', type=str, default=None,
help="file")
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
set_seed(args)
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, pad_token='<PAD>', padding_side='left')
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
model.eval()
if args.length < 0 and model.config.max_position_embeddings > 0:
args.length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < args.length:
args.length = model.config.max_position_embeddings # No generation bigger than model size
elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop
logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7:
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
fin = open(args.input_file)
inputs = [i.strip() for i in fin]
output_tests = []
for idx in range(0, len(inputs), args.batch_size):
logger.info(f"PROGRESS: {int(idx/len(inputs)*100)}%")
# raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
raw_inputs = []
for i in range(idx, min(idx+args.batch_size, len(inputs))):
lines = inputs[i]
raw_text = lines.split(' & ')[0] + ' & '
if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
raw_inputs.append(raw_text)
encoding_inputs = tokenizer.batch_encode_plus(raw_inputs, pad_to_max_length=True, add_special_tokens=False)
context_tokens = torch.LongTensor(encoding_inputs['input_ids']).to(args.device)
max_length = len(context_tokens[0])
attention_mask = torch.LongTensor(encoding_inputs['attention_mask']).to(args.device)
position_ids = (attention_mask.cumsum(-1) - 1)
position_ids.masked_fill_(attention_mask==0, 0)
if args.model_type == "ctrl":
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
out_ids = model.generate(
input_ids=context_tokens,
attention_mask=attention_mask,
position_ids=position_ids,
num_beams=args.num_samples,
num_return_sequences=args.num_samples,
max_length=args.length,
temperature=args.temperature,
do_sample=True,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty
)
out_ids = out_ids.reshape(len(raw_inputs), args.num_samples, -1)[:, :, max_length:].tolist()
for j, out in enumerate(out_ids):
examples = [inputs[j]]
for o in out:
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
examples.append(text)
output_tests.append(examples)
# break
# if args.prompt:
# break
import json
json.dump(output_tests, open(args.output_file,'w'), indent=2)
return text
if __name__ == '__main__':
main()
import torch
import numpy as np
import os
import zipfile
from copy import deepcopy
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from convlab.nlg.scgpt.utils import tuple2seq
from convlab.nlg.scgpt.decode import set_seed, sample_sequence
from convlab.nlg.nlg import NLG
from convlab.util.file_util import cached_path
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
class SCGPT(NLG):
def __init__(self, model_file=None,
use_cuda=True, is_user=False):
# If no filename is mentioned then set to default
if not model_file:
if is_user:
model_file = 'https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'
else:
model_file = 'https://zenodo.org/record/5767426/files/neo_scgpt_system.zip'
# Load from file/url
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.isfile(model_file):
model_file = cached_path(model_file)
if not os.path.isdir(model_file):
archive = zipfile.ZipFile(model_file, 'r')
archive.extractall(model_dir)
# Get model directory
model_file = archive.filelist[0].filename.replace('/', '')
self.model_name_or_path = os.path.join(model_dir, model_file)
else:
self.model_name_or_path = model_file
self.length = 50
self.num_samples = 5
self.temperature = 1.0
self.repetition_penalty = 1.0
self.top_k = 50
self.top_p = 0.9
self.seed = 42
self.is_user = is_user
self.stop_token = '<|endoftext|>'
self.device = torch.device("cuda" if torch.cuda.is_available() and use_cuda else "cpu")
set_seed(self.seed, torch.cuda.device_count())
model_class, tokenizer_class = GPT2LMHeadModel, GPT2Tokenizer
self.tokenizer = tokenizer_class.from_pretrained(self.model_name_or_path)
self.model = model_class.from_pretrained(self.model_name_or_path)
self.model.to(self.device)
self.model.eval()
if self.length < 0 and self.model.config.max_position_embeddings > 0:
self.length = self.model.config.max_position_embeddings
elif 0 < self.model.config.max_position_embeddings < self.length:
self.length = self.model.config.max_position_embeddings # No generation bigger than model size
elif self.length < 0:
self.length = self.MAX_LENGTH # avoid infinite loop
self.init_session()
def init_session(self):
self.sess_domains = {'Attraction':False,
'Hospital':False,
'Hotel':False,
'Police':False,
'Restaurant':False,
'Taxi':False,
'Train':False,}
self.cur_domain = None
# if not self.is_user:
# self.sess_domains['Booking'] = False
def generate(self, meta):
#some actions in testing data is none
if not meta:
return 'No user action'
meta = deepcopy(meta)
for list_ in meta:
domain = list_[1]
if domain not in ('general', 'Booking'):
self.cur_domain = domain
for i, list_ in enumerate(meta):
list_ = list(list_)
if list_[1] == 'Booking':
if self.cur_domain is not None:
list_[1] = self.cur_domain
meta[i] = list_
else:
print('`cur_domain` is None, but there is `Booking` in dialog action.')
raw_text = tuple2seq(meta)
domains = set([item[1] for item in meta])
for domain in domains:
if domain not in ('general', 'Booking') and not self.sess_domains[domain]:
raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1)
self.sess_domains[domain] = True
context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)
out = sample_sequence(
model=self.model,
context=context_tokens,
num_samples=self.num_samples,
length=self.length,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
device=self.device,
)
out = out[:, len(context_tokens):].tolist()
index = np.random.choice([0,1,2,3],p=[0.4,0.3,0.2,0.1])
o = out[index]
text = self.tokenizer.decode(o, clean_up_tokenization_spaces=True)
text = text.split('& ')[-1]
text = text[: text.find(self.stop_token) if self.stop_token else None]
return text
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment