From e8ae8881081d0f38e23c4269326f966c0d9b2958 Mon Sep 17 00:00:00 2001
From: xw <48146603+xwwwwww@users.noreply.github.com>
Date: Thu, 24 Sep 2020 15:28:39 +0800
Subject: [PATCH] nlu update and bugfix (#118)

* jointBERT_new avaliable && fix milu dataset_reader && fix jointBERT/tag2id

* remove jointBERT_new

* update milu/multiwoz/nlu.py model_file path
---
 .gitignore                           | 11 ++++++++++-
 convlab2/nlu/jointBERT/dataloader.py |  3 ++-
 convlab2/nlu/milu/dataset_reader.py  |  6 ++++--
 convlab2/nlu/milu/multiwoz/nlu.py    |  2 +-
 convlab2/nlu/milu/train.py           |  3 ++-
 5 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/.gitignore b/.gitignore
index b8268b1..cd35620 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,12 +33,21 @@ convlab2/nlg/sclstm/**/sclstm.log
 convlab2/nlg/sclstm/**/sclstm_usr.pt
 convlab2/nlg/sclstm/**/sclstm_usr.res
 convlab2/nlg/sclstm/**/sclstm_usr.log
-convlab2/nlu/jointBERT/**/output/
 convlab2/dst/sumbt/multiwoz/output/
 convlab2/nlg/sclstm/**/generated_sens_sys.json
 convlab2/nlg/template/**/generated_sens_sys.json
 convlab2/nlu/jointBERT/crosswoz/**/data
 convlab2/nlu/jointBERT/multiwoz/**/data
+convlab2/nlu/jointBERT/**/output/
+convlab2/nlu/jointBERT_new/crosswoz/**/data
+convlab2/nlu/jointBERT_new/multiwoz/**/data
+convlab2/nlu/jointBERT_new/crosswoz/**/log
+convlab2/nlu/jointBERT_new/multiwoz/**/log
+convlab2/nlu/jointBERT_new/**/output/
+convlab2/nlu/milu/09*
+convlab2/nlu/jointBERT/multiwoz/configs/multiwoz_new_usr_context.json
+convlab2/nlu/milu/multiwoz/configs/system_without_context.jsonnet
+convlab2/nlu/milu/multiwoz/configs/user_without_context.jsonnet
 
 # test script
 *_test.py
diff --git a/convlab2/nlu/jointBERT/dataloader.py b/convlab2/nlu/jointBERT/dataloader.py
index fba4ebf..38fc24e 100755
--- a/convlab2/nlu/jointBERT/dataloader.py
+++ b/convlab2/nlu/jointBERT/dataloader.py
@@ -57,6 +57,7 @@ class Dataloader:
                 new2ori = None
             d.append(new2ori)
             d.append(word_seq)
+
             d.append(self.seq_tag2id(tag_seq))
             d.append(self.seq_intent2id(d[2]))
             # d = (tokens, tags, intents, da2triples(turn["dialog_act"]), context(token id), new2ori, new_word_seq, tag2id_seq, intent2id_seq)
@@ -95,7 +96,7 @@ class Dataloader:
         return split_tokens, new_tag_seq, new2ori
 
     def seq_tag2id(self, tags):
-        return [self.tag2id[x] for x in tags if x in self.tag2id]
+        return [self.tag2id[x] if x in self.tag2id else self.tag2id['O'] for x in tags]
 
     def seq_id2tag(self, ids):
         return [self.id2tag[x] for x in ids]
diff --git a/convlab2/nlu/milu/dataset_reader.py b/convlab2/nlu/milu/dataset_reader.py
index 3a8cf77..5e00af0 100755
--- a/convlab2/nlu/milu/dataset_reader.py
+++ b/convlab2/nlu/milu/dataset_reader.py
@@ -75,9 +75,11 @@ class MILUDatasetReader(DatasetReader):
             dialog = dialogs[dial_name]["log"]
             context_tokens_list = []
             for i, turn in enumerate(dialog):
-                if self._agent and self._agent == "user" and i % 2 != 1: 
+                if self._agent and self._agent == "user" and i % 2 == 1: 
+                    context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
                     continue
-                if self._agent and self._agent == "system" and i % 2 != 0: 
+                if self._agent and self._agent == "system" and i % 2 == 0: 
+                    context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
                     continue
 
                 tokens = turn["text"].split()
diff --git a/convlab2/nlu/milu/multiwoz/nlu.py b/convlab2/nlu/milu/multiwoz/nlu.py
index 5417c6d..002a7dc 100755
--- a/convlab2/nlu/milu/multiwoz/nlu.py
+++ b/convlab2/nlu/milu/multiwoz/nlu.py
@@ -28,7 +28,7 @@ class MILU(NLU):
     def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
-                model_file="https://convlab.blob.core.windows.net/convlab-2/milu_multiwoz_all_context.tar.gz",
+                model_file="https://convlab.blob.core.windows.net/convlab-2/new_milu(20200922)_multiwoz_all_context.tar.gz",
                 context_size=3):
         """ Constructor for NLU class. """
 
diff --git a/convlab2/nlu/milu/train.py b/convlab2/nlu/milu/train.py
index 99db49f..9507a3a 100755
--- a/convlab2/nlu/milu/train.py
+++ b/convlab2/nlu/milu/train.py
@@ -16,7 +16,8 @@ from allennlp.common.checks import check_for_gpu
 from allennlp.common.util import prepare_environment, prepare_global_logging, cleanup_global_logging, dump_metrics
 from allennlp.models.archival import archive_model, CONFIG_NAME
 from allennlp.models.model import Model, _DEFAULT_WEIGHTS
-from allennlp.training.trainer import Trainer, TrainerPieces
+from allennlp.training.trainer import Trainer
+from allennlp.training.trainer_pieces import TrainerPieces
 from allennlp.training.trainer_base import TrainerBase
 from allennlp.training.util import create_serialization_dir, evaluate
 
-- 
GitLab