Skip to content
Snippets Groups Projects
Commit 76c1d32e authored by zqwerty's avatar zqwerty
Browse files

update get token loss method for multi-gpu inference

parent 46df55a3
No related branches found
No related tags found
No related merge requests found
...@@ -35,5 +35,5 @@ if __name__ == '__main__': ...@@ -35,5 +35,5 @@ if __name__ == '__main__':
for dataset_name in tqdm(args.datasets, desc='datasets'): for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name) dataset = load_dataset(dataset_name)
for task_name in tqdm(args.tasks, desc='tasks', leave=False): for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_dir = os.path.join('data', task_name, dataset_name, args.model_type) data_dir = os.path.join('data', task_name, args.model_type, dataset_name)
eval(f"create_{task_name}_data")(dataset, data_dir, args) eval(f"create_{task_name}_data")(dataset, data_dir, args)
n_gpus=4
master_port=23456
task_name="lm"
model_type="gpt"
cache_dir="../cache"
source_column="dialogue"
max_length=512
model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large"
per_device_eval_batch_size=16
for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
do
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}"
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type ${model_type}
for data_split in validation train
do
validation_file="${data_dir}/${data_split}.json"
dump_eval_loss_to="${data_dir}/token_loss_${data_split}.json"
rm ${dump_eval_loss_to}
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../run_clm.py \
--dump_eval_loss_to ${dump_eval_loss_to}\
--model_name_or_path ${model_name_or_path} \
--output_dir ${data_dir} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--max_length ${max_length} \
--do_eval \
--cache_dir ${cache_dir} \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
done
done
...@@ -44,7 +44,6 @@ from transformers import ( ...@@ -44,7 +44,6 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
Trainer,
TrainingArguments, TrainingArguments,
DataCollatorForTokenClassification, DataCollatorForTokenClassification,
is_torch_tpu_available, is_torch_tpu_available,
...@@ -53,6 +52,7 @@ from transformers import ( ...@@ -53,6 +52,7 @@ from transformers import (
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from convlab.base_models.gpt.trainer import DumpTokenLossTrainer
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
...@@ -491,15 +491,17 @@ def main(): ...@@ -491,15 +491,17 @@ def main():
pad_to_multiple_of=8 if training_args.fp16 else None, pad_to_multiple_of=8 if training_args.fp16 else None,
) )
training_args.dump_eval_loss_to = data_args.dump_eval_loss_to
# Initialize our Trainer # Initialize our Trainer
trainer = Trainer( trainer = DumpTokenLossTrainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=train_dataset if training_args.do_train else None, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it. # Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=data_collator data_collator=data_collator,
) )
# Training # Training
...@@ -525,7 +527,6 @@ def main(): ...@@ -525,7 +527,6 @@ def main():
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
if not data_args.dump_eval_loss_to:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
...@@ -538,44 +539,6 @@ def main(): ...@@ -538,44 +539,6 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
else:
if trainer.is_world_process_zero():
output_prediction_file = data_args.dump_eval_loss_to
writer = open(output_prediction_file, "w", encoding='utf-8')
eval_dataloader = DataLoader(
eval_dataset, collate_fn=lambda x: {k: v.to(model.device) for k, v in data_collator(x).items()}, batch_size=training_args.per_device_eval_batch_size
)
model.eval()
losses = []
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
for batch in tqdm(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
losses.append(loss.repeat(training_args.per_device_eval_batch_size))
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = batch['labels'][..., 1:].contiguous()
batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist()
labels = batch['labels'].tolist()
for i in range(len(labels)):
token_ids = [x for x in labels[i] if x != -100]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
token_losses = [0] + batch_token_loss[i][:len(token_ids)-1]
writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n')
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
perplexity = math.exp(torch.mean(losses))
except OverflowError:
perplexity = float("inf")
logger.info(f"perplexity: {perplexity}")
writer.close()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
......
from transformers import Trainer
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.deepspeed import deepspeed_init
from transformers.utils import logging
from transformers.trainer_pt_utils import find_batch_size, nested_concat, nested_numpify, IterableDatasetShard, nested_truncate
from transformers.trainer_utils import EvalPrediction, denumpify_detensorize
import torch
from torch.utils.data import DataLoader
import numpy as np
from typing import List, Optional
import json
logger = logging.get_logger(__name__)
class DumpTokenLossTrainer(Trainer):
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
args = self.args
prediction_loss_only = args.prediction_loss_only
# if eval is called w/o train init deepspeed here
if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(
self, num_training_steps=0, resume_from_checkpoint=None, inference=True
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = self.args.eval_batch_size
logger.info(f"***** Running {description} *****")
if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")
model.eval()
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)
if args.past_index >= 0:
self._past = None
# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
inputs_host = None
# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length.
if args.dump_eval_loss_to:
writer = open(args.dump_eval_loss_to, "a", encoding='utf-8')
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
num_sample_to_write = len(eval_dataset)
observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
# Update containers on host
if loss is not None:
losses = self._nested_gather(loss.repeat(batch_size))
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if labels is not None:
labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels)
# labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
if inputs_decode is not None:
inputs_decode = self._pad_across_processes(inputs_decode)
inputs_decode = self._nested_gather(inputs_decode)
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
if logits is not None:
logits = self._pad_across_processes(logits)
logits = self._nested_gather(logits)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
# preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if args.dump_eval_loss_to:
if self.is_world_process_zero() and num_sample_to_write > 0:
assert logits is not None and labels is not None, print('prediction_loss_only', prediction_loss_only)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist()
labels = labels.tolist()
for i in range(len(labels)):
if num_sample_to_write > 0:
num_sample_to_write -= 1
else:
break
token_ids = [x for x in labels[i] if x != -100]
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
token_losses = [0] + batch_token_loss[i][:len(token_ids)-1]
writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n')
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)
# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
if args.dump_eval_loss_to:
writer.close()
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
if all_losses is not None:
all_losses = all_losses[:num_samples]
if all_preds is not None:
all_preds = nested_truncate(all_preds, num_samples)
if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples)
if all_inputs is not None:
all_inputs = nested_truncate(all_inputs, num_samples)
# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment