Skip to content
Snippets Groups Projects
Commit e4872bae authored by zqwerty's avatar zqwerty
Browse files

pre-train t5 rg

parent a9443416
No related branches found
No related tags found
No related merge requests found
import json
import json_lines
import os
import random
from tqdm import tqdm
......@@ -7,85 +8,19 @@ from nltk import sent_tokenize
def main(args):
random.seed(42)
os.makedirs(args.output_dir, exist_ok=True)
if args.mode == 'multitask':
dataset_name = args.output_dir.split('/')[-1]
for data_split in ['validation', 'train']:
with open(os.path.join(args.output_dir, f"{data_split}.json"), 'w', encoding='utf-8') as fout:
for task_name in ['rg', 'key2gen', 'key2gen_noisy']:
with open(os.path.join(args.input_dir, task_name, 'gpt', dataset_name, f"{data_split}.json")) as fin:
for line in fin:
item = json.loads(line)
fout.write(json.dumps({'source': item['source'], 'target': item['target']}, ensure_ascii=False)+'\n')
return
if args.mode == 'sen2gen':
generated_filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if f.startswith('gen_')]
original_filenames = [f[4:] for f in generated_filenames]
for ori_f, gen_f in zip(original_filenames, generated_filenames):
fori = open(os.path.join(args.input_dir, ori_f))
fgen = open(os.path.join(args.input_dir, gen_f))
fout = open(os.path.join(args.output_dir, f"{ori_f.split('_')[0]}.json"), 'w', encoding='utf-8')
for ori_line, gen_line in zip(fori, fgen):
ori_item = json.loads(ori_line)
gen_item = json.loads(gen_line)
context = ori_item['source'][ori_item['source'].index('context:\n\n'):]
gen_sen = gen_item['predictions']
ori_item['source'] = f'generate a response: grounded knowledge: | {gen_sen} | {context}'
ori_item['gen_sen'] = gen_sen
fout.write(json.dumps(ori_item, ensure_ascii=False)+'\n')
return
if args.mode == 'sen2gen_noisy':
def gen_samples(dialog_samples):
turn_gen_sens = [sent_tokenize(item['gen_sen']) for item in dialog_samples]
for i, sample in enumerate(dialog_samples):
possible_sens_turns = turn_gen_sens[i][:]
num_possible_sens_turns = min(random.randint(1, 5), len(turn_gen_sens) - 1)
for turn_sens in random.sample(turn_gen_sens[:i] + turn_gen_sens[i+1:], num_possible_sens_turns):
possible_sens_turns.extend(turn_sens)
random.shuffle(possible_sens_turns)
possible_sens = ' | '.join(possible_sens_turns)
context = sample['source'][sample['source'].index('context:\n\n'):]
sample['source'] = f'generate a response: all knowledge: | {possible_sens} | {context}'
yield sample
for ori_f in [f for (_, _, fs) in os.walk(args.input_dir) for f in fs]:
fori = open(os.path.join(args.input_dir, ori_f))
fout = open(os.path.join(args.output_dir, ori_f), 'w', encoding='utf-8')
dialog = []
prev_num_turns = 0
for line in fori:
item = json.loads(line)
num_turns = item['source'].count('\n')
if len(dialog) == 0 or num_turns < prev_num_turns:
# process a dialog with augmented responses
for sample in gen_samples(dialog):
fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
# next dialog
dialog = [item]
else:
# next turn
dialog.append(item)
prev_num_turns = num_turns
for sample in gen_samples(dialog):
fout.write(json.dumps(sample, ensure_ascii=False)+'\n')
return
filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
filenames = [os.path.join(args.input_dir, f) for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
for filename in filenames:
data = json.load(open(os.path.join(args.input_dir, filename)))
# Splitting the data into multiple pieces.
if args.n_splits > 1:
len_data_pieces = len(data)//args.n_splits
fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}_split_{i}-of-{args.n_splits}.json"), 'w', encoding='utf-8') for i in range(args.n_splits)]
random.shuffle(data)
else:
len_data_pieces = len(data)
fouts = [open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')]
for i, fout in enumerate(fouts):
for dial in tqdm(data[i*len_data_pieces:(i+1)*len_data_pieces]):
dataset_name = filename.split('/')[-2]
data_split = filename.split('/')[-1].split('_')[-1].split('.')[0]
output_file = os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[-1]}")
print(f'processing {dataset_name}: {filename} => {output_file}')
cnt = 0
with open(filename, 'rb') as fin, open(output_file, 'w', encoding='utf-8') as fout:
for dial in tqdm(json_lines.reader(fin)):
context = []
turns_keywords = [turn['keywords'] for turn in dial]
for i, turn in enumerate(dial):
if 'wikidialog' in filename:
if dataset_name == 'wikidialog':
# skip user turns that generated by T5 in wikidialog
speaker = 'user' if i % 2 == 1 else 'system'
else:
......@@ -93,18 +28,16 @@ def main(args):
utt = turn['utterance']
context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
context.append({'speaker': speaker, 'utt': utt})
if i == 0 or ('wikidialog' in filename and speaker == 'user'):
if i == 0 or (dataset_name == 'wikidialog' and speaker == 'user'):
continue
if args.mode == 'rg':
input_seq = f'generate a response: context:\n\n{context_seq}'
input_seq = f'generate a response: all knowledge: | | context:\n\n{context_seq}'
fout.write(json.dumps({
'dataset': filename.split('/')[-1].split('_')[0],
'dataset': dataset_name,
'source': input_seq,
'target': utt}, ensure_ascii=False)+'\n')
continue
if len(turn['keywords']) == 0 or max([len(k) for k in turn['keywords']])>10:
'target': utt
}, ensure_ascii=False)+'\n')
continue
if args.mode == 'key2gen':
......@@ -113,11 +46,14 @@ def main(args):
random.shuffle(turn['keywords'][j])
keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']])
input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({
'dataset': filename.split('/')[-1].split('_')[0],
json2dump = {
'dataset': dataset_name,
'source': input_seq,
'target': utt,
'keywords': turn['keywords']}, ensure_ascii=False)+'\n')
'target': utt
}
if data_split == 'validation':
json2dump.update({'keywords': turn['keywords']})
fout.write(json.dumps(json2dump, ensure_ascii=False)+'\n')
continue
if args.mode == 'key2gen_noisy':
......@@ -128,12 +64,14 @@ def main(args):
random.shuffle(possible_keywords_sents)
possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents])
input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({
'dataset': filename.split('/')[-1].split('_')[0],
json2dump = {
'dataset': dataset_name,
'source': input_seq,
'target': utt,
'keywords': turn['keywords'],
'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n')
'target': utt
}
if data_split == 'validation':
json2dump.update({'keywords': turn['keywords'], 'all_keywords': possible_keywords_sents})
fout.write(json.dumps(json2dump, ensure_ascii=False)+'\n')
continue
......@@ -142,8 +80,7 @@ if __name__ == '__main__':
parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy', 'sen2gen', 'sen2gen_noisy', 'multitask'], help='which task to perform')
parser.add_argument('--n_splits', '-n', type=int, default=1, help='split the data into multiple pieces')
parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy'], help='which task to perform')
args = parser.parse_args()
print(args)
main(args)
# generate data for response generation, key2gen, key2gen_noisy
for task_name in rg
for task_name in rg key2gen key2gen_noisy
do
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
names=$(echo ${dataset_name} | tr "+" "\n")
......@@ -8,8 +8,7 @@ do
mkdir -p ${data_dir}
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
rm ${train_file} ${validation_file} ${test_file}
rm ${train_file} ${validation_file}
for name in ${names}
do
echo "preprocessing ${name}"
......@@ -17,60 +16,6 @@ do
if [ "${name}" != "${dataset_name}" ]; then
cat "data/${task_name}/${model_type}/${name}/train.json" >> ${train_file}
cat "data/${task_name}/${model_type}/${name}/validation.json" >> ${validation_file}
cat "data/${task_name}/${model_type}/${name}/test.json" >> ${test_file}
fi
done
done
\ No newline at end of file
# # generate data for sentence grounded generation
# task_name="key2gen"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# names=$(echo ${dataset_name} | tr "+" "\n")
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# n_splits=2
# for ((n=0;n<${n_splits};n++))
# do
# rm ${data_dir}/train_split_${n}-of-${n_splits}.json ${data_dir}/validation_split_${n}-of-${n_splits}.json ${data_dir}/test_split_${n}-of-${n_splits}.json
# done
# for name in ${names}
# do
# echo "preprocessing ${name}"
# python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} -m ${task_name} -n ${n_splits}
# if [ "${name}" != "${dataset_name}" ]; then
# for ((n=0;n<${n_splits};n++))
# do
# cat "data/${task_name}/gpt/${name}/train_split_${n}-of-${n_splits}.json" >> "${data_dir}/train_split_${n}-of-${n_splits}.json"
# cat "data/${task_name}/gpt/${name}/validation_split_${n}-of-${n_splits}.json" >> "${data_dir}/validation_split_${n}-of-${n_splits}.json"
# cat "data/${task_name}/gpt/${name}/test_split_${n}-of-${n_splits}.json" >> "${data_dir}/test_split_${n}-of-${n_splits}.json"
# done
# fi
# done
# # merge generated data with original data
# task_name="sen2gen"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# names=$(echo ${dataset_name} | tr "+" "\n")
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# python gen_pretraining_data.py -i data/key2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
# # generate sen2gen_noisy data with original data
# task_name="sen2gen_noisy"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# names=$(echo ${dataset_name} | tr "+" "\n")
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# python gen_pretraining_data.py -i data/sen2gen/${model_type}/${dataset_name} -o data/${task_name}/${model_type}/${dataset_name} -m ${task_name}
# merge data for multitask training
# task_name="rg+key2gen+key2gen_noisy+sen2gen+sen2gen_noisy"
# dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
# model_type="gpt"
# data_dir=data/${task_name}/${model_type}/${dataset_name}
# mkdir -p ${data_dir}
# python gen_pretraining_data.py -i data/ -o data/${task_name}/${model_type}/${dataset_name} -m multitask
n_gpus=2
master_port=$1
task_name="key2gen"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
model_type="gpt"
split_id=$2
n_splits=$3
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}/gen"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
# train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
# validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
let infer_split_id=($split_id+1)%$n_splits
test_file="${data_dir}/validation_split_${infer_split_id}-of-${n_splits}.json"
source_column="source"
target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=4
lr=1e-3
num_train_epochs=10
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_predict \
--predict_with_generate \
--do_sample \
--top_p 0.9 \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 16 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
......@@ -127,7 +127,7 @@ def main(args):
fin = open(word_loss_file, 'rb')
fout = open(args.output_file, 'w', encoding='utf-8')
for item in json_lines.reader(fin):
for item in tqdm(json_lines.reader(fin)):
words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']]
losses = [np.mean(loss) for loss in item['losses']]
dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses)
......
set -e
n_gpus=2
master_port=$1
task_name="key2gen"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+wikidialog"
model_type="gpt"
split_id=$2
n_splits=$3
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}_split_${split_id}-of-${n_splits}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train_split_${split_id}-of-${n_splits}.json"
validation_file="${data_dir}/validation_split_${split_id}-of-${n_splits}.json"
test_file="${data_dir}/test_split_${split_id}-of-${n_splits}.json"
source_column="source"
target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=4
num_workers=16
lr=1e-3
num_train_epochs=1
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--load_best_model_at_end \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
set -e
n_gpus=2
n_gpus=8
master_port=23456
task_name="rg"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3+sgd+reddit+wikidialog"
model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
source_column="source"
target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_train_batch_size=64
per_device_eval_batch_size=128
gradient_accumulation_steps=4
gradient_accumulation_steps=1
num_workers=16
lr=1e-3
num_train_epochs=3
num_train_epochs=1
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
......@@ -36,21 +33,18 @@ python -m torch.distributed.launch --master_port ${master_port} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--prediction_loss_only \
--save_steps 5000 \
--save_total_limit 3 \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--optim adafactor \
--lr_scheduler_type constant \
--gradient_checkpointing
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment