Skip to content
Snippets Groups Projects
Unverified Commit 2e0a2c24 authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Maintenance (#119)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization
parent 1f8ed9fd
No related branches found
No related tags found
No related merge requests found
......@@ -18079,7 +18079,7 @@
"价格是[Inform+酒店+价格],地址是[Inform+酒店+地址]。",
"地址是在[Inform+酒店+地址],价格[Inform+酒店+价格]。",
"地址在[Inform+酒店+地址];价格是[Inform+酒店+价格]。",
"价格[Inform+酒店+价格];[Inform+酒店+地址]。",
"价格[Inform+酒店+价格];地址在[Inform+酒店+地址]。",
"这家酒店的地址是[Inform+酒店+地址],价格是[Inform+酒店+价格]。",
"不贵,[Inform+酒店+价格],地址在[Inform+酒店+地址]。",
"它的价格是[Inform+酒店+价格],然后地址是[Inform+酒店+地址]。",
......@@ -20988,7 +20988,7 @@
"[Inform+酒店+名称]的评分是[Inform+酒店+评分],地址是[Inform+酒店+地址]。"
],
"Inform+酒店+名称*Inform+酒店+地址*Inform+酒店+酒店类型": [
"给您推荐[Inform+酒店+名称],它[Inform+酒店+地址],为[Inform+酒店+酒店类型]酒店。",
"给您推荐[Inform+酒店+名称],它位于[Inform+酒店+地址],为[Inform+酒店+酒店类型]酒店。",
"[Inform+酒店+名称]是[Inform+酒店+酒店类型]酒店,地址是[Inform+酒店+地址]。",
"[Inform+酒店+名称]是一家[Inform+酒店+酒店类型]的酒店,酒店的地址是[Inform+酒店+地址]。",
"[Inform+酒店+名称]是[Inform+酒店+酒店类型]酒店,地址是[Inform+酒店+地址]。",
......@@ -24308,7 +24308,7 @@
"地址是在[Inform+酒店+地址]。",
"地址是在[Inform+酒店+地址]。",
"地址在[Inform+酒店+地址]。",
"[Inform+酒店+地址]。",
"地址在[Inform+酒店+地址]。",
"地址是[Inform+酒店+地址]。",
"地址是[Inform+酒店+地址]。",
"地址是[Inform+酒店+地址]。",
......@@ -4,12 +4,13 @@ import json
import torch
from unidecode import unidecode
import spacy
from convlab2.util.file_util import cached_path
from convlab2.util.file_util import cached_path, get_root_path
from convlab2.nlu.nlu import NLU
from convlab2.nlu.jointBERT.dataloader import Dataloader
from convlab2.nlu.jointBERT.jointBERT import JointBERT
from convlab2.nlu.jointBERT.multiwoz.postprocess import recover_intent
from convlab2.nlu.jointBERT.multiwoz.preprocess import preprocess
from spacy.symbols import ORTH, LEMMA, POS
class BERTNLU(NLU):
......@@ -56,11 +57,18 @@ class BERTNLU(NLU):
self.use_context = config['model']['context']
self.dataloader = dataloader
self.nlp = spacy.load('en_core_web_sm')
with open(os.path.join(get_root_path(), 'data/multiwoz/db/postcode.json'), 'r') as f:
token_list = json.load(f)
for token in token_list:
token = token.strip()
self.nlp.tokenizer.add_special_case(token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}])
print("BERTNLU loaded")
def predict(self, utterance, context=list()):
# ori_word_seq = unidecode(utterance).split()
# tokenization first, very important!
ori_word_seq = [token.text for token in self.nlp(unidecode(utterance)) if token.text.strip()]
# print(ori_word_seq)
ori_tag_seq = ['O'] * len(ori_word_seq)
if self.use_context:
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
......@@ -94,9 +102,9 @@ class BERTNLU(NLU):
if __name__ == '__main__':
text = "I will need you departure and arrival city and time ."
text = "How about rosa's bed and breakfast ? Their postcode is cb22ha."
nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip')
print(nlu.predict(text, context=['', "I ' m looking for a train leaving on tuesday please ."]))
text = "I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant."
print(nlu.predict(text))
# text = "I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant."
# print(nlu.predict(text))
......@@ -12,11 +12,11 @@ from allennlp.data import DatasetReader
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.models.archival import load_archive
from convlab2.util.file_util import cached_path
from convlab2.util.file_util import cached_path, get_root_path
from convlab2.nlu.nlu import NLU
from convlab2.nlu.milu import dataset_reader, model
from spacy.symbols import ORTH, LEMMA
import json
from spacy.symbols import ORTH, LEMMA, POS
DEFAULT_CUDA_DEVICE = -1
DEFAULT_DIRECTORY = "models"
......@@ -47,6 +47,12 @@ class MILU(NLU):
self.tokenizer = SpacyWordSplitter(language="en_core_web_sm")
_special_case = [{ORTH: u"id", LEMMA: u"id"}]
self.tokenizer.spacy.tokenizer.add_special_case(u"id", _special_case)
with open(os.path.join(get_root_path(), 'data/multiwoz/db/postcode.json'), 'r') as f:
token_list = json.load(f)
for token in token_list:
token = token.strip()
self.tokenizer.spacy.tokenizer.add_special_case(token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}])
dataset_reader_params = archive.config["dataset_reader"]
self.dataset_reader = DatasetReader.from_params(dataset_reader_params)
......
from pathlib import Path
import zipfile
import json
import os
from convlab2.util.allennlp_file_utils import cached_path as allennlp_cached_path
......@@ -24,3 +25,7 @@ def dump_json(content, filepath):
def write_zipped_json(zip_path, filepath):
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write(filepath)
def get_root_path():
return os.path.abspath(os.path.join(os.path.abspath(__file__), '../../..'))
import json
import os
def main():
dir_path = os.path.dirname(os.path.abspath(__file__))
postcode = []
for domain in ['attraction', 'hotel', 'hospital', 'police', 'restaurant']:
db = json.load(open(os.path.join(dir_path, "{}_db.json".format(domain))))
for entry in db:
if entry['postcode'] not in postcode:
postcode.append(entry['postcode'])
json.dump(postcode, open(os.path.join(dir_path, "postcode.json"), 'w'), indent=2)
if __name__ == '__main__':
main()
[
"cb58nt",
"cb58as",
"cb58bs",
"cb23na",
"cb11ln",
"cb21sj",
"cb30af",
"cb58sx",
"cb30aq",
"cb23pj",
"cb13ef",
"cb39ey",
"cb21su",
"cb58ld",
"cb21jf",
"cb23bj",
"cb18dw",
"cb23bu",
"cb30ds",
"cb17dy",
"cb21tl",
"cb39al",
"cb12jb",
"cb21rh",
"cb21dq",
"cb23ap",
"cb58hy",
"cb15dh",
"cb21ta",
"cb23pq",
"cb23nz",
"cb12ew",
"cb58bl",
"cb43px",
"cb23qb",
"cb21st",
"cb42xh",
"cb21qy",
"cb30ag",
"cb46az",
"cb11pt",
"cb23dz",
"cb39da",
"cb21tt",
"cb11ly",
"cb21rf",
"cb3ojg",
"cb39et",
"cb11er",
"cb43ax",
"cb13ew",
"cb21rl",
"cb21tp",
"cb21er",
"cb21rs",
"cb22ad",
"cb23hu",
"cb23qf",
"cb23qe",
"cb41as",
"cb19ej",
"cb23hx",
"cb21rb",
"cb17gx",
"cb12lf",
"cb23hg",
"cb21tq",
"cb11ps",
"cb223ae",
"cb238el",
"cb23rh",
"cb12lj",
"cb12dp",
"cb41da",
"cb12de",
"cb13js",
"cb41xa",
"cb42je",
"cb43pe",
"cb41er",
"cb58rs",
"cb43pd",
"cb17sr",
"cb28rj",
"cb13nx",
"cb43ht",
"cb12tz",
"cb11eg",
"cb13lh",
"cb30nd",
"cb39lh",
"cb41la",
"pe296fl",
"cb41sr",
"cb22ha",
"cb236bw",
"cb21en",
"cb21ad",
"cb11ee",
"cb20qq",
"cb11jg",
"cb21ab",
"cb259aq",
"cb21dp",
"cb17ag",
"cb17aa",
"cb28pb",
"cb58pa",
"cb23nj",
"cb58jj",
"cb21db",
"cb21uf",
"cb23pp",
"cb21aw",
"cb12qa",
"cb43le",
"cb21rt",
"cb12as",
"cb30ad",
"cb12bd",
"cb21nt",
"cb23ar",
"cb12az",
"cb30ah",
"cb21qa",
"cb23dt",
"cb41jy",
"cb21tw",
"cb21eg",
"cb21rg",
"cb13nf",
"cb21rq",
"cb11lh",
"cb41uy",
"cb11bg",
"cb58wr",
"cb21uw",
"cb23ll",
"cb41eh",
"cb43hl",
"cb43lf",
"cb21la",
"cb58aq",
"cb11dg",
"cb58ba",
"cb41nl",
"cb28nx",
"cb30lx",
"cb30dq",
"cb41ep",
"cb21uj",
"cb11hr",
"cb23ju",
"cb21ug",
"cb30df",
"cb13nl",
"cb58rg",
"cb19hx",
"cb41ha",
"cb21nw",
"cb23jx"
]
\ No newline at end of file
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment