Skip to content
Snippets Groups Projects
Commit 036ad765 authored by zz-jacob's avatar zz-jacob
Browse files

delete redundancy codes

parent b2cc2e40
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,6 @@ import torch.nn as nn ...@@ -9,7 +9,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import os import os
from transformers import get_linear_schedule_with_warmup from transformers import get_linear_schedule_with_warmup
...@@ -40,7 +39,7 @@ torch.cuda.set_device(local_rank) ...@@ -40,7 +39,7 @@ torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl') dist.init_process_group(backend='nccl')
# TensorBoard # TensorBoard
tb_writer = SummaryWriter() tb_writer = SummaryWriter('')
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK] special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
## load model ## load model
...@@ -105,7 +104,6 @@ if code_test: ...@@ -105,7 +104,6 @@ if code_test:
VAL_STEP = 2 VAL_STEP = 2
WARM_STEPS = 3 WARM_STEPS = 3
LR = 5e-5 LR = 5e-5
TASK_TYPE = 'nlu' # nlu or dst
SAVE_PATH = f'./saved_model' SAVE_PATH = f'./saved_model'
def train(model, nlg_data, global_step=0): def train(model, nlg_data, global_step=0):
train_dataset = SCGPTDataset(nlg_data['train'], tokenizer) train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
...@@ -212,7 +210,6 @@ def test(model, nlg_data, ontology, model_path): ...@@ -212,7 +210,6 @@ def test(model, nlg_data, ontology, model_path):
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
model.eval() model.eval()
print(f'model loaded from [{model_path}]') print(f'model loaded from [{model_path}]')
# sample_file = os.path.join(f'../../../data/dstc2/sample50_{TASK_TYPE}_input_data.txt')
# Load test nlg data # Load test nlg data
test_data = nlg_data['test'] test_data = nlg_data['test']
dialog_acts = [act2str(item['dialogue_acts']) for item in test_data] dialog_acts = [act2str(item['dialogue_acts']) for item in test_data]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment