diff --git a/convlab2/nlg/scgpt/decode.py b/convlab2/nlg/scgpt/decode.py deleted file mode 100644 index e95025afdde60d8beb41bac5d2fb038e39357f3d..0000000000000000000000000000000000000000 --- a/convlab2/nlg/scgpt/decode.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Sat Apr 4 21:34:38 2020 - -@author: truthless -""" -import numpy as np -import torch - -def set_seed(seed, n_gpu): - np.random.seed(seed) - torch.manual_seed(seed) - if n_gpu > 0: - torch.cuda.manual_seed_all(seed) - - -def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size x vocabulary size) - top_k > 0: keep only top k tokens with highest probability (top-k filtering). - top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - top_k = min(top_k, logits.size(-1)) # Safety check - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) - logits[indices_to_remove] = filter_value - return logits - - -def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, - is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'): - context = torch.tensor(context, dtype=torch.long, device=device) - context = context.unsqueeze(0).repeat(num_samples, 1) - generated = context - with torch.no_grad(): - for _ in range(length): - - inputs = {'input_ids': generated} - if is_xlnet: - # XLNet is a direct (predict same token, not next token) and bi-directional model by default - # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring) - input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1) - perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device) - perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token - target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device) - target_mapping[0, 0, -1] = 1.0 # predict last token - inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} - - if is_xlm_mlm and xlm_mask_token: - # XLM MLM models are direct models (predict same token, not next token) - # => need one additional dummy token in the input (will be masked and guessed) - input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1) - inputs = {'input_ids': input_ids} - - if xlm_lang is not None: - inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1) - - outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) - next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.) - - # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) - for i in range(num_samples): - for _ in set(generated[i].tolist()): - next_token_logits[i, _] /= repetition_penalty - - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - if temperature == 0: # greedy sampling: - next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) - else: - next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1) - generated = torch.cat((generated, next_token), dim=1) - return generated diff --git a/convlab2/nlg/scgpt/main.py b/convlab2/nlg/scgpt/main.py index 2d69ba2f5af38b29efaa7e71dcbfae851aa3a218..9f1ed5817d36d12c5f6a69b29c82c63dd585db44 100644 --- a/convlab2/nlg/scgpt/main.py +++ b/convlab2/nlg/scgpt/main.py @@ -19,22 +19,19 @@ from convlab2.nlg.scgpt.util import act2str from convlab2.nlg.scgpt.model import SCGPTDataset from evaluate import GentScorer -# 分部式训练 import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from util import build_mask from scgpt_special_tokens import * -# from plot.attn_plot import plot_attn_encdec -## Model Testing code_test = False -## 多GPU并行参数设置 parser = argparse.ArgumentParser() parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument('--do_train', action="store_true", help="Whether to run training.") -parser.add_argument('--dataset', default="multiwoz21", type=str, help="Whether to run training.") +parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.") +parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.") parser.add_argument("--max_seq_len", default=256, type=int) FLAGS = parser.parse_args() local_rank = FLAGS.local_rank @@ -52,16 +49,15 @@ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank) model.resize_token_embeddings(len(tokenizer)) -## loss计算 nll_loss = nn.NLLLoss(reduce=False).to(local_rank) ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank) def cal_loss(input, target, seq_lens, seq_lens_input): """Only calculate loss on responses, not on dialog act""" global nll_loss """Input: [batch, length, vocab]; target: [batch, length]; seq_lens: [batch]""" - log_probs = F.log_softmax(input, dim=-1).transpose(1, 2) # 类别维度要放在dim=1的位置,nn.NLLLoss的要求 + log_probs = F.log_softmax(input, dim=-1).transpose(1, 2) loss = nll_loss(log_probs, target) - # loss = ce_loss(input, target) # 等价 + # loss = ce_loss(input, target) mask = build_mask(torch.max(seq_lens).item()-1, seq_lens-1).to(local_rank) input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank) output_mask = torch.logical_xor(mask, input_mask) @@ -82,7 +78,7 @@ def pad_collate(batch): START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) pad_token_id = tokenizer.pad_token_id batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch] - batch = [item[-FLAGS.max_seq_len:] for item in batch] # TF限制输入长度 + batch = [item[-FLAGS.max_seq_len:] for item in batch] max_len = max([len(item) for item in batch]) seq_lens = [len(item) for item in batch] split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED) @@ -286,4 +282,4 @@ if __name__ == '__main__': if FLAGS.do_train: train(model, nlg_data) else: - test(model, nlg_data, ontology, './saved_model/epoch_0/epoch_0_step2839.pt') + test(model, nlg_data, ontology, FLAGS.model_path) diff --git a/convlab2/nlg/scgpt/modeling_utils.py b/convlab2/nlg/scgpt/modeling_utils.py deleted file mode 100644 index a8b3f6ddfc6b7347c624446bf7869c67d3064cc1..0000000000000000000000000000000000000000 --- a/convlab2/nlg/scgpt/modeling_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -import warnings -from contextlib import nullcontext -from typing import TYPE_CHECKING -import torch.cuda.amp as amp -import transformers -from transformers import GPT2LMHeadModel - - -# reference: https://pytorch.org/docs/master/notes/amp_examples.html -class AmpGPT2LMHeadModel(GPT2LMHeadModel): - if TYPE_CHECKING: - # For IDE's code hinting - forward = GPT2LMHeadModel.forward - else: - def forward(self, *args, **kwargs): - with amp.autocast(): - return super().forward(*args, **kwargs) - - -def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"): - if model.supports_gradient_checkpointing: - model.gradient_checkpointing_enable() - else: - warnings.warn(f"{type(model)} doesn't support gradient_checkpointing") - - -class AmpHelper: - """ - References: - https://pytorch.org/docs/master/notes/amp_examples.html - """ - def __init__(self, use_amp=True): - self.use_amp = use_amp - self.might_enable_autocast = amp.autocast() if use_amp else nullcontext() - self.scaler = amp.GradScaler() - - def backward(self, loss): - if self.use_amp: - return self.scaler.scale(loss).backward() - else: - return loss.backward() - - def step(self, optimizer): - if self.use_amp: - self.scaler.step(optimizer) - self.scaler.update() - else: - optimizer.step() - - def might_unscale_(self, optimizer): - if self.use_amp: - # Unscales the gradients of optimizer's assigned params in-place - self.scaler.unscale_(optimizer) \ No newline at end of file diff --git a/convlab2/nlg/scgpt/utils.py b/convlab2/nlg/scgpt/utils.py deleted file mode 100644 index 7fefc166f096c307590fd5b0478c4db1cf551e7f..0000000000000000000000000000000000000000 --- a/convlab2/nlg/scgpt/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Tue Mar 24 18:34:55 2020 - -@author: truthless -""" - -def tuple2dict(t): - ''' - tuple: [(intent, domain, slot, value)] - dict: [domain: { intent: [slot, value] }] - ''' - d = {} - for intent, domain, slot, value in t: - if domain not in d: - d[domain] = {} - if intent not in d[domain]: - d[domain][intent] = [] - if slot == 'none' or slot is None: - continue - d[domain][intent].append([slot, value]) - return d - -def dict2dict(D): - ''' - dict: [domain-intent: [slot, value]] - dict: [domain: { intent: [slot, value] }] - ''' - d = {} - for domint in D: - domain, intent = domint.split('-') - if domain not in d: - d[domain] = {} - if intent not in d[domain]: - d[domain][intent] = [] - for slot, value in D[domint]: - if slot == 'none' or slot is None: - continue - d[domain][intent].append([slot, value]) - return d - -def dict2seq(d): - ''' - dict: [domain: { intent: [slot, value] }] - seq: [domain { intent ( slot = value ; ) @ } | ] - ''' - s = '' - first_domain = True - first_intent = True - first_slot = True - for domain in d: - if not first_domain: - s += ' | ' - s += domain - s += ' { ' - first_domain = False - first_intent = True - for intent in d[domain]: - if not first_intent: - s += ' @ ' - s += intent - s += ' ( ' - first_intent = False - first_slot = True - for slot, value in d[domain][intent]: - if not first_slot: - s += ' ; ' - s += slot - if value: - s += ' = ' - s += value - first_slot = False - s += ' )' - s += ' }' - return s.lower() - -def tuple2seq(t): - d = tuple2dict(t) - s = dict2seq(d) - return s - -if __name__ == '__main__': - da_tuple = [('Inform', 'Booking', 'none', 'none'), ('Inform', 'Hotel', 'Price', 'cheap'), ('Inform', 'Hotel', 'Choice', '1'), ('Inform', 'Hotel', 'Parking', 'none')] - da_dict = tuple2dict(da_tuple) - print(da_dict) - da_seq = dict2seq(da_dict) - print(da_seq) - - da_tuple = [('Request', 'Hotel', 'Address', '?'), ('Request', 'Hotel', 'Area', '?'), ('Inform', 'Attraction', 'Area', 'center'), ('Inform', 'Hotel', 'Price', 'cheap')] - da_dict = tuple2dict(da_tuple) - print(da_dict) - da_seq = dict2seq(da_dict) - print(da_seq) - - D = {'Hotel-Inform': [['Price', 'cheap'], ['Type', 'hotel']]} - da_dict = dict2dict(D) - print(da_dict) -