Skip to content
Snippets Groups Projects
Commit ce2d802d authored by zqwerty's avatar zqwerty
Browse files

clean up run_seq2seq.py

parent 7f82cc0c
Branches
No related tags found
No related merge requests found
#!python
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
......@@ -27,12 +27,10 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import load_dataset, load_metric
from datasets import load_dataset
import transformers
from filelock import FileLock
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
......@@ -43,7 +41,6 @@ from transformers import (
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.file_utils import is_offline_mode
from transformers.trainer_utils import EvalPrediction, get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
......@@ -109,7 +106,7 @@ class DataTrainingArguments:
task_name: Optional[str] = field(
default=None, metadata={"help": "The name of the task, e.g., response generation."}
default=None, metadata={"help": "The name of the task, e.g., rg (for rgresponse generation)."}
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
......@@ -223,7 +220,7 @@ class DataTrainingArguments:
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
raise ValueError("Need either a dataset name or a training/validation/testing file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
......@@ -238,8 +235,6 @@ class DataTrainingArguments:
self.val_max_target_length = self.max_target_length
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
......@@ -586,6 +581,7 @@ def main():
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
logger.info("*** Predict ***")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment