diff --git a/convlab/e2e/soloist/multiwoz/soloist.py b/convlab/e2e/soloist/multiwoz/soloist.py
index fea24f32dc1d56d205814164541e80ba0c48d321..580a14737d79b17b2e67aa3744611b3dd6272c53 100644
--- a/convlab/e2e/soloist/multiwoz/soloist.py
+++ b/convlab/e2e/soloist/multiwoz/soloist.py
@@ -10,7 +10,7 @@ from nltk.tokenize import word_tokenize
 
 from convlab.util.file_util import cached_path
 from convlab.e2e.soloist.multiwoz.config import global_config as cfg
-from convlab.e2e.soloist.multiwoz.soloist_net import SOLOIST, cuda_
+from convlab.e2e.soloist.multiwoz.soloist_net import SOLOIST
 from convlab.dialog_agent import Agent
 from utils import MultiWozReader
 
diff --git a/convlab/e2e/soloist/multiwoz/soloist_net.py b/convlab/e2e/soloist/multiwoz/soloist_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f23d106076603e6dcd8b33638195dfd8a51a424
--- /dev/null
+++ b/convlab/e2e/soloist/multiwoz/soloist_net.py
@@ -0,0 +1,48 @@
+import logging
+import torch
+
+from transformers import (
+    AutoConfig,
+    AutoModelForSeq2SeqLM,
+    AutoTokenizer
+)
+
+from convlab.e2e.soloist.multiwoz.config import global_config as cfg
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+
+def cuda_(var):
+    return var.cuda() if cfg.cuda and torch.cuda.is_available() else var
+
+
+def tensor(var):
+    return cuda_(torch.tensor(var))
+
+class SOLOIST:
+
+    def __init__(self) -> None:
+        
+        self.config = AutoConfig.from_pretrained(cfg.model_name_or_path)
+        self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config)
+        self.tokenizer = AutoTokenizer.from_pretrained('t5-base')
+        print('model loaded!')
+
+        self.model = self.model.cuda() if torch.cuda.is_available() else self.model
+
+    def generate(self, inputs):
+
+        self.model.eval()
+        inputs = self.tokenizer([inputs])
+        input_ids = tensor(inputs['input_ids'])
+        generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p)
+        decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+
+        return decoded_preds[0]
+
+    
+    
\ No newline at end of file