Skip to content
Snippets Groups Projects
Commit e265e86b authored by zqwerty's avatar zqwerty
Browse files

add key2gen metric. pre-training dataset add dailydialog

parent a1a2f240
No related branches found
No related tags found
No related merge requests found
import json
import datasets
from tabulate import tabulate
def main(predict_result):
data = {
"keywords": {
"positive_keywords": [], "negative_keywords": None,
"predictions": [], "references": []
},
"possible keywords": {
"positive_keywords": [], "negative_keywords": [],
"predictions": [], "references": []
}
}
with open(predict_result) as f:
for line in f:
item = json.loads(line)
if item["keywords+context"].startswith("keywords"):
data["keywords"]["predictions"].append(item['predictions'].strip())
data["keywords"]["references"].append(item['response'].strip())
positive_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split(' | ') if len(k) > 0]
data["keywords"]["positive_keywords"].append(positive_keywords)
elif item["keywords+context"].startswith("possible keywords"):
data["possible keywords"]["predictions"].append(item['predictions'].strip())
data["possible keywords"]["references"].append(item['response'].strip())
possible_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split(' | ') if len(k) > 0]
for keyword in positive_keywords:
possible_keywords.remove(keyword)
data["possible keywords"]["positive_keywords"].append(positive_keywords)
data["possible keywords"]["negative_keywords"].append(possible_keywords)
metric = datasets.load_metric('./key2gen_metric.py')
table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}]
if len(data["possible keywords"]["predictions"]) > 0:
table.append({'prompt': "possible keywords", **metric.compute(**data["possible keywords"])})
print(tabulate(table, headers='keys', tablefmt='github'))
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="evaluate keywords to response generation performance")
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
args = parser.parse_args()
print(args)
main(args.predict_result)
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""key2gen Metric"""
import datasets
import sacrebleu
# TODO: Add BibTeX citation
_CITATION = """\
@inproceedings{post-2018-call,
title = "A Call for Clarity in Reporting {BLEU} Scores",
author = "Post, Matt",
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
month = oct,
year = "2018",
address = "Belgium, Brussels",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/W18-6319",
pages = "186--191",
}
"""
_DESCRIPTION = """\
Metric to evaluate text-to-text models on the keywords grounded generation task.
"""
_KWARGS_DESCRIPTION = """
Calculates corpus-bleu4, positive keywords recall, negative keywords recall
Args:
positive_keywords: list of keywords (list of string) in the ground truth references
negative_keywords: list of keywords (list of string) in the random sampled references
predictions: list of predictions to score. Each predictions
should be a string.
references: list of reference for each prediction. Each
reference should be a string.
Returns:
bleu: corpus-bleu score
positive_keywords_recall: how many keywords in the ground truth response are generated, micro-averaged
negative_keywords_recall: how many keywords in the random sampled response are generated, micro-averaged
Examples:
>>> key2gen_metric = datasets.load_metric("key2gen_metric.py")
>>> predictions = ["hello there general kenobi", "foo bar foobar"]
>>> references = ["hello there kenobi", "foo bar foobar"]
>>> results = nlg_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'bleu': 35.35533905932737}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Key2GenMetrics(datasets.Metric):
"""Metric to evaluate text-to-text models on the keywords grounded generation task."""
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features({
'predictions': datasets.Value('string'),
'references': datasets.Value('string'),
})
)
def _compute(self, predictions, references, positive_keywords, negative_keywords=None):
"""Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall"""
if not negative_keywords:
negative_keywords = [[]] * len(positive_keywords)
bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
cnt = {'pos': 0, 'neg': 0, 'pos_recall': 0, 'neg_recall': 0}
for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions):
cnt['pos'] += len(poskeys)
cnt['neg'] += len(negkeys)
prediction = prediction.lower()
for key in poskeys:
key = key.lower()
if key in prediction:
cnt['pos_recall'] += 1
for key in negkeys:
key = key.lower()
if key in prediction:
cnt['neg_recall'] += 1
return {
"bleu": bleu,
"positive_keywords_recall": cnt['pos_recall']/cnt['pos'] if cnt['pos'] > 0 else 0,
"negative_keywords_recall": cnt['neg_recall']/cnt['neg'] if cnt['neg'] > 0 else 0,
}
set -e set -e
n_gpus=1 n_gpus=2
task_name="key2gen" task_name="key2gen_shuffle_noisy"
dataset_name="multiwoz21" dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
speaker="all" speaker="all"
model_type="gpt" model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}" data_dir="data/${task_name}/${model_type}/${dataset_name}"
...@@ -16,7 +16,7 @@ target_column="response" ...@@ -16,7 +16,7 @@ target_column="response"
truncation_side="left" truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=128 max_target_length=128
model_name_or_path="output/key2gen/gpt/metalwoz+sgd+tm1+tm2+tm3" model_name_or_path="output/${task_name}/${model_type}/${dataset_name}"
per_device_train_batch_size=128 per_device_train_batch_size=128
per_device_eval_batch_size=128 per_device_eval_batch_size=128
gradient_accumulation_steps=4 gradient_accumulation_steps=4
...@@ -40,4 +40,11 @@ python -m torch.distributed.launch \ ...@@ -40,4 +40,11 @@ python -m torch.distributed.launch \
--logging_dir ${logging_dir} \ --logging_dir ${logging_dir} \
--overwrite_output_dir \ --overwrite_output_dir \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} --per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
...@@ -23,7 +23,7 @@ gradient_accumulation_steps=4 ...@@ -23,7 +23,7 @@ gradient_accumulation_steps=4
lr=1e-3 lr=1e-3
num_train_epochs=1 num_train_epochs=1
python -m torch.distributed.launch --master_port 23456\ python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \ --task_name ${task_name} \
--train_file ${train_file} \ --train_file ${train_file} \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment