From 58f9a0d4586bf97cf9195c7a3b5ad8eb4c713d90 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Tue, 13 Dec 2022 12:07:05 +0800
Subject: [PATCH] update nlu interface: auto preprocess for jointbert;
 transform da output for t5nlu

---
 convlab/base_models/t5/nlu/nlu.py             | 11 +++++++----
 convlab/nlu/jointBERT/unified_datasets/nlu.py |  5 ++++-
 2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/convlab/base_models/t5/nlu/nlu.py b/convlab/base_models/t5/nlu/nlu.py
index 4162aa3f..ef679ffa 100755
--- a/convlab/base_models/t5/nlu/nlu.py
+++ b/convlab/base_models/t5/nlu/nlu.py
@@ -1,4 +1,4 @@
-import logging
+gimport logging
 import torch
 from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
 from convlab.nlu.nlu import NLU
@@ -38,8 +38,11 @@ class T5NLU(NLU):
         # print(output_seq)
         output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True)
         # print(output_seq)
-        dialogue_acts = deserialize_dialogue_acts(output_seq.strip())
-        return dialogue_acts
+        das = deserialize_dialogue_acts(output_seq.strip())
+        dialog_act = []
+        for da in das:
+            dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')])
+        return dialog_act
 
 
 if __name__ == '__main__':
@@ -67,7 +70,7 @@ if __name__ == '__main__':
         "I am not sure of the type of food but could you please check again and see if you can find it? Thank you.",
         "Could you double check that you've spelled the name correctly? The closest I can find is Nandos."]
     ]
-    nlu = T5NLU(speaker='user', context_window_size=3, model_name_or_path='output/nlu/multiwoz21/user/context_3')
+    nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')
     for text, context in zip(texts, contexts):
         print(text)
         print(nlu.predict(text, context))
diff --git a/convlab/nlu/jointBERT/unified_datasets/nlu.py b/convlab/nlu/jointBERT/unified_datasets/nlu.py
index 85cf56d9..311f0b6f 100755
--- a/convlab/nlu/jointBERT/unified_datasets/nlu.py
+++ b/convlab/nlu/jointBERT/unified_datasets/nlu.py
@@ -7,6 +7,7 @@ import transformers
 from convlab.nlu.nlu import NLU
 from convlab.nlu.jointBERT.dataloader import Dataloader
 from convlab.nlu.jointBERT.jointBERT import JointBERT
+from convlab.nlu.jointBERT.unified_datasets.preprocess import preprocess
 from convlab.nlu.jointBERT.unified_datasets.postprocess import recover_intent
 from convlab.util.custom_util import model_downloader
 
@@ -25,7 +26,9 @@ class BERTNLU(NLU):
         data_dir = os.path.join(root_dir, config['data_dir'])
         output_dir = os.path.join(root_dir, config['output_dir'])
 
-        assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first')
+        if not os.path.exists(os.path.join(data_dir, 'intent_vocab.json')):
+            print('Run preprocess first')
+            preprocess(config['dataset_name'], data_dir.split('/')[-2], os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'), int(data_dir.split('/')[-1].split('_')[-1]))
 
         intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
         tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
-- 
GitLab