@@ -161,6 +163,103 @@ class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
...
@@ -161,6 +163,103 @@ class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
classConvLabSeq2SeqTrainer(Seq2SeqTrainer):
classConvLabSeq2SeqTrainer(Seq2SeqTrainer):
# modifed from Seq2SeqTrainer of 4.26.1: https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/trainer_seq2seq.py
# add generation args in `prediction_step`
defevaluate(
self,
eval_dataset:Optional[Dataset]=None,
ignore_keys:Optional[List[str]]=None,
metric_key_prefix:str="eval",
**gen_kwargs
)->Dict[str,float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init `compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is `"eval"` (default)
max_length (`int`, *optional*):
The maximum target length to use when predicting with the generate method.
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
gen_kwargs:
Additional `generate` specific kwargs.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.