Skip to content
Snippets Groups Projects
Commit 89e557bc authored by zqwerty's avatar zqwerty
Browse files

add nlg fewshot, update run_seq2seq for separated truncation

parent aee0a26e
Branches
No related tags found
No related merge requests found
...@@ -87,9 +87,9 @@ def create_nlg_data(dataset, data_dir, args): ...@@ -87,9 +87,9 @@ def create_nlg_data(dataset, data_dir, args):
dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts']) dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
if args.context_window_size>0: if args.context_window_size>0:
context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: ']) context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: '])
context = f'{dialogue_acts_seq}\n{context}' context = f'{dialogue_acts_seq}\n\n{context}'
else: else:
context = f'{dialogue_acts_seq}\n{sample["speaker"]}: ' context = f'{dialogue_acts_seq}\n\n{sample["speaker"]}: '
assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq)) assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq))
data.append(json.dumps({'context+da': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n') data.append(json.dumps({'context+da': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n')
......
...@@ -5,7 +5,7 @@ from convlab2.util import load_dataset, load_nlg_data ...@@ -5,7 +5,7 @@ from convlab2.util import load_dataset, load_nlg_data
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result): def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
assert os.path.exists(predict_result) assert os.path.exists(predict_result)
dataset = load_dataset(dataset_name) dataset = load_dataset(dataset_name, args.dial_ids_order)
data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
if save_dir is None: if save_dir is None:
...@@ -28,6 +28,7 @@ if __name__ == '__main__': ...@@ -28,6 +28,7 @@ if __name__ == '__main__':
parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
n_gpus=1
task_name="nlg"
dataset_name=$1
speaker="system"
context_window_size=$2
ratio=$3
dial_ids_order=$4
data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="nlg_metric.py"
metric_for_best_model="bleu"
source_column="context+da"
target_column="response"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=64
gradient_accumulation_steps=4
lr=1e-3
num_train_epochs=100
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -r ${ratio} -o ${dial_ids_order}
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--prediction_loss_only \
--load_best_model_at_end \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
...@@ -25,6 +25,8 @@ import sys ...@@ -25,6 +25,8 @@ import sys
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from itertools import zip_longest
from functools import reduce
import datasets import datasets
import numpy as np import numpy as np
...@@ -439,8 +441,14 @@ def main(): ...@@ -439,8 +441,14 @@ def main():
inputs.append(examples[source_column][i]) inputs.append(examples[source_column][i])
targets.append(examples[target_column][i]) targets.append(examples[target_column][i])
inputs = [prefix + inp for inp in inputs] inputs = [prefix + '\n\n' + inp for inp in inputs]
if padding:
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
else:
# truncate each part separated by \n\n respectively
split_inputs = [inp.split('\n\n') for inp in inputs]
split_model_inputs = [tokenizer(x, max_length=data_args.max_source_length, padding=False, truncation=True) for x in split_inputs]
model_inputs = {k: [reduce(lambda x, y: x[:-1]+y, item[k]) for item in split_model_inputs] for k in split_model_inputs[0]}
# Setup the tokenizer for targets # Setup the tokenizer for targets
with tokenizer.as_target_tokenizer(): with tokenizer.as_target_tokenizer():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment