Skip to content
Snippets Groups Projects
Unverified Commit 246bba6f authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #71 from ConvLab/convlab_exp

Convlab exp
parents b9804957 2cb15bb2
No related branches found
No related tags found
No related merge requests found
......@@ -85,6 +85,9 @@ def create_nlg_data(dataset, data_dir, args):
data = []
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
if len(dialogue_acts_seq) == 0:
# skip empty dialogue acts
continue
if args.context_window_size>0:
context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: '])
context = f'{dialogue_acts_seq}\n\n{context}'
......@@ -145,10 +148,10 @@ if __name__ == '__main__':
if args.len_tokenizer:
tokenizer = AutoTokenizer.from_pretrained(args.len_tokenizer)
for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name, args.dial_ids_order)
if args.ratio:
dataset['train'] = dataset['train'][:round(len(dataset['train'])*args.ratio)]
dataset['validation'] = dataset['validation'][:round(len(dataset['validation'])*args.ratio)]
dataset = load_dataset(dataset_name, dial_ids_order=args.dial_ids_order, split2ratio={'train': args.ratio, 'validation': args.ratio})
else:
dataset = load_dataset(dataset_name, args.dial_ids_order)
for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_dir = os.path.join('data', task_name, (dataset_name if not args.ratio else f'{dataset_name}_{args.ratio}_order{args.dial_ids_order}'))
data_by_split = eval(f"create_{task_name}_data")(dataset, data_dir, args)
......
import json
import os
import sys
if __name__ == '__main__':
merged_data = {'train': [], 'validation': [], 'test': []}
print(sys.argv)
for dataset_name in sys.argv[1:]:
data_dir = os.path.join('data/nlg', dataset_name, 'system/context_0')
for data_split in merged_data:
with open(os.path.join(data_dir, f'{data_split}.json'), 'r') as f:
for line in f:
item = json.loads(line)
item['context+da'] = f"{dataset_name}: {item['context+da']}"
merged_data[data_split].append(item)
for data_split in merged_data:
data_dir = os.path.join('data/nlg', '+'.join(sys.argv[1:]), 'system/context_0')
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, f'{data_split}.json'), 'w') as f:
for item in merged_data[data_split]:
f.write(json.dumps(item)+'\n')
......@@ -3,10 +3,8 @@ import os
from convlab.util import load_dataset, load_nlg_data
def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
def merge(dataset_names, speaker, save_dir, context_window_size, predict_result):
assert os.path.exists(predict_result)
dataset = load_dataset(dataset_name, args.dial_ids_order)
data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
if save_dir is None:
save_dir = os.path.dirname(predict_result)
......@@ -14,10 +12,20 @@ def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
os.makedirs(save_dir, exist_ok=True)
predict_result = [json.loads(x)['predictions'].strip() for x in open(predict_result)]
for sample, prediction in zip(data, predict_result):
sample['predictions'] = {'utterance': prediction}
merged = []
i = 0
for dataset_name in dataset_names.split('+'):
print(dataset_name)
dataset = load_dataset(dataset_name, args.dial_ids_order)
data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
for sample in data:
if all([len(sample['dialogue_acts'][da_type])==0 for da_type in sample['dialogue_acts']]):
continue
sample['predictions'] = {'utterance': predict_result[i]}
i += 1
json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__':
......
......@@ -73,6 +73,7 @@ class NLGMetrics(datasets.Metric):
def _compute(self, predictions, references):
"""Returns the scores: bleu"""
references = [" " if ref=="" else ref for ref in references]
bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
return {
......
......@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--save_total_limit 1 \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
......
......@@ -42,7 +42,7 @@ python ../run_seq2seq.py \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--save_total_limit 1 \
--prediction_loss_only \
--load_best_model_at_end \
--cache_dir ${cache_dir} \
......
n_gpus=1
task_name="nlg"
dataset_name="sgd+tm1+tm2+tm3+multiwoz21"
speaker="system"
context_window_size=0
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="nlg_metric.py"
metric_for_best_model="bleu"
source_column="context+da"
target_column="response"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=8
lr=1e-3
num_train_epochs=10
names=$(echo ${dataset_name} | tr "+" "\n")
rm -r ${data_dir}
mkdir -p ${data_dir}
for name in ${names};
do
echo "preprocessing ${name}"
python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
done
python merge_data.py $(echo ${dataset_name} | tr "+" " ")
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
# python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name}
......@@ -31,13 +31,10 @@ for name in ${names};
do
echo "preprocessing ${name}"
python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size}
if [ "${name}" != "${dataset_name}" ]; then
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/train.json" >> ${train_file}
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/validation.json" >> ${validation_file}
cat "data/${task_name}/${name}/${speaker}/context_${context_window_size}/test.json" >> ${test_file}
fi
done
python merge_data.py $(echo ${dataset_name} | tr "+" " ")
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
......
......@@ -36,6 +36,7 @@ def evaluate(predict_result, ontology):
references.append(predict_result[i]['utterance'])
candidates.append(predict_result[i]['predictions']['utterance'])
# metrics['bleu'] = corpus_bleu(references, candidates)
references = [" " if ref=="" else ref for ref in references]
metrics['bleu'] = sacrebleu.corpus_bleu(candidates, [references], lowercase=True).score
# ERROR Rate
......
absl-py==1.1.0
accelerate==0.10.0
aiohttp==3.8.1
aiosignal==1.2.0
async-timeout==4.0.2
......@@ -11,6 +12,7 @@ catalogue==2.0.7
certifi==2022.5.18.1
charset-normalizer==2.0.12
click==8.1.3
colorama==0.4.5
cycler==0.11.0
cymem==2.0.6
datasets==2.3.2
......@@ -51,8 +53,10 @@ packaging==21.3
pandas==1.4.2
pathy==0.6.1
Pillow==9.1.1
portalocker==2.4.0
preshed==3.0.6
protobuf==3.19.4
psutil==5.9.1
pyarrow==8.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
......@@ -69,8 +73,10 @@ regex==2022.6.2
requests==2.28.0
requests-oauthlib==1.3.1
responses==0.18.0
rouge-score==0.0.4
rsa==4.8
s3transfer==0.6.0
sacrebleu==2.1.0
scikit-learn==1.1.1
scipy==1.8.1
seqeval==1.2.2
......@@ -85,6 +91,7 @@ tabulate==0.8.10
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5.1
thinc==8.0.17
threadpoolctl==3.1.0
tokenizers==0.12.1
......
......@@ -24,6 +24,10 @@ setup(
],
setup_requires=['setuptools-git'],
install_requires=[
'accelerate',
'rouge-score',
'sacrebleu',
'tensorboardX',
'boto3',
'matplotlib',
'tabulate',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment