diff --git a/.gitignore b/.gitignore index 7ad53b1758697e6040fba129c8a2f16c09245906..2edbed136e2ef1fc6f4a66b5746b72763001a22a 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ convlab2/dst/sumbt/multiwoz/output/ convlab2/nlg/sclstm/**/generated_sens_sys.json convlab2/nlg/template/**/generated_sens_sys.json convlab2/nlu/jointBERT/crosswoz/**/data +convlab2/nlu/jointBERT/multiwoz/**/data # test script *_test.py diff --git a/convlab2/dst/sumbt/multiwoz_zh/convert_to_glue_format.py b/convlab2/dst/sumbt/multiwoz_zh/convert_to_glue_format.py index 210cdf07bde79294495d8a8cdb666b7e4ddaaa13..719bbc44e52933113e70335b92c3197dcd86ec1c 100644 --- a/convlab2/dst/sumbt/multiwoz_zh/convert_to_glue_format.py +++ b/convlab2/dst/sumbt/multiwoz_zh/convert_to_glue_format.py @@ -123,7 +123,7 @@ def convert_to_glue_format(data_dir, sumbt_dir): continue # not defined in ontology value = data[file_id]['log'][idx]['metadata'][domain]['book'][slot].strip() - value = trans_value(value, value) + value = trans_value(value) if str('预订' + slot) not in ontology[domain]: print("预订%s is not defined in domain %s" % (slot, domain)) diff --git a/convlab2/dst/sumbt/multiwoz_zh/sumbt.py b/convlab2/dst/sumbt/multiwoz_zh/sumbt.py index 4aa1575b4ee9bafc987c7ffe404c40835a5c0810..cd99f3924ca07c19a5436c25dfdfc59087998f9b 100644 --- a/convlab2/dst/sumbt/multiwoz_zh/sumbt.py +++ b/convlab2/dst/sumbt/multiwoz_zh/sumbt.py @@ -567,7 +567,7 @@ class SUMBTTracker(DST): print('loading weights from trained model') self.load_weights(model_path=os.path.join(SUMBT_PATH, args.output_dir, 'pytorch_model.bin')) else: - raise ValueError('no availabel weights found.') + raise ValueError('no available weights found.') self.param_restored = True def update(self, user_act=None):