diff --git a/convlab/base_models/t5/create_data.py b/convlab/base_models/t5/create_data.py index 75c282848fa30deca8ca66ecd176df35d8588ed0..a1b10adb2ef0440868f4575e2ce2ce76508f13c4 100644 --- a/convlab/base_models/t5/create_data.py +++ b/convlab/base_models/t5/create_data.py @@ -3,7 +3,7 @@ import json from tqdm import tqdm import re from transformers import AutoTokenizer -from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data +from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, retrieve_utterances from convlab.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq from convlab.base_models.t5.dst.serialization import serialize_dialogue_state, deserialize_dialogue_state, equal_state_seq @@ -120,6 +120,59 @@ def create_goal2dialogue_data(dataset, data_dir, args): data_by_split[data_split] = data return data_by_split +def create_retnlu_data(dataset, data_dir, args): + dataset_name = dataset[list(dataset.keys())[0]][0]['dataset'] + data_by_split = load_nlu_data(dataset, speaker=args.speaker, use_context=args.context_window_size>0, context_window_size=args.context_window_size) + data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}', \ + f'in_context_{args.retrieval_in_context}', f'topk_{args.retrieval_topk}') + os.makedirs(data_dir, exist_ok=True) + + turn_pool = [] + for d in args.retrieval_datasets: + pool_dataset = load_dataset(d) + for turn in load_nlu_data(pool_dataset, data_split='train', speaker=args.speaker)['train']: + if any([len(das) > 0 for da_type, das in turn['dialogue_acts'].items()]): + turn_pool.append({'dataset': d, **turn}) + + data_splits = data_by_split.keys() + query_turns = [] + for data_split in data_splits: + query_turns.extend(data_by_split[data_split]) + augmented_dataset = retrieve_utterances(query_turns, turn_pool, args.retrieval_topk, 'all-MiniLM-L6-v2') + + i = 0 + for data_split in data_splits: + data = [] + for j in tqdm(range(len(data_by_split[data_split])), desc=f'{data_split} sample', leave=False): + sample = augmented_dataset[i+j] + response = f"{sample['speaker']}: {sample['utterance']}" + if args.context_window_size>0: + context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response]) + else: + context = response + context = ' '.join([dataset_name, context]) + dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts']) + assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq)) + + retrieved_turns = sample['retrieved_turns'] + for t in retrieved_turns: + # in-context learning + retrieved_utterance = f"{t['dataset']} {t['speaker']}: {t['utterance']}" + retrieved_dialogue_acts_seq = serialize_dialogue_acts(t['dialogue_acts']) + if args.retrieval_in_context: + context = f"{retrieved_utterance} => {retrieved_dialogue_acts_seq}\n\n" + context + elif data_split != 'test': + data.append(json.dumps({'context': retrieved_utterance, 'dialogue_acts_seq': retrieved_dialogue_acts_seq}, ensure_ascii=False)+'\n') + + data.append(json.dumps({'context': context, 'dialogue_acts_seq': dialogue_acts_seq}, ensure_ascii=False)+'\n') + i += len(data_by_split[data_split]) + + file_name = os.path.join(data_dir, f"{data_split}.json") + with open(file_name, "w", encoding='utf-8') as f: + f.writelines(data) + data_by_split[data_split] = data + return data_by_split + def get_max_len(data_by_split, tokenizer): for data_split in data_by_split.keys(): seq_len = {} @@ -136,13 +189,16 @@ def get_max_len(data_by_split, tokenizer): if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser(description="create data for seq2seq training") - parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue'], help='names of tasks') + parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue', 'retnlu', 'retnlg'], help='names of tasks') parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets') parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s)') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') parser.add_argument('--len_tokenizer', '-l', type=str, default=None, help='name or path of tokenizer that used to get seq len') parser.add_argument('--ratio', '-r', type=float, default=None, help='how many data is used for training and evaluation') parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') + parser.add_argument('--retrieval_datasets', metavar='dataset_name for retrieval augmentation', nargs='*', help='names of unified datasets for retrieval') + parser.add_argument('--retrieval_topk', type=int, default=3, help='how many utterances to be retrieved') + parser.add_argument('--retrieval_in_context', action='store_true', default=False, help='whether use the retrieved utterance by in-context learning') args = parser.parse_args() print(args) if args.len_tokenizer: diff --git a/convlab/base_models/t5/dst/run_dst_fewshot.sh b/convlab/base_models/t5/dst/run_dst_fewshot.sh index d45719112e50dd44672ab52b28c04014cb5d6e5c..f548c053b544b51101f0cfbcc0b1a7b3a09c8088 100644 --- a/convlab/base_models/t5/dst/run_dst_fewshot.sh +++ b/convlab/base_models/t5/dst/run_dst_fewshot.sh @@ -40,7 +40,7 @@ python ../run_seq2seq.py \ --do_eval \ --save_strategy epoch \ --evaluation_strategy epoch \ - --save_total_limit 3 \ + --save_total_limit 1 \ --early_stopping_patience 10 \ --prediction_loss_only \ --load_best_model_at_end \ diff --git a/convlab/base_models/t5/nlu/run_nlu.sh b/convlab/base_models/t5/nlu/run_nlu.sh index fb9be0227b3cced261ed6ccbffa9857e477012a2..8cba74aca0510464d176aa44ef0388c914796f5f 100644 --- a/convlab/base_models/t5/nlu/run_nlu.sh +++ b/convlab/base_models/t5/nlu/run_nlu.sh @@ -40,7 +40,7 @@ python ../run_seq2seq.py \ --do_eval \ --save_strategy epoch \ --evaluation_strategy epoch \ - --save_total_limit 3 \ + --save_total_limit 1 \ --prediction_loss_only \ --cache_dir ${cache_dir} \ --output_dir ${output_dir} \ diff --git a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh index 568c271323cf2472f7989e0cb68e9af051bcc89b..8da69801df77d8f72d23204c8cf008ea7512d10c 100644 --- a/convlab/base_models/t5/nlu/run_nlu_fewshot.sh +++ b/convlab/base_models/t5/nlu/run_nlu_fewshot.sh @@ -42,7 +42,7 @@ python ../run_seq2seq.py \ --do_eval \ --save_strategy epoch \ --evaluation_strategy epoch \ - --save_total_limit 3 \ + --save_total_limit 1 \ --prediction_loss_only \ --load_best_model_at_end \ --cache_dir ${cache_dir} \ diff --git a/convlab/base_models/t5/nlu/run_retnlu.sh b/convlab/base_models/t5/nlu/run_retnlu.sh new file mode 100644 index 0000000000000000000000000000000000000000..b45a0e45643fd5a2247633305df7e0c1f11ce848 --- /dev/null +++ b/convlab/base_models/t5/nlu/run_retnlu.sh @@ -0,0 +1,86 @@ +n_gpus=1 +task_name="retnlu" +dataset_name="multiwoz21" +speaker="user" +context_window_size=0 +retrieval_topk=1 +data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}" +output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="nlu_metric.py" +metric_for_best_model="overall_f1" +source_column="context" +target_column="dialogue_acts_seq" +truncation_side="left" +max_source_length=512 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=128 +per_device_eval_batch_size=64 +gradient_accumulation_steps=2 +lr=1e-3 +num_train_epochs=10 + +python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --save_total_limit 1 \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --metric_name_or_path ${metric_name_or_path} \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh new file mode 100644 index 0000000000000000000000000000000000000000..d165859b01485b7885f88e5b1ae3a279e41f4caf --- /dev/null +++ b/convlab/base_models/t5/nlu/run_retnlu_fewshot.sh @@ -0,0 +1,89 @@ +n_gpus=1 +task_name="retnlu" +dataset_name="multiwoz21" +speaker="user" +context_window_size=0 +ratio=$1 +dial_ids_order=$2 +retrieval_topk=$3 +data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}" +output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="nlu_metric.py" +metric_for_best_model="overall_f1" +source_column="context" +target_column="dialogue_acts_seq" +truncation_side="left" +max_source_length=512 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=128 +per_device_eval_batch_size=64 +gradient_accumulation_steps=2 +lr=1e-3 +num_train_epochs=100 + +# python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} -r ${ratio} -o ${dial_ids_order} + +# python ../run_seq2seq.py \ +# --task_name ${task_name} \ +# --train_file ${train_file} \ +# --validation_file ${validation_file} \ +# --source_column ${source_column} \ +# --target_column ${target_column} \ +# --max_source_length ${max_source_length} \ +# --max_target_length ${max_target_length} \ +# --truncation_side ${truncation_side} \ +# --model_name_or_path ${model_name_or_path} \ +# --do_train \ +# --do_eval \ +# --save_strategy epoch \ +# --evaluation_strategy epoch \ +# --save_total_limit 1 \ +# --prediction_loss_only \ +# --load_best_model_at_end \ +# --cache_dir ${cache_dir} \ +# --output_dir ${output_dir} \ +# --logging_dir ${logging_dir} \ +# --overwrite_output_dir \ +# --preprocessing_num_workers 4 \ +# --per_device_train_batch_size ${per_device_train_batch_size} \ +# --per_device_eval_batch_size ${per_device_eval_batch_size} \ +# --gradient_accumulation_steps ${gradient_accumulation_steps} \ +# --learning_rate ${lr} \ +# --num_train_epochs ${num_train_epochs} \ +# --adafactor \ +# --gradient_checkpointing + +# python ../run_seq2seq.py \ +# --task_name ${task_name} \ +# --test_file ${test_file} \ +# --source_column ${source_column} \ +# --target_column ${target_column} \ +# --max_source_length ${max_source_length} \ +# --max_target_length ${max_target_length} \ +# --truncation_side ${truncation_side} \ +# --model_name_or_path ${output_dir} \ +# --do_predict \ +# --predict_with_generate \ +# --metric_name_or_path ${metric_name_or_path} \ +# --cache_dir ${cache_dir} \ +# --output_dir ${output_dir} \ +# --logging_dir ${logging_dir} \ +# --overwrite_output_dir \ +# --preprocessing_num_workers 4 \ +# --per_device_train_batch_size ${per_device_train_batch_size} \ +# --per_device_eval_batch_size ${per_device_eval_batch_size} \ +# --gradient_accumulation_steps ${gradient_accumulation_steps} \ +# --learning_rate ${lr} \ +# --num_train_epochs ${num_train_epochs} \ +# --adafactor \ +# --gradient_checkpointing + +# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh new file mode 100644 index 0000000000000000000000000000000000000000..82dae873ebb419d1d311f347a813f1da6071dccb --- /dev/null +++ b/convlab/base_models/t5/nlu/run_retnlu_in_context.sh @@ -0,0 +1,86 @@ +n_gpus=1 +task_name="retnlu" +dataset_name="multiwoz21" +speaker="user" +context_window_size=0 +retrieval_topk=$1 +data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}" +output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="nlu_metric.py" +metric_for_best_model="overall_f1" +source_column="context" +target_column="dialogue_acts_seq" +truncation_side="left" +max_source_length=512 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=128 +per_device_eval_batch_size=64 +gradient_accumulation_steps=2 +lr=1e-3 +num_train_epochs=10 + +python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --save_total_limit 1 \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --metric_name_or_path ${metric_name_or_path} \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh new file mode 100644 index 0000000000000000000000000000000000000000..836152f80e9d21695aaadde0016aa7399eedbdf2 --- /dev/null +++ b/convlab/base_models/t5/nlu/run_retnlu_in_context_fewshot.sh @@ -0,0 +1,89 @@ +n_gpus=1 +task_name="retnlu" +dataset_name="multiwoz21" +speaker="user" +context_window_size=0 +ratio=$1 +dial_ids_order=$2 +retrieval_topk=$3 +data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}" +output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}" +cache_dir="../cache" +logging_dir="${output_dir}/runs" +train_file="${data_dir}/train.json" +validation_file="${data_dir}/validation.json" +test_file="${data_dir}/test.json" +metric_name_or_path="nlu_metric.py" +metric_for_best_model="overall_f1" +source_column="context" +target_column="dialogue_acts_seq" +truncation_side="left" +max_source_length=512 +max_target_length=512 +model_name_or_path="t5-small" +per_device_train_batch_size=128 +per_device_eval_batch_size=64 +gradient_accumulation_steps=2 +lr=1e-3 +num_train_epochs=100 + +python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context -r ${ratio} -o ${dial_ids_order} + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --train_file ${train_file} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${model_name_or_path} \ + --do_train \ + --do_eval \ + --save_strategy epoch \ + --evaluation_strategy epoch \ + --save_total_limit 1 \ + --prediction_loss_only \ + --load_best_model_at_end \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python ../run_seq2seq.py \ + --task_name ${task_name} \ + --test_file ${test_file} \ + --source_column ${source_column} \ + --target_column ${target_column} \ + --max_source_length ${max_source_length} \ + --max_target_length ${max_target_length} \ + --truncation_side ${truncation_side} \ + --model_name_or_path ${output_dir} \ + --do_predict \ + --predict_with_generate \ + --metric_name_or_path ${metric_name_or_path} \ + --cache_dir ${cache_dir} \ + --output_dir ${output_dir} \ + --logging_dir ${logging_dir} \ + --overwrite_output_dir \ + --preprocessing_num_workers 4 \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --per_device_eval_batch_size ${per_device_eval_batch_size} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --learning_rate ${lr} \ + --num_train_epochs ${num_train_epochs} \ + --adafactor \ + --gradient_checkpointing + +python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order} + +python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py index e24658410738b290da97149382c8c89030936679..a35199bf344efa75e1736c2a46f20a33f5d8d28c 100644 --- a/convlab/util/unified_datasets_util.py +++ b/convlab/util/unified_datasets_util.py @@ -9,7 +9,9 @@ from abc import ABC, abstractmethod from pprint import pprint from convlab.util.file_util import cached_path import shutil -import importlib +from sentence_transformers import SentenceTransformer, util +import torch +from tqdm import tqdm class BaseDatabase(ABC): @@ -433,6 +435,36 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore return dataset, sorted(list(delex_vocab)) +def retrieve_utterances(query_turns, turn_pool, top_k, model_name): + """ + It takes a list of query turns, a list of turn pool, and a top_k value, and returns a list of query + turns with a new key called 'retrieve_utterances' that contains a list of top_k retrieved utterances + from the turn pool + + :param query_turns: a list of turns that you want to retrieve utterances for + :param turn_pool: the pool of turns to retrieve from + :param top_k: the number of utterances to retrieve for each query turn + :param model_name: the name of the model you want to use + :return: A list of dictionaries, with a new key 'retrieve_utterances' that is a list of retrieved turns and similarity scores. + """ + embedder = SentenceTransformer(model_name) + corpus = [turn['utterance'] for turn in turn_pool] + corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True) + corpus_embeddings = corpus_embeddings.to('cuda') + corpus_embeddings = util.normalize_embeddings(corpus_embeddings) + + queries = [turn['utterance'] for turn in query_turns] + query_embeddings = embedder.encode(queries, convert_to_tensor=True) + query_embeddings = query_embeddings.to('cuda') + query_embeddings = util.normalize_embeddings(query_embeddings) + + hits = util.semantic_search(query_embeddings, corpus_embeddings, score_function=util.dot_score, top_k=top_k) + + for i, turn in enumerate(query_turns): + turn['retrieved_turns'] = [{'score': hit['score'], **turn_pool[hit['corpus_id']]} for hit in hits[i]] + return query_turns + + if __name__ == "__main__": dataset = load_dataset('multiwoz21', dial_ids_order=0) train_ratio = 0.1 @@ -447,7 +479,11 @@ if __name__ == "__main__": print(res[0], len(res)) data_by_split = load_nlu_data(dataset, data_split='test', speaker='user') - pprint(data_by_split['test'][0]) + query_turns = data_by_split['test'][:10] + pool_dataset = load_dataset('camrest') + turn_pool = load_nlu_data(pool_dataset, data_split='train', speaker='user')['train'] + augmented_dataset = retrieve_utterances(query_turns, turn_pool, 3, 'all-MiniLM-L6-v2') + pprint(augmented_dataset[0]) def delex_slot(domain, slot, value): # only use slot name for delexicalization diff --git a/requirements.txt b/requirements.txt index eba77a0fab5eac6f59af19f4456a9e35c2ca5834..380a19d536a5770658d9c6da82e883f65e9a9961 100644 --- a/requirements.txt +++ b/requirements.txt @@ -79,6 +79,7 @@ s3transfer==0.6.0 sacrebleu==2.1.0 scikit-learn==1.1.1 scipy==1.8.1 +sentence-transformers=2.2.2 seqeval==1.2.2 simplejson==3.17.6 six==1.16.0 diff --git a/setup.py b/setup.py index 953e5486fbc68a02c2db65a4f4760076681f78f6..07d966b7a758e700fba48e16ec5d9e32cf62991a 100755 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ setup( 'tensorboard', 'torch>=1.6', 'transformers>=4.0', + 'sentence-transformers', 'datasets>=1.8', 'seqeval', 'spacy',