From 036ad7651b40a1ea461c4884fac9bdef424c13c6 Mon Sep 17 00:00:00 2001 From: zz-jacob <zhangz.goal@gmail.com> Date: Wed, 20 Apr 2022 16:40:00 +0800 Subject: [PATCH] delete redundancy codes --- convlab2/nlg/scgpt/main.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/convlab2/nlg/scgpt/main.py b/convlab2/nlg/scgpt/main.py index 9f1ed581..8f459182 100644 --- a/convlab2/nlg/scgpt/main.py +++ b/convlab2/nlg/scgpt/main.py @@ -9,7 +9,6 @@ import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Tokenizer, GPT2LMHeadModel from torch.utils.data import DataLoader -from torch.utils.data import Dataset from torch.utils.tensorboard import SummaryWriter import os from transformers import get_linear_schedule_with_warmup @@ -40,7 +39,7 @@ torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl') # TensorBoard -tb_writer = SummaryWriter() +tb_writer = SummaryWriter('') special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK] ## load model @@ -105,7 +104,6 @@ if code_test: VAL_STEP = 2 WARM_STEPS = 3 LR = 5e-5 -TASK_TYPE = 'nlu' # nlu or dst SAVE_PATH = f'./saved_model' def train(model, nlg_data, global_step=0): train_dataset = SCGPTDataset(nlg_data['train'], tokenizer) @@ -212,7 +210,6 @@ def test(model, nlg_data, ontology, model_path): model.load_state_dict(torch.load(model_path)) model.eval() print(f'model loaded from [{model_path}]') - # sample_file = os.path.join(f'../../../data/dstc2/sample50_{TASK_TYPE}_input_data.txt') # Load test nlg data test_data = nlg_data['test'] dialog_acts = [act2str(item['dialogue_acts']) for item in test_data] -- GitLab