Skip to content
Snippets Groups Projects
Commit f5c0a840 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

SCGPT default device GPU

parent 7c7f914b
Branches
No related tags found
No related merge requests found
......@@ -8,12 +8,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from convlab.nlg.nlg import NLG
from convlab.nlg.scgpt.util import act2str
from convlab.nlg.scgpt.scgpt_special_tokens import *
class SCGPT(NLG):
def __init__(self, dataset_name, model_path, device='cpu'):
def __init__(self, dataset_name, model_path, device='gpu'):
super(SCGPT, self).__init__()
self.dataset_name = dataset_name
self.device = device
self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment