Skip to content
Snippets Groups Projects
Commit 6f318bea authored by zqwerty's avatar zqwerty
Browse files

base model run response generation

parent ec1dd65e
Branches
No related tags found
No related merge requests found
import os
import json
from tqdm import tqdm
from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
def create_rg_data(data_by_split, data_dir):
data_splits = data_by_split.keys()
file_name = os.path.join(data_dir, f"source_prefix.txt")
with open(file_name, "w") as f:
f.write("generate a system response according to the context: ")
for data_split in data_splits:
data = []
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']])
response = f"{sample['speaker']}: {sample['utterance']}"
data.append(json.dumps({'context': context, 'response': response})+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w") as f:
f.writelines(data)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument('--tasks', metavar='task_name', nargs='*', choices=['rg'], help='names of tasks')
parser.add_argument('--datasets', metavar='dataset_name', nargs='*', help='names of unified datasets')
parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, default: data/$task_name/$dataset_name')
args = parser.parse_args()
print(args)
for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name)
for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_by_split = eval(f"load_{task_name}_data")(dataset)
data_dir = os.path.join(args.save_dir, task_name, dataset_name)
os.makedirs(data_dir, exist_ok=True)
eval(f"create_{task_name}_data")(data_by_split, data_dir)
\ No newline at end of file
n_gpus=8
task_name="rg"
dataset_name="multiwoz21"
data_dir="data/${task_name}/${dataset_name}"
output_dir="output/${task_name}/${dataset_name}"
cache_dir="cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
source_prefix="${data_dir}/source_prefix.txt"
source_column="context"
target_column="response"
model_name_or_path="t5-small"
per_device_train_batch_size=32
per_device_eval_batch_size=128
gradient_accumulation_steps=1
lr=1e-3
num_train_epochs=5
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--source_prefix ${source_prefix} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--load_best_model_at_end \
--predict_with_generate \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment