Skip to content
Snippets Groups Projects
Commit d38bfcae authored by zqwerty's avatar zqwerty
Browse files

train bert for bio tagging. add utf-8 for dumping json

parent 301e106f
Branches
No related tags found
No related merge requests found
...@@ -41,16 +41,63 @@ def create_bio_data(dataset, data_dir): ...@@ -41,16 +41,63 @@ def create_bio_data(dataset, data_dir):
labels[word_start] = 'B' labels[word_start] = 'B'
for i in range(word_start+1, word_end): for i in range(word_start+1, word_end):
labels[i] = "I" labels[i] = "I"
data.append(json.dumps({'tokens': tokens, 'labels': labels})+'\n') data.append(json.dumps({'tokens': tokens, 'labels': labels}, ensure_ascii=False)+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json") file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w") as f: with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
print('num of spans in utterances', cnt)
def create_dialogBIO_data(dataset, data_dir):
data_by_split = load_nlu_data(dataset, split_to_turn=False)
os.makedirs(data_dir, exist_ok=True)
sent_tokenizer = PunktSentenceTokenizer()
word_tokenizer = TreebankWordTokenizer()
data_splits = data_by_split.keys()
cnt = Counter()
for data_split in data_splits:
data = []
for dialog in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
all_tokens, all_labels = [], []
for sample in dialog['turns']:
speaker = sample['speaker']
utterance = sample['utterance']
dialogue_acts = [da for da in sample['dialogue_acts']['non-categorical'] if 'start' in da]
cnt[len(dialogue_acts)] += 1
sentences = sent_tokenizer.tokenize(utterance)
sent_spans = sent_tokenizer.span_tokenize(utterance)
tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)]
token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)]
labels = ['O'] * len(tokens)
for da in dialogue_acts:
char_start = da['start']
char_end = da['end']
word_start, word_end = -1, -1
for i, token_span in enumerate(token_spans):
if char_start == token_span[0]:
word_start = i
if char_end == token_span[1]:
word_end = i + 1
if word_start == -1 and word_end == -1:
# char span does not match word, skip
continue
labels[word_start] = 'B'
for i in range(word_start+1, word_end):
labels[i] = "I"
all_tokens.extend([speaker, ':']+tokens)
all_labels.extend(['O', 'O']+labels)
data.append(json.dumps({'tokens': all_tokens, 'labels': all_labels}, ensure_ascii=False)+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data) f.writelines(data)
print('num of spans in utterances', cnt) print('num of spans in utterances', cnt)
if __name__ == '__main__': if __name__ == '__main__':
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training") parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument('--tasks', metavar='task_name', nargs='*', choices=['bio'], help='names of tasks') parser.add_argument('--tasks', metavar='task_name', nargs='*', choices=['bio', 'dialogBIO'], help='names of tasks')
parser.add_argument('--datasets', metavar='dataset_name', nargs='*', help='names of unified datasets') parser.add_argument('--datasets', metavar='dataset_name', nargs='*', help='names of unified datasets')
parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, default: data/$task_name/$dataset_name') parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, default: data/$task_name/$dataset_name')
args = parser.parse_args() args = parser.parse_args()
......
set -e
n_gpus=3
task_name="dialogBIO"
dataset_name="multiwoz21"
data_dir="data/${task_name}/${dataset_name}"
output_dir="output/${task_name}/${dataset_name}"
cache_dir="cache"
logging_dir="${output_dir}/runs"
source_column="tokens"
target_column="labels"
model_name_or_path="output/dialogBIO/sgd"
per_device_eval_batch_size=16
python create_data.py --tasks ${task_name} --datasets ${dataset_name} --save_dir "data"
for split in test validation train
do
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} run_token_classification.py \
--task_name ${task_name} \
--train_file ${data_dir}/${split}.json \
--validation_file ${data_dir}/${split}.json \
--test_file ${data_dir}/${split}.json \
--source_column ${source_column} \
--target_column ${target_column} \
--model_name_or_path ${model_name_or_path} \
--do_predict \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--debug underflow_overflow
mv ${output_dir}/predictions.json ${output_dir}/${split}.json
done
...@@ -472,7 +472,7 @@ def main(): ...@@ -472,7 +472,7 @@ def main():
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
# Metrics # Metrics
metric = load_metric("seqeval") metric = load_metric(path="seqeval_metric.py")
def compute_metrics(p: EvalPrediction): def compute_metrics(p: EvalPrediction):
predictions, labels = p predictions, labels = p
......
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" seqeval metric. """
import importlib
from typing import List, Optional, Union
from seqeval.metrics import accuracy_score, classification_report
import datasets
_CITATION = """\
@inproceedings{ramshaw-marcus-1995-text,
title = "Text Chunking using Transformation-Based Learning",
author = "Ramshaw, Lance and
Marcus, Mitch",
booktitle = "Third Workshop on Very Large Corpora",
year = "1995",
url = "https://www.aclweb.org/anthology/W95-0107",
}
@misc{seqeval,
title={{seqeval}: A Python framework for sequence labeling evaluation},
url={https://github.com/chakki-works/seqeval},
note={Software available from https://github.com/chakki-works/seqeval},
author={Hiroki Nakayama},
year={2018},
}
"""
_DESCRIPTION = """\
seqeval is a Python framework for sequence labeling evaluation.
seqeval can evaluate the performance of chunking tasks such as named-entity recognition, part-of-speech tagging, semantic role labeling and so on.
This is well-tested by using the Perl script conlleval, which can be used for
measuring the performance of a system that has processed the CoNLL-2000 shared task data.
seqeval supports following formats:
IOB1
IOB2
IOE1
IOE2
IOBES
See the [README.md] file at https://github.com/chakki-works/seqeval for more information.
"""
_KWARGS_DESCRIPTION = """
Produces labelling scores along with its sufficient statistics
from a source against one or more references.
Args:
predictions: List of List of predicted labels (Estimated targets as returned by a tagger)
references: List of List of reference labels (Ground truth (correct) target values)
suffix: True if the IOB prefix is after type, False otherwise. default: False
scheme: Specify target tagging scheme. Should be one of ["IOB1", "IOB2", "IOE1", "IOE2", "IOBES", "BILOU"].
default: None
mode: Whether to count correct entity labels with incorrect I/B tags as true positives or not.
If you want to only count exact matches, pass mode="strict". default: None.
sample_weight: Array-like of shape (n_samples,), weights for individual samples. default: None
zero_division: Which value to substitute as a metric value when encountering zero division. Should be on of 0, 1,
"warn". "warn" acts as 0, but the warning is raised.
Returns:
'scores': dict. Summary of the scores for overall and per type
Overall:
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': F1 score, also known as balanced F-score or F-measure,
Per type:
'precision': precision,
'recall': recall,
'f1': F1 score, also known as balanced F-score or F-measure
Examples:
>>> predictions = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> references = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> seqeval = datasets.load_metric("seqeval")
>>> results = seqeval.compute(predictions=predictions, references=references)
>>> print(list(results.keys()))
['MISC', 'PER', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy']
>>> print(results["overall_f1"])
0.5
>>> print(results["PER"]["f1"])
1.0
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Seqeval(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
homepage="https://github.com/chakki-works/seqeval",
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"),
"references": datasets.Sequence(datasets.Value("string", id="label"), id="sequence"),
}
),
codebase_urls=["https://github.com/chakki-works/seqeval"],
reference_urls=["https://github.com/chakki-works/seqeval"],
)
def _compute(
self,
predictions,
references,
suffix: bool = False,
scheme: Optional[str] = None,
mode: Optional[str] = None,
sample_weight: Optional[List[int]] = None,
zero_division: Union[str, int] = "warn",
):
if scheme is not None:
try:
scheme_module = importlib.import_module("seqeval.scheme")
scheme = getattr(scheme_module, scheme)
except AttributeError:
raise ValueError(f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {scheme}")
report = classification_report(
y_true=references,
y_pred=predictions,
suffix=suffix,
output_dict=True,
scheme=scheme,
mode=mode,
sample_weight=sample_weight,
zero_division=zero_division,
)
report.pop("macro avg")
report.pop("weighted avg")
overall_score = report.pop("micro avg")
scores = {
type_name: {
"precision": score["precision"],
"recall": score["recall"],
"f1": score["f1-score"],
"number": score["support"],
}
for type_name, score in report.items()
}
scores["overall_precision"] = overall_score["precision"]
scores["overall_recall"] = overall_score["recall"]
scores["overall_f1"] = overall_score["f1-score"]
scores["overall_accuracy"] = accuracy_score(y_true=references, y_pred=predictions)
return scores
\ No newline at end of file
n_gpus=8 n_gpus=3
task_name="bio" task_name="dialogBIO"
dataset_name="sgd" dataset_name="sgd"
data_dir="data/${task_name}/${dataset_name}" data_dir="data/${task_name}/${dataset_name}"
output_dir="output/${task_name}/${dataset_name}" output_dir="output/${task_name}/${dataset_name}"
...@@ -11,9 +11,9 @@ test_file="${data_dir}/test.json" ...@@ -11,9 +11,9 @@ test_file="${data_dir}/test.json"
source_column="tokens" source_column="tokens"
target_column="labels" target_column="labels"
model_name_or_path="bert-base-uncased" model_name_or_path="bert-base-uncased"
per_device_train_batch_size=128 per_device_train_batch_size=8
per_device_eval_batch_size=512 per_device_eval_batch_size=16
gradient_accumulation_steps=1 gradient_accumulation_steps=2
lr=2e-5 lr=2e-5
num_train_epochs=1 num_train_epochs=1
metric_for_best_model="f1" metric_for_best_model="f1"
......
...@@ -16,10 +16,10 @@ def create_rg_data(dataset, data_dir): ...@@ -16,10 +16,10 @@ def create_rg_data(dataset, data_dir):
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]) context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']])
response = f"{sample['speaker']}: {sample['utterance']}" response = f"{sample['speaker']}: {sample['utterance']}"
data.append(json.dumps({'context': context, 'response': response})+'\n') data.append(json.dumps({'context': context, 'response': response}, ensure_ascii=False)+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json") file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w") as f: with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data) f.writelines(data)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -604,10 +604,10 @@ def main(): ...@@ -604,10 +604,10 @@ def main():
) )
predictions = [pred.strip() for pred in predictions] predictions = [pred.strip() for pred in predictions]
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.json") output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.json")
with open(output_prediction_file, "w") as writer: with open(output_prediction_file, "w", encoding='utf-8') as writer:
for sample, pred in zip(raw_datasets["test"], predictions): for sample, pred in zip(raw_datasets["test"], predictions):
sample["predictions"] = pred sample["predictions"] = pred
writer.write(json.dumps(sample)+'\n') writer.write(json.dumps(sample, ensure_ascii=False)+'\n')
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": data_args.task_name} kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": data_args.task_name}
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
......
...@@ -85,12 +85,14 @@ def load_unified_data( ...@@ -85,12 +85,14 @@ def load_unified_data(
context_window_size=0, context_window_size=0,
terminated=False, terminated=False,
goal=False, goal=False,
active_domains=False active_domains=False,
split_to_turn=True
): ):
data_splits = dataset.keys() if data_split == 'all' else [data_split] data_splits = dataset.keys() if data_split == 'all' else [data_split]
assert speaker in ['user', 'system', 'all'] assert speaker in ['user', 'system', 'all']
assert not use_context or context_window_size > 0 assert not use_context or context_window_size > 0
info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results'])) info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results']))
info_list += ['utt_idx']
data_by_split = {} data_by_split = {}
for data_split in data_splits: for data_split in data_splits:
data_by_split[data_split] = [] data_by_split[data_split] = []
...@@ -102,11 +104,11 @@ def load_unified_data( ...@@ -102,11 +104,11 @@ def load_unified_data(
if ele in turn: if ele in turn:
sample[ele] = turn[ele] sample[ele] = turn[ele]
if use_context: if use_context or not split_to_turn:
sample_copy = deepcopy(sample) sample_copy = deepcopy(sample)
context.append(sample_copy) context.append(sample_copy)
if speaker == turn['speaker'] or speaker == 'all': if split_to_turn and speaker in [turn['speaker'], 'all']:
if use_context: if use_context:
sample['context'] = context[-context_window_size-1:-1] sample['context'] = context[-context_window_size-1:-1]
if goal: if goal:
...@@ -116,6 +118,9 @@ def load_unified_data( ...@@ -116,6 +118,9 @@ def load_unified_data(
if terminated: if terminated:
sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1 sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1
data_by_split[data_split].append(sample) data_by_split[data_split].append(sample)
if not split_to_turn:
dialogue['turns'] = context
data_by_split[data_split].append(dialogue)
return data_by_split return data_by_split
def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs): def load_nlu_data(dataset, data_split='all', speaker='user', use_context=False, context_window_size=0, **kwargs):
......
absl-py==1.0.0
aiohttp==3.8.1
aiosignal==1.2.0
allennlp==2.8.0
argon2-cffi==21.1.0
argon2-cffi-bindings==21.2.0
async-timeout==4.0.1
attrs==21.2.0
autopep8==1.6.0
backcall==0.2.0
backports.csv==1.0.7
base58==2.1.1
beautifulsoup4==4.10.0
bert-score==0.3.11
bleach==4.1.0
blis==0.7.5
boto3==1.20.14
botocore==1.23.14
cached-path==0.3.2
cachetools==4.2.4
catalogue==2.0.6
certifi==2021.10.8
cffi==1.15.0
chardet==4.0.0
charset-normalizer==2.0.8
checklist==0.0.11
cheroot==8.5.2
CherryPy==18.6.1
click==8.0.3
colorama==0.4.4
configparser==5.1.0
cryptography==36.0.0
cycler==0.11.0
cymem==2.0.6
datasets==1.16.1
debugpy==1.5.1
decorator==5.1.0
deepspeech==0.9.3
defusedxml==0.7.1
dill==0.3.4
docker-pycreds==0.4.0
embeddings==0.0.8
entrypoints==0.3
fairscale==0.4.0
feedparser==6.0.8
filelock==3.3.2
fonttools==4.28.5
frozenlist==1.2.0
fsspec==2021.11.1
future==0.18.2
fuzzywuzzy==0.18.0
gitdb==4.0.9
GitPython==3.1.24
google-api-core==2.2.2
google-auth==2.3.3
google-cloud-core==2.2.1
google-cloud-storage==1.43.0
google-crc32c==1.3.0
google-resumable-media==2.1.0
googleapis-common-protos==1.53.0
gTTS==2.2.3
h5py==3.6.0
huggingface-hub==0.1.2
idna==3.3
iniconfig==1.1.1
ipykernel==6.5.1
ipython==7.30.0
ipython-genutils==0.2.0
ipywidgets==7.6.5
iso-639==0.4.5
jaraco.classes==3.2.1
jaraco.collections==3.4.0
jaraco.functools==3.4.0
jaraco.text==3.6.0
jedi==0.18.1
jieba==0.42.1
Jinja2==3.0.3
jmespath==0.10.0
joblib==1.1.0
json-lines==0.5.0
jsonnet==0.17.0
jsonpatch==1.32
jsonpointer==2.2
jsonschema==4.2.1
jupyter==1.0.0
jupyter-client==7.1.0
jupyter-console==6.4.0
jupyter-core==4.9.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.2
kiwisolver==1.3.2
langcodes==3.3.0
lmdb==1.2.1
lxml==4.6.4
MarkupSafe==2.0.1
matplotlib==3.5.1
matplotlib-inline==0.1.3
mistune==0.8.4
more-itertools==8.12.0
multidict==5.2.0
multiprocess==0.70.12.2
munch==2.5.0
murmurhash==1.0.6
nbclient==0.5.9
nbconvert==6.3.0
nbformat==5.1.3
nest-asyncio==1.5.1
nltk==3.6.5
notebook==6.4.6
numpy==1.21.4
overrides==3.1.0
packaging==21.3
pandas==1.3.4
pandocfilters==1.5.0
parso==0.8.2
pathtools==0.1.2
pathy==0.6.1
patternfork-nosql==3.6
pdfminer.six==20211012
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.4.0
pipdeptree==2.2.0
pluggy==1.0.0
portalocker==2.3.2
portend==3.1.0
preshed==3.0.6
prometheus-client==0.12.0
promise==2.3
prompt-toolkit==3.0.23
protobuf==3.19.1
psutil==5.8.0
ptyprocess==0.7.0
py==1.11.0
pyarrow==6.0.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycodestyle==2.8.0
pycparser==2.21
pydantic==1.8.2
pydub==0.25.1
Pygments==2.10.0
pyparsing==3.0.6
pyrsistent==0.18.0
pytest==6.2.5
python-dateutil==2.8.2
python-docx==0.8.11
python-Levenshtein==0.12.2
pytokenizations==0.8.4
pytz==2021.3
PyYAML==6.0
pyzmq==22.3.0
qtconsole==5.2.1
QtPy==1.11.2
quadprog==0.1.10
regex==2021.11.10
requests==2.26.0
rouge-score==0.0.4
rsa==4.8
s3transfer==0.5.0
sacrebleu==2.0.0
sacremoses==0.0.46
scikit-learn==1.0.1
scipy==1.7.3
Send2Trash==1.8.0
sentencepiece==0.1.96
sentry-sdk==1.5.0
seqeval==1.2.2
sgmllib3k==1.0.0
shortuuid==1.0.8
simplejson==3.17.6
six==1.16.0
smart-open==5.2.1
smmap==5.0.0
soupsieve==2.3.1
spacy==3.1.4
spacy-legacy==3.0.8
spacy-loggers==1.0.1
sqlitedict==1.7.0
srsly==2.4.2
subprocess32==3.5.4
tabulate==0.8.9
tempora==4.1.2
tensorboardX==2.4.1
termcolor==1.1.0
terminado==0.12.1
testpath==0.5.0
thinc==8.0.13
threadpoolctl==3.0.0
tokenizers==0.10.3
toml==0.10.2
torch==1.8.1+cu101
torchfile==0.1.0
torchvision==0.9.1+cu101
tornado==6.1
tqdm==4.62.3
traitlets==5.1.1
transformers==4.12.5
typer==0.4.0
typing_extensions==4.0.0
Unidecode==1.3.2
urllib3==1.26.7
visdom==0.1.8.9
wandb==0.12.7
wasabi==0.8.2
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==1.2.1
widgetsnbextension==3.5.2
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zc.lockfile==2.0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment