diff --git a/README.md b/README.md index de367a3a7d6d3ab2187d4cbb4a1a10550936370a..6b26b6a25fc8c1c94801a22490e2a5b7f9833452 100755 --- a/README.md +++ b/README.md @@ -16,11 +16,12 @@ ## Updates +- **2023.2.26**: Update ConvLab on PyPI to 3.0.1 to reflect bug fixes. - **2022.11.30**: ConvLab-3 release. ## Installation -You can install ConvLab-3 in one of the following ways according to your need. Higher versions of `torch` and `transformers` may also work. +You can install ConvLab-3 in one of the following ways according to your need. We use `torch>=1.10.1,<=1.13` and `transformers>=4.17.0,<=4.24.0`. Higher versions of `torch` and `transformers` may also work. ### Git clone and pip install in development mode (Recommend) diff --git a/convlab/base_models/t5/dst/merge_predict_res.py b/convlab/base_models/t5/dst/merge_predict_res.py index f25279a87a8404faab03523a072782c1a08b738c..9b519229b5b0a941a6d4172bd1473c6c8077658a 100755 --- a/convlab/base_models/t5/dst/merge_predict_res.py +++ b/convlab/base_models/t5/dst/merge_predict_res.py @@ -4,7 +4,7 @@ from convlab.util import load_dataset, load_dst_data from convlab.base_models.t5.dst.serialization import deserialize_dialogue_state -def merge(dataset_names, speaker, save_dir, context_window_size, predict_result): +def merge(dataset_names, speaker, save_dir, context_window_size, predict_result, dial_ids_order): assert os.path.exists(predict_result) if save_dir is None: @@ -17,14 +17,18 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) i = 0 for dataset_name in dataset_names.split('+'): print(dataset_name) - dataset = load_dataset(dataset_name, args.dial_ids_order) + single = [] + dataset = load_dataset(dataset_name, dial_ids_order) data = load_dst_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] for sample in data: sample['predictions'] = {'state': predict_result[i]} i += 1 + single.append(sample) merged.append(sample) + json.dump(single, open(os.path.join(save_dir, f'{dataset_name}_predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) @@ -35,8 +39,8 @@ if __name__ == '__main__': parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') - parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file test_generated_predictions.json') parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') args = parser.parse_args() print(args) - merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result, args.dial_ids_order) diff --git a/convlab/base_models/t5/dst/run_dst.sh b/convlab/base_models/t5/dst/run_dst.sh index 05975400bd1ca901e1058dc80587e2cce0b0f1bb..0e2b94969049bed448cba89a76d866c58cdcc775 100644 --- a/convlab/base_models/t5/dst/run_dst.sh +++ b/convlab/base_models/t5/dst/run_dst.sh @@ -1,6 +1,6 @@ n_gpus=1 task_name="dst" -dataset_name=$1 +dataset_name=crosswoz speaker="user" context_window_size=100 data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" @@ -17,10 +17,10 @@ target_column="state_seq" truncation_side="left" max_source_length=1024 max_target_length=512 -model_name_or_path="t5-small" -per_device_train_batch_size=64 -per_device_eval_batch_size=64 -gradient_accumulation_steps=2 +model_name_or_path="/data/zhuqi/pre-trained-models/mt5-small" +per_device_train_batch_size=16 +per_device_eval_batch_size=16 +gradient_accumulation_steps=4 lr=1e-3 num_train_epochs=10 @@ -80,6 +80,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/dst/run_dst_fewshot.sh b/convlab/base_models/t5/dst/run_dst_fewshot.sh index 4acd605706752c67d1f1df3b5fa04df13d2e46ad..2e2e999811149acc7926c7813cdad769ca1d8fbc 100644 --- a/convlab/base_models/t5/dst/run_dst_fewshot.sh +++ b/convlab/base_models/t5/dst/run_dst_fewshot.sh @@ -82,6 +82,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/dst/run_dst_multitask.sh b/convlab/base_models/t5/dst/run_dst_multitask.sh index aefb1d5200db292d0e68e7d39498dfbd182d1fa0..7846e049aecf9302063517533e2810aaea9d339c 100644 --- a/convlab/base_models/t5/dst/run_dst_multitask.sh +++ b/convlab/base_models/t5/dst/run_dst_multitask.sh @@ -30,7 +30,7 @@ mkdir -p ${data_dir} for name in ${names}; do echo "preprocessing ${name}" - # python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size} + python ../create_data.py -t ${task_name} -d ${name} -s ${speaker} -c ${context_window_size} done python merge_data.py $(echo ${dataset_name} | tr "+" " ") @@ -89,6 +89,10 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json +for name in ${names}; +do + echo "evaluating ${name}" + python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/${name}_predictions.json +done \ No newline at end of file diff --git a/convlab/base_models/t5/nlg/merge_predict_res.py b/convlab/base_models/t5/nlg/merge_predict_res.py index 7d2995d84737378958a54765f3efc5f996f112e3..a4d1b34155fda3826b6aaa100ddf7a8f5c37c49f 100755 --- a/convlab/base_models/t5/nlg/merge_predict_res.py +++ b/convlab/base_models/t5/nlg/merge_predict_res.py @@ -3,7 +3,7 @@ import os from convlab.util import load_dataset, load_nlg_data -def merge(dataset_names, speaker, save_dir, context_window_size, predict_result): +def merge(dataset_names, speaker, save_dir, context_window_size, predict_result, dial_ids_order): assert os.path.exists(predict_result) if save_dir is None: @@ -16,7 +16,8 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) i = 0 for dataset_name in dataset_names.split('+'): print(dataset_name) - dataset = load_dataset(dataset_name, args.dial_ids_order) + single = [] + dataset = load_dataset(dataset_name, dial_ids_order) data = load_nlg_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] for sample in data: @@ -24,8 +25,11 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) continue sample['predictions'] = {'utterance': predict_result[i]} i += 1 + single.append(sample) merged.append(sample) + json.dump(single, open(os.path.join(save_dir, f'{dataset_name}_predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) @@ -36,8 +40,8 @@ if __name__ == '__main__': parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') - parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file test_generated_predictions.json') parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') args = parser.parse_args() print(args) - merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result, args.dial_ids_order) diff --git a/convlab/base_models/t5/nlg/run_nlg.sh b/convlab/base_models/t5/nlg/run_nlg.sh index 0b5fa390dcaf98b098abc17f18026994ee54702c..718dca4a2e344fc903e20f48e75cccd46b7821ce 100644 --- a/convlab/base_models/t5/nlg/run_nlg.sh +++ b/convlab/base_models/t5/nlg/run_nlg.sh @@ -1,8 +1,8 @@ n_gpus=1 task_name="nlg" -dataset_name=$1 -speaker="system" -context_window_size=$2 +dataset_name=crosswoz +speaker="all" +context_window_size=0 data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" cache_dir="../cache" @@ -17,10 +17,10 @@ target_column="response" truncation_side="left" max_source_length=512 max_target_length=512 -model_name_or_path="t5-small" -per_device_train_batch_size=128 -per_device_eval_batch_size=64 -gradient_accumulation_steps=4 +model_name_or_path="/data/zhuqi/pre-trained-models/mt5-small" +per_device_train_batch_size=32 +per_device_eval_batch_size=16 +gradient_accumulation_steps=8 lr=1e-3 num_train_epochs=10 @@ -80,6 +80,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name} diff --git a/convlab/base_models/t5/nlg/run_nlg_fewshot.sh b/convlab/base_models/t5/nlg/run_nlg_fewshot.sh index 61e50cdaa094b301660d38f74fcf8420424a7d3f..17f110a0da83dd2cd4ef74bff0b9f8f43a8637a1 100644 --- a/convlab/base_models/t5/nlg/run_nlg_fewshot.sh +++ b/convlab/base_models/t5/nlg/run_nlg_fewshot.sh @@ -83,6 +83,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name} diff --git a/convlab/base_models/t5/nlg/run_nlg_multitask.sh b/convlab/base_models/t5/nlg/run_nlg_multitask.sh index dec894aab37a37ba7923d60431fb22ef5ac4d6b6..9556bdcc0d55e4e520c371386438f9a99c71280c 100644 --- a/convlab/base_models/t5/nlg/run_nlg_multitask.sh +++ b/convlab/base_models/t5/nlg/run_nlg_multitask.sh @@ -89,6 +89,10 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -# python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/predictions.json --dataset_name ${dataset_name} +for name in ${names}; +do + echo "evaluating ${name}" + python ../../../nlg/evaluate_unified_datasets.py -p ${output_dir}/${name}_predictions.json --dataset_name ${name} +done \ No newline at end of file diff --git a/convlab/base_models/t5/nlu/merge_predict_res.py b/convlab/base_models/t5/nlu/merge_predict_res.py index e247160769f7e5b0c9445b38e4dc2a5caa567fd0..8c522063f98e38773841554ddddfc9f476b18049 100755 --- a/convlab/base_models/t5/nlu/merge_predict_res.py +++ b/convlab/base_models/t5/nlu/merge_predict_res.py @@ -4,7 +4,7 @@ from convlab.util import load_dataset, load_nlu_data from convlab.base_models.t5.nlu.serialization import deserialize_dialogue_acts -def merge(dataset_names, speaker, save_dir, context_window_size, predict_result): +def merge(dataset_names, speaker, save_dir, context_window_size, predict_result, dial_ids_order): assert os.path.exists(predict_result) if save_dir is None: @@ -17,14 +17,18 @@ def merge(dataset_names, speaker, save_dir, context_window_size, predict_result) i = 0 for dataset_name in dataset_names.split('+'): print(dataset_name) - dataset = load_dataset(dataset_name, args.dial_ids_order) + single = [] + dataset = load_dataset(dataset_name, dial_ids_order) data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test'] for sample in data: sample['predictions'] = {'dialogue_acts': predict_result[i]} i += 1 + single.append(sample) merged.append(sample) + json.dump(single, open(os.path.join(save_dir, f'{dataset_name}_predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) + json.dump(merged, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False) @@ -35,8 +39,8 @@ if __name__ == '__main__': parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances') parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') - parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json') + parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file test_generated_predictions.json') parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') args = parser.parse_args() print(args) - merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result) + merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result, args.dial_ids_order) diff --git a/convlab/base_models/t5/nlu/run_nlu.sh b/convlab/base_models/t5/nlu/run_nlu.sh index b81b04c0f360fe55c25e55f85ff8ceac3578a99d..cf668b5d32179747819ad3dfba04d1fda954acec 100644 --- a/convlab/base_models/t5/nlu/run_nlu.sh +++ b/convlab/base_models/t5/nlu/run_nlu.sh @@ -1,8 +1,8 @@ n_gpus=1 task_name="nlu" -dataset_name=$1 -speaker="user" -context_window_size=$2 +dataset_name=crosswoz +speaker="all" +context_window_size=0 data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}" cache_dir="../cache" @@ -17,10 +17,10 @@ target_column="dialogue_acts_seq" truncation_side="left" max_source_length=512 max_target_length=512 -model_name_or_path="t5-small" -per_device_train_batch_size=128 -per_device_eval_batch_size=64 -gradient_accumulation_steps=2 +model_name_or_path="/data/zhuqi/pre-trained-models/mt5-small" +per_device_train_batch_size=16 +per_device_eval_batch_size=16 +gradient_accumulation_steps=16 lr=1e-3 num_train_epochs=10 @@ -80,6 +80,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh index a966310a5bea242db413dda7b9ca12bcbda0ae43..5f04579f64c097dff1d4da50fa00231354c55d7d 100644 --- a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh +++ b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh @@ -83,6 +83,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_nlu_multitask.sh b/convlab/base_models/t5/nlu/run_nlu_multitask.sh index b91f21e3f02270ff2f1dfa42fe8baa8f16a20acc..4e29168007763e44817b8c8215705864d0131a8f 100644 --- a/convlab/base_models/t5/nlu/run_nlu_multitask.sh +++ b/convlab/base_models/t5/nlu/run_nlu_multitask.sh @@ -89,6 +89,10 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json +for name in ${names}; +do + echo "evaluating ${name}" + python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/${name}_predictions.json +done \ No newline at end of file diff --git a/convlab/base_models/t5/nlu/run_retnlu.sh b/convlab/base_models/t5/nlu/run_retnlu.sh index fd44e063dc84da86e4f77ead69b0e329ac0cc7d1..ede928abd601dc37ae48536dc2ad6229c7a4b556 100644 --- a/convlab/base_models/t5/nlu/run_retnlu.sh +++ b/convlab/base_models/t5/nlu/run_retnlu.sh @@ -81,6 +81,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh index e778c80bdc844dfea732421e9234e8965e20d987..a3bfbfaac21a5b195dcd89572aad386c7b82e612 100644 --- a/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh +++ b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh @@ -84,6 +84,6 @@ num_train_epochs=100 # --optim adafactor \ # --gradient_checkpointing -# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh index 775b4b06ed35f82610466ca96e518e95eb9b86f8..5c4a091d533173691d0686c1a508ded7bee68737 100644 --- a/convlab/base_models/t5/nlu/run_retnlu_in_context.sh +++ b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh @@ -81,6 +81,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh index 913ef7cbad5fae0b3092c29fe0cd5f44604c333d..3a6c4ce0484c1b5366b9d081483fd3c4c9e5c247 100644 --- a/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh +++ b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh @@ -84,6 +84,6 @@ python ../run_seq2seq.py \ --optim adafactor \ --gradient_checkpointing -python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/test_generated_predictions.json -o ${dial_ids_order} python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/run_seq2seq.py b/convlab/base_models/t5/run_seq2seq.py index 5fa921f0d4c855dc17b7f3b5d1daa8cc404f957c..bdf1b10f21c5f9442efc0e5381e6899966fc8872 100644 --- a/convlab/base_models/t5/run_seq2seq.py +++ b/convlab/base_models/t5/run_seq2seq.py @@ -37,6 +37,8 @@ from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, + T5ForConditionalGeneration, + T5Tokenizer, DataCollatorForSeq2Seq, HfArgumentParser, EarlyStoppingCallback, @@ -358,22 +360,40 @@ def main(): revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast_tokenizer, - truncation_side=model_args.truncation_side, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - model = AutoModelForSeq2SeqLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) + try: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + truncation_side=model_args.truncation_side, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + except: + tokenizer = T5Tokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + truncation_side=model_args.truncation_side, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = T5ForConditionalGeneration.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) model.resize_token_embeddings(len(tokenizer)) @@ -612,16 +632,17 @@ def main(): # Predict if training_args.do_predict: - logger.info("*** Predict ***") - predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict") + file_prefix = os.path.splitext(os.path.basename(data_args.test_file))[0] + logger.info(f"*** Predict {file_prefix}***") + predict_results = trainer.predict(predict_dataset, metric_key_prefix=file_prefix) metrics = predict_results.metrics max_predict_samples = ( data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) ) - metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + metrics[f"{file_prefix}_samples"] = min(max_predict_samples, len(predict_dataset)) - trainer.log_metrics("predict", metrics) - trainer.save_metrics("predict", metrics) + trainer.log_metrics(file_prefix, metrics) + trainer.save_metrics(file_prefix, metrics) if trainer.is_world_process_zero(): if training_args.predict_with_generate: @@ -629,10 +650,13 @@ def main(): predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True ) predictions = [pred.strip() for pred in predictions] - output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.json") + output_prediction_file = os.path.join(training_args.output_dir, f"{file_prefix}_generated_predictions.json") with open(output_prediction_file, "w", encoding='utf-8') as writer: - for sample, pred in zip(raw_datasets["test"], predictions): - sample["predictions"] = pred + for idx, sample in enumerate(raw_datasets["test"]): + if training_args.num_return_sequences > 1: + sample["predictions"] = predictions[idx*training_args.num_return_sequences:(idx+1)*training_args.num_return_sequences] + else: + sample["predictions"] = predictions[idx] writer.write(json.dumps(sample, ensure_ascii=False)+'\n') kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": data_args.task_name} diff --git a/convlab/dst/rule/multiwoz/dst.py b/convlab/dst/rule/multiwoz/dst.py index eb7d31cf4d5fd59089164210a08c2bf10e51b6ad..57e438a88e8c214b995acbab7fec1210b4f63e5e 100755 --- a/convlab/dst/rule/multiwoz/dst.py +++ b/convlab/dst/rule/multiwoz/dst.py @@ -70,82 +70,88 @@ if __name__ == '__main__': dst = RuleDST() - # Action is a dict. Its keys are strings(domain-type pairs, both uppercase and lowercase is OK) and its values are list of lists. - # The domain may be one of ('Attraction', 'Hospital', 'Booking', 'Hotel', 'Restaurant', 'Taxi', 'Train', 'Police'). - # The type may be "inform" or "request". - - # For example, the action below has a key "Hotel-Inform", in which "Hotel" is domain and "Inform" is action type. - # Each list in the value of "Hotel-Inform" is a slot-value pair. "Area" is slot and "east" is value. "Star" is slot and "4" is value. + # Action (dialog acts) is a list of (intent, domain, slot, value) tuples. + # RuleDST will only handle `inform` and `request` actions action = [ - ["Inform", "Hotel", "Area", "east"], - ["Inform", "Hotel", "Stars", "4"] + ["inform", "hotel", "area", "east"], + ["inform", "hotel", "stars", "4"] ] # method `update` updates the attribute `state` of tracker, and returns it. state = dst.update(action) assert state == dst.state - assert state == {'user_action': [], - 'system_action': [], - 'belief_state': {'police': {'book': {'booked': []}, 'semi': {}}, - 'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''}, - 'semi': {'name': '', - 'area': 'east', - 'parking': '', - 'pricerange': '', - 'stars': '4', - 'internet': '', - 'type': ''}}, - 'attraction': {'book': {'booked': []}, - 'semi': {'type': '', 'name': '', 'area': ''}}, - 'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''}, - 'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}}, - 'hospital': {'book': {'booked': []}, 'semi': {'department': ''}}, - 'taxi': {'book': {'booked': []}, - 'semi': {'leaveAt': '', - 'destination': '', - 'departure': '', - 'arriveBy': ''}}, - 'train': {'book': {'booked': [], 'people': ''}, - 'semi': {'leaveAt': '', - 'destination': '', - 'day': '', - 'arriveBy': '', - 'departure': ''}}}, + assert state == {'belief_state': {'attraction': {'area': '', 'name': '', 'type': ''}, + 'hospital': {'department': ''}, + 'hotel': {'area': 'east', + 'book day': '', + 'book people': '', + 'book stay': '', + 'internet': '', + 'name': '', + 'parking': '', + 'price range': '', + 'stars': '4', + 'type': ''}, + 'restaurant': {'area': '', + 'book day': '', + 'book people': '', + 'book time': '', + 'food': '', + 'name': '', + 'price range': ''}, + 'taxi': {'arrive by': '', + 'departure': '', + 'destination': '', + 'leave at': ''}, + 'train': {'arrive by': '', + 'book people': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leave at': ''}}, + 'booked': {}, + 'history': [], 'request_state': {}, + 'system_action': [], 'terminated': False, - 'history': []} + 'user_action': []} # Please call `init_session` before a new dialog. This initializes the attribute `state` of tracker with a default state, which `convlab.util.multiwoz.state.default_state` returns. But You needn't call it before the first dialog, because tracker gets a default state in its constructor. dst.init_session() - action = [["Inform", "Train", "Arrive", "19:45"]] + action = [["inform", "train", "arrive by", "19:45"]] state = dst.update(action) - assert state == {'user_action': [], - 'system_action': [], - 'belief_state': {'police': {'book': {'booked': []}, 'semi': {}}, - 'hotel': {'book': {'booked': [], 'people': '', 'day': '', 'stay': ''}, - 'semi': {'name': '', - 'area': '', - 'parking': '', - 'pricerange': '', - 'stars': '', - 'internet': '', - 'type': ''}}, - 'attraction': {'book': {'booked': []}, - 'semi': {'type': '', 'name': '', 'area': ''}}, - 'restaurant': {'book': {'booked': [], 'people': '', 'day': '', 'time': ''}, - 'semi': {'food': '', 'pricerange': '', 'name': '', 'area': ''}}, - 'hospital': {'book': {'booked': []}, 'semi': {'department': ''}}, - 'taxi': {'book': {'booked': []}, - 'semi': {'leaveAt': '', - 'destination': '', - 'departure': '', - 'arriveBy': ''}}, - 'train': {'book': {'booked': [], 'people': ''}, - 'semi': {'leaveAt': '', - 'destination': '', - 'day': '', - 'arriveBy': '19:45', - 'departure': ''}}}, + assert state == {'belief_state': {'attraction': {'area': '', 'name': '', 'type': ''}, + 'hospital': {'department': ''}, + 'hotel': {'area': '', + 'book day': '', + 'book people': '', + 'book stay': '', + 'internet': '', + 'name': '', + 'parking': '', + 'price range': '', + 'stars': '', + 'type': ''}, + 'restaurant': {'area': '', + 'book day': '', + 'book people': '', + 'book time': '', + 'food': '', + 'name': '', + 'price range': ''}, + 'taxi': {'arrive by': '', + 'departure': '', + 'destination': '', + 'leave at': ''}, + 'train': {'arrive by': '19:45', + 'book people': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leave at': ''}}, + 'booked': {}, + 'history': [], 'request_state': {}, + 'system_action': [], 'terminated': False, - 'history': []} + 'user_action': []} diff --git a/convlab/dst/rule/multiwoz/evaluate.py b/convlab/dst/rule/multiwoz/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..d25889c788a46d52e324beb8da7e2c5cb3023b1a --- /dev/null +++ b/convlab/dst/rule/multiwoz/evaluate.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 DSML Group, Heinrich Heine University, Düsseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MultiWOZ Test data inference for RuleDST and BERTNLU+RuleDST""" + +import json +from copy import deepcopy +import os + +from tqdm import tqdm + +from convlab.util import load_dataset, load_dst_data +from convlab.dst.rule.multiwoz.dst import RuleDST +from convlab.nlu.jointBERT.unified_datasets.nlu import BERTNLU + +BERTNLU_PATH = "https://huggingface.co/ConvLab/bert-base-nlu/resolve/main/bertnlu_unified_multiwoz21_user_context3.zip" + + +def flatten_act(acts: dict) -> list: + acts_list = list() + for act_type, _acts in acts.items(): + for act in _acts: + if 'value' in act: + _act = [act['intent'], act['domain'], act['slot'], act['value']] + else: + _act = [act['intent'], act['domain'], act['slot'], ''] + acts_list.append(_act) + return acts_list + + +def load_act_data(dataset: dict) -> list: + data = list() + for dialogue in tqdm(dataset['test']): + dial = [] + for _turn in dialogue['turns']: + if _turn['speaker'] == 'user': + turn = {'user_acts': flatten_act(_turn['dialogue_acts']), + 'state': _turn['state']} + dial.append(turn) + data.append(dial) + return data + + +def load_text_data(dataset: dict) -> list: + data = list() + for dialogue in tqdm(dataset['test']): + dial = [] + turn = {'user': '', 'system': 'Start', 'state': None} + for _turn in dialogue['turns']: + if _turn['speaker'] == 'user': + turn['user'] = _turn['utterance'] + turn['state'] = _turn['state'] + elif _turn['speaker'] == 'system': + turn['system'] = _turn['utterance'] + if turn['user'] and turn['system']: + if turn['system'] == 'Start': + turn['system'] = '' + dial.append(deepcopy(turn)) + turn = {'user': '', 'system': '', 'state': None} + data.append(dial) + return data + + +def predict_acts(data: list, nlu: BERTNLU) -> list: + processed_data = list() + for dialogue in tqdm(data): + context = list() + dial = list() + for turn in dialogue: + context.append(['sys', turn['system']]) + acts = nlu.predict(turn['user'], context=context) + context.append(['usr', turn['user']]) + dial.append({'user_acts': deepcopy(acts), 'state': turn['state']}) + processed_data.append(dial) + return processed_data + + +def predict_states(data: list): + dst = RuleDST() + processed_data = list() + for dialogue in tqdm(data): + dst.init_session() + for turn in dialogue: + pred = dst.update(turn['user_acts']) + dial = {'state': turn['state'], + 'predictions': {'state': deepcopy(pred['belief_state'])}} + processed_data.append(dial) + return processed_data + + +if __name__ == '__main__': + dataset = load_dataset(dataset_name='multiwoz21') + dataset = load_dst_data(dataset, data_split='test', speaker='all', dialogue_acts=True, split_to_turn=False) + + data = load_text_data(dataset) + nlu = BERTNLU(mode='user', config_file='multiwoz21_user_context3.json', model_file=BERTNLU_PATH) + bertnlu_data = predict_acts(data, nlu) + + golden_data = load_act_data(dataset) + + bertnlu_data = predict_states(bertnlu_data) + golden_data = predict_states(golden_data) + + path = os.path.dirname(os.path.realpath(__file__)) + writer = open(os.path.join(path, f"predictions_BERTNLU-RuleDST.json"), 'w') + json.dump(bertnlu_data, writer) + writer.close() + + writer = open(os.path.join(path, f"predictions_RuleDST.json"), 'w') + json.dump(golden_data, writer) + writer.close() diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index e40332048188b7f5a1f43e397896a8b6b201553d..f56bbadc2f4d8fdca102b2bbc996acb0ae5a4a58 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -27,7 +27,7 @@ class SetSUMBTTracker(DST): confidence_threshold='auto', return_belief_state_entropy: bool = False, return_belief_state_mutual_info: bool = False, - store_full_belief_state: bool = False): + store_full_belief_state: bool = True): """ Args: model_path: Model path or download URL @@ -326,7 +326,9 @@ class SetSUMBTTracker(DST): dialogue_state[dom][slot] = val if self.store_full_belief_state: - self.full_belief_state = belief_state + self.info_dict['belief_state_distributions'] = belief_state + if state_mutual_info is not None: + self.info_dict['belief_state_knowledge_uncertainty'] = state_mutual_info # Obtain model output probabilities if self.return_confidence_scores: diff --git a/convlab/dst/trippy/modeling_dst.py b/convlab/dst/trippy/modeling_dst.py index 2828d17ed1e97ebb60b74ab999c28efb1e7bfa88..3bd875b60b5059ae00888756a926080295af8a9b 100644 --- a/convlab/dst/trippy/modeling_dst.py +++ b/convlab/dst/trippy/modeling_dst.py @@ -62,7 +62,7 @@ def TransformerForDST(parent_name): class TransformerForDST(PARENT_CLASSES[parent_name]): def __init__(self, config): assert config.model_type in PARENT_CLASSES - assert self.__class__.__bases__[0] in MODEL_CLASSES + # assert self.__class__.__bases__[0] in MODEL_CLASSES super(TransformerForDST, self).__init__(config) self.model_type = config.model_type self.slot_list = config.dst_slot_list @@ -82,7 +82,7 @@ def TransformerForDST(parent_name): self.refer_index = -1 # Make sure this module has the same name as in the pretrained checkpoint you want to load! - self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config)) + self.add_module(self.model_type, MODEL_CLASSES[PARENT_CLASSES[self.model_type]](config)) if self.model_type == "electra": self.pooler = ElectraPooler(config) diff --git a/convlab/dst/trippy/tracker.py b/convlab/dst/trippy/tracker.py index b0470266b2ce2c7d5cfba29d0270972b0c7cfa78..8ceaedddc5392c7a4ba9cdb7fe2f9ac9b39a72ba 100644 --- a/convlab/dst/trippy/tracker.py +++ b/convlab/dst/trippy/tracker.py @@ -30,10 +30,15 @@ from convlab.dst.trippy.modeling_dst import (TransformerForDST) from convlab.dst.trippy.dataset_interfacer import (create_dataset_interfacer) from convlab.util import relative_import_module_from_unified_datasets + +class BertForDST(TransformerForDST('bert')): pass +class RobertaForDST(TransformerForDST('roberta')): pass +class ElectraForDST(TransformerForDST('electra')): pass + MODEL_CLASSES = { - 'bert': (BertConfig, TransformerForDST('bert'), BertTokenizer), - 'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaTokenizer), - 'electra': (ElectraConfig, TransformerForDST('electra'), ElectraTokenizer), + 'bert': (BertConfig, BertForDST, BertTokenizer), + 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), + 'electra': (ElectraConfig, ElectraForDST, ElectraTokenizer), } diff --git a/convlab/e2e/soloist/multiwoz/soloist.py b/convlab/e2e/soloist/multiwoz/soloist.py index fea24f32dc1d56d205814164541e80ba0c48d321..580a14737d79b17b2e67aa3744611b3dd6272c53 100644 --- a/convlab/e2e/soloist/multiwoz/soloist.py +++ b/convlab/e2e/soloist/multiwoz/soloist.py @@ -10,7 +10,7 @@ from nltk.tokenize import word_tokenize from convlab.util.file_util import cached_path from convlab.e2e.soloist.multiwoz.config import global_config as cfg -from convlab.e2e.soloist.multiwoz.soloist_net import SOLOIST, cuda_ +from convlab.e2e.soloist.multiwoz.soloist_net import SOLOIST from convlab.dialog_agent import Agent from utils import MultiWozReader diff --git a/convlab/e2e/soloist/multiwoz/soloist_net.py b/convlab/e2e/soloist/multiwoz/soloist_net.py new file mode 100644 index 0000000000000000000000000000000000000000..3f23d106076603e6dcd8b33638195dfd8a51a424 --- /dev/null +++ b/convlab/e2e/soloist/multiwoz/soloist_net.py @@ -0,0 +1,48 @@ +import logging +import torch + +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer +) + +from convlab.e2e.soloist.multiwoz.config import global_config as cfg + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + +def cuda_(var): + return var.cuda() if cfg.cuda and torch.cuda.is_available() else var + + +def tensor(var): + return cuda_(torch.tensor(var)) + +class SOLOIST: + + def __init__(self) -> None: + + self.config = AutoConfig.from_pretrained(cfg.model_name_or_path) + self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config) + self.tokenizer = AutoTokenizer.from_pretrained('t5-base') + print('model loaded!') + + self.model = self.model.cuda() if torch.cuda.is_available() else self.model + + def generate(self, inputs): + + self.model.eval() + inputs = self.tokenizer([inputs]) + input_ids = tensor(inputs['input_ids']) + generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p) + decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + + return decoded_preds[0] + + + \ No newline at end of file diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py index ee591e79ac0590582ec6a84be62bb19e31b31004..def3b2f3ae1f5a750049dc6af23e524a77789913 100644 --- a/convlab/nlg/scgpt/scgpt.py +++ b/convlab/nlg/scgpt/scgpt.py @@ -1,3 +1,4 @@ +import pdb import sys sys.path.append('../../..') @@ -6,17 +7,17 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config from torch.nn.parallel import DistributedDataParallel as DDP from convlab.nlg.nlg import NLG -from util import act2str -from scgpt_special_tokens import * +from convlab.nlg.scgpt.util import act2str class SCGPT(NLG): - def __init__(self, dataset_name, model_path, device='cpu'): + def __init__(self, dataset_name, model_path, device='gpu'): super(SCGPT, self).__init__() + self.dataset_name = dataset_name self.device = device self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device) self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') - self.model.load_state_dict(torch.load(model_path)) + self.model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device))) def generate(self, action): if isinstance(action, dict): @@ -50,5 +51,5 @@ class SCGPT(NLG): if self.tokenizer.eos_token in sent: sent = sent[:sent.index(self.tokenizer.eos_token)] return sent - output_strs = [clean_sentence(item) for item in outputs] + output_strs = [clean_sentence(self.tokenizer.decode(item, skip_special_tokens=True)) for item in outputs] return output_strs \ No newline at end of file diff --git a/convlab/nlu/jointBERT/unified_datasets/README.md b/convlab/nlu/jointBERT/unified_datasets/README.md index baf7bc126c24ab4938f74ad8de2173a0dc3de611..0ba2858902ea018d306f61129d4421cc339c160e 100755 --- a/convlab/nlu/jointBERT/unified_datasets/README.md +++ b/convlab/nlu/jointBERT/unified_datasets/README.md @@ -7,6 +7,10 @@ We support training BERTNLU on datasets that are in our unified format. ## Usage +#### Important note! + +The BERTNLU codebase utilizes the speaker parameter to identify the entity that is speaking. If a model is trained with `speaker=user`, it should be utilized by the system agent to anticipate user actions. When setting up a dialogue system, it is crucial to select a checkpoint that has been trained for `speaker=user` for the NLU of the system agent and visa-versa. It is important to differentiate between the `speaker` and the agent's name to avoid confusion. The zipped downloadable model names have the following format `bertnlu_unified_<dataset_name>_<speaker>_context<context_len>.zip`. + #### Preprocess data ```sh diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 821f7271c6dadbf9c4f3ef5e766011513ac89be0..1c541857fe2d93a08684e5fcbc8dc4387fe2f7f4 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -3,6 +3,7 @@ import os import sys import numpy as np import logging +import json from copy import deepcopy from convlab.policy.vec import Vector @@ -79,8 +80,14 @@ class VectorBase(Vector): print("Load actions from file..") with open(os.path.join(dir_path, "sys_da_voc.txt")) as f: self.da_voc = f.read().splitlines() + if self.da_voc[0][0] != "[": + # if act is not a list, we still have the old action dict + self.load_actions_from_data() + else: + self.da_voc = [tuple(json.loads(act)) for act in self.da_voc] with open(os.path.join(dir_path, "user_da_voc.txt")) as f: self.da_voc_opp = f.read().splitlines() + self.da_voc_opp = [tuple(json.loads(act)) for act in self.da_voc_opp] self.generate_dict() @@ -104,14 +111,14 @@ class VectorBase(Vector): if turn['speaker'] == 'system': for act in delex_acts: - act = "_".join(act) + act = tuple([a.lower() for a in act]) if act not in system_dict: system_dict[act] = 1 else: system_dict[act] += 1 else: for act in delex_acts: - act = "_".join(act) + act = tuple([a.lower() for a in act]) if act not in user_dict: user_dict[act] = 1 else: @@ -138,10 +145,10 @@ class VectorBase(Vector): os.makedirs(dir_path, exist_ok=True) with open(os.path.join(dir_path, "sys_da_voc.txt"), "w") as f: for act in self.da_voc: - f.write(act + "\n") + f.write(json.dumps(act) + "\n") with open(os.path.join(dir_path, "user_da_voc.txt"), "w") as f: for act in self.da_voc_opp: - f.write(act + "\n") + f.write(json.dumps(act) + "\n") def load_actions_from_ontology(self): """ @@ -240,7 +247,7 @@ class VectorBase(Vector): for i in range(self.da_dim): action = self.vec2act[i] - action_domain = action.split('_')[0] + action_domain = action[0] if action_domain in domain_active_dict.keys(): if not domain_active_dict[action_domain]: mask_list[i] = 1.0 @@ -253,7 +260,7 @@ class VectorBase(Vector): for i in range(self.da_dim): action = self.vec2act[i] - domain, intent, slot, value = action.split('_') + domain, intent, slot, value = action # NoBook/NoOffer-SLOT does not make sense because policy can not know which constraint made offer impossible # If one wants to do it, lexicaliser needs to do it @@ -281,7 +288,7 @@ class VectorBase(Vector): return mask_list for i in range(self.da_dim): action = self.vec2act[i] - domain, intent, slot, value = action.split('_') + domain, intent, slot, value = action domain_entities = number_entities_dict.get(domain, 1) if intent in ['inform', 'select', 'recommend'] and value != None and value != 'none': @@ -349,10 +356,11 @@ class VectorBase(Vector): def action_vectorize(self, action): action = delexicalize_da(action, self.requestable) - action = flat_da(action) + #action = flat_da(action) act_vec = np.zeros(self.da_dim) for da in action: + da = tuple([a.lower() for a in da]) if da in self.act2vec: act_vec[self.act2vec[da]] = 1. return act_vec @@ -376,29 +384,29 @@ class VectorBase(Vector): if len(act_array) == 0: if self.reqinfo_filler_action: - act_array.append('general_reqinfo_none_none') + act_array.append(("general", "reqinfo", "none", "none")) else: - act_array.append('general_reqmore_none_none') + act_array.append(("general", "reqmore", "none", "none")) action = deflat_da(act_array) entities = {} for domint in action: - domain, intent = domint.split('_') + domain, intent = domint if domain not in entities and domain not in ['general']: entities[domain] = self.dbquery_domain(domain) # From db query find which slot causes no_offer - nooffer = [domint for domint in action if 'nooffer' in domint] + nooffer = [domint for domint in action if 'nooffer' in domint[1]] for domint in nooffer: - domain, intent = domint.split('_') + domain, intent = domint slot = self.find_nooffer_slot(domain) action[domint] = [[slot, '1'] ] if slot != 'none' else [[slot, 'none']] # Randomly select booking constraint "causing" no_book - nobook = [domint for domint in action if 'nobook' in domint] + nobook = [domint for domint in action if 'nobook' in domint[1]] for domint in nobook: - domain, intent = domint.split('_') + domain, intent = domint if domain in self.state: slots = self.state[domain] slots = [slot for slot, i in slots.items() @@ -430,17 +438,19 @@ class VectorBase(Vector): if not self.use_none: # replace all occurences of "none" with an empty string "" - action = [[a_string.replace('none', '') for a_string in a_list] for a_list in action] + f = lambda x: x if x != "none" else "" + action = [[f(x) for x in a_list] for a_list in action] + #action = [[ for a_tuple in a_list] for a_list in action] return action def add_booking_reference(self, action): new_acts = {} for domint in action: - domain, intent = domint.split('_', 1) + domain, intent = domint if intent == 'book' and action[domint]: - ref_domint = f'{domain}_inform' + ref_domint = (domain, "inform") if ref_domint not in new_acts: new_acts[ref_domint] = [] new_acts[ref_domint].append(['ref', '1']) @@ -458,14 +468,14 @@ class VectorBase(Vector): name_inform = {domain: [] for domain in self.domains} # General Inform Condition for Naming - domains = [domint.split('_', 1)[0] for domint in action] + domains = [domint[0] for domint in action] domains = list(set([d for d in domains if d not in ['general']])) for domain in domains: contains_name = False if domain == 'none': raise NameError('Domain not defined') - cur_inform = domain + '_inform' - cur_request = domain + '_request' + cur_inform = (domain, "inform") + cur_request = (domain, "request") index = -1 if cur_inform in action: # Check if current inform within a domain is accompanied by a name inform diff --git a/convlab/policy/vector/vector_binary.py b/convlab/policy/vector/vector_binary.py index e780dc645043f4775b208479abf022dccce649a5..c6b02a1122adac3002ad1e32dd1f495046d0e6be 100755 --- a/convlab/policy/vector/vector_binary.py +++ b/convlab/policy/vector/vector_binary.py @@ -94,9 +94,10 @@ class VectorBinary(VectorBase): def vectorize_system_act(self, state): action = state['system_action'] if self.character == 'sys' else state['user_action'] action = delexicalize_da(action, self.requestable) - action = flat_da(action) + #action = flat_da(action) last_act_vec = np.zeros(self.da_dim) for da in action: + da = tuple(da) if da in self.act2vec: last_act_vec[self.act2vec[da]] = 1. return last_act_vec @@ -104,9 +105,10 @@ class VectorBinary(VectorBase): def vectorize_user_act(self, state): action = state['user_action'] if self.character == 'sys' else state['system_action'] opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) + #opp_action = flat_da(opp_action) opp_act_vec = np.zeros(self.da_opp_dim) for da in opp_action: + da = tuple(da) if da in self.opp2vec: prob = 1.0 opp_act_vec[self.opp2vec[da]] = prob diff --git a/convlab/policy/vector/vector_nodes.py b/convlab/policy/vector/vector_nodes.py index 24b1c1045a55960949c4d5747c066fff7c5906e9..2c7712bc9d7df74d9ae35a36bd8bb9edd4886c60 100644 --- a/convlab/policy/vector/vector_nodes.py +++ b/convlab/policy/vector/vector_nodes.py @@ -116,11 +116,12 @@ class VectorNodes(VectorBase): feature_type = 'last system act' action = state['system_action'] if self.character == 'sys' else state['user_action'] action = delexicalize_da(action, self.requestable) - action = flat_da(action) + #action = flat_da(action) for da in action: + da = tuple(da) if da in self.act2vec: - domain = da.split('_')[0] - description = "system-" + da + domain = da[0] + description = "system-" + "_".join(da) value = 1.0 self.add_graph_node(domain, feature_type, description.lower(), value) @@ -129,12 +130,13 @@ class VectorNodes(VectorBase): feature_type = 'user act' action = state['user_action'] if self.character == 'sys' else state['system_action'] opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) + #opp_action = flat_da(opp_action) for da in opp_action: + da = tuple(da) if da in self.opp2vec: - domain = da.split('_')[0] - description = "user-" + da + domain = da[0] + description = "user-" + "_".join(da) value = 1.0 self.add_graph_node(domain, feature_type, description.lower(), value) diff --git a/convlab/policy/vector/vector_uncertainty.py b/convlab/policy/vector/vector_uncertainty.py index afe8a5b89e2caac03217709f9d36632cbe3904c2..20bf9736b78ed75ae23741c141413aad749c8979 100644 --- a/convlab/policy/vector/vector_uncertainty.py +++ b/convlab/policy/vector/vector_uncertainty.py @@ -95,12 +95,13 @@ class VectorUncertainty(VectorBinary): self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None action = state['user_action'] if self.character == 'sys' else state['system_action'] opp_action = delexicalize_da(action, self.requestable) - opp_action = flat_da(opp_action) + #opp_action = flat_da(opp_action) opp_act_vec = np.zeros(self.da_opp_dim) for da in opp_action: + da = tuple(da) if da in self.opp2vec: if 'belief_state_probs' in state and self.use_confidence_scores: - domain, intent, slot, value = da.split('_') + domain, intent, slot, value = da if domain in state['belief_state_probs']: slot = slot if slot else 'none' if slot in state['belief_state_probs'][domain]: diff --git a/convlab/policy/vtrace_DPT/create_descriptions.py b/convlab/policy/vtrace_DPT/create_descriptions.py index c6e88daba8132dd30c0aaeeff23e6e2b619e1c92..138861262d20aa43901ef8774571d9babf1a94ed 100644 --- a/convlab/policy/vtrace_DPT/create_descriptions.py +++ b/convlab/policy/vtrace_DPT/create_descriptions.py @@ -20,14 +20,8 @@ def create_description_dicts(name='multiwoz21'): db = None db_domains = [] - root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - voc_file = os.path.join(root_dir, f'vector/action_dicts/{name}_VectorBinary/sys_da_voc.txt') - voc_opp_file = os.path.join(root_dir, f'vector/action_dicts/{name}_VectorBinary/user_da_voc.txt') - - with open(voc_file) as f: - da_voc = f.read().splitlines() - with open(voc_opp_file) as f: - da_voc_opp = f.read().splitlines() + da_voc = vector.da_voc + da_voc_opp = vector.da_voc_opp description_dict_semantic = {} @@ -47,13 +41,15 @@ def create_description_dicts(name='multiwoz21'): description_dict_semantic[f"general-{domain}"] = f"domain {domain}" for act in da_voc: - domain, intent, slot, value = act.split("_") + domain, intent, slot, value = act domain = domain.lower() + act = "_".join(act) description_dict_semantic["system-"+act.lower()] = f"last system act {domain} {intent} {slot} {value}" for act in da_voc_opp: - domain, intent, slot, value = [item.lower() for item in act.split("_")] + domain, intent, slot, value = [item.lower() for item in act] domain = domain.lower() + act = "_".join(act) description_dict_semantic["user-"+act.lower()] = f"user act {domain} {intent} {slot} {value}" root_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/convlab/policy/vtrace_DPT/supervised/train_supervised.py b/convlab/policy/vtrace_DPT/supervised/train_supervised.py index 1807a671da7e2938173a18277cd21980ee577a11..ccc407086e3399656d6ec7840e5c920779e2d058 100644 --- a/convlab/policy/vtrace_DPT/supervised/train_supervised.py +++ b/convlab/policy/vtrace_DPT/supervised/train_supervised.py @@ -182,7 +182,7 @@ if __name__ == '__main__': args = arg_parser() root_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - with open(os.path.join(root_directory, 'config.json'), 'r') as f: + with open(os.path.join(root_directory, 'configs/multiwoz21_dpt.json'), 'r') as f: cfg = json.load(f) cfg['dataset_name'] = args.dataset_name diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py index 8ec1d059388048be63be891806e9f73f724f531a..243ba344cba84b5a6f86288b5db45cf9acf9b695 100644 --- a/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py +++ b/convlab/policy/vtrace_DPT/transformer_model/action_embedder.py @@ -24,7 +24,7 @@ class ActionEmbedder(nn.Module): = self.create_dicts(action_dict) #EOS token is considered a "domain" - self.action_dict = dict((key.lower(), value) for key, value in action_dict.items()) + self.action_dict = dict((key, value) for key, value in action_dict.items()) self.action_dict_reversed = dict((value, key) for key, value in self.action_dict.items()) self.embed_domain = torch.randn(len(self.domain_dict), embedding_dim) self.embed_intent = torch.randn(len(self.intent_dict), embedding_dim) @@ -88,19 +88,19 @@ class ActionEmbedder(nn.Module): elif not intent: # Domain was selected, check intents that are allowed for intent in self.intent_dict: - domain_intent = f"{domain}_{intent}" for idx, not_allow in enumerate(legal_mask): semantic_act = self.action_dict_reversed[idx] - if domain_intent in semantic_act and not_allow == 0: + if domain == semantic_act[0] and intent == semantic_act[1] and not_allow == 0: action_mask[self.small_action_dict[intent]] = 0 break else: # Selected domain and intent, need slot-value for slot_value in self.slot_value_dict: - domain_intent_slot = f"{domain}_{intent}_{slot_value}" + slot, value = slot_value for idx, not_allow in enumerate(legal_mask): semantic_act = self.action_dict_reversed[idx] - if domain_intent_slot in semantic_act and not_allow == 0: + if domain == semantic_act[0] and intent == semantic_act[1] \ + and slot == semantic_act[2] and value == semantic_act[3] and not_allow == 0: action_mask[self.small_action_dict[slot_value]] = 0 break @@ -128,14 +128,15 @@ class ActionEmbedder(nn.Module): elif not intent: # Domain was selected, need intent now for intent in self.intent_dict: - domain_intent = f"{domain}_{intent}" - valid = self.is_valid(domain_intent + "_") + domain_intent = (domain, intent) + valid = self.is_valid(domain_intent) if valid: action_mask[self.small_action_dict[intent]] = 0 else: # Selected domain and intent, need slot-value for slot_value in self.slot_value_dict: - domain_intent_slot = f"{domain}_{intent}_{slot_value}" + slot, value = slot_value + domain_intent_slot = (domain, intent, slot, value) valid = self.is_valid(domain_intent_slot) if valid: action_mask[self.small_action_dict[slot_value]] = 0 @@ -160,9 +161,8 @@ class ActionEmbedder(nn.Module): def is_valid(self, part_action): for act in self.action_dict: - if act.startswith(part_action): + if part_action == act[:len(part_action)]: return True - return False def create_action_embeddings(self, embedding_dim): @@ -178,7 +178,7 @@ class ActionEmbedder(nn.Module): action_embeddings[len(small_action_dict)] = self.embed_intent[idx] small_action_dict[intent] = len(small_action_dict) for slot_value in self.slot_value_dict: - slot, value = slot_value.split("_") + slot, value = slot_value slot_idx = self.slot_dict[slot] value_idx = self.value_dict[value] action_embeddings[len(small_action_dict)] = torch.cat( @@ -201,7 +201,7 @@ class ActionEmbedder(nn.Module): action_embeddings.append(intent) small_action_dict[intent] = len(small_action_dict) for slot_value in self.slot_value_dict: - slot, value = slot_value.split("_") + slot, value = slot_value action_embeddings.append(f"{slot} {value}") small_action_dict[slot_value] = len(small_action_dict) @@ -211,7 +211,7 @@ class ActionEmbedder(nn.Module): action_embeddings_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'action_embeddings_{self.dataset_name}.pt') small_action_dict_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), - f'small_action_dict_{self.dataset_name}.json') + f'small_action_dict_{self.dataset_name}.txt') if os.path.exists(action_embeddings_path): self.action_embeddings = torch.load(action_embeddings_path).to(DEVICE) @@ -220,11 +220,15 @@ class ActionEmbedder(nn.Module): torch.save(self.action_embeddings, action_embeddings_path) if os.path.exists(small_action_dict_path): - self.small_action_dict = json.load(open(small_action_dict_path, 'r')) + with open(os.path.join(small_action_dict_path)) as f: + self.small_action_dict = f.read().splitlines() + self.small_action_dict = [tuple(json.loads(act)) for act in self.small_action_dict] + self.small_action_dict = dict((name, idx) for idx, name in enumerate(self.small_action_dict)) else: self.small_action_dict = small_action_dict with open(small_action_dict_path, 'w') as f: - json.dump(self.small_action_dict, f) + for act in self.small_action_dict: + f.write(json.dumps(act) + "\n") self.small_action_dict = small_action_dict @@ -235,7 +239,7 @@ class ActionEmbedder(nn.Module): value_dict = {} slot_value_dict = {} for action in action_dict: - domain, intent, slot, value = [act.lower() for act in action.split('_')] + domain, intent, slot, value = [act.lower() for act in action] if domain not in domain_dict: domain_dict[domain] = len(domain_dict) if intent not in intent_dict: @@ -244,8 +248,8 @@ class ActionEmbedder(nn.Module): slot_dict[slot] = len(slot_dict) if value not in value_dict: value_dict[value] = len(value_dict) - if slot + "_" + value not in slot_value_dict: - slot_value_dict[slot + "_" + value] = len(slot_value_dict) + if (slot, value) not in slot_value_dict: + slot_value_dict[(slot, value)] = len(slot_value_dict) domain_dict['eos'] = len(domain_dict) @@ -255,17 +259,17 @@ class ActionEmbedder(nn.Module): #print("SMALL ACTION LIST:", small_action_list) action_vector = torch.zeros(len(self.action_dict)) - act_string = "" + act_list = [] for idx, act in enumerate(small_action_list): if act == 'eos': break if idx % 3 != 2: - act_string += f"{act}_" + act_list.append(act) else: - act_string += act - action_vector[self.action_dict[act_string]] = 1 - act_string = "" + act_list += list(act) + action_vector[self.action_dict[tuple(act_list)]] = 1 + act_list = [] return action_vector @@ -278,7 +282,8 @@ class ActionEmbedder(nn.Module): action_list = [] for idx, i in enumerate(action): if i == 1: - action_list += self.action_dict_reversed[idx].split("_", 2) + d, i, s, v = self.action_dict_reversed[idx] + action_list += [d, i, (s, v)] if permute and len(action_list) > 3: action_list_new = deepcopy(action_list[-3:]) + deepcopy(action_list[:-3]) diff --git a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt b/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt deleted file mode 100644 index 67e557ce5ef7a0bfd40c60ae5a03937b47b92de2..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/action_embeddings_sgd.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt b/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt deleted file mode 100644 index 619824588654a36cab3bf795e6fe94527b04ba68..0000000000000000000000000000000000000000 Binary files a/convlab/policy/vtrace_DPT/transformer_model/embedded_descriptions_base_sgd.pt and /dev/null differ diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json deleted file mode 100644 index 0d5bd2002fa7d0082e7589b80ae3664781732ece..0000000000000000000000000000000000000000 --- a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict.json +++ /dev/null @@ -1 +0,0 @@ -{"attraction": 0, "general": 1, "hospital": 2, "hotel": 3, "police": 4, "restaurant": 5, "taxi": 6, "train": 7, "eos": 8, "inform": 9, "nooffer": 10, "recommend": 11, "request": 12, "select": 13, "bye": 14, "greet": 15, "reqmore": 16, "welcome": 17, "book": 18, "offerbook": 19, "nobook": 20, "address-1": 21, "address-2": 22, "address-3": 23, "area-1": 24, "area-2": 25, "area-3": 26, "choice-1": 27, "choice-2": 28, "choice-3": 29, "entrance fee-1": 30, "entrance fee-2": 31, "name-1": 32, "name-2": 33, "name-3": 34, "name-4": 35, "phone-1": 36, "postcode-1": 37, "type-1": 38, "type-2": 39, "type-3": 40, "type-4": 41, "type-5": 42, "none-none": 43, "area-?": 44, "entrance fee-?": 45, "name-?": 46, "type-?": 47, "department-1": 48, "department-?": 49, "book day-1": 50, "book people-1": 51, "book stay-1": 52, "internet-1": 53, "parking-1": 54, "price range-1": 55, "price range-2": 56, "ref-1": 57, "stars-1": 58, "stars-2": 59, "book day-?": 60, "book people-?": 61, "book stay-?": 62, "internet-?": 63, "parking-?": 64, "price range-?": 65, "stars-?": 66, "book time-1": 67, "food-1": 68, "food-2": 69, "food-3": 70, "food-4": 71, "postcode-2": 72, "book time-?": 73, "food-?": 74, "arrive by-1": 75, "departure-1": 76, "destination-1": 77, "leave at-1": 78, "arrive by-?": 79, "departure-?": 80, "destination-?": 81, "leave at-?": 82, "arrive by-2": 83, "day-1": 84, "duration-1": 85, "leave at-2": 86, "leave at-3": 87, "price-1": 88, "train id-1": 89, "day-?": 90, "pad": 91} \ No newline at end of file diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json deleted file mode 100644 index 0d4ec8d8a388236d6f65cf9b9c2c28560791b818..0000000000000000000000000000000000000000 --- a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.json +++ /dev/null @@ -1 +0,0 @@ -{"attraction": 0, "general": 1, "hospital": 2, "hotel": 3, "police": 4, "restaurant": 5, "taxi": 6, "train": 7, "eos": 8, "inform": 9, "nooffer": 10, "recommend": 11, "request": 12, "select": 13, "bye": 14, "greet": 15, "reqmore": 16, "welcome": 17, "book": 18, "offerbook": 19, "nobook": 20, "address_1": 21, "address_2": 22, "address_3": 23, "area_1": 24, "area_2": 25, "area_3": 26, "choice_1": 27, "choice_2": 28, "choice_3": 29, "entrance fee_1": 30, "entrance fee_2": 31, "name_1": 32, "name_2": 33, "name_3": 34, "name_4": 35, "phone_1": 36, "postcode_1": 37, "type_1": 38, "type_2": 39, "type_3": 40, "type_4": 41, "type_5": 42, "none_none": 43, "area_?": 44, "entrance fee_?": 45, "name_?": 46, "type_?": 47, "department_1": 48, "department_?": 49, "book day_1": 50, "book people_1": 51, "book stay_1": 52, "internet_1": 53, "parking_1": 54, "price range_1": 55, "price range_2": 56, "ref_1": 57, "stars_1": 58, "stars_2": 59, "book day_?": 60, "book people_?": 61, "book stay_?": 62, "internet_?": 63, "parking_?": 64, "price range_?": 65, "stars_?": 66, "book time_1": 67, "food_1": 68, "food_2": 69, "food_3": 70, "food_4": 71, "postcode_2": 72, "book time_?": 73, "food_?": 74, "arrive by_1": 75, "departure_1": 76, "destination_1": 77, "leave at_1": 78, "arrive by_?": 79, "departure_?": 80, "destination_?": 81, "leave at_?": 82, "arrive by_2": 83, "day_1": 84, "duration_1": 85, "leave at_2": 86, "leave at_3": 87, "price_1": 88, "train id_1": 89, "day_?": 90, "pad": 91} \ No newline at end of file diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.txt b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.txt new file mode 100644 index 0000000000000000000000000000000000000000..b368cd9fb7e978b36ed0fb261aafbe1baeed4ab4 --- /dev/null +++ b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_multiwoz21.txt @@ -0,0 +1,92 @@ +"attraction" +"general" +"hospital" +"hotel" +"police" +"restaurant" +"taxi" +"train" +"eos" +"inform" +"nooffer" +"recommend" +"request" +"select" +"bye" +"greet" +"reqmore" +"welcome" +"book" +"offerbook" +"nobook" +["address", "1"] +["address", "2"] +["address", "3"] +["area", "1"] +["area", "2"] +["area", "3"] +["choice", "1"] +["choice", "2"] +["choice", "3"] +["entrance fee", "1"] +["entrance fee", "2"] +["name", "1"] +["name", "2"] +["name", "3"] +["name", "4"] +["phone", "1"] +["postcode", "1"] +["type", "1"] +["type", "2"] +["type", "3"] +["type", "4"] +["type", "5"] +["none", "none"] +["area", "?"] +["entrance fee", "?"] +["name", "?"] +["type", "?"] +["department", "1"] +["department", "?"] +["book day", "1"] +["book people", "1"] +["book stay", "1"] +["internet", "1"] +["parking", "1"] +["price range", "1"] +["price range", "2"] +["ref", "1"] +["stars", "1"] +["stars", "2"] +["book day", "?"] +["book people", "?"] +["book stay", "?"] +["internet", "?"] +["parking", "?"] +["price range", "?"] +["stars", "?"] +["book time", "1"] +["food", "1"] +["food", "2"] +["food", "3"] +["food", "4"] +["postcode", "2"] +["book time", "?"] +["food", "?"] +["arrive by", "1"] +["departure", "1"] +["destination", "1"] +["leave at", "1"] +["arrive by", "?"] +["departure", "?"] +["destination", "?"] +["leave at", "?"] +["arrive by", "2"] +["day", "1"] +["duration", "1"] +["leave at", "2"] +["leave at", "3"] +["price", "1"] +["train id", "1"] +["day", "?"] +"pad" diff --git a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json b/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json deleted file mode 100644 index c48573262f50cabe1d4cd3811638ebf4c046a7e0..0000000000000000000000000000000000000000 --- a/convlab/policy/vtrace_DPT/transformer_model/small_action_dict_sgd.json +++ /dev/null @@ -1 +0,0 @@ -{"": 0, "alarm_1": 1, "banks_1": 2, "banks_2": 3, "buses_1": 4, "buses_2": 5, "buses_3": 6, "calendar_1": 7, "events_1": 8, "events_2": 9, "events_3": 10, "flights_1": 11, "flights_2": 12, "flights_3": 13, "flights_4": 14, "homes_1": 15, "homes_2": 16, "hotels_1": 17, "hotels_2": 18, "hotels_3": 19, "hotels_4": 20, "media_1": 21, "media_2": 22, "media_3": 23, "messaging_1": 24, "movies_1": 25, "movies_2": 26, "movies_3": 27, "music_1": 28, "music_2": 29, "music_3": 30, "payment_1": 31, "rentalcars_1": 32, "rentalcars_2": 33, "rentalcars_3": 34, "restaurants_1": 35, "restaurants_2": 36, "ridesharing_1": 37, "ridesharing_2": 38, "services_1": 39, "services_2": 40, "services_3": 41, "services_4": 42, "trains_1": 43, "travel_1": 44, "weather_1": 45, "eos": 46, "goodbye": 47, "req_more": 48, "confirm": 49, "inform_count": 50, "notify_success": 51, "offer": 52, "offer_intent": 53, "request": 54, "inform": 55, "notify_failure": 56, "none-none": 57, "new_alarm_name-1": 58, "new_alarm_time-1": 59, "count-1": 60, "alarm_name-1": 61, "alarm_time-1": 62, "addalarm-1": 63, "new_alarm_time-?": 64, "account_type-1": 65, "amount-1": 66, "recipient_account_name-1": 67, "recipient_account_type-1": 68, "balance-1": 69, "transfermoney-1": 70, "account_type-?": 71, "amount-?": 72, "recipient_account_name-?": 73, "recipient_name-1": 74, "transfer_amount-1": 75, "transfer_time-1": 76, "account_balance-1": 77, "recipient_name-?": 78, "transfer_amount-?": 79, "from_location-1": 80, "leaving_date-1": 81, "leaving_time-1": 82, "to_location-1": 83, "travelers-1": 84, "from_station-1": 85, "to_station-1": 86, "transfers-1": 87, "fare-1": 88, "buybusticket-1": 89, "from_location-?": 90, "leaving_date-?": 91, "leaving_time-?": 92, "to_location-?": 93, "travelers-?": 94, "departure_date-1": 95, "departure_time-1": 96, "destination-1": 97, "fare_type-1": 98, "group_size-1": 99, "origin-1": 100, "destination_station_name-1": 101, "origin_station_name-1": 102, "price-1": 103, "departure_date-?": 104, "departure_time-?": 105, "destination-?": 106, "group_size-?": 107, "origin-?": 108, "additional_luggage-1": 109, "from_city-1": 110, "num_passengers-1": 111, "to_city-1": 112, "category-1": 113, "from_city-?": 114, "num_passengers-?": 115, "to_city-?": 116, "event_date-1": 117, "event_location-1": 118, "event_name-1": 119, "event_time-1": 120, "available_end_time-1": 121, "available_start_time-1": 122, "addevent-1": 123, "event_date-?": 124, "event_location-?": 125, "event_name-?": 126, "event_time-?": 127, "city_of_event-1": 128, "date-1": 129, "number_of_seats-1": 130, "address_of_location-1": 131, "subcategory-1": 132, "time-1": 133, "buyeventtickets-1": 134, "category-?": 135, "city_of_event-?": 136, "date-?": 137, "number_of_seats-?": 138, "city-1": 139, "number_of_tickets-1": 140, "venue-1": 141, "venue_address-1": 142, "city-?": 143, "event_type-?": 144, "number_of_tickets-?": 145, "price_per_ticket-1": 146, "airlines-1": 147, "destination_city-1": 148, "inbound_departure_time-1": 149, "origin_city-1": 150, "outbound_departure_time-1": 151, "passengers-1": 152, "return_date-1": 153, "seating_class-1": 154, "destination_airport-1": 155, "inbound_arrival_time-1": 156, "number_stops-1": 157, "origin_airport-1": 158, "outbound_arrival_time-1": 159, "refundable-1": 160, "reserveonewayflight-1": 161, "reserveroundtripflights-1": 162, "airlines-?": 163, "destination_city-?": 164, "inbound_departure_time-?": 165, "origin_city-?": 166, "outbound_departure_time-?": 167, "return_date-?": 168, "is_redeye-1": 169, "arrives_next_day-1": 170, "destination_airport_name-1": 171, "origin_airport_name-1": 172, "is_nonstop-1": 173, "destination_airport-?": 174, "origin_airport-?": 175, "property_name-1": 176, "visit_date-1": 177, "furnished-1": 178, "pets_allowed-1": 179, "phone_number-1": 180, "address-1": 181, "number_of_baths-1": 182, "number_of_beds-1": 183, "rent-1": 184, "schedulevisit-1": 185, "area-?": 186, "number_of_beds-?": 187, "visit_date-?": 188, "has_garage-1": 189, "in_unit_laundry-1": 190, "intent-?": 191, "number_of_baths-?": 192, "check_in_date-1": 193, "hotel_name-1": 194, "number_of_days-1": 195, "number_of_rooms-1": 196, "has_wifi-1": 197, "price_per_night-1": 198, "street_address-1": 199, "star_rating-1": 200, "reservehotel-1": 201, "check_in_date-?": 202, "hotel_name-?": 203, "number_of_days-?": 204, "check_out_date-1": 205, "number_of_adults-1": 206, "where_to-1": 207, "has_laundry_service-1": 208, "total_price-1": 209, "rating-1": 210, "bookhouse-1": 211, "check_out_date-?": 212, "number_of_adults-?": 213, "where_to-?": 214, "location-1": 215, "pets_welcome-1": 216, "average_rating-1": 217, "location-?": 218, "place_name-1": 219, "stay_length-1": 220, "smoking_allowed-1": 221, "stay_length-?": 222, "subtitles-1": 223, "title-1": 224, "directed_by-1": 225, "genre-1": 226, "title-2": 227, "title-3": 228, "playmovie-1": 229, "genre-?": 230, "title-?": 231, "movie_name-1": 232, "subtitle_language-1": 233, "movie_name-2": 234, "movie_name-3": 235, "rentmovie-1": 236, "starring-1": 237, "contact_name-1": 238, "contact_name-?": 239, "show_date-1": 240, "show_time-1": 241, "show_type-1": 242, "theater_name-1": 243, "buymovietickets-1": 244, "movie_name-?": 245, "show_date-?": 246, "show_time-?": 247, "show_type-?": 248, "aggregate_rating-1": 249, "cast-1": 250, "movie_title-1": 251, "percent_rating-1": 252, "playback_device-1": 253, "song_name-1": 254, "album-1": 255, "year-1": 256, "artist-1": 257, "playsong-1": 258, "song_name-?": 259, "playmedia-1": 260, "device-1": 261, "track-1": 262, "payment_method-1": 263, "private_visibility-1": 264, "receiver-1": 265, "payment_method-?": 266, "receiver-?": 267, "dropoff_date-1": 268, "pickup_date-1": 269, "pickup_location-1": 270, "pickup_time-1": 271, "type-1": 272, "car_name-1": 273, "reservecar-1": 274, "dropoff_date-?": 275, "pickup_city-?": 276, "pickup_date-?": 277, "pickup_location-?": 278, "pickup_time-?": 279, "type-?": 280, "car_type-1": 281, "car_type-?": 282, "add_insurance-1": 283, "end_date-1": 284, "start_date-1": 285, "price_per_day-1": 286, "add_insurance-?": 287, "end_date-?": 288, "start_date-?": 289, "party_size-1": 290, "restaurant_name-1": 291, "cuisine-1": 292, "has_live_music-1": 293, "price_range-1": 294, "serves_alcohol-1": 295, "reserverestaurant-1": 296, "cuisine-?": 297, "restaurant_name-?": 298, "time-?": 299, "has_seating_outdoors-1": 300, "has_vegetarian_options-1": 301, "number_of_riders-1": 302, "shared_ride-1": 303, "approximate_ride_duration-1": 304, "ride_fare-1": 305, "number_of_riders-?": 306, "shared_ride-?": 307, "ride_type-1": 308, "wait_time-1": 309, "ride_type-?": 310, "appointment_date-1": 311, "appointment_time-1": 312, "stylist_name-1": 313, "is_unisex-1": 314, "bookappointment-1": 315, "appointment_date-?": 316, "appointment_time-?": 317, "dentist_name-1": 318, "offers_cosmetic_services-1": 319, "doctor_name-1": 320, "therapist_name-1": 321, "class-1": 322, "date_of_journey-1": 323, "from-1": 324, "journey_start_time-1": 325, "to-1": 326, "trip_protection-1": 327, "total-1": 328, "gettraintickets-1": 329, "date_of_journey-?": 330, "from-?": 331, "to-?": 332, "trip_protection-?": 333, "free_entry-1": 334, "good_for_kids-1": 335, "attraction_name-1": 336, "humidity-1": 337, "wind-1": 338, "precipitation-1": 339, "temperature-1": 340, "pad": 341} \ No newline at end of file diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py index 5c6b0d33f1755cf443521307024d6e17280f060d..3bc6550fe903b86f38240b0061f41f91d6326e3e 100644 --- a/convlab/util/custom_util.py +++ b/convlab/util/custom_util.py @@ -169,7 +169,7 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do single_domain_goals, allowed_domains) goals.append(goal[0]) - if conf['model']['process_num'] == 1: + if conf['model']['process_num'] == 1 or save_eval: complete_rate, success_rate, success_rate_strict, avg_return, turns, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \ select_acts, offer_acts, recommend_acts = evaluate(sess, @@ -330,7 +330,6 @@ def create_env(args, policy_sys): def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False, save_path=None, goals=None): - eval_save = {} turn_counter_dict = {} turn_counter = 0.0 @@ -426,10 +425,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False # print('length of dict ' + str(len(eval_save))) if save_flag: - # print("what are you doing") - save_file = open(os.path.join(save_path, 'evaluate_INFO.json'), 'w') - json.dump(eval_save, save_file, cls=NumpyEncoder) - save_file.close() + torch.save(eval_save, os.path.join(save_path, 'evaluate_INFO.pt')) # save dialogue_info and clear mem return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \ diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index a8df10672e5fe6e1ea8631e632d71bd0c2c7ba51..1e5f7ce69eb046e391d752533d024cf9b766ff66 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -34,8 +34,8 @@ def deflat_da(meta): meta = deepcopy(meta) dialog_act = {} for da in meta: - d, i, s, v = da.split('_') - k = '_'.join((d, i)) + d, i, s, v = da + k = (d, i) if k not in dialog_act: dialog_act[k] = [] dialog_act[k].append([s, v]) @@ -45,7 +45,7 @@ def deflat_da(meta): def lexicalize_da(meta, entities, state, requestable): meta = deepcopy(meta) for k, v in meta.items(): - domain, intent = k.split('_') + domain, intent = k if domain in ['general']: continue elif intent in requestable: @@ -99,6 +99,6 @@ def lexicalize_da(meta, entities, state, requestable): tuples = [] for domain_intent, svs in meta.items(): for slot, value in svs: - domain, intent = domain_intent.split('_') + domain, intent = domain_intent tuples.append([intent, domain, slot, value]) return tuples diff --git a/setup.py b/setup.py index c2cae13c857561899c9b71f68e236f30d705f5ce..0a94fb0e50cbc36f96b4b3102690656bfa10f2e1 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name='convlab', - version='3.0.0', + version='3.0.1', packages=find_packages(), license='Apache', description='An Open-source Dialog System Toolkit',