Skip to content
Snippets Groups Projects
Commit e4872bae authored by zqwerty's avatar zqwerty
Browse files

pre-train t5 rg

parent a9443416
No related branches found
No related tags found
No related merge requests found
import json import json
import json_lines
import os import os
import random import random
from tqdm import tqdm from tqdm import tqdm
...@@ -7,85 +8,19 @@ from nltk import sent_tokenize ...@@ -7,85 +8,19 @@ from nltk import sent_tokenize
def main(args): def main(args):
random.seed(42) random.seed(42)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.mode == 'multitask': filenames = [os.path.join(args.input_dir, f) for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
dataset_name = args.output_dir.split('/')[-1]
for data_split in ['validation', 'train']:
with open(os.path.join(args.output_dir, f"{data_split}.json"), 'w', encoding='utf-8') as fout:
for task_name in ['rg', 'key2gen', 'key2gen_noisy']:
with open(os.path.join(args.input_dir, task_name, 'gpt', dataset_name, f"{data_split}.json")) as fin:
for line in fin:
item = json.loads(line)
fout.write(json.dumps({'source': item['source'], 'target': item['target']}, ensure_ascii=False)+'\n')
return
if args.mode == 'sen2gen':
generated_filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if f.startswith('gen_')]
original_filenames = [f[4:] for f in generated_filenames]
for ori_f, gen_f in zip(original_filenames, generated_filenames):
fori = open(os.path.join(args.input_dir, ori_f))
fgen = open(os.path.join(args.input_dir, gen_f))
fout = open(os.path.join(args.output_dir, f"{ori_f.split('_')[0]}.json"), 'w', encoding='utf-8')
for ori_line, gen_line in zip(fori, fgen):
ori_item = json.loads(ori_line)
gen_item = json.loads(gen_line)
context = ori_item['source'][ori_item['source'].index('context:\n\n'):]
gen_sen = gen_item['predictions']
ori_item['source'] = f'generate a response: grounded knowledge: | {gen_sen} | {context}'
ori_item['gen_sen'] = gen_sen
fout.write(json.dumps(ori_item, ensure_ascii=False)+'\n')
return
if args.mode == 'sen2gen_noisy':
def gen_samples(dialog_samples):
turn_gen_sens = [sent_tokenize(item['gen_sen']) for item in dialog_samples]
for i, sample in enumerate(dialog_samples):
possible_sens_turns = turn_gen_sens[i][:]
num_possible_sens_turns = min(random.randint(1, 5), len(turn_gen_sens) - 1)
for turn_sens in random.sample(turn_gen_sens[:i] + turn_gen_sens[i+1:], num_possible_sens_turns):
possible_sens_turns.extend(turn_sens)
random.shuffle(possible_sens_turns)
possible_sens = ' | '.join(possible_sens_turns)
context = sample['source'][sample['source'].index('context:\n\n'):]
sample['source'] = f'generate a response: all knowledge: | {possible_sens} | {context}'
yield sample
for ori_f in [f for (_, _, fs) in os.walk(args.input_dir) for f in fs]:
fori = open(os.path.join(args.input_dir, ori_f))
fout = open(os.path.join(args.output_dir, ori_f), 'w', encoding='utf-8')
dialog = []
prev_num_turns = 0
for line in fori:
item = json.loads(line)
num_turns = item['source'].count('\n')
if len(dialog) == 0 or num_turns < prev_num_turns:
# process a dialog with augmented responses
for sample in gen_samples(dialog):
fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
# next dialog
dialog = [item]
else:
# next turn
dialog.append(item)
prev_num_turns = num_turns
for sample in gen_samples(dialog):
fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
return
filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
for filename in filenames: for filename in filenames:
data = json.load(open(os.path.join(args.input_dir, filename))) dataset_name = filename.split('/')[-2]
# Splitting the data into multiple pieces. data_split = filename.split('/')[-1].split('_')[-1].split('.')[0]
if args.n_splits > 1: output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}")
len_data_pieces = len(data)//args.n_splits print(f'processing {dataset_name}: {filename} => {output_file}')
fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}_split_{i}-of-{args.n_splits}.json"), 'w', encoding='utf-8') for i in range(args.n_splits)] cnt = 0
random.shuffle(data) with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout:
else: for dial in tqdm(json_lines.reader(fin)):
len_data_pieces = len(data)
fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')]
for i, fout in enumerate(fouts):
for dial in tqdm(data[i*len_data_pieces:(i+1)*len_data_pieces]):
context = [] context = []
turns_keywords = [turn['keywords'] for turn in dial] turns_keywords = [turn['keywords'] for turn in dial]
for i, turn in enumerate(dial): for i, turn in enumerate(dial):
if 'wikidialog' in filename: if dataset_name == 'wikidialog':
# skip user turns that generated by T5 in wikidialog # skip user turns that generated by T5 in wikidialog
speaker = 'user' if i % 2 == 1 else 'system' speaker = 'user' if i % 2 == 1 else 'system'
else: else:
...@@ -93,18 +28,16 @@ def main(args): ...@@ -93,18 +28,16 @@ def main(args):
utt = turn['utterance'] utt = turn['utterance']
context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: ']) context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
context.append({'speaker': speaker, 'utt': utt}) context.append({'speaker': speaker, 'utt': utt})
if i == 0 or ('wikidialog' in filename and speaker == 'user'): if i == 0 or (dataset_name == 'wikidialog' and speaker == 'user'):
continue continue
if args.mode == 'rg': if args.mode == 'rg':
input_seq = f'generate a response: context:\n\n{context_seq}' input_seq = f'generate a response: all knowledge: | | context:\n\n{context_seq}'
fout.write(json.dumps({ fout.write(json.dumps({
'dataset': filename.split('/')[-1].split('_')[0], 'dataset': dataset_name,
'source': input_seq, 'source': input_seq,
'target': utt}, ensure_ascii=False)+'\n') 'target': utt
continue }, ensure_ascii=False)+'\n')
if len(turn['keywords']) == 0 or max([len(k) for k in turn['keywords']])>10:
continue continue
if args.mode == 'key2gen': if args.mode == 'key2gen':
...@@ -113,11 +46,14 @@ def main(args): ...@@ -113,11 +46,14 @@ def main(args):
random.shuffle(turn['keywords'][j]) random.shuffle(turn['keywords'][j])
keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']]) keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']])
input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}' input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({ json2dump = {
'dataset': filename.split('/')[-1].split('_')[0], 'dataset': dataset_name,
'source': input_seq, 'source': input_seq,
'target': utt, 'target': utt
'keywords': turn['keywords']}, ensure_ascii=False)+'\n') }
if data_split == 'validation':
json2dump.update({'keywords': turn['keywords']})
fout.write(json.dumps(json2dump, ensure_ascii=False)+'\n')
continue continue
if args.mode == 'key2gen_noisy': if args.mode == 'key2gen_noisy':
...@@ -128,12 +64,14 @@ def main(args): ...@@ -128,12 +64,14 @@ def main(args):
random.shuffle(possible_keywords_sents) random.shuffle(possible_keywords_sents)
possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents]) possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents])
input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}' input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({ json2dump = {
'dataset': filename.split('/')[-1].split('_')[0], 'dataset': dataset_name,
'source': input_seq, 'source': input_seq,
'target': utt, 'target': utt
'keywords': turn['keywords'], }
'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n') if data_split == 'validation':
json2dump.update({'keywords': turn['keywords'], 'all_keywords': possible_keywords_sents})
fout.write(json.dumps(json2dump, ensure_ascii=False)+'\n')
continue continue
...@@ -142,8 +80,7 @@ if __name__ == '__main__': ...@@ -142,8 +80,7 @@ if __name__ == '__main__':
parser = ArgumentParser(description="calculate NLU metrics for unified datasets") parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--input_dir', '-i', type=str, help='path to the input files') parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
parser.add_argument('--output_dir', '-o', type=str, help='path to the output files') parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy', 'sen2gen', 'sen2gen_noisy', 'multitask'], help='which task to perform') parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy'], help='which task to perform')
parser.add_argument('--n_splits', '-n', type=int, default=1, help='split the data into multiple pieces')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
# generate data for response generation, key2gen, key2gen_noisy # generate data for response generation, key2gen, key2gen_noisy
for task_name in rg for task_name in rg key2gen key2gen_noisy
do do
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog" dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
names=$(echo ${dataset_name} | tr "+" "\n") names=$(echo ${dataset_name} | tr "+" "\n")
...@@ -8,8 +8,7 @@ do ...@@ -8,8 +8,7 @@ do
mkdir -p ${data_dir} mkdir -p ${data_dir}
train_file="${data_dir}/train.json" train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json" validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json" rm ${train_file} ${validation_file}
rm ${train_file} ${validation_file} ${test_file}
for name in ${names} for name in ${names}
do do
echo "preprocessing ${name}" echo "preprocessing ${name}"
...@@ -17,60 +16,6 @@ do ...@@ -17,60 +16,6 @@ do
if [ "${name}" != "${dataset_name}" ]; then if [ "${name}" != "${dataset_name}" ]; then
cat "data/${task_name}/${model_type}/${name}/train.json" >> ${train_file} cat "data/${task_name}/${model_type}/${name}/train.json" >> ${train_file}
cat "data/${task_name}/${model_type}/${name}/validation.json" >> ${validation_file} cat "data/${task_name}/${model_type}/${name}/validation.json" >> ${validation_file}
cat "data/${task_name}/${model_type}/${name}/test.json" >> ${test_file}
fi fi
done done
done done
\ No newline at end of file
# # generate data for sentence grounded generation
# task_name="key2gen"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# names=$(echo ${dataset_name} | tr "+" "\n")
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# n_splits=2
# for ((n=0;n<${n_splits};n++))
# do
# rm ${data_dir}/train_split_${n}-of-${n_splits}.json ${data_dir}/validation_split_${n}-of-${n_splits}.json ${data_dir}/test_split_${n}-of-${n_splits}.json
# done
# for name in ${names}
# do
# echo "preprocessing ${name}"
# python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} -m ${task_name} -n ${n_splits}
# if [ "${name}" != "${dataset_name}" ]; then
# for ((n=0;n<${n_splits};n++))
# do
# cat "data/${task_name}/gpt/${name}/train_split_${n}-of-${n_splits}.json" >> "${data_dir}/train_split_${n}-of-${n_splits}.json"
# cat "data/${task_name}/gpt/${name}/validation_split_${n}-of-${n_splits}.json" >> "${data_dir}/validation_split_${n}-of-${n_splits}.json"
# cat "data/${task_name}/gpt/${name}/test_split_${n}-of-${n_splits}.json" >> "${data_dir}/test_split_${n}-of-${n_splits}.json"
# done
# fi
# done
# # merge generated data with original data
# task_name="sen2gen"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# names=$(echo ${dataset_name} | tr "+" "\n")
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# python gen_pretraining_data.py -i data/key2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
# # generate sen2gen_noisy data with original data
# task_name="sen2gen_noisy"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# names=$(echo ${dataset_name} | tr "+" "\n")
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# python gen_pretraining_data.py -i data/sen2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
# merge data for multitask training
# task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# python gen_pretraining_data.py -i data/ -o data/${task_name}/${model_type}/${dataset_name} -m multitask
n_gpus=2
master_port=$1
task_name="key2gen"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
model_type="gpt"
split_id=$2
n_splits=$3
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}/gen"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
# train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
# validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
let infer_split_id=($split_id+1)%$n_splits
test_file="${data_dir}/validation_split_${infer_split_id}-of-${n_splits}.json"
source_column="source"
target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=4
lr=1e-3
num_train_epochs=10
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_predict \
--predict_with_generate \
--do_sample \
--top_p 0.9 \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 16 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
...@@ -127,7 +127,7 @@ def main(args): ...@@ -127,7 +127,7 @@ def main(args):
fin = open(word_loss_file, 'rb') fin = open(word_loss_file, 'rb')
fout = open(args.output_file, 'w', encoding='utf-8') fout = open(args.output_file, 'w', encoding='utf-8')
for item in json_lines.reader(fin): for item in tqdm(json_lines.reader(fin)):
words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']] words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']]
losses = [np.mean(loss) for loss in item['losses']] losses = [np.mean(loss) for loss in item['losses']]
dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses) dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses)
......
set -e
n_gpus=2
master_port=$1
task_name="key2gen"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
model_type="gpt"
split_id=$2
n_splits=$3
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
test_file="${data_dir}/test_split_${split_id}-of-${n_splits}.json"
source_column="source"
target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=4
num_workers=16
lr=1e-3
num_train_epochs=1
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--load_best_model_at_end \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
set -e set -e
n_gpus=2 n_gpus=8
master_port=23456 master_port=23456
task_name="rg" task_name="rg"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
model_type="gpt" model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}" data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}" output_dir="output/${task_name}/${model_type}/${dataset_name}"
cache_dir="../cache" cache_dir="../cache"
logging_dir="${output_dir}/runs" logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json" train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
source_column="source" source_column="source"
target_column="target" target_column="target"
truncation_side="left" truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=128 max_target_length=128
model_name_or_path="t5-small" model_name_or_path="t5-small"
per_device_train_batch_size=128 per_device_train_batch_size=64
per_device_eval_batch_size=128 per_device_eval_batch_size=128
gradient_accumulation_steps=4 gradient_accumulation_steps=1
num_workers=16
lr=1e-3 lr=1e-3
num_train_epochs=3 num_train_epochs=1
python -m torch.distributed.launch --master_port ${master_port} \ python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \ --task_name ${task_name} \
--train_file ${train_file} \ --train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \ --source_column ${source_column} \
--target_column ${target_column} \ --target_column ${target_column} \
--max_source_length ${max_source_length} \ --max_source_length ${max_source_length} \
...@@ -36,21 +33,18 @@ python -m torch.distributed.launch --master_port ${master_port} \ ...@@ -36,21 +33,18 @@ python -m torch.distributed.launch --master_port ${master_port} \
--truncation_side ${truncation_side} \ --truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \ --model_name_or_path ${model_name_or_path} \
--do_train \ --do_train \
--do_eval \ --save_steps 5000 \
--do_predict \ --save_total_limit 3 \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--prediction_loss_only \
--cache_dir ${cache_dir} \ --cache_dir ${cache_dir} \
--output_dir ${output_dir} \ --output_dir ${output_dir} \
--logging_dir ${logging_dir} \ --logging_dir ${logging_dir} \
--overwrite_output_dir \ --preprocessing_num_workers ${num_workers} \
--preprocessing_num_workers 4 \ --dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \ --gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \ --learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \ --num_train_epochs ${num_train_epochs} \
--adafactor \ --optim adafactor \
--lr_scheduler_type constant \
--gradient_checkpointing --gradient_checkpointing
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment