Skip to content
Snippets Groups Projects
Commit 5ead12a1 authored by zqwerty's avatar zqwerty
Browse files

key2gen_noisy: 1) rg; 2) keywords grounded generation; 3) add noisy keywords...

key2gen_noisy: 1) rg; 2) keywords grounded generation; 3) add noisy keywords from randomly sampled turns within the dialog
parent 243040d1
Branches
No related tags found
No related merge requests found
Showing
with 130 additions and 59 deletions
...@@ -19,16 +19,27 @@ def main(predict_result): ...@@ -19,16 +19,27 @@ def main(predict_result):
if item["keywords+context"].startswith("keywords"): if item["keywords+context"].startswith("keywords"):
data["keywords"]["predictions"].append(item['predictions'].strip()) data["keywords"]["predictions"].append(item['predictions'].strip())
data["keywords"]["references"].append(item['response'].strip()) data["keywords"]["references"].append(item['response'].strip())
positive_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split(' | ') if len(k) > 0] positive_keywords = [k.strip() for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split('|')[1].split(' : ') if len(k) > 0]
data["keywords"]["positive_keywords"].append(positive_keywords) data["keywords"]["positive_keywords"].append(positive_keywords)
elif item["keywords+context"].startswith("possible keywords"): elif item["keywords+context"].startswith("possible keywords"):
data["possible keywords"]["predictions"].append(item['predictions'].strip()) data["possible keywords"]["predictions"].append(item['predictions'].strip())
data["possible keywords"]["references"].append(item['response'].strip()) data["possible keywords"]["references"].append(item['response'].strip())
possible_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split(' | ') if len(k) > 0] possible_keywords = [k.strip() for ks in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split('|') for k in ks.split(' : ') if len(k) > 0]
has_positive = True
for keyword in positive_keywords: for keyword in positive_keywords:
if keyword in possible_keywords:
possible_keywords.remove(keyword) possible_keywords.remove(keyword)
else:
has_positive = False
break
if has_positive:
data["possible keywords"]["positive_keywords"].append(positive_keywords) data["possible keywords"]["positive_keywords"].append(positive_keywords)
else:
data["possible keywords"]["positive_keywords"].append([])
data["possible keywords"]["negative_keywords"].append(possible_keywords) data["possible keywords"]["negative_keywords"].append(possible_keywords)
# print(data)
# if len(data["possible keywords"]["positive_keywords"])>0:
# break
metric = datasets.load_metric('./key2gen_metric.py') metric = datasets.load_metric('./key2gen_metric.py')
table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}] table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}]
if len(data["possible keywords"]["predictions"]) > 0: if len(data["possible keywords"]["predictions"]) > 0:
......
...@@ -4,36 +4,46 @@ import random ...@@ -4,36 +4,46 @@ import random
from tqdm import tqdm from tqdm import tqdm
def main(args): def main(args):
random.seed(45) random.seed(42)
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f] filenames = [f for (_, _, fs) in os.walk(args.input_dir) for f in fs if 'keywords' in f]
for filename in filenames: for filename in filenames:
data = json.load(open(os.path.join(args.input_dir, filename))) data = json.load(open(os.path.join(args.input_dir, filename)))
fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8') fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')
turn_keywords = [turn['keywords'] for dial in data for turn in dial]
random.shuffle(turn_keywords)
cnt = 0
# keywords_set = {keyword for keywords in turn_keywords_set for keyword in keywords}
for dial in tqdm(data): for dial in tqdm(data):
context = [] context = []
turn_keywords = [turn['keywords'] for turn in dial]
for i, turn in enumerate(dial): for i, turn in enumerate(dial):
speaker = 'user' if i % 2 == 0 else 'system' speaker = 'user' if i % 2 == 0 else 'system'
random.shuffle(turn['keywords'])
keywords = ' | '.join(turn['keywords'])
utt = turn['utterance'] utt = turn['utterance']
context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: ']) context_seq = '\n'.join([f"{turn['speaker']}: {turn['utt']}" for turn in context]+[f'{speaker}: '])
input_seq = f'keywords: {keywords}\n\ncontext: {context_seq}'
context.append({'speaker': speaker, 'utt': utt}) context.append({'speaker': speaker, 'utt': utt})
fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n') if i == 0:
continue
input_seq = f'generate a response: context:\n\n{context_seq}'
fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
if args.mode == 'rg':
continue
random.shuffle(turn['keywords'])
keywords = ' : '.join(turn['keywords'])
input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
if args.mode == 'key2gen':
continue
negative_keywords = turn_keywords[cnt] possible_keywords_turns = [turn['keywords']]
cnt += 1 num_possible_keywords_turns = min(random.randint(1, 5), len(turn_keywords) - 1)
possible_keywords = turn['keywords'] + list(negative_keywords) possible_keywords_turns += random.sample(turn_keywords[:i] + turn_keywords[i+1:], num_possible_keywords_turns)
random.shuffle(possible_keywords) random.shuffle(possible_keywords_turns)
possible_keywords = ' | '.join(possible_keywords) for possible_keywords_turn in possible_keywords_turns:
input_seq = f'possible keywords: {possible_keywords}\n\ncontext: {context_seq}' random.shuffle(possible_keywords_turn)
if args.noisy: possible_keywords = ' | '.join([' : '.join(possible_keywords_turn) for possible_keywords_turn in possible_keywords_turns])
fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n') input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
if args.mode == 'key2gen_noisy':
continue
if __name__ == '__main__': if __name__ == '__main__':
...@@ -41,7 +51,7 @@ if __name__ == '__main__': ...@@ -41,7 +51,7 @@ if __name__ == '__main__':
parser = ArgumentParser(description="calculate NLU metrics for unified datasets") parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--input_dir', '-i', type=str, help='path to the input files') parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
parser.add_argument('--output_dir', '-o', type=str, help='path to the output files') parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
parser.add_argument('--noisy', action='store_true', help='whether add noisy keywords samples') parser.add_argument('--mode', '-m', type=str, choices=['rg', 'key2gen', 'key2gen_noisy'], help='which task to perform')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
task_name="key2gen_shuffle_noisy" task_name="key2gen_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
names=$(echo ${dataset_name} | tr "+" "\n") names=$(echo ${dataset_name} | tr "+" "\n")
model_type="gpt" model_type="gpt"
...@@ -11,11 +11,11 @@ test_file="${data_dir}/test.json" ...@@ -11,11 +11,11 @@ test_file="${data_dir}/test.json"
for name in ${names} for name in ${names}
do do
echo "preprocessing ${name}" echo "preprocessing ${name}"
python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} --noisy python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} -m ${task_name}
if [ "${name}" != "${dataset_name}" ]; then if [ "${name}" != "${dataset_name}" ]; then
cat "data/${task_name}/gpt/${name}/train.json" >> ${train_file} cat "data/${task_name}/gpt/${name}/train.json" >> ${train_file}
cat "data/${task_name}/gpt/${name}/validation.json" >> ${validation_file} cat "data/${task_name}/gpt/${name}/validation.json" >> ${validation_file}
cat "data/${task_name}/gpt/${name}/test.json" >> ${test_file} cat "data/${task_name}/gpt/${name}/test.json" >> ${test_file}
fi fi
done done
python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/${task_name}/${model_type}/multiwoz21 --noisy python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/${task_name}/${model_type}/multiwoz21 -m ${task_name}
\ No newline at end of file \ No newline at end of file
...@@ -48,14 +48,6 @@ Returns: ...@@ -48,14 +48,6 @@ Returns:
bleu: corpus-bleu score bleu: corpus-bleu score
positive_keywords_recall: how many keywords in the ground truth response are generated, micro-averaged positive_keywords_recall: how many keywords in the ground truth response are generated, micro-averaged
negative_keywords_recall: how many keywords in the random sampled response are generated, micro-averaged negative_keywords_recall: how many keywords in the random sampled response are generated, micro-averaged
Examples:
>>> key2gen_metric = datasets.load_metric("key2gen_metric.py")
>>> predictions = ["hello there general kenobi", "foo bar foobar"]
>>> references = ["hello there kenobi", "foo bar foobar"]
>>> results = nlg_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'bleu': 35.35533905932737}
""" """
...@@ -77,6 +69,7 @@ class Key2GenMetrics(datasets.Metric): ...@@ -77,6 +69,7 @@ class Key2GenMetrics(datasets.Metric):
def _compute(self, predictions, references, positive_keywords, negative_keywords=None): def _compute(self, predictions, references, positive_keywords, negative_keywords=None):
"""Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall""" """Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall"""
# rouge-1/2/L bleu-1/2 distinct-1/2
if not negative_keywords: if not negative_keywords:
negative_keywords = [[]] * len(positive_keywords) negative_keywords = [[]] * len(positive_keywords)
bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
......
...@@ -41,7 +41,7 @@ def merge_tokens(tokens, losses): ...@@ -41,7 +41,7 @@ def merge_tokens(tokens, losses):
res[-1][0].append(token) res[-1][0].append(token)
res[-1][1].append(loss) res[-1][1].append(loss)
else: else:
res.append([token, loss]) res.append([[token], [loss]])
i += 1 i += 1
return res return res
......
...@@ -6,9 +6,9 @@ def main(args): ...@@ -6,9 +6,9 @@ def main(args):
dialogs = [] dialogs = []
for i in range(len(filename2data[first_filename])): for i in range(len(filename2data[first_filename])):
turns = [] turns = []
for j in range(len(filename2data[first_filename][i])): for j in range(min([len(filename2data[filename][i]) for filename in filename2data])):
utt = filename2data[first_filename][i][j]['utterance'] utt = filename2data[first_filename][i][j]['utterance']
keywords = {filename.split('_')[2]+'_nonstopword'+filename.split('_')[-1]: ' | '.join([x[0] for x in filename2data[filename][i][j]['keywords']]) for filename in filename2data} keywords = {filename.split('_')[3]+'_nonstopword'+filename.split('_')[-1]: ' | '.join(filename2data[filename][i][j]['keywords']) for filename in filename2data}
turns.append({ turns.append({
"utterance": utt, "utterance": utt,
**keywords **keywords
......
set -e set -e
n_gpus=2 n_gpus=4
task_name="key2gen_shuffle_noisy" master_port=23457
task_name="key2gen_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
speaker="all"
model_type="gpt" model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}" data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}" output_dir="output/${task_name}/${model_type}/${dataset_name}"
cache_dir="../../t5/cache" cache_dir="../cache"
logging_dir="${output_dir}/runs" logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json" train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json" validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json" test_file="${data_dir}/test.json"
source_column="keywords+context" source_column="source"
target_column="response" target_column="target"
truncation_side="left" truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=128 max_target_length=128
model_name_or_path="output/${task_name}/${model_type}/${dataset_name}" model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+sgd+tm1+tm2+tm3"
per_device_train_batch_size=128 per_device_train_batch_size=128
per_device_eval_batch_size=128 per_device_eval_batch_size=128
gradient_accumulation_steps=4 gradient_accumulation_steps=2
lr=1e-3 lr=1e-3
num_train_epochs=1 num_train_epochs=3
python -m torch.distributed.launch \ python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \ --task_name ${task_name} \
--test_file ${test_file} \ --test_file ${test_file} \
......
set -e set -e
n_gpus=2 n_gpus=4
task_name="key2gen_shuffle_noisy" master_port=23457
task_name="key2gen_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
speaker="all"
model_type="gpt" model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}" data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}" output_dir="output/${task_name}/${model_type}/${dataset_name}"
cache_dir="../../t5/cache" cache_dir="../cache"
logging_dir="${output_dir}/runs" logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json" train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json" validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json" test_file="${data_dir}/test.json"
source_column="keywords+context" source_column="source"
target_column="response" target_column="target"
truncation_side="left" truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=128 max_target_length=128
model_name_or_path="t5-small" model_name_or_path="t5-small"
per_device_train_batch_size=128 per_device_train_batch_size=128
per_device_eval_batch_size=128 per_device_eval_batch_size=128
gradient_accumulation_steps=4 gradient_accumulation_steps=2
lr=1e-3 lr=1e-3
num_train_epochs=1 num_train_epochs=3
python -m torch.distributed.launch \ python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \ --task_name ${task_name} \
--train_file ${train_file} \ --train_file ${train_file} \
......
set -e
n_gpus=2
master_port=23456
task_name="rg"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
source_column="source"
target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=4
lr=1e-3
num_train_epochs=3
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
...@@ -445,7 +445,7 @@ def main(): ...@@ -445,7 +445,7 @@ def main():
inputs.append(examples[source_column][i]) inputs.append(examples[source_column][i])
targets.append(examples[target_column][i]) targets.append(examples[target_column][i])
inputs = [prefix + '\n\n' + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
if padding: if padding:
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
else: else:
...@@ -566,6 +566,7 @@ def main(): ...@@ -566,6 +566,7 @@ def main():
compute_metrics=compute_metrics if training_args.predict_with_generate else None, compute_metrics=compute_metrics if training_args.predict_with_generate else None,
) )
if training_args.load_best_model_at_end: if training_args.load_best_model_at_end:
if data_args.early_stopping_patience > 0:
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience)) trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience))
# Training # Training
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment