Skip to content
Snippets Groups Projects
Commit 7a6522f0 authored by zqwerty's avatar zqwerty
Browse files

do not update ft data, but update evaluate & metric

parent 4c6d6d17
No related branches found
No related tags found
No related merge requests found
...@@ -59,24 +59,19 @@ if __name__ == '__main__': ...@@ -59,24 +59,19 @@ if __name__ == '__main__':
for shot in tqdm(args.shots, desc='shots', leave=False): for shot in tqdm(args.shots, desc='shots', leave=False):
for output_dir in tqdm(args.output_dirs, desc='models', leave=False): for output_dir in tqdm(args.output_dirs, desc='models', leave=False):
model_name = output_dir.split('/')[-1] model_name = output_dir.split('/')[-1]
if task_name == "wow":
test_splits = ["_seen", "_unseen"]
else:
test_splits = [""]
for test_split in test_splits:
results = [] results = []
for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False): for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False):
result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}") result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen")
result_file = os.path.join(result_dir, "result.json") result_file = os.path.join(result_dir, "result.json")
if not os.path.exists(result_file): if not os.path.exists(result_file):
filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}/generated_predictions.json") filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen/generated_predictions.json")
result = evaluate(filename, metric) result = evaluate(filename, metric)
json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False) json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
else: else:
result = json.load(open(result_file)) result = json.load(open(result_file))
results.append(result) results.append(result)
res = { res = {
"dataset": f"{task_name}{test_split}-{shot}shot", "dataset": f"{task_name}-{shot}shot",
"model": f"{model_name}", "model": f"{model_name}",
**avg_result(results) **avg_result(results)
} }
......
set -e
dataset_path=$1
model_name=$2
model_name_or_path=$3
dataset_name=$4
if [ "${dataset_name}" == "multiwoz21" ]
then
task_name="nlg"
else
task_name=${dataset_name}
fi
master_port=$5
n_gpus=2 n_gpus=2
master_port=23456
cache_dir="../cache" cache_dir="../cache"
dataset_path="dataset_vanilla.py"
metric_name_or_path="metric.py" metric_name_or_path="metric.py"
source_column="context+knowledge" source_column="context+knowledge"
target_column="response" target_column="response"
truncation_side="left" truncation_side="left"
max_source_length=512 max_source_length=512
max_target_length=512 max_target_length=512
model_name="t5-small"
model_name_or_path="t5-small"
per_device_train_batch_size=64 per_device_train_batch_size=64
per_device_eval_batch_size=64 per_device_eval_batch_size=64
gradient_accumulation_steps=1 gradient_accumulation_steps=1
...@@ -17,14 +26,6 @@ num_workers=16 ...@@ -17,14 +26,6 @@ num_workers=16
lr=1e-3 lr=1e-3
num_train_epochs=100 num_train_epochs=100
for dataset_name in multiwoz21 kvret opendialkg wow personachat
do
if [ "${dataset_name}" == "multiwoz21" ]
then
task_name="nlg"
else
task_name=${dataset_name}
fi
for shot in 50 100 200 for shot in 50 100 200
do do
for dial_ids_order in 0 1 2 3 4 for dial_ids_order in 0 1 2 3 4
...@@ -74,43 +75,6 @@ do ...@@ -74,43 +75,6 @@ do
--gradient_checkpointing --gradient_checkpointing
# inference # inference
if [ "${dataset_name}" == "wow" ]
then
for test_split in seen unseen
do
test_file="data/${task_name}/test_${test_split}.json"
gen_output_dir="${output_dir}/gen_${test_split}"
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--dataset_name ${dataset_path} \
--dataset_config_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--cache_dir ${cache_dir} \
--output_dir ${gen_output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--optim adafactor \
--lr_scheduler_type constant \
--gradient_checkpointing
done
else
test_file="data/${task_name}/test.json" test_file="data/${task_name}/test.json"
gen_output_dir="${output_dir}/gen" gen_output_dir="${output_dir}/gen"
...@@ -144,10 +108,9 @@ do ...@@ -144,10 +108,9 @@ do
--optim adafactor \ --optim adafactor \
--lr_scheduler_type constant \ --lr_scheduler_type constant \
--gradient_checkpointing --gradient_checkpointing
fi
done
done done
done done
# evaluation # evaluation
python evaluate.py --output_dirs output/${model_name} -t nlg kvret opendialkg personachat wow -s 50 100 200 -o 0 1 2 3 4 python evaluate.py --output_dirs output/${model_name} -t ${task_name} -s 50 100 200 -o 0 1 2 3 4
\ No newline at end of file
...@@ -65,15 +65,6 @@ Returns: ...@@ -65,15 +65,6 @@ Returns:
unigram f1: unigram overlap, from parlai unigram f1: unigram overlap, from parlai
distinct-1/2: from parlai distinct-1/2: from parlai
other knowledge utility score: task-specific knowledge utility metrics other knowledge utility score: task-specific knowledge utility metrics
Examples:
>>> nlg_metric = datasets.load_metric("metric.py", "nlg")
>>> predictions = ["hello there general kenobi", "foo bar foobar"]
>>> references = ["hello there kenobi", "foo bar foobar"]
>>> results = nlg_metric.compute(predictions=predictions, references=references)
>>> print(results)
{"bleu": 35.35533905932737}
""" """
re_art = re.compile(r'\b(a|an|the)\b') re_art = re.compile(r'\b(a|an|the)\b')
...@@ -325,12 +316,12 @@ def f1_score(y_pred, y_true, average="micro"): ...@@ -325,12 +316,12 @@ def f1_score(y_pred, y_true, average="micro"):
if average == "macro": if average == "macro":
F1_macro_score = F1_pred / float(F1_count) if F1_count != 0 else 0 F1_macro_score = F1_pred / float(F1_count) if F1_count != 0 else 0
return F1_macro_score return F1_macro_score * 100
elif average == "micro": elif average == "micro":
P_score = TP_all / float(TP_all + FP_all) if (TP_all + FP_all) != 0 else 0 P_score = TP_all / float(TP_all + FP_all) if (TP_all + FP_all) != 0 else 0
R_score = TP_all / float(TP_all + FN_all) if (TP_all + FN_all) != 0 else 0 R_score = TP_all / float(TP_all + FN_all) if (TP_all + FN_all) != 0 else 0
F1_micro_score = _compute_F1(P_score, R_score) F1_micro_score = _compute_F1(P_score, R_score)
return F1_micro_score return F1_micro_score * 100
else: else:
raise ValueError("Options other than micro/macro are not supported.") raise ValueError("Options other than micro/macro are not supported.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment