Skip to content
Snippets Groups Projects
Commit bd7a97b6 authored by zqwerty's avatar zqwerty
Browse files

add retrieval augmented NLU

parent 61b6c665
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ import json
from tqdm import tqdm
import re
from transformers import AutoTokenizer
from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
from convlab.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data, retrieve_utterances
from convlab.base_models.t5.nlu.serialization import serialize_dialogue_acts, deserialize_dialogue_acts, equal_da_seq
from convlab.base_models.t5.dst.serialization import serialize_dialogue_state, deserialize_dialogue_state, equal_state_seq
......@@ -120,6 +120,59 @@ def create_goal2dialogue_data(dataset, data_dir, args):
data_by_split[data_split] = data
return data_by_split
def create_retnlu_data(dataset, data_dir, args):
dataset_name = dataset[list(dataset.keys())[0]][0]['dataset']
data_by_split = load_nlu_data(dataset, speaker=args.speaker, use_context=args.context_window_size>0, context_window_size=args.context_window_size)
data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}', \
f'in_context_{args.retrieval_in_context}', f'topk_{args.retrieval_topk}')
os.makedirs(data_dir, exist_ok=True)
turn_pool = []
for d in args.retrieval_datasets:
pool_dataset = load_dataset(d)
for turn in load_nlu_data(pool_dataset, data_split='train', speaker=args.speaker)['train']:
if any([len(das) > 0 for da_type, das in turn['dialogue_acts'].items()]):
turn_pool.append({'dataset': d, **turn})
data_splits = data_by_split.keys()
query_turns = []
for data_split in data_splits:
query_turns.extend(data_by_split[data_split])
augmented_dataset = retrieve_utterances(query_turns, turn_pool, args.retrieval_topk, 'all-MiniLM-L6-v2')
i = 0
for data_split in data_splits:
data = []
for j in tqdm(range(len(data_by_split[data_split])), desc=f'{data_split} sample', leave=False):
sample = augmented_dataset[i+j]
response = f"{sample['speaker']}: {sample['utterance']}"
if args.context_window_size>0:
context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response])
else:
context = response
context = ' '.join([dataset_name, context])
dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts'])
assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq))
retrieved_turns = sample['retrieved_turns']
for t in retrieved_turns:
# in-context learning
retrieved_utterance = f"{t['dataset']} {t['speaker']}: {t['utterance']}"
retrieved_dialogue_acts_seq = serialize_dialogue_acts(t['dialogue_acts'])
if args.retrieval_in_context:
context = f"{retrieved_utterance} => {retrieved_dialogue_acts_seq}\n\n" + context
elif data_split != 'test':
data.append(json.dumps({'context': retrieved_utterance, 'dialogue_acts_seq': retrieved_dialogue_acts_seq}, ensure_ascii=False)+'\n')
data.append(json.dumps({'context': context, 'dialogue_acts_seq': dialogue_acts_seq}, ensure_ascii=False)+'\n')
i += len(data_by_split[data_split])
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
data_by_split[data_split] = data
return data_by_split
def get_max_len(data_by_split, tokenizer):
for data_split in data_by_split.keys():
seq_len = {}
......@@ -136,13 +189,16 @@ def get_max_len(data_by_split, tokenizer):
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue'], help='names of tasks')
parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue', 'retnlu', 'retnlg'], help='names of tasks')
parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets')
parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s)')
parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
parser.add_argument('--len_tokenizer', '-l', type=str, default=None, help='name or path of tokenizer that used to get seq len')
parser.add_argument('--ratio', '-r', type=float, default=None, help='how many data is used for training and evaluation')
parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
parser.add_argument('--retrieval_datasets', metavar='dataset_name for retrieval augmentation', nargs='*', help='names of unified datasets for retrieval')
parser.add_argument('--retrieval_topk', type=int, default=3, help='how many utterances to be retrieved')
parser.add_argument('--retrieval_in_context', action='store_true', default=False, help='whether use the retrieved utterance by in-context learning')
args = parser.parse_args()
print(args)
if args.len_tokenizer:
......
......@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--save_total_limit 1 \
--early_stopping_patience 10 \
--prediction_loss_only \
--load_best_model_at_end \
......
......@@ -40,7 +40,7 @@ python ../run_seq2seq.py \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--save_total_limit 1 \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
......
......@@ -42,7 +42,7 @@ python ../run_seq2seq.py \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 3 \
--save_total_limit 1 \
--prediction_loss_only \
--load_best_model_at_end \
--cache_dir ${cache_dir} \
......
n_gpus=1
task_name="retnlu"
dataset_name="multiwoz21"
speaker="user"
context_window_size=0
retrieval_topk=1
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="nlu_metric.py"
metric_for_best_model="overall_f1"
source_column="context"
target_column="dialogue_acts_seq"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=10
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk}
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
n_gpus=1
task_name="retnlu"
dataset_name="multiwoz21"
speaker="user"
context_window_size=0
ratio=$1
dial_ids_order=$2
retrieval_topk=$3
data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_False/topk_${retrieval_topk}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="nlu_metric.py"
metric_for_best_model="overall_f1"
source_column="context"
target_column="dialogue_acts_seq"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=100
# python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} -r ${ratio} -o ${dial_ids_order}
# python ../run_seq2seq.py \
# --task_name ${task_name} \
# --train_file ${train_file} \
# --validation_file ${validation_file} \
# --source_column ${source_column} \
# --target_column ${target_column} \
# --max_source_length ${max_source_length} \
# --max_target_length ${max_target_length} \
# --truncation_side ${truncation_side} \
# --model_name_or_path ${model_name_or_path} \
# --do_train \
# --do_eval \
# --save_strategy epoch \
# --evaluation_strategy epoch \
# --save_total_limit 1 \
# --prediction_loss_only \
# --load_best_model_at_end \
# --cache_dir ${cache_dir} \
# --output_dir ${output_dir} \
# --logging_dir ${logging_dir} \
# --overwrite_output_dir \
# --preprocessing_num_workers 4 \
# --per_device_train_batch_size ${per_device_train_batch_size} \
# --per_device_eval_batch_size ${per_device_eval_batch_size} \
# --gradient_accumulation_steps ${gradient_accumulation_steps} \
# --learning_rate ${lr} \
# --num_train_epochs ${num_train_epochs} \
# --adafactor \
# --gradient_checkpointing
# python ../run_seq2seq.py \
# --task_name ${task_name} \
# --test_file ${test_file} \
# --source_column ${source_column} \
# --target_column ${target_column} \
# --max_source_length ${max_source_length} \
# --max_target_length ${max_target_length} \
# --truncation_side ${truncation_side} \
# --model_name_or_path ${output_dir} \
# --do_predict \
# --predict_with_generate \
# --metric_name_or_path ${metric_name_or_path} \
# --cache_dir ${cache_dir} \
# --output_dir ${output_dir} \
# --logging_dir ${logging_dir} \
# --overwrite_output_dir \
# --preprocessing_num_workers 4 \
# --per_device_train_batch_size ${per_device_train_batch_size} \
# --per_device_eval_batch_size ${per_device_eval_batch_size} \
# --gradient_accumulation_steps ${gradient_accumulation_steps} \
# --learning_rate ${lr} \
# --num_train_epochs ${num_train_epochs} \
# --adafactor \
# --gradient_checkpointing
# python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
n_gpus=1
task_name="retnlu"
dataset_name="multiwoz21"
speaker="user"
context_window_size=0
retrieval_topk=$1
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="nlu_metric.py"
metric_for_best_model="overall_f1"
source_column="context"
target_column="dialogue_acts_seq"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=10
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
n_gpus=1
task_name="retnlu"
dataset_name="multiwoz21"
speaker="user"
context_window_size=0
ratio=$1
dial_ids_order=$2
retrieval_topk=$3
data_dir="data/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
output_dir="output/${task_name}/${dataset_name}_${ratio}_order${dial_ids_order}/${speaker}/context_${context_window_size}/in_context_True/topk_${retrieval_topk}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="nlu_metric.py"
metric_for_best_model="overall_f1"
source_column="context"
target_column="dialogue_acts_seq"
truncation_side="left"
max_source_length=512
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=100
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} --retrieval_datasets sgd tm1 tm2 tm3 --retrieval_topk ${retrieval_topk} --retrieval_in_context -r ${ratio} -o ${dial_ids_order}
python ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--prediction_loss_only \
--load_best_model_at_end \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--adafactor \
--gradient_checkpointing
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json -o ${dial_ids_order}
python ../../../nlu/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
......@@ -9,7 +9,9 @@ from abc import ABC, abstractmethod
from pprint import pprint
from convlab.util.file_util import cached_path
import shutil
import importlib
from sentence_transformers import SentenceTransformer, util
import torch
from tqdm import tqdm
class BaseDatabase(ABC):
......@@ -433,6 +435,36 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
return dataset, sorted(list(delex_vocab))
def retrieve_utterances(query_turns, turn_pool, top_k, model_name):
"""
It takes a list of query turns, a list of turn pool, and a top_k value, and returns a list of query
turns with a new key called 'retrieve_utterances' that contains a list of top_k retrieved utterances
from the turn pool
:param query_turns: a list of turns that you want to retrieve utterances for
:param turn_pool: the pool of turns to retrieve from
:param top_k: the number of utterances to retrieve for each query turn
:param model_name: the name of the model you want to use
:return: A list of dictionaries, with a new key 'retrieve_utterances' that is a list of retrieved turns and similarity scores.
"""
embedder = SentenceTransformer(model_name)
corpus = [turn['utterance'] for turn in turn_pool]
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
corpus_embeddings = corpus_embeddings.to('cuda')
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
queries = [turn['utterance'] for turn in query_turns]
query_embeddings = embedder.encode(queries, convert_to_tensor=True)
query_embeddings = query_embeddings.to('cuda')
query_embeddings = util.normalize_embeddings(query_embeddings)
hits = util.semantic_search(query_embeddings, corpus_embeddings, score_function=util.dot_score, top_k=top_k)
for i, turn in enumerate(query_turns):
turn['retrieved_turns'] = [{'score': hit['score'], **turn_pool[hit['corpus_id']]} for hit in hits[i]]
return query_turns
if __name__ == "__main__":
dataset = load_dataset('multiwoz21', dial_ids_order=0)
train_ratio = 0.1
......@@ -447,7 +479,11 @@ if __name__ == "__main__":
print(res[0], len(res))
data_by_split = load_nlu_data(dataset, data_split='test', speaker='user')
pprint(data_by_split['test'][0])
query_turns = data_by_split['test'][:10]
pool_dataset = load_dataset('camrest')
turn_pool = load_nlu_data(pool_dataset, data_split='train', speaker='user')['train']
augmented_dataset = retrieve_utterances(query_turns, turn_pool, 3, 'all-MiniLM-L6-v2')
pprint(augmented_dataset[0])
def delex_slot(domain, slot, value):
# only use slot name for delexicalization
......
......@@ -79,6 +79,7 @@ s3transfer==0.6.0
sacrebleu==2.1.0
scikit-learn==1.1.1
scipy==1.8.1
sentence-transformers=2.2.2
seqeval==1.2.2
simplejson==3.17.6
six==1.16.0
......
......@@ -39,6 +39,7 @@ setup(
'tensorboard',
'torch>=1.6',
'transformers>=4.0',
'sentence-transformers',
'datasets>=1.8',
'seqeval',
'spacy',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment