diff --git a/convlab2/nlu/jointBERT/multiwoz/nlu.py b/convlab2/nlu/jointBERT/multiwoz/nlu.py index 8b9f0dcf9fb6c4d8bd426b8423f9158e881d0c0c..2418bd27e72e3e45582d9cc05c398629db52e9c1 100755 --- a/convlab2/nlu/jointBERT/multiwoz/nlu.py +++ b/convlab2/nlu/jointBERT/multiwoz/nlu.py @@ -58,6 +58,14 @@ class BERTNLU(NLU): self.use_context = config['model']['context'] self.dataloader = dataloader self.nlp = spacy.load('en_core_web_sm') + try: + self.nlp = spacy.load("en_core_web_sm") + except Exception: + print('download en_core_web_sm for spacy') + from spacy.cli.download import download as spacy_download + spacy_download("en_core_web_sm") + spacy_model_module = __import__("en_core_web_sm") + self.nlp = spacy_model_module.load() with open(os.path.join(get_root_path(), 'data/multiwoz/db/postcode.json'), 'r') as f: token_list = json.load(f)