diff --git a/convlab2/base_models/bert/create_data.py b/convlab2/base_models/bert/create_data.py index f85f9d1115d0607a12f01ad6a85fdc690cf32263..825c736d2259552017166c93db168069fe0a6976 100644 --- a/convlab2/base_models/bert/create_data.py +++ b/convlab2/base_models/bert/create_data.py @@ -41,16 +41,63 @@ def create_bio_data(dataset, data_dir): labels[word_start] = 'B' for i in range(word_start+1, word_end): labels[i] = "I" - data.append(json.dumps({'tokens': tokens, 'labels': labels})+'\n') + data.append(json.dumps({'tokens': tokens, 'labels': labels}, ensure_ascii=False)+'\n') file_name = os.path.join(data_dir, f"{data_split}.json") - with open(file_name, "w") as f: + with open(file_name, "w", encoding='utf-8') as f: + f.writelines(data) + print('num of spans in utterances', cnt) + +def create_dialogBIO_data(dataset, data_dir): + data_by_split = load_nlu_data(dataset, split_to_turn=False) + os.makedirs(data_dir, exist_ok=True) + + sent_tokenizer = PunktSentenceTokenizer() + word_tokenizer = TreebankWordTokenizer() + + data_splits = data_by_split.keys() + cnt = Counter() + for data_split in data_splits: + data = [] + for dialog in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): + all_tokens, all_labels = [], [] + for sample in dialog['turns']: + speaker = sample['speaker'] + utterance = sample['utterance'] + dialogue_acts = [da for da in sample['dialogue_acts']['non-categorical'] if 'start' in da] + cnt[len(dialogue_acts)] += 1 + + sentences = sent_tokenizer.tokenize(utterance) + sent_spans = sent_tokenizer.span_tokenize(utterance) + tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)] + token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)] + labels = ['O'] * len(tokens) + for da in dialogue_acts: + char_start = da['start'] + char_end = da['end'] + word_start, word_end = -1, -1 + for i, token_span in enumerate(token_spans): + if char_start == token_span[0]: + word_start = i + if char_end == token_span[1]: + word_end = i + 1 + if word_start == -1 and word_end == -1: + # char span does not match word, skip + continue + labels[word_start] = 'B' + for i in range(word_start+1, word_end): + labels[i] = "I" + all_tokens.extend([speaker, ':']+tokens) + all_labels.extend(['O', 'O']+labels) + data.append(json.dumps({'tokens': all_tokens, 'labels': all_labels}, ensure_ascii=False)+'\n') + file_name = os.path.join(data_dir, f"{data_split}.json") + with open(file_name, "w", encoding='utf-8') as f: f.writelines(data) print('num of spans in utterances', cnt) if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser(description="create data for seq2seq training") - parser.add_argument('--tasks', metavar='task_name', nargs='*', choices=['bio'], help='names of tasks') + parser.add_argument('--tasks', metavar='task_name', nargs='*', choices=['bio', 'dialogBIO'], help='names of tasks') parser.add_argument('--datasets', metavar='dataset_name', nargs='*', help='names of unified datasets') parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, default: data/$task_name/$dataset_name') args = parser.parse_args() diff --git a/convlab2/base_models/bert/infer_bio.sh b/convlab2/base_models/bert/infer_bio.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed784c515c6703088313da0809b7c0442bcec333 --- /dev/null +++ b/convlab2/base_models/bert/infer_bio.sh @@ -0,0 +1,38 @@ +set -e +n_gpus=3 +task_name="dialogBIO" +dataset_name="multiwoz21" +data_dir="data/${task_name}/${dataset_name}" +output_dir="output/${task_name}/${dataset_name}" +cache_dir="cache" +logging_dir="${output_dir}/runs" +source_column="tokens" +target_column="labels" +model_name_or_path="output/dialogBIO/sgd" +per_device_eval_batch_size=16 + +python create_data.py --tasks ${task_name} --datasets ${dataset_name} --save_dir "data" + +for split in test validation train +do + python -m torch.distributed.launch \ + --nproc_per_node ${n_gpus} run_token_classification.py \ + --task_name ${task_name} \ + --train_file ${data_dir}/${split}.json \ + --validation_file ${data_dir}/${split}.json \ + --test_file ${data_dir}/${split}.json \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --model_name_or_path ${model_name_or_path} \ + --do_predict \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --debug underflow_overflow + + mv ${output_dir}/predictions.json ${output_dir}/${split}.json +done + diff --git a/convlab2/base_models/bert/run_token_classification.py b/convlab2/base_models/bert/run_token_classification.py index 57e156d768b3ebf84514cc73187f7092d5e7082a..c97fc60aa49a50d42a8470522d2dfaa09227b2ce 100644 --- a/convlab2/base_models/bert/run_token_classification.py +++ b/convlab2/base_models/bert/run_token_classification.py @@ -472,7 +472,7 @@ def main(): data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) # Metrics - metric = load_metric("seqeval") + metric = load_metric(path="seqeval_metric.py") def compute_metrics(p: EvalPrediction): predictions, labels = p diff --git a/convlab2/base_models/bert/seqeval_metric.py b/convlab2/base_models/bert/seqeval_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..58a6d07a2675f51a3d0ca8e56c015c4a856b3eda --- /dev/null +++ b/convlab2/base_models/bert/seqeval_metric.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" seqeval metric. """ + +import importlib +from typing import List, Optional, Union + +from seqeval.metrics import accuracy_score, classification_report + +import datasets + + +_CITATION = """\ +@inproceedings{ramshaw-marcus-1995-text, + title = "Text Chunking using Transformation-Based Learning", + author = "Ramshaw, Lance and + Marcus, Mitch", + booktitle = "Third Workshop on Very Large Corpora", + year = "1995", + url = "https://www.aclweb.org/anthology/W95-0107", +} +@misc{seqeval, + title={{seqeval}: A Python framework for sequence labeling evaluation}, + url={https://github.com/chakki-works/seqeval}, + note={Software available from https://github.com/chakki-works/seqeval}, + author={Hiroki Nakayama}, + year={2018}, +} +""" + +_DESCRIPTION = """\ +seqeval is a Python framework for sequence labeling evaluation. +seqeval can evaluate the performance of chunking tasks such as named-entity recognition, part-of-speech tagging, semantic role labeling and so on. +This is well-tested by using the Perl script conlleval, which can be used for +measuring the performance of a system that has processed the CoNLL-2000 shared task data. +seqeval supports following formats: +IOB1 +IOB2 +IOE1 +IOE2 +IOBES +See the [README.md] file at https://github.com/chakki-works/seqeval for more information. +""" + +_KWARGS_DESCRIPTION = """ +Produces labelling scores along with its sufficient statistics +from a source against one or more references. +Args: + predictions: List of List of predicted labels (Estimated targets as returned by a tagger) + references: List of List of reference labels (Ground truth (correct) target values) + suffix: True if the IOB prefix is after type, False otherwise. default: False + scheme: Specify target tagging scheme. Should be one of ["IOB1", "IOB2", "IOE1", "IOE2", "IOBES", "BILOU"]. + default: None + mode: Whether to count correct entity labels with incorrect I/B tags as true positives or not. + If you want to only count exact matches, pass mode="strict". default: None. + sample_weight: Array-like of shape (n_samples,), weights for individual samples. default: None + zero_division: Which value to substitute as a metric value when encountering zero division. Should be on of 0, 1, + "warn". "warn" acts as 0, but the warning is raised. +Returns: + 'scores': dict. Summary of the scores for overall and per type + Overall: + 'accuracy': accuracy, + 'precision': precision, + 'recall': recall, + 'f1': F1 score, also known as balanced F-score or F-measure, + Per type: + 'precision': precision, + 'recall': recall, + 'f1': F1 score, also known as balanced F-score or F-measure +Examples: + >>> predictions = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] + >>> references = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] + >>> seqeval = datasets.load_metric("seqeval") + >>> results = seqeval.compute(predictions=predictions, references=references) + >>> print(list(results.keys())) + ['MISC', 'PER', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy'] + >>> print(results["overall_f1"]) + 0.5 + >>> print(results["PER"]["f1"]) + 1.0 +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Seqeval(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + homepage="https://github.com/chakki-works/seqeval", + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), + "references": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"), + } + ), + codebase_urls=["https://github.com/chakki-works/seqeval"], + reference_urls=["https://github.com/chakki-works/seqeval"], + ) + + def _compute( + self, + predictions, + references, + suffix: bool = False, + scheme: Optional[str] = None, + mode: Optional[str] = None, + sample_weight: Optional[List[int]] = None, + zero_division: Union[str, int] = "warn", + ): + if scheme is not None: + try: + scheme_module = importlib.import_module("seqeval.scheme") + scheme = getattr(scheme_module, scheme) + except AttributeError: + raise ValueError(f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {scheme}") + report = classification_report( + y_true=references, + y_pred=predictions, + suffix=suffix, + output_dict=True, + scheme=scheme, + mode=mode, + sample_weight=sample_weight, + zero_division=zero_division, + ) + report.pop("macro avg") + report.pop("weighted avg") + overall_score = report.pop("micro avg") + + scores = { + type_name: { + "precision": score["precision"], + "recall": score["recall"], + "f1": score["f1-score"], + "number": score["support"], + } + for type_name, score in report.items() + } + scores["overall_precision"] = overall_score["precision"] + scores["overall_recall"] = overall_score["recall"] + scores["overall_f1"] = overall_score["f1-score"] + scores["overall_accuracy"] = accuracy_score(y_true=references, y_pred=predictions) + + return scores \ No newline at end of file diff --git a/convlab2/base_models/bert/run_bio.sh b/convlab2/base_models/bert/train_bio.sh similarity index 92% rename from convlab2/base_models/bert/run_bio.sh rename to convlab2/base_models/bert/train_bio.sh index 3d2e797c3020ea8d6e0ddcded1b900fb1141b384..db2ee860d2464c57dfb20d57a54ea5b34cda85b1 100644 --- a/convlab2/base_models/bert/run_bio.sh +++ b/convlab2/base_models/bert/train_bio.sh @@ -1,5 +1,5 @@ -n_gpus=8 -task_name="bio" +n_gpus=3 +task_name="dialogBIO" dataset_name="sgd" data_dir="data/${task_name}/${dataset_name}" output_dir="output/${task_name}/${dataset_name}" @@ -11,9 +11,9 @@ test_file="${data_dir}/test.json" source_column="tokens" target_column="labels" model_name_or_path="bert-base-uncased" -per_device_train_batch_size=128 -per_device_eval_batch_size=512 -gradient_accumulation_steps=1 +per_device_train_batch_size=8 +per_device_eval_batch_size=16 +gradient_accumulation_steps=2 lr=2e-5 num_train_epochs=1 metric_for_best_model="f1" diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py index bc760d7e6fafb4a7db7e90bf4e364b04dd14c77a..6e637826c4bdc2a9fead90b71c4f969ea8a92408 100644 --- a/convlab2/base_models/t5/create_data.py +++ b/convlab2/base_models/t5/create_data.py @@ -16,10 +16,10 @@ def create_rg_data(dataset, data_dir): for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]) response = f"{sample['speaker']}: {sample['utterance']}" - data.append(json.dumps({'context': context, 'response': response})+'\n') + data.append(json.dumps({'context': context, 'response': response}, ensure_ascii=False)+'\n') file_name = os.path.join(data_dir, f"{data_split}.json") - with open(file_name, "w") as f: + with open(file_name, "w", encoding='utf-8') as f: f.writelines(data) if __name__ == '__main__': diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index a7f4a2f47ce87804af3af6cf234dbcb196570e0e..aaef4470845bb400fed28ecab1ef164ffa37d4b2 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -604,10 +604,10 @@ def main(): ) predictions = [pred.strip() for pred in predictions] output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.json") - with open(output_prediction_file, "w") as writer: + with open(output_prediction_file, "w", encoding='utf-8') as writer: for sample, pred in zip(raw_datasets["test"], predictions): sample["predictions"] = pred - writer.write(json.dumps(sample)+'\n') + writer.write(json.dumps(sample, ensure_ascii=False)+'\n') kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": data_args.task_name} if data_args.dataset_name is not None: diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index 165b29af217bf28f4e3d993859ccbec3989a518b..e4344bd838785dda7a3736c37d7577a2887fd9d7 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -85,12 +85,14 @@ def load_unified_data( context_window_size=0, terminated=False, goal=False, - active_domains=False + active_domains=False, + split_to_turn=True ): data_splits = dataset.keys() if data_split == 'all' else [data_split] assert speaker in ['user', 'system', 'all'] assert not use_context or context_window_size > 0 info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results'])) + info_list += ['utt_idx'] data_by_split = {} for data_split in data_splits: data_by_split[data_split] = [] @@ -102,11 +104,11 @@ def load_unified_data( if ele in turn: sample[ele] = turn[ele] - if use_context: + if use_context or not split_to_turn: sample_copy = deepcopy(sample) context.append(sample_copy) - if speaker == turn['speaker'] or speaker == 'all': + if split_to_turn and speaker in [turn['speaker'], 'all']: if use_context: sample['context'] = context[-context_window_size-1:-1] if goal: @@ -116,6 +118,9 @@ def load_unified_data( if terminated: sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1 data_by_split[data_split].append(sample) + if not split_to_turn: + dialogue['turns'] = context + data_by_split[data_split].append(dialogue) return data_by_split def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs): diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b9fecf8056275f8430c88d292bbb26565cba59a5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,213 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +allennlp==2.8.0 +argon2-cffi==21.1.0 +argon2-cffi-bindings==21.2.0 +async-timeout==4.0.1 +attrs==21.2.0 +autopep8==1.6.0 +backcall==0.2.0 +backports.csv==1.0.7 +base58==2.1.1 +beautifulsoup4==4.10.0 +bert-score==0.3.11 +bleach==4.1.0 +blis==0.7.5 +boto3==1.20.14 +botocore==1.23.14 +cached-path==0.3.2 +cachetools==4.2.4 +catalogue==2.0.6 +certifi==2021.10.8 +cffi==1.15.0 +chardet==4.0.0 +charset-normalizer==2.0.8 +checklist==0.0.11 +cheroot==8.5.2 +CherryPy==18.6.1 +click==8.0.3 +colorama==0.4.4 +configparser==5.1.0 +cryptography==36.0.0 +cycler==0.11.0 +cymem==2.0.6 +datasets==1.16.1 +debugpy==1.5.1 +decorator==5.1.0 +deepspeech==0.9.3 +defusedxml==0.7.1 +dill==0.3.4 +docker-pycreds==0.4.0 +embeddings==0.0.8 +entrypoints==0.3 +fairscale==0.4.0 +feedparser==6.0.8 +filelock==3.3.2 +fonttools==4.28.5 +frozenlist==1.2.0 +fsspec==2021.11.1 +future==0.18.2 +fuzzywuzzy==0.18.0 +gitdb==4.0.9 +GitPython==3.1.24 +google-api-core==2.2.2 +google-auth==2.3.3 +google-cloud-core==2.2.1 +google-cloud-storage==1.43.0 +google-crc32c==1.3.0 +google-resumable-media==2.1.0 +googleapis-common-protos==1.53.0 +gTTS==2.2.3 +h5py==3.6.0 +huggingface-hub==0.1.2 +idna==3.3 +iniconfig==1.1.1 +ipykernel==6.5.1 +ipython==7.30.0 +ipython-genutils==0.2.0 +ipywidgets==7.6.5 +iso-639==0.4.5 +jaraco.classes==3.2.1 +jaraco.collections==3.4.0 +jaraco.functools==3.4.0 +jaraco.text==3.6.0 +jedi==0.18.1 +jieba==0.42.1 +Jinja2==3.0.3 +jmespath==0.10.0 +joblib==1.1.0 +json-lines==0.5.0 +jsonnet==0.17.0 +jsonpatch==1.32 +jsonpointer==2.2 +jsonschema==4.2.1 +jupyter==1.0.0 +jupyter-client==7.1.0 +jupyter-console==6.4.0 +jupyter-core==4.9.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.2 +kiwisolver==1.3.2 +langcodes==3.3.0 +lmdb==1.2.1 +lxml==4.6.4 +MarkupSafe==2.0.1 +matplotlib==3.5.1 +matplotlib-inline==0.1.3 +mistune==0.8.4 +more-itertools==8.12.0 +multidict==5.2.0 +multiprocess==0.70.12.2 +munch==2.5.0 +murmurhash==1.0.6 +nbclient==0.5.9 +nbconvert==6.3.0 +nbformat==5.1.3 +nest-asyncio==1.5.1 +nltk==3.6.5 +notebook==6.4.6 +numpy==1.21.4 +overrides==3.1.0 +packaging==21.3 +pandas==1.3.4 +pandocfilters==1.5.0 +parso==0.8.2 +pathtools==0.1.2 +pathy==0.6.1 +patternfork-nosql==3.6 +pdfminer.six==20211012 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.4.0 +pipdeptree==2.2.0 +pluggy==1.0.0 +portalocker==2.3.2 +portend==3.1.0 +preshed==3.0.6 +prometheus-client==0.12.0 +promise==2.3 +prompt-toolkit==3.0.23 +protobuf==3.19.1 +psutil==5.8.0 +ptyprocess==0.7.0 +py==1.11.0 +pyarrow==6.0.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pycparser==2.21 +pydantic==1.8.2 +pydub==0.25.1 +Pygments==2.10.0 +pyparsing==3.0.6 +pyrsistent==0.18.0 +pytest==6.2.5 +python-dateutil==2.8.2 +python-docx==0.8.11 +python-Levenshtein==0.12.2 +pytokenizations==0.8.4 +pytz==2021.3 +PyYAML==6.0 +pyzmq==22.3.0 +qtconsole==5.2.1 +QtPy==1.11.2 +quadprog==0.1.10 +regex==2021.11.10 +requests==2.26.0 +rouge-score==0.0.4 +rsa==4.8 +s3transfer==0.5.0 +sacrebleu==2.0.0 +sacremoses==0.0.46 +scikit-learn==1.0.1 +scipy==1.7.3 +Send2Trash==1.8.0 +sentencepiece==0.1.96 +sentry-sdk==1.5.0 +seqeval==1.2.2 +sgmllib3k==1.0.0 +shortuuid==1.0.8 +simplejson==3.17.6 +six==1.16.0 +smart-open==5.2.1 +smmap==5.0.0 +soupsieve==2.3.1 +spacy==3.1.4 +spacy-legacy==3.0.8 +spacy-loggers==1.0.1 +sqlitedict==1.7.0 +srsly==2.4.2 +subprocess32==3.5.4 +tabulate==0.8.9 +tempora==4.1.2 +tensorboardX==2.4.1 +termcolor==1.1.0 +terminado==0.12.1 +testpath==0.5.0 +thinc==8.0.13 +threadpoolctl==3.0.0 +tokenizers==0.10.3 +toml==0.10.2 +torch==1.8.1+cu101 +torchfile==0.1.0 +torchvision==0.9.1+cu101 +tornado==6.1 +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.12.5 +typer==0.4.0 +typing_extensions==4.0.0 +Unidecode==1.3.2 +urllib3==1.26.7 +visdom==0.1.8.9 +wandb==0.12.7 +wasabi==0.8.2 +wcwidth==0.2.5 +webencodings==0.5.1 +websocket-client==1.2.1 +widgetsnbextension==3.5.2 +xxhash==2.0.2 +yarl==1.7.2 +yaspin==2.1.0 +zc.lockfile==2.0