Skip to content
Snippets Groups Projects
Commit a1a2f240 authored by zqwerty's avatar zqwerty
Browse files

update keywords extraction: remove punctunation and do not merge tokens into a...

update keywords extraction: remove punctunation and do not merge tokens into a span if they are separated by punctuation
parent beda1b76
Branches
No related tags found
No related merge requests found
...@@ -26,16 +26,13 @@ def main(args): ...@@ -26,16 +26,13 @@ def main(args):
context.append({'speaker': speaker, 'utt':utt}) context.append({'speaker': speaker, 'utt':utt})
fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n') fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
# min_neg = len(turn['keywords'])
# max_neg = 4 * min_neg
# negative_keywords = random.sample(keywords_set, random.randint(min_neg, max_neg))
# negative_keywords = random.sample(turn_keywords_set, 1)[0]
negative_keywords = turn_keywords[cnt] negative_keywords = turn_keywords[cnt]
cnt += 1 cnt += 1
possible_keywords = turn['keywords'] + list(negative_keywords) possible_keywords = turn['keywords'] + list(negative_keywords)
random.shuffle(possible_keywords) random.shuffle(possible_keywords)
possible_keywords = ' | '.join(possible_keywords) possible_keywords = ' | '.join(possible_keywords)
input_seq = f'possible keywords: {possible_keywords}\n\ncontext: {context_seq}' input_seq = f'possible keywords: {possible_keywords}\n\ncontext: {context_seq}'
if args.noisy:
fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n') fout.write(json.dumps({'keywords+context': input_seq, 'response': utt}, ensure_ascii=False)+'\n')
...@@ -44,6 +41,7 @@ if __name__ == '__main__': ...@@ -44,6 +41,7 @@ if __name__ == '__main__':
parser = ArgumentParser(description="calculate NLU metrics for unified datasets") parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--input_dir', '-i', type=str, help='path to the input files') parser.add_argument('--input_dir', '-i', type=str, help='path to the input files')
parser.add_argument('--output_dir', '-o', type=str, help='path to the output files') parser.add_argument('--output_dir', '-o', type=str, help='path to the output files')
parser.add_argument('--noisy', action='store_true', help='whether add noisy keywords samples')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
dataset_name="metalwoz+sgd+tm1+tm2+tm3" task_name="key2gen_shuffle_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
names=$(echo ${dataset_name} | tr "+" "\n") names=$(echo ${dataset_name} | tr "+" "\n")
model_type="gpt" model_type="gpt"
data_dir=data/key2gen_shuffle_noisy/${model_type}/${name}/${dataset_name} data_dir=data/${task_name}/${model_type}/${name}/${dataset_name}
rm -r ${data_dir} rm -r ${data_dir}
mkdir -p ${data_dir} mkdir -p ${data_dir}
train_file="${data_dir}/train.json" train_file="${data_dir}/train.json"
...@@ -10,11 +11,11 @@ test_file="${data_dir}/test.json" ...@@ -10,11 +11,11 @@ test_file="${data_dir}/test.json"
for name in ${names} for name in ${names}
do do
echo "preprocessing ${name}" echo "preprocessing ${name}"
python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/key2gen_shuffle_noisy/${model_type}/${name} python gen_pretraining_data.py -i data/lm/${name}/${model_type} -o data/${task_name}/${model_type}/${name} --noisy
if [ "${name}" != "${dataset_name}" ]; then if [ "${name}" != "${dataset_name}" ]; then
cat "data/key2gen_shuffle_noisy/gpt/${name}/train.json" >> ${train_file} cat "data/${task_name}/gpt/${name}/train.json" >> ${train_file}
cat "data/key2gen_shuffle_noisy/gpt/${name}/validation.json" >> ${validation_file} cat "data/${task_name}/gpt/${name}/validation.json" >> ${validation_file}
cat "data/key2gen_shuffle_noisy/gpt/${name}/test.json" >> ${test_file} cat "data/${task_name}/gpt/${name}/test.json" >> ${test_file}
fi fi
done done
python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/key2gen_shuffle_noisy/${model_type}/multiwoz21 python gen_pretraining_data.py -i data/lm/multiwoz21/${model_type} -o data/${task_name}/${model_type}/multiwoz21 --noisy
\ No newline at end of file \ No newline at end of file
...@@ -6,17 +6,20 @@ from tqdm import tqdm ...@@ -6,17 +6,20 @@ from tqdm import tqdm
import numpy as np import numpy as np
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize
from transformers import GPT2Tokenizer
from string import punctuation
def merge_tokens(tokens, losses):
def merge_tokens(tokens, losses, loss_merge_func=np.mean): """Merge tokens into words"""
res = [] res = []
i = 0 i = 0
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
loss = losses[i] loss = losses[i]
if token in ['Ġ', 'Ċ']: if token in ['Ġ', 'Ċ']:
if token == 'Ċ' and i < len(tokens) - 1: # "Ġ" means " ", "Ċ" means "\n"
if token == 'Ċ' and i < len(tokens) - 1 and not tokens[i+1].startswith('Ġ'):
tokens[i+1] = 'Ġ'+tokens[i+1] tokens[i+1] = 'Ġ'+tokens[i+1]
i += 1 i += 1
continue continue
...@@ -28,26 +31,23 @@ def merge_tokens(tokens, losses, loss_merge_func=np.mean): ...@@ -28,26 +31,23 @@ def merge_tokens(tokens, losses, loss_merge_func=np.mean):
i += 2 i += 2
continue continue
if token.startswith('Ġ'): if token.startswith('Ġ'):
# Ġ means space # token = token.replace("Ġ", "")
token = token.replace("Ġ", "") res.append([[token], [loss]])
res.append([token, loss])
elif token == '<|endoftext|>': elif token == '<|endoftext|>':
res.append([token, loss]) res.append([[token], [loss]])
else: else:
assert 'Ġ' not in token assert 'Ġ' not in token
if len(res) > 0: if len(res) > 0:
res[-1][0] += token res[-1][0].append(token)
res[-1].append(loss) res[-1][1].append(loss)
else: else:
res.append([token, loss]) res.append([token, loss])
i += 1 i += 1
if loss_merge_func:
for i in range(len(res)):
res[i] = [res[i][0], loss_merge_func(res[i][1:])]
return res return res
def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean): def convert_token_loss2word_loss(token_loss_file):
"""generate a word loss file according to the token loss file"""
word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word')) word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word'))
fin = open(token_loss_file, 'rb') fin = open(token_loss_file, 'rb')
fout = open(word_loss_file, 'w', encoding='utf-8') fout = open(word_loss_file, 'w', encoding='utf-8')
...@@ -56,7 +56,7 @@ def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean): ...@@ -56,7 +56,7 @@ def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean):
for item in tqdm(json_lines.reader(fin)): for item in tqdm(json_lines.reader(fin)):
tokens, losses = item['tokens'], item['losses'] tokens, losses = item['tokens'], item['losses']
assert len(tokens) == len(losses) assert len(tokens) == len(losses)
word2losses = merge_tokens(tokens, losses, loss_merge_func) word2losses = merge_tokens(tokens, losses)
lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]}) lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]})
fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n') fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n')
...@@ -79,6 +79,7 @@ def main(args): ...@@ -79,6 +79,7 @@ def main(args):
return return
stop_words = set(stopwords.words('english')) stop_words = set(stopwords.words('english'))
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
if args.keywords_th_ratio > 0: if args.keywords_th_ratio > 0:
losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])] losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])]
...@@ -88,23 +89,33 @@ def main(args): ...@@ -88,23 +89,33 @@ def main(args):
loss_th = 0 loss_th = 0
def keywords_filter(word_loss_pairs): def keywords_filter(word_loss_pairs):
candidate_indexes = [] index2keyword = {}
for i, word_loss_pair in enumerate(word_loss_pairs): for i, word_loss_pair in enumerate(word_loss_pairs):
if args.stopwords and any([w.lower() in stop_words for w in word_tokenize(word_loss_pair[0])]): words = word_tokenize(word_loss_pair[0])
if args.stopwords and any([w.lower() in stop_words for w in words]):
# skip stopwords
continue continue
if word_loss_pair[1] <= loss_th: if word_loss_pair[1] <= loss_th:
# skip if loss is too small
continue continue
candidate_indexes.append(i) # strip punctuation
strip_punctuation = word_loss_pair[0].strip(punctuation).strip()
if len(strip_punctuation) == 0:
# skip punctuation
continue
index2keyword[i] = strip_punctuation
candidate_indexes = list(index2keyword.keys())
topk = min(round(args.keywords_ratio*len(word_loss_pairs)), args.keywords_num) topk = min(round(args.keywords_ratio*len(word_loss_pairs)), args.keywords_num)
topk_indexes = sorted(candidate_indexes, key=lambda x: word_loss_pairs[x][1], reverse=True)[:topk] topk_indexes = sorted(candidate_indexes, key=lambda x: word_loss_pairs[x][1], reverse=True)[:topk]
topk_indexes = sorted(topk_indexes) topk_indexes = sorted(topk_indexes)
keywords = [] keywords = []
for i, index in enumerate(topk_indexes): for i, index in enumerate(topk_indexes):
if i > 0 and index == topk_indexes[i-1] + 1: if i > 0 and index == topk_indexes[i-1] + 1 and \
keywords[-1]+= ' '+word_loss_pairs[index][0] word_loss_pairs[index][0].strip().startswith(index2keyword[index]) and \
word_loss_pairs[topk_indexes[i-1]][0].strip().endswith(index2keyword[topk_indexes[i-1]]):
keywords[-1]+= ' '+index2keyword[index]
else: else:
keywords.append(word_loss_pairs[index][0]) keywords.append(index2keyword[index])
return keywords return keywords
...@@ -115,12 +126,13 @@ def main(args): ...@@ -115,12 +126,13 @@ def main(args):
turns = [] turns = []
turn = {'words': [], 'losses': []} turn = {'words': [], 'losses': []}
for word, loss in zip(words, losses): for word, loss in zip(words, losses):
if word == '<|endoftext|>': if word == ['<|endoftext|>']:
# switch turn # switch turn
turn['utterance'] = ' '.join(turn['words']) turn['words'] = [tokenizer.convert_tokens_to_string(tokens) for tokens in turn['words']]
turn['losses'] = [np.mean(losses) for losses in turn['losses']]
turn['utterance'] = ''.join(turn['words']).strip()
keywords = keywords_filter(list(zip(turn['words'], turn['losses']))) keywords = keywords_filter(list(zip(turn['words'], turn['losses'])))
turn['keywords'] = keywords turn['keywords'] = keywords
# turn['keywords'] = ' | '.join([x[0] for x in keywords])
turn.pop('words') turn.pop('words')
turn.pop('losses') turn.pop('losses')
turns.append(turn) turns.append(turn)
......
set -e set -e
n_gpus=1 n_gpus=2
task_name="key2gen_shuffle_noisy" task_name="key2gen_shuffle_noisy"
dataset_name="metalwoz+sgd+tm1+tm2+tm3" dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
speaker="all" speaker="all"
model_type="gpt" model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}" data_dir="data/${task_name}/${model_type}/${dataset_name}"
...@@ -19,11 +19,11 @@ max_target_length=128 ...@@ -19,11 +19,11 @@ max_target_length=128
model_name_or_path="t5-small" model_name_or_path="t5-small"
per_device_train_batch_size=128 per_device_train_batch_size=128
per_device_eval_batch_size=128 per_device_eval_batch_size=128
gradient_accumulation_steps=8 gradient_accumulation_steps=4
lr=1e-3 lr=1e-3
num_train_epochs=1 num_train_epochs=1
python -m torch.distributed.launch \ python -m torch.distributed.launch --master_port 23456\
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \ --nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \ --task_name ${task_name} \
--train_file ${train_file} \ --train_file ${train_file} \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment