Skip to content
Snippets Groups Projects
Commit 036ad765 authored by zz-jacob's avatar zz-jacob
Browse files

delete redundancy codes

parent b2cc2e40
Branches
No related tags found
No related merge requests found
......@@ -9,7 +9,6 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import os
from transformers import get_linear_schedule_with_warmup
......@@ -40,7 +39,7 @@ torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
# TensorBoard
tb_writer = SummaryWriter()
tb_writer = SummaryWriter('')
special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
## load model
......@@ -105,7 +104,6 @@ if code_test:
VAL_STEP = 2
WARM_STEPS = 3
LR = 5e-5
TASK_TYPE = 'nlu' # nlu or dst
SAVE_PATH = f'./saved_model'
def train(model, nlg_data, global_step=0):
train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
......@@ -212,7 +210,6 @@ def test(model, nlg_data, ontology, model_path):
model.load_state_dict(torch.load(model_path))
model.eval()
print(f'model loaded from [{model_path}]')
# sample_file = os.path.join(f'../../../data/dstc2/sample50_{TASK_TYPE}_input_data.txt')
# Load test nlg data
test_data = nlg_data['test']
dialog_acts = [act2str(item['dialogue_acts']) for item in test_data]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment