diff --git a/convlab/base_models/t5/key2gen/create_data.py b/convlab/base_models/t5/key2gen/create_data.py index cb4e12c8e720f3031808ec0972c38926a617ef68..54cb8e2aa4b3ac49934d64d52d94496c89ca1b84 100644 --- a/convlab/base_models/t5/key2gen/create_data.py +++ b/convlab/base_models/t5/key2gen/create_data.py @@ -4,18 +4,26 @@ from tqdm import tqdm from convlab.util import load_dataset, load_unified_data, load_nlu_data def create_nlg_data(dataset, data_dir, args): - data_by_split = load_nlu_data(dataset, speaker='system', use_context=True, context_window_size=3) + data_by_split = dataset os.makedirs(data_dir, exist_ok=True) data_splits = data_by_split.keys() for data_split in data_splits: data = [] - for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - context = [(turn['speaker'], turn['utterance']) for turn in sample['context']] - response = sample['utterance'] - if len(context) > 0 and len(response) > 0: - knowledge = sample['dialogue_acts'] - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') + num_dial = 0 + for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): + context = [] + is_valid = False + for turn in dial['turns']: + response = turn['utterance'] + context.append((turn['speaker'], turn['utterance'])) + if turn['speaker'] == 'system' and len(context) > 1 and len(response) > 0: + data.append(json.dumps({'context': context[-4:-1], 'knowledge': turn['dialogue_acts'], 'response': response}, ensure_ascii=False)+'\n') + is_valid = True + if is_valid: + num_dial += 1 + if 'test' not in data_split and args.shot and isinstance(args.shot, int) and args.shot >= 1 and args.shot == num_dial: + break if 'test' in data_split: file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") @@ -27,24 +35,34 @@ def create_nlg_data(dataset, data_dir, args): return data_by_split def create_kvret_data(dataset, data_dir, args): - data_by_split = load_unified_data(dataset, speaker='system', utterance=True, db_results=True, use_context=True, context_window_size=100) + data_by_split = dataset os.makedirs(data_dir, exist_ok=True) domain2entity_col = {'schedule': 'event' ,'navigate': 'poi', 'weather': 'location'} data_splits = data_by_split.keys() for data_split in data_splits: data = [] - for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - context = [(turn['speaker'], turn['utterance']) for turn in sample['context']] - response = sample['utterance'] - if len(context) > 0 and len(response) > 0: - knowledge = sample['db_results'] - for domain, db_items in knowledge.items(): - entity_col = domain2entity_col[domain] - for db_item in db_items: - db_item['entity'] = db_item.pop(entity_col) - - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') + num_dial = 0 + for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): + context = [] + is_valid = False + for turn in dial['turns']: + response = turn['utterance'] + context.append((turn['speaker'], turn['utterance'])) + if turn['speaker'] == 'system' and len(context) > 1 and len(response) > 0: + knowledge = turn['db_results'] + if dial['domains'][0] == 'schedule' and len(knowledge['schedule']) == 0: + continue + for domain, db_items in knowledge.items(): + entity_col = domain2entity_col[domain] + for db_item in db_items: + db_item['entity'] = db_item.pop(entity_col) + data.append(json.dumps({'context': context[:-1], 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') + is_valid = True + if is_valid: + num_dial += 1 + if 'test' not in data_split and args.shot and isinstance(args.shot, int) and args.shot >= 1 and args.shot == num_dial: + break if 'test' in data_split: file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") @@ -62,14 +80,21 @@ def create_personachat_data(dataset, data_dir, args): data_splits = data_by_split.keys() for data_split in data_splits: data = [] + num_dial = 0 for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): knowledge = dial['persona']['system'] context = [] + is_valid = False for turn in dial['turns']: response = turn['utterance'] - if turn['speaker'] == 'system' and len(context) > 0 and len(response) > 0: - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') context.append((turn['speaker'], turn['utterance'])) + if turn['speaker'] == 'system' and len(context) > 1 and len(response) > 0: + data.append(json.dumps({'context': context[:-1], 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') + is_valid = True + if is_valid: + num_dial += 1 + if 'test' not in data_split and args.shot and isinstance(args.shot, int) and args.shot >= 1 and args.shot == num_dial: + break if 'test' in data_split: file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") @@ -83,22 +108,30 @@ def create_personachat_data(dataset, data_dir, args): def create_wow_data(dataset, data_dir, args): data_by_split = dataset os.makedirs(data_dir, exist_ok=True) + data_by_split['test'] = data_by_split['test_seen'] + data_by_split['test_unseen'] + data_by_split.pop('test_seen') + data_by_split.pop('test_unseen') data_splits = data_by_split.keys() for data_split in data_splits: data = [] + num_dial = 0 for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): context = [] + is_valid = False for turn in dial['turns']: response = turn['utterance'] - if turn['speaker'] == 'system' and len(context) > 0 and len(response) > 0: + context.append((turn['speaker'], turn['utterance'])) + if turn['speaker'] == 'system' and len(context) > 1 and len(response) > 0: knowledge = turn['checked_passage'] if knowledge is None: - knowledge = [] - elif isinstance(knowledge, str): - knowledge = [knowledge] - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') - context.append((turn['speaker'], turn['utterance'])) + continue + data.append(json.dumps({'context': context[:-1], 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') + is_valid = True + if is_valid: + num_dial += 1 + if 'test' not in data_split and args.shot and isinstance(args.shot, int) and args.shot >= 1 and args.shot == num_dial: + break if 'test' in data_split: file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") @@ -116,14 +149,22 @@ def create_opendialkg_data(dataset, data_dir, args): data_splits = data_by_split.keys() for data_split in data_splits: data = [] + num_dial = 0 for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): context = [] + is_valid = False for turn in dial['turns']: response = turn['utterance'] + context.append((turn['speaker'], turn['utterance'])) if turn['speaker'] == 'system' and 'kg_path' in turn and len(context) > 0 and len(response) > 0: knowledge = turn['kg_path']['triples'] - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') - context.append((turn['speaker'], turn['utterance'])) + assert len(knowledge) > 0 + data.append(json.dumps({'context': context[:-1], 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') + is_valid = True + if is_valid: + num_dial += 1 + if 'test' not in data_split and args.shot and isinstance(args.shot, int) and args.shot >= 1 and args.shot == num_dial: + break if 'test' in data_split: file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") @@ -147,13 +188,14 @@ if __name__ == '__main__': for dataset_name in tqdm(args.datasets, desc='datasets'): dataset = load_dataset(dataset_name, dial_ids_order=args.dial_ids_order) if args.shot: + # few-shot if args.shot < 1: + # percentage dataset['train'] = dataset['train'][:round(len(dataset['train'])*args.shot)] dataset['validation'] = dataset['validation'][:round(len(dataset['validation'])*args.shot)] else: + # absolute, handle inside process function args.shot = int(args.shot) - dataset['train'] = dataset['train'][:args.shot] - dataset['validation'] = dataset['validation'][:args.shot] for task_name in tqdm(args.tasks, desc='tasks', leave=False): data_dir = os.path.join('data', task_name, (dataset_name if not args.shot else f'{dataset_name}_{args.shot}shot_order{args.dial_ids_order}')) data_by_split = eval(f"create_{task_name}_data")(dataset, data_dir, args) diff --git a/convlab/base_models/t5/key2gen/evaluate.py b/convlab/base_models/t5/key2gen/evaluate.py index 7acb1118cc857d4cd3e1b401b1d8ecddab2288e9..769fdfcf3d1c899aad1b5389dad2c8d9465c05c6 100644 --- a/convlab/base_models/t5/key2gen/evaluate.py +++ b/convlab/base_models/t5/key2gen/evaluate.py @@ -59,31 +59,26 @@ if __name__ == '__main__': for shot in tqdm(args.shots, desc='shots', leave=False): for output_dir in tqdm(args.output_dirs, desc='models', leave=False): model_name = output_dir.split('/')[-1] - if task_name == "wow": - test_splits = ["_seen", "_unseen"] - else: - test_splits = [""] - for test_split in test_splits: - results = [] - for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False): - result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}") - result_file = os.path.join(result_dir, "result.json") - if not os.path.exists(result_file): - filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen{test_split}/generated_predictions.json") - result = evaluate(filename, metric) - json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False) - else: - result = json.load(open(result_file)) - results.append(result) - res = { - "dataset": f"{task_name}{test_split}-{shot}shot", - "model": f"{model_name}", - **avg_result(results) - } - table.append(res) - for k in res: - if k not in fieldnames: - fieldnames.append(k) + results = [] + for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False): + result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen") + result_file = os.path.join(result_dir, "result.json") + if not os.path.exists(result_file): + filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen/generated_predictions.json") + result = evaluate(filename, metric) + json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + else: + result = json.load(open(result_file)) + results.append(result) + res = { + "dataset": f"{task_name}-{shot}shot", + "model": f"{model_name}", + **avg_result(results) + } + table.append(res) + for k in res: + if k not in fieldnames: + fieldnames.append(k) res = tabulate(table, headers='keys', tablefmt='github') with open(f'eval_results.txt', 'w', encoding='utf-8') as f: diff --git a/convlab/base_models/t5/key2gen/finetune.sh b/convlab/base_models/t5/key2gen/finetune.sh index 390ea1a908dcf9f335ec74d35a422239e9a923ca..df53406850248c89a54a9ead4395c0b78a1aaba3 100644 --- a/convlab/base_models/t5/key2gen/finetune.sh +++ b/convlab/base_models/t5/key2gen/finetune.sh @@ -1,15 +1,24 @@ +set -e +dataset_path=$1 +model_name=$2 +model_name_or_path=$3 +dataset_name=$4 +if [ "${dataset_name}" == "multiwoz21" ] +then + task_name="nlg" +else + task_name=${dataset_name} +fi +master_port=$5 + n_gpus=2 -master_port=23456 cache_dir="../cache" -dataset_path="dataset_vanilla.py" metric_name_or_path="metric.py" source_column="context+knowledge" target_column="response" truncation_side="left" max_source_length=512 max_target_length=512 -model_name="t5-small" -model_name_or_path="t5-small" per_device_train_batch_size=64 per_device_eval_batch_size=64 gradient_accumulation_steps=1 @@ -17,137 +26,91 @@ num_workers=16 lr=1e-3 num_train_epochs=100 -for dataset_name in multiwoz21 kvret opendialkg wow personachat +for shot in 50 100 200 do - if [ "${dataset_name}" == "multiwoz21" ] - then - task_name="nlg" - else - task_name=${dataset_name} - fi - for shot in 50 100 200 + for dial_ids_order in 0 1 2 3 4 do - for dial_ids_order in 0 1 2 3 4 - do - python create_data.py -t ${task_name} -d ${dataset_name} -o ${dial_ids_order} -s ${shot} + python create_data.py -t ${task_name} -d ${dataset_name} -o ${dial_ids_order} -s ${shot} - data_dir="data/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}" - output_dir="output/${model_name}/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}" - logging_dir="${output_dir}/runs" - train_file="${data_dir}/train.json" - validation_file="${data_dir}/validation.json" + data_dir="data/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}" + output_dir="output/${model_name}/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}" + logging_dir="${output_dir}/runs" + train_file="${data_dir}/train.json" + validation_file="${data_dir}/validation.json" - # training - python -m torch.distributed.launch --master_port ${master_port} \ - --nproc_per_node ${n_gpus} ../run_seq2seq.py \ - --task_name ${task_name} \ - --dataset_name ${dataset_path} \ - --dataset_config_name ${task_name} \ - --train_file ${train_file} \ - --validation_file ${validation_file} \ - --source_column ${source_column} \ - --target_column ${target_column} \ - --max_source_length ${max_source_length} \ - --max_target_length ${max_target_length} \ - --truncation_side ${truncation_side} \ - --model_name_or_path ${model_name_or_path} \ - --do_train \ - --do_eval \ - --save_strategy epoch \ - --evaluation_strategy epoch \ - --save_total_limit 1 \ - --prediction_loss_only \ - --load_best_model_at_end \ - --overwrite_output_dir \ - --cache_dir ${cache_dir} \ - --output_dir ${output_dir} \ - --logging_dir ${logging_dir} \ - --preprocessing_num_workers ${num_workers} \ - --dataloader_num_workers ${num_workers} \ - --per_device_train_batch_size ${per_device_train_batch_size} \ - --per_device_eval_batch_size ${per_device_eval_batch_size} \ - --gradient_accumulation_steps ${gradient_accumulation_steps} \ - --learning_rate ${lr} \ - --num_train_epochs ${num_train_epochs} \ - --optim adafactor \ - --lr_scheduler_type constant \ - --gradient_checkpointing + # training + python -m torch.distributed.launch --master_port ${master_port} \ + --nproc_per_node ${n_gpus} ../run_seq2seq.py \ + --task_name ${task_name} \ + --dataset_name ${dataset_path} \ + --dataset_config_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --save_total_limit 1 \ + --prediction_loss_only \ + --load_best_model_at_end \ + --overwrite_output_dir \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --preprocessing_num_workers ${num_workers} \ + --dataloader_num_workers ${num_workers} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --optim adafactor \ + --lr_scheduler_type constant \ + --gradient_checkpointing - # inference - if [ "${dataset_name}" == "wow" ] - then - for test_split in seen unseen - do - test_file="data/${task_name}/test_${test_split}.json" - gen_output_dir="${output_dir}/gen_${test_split}" + # inference + test_file="data/${task_name}/test.json" + gen_output_dir="${output_dir}/gen" - python -m torch.distributed.launch --master_port ${master_port} \ - --nproc_per_node ${n_gpus} ../run_seq2seq.py \ - --task_name ${task_name} \ - --dataset_name ${dataset_path} \ - --dataset_config_name ${task_name} \ - --test_file ${test_file} \ - --source_column ${source_column} \ - --target_column ${target_column} \ - --max_source_length ${max_source_length} \ - --max_target_length ${max_target_length} \ - --truncation_side ${truncation_side} \ - --model_name_or_path ${output_dir} \ - --do_predict \ - --predict_with_generate \ - --cache_dir ${cache_dir} \ - --output_dir ${gen_output_dir} \ - --logging_dir ${logging_dir} \ - --overwrite_output_dir \ - --preprocessing_num_workers ${num_workers} \ - --dataloader_num_workers ${num_workers} \ - --per_device_train_batch_size ${per_device_train_batch_size} \ - --per_device_eval_batch_size ${per_device_eval_batch_size} \ - --gradient_accumulation_steps ${gradient_accumulation_steps} \ - --learning_rate ${lr} \ - --num_train_epochs ${num_train_epochs} \ - --optim adafactor \ - --lr_scheduler_type constant \ - --gradient_checkpointing - done - else - test_file="data/${task_name}/test.json" - gen_output_dir="${output_dir}/gen" - - python -m torch.distributed.launch --master_port ${master_port} \ - --nproc_per_node ${n_gpus} ../run_seq2seq.py \ - --task_name ${task_name} \ - --dataset_name ${dataset_path} \ - --dataset_config_name ${task_name} \ - --metric_name_or_path ${metric_name_or_path} \ - --metric_config_name ${task_name} \ - --test_file ${test_file} \ - --source_column ${source_column} \ - --target_column ${target_column} \ - --max_source_length ${max_source_length} \ - --max_target_length ${max_target_length} \ - --truncation_side ${truncation_side} \ - --model_name_or_path ${output_dir} \ - --do_predict \ - --predict_with_generate \ - --cache_dir ${cache_dir} \ - --output_dir ${gen_output_dir} \ - --logging_dir ${logging_dir} \ - --overwrite_output_dir \ - --preprocessing_num_workers ${num_workers} \ - --dataloader_num_workers ${num_workers} \ - --per_device_train_batch_size ${per_device_train_batch_size} \ - --per_device_eval_batch_size ${per_device_eval_batch_size} \ - --gradient_accumulation_steps ${gradient_accumulation_steps} \ - --learning_rate ${lr} \ - --num_train_epochs ${num_train_epochs} \ - --optim adafactor \ - --lr_scheduler_type constant \ - --gradient_checkpointing - fi - done + python -m torch.distributed.launch --master_port ${master_port} \ + --nproc_per_node ${n_gpus} ../run_seq2seq.py \ + --task_name ${task_name} \ + --dataset_name ${dataset_path} \ + --dataset_config_name ${task_name} \ + --metric_name_or_path ${metric_name_or_path} \ + --metric_config_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --cache_dir ${cache_dir} \ + --output_dir ${gen_output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers ${num_workers} \ + --dataloader_num_workers ${num_workers} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --optim adafactor \ + --lr_scheduler_type constant \ + --gradient_checkpointing + done done # evaluation -python evaluate.py --output_dirs output/${model_name} -t nlg kvret opendialkg personachat wow -s 50 100 200 -o 0 1 2 3 4 \ No newline at end of file +python evaluate.py --output_dirs output/${model_name} -t ${task_name} -s 50 100 200 -o 0 1 2 3 4 \ No newline at end of file diff --git a/convlab/base_models/t5/key2gen/metric.py b/convlab/base_models/t5/key2gen/metric.py index d3e493188194639adad04539619bb63f14284841..ce385c9d9ea073a8a1e2245f1be263857c810668 100644 --- a/convlab/base_models/t5/key2gen/metric.py +++ b/convlab/base_models/t5/key2gen/metric.py @@ -65,15 +65,6 @@ Returns: unigram f1: unigram overlap, from parlai distinct-1/2: from parlai other knowledge utility score: task-specific knowledge utility metrics - -Examples: - - >>> nlg_metric = datasets.load_metric("metric.py", "nlg") - >>> predictions = ["hello there general kenobi", "foo bar foobar"] - >>> references = ["hello there kenobi", "foo bar foobar"] - >>> results = nlg_metric.compute(predictions=predictions, references=references) - >>> print(results) - {"bleu": 35.35533905932737} """ re_art = re.compile(r'\b(a|an|the)\b') @@ -325,11 +316,11 @@ def f1_score(y_pred, y_true, average="micro"): if average == "macro": F1_macro_score = F1_pred / float(F1_count) if F1_count != 0 else 0 - return F1_macro_score + return F1_macro_score * 100 elif average == "micro": P_score = TP_all / float(TP_all + FP_all) if (TP_all + FP_all) != 0 else 0 R_score = TP_all / float(TP_all + FN_all) if (TP_all + FN_all) != 0 else 0 - F1_micro_score = _compute_F1(P_score, R_score) + F1_micro_score = _compute_F1(P_score, R_score) * 100 return F1_micro_score else: raise ValueError("Options other than micro/macro are not supported.") diff --git a/convlab/base_models/t5/run_seq2seq.py b/convlab/base_models/t5/run_seq2seq.py index 5fa921f0d4c855dc17b7f3b5d1daa8cc404f957c..1566cd4b86d439df79cc3613f7944ebc856aef15 100644 --- a/convlab/base_models/t5/run_seq2seq.py +++ b/convlab/base_models/t5/run_seq2seq.py @@ -221,7 +221,7 @@ class DataTrainingArguments: default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} ) early_stopping_patience: Optional[int] = field( - default=10, metadata={"help": "early stopping patience, set to 0 if you do not want to use early stopping."}, + default=5, metadata={"help": "early stopping patience, set to 0 if you do not want to use early stopping."}, ) def __post_init__(self):