Skip to content
Snippets Groups Projects
Unverified Commit e8ae8881 authored by xw's avatar xw Committed by GitHub
Browse files

nlu update and bugfix (#118)

* jointBERT_new avaliable && fix milu dataset_reader && fix jointBERT/tag2id

* remove jointBERT_new

* update milu/multiwoz/nlu.py model_file path
parent 8e8b830b
No related branches found
No related tags found
No related merge requests found
...@@ -33,12 +33,21 @@ convlab2/nlg/sclstm/**/sclstm.log ...@@ -33,12 +33,21 @@ convlab2/nlg/sclstm/**/sclstm.log
convlab2/nlg/sclstm/**/sclstm_usr.pt convlab2/nlg/sclstm/**/sclstm_usr.pt
convlab2/nlg/sclstm/**/sclstm_usr.res convlab2/nlg/sclstm/**/sclstm_usr.res
convlab2/nlg/sclstm/**/sclstm_usr.log convlab2/nlg/sclstm/**/sclstm_usr.log
convlab2/nlu/jointBERT/**/output/
convlab2/dst/sumbt/multiwoz/output/ convlab2/dst/sumbt/multiwoz/output/
convlab2/nlg/sclstm/**/generated_sens_sys.json convlab2/nlg/sclstm/**/generated_sens_sys.json
convlab2/nlg/template/**/generated_sens_sys.json convlab2/nlg/template/**/generated_sens_sys.json
convlab2/nlu/jointBERT/crosswoz/**/data convlab2/nlu/jointBERT/crosswoz/**/data
convlab2/nlu/jointBERT/multiwoz/**/data convlab2/nlu/jointBERT/multiwoz/**/data
convlab2/nlu/jointBERT/**/output/
convlab2/nlu/jointBERT_new/crosswoz/**/data
convlab2/nlu/jointBERT_new/multiwoz/**/data
convlab2/nlu/jointBERT_new/crosswoz/**/log
convlab2/nlu/jointBERT_new/multiwoz/**/log
convlab2/nlu/jointBERT_new/**/output/
convlab2/nlu/milu/09*
convlab2/nlu/jointBERT/multiwoz/configs/multiwoz_new_usr_context.json
convlab2/nlu/milu/multiwoz/configs/system_without_context.jsonnet
convlab2/nlu/milu/multiwoz/configs/user_without_context.jsonnet
# test script # test script
*_test.py *_test.py
......
...@@ -57,6 +57,7 @@ class Dataloader: ...@@ -57,6 +57,7 @@ class Dataloader:
new2ori = None new2ori = None
d.append(new2ori) d.append(new2ori)
d.append(word_seq) d.append(word_seq)
d.append(self.seq_tag2id(tag_seq)) d.append(self.seq_tag2id(tag_seq))
d.append(self.seq_intent2id(d[2])) d.append(self.seq_intent2id(d[2]))
# d = (tokens, tags, intents, da2triples(turn["dialog_act"]), context(token id), new2ori, new_word_seq, tag2id_seq, intent2id_seq) # d = (tokens, tags, intents, da2triples(turn["dialog_act"]), context(token id), new2ori, new_word_seq, tag2id_seq, intent2id_seq)
...@@ -95,7 +96,7 @@ class Dataloader: ...@@ -95,7 +96,7 @@ class Dataloader:
return split_tokens, new_tag_seq, new2ori return split_tokens, new_tag_seq, new2ori
def seq_tag2id(self, tags): def seq_tag2id(self, tags):
return [self.tag2id[x] for x in tags if x in self.tag2id] return [self.tag2id[x] if x in self.tag2id else self.tag2id['O'] for x in tags]
def seq_id2tag(self, ids): def seq_id2tag(self, ids):
return [self.id2tag[x] for x in ids] return [self.id2tag[x] for x in ids]
......
...@@ -75,9 +75,11 @@ class MILUDatasetReader(DatasetReader): ...@@ -75,9 +75,11 @@ class MILUDatasetReader(DatasetReader):
dialog = dialogs[dial_name]["log"] dialog = dialogs[dial_name]["log"]
context_tokens_list = [] context_tokens_list = []
for i, turn in enumerate(dialog): for i, turn in enumerate(dialog):
if self._agent and self._agent == "user" and i % 2 != 1: if self._agent and self._agent == "user" and i % 2 == 1:
context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
continue continue
if self._agent and self._agent == "system" and i % 2 != 0: if self._agent and self._agent == "system" and i % 2 == 0:
context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
continue continue
tokens = turn["text"].split() tokens = turn["text"].split()
......
...@@ -28,7 +28,7 @@ class MILU(NLU): ...@@ -28,7 +28,7 @@ class MILU(NLU):
def __init__(self, def __init__(self,
archive_file=DEFAULT_ARCHIVE_FILE, archive_file=DEFAULT_ARCHIVE_FILE,
cuda_device=DEFAULT_CUDA_DEVICE, cuda_device=DEFAULT_CUDA_DEVICE,
model_file="https://convlab.blob.core.windows.net/convlab-2/milu_multiwoz_all_context.tar.gz", model_file="https://convlab.blob.core.windows.net/convlab-2/new_milu(20200922)_multiwoz_all_context.tar.gz",
context_size=3): context_size=3):
""" Constructor for NLU class. """ """ Constructor for NLU class. """
......
...@@ -16,7 +16,8 @@ from allennlp.common.checks import check_for_gpu ...@@ -16,7 +16,8 @@ from allennlp.common.checks import check_for_gpu
from allennlp.common.util import prepare_environment, prepare_global_logging, cleanup_global_logging, dump_metrics from allennlp.common.util import prepare_environment, prepare_global_logging, cleanup_global_logging, dump_metrics
from allennlp.models.archival import archive_model, CONFIG_NAME from allennlp.models.archival import archive_model, CONFIG_NAME
from allennlp.models.model import Model, _DEFAULT_WEIGHTS from allennlp.models.model import Model, _DEFAULT_WEIGHTS
from allennlp.training.trainer import Trainer, TrainerPieces from allennlp.training.trainer import Trainer
from allennlp.training.trainer_pieces import TrainerPieces
from allennlp.training.trainer_base import TrainerBase from allennlp.training.trainer_base import TrainerBase
from allennlp.training.util import create_serialization_dir, evaluate from allennlp.training.util import create_serialization_dir, evaluate
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment