diff --git a/convlab2/nlg/template/crosswoz/auto_system_template_nlg.json b/convlab2/nlg/template/crosswoz/auto_system_template_nlg.json index a2e40b55533e2ad0c37ceedf4a31f1e194a39f74..d6bd2ecf69c91ba5483b39cd48386c636d9f534c 100755 --- a/convlab2/nlg/template/crosswoz/auto_system_template_nlg.json +++ b/convlab2/nlg/template/crosswoz/auto_system_template_nlg.json @@ -18079,7 +18079,7 @@ "价格是[Inform+酒店+价格],地址是[Inform+酒店+地址]。", "地址是在[Inform+酒店+地址],价格[Inform+酒店+价格]。", "地址在[Inform+酒店+地址];价格是[Inform+酒店+价格]。", - "价格[Inform+酒店+价格];[Inform+酒店+地址]。", + "价格[Inform+酒店+价格];地址在[Inform+酒店+地址]。", "这家酒店的地址是[Inform+酒店+地址],价格是[Inform+酒店+价格]。", "不贵,[Inform+酒店+价格],地址在[Inform+酒店+地址]。", "它的价格是[Inform+酒店+价格],然后地址是[Inform+酒店+地址]。", @@ -20988,7 +20988,7 @@ "[Inform+酒店+名称]的评分是[Inform+酒店+评分],地址是[Inform+酒店+地址]。" ], "Inform+酒店+名称*Inform+酒店+地址*Inform+酒店+酒店类型": [ - "给您推荐[Inform+酒店+名称],它[Inform+酒店+地址],为[Inform+酒店+酒店类型]酒店。", + "给您推荐[Inform+酒店+名称],它位于[Inform+酒店+地址],为[Inform+酒店+酒店类型]酒店。", "[Inform+酒店+名称]是[Inform+酒店+酒店类型]酒店,地址是[Inform+酒店+地址]。", "[Inform+酒店+名称]是一家[Inform+酒店+酒店类型]的酒店,酒店的地址是[Inform+酒店+地址]。", "[Inform+酒店+名称]是[Inform+酒店+酒店类型]酒店,地址是[Inform+酒店+地址]。", @@ -24308,7 +24308,7 @@ "地址是在[Inform+酒店+地址]。", "地址是在[Inform+酒店+地址]。", "地址在[Inform+酒店+地址]。", - "[Inform+酒店+地址]。", + "地址在[Inform+酒店+地址]。", "地址是[Inform+酒店+地址]。", "地址是[Inform+酒店+地址]。", "地址是[Inform+酒店+地址]。", diff --git a/convlab2/nlu/jointBERT/multiwoz/nlu.py b/convlab2/nlu/jointBERT/multiwoz/nlu.py index 17c24d98196eee90523bb8d40d3efc44851f8c5a..e900f1ecf2bbcfc1faa5db1caa7088b6c58154d9 100755 --- a/convlab2/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab2/nlu/jointBERT/multiwoz/nlu.py @@ -4,12 +4,13 @@ import json import torch from unidecode import unidecode import spacy -from convlab2.util.file_util import cached_path +from convlab2.util.file_util import cached_path, get_root_path from convlab2.nlu.nlu import NLU from convlab2.nlu.jointBERT.dataloader import Dataloader from convlab2.nlu.jointBERT.jointBERT import JointBERT from convlab2.nlu.jointBERT.multiwoz.postprocess import recover_intent from convlab2.nlu.jointBERT.multiwoz.preprocess import preprocess +from spacy.symbols import ORTH, LEMMA, POS class BERTNLU(NLU): @@ -56,11 +57,18 @@ class BERTNLU(NLU): self.use_context = config['model']['context'] self.dataloader = dataloader self.nlp = spacy.load('en_core_web_sm') + with open(os.path.join(get_root_path(), 'data/multiwoz/db/postcode.json'), 'r') as f: + token_list = json.load(f) + + for token in token_list: + token = token.strip() + self.nlp.tokenizer.add_special_case(token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) print("BERTNLU loaded") def predict(self, utterance, context=list()): - # ori_word_seq = unidecode(utterance).split() + # tokenization first, very important! ori_word_seq = [token.text for token in self.nlp(unidecode(utterance)) if token.text.strip()] + # print(ori_word_seq) ori_tag_seq = ['O'] * len(ori_word_seq) if self.use_context: if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1: @@ -94,9 +102,9 @@ class BERTNLU(NLU): if __name__ == '__main__': - text = "I will need you departure and arrival city and time ." + text = "How about rosa's bed and breakfast ? Their postcode is cb22ha." nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_all_context.zip') - print(nlu.predict(text, context=['', "I ' m looking for a train leaving on tuesday please ."])) - text = "I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant." print(nlu.predict(text)) + # text = "I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant.I don't care about the Price of the restaurant." + # print(nlu.predict(text)) diff --git a/convlab2/nlu/milu/multiwoz/nlu.py b/convlab2/nlu/milu/multiwoz/nlu.py index b5705d13fcfaecf67c2c57708766edabb0894aca..5417c6d958954bf1005d399895bf7e2972379861 100755 --- a/convlab2/nlu/milu/multiwoz/nlu.py +++ b/convlab2/nlu/milu/multiwoz/nlu.py @@ -12,11 +12,11 @@ from allennlp.data import DatasetReader from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter from allennlp.models.archival import load_archive -from convlab2.util.file_util import cached_path +from convlab2.util.file_util import cached_path, get_root_path from convlab2.nlu.nlu import NLU from convlab2.nlu.milu import dataset_reader, model - -from spacy.symbols import ORTH, LEMMA +import json +from spacy.symbols import ORTH, LEMMA, POS DEFAULT_CUDA_DEVICE = -1 DEFAULT_DIRECTORY = "models" @@ -47,6 +47,12 @@ class MILU(NLU): self.tokenizer = SpacyWordSplitter(language="en_core_web_sm") _special_case = [{ORTH: u"id", LEMMA: u"id"}] self.tokenizer.spacy.tokenizer.add_special_case(u"id", _special_case) + with open(os.path.join(get_root_path(), 'data/multiwoz/db/postcode.json'), 'r') as f: + token_list = json.load(f) + + for token in token_list: + token = token.strip() + self.tokenizer.spacy.tokenizer.add_special_case(token, [{ORTH: token, LEMMA: token, POS: u'NOUN'}]) dataset_reader_params = archive.config["dataset_reader"] self.dataset_reader = DatasetReader.from_params(dataset_reader_params) diff --git a/convlab2/util/file_util.py b/convlab2/util/file_util.py index 0d1fe1593a12ad891ef775dc1e608c7f3d5f7afb..0956ccf5a59d99aaeaf7243af53a22ec75469129 100755 --- a/convlab2/util/file_util.py +++ b/convlab2/util/file_util.py @@ -1,6 +1,7 @@ from pathlib import Path import zipfile import json +import os from convlab2.util.allennlp_file_utils import cached_path as allennlp_cached_path @@ -24,3 +25,7 @@ def dump_json(content, filepath): def write_zipped_json(zip_path, filepath): with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: zf.write(filepath) + + +def get_root_path(): + return os.path.abspath(os.path.join(os.path.abspath(__file__), '../../..')) diff --git a/data/multiwoz/db/extract_postcode4tokenize.py b/data/multiwoz/db/extract_postcode4tokenize.py new file mode 100644 index 0000000000000000000000000000000000000000..93e83793ed86397c61c08913eca0f1d113e41612 --- /dev/null +++ b/data/multiwoz/db/extract_postcode4tokenize.py @@ -0,0 +1,17 @@ +import json +import os + + +def main(): + dir_path = os.path.dirname(os.path.abspath(__file__)) + postcode = [] + for domain in ['attraction', 'hotel', 'hospital', 'police', 'restaurant']: + db = json.load(open(os.path.join(dir_path, "{}_db.json".format(domain)))) + for entry in db: + if entry['postcode'] not in postcode: + postcode.append(entry['postcode']) + json.dump(postcode, open(os.path.join(dir_path, "postcode.json"), 'w'), indent=2) + + +if __name__ == '__main__': + main() diff --git a/data/multiwoz/db/postcode.json b/data/multiwoz/db/postcode.json new file mode 100644 index 0000000000000000000000000000000000000000..1967363bc67fffbcd00dead652c1b9b3d7e20738 --- /dev/null +++ b/data/multiwoz/db/postcode.json @@ -0,0 +1,163 @@ +[ + "cb58nt", + "cb58as", + "cb58bs", + "cb23na", + "cb11ln", + "cb21sj", + "cb30af", + "cb58sx", + "cb30aq", + "cb23pj", + "cb13ef", + "cb39ey", + "cb21su", + "cb58ld", + "cb21jf", + "cb23bj", + "cb18dw", + "cb23bu", + "cb30ds", + "cb17dy", + "cb21tl", + "cb39al", + "cb12jb", + "cb21rh", + "cb21dq", + "cb23ap", + "cb58hy", + "cb15dh", + "cb21ta", + "cb23pq", + "cb23nz", + "cb12ew", + "cb58bl", + "cb43px", + "cb23qb", + "cb21st", + "cb42xh", + "cb21qy", + "cb30ag", + "cb46az", + "cb11pt", + "cb23dz", + "cb39da", + "cb21tt", + "cb11ly", + "cb21rf", + "cb3ojg", + "cb39et", + "cb11er", + "cb43ax", + "cb13ew", + "cb21rl", + "cb21tp", + "cb21er", + "cb21rs", + "cb22ad", + "cb23hu", + "cb23qf", + "cb23qe", + "cb41as", + "cb19ej", + "cb23hx", + "cb21rb", + "cb17gx", + "cb12lf", + "cb23hg", + "cb21tq", + "cb11ps", + "cb223ae", + "cb238el", + "cb23rh", + "cb12lj", + "cb12dp", + "cb41da", + "cb12de", + "cb13js", + "cb41xa", + "cb42je", + "cb43pe", + "cb41er", + "cb58rs", + "cb43pd", + "cb17sr", + "cb28rj", + "cb13nx", + "cb43ht", + "cb12tz", + "cb11eg", + "cb13lh", + "cb30nd", + "cb39lh", + "cb41la", + "pe296fl", + "cb41sr", + "cb22ha", + "cb236bw", + "cb21en", + "cb21ad", + "cb11ee", + "cb20qq", + "cb11jg", + "cb21ab", + "cb259aq", + "cb21dp", + "cb17ag", + "cb17aa", + "cb28pb", + "cb58pa", + "cb23nj", + "cb58jj", + "cb21db", + "cb21uf", + "cb23pp", + "cb21aw", + "cb12qa", + "cb43le", + "cb21rt", + "cb12as", + "cb30ad", + "cb12bd", + "cb21nt", + "cb23ar", + "cb12az", + "cb30ah", + "cb21qa", + "cb23dt", + "cb41jy", + "cb21tw", + "cb21eg", + "cb21rg", + "cb13nf", + "cb21rq", + "cb11lh", + "cb41uy", + "cb11bg", + "cb58wr", + "cb21uw", + "cb23ll", + "cb41eh", + "cb43hl", + "cb43lf", + "cb21la", + "cb58aq", + "cb11dg", + "cb58ba", + "cb41nl", + "cb28nx", + "cb30lx", + "cb30dq", + "cb41ep", + "cb21uj", + "cb11hr", + "cb23ju", + "cb21ug", + "cb30df", + "cb13nl", + "cb58rg", + "cb19hx", + "cb41ha", + "cb21nw", + "cb23jx" +] \ No newline at end of file diff --git a/data/multiwoz/goal/new_goal_model.pkl b/data/multiwoz/goal/new_goal_model.pkl index c1bfa6bf2243797f5c7a8b3fc64ede6516322480..ffee4b297b2cd85598e3e33cc4bf16749c86511c 100644 Binary files a/data/multiwoz/goal/new_goal_model.pkl and b/data/multiwoz/goal/new_goal_model.pkl differ