diff --git a/convlab/e2e/soloist/multiwoz/soloist.py b/convlab/e2e/soloist/multiwoz/soloist.py index fea24f32dc1d56d205814164541e80ba0c48d321..580a14737d79b17b2e67aa3744611b3dd6272c53 100644 --- a/convlab/e2e/soloist/multiwoz/soloist.py +++ b/convlab/e2e/soloist/multiwoz/soloist.py @@ -10,7 +10,7 @@ from nltk.tokenize import word_tokenize from convlab.util.file_util import cached_path from convlab.e2e.soloist.multiwoz.config import global_config as cfg -from convlab.e2e.soloist.multiwoz.soloist_net import SOLOIST, cuda_ +from convlab.e2e.soloist.multiwoz.soloist_net import SOLOIST from convlab.dialog_agent import Agent from utils import MultiWozReader diff --git a/convlab/e2e/soloist/multiwoz/soloist_net.py b/convlab/e2e/soloist/multiwoz/soloist_net.py new file mode 100644 index 0000000000000000000000000000000000000000..3f23d106076603e6dcd8b33638195dfd8a51a424 --- /dev/null +++ b/convlab/e2e/soloist/multiwoz/soloist_net.py @@ -0,0 +1,48 @@ +import logging +import torch + +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer +) + +from convlab.e2e.soloist.multiwoz.config import global_config as cfg + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + +def cuda_(var): + return var.cuda() if cfg.cuda and torch.cuda.is_available() else var + + +def tensor(var): + return cuda_(torch.tensor(var)) + +class SOLOIST: + + def __init__(self) -> None: + + self.config = AutoConfig.from_pretrained(cfg.model_name_or_path) + self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config) + self.tokenizer = AutoTokenizer.from_pretrained('t5-base') + print('model loaded!') + + self.model = self.model.cuda() if torch.cuda.is_available() else self.model + + def generate(self, inputs): + + self.model.eval() + inputs = self.tokenizer([inputs]) + input_ids = tensor(inputs['input_ids']) + generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p) + decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + + return decoded_preds[0] + + + \ No newline at end of file