From ce2d802dc75e6922b2a292519cdb3a5ee9a61501 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Tue, 28 Dec 2021 04:04:21 +0000 Subject: [PATCH] clean up run_seq2seq.py --- convlab2/base_models/t5/run_seq2seq.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index 30e2b1d5..0679e367 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -1,4 +1,4 @@ -#!python +#!/usr/bin/env python # coding=utf-8 # Copyright 2021 The HuggingFace Team. All rights reserved. # @@ -27,12 +27,10 @@ from dataclasses import dataclass, field from typing import Optional import datasets -import nltk # Here to have a nice missing dependency error message early on import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset import transformers -from filelock import FileLock from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, @@ -43,7 +41,6 @@ from transformers import ( Seq2SeqTrainingArguments, set_seed, ) -from transformers.file_utils import is_offline_mode from transformers.trainer_utils import EvalPrediction, get_last_checkpoint from transformers.utils import check_min_version from transformers.utils.versions import require_version @@ -109,7 +106,7 @@ class DataTrainingArguments: task_name: Optional[str] = field( - default=None, metadata={"help": "The name of the task, e.g., response generation."} + default=None, metadata={"help": "The name of the task, e.g., rg (for rgresponse generation)."} ) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} @@ -223,7 +220,7 @@ class DataTrainingArguments: and self.validation_file is None and self.test_file is None ): - raise ValueError("Need either a dataset name or a training/validation file/test_file.") + raise ValueError("Need either a dataset name or a training/validation/testing file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -238,8 +235,6 @@ class DataTrainingArguments: self.val_max_target_length = self.max_target_length - - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -586,6 +581,7 @@ def main(): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + # Predict if training_args.do_predict: logger.info("*** Predict ***") @@ -636,4 +632,4 @@ def _mp_fn(index): if __name__ == "__main__": - main() \ No newline at end of file + main() -- GitLab