From 76c1d32e6d75cb73de20f60c101b9a223a45dfbc Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Fri, 15 Jul 2022 15:52:28 +0800 Subject: [PATCH] update get token loss method for multi-gpu inference --- convlab/base_models/gpt/create_data.py | 2 +- .../gpt/keyword_extraction/get_token_loss.sh | 35 +++ convlab/base_models/gpt/run_clm.py | 73 ++---- convlab/base_models/gpt/trainer.py | 243 ++++++++++++++++++ 4 files changed, 297 insertions(+), 56 deletions(-) create mode 100644 convlab/base_models/gpt/keyword_extraction/get_token_loss.sh create mode 100644 convlab/base_models/gpt/trainer.py diff --git a/convlab/base_models/gpt/create_data.py b/convlab/base_models/gpt/create_data.py index c9e6e9f0..e6c4d67b 100644 --- a/convlab/base_models/gpt/create_data.py +++ b/convlab/base_models/gpt/create_data.py @@ -35,5 +35,5 @@ if __name__ == '__main__': for dataset_name in tqdm(args.datasets, desc='datasets'): dataset = load_dataset(dataset_name) for task_name in tqdm(args.tasks, desc='tasks', leave=False): - data_dir = os.path.join('data', task_name, dataset_name, args.model_type) + data_dir = os.path.join('data', task_name, args.model_type, dataset_name) eval(f"create_{task_name}_data")(dataset, data_dir, args) diff --git a/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh new file mode 100644 index 00000000..4c73a3b0 --- /dev/null +++ b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh @@ -0,0 +1,35 @@ +n_gpus=4 +master_port=23456 +task_name="lm" +model_type="gpt" +cache_dir="../cache" +source_column="dialogue" +max_length=512 +model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large" +per_device_eval_batch_size=16 + +for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog +do + data_dir="data/${task_name}/${model_type}/${dataset_name}" + output_dir="output/${task_name}/${model_type}/${dataset_name}" + + python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type ${model_type} + for data_split in validation train + do + validation_file="${data_dir}/${data_split}.json" + dump_eval_loss_to="${data_dir}/token_loss_${data_split}.json" + rm ${dump_eval_loss_to} + python -m torch.distributed.launch --master_port ${master_port} \ + --nproc_per_node ${n_gpus} ../run_clm.py \ + --dump_eval_loss_to ${dump_eval_loss_to}\ + --model_name_or_path ${model_name_or_path} \ + --output_dir ${data_dir} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --max_length ${max_length} \ + --do_eval \ + --cache_dir ${cache_dir} \ + --preprocessing_num_workers 4 \ + --per_device_eval_batch_size ${per_device_eval_batch_size} + done +done diff --git a/convlab/base_models/gpt/run_clm.py b/convlab/base_models/gpt/run_clm.py index 9dff4a0a..ace68609 100644 --- a/convlab/base_models/gpt/run_clm.py +++ b/convlab/base_models/gpt/run_clm.py @@ -44,7 +44,6 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, - Trainer, TrainingArguments, DataCollatorForTokenClassification, is_torch_tpu_available, @@ -53,6 +52,7 @@ from transformers import ( from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version from transformers.utils.versions import require_version +from convlab.base_models.gpt.trainer import DumpTokenLossTrainer # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -491,15 +491,17 @@ def main(): pad_to_multiple_of=8 if training_args.fp16 else None, ) + training_args.dump_eval_loss_to = data_args.dump_eval_loss_to + # Initialize our Trainer - trainer = Trainer( + trainer = DumpTokenLossTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, # Data collator will default to DataCollatorWithPadding, so we change it. - data_collator=data_collator + data_collator=data_collator, ) # Training @@ -525,58 +527,19 @@ def main(): # Evaluation if training_args.do_eval: logger.info("*** Evaluate ***") - if not data_args.dump_eval_loss_to: - metrics = trainer.evaluate(metric_key_prefix="eval") - max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["eval_perplexity"] = perplexity - logger.info(f"eval_perplexity: {perplexity}") - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - else: - if trainer.is_world_process_zero(): - output_prediction_file = data_args.dump_eval_loss_to - writer = open(output_prediction_file, "w", encoding='utf-8') - - eval_dataloader = DataLoader( - eval_dataset, collate_fn=lambda x: {k: v.to(model.device) for k, v in data_collator(x).items()}, batch_size=training_args.per_device_eval_batch_size - ) - model.eval() - losses = [] - loss_fct = torch.nn.CrossEntropyLoss(reduction='none') - for batch in tqdm(eval_dataloader): - with torch.no_grad(): - outputs = model(**batch) - - loss = outputs.loss - losses.append(loss.repeat(training_args.per_device_eval_batch_size)) - - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = batch['labels'][..., 1:].contiguous() - batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist() - labels = batch['labels'].tolist() - for i in range(len(labels)): - token_ids = [x for x in labels[i] if x != -100] - tokens = tokenizer.convert_ids_to_tokens(token_ids) - token_losses = [0] + batch_token_loss[i][:len(token_ids)-1] - writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n') - - losses = torch.cat(losses) - losses = losses[: len(eval_dataset)] - try: - perplexity = math.exp(torch.mean(losses)) - except OverflowError: - perplexity = float("inf") - logger.info(f"perplexity: {perplexity}") - - writer.close() - + metrics = trainer.evaluate(metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["eval_perplexity"] = perplexity + logger.info(f"eval_perplexity: {perplexity}") + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name diff --git a/convlab/base_models/gpt/trainer.py b/convlab/base_models/gpt/trainer.py new file mode 100644 index 00000000..5a8ed11c --- /dev/null +++ b/convlab/base_models/gpt/trainer.py @@ -0,0 +1,243 @@ +from transformers import Trainer +from transformers.trainer_utils import EvalLoopOutput, has_length +from transformers.deepspeed import deepspeed_init +from transformers.utils import logging +from transformers.trainer_pt_utils import find_batch_size, nested_concat, nested_numpify, IterableDatasetShard, nested_truncate +from transformers.trainer_utils import EvalPrediction, denumpify_detensorize +import torch +from torch.utils.data import DataLoader +import numpy as np +from typing import List, Optional +import json + + +logger = logging.get_logger(__name__) + +class DumpTokenLossTrainer(Trainer): + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = args.prediction_loss_only + + # if eval is called w/o train init deepspeed here + if args.deepspeed and not self.deepspeed: + + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init( + self, num_training_steps=0, resume_from_checkpoint=None, inference=True + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = self.args.eval_batch_size + + logger.info(f"***** Running {description} *****") + if has_length(dataloader): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + model.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = getattr(dataloader, "dataset", None) + + if args.past_index >= 0: + self._past = None + + # Initialize containers + # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) + losses_host = None + preds_host = None + labels_host = None + inputs_host = None + + # losses/preds/labels on CPU (final containers) + all_losses = None + all_preds = None + all_labels = None + all_inputs = None + # Will be useful when we have an iterable dataset so don't know its length. + + if args.dump_eval_loss_to: + writer = open(args.dump_eval_loss_to, "a", encoding='utf-8') + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + num_sample_to_write = len(eval_dataset) + + observed_num_examples = 0 + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + # For batch samplers, batch_size is not known by the dataloader in advance. + if batch_size is None: + batch_size = observed_batch_size + + # Prediction step + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + # Update containers on host + if loss is not None: + losses = self._nested_gather(loss.repeat(batch_size)) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if labels is not None: + labels = self._pad_across_processes(labels) + labels = self._nested_gather(labels) + # labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_decode = self._pad_across_processes(inputs_decode) + inputs_decode = self._nested_gather(inputs_decode) + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + if logits is not None: + logits = self._pad_across_processes(logits) + logits = self._nested_gather(logits) + if self.preprocess_logits_for_metrics is not None: + logits = self.preprocess_logits_for_metrics(logits, labels) + # preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + + if args.dump_eval_loss_to: + if self.is_world_process_zero() and num_sample_to_write > 0: + assert logits is not None and labels is not None, print('prediction_loss_only', prediction_loss_only) + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist() + labels = labels.tolist() + for i in range(len(labels)): + if num_sample_to_write > 0: + num_sample_to_write -= 1 + else: + break + token_ids = [x for x in labels[i] if x != -100] + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + token_losses = [0] + batch_token_loss[i][:len(token_ids)-1] + writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n') + + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode + if all_inputs is None + else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = ( + labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + ) + + # Set back to None to begin a new accumulation + losses_host, preds_host, inputs_host, labels_host = None, None, None, None + + if args.dump_eval_loss_to: + writer.close() + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + # Number of samples + if has_length(eval_dataset): + num_samples = len(eval_dataset) + # The instance check is weird and does not actually check for the type, but whether the dataset has the right + # methods. Therefore we need to make sure it also has the attribute. + elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"): + num_samples = eval_dataset.num_examples + else: + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples + + # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of + # samplers has been rounded to a multiple of batch_size, so we truncate. + if all_losses is not None: + all_losses = all_losses[:num_samples] + if all_preds is not None: + all_preds = nested_truncate(all_preds, num_samples) + if all_labels is not None: + all_labels = nested_truncate(all_labels, num_samples) + if all_inputs is not None: + all_inputs = nested_truncate(all_inputs, num_samples) + + # Metrics! + if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if all_losses is not None: + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) -- GitLab