diff --git a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8cd9f074ec3a60fb084c16eed6d1b41e53027c --- /dev/null +++ b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py @@ -0,0 +1,45 @@ +import json +import datasets +from tabulate import tabulate + +def main(predict_result): + data = { + "keywords": { + "positive_keywords": [], "negative_keywords": None, + "predictions": [], "references": [] + }, + "possible keywords": { + "positive_keywords": [], "negative_keywords": [], + "predictions": [], "references": [] + } + } + with open(predict_result) as f: + for line in f: + item = json.loads(line) + if item["keywords+context"].startswith("keywords"): + data["keywords"]["predictions"].append(item['predictions'].strip()) + data["keywords"]["references"].append(item['response'].strip()) + positive_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split(' | ') if len(k) > 0] + data["keywords"]["positive_keywords"].append(positive_keywords) + elif item["keywords+context"].startswith("possible keywords"): + data["possible keywords"]["predictions"].append(item['predictions'].strip()) + data["possible keywords"]["references"].append(item['response'].strip()) + possible_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split(' | ') if len(k) > 0] + for keyword in positive_keywords: + possible_keywords.remove(keyword) + data["possible keywords"]["positive_keywords"].append(positive_keywords) + data["possible keywords"]["negative_keywords"].append(possible_keywords) + metric = datasets.load_metric('./key2gen_metric.py') + table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}] + if len(data["possible keywords"]["predictions"]) > 0: + table.append({'prompt': "possible keywords", **metric.compute(**data["possible keywords"])}) + print(tabulate(table, headers='keys', tablefmt='github')) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="evaluate keywords to response generation performance") + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + args = parser.parse_args() + print(args) + main(args.predict_result) diff --git a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..57eaa502ef11e35a316728646f00b1e671c7d89a --- /dev/null +++ b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py @@ -0,0 +1,103 @@ +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""key2gen Metric""" + +import datasets +import sacrebleu + +# TODO: Add BibTeX citation +_CITATION = """\ +@inproceedings{post-2018-call, + title = "A Call for Clarity in Reporting {BLEU} Scores", + author = "Post, Matt", + booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers", + month = oct, + year = "2018", + address = "Belgium, Brussels", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/W18-6319", + pages = "186--191", +} +""" + +_DESCRIPTION = """\ +Metric to evaluate text-to-text models on the keywords grounded generation task. +""" + +_KWARGS_DESCRIPTION = """ +Calculates corpus-bleu4, positive keywords recall, negative keywords recall +Args: + positive_keywords: list of keywords (list of string) in the ground truth references + negative_keywords: list of keywords (list of string) in the random sampled references + predictions: list of predictions to score. Each predictions + should be a string. + references: list of reference for each prediction. Each + reference should be a string. +Returns: + bleu: corpus-bleu score + positive_keywords_recall: how many keywords in the ground truth response are generated, micro-averaged + negative_keywords_recall: how many keywords in the random sampled response are generated, micro-averaged +Examples: + + >>> key2gen_metric = datasets.load_metric("key2gen_metric.py") + >>> predictions = ["hello there general kenobi", "foo bar foobar"] + >>> references = ["hello there kenobi", "foo bar foobar"] + >>> results = nlg_metric.compute(predictions=predictions, references=references) + >>> print(results) + {'bleu': 35.35533905932737} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Key2GenMetrics(datasets.Metric): + """Metric to evaluate text-to-text models on the keywords grounded generation task.""" + + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + # This defines the format of each prediction and reference + features=datasets.Features({ + 'predictions': datasets.Value('string'), + 'references': datasets.Value('string'), + }) + ) + + def _compute(self, predictions, references, positive_keywords, negative_keywords=None): + """Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall""" + if not negative_keywords: + negative_keywords = [[]] * len(positive_keywords) + bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score + cnt = {'pos': 0, 'neg': 0, 'pos_recall': 0, 'neg_recall': 0} + for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions): + cnt['pos'] += len(poskeys) + cnt['neg'] += len(negkeys) + + prediction = prediction.lower() + for key in poskeys: + key = key.lower() + if key in prediction: + cnt['pos_recall'] += 1 + + for key in negkeys: + key = key.lower() + if key in prediction: + cnt['neg_recall'] += 1 + + return { + "bleu": bleu, + "positive_keywords_recall": cnt['pos_recall']/cnt['pos'] if cnt['pos'] > 0 else 0, + "negative_keywords_recall": cnt['neg_recall']/cnt['neg'] if cnt['neg'] > 0 else 0, + } diff --git a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh index 9a274f8fe344efe3125920903bc688e8aeb7c38e..469ec695ba681a835c7d9c51e95803c674c87d11 100644 --- a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh +++ b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh @@ -1,7 +1,7 @@ set -e -n_gpus=1 -task_name="key2gen" -dataset_name="multiwoz21" +n_gpus=2 +task_name="key2gen_shuffle_noisy" +dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" speaker="all" model_type="gpt" data_dir="data/${task_name}/${model_type}/${dataset_name}" @@ -16,7 +16,7 @@ target_column="response" truncation_side="left" max_source_length=512 max_target_length=128 -model_name_or_path="output/key2gen/gpt/metalwoz+sgd+tm1+tm2+tm3" +model_name_or_path="output/${task_name}/${model_type}/${dataset_name}" per_device_train_batch_size=128 per_device_eval_batch_size=128 gradient_accumulation_steps=4 @@ -40,4 +40,11 @@ python -m torch.distributed.launch \ --logging_dir ${logging_dir} \ --overwrite_output_dir \ --preprocessing_num_workers 4 \ - --per_device_eval_batch_size ${per_device_eval_batch_size} + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --debug underflow_overflow \ + --adafactor \ + --gradient_checkpointing diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh index aa648116737f10909d1d83a0e9ec1ac0a5a682d2..d92365e787e2c58fdd8b6f4a4f870053c7561f2e 100644 --- a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh +++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh @@ -23,7 +23,7 @@ gradient_accumulation_steps=4 lr=1e-3 num_train_epochs=1 -python -m torch.distributed.launch --master_port 23456\ +python -m torch.distributed.launch \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --task_name ${task_name} \ --train_file ${train_file} \