Skip to content
Snippets Groups Projects
Commit ec2167a3 authored by zqwerty's avatar zqwerty
Browse files

update pre-training data

parent 4a5a0d4b
Branches
No related tags found
No related merge requests found
......@@ -4,46 +4,51 @@ from tabulate import tabulate
def main(predict_result):
data = {
"keywords": {
"grounded keywords": {
"positive_keywords": [], "negative_keywords": None,
"predictions": [], "references": []
},
"possible keywords": {
"all keywords": {
"positive_keywords": [], "negative_keywords": [],
"predictions": [], "references": []
},
"no keywords": {
"positive_keywords": None, "negative_keywords": None,
"predictions": [], "references": []
}
}
with open(predict_result) as f:
for line in f:
item = json.loads(line)
if item["keywords+context"].startswith("keywords"):
data["keywords"]["predictions"].append(item['predictions'].strip())
data["keywords"]["references"].append(item['response'].strip())
positive_keywords = [k.strip() for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split('|')[1].split(' : ') if len(k) > 0]
data["keywords"]["positive_keywords"].append(positive_keywords)
elif item["keywords+context"].startswith("possible keywords"):
data["possible keywords"]["predictions"].append(item['predictions'].strip())
data["possible keywords"]["references"].append(item['response'].strip())
possible_keywords = [k.strip() for ks in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split('|') for k in ks.split(' : ') if len(k) > 0]
has_positive = True
prediction = item['predictions'].strip()
reference = item['target'].strip()
if 'all_keywords' in item and item['all_keywords']:
sample_type = 'all keywords'
positive_keywords = [k for g in item['keywords'] for k in g]
data[sample_type]["positive_keywords"].append(positive_keywords)
all_keywords = [k for g in item['all_keywords'] for k in g]
for keyword in positive_keywords:
if keyword in possible_keywords:
possible_keywords.remove(keyword)
else:
has_positive = False
break
if has_positive:
data["possible keywords"]["positive_keywords"].append(positive_keywords)
all_keywords.remove(keyword)
data[sample_type]["negative_keywords"].append(all_keywords)
elif 'keywords' in item and item['keywords']:
sample_type = 'grounded keywords'
positive_keywords = [k for g in item['keywords'] for k in g]
data[sample_type]["positive_keywords"].append(positive_keywords)
else:
data["possible keywords"]["positive_keywords"].append([])
data["possible keywords"]["negative_keywords"].append(possible_keywords)
# print(data)
# if len(data["possible keywords"]["positive_keywords"])>0:
# break
sample_type = 'no keywords'
data[sample_type]["predictions"].append(prediction)
data[sample_type]["references"].append(reference)
metric = datasets.load_metric('./key2gen_metric.py')
table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}]
if len(data["possible keywords"]["predictions"]) > 0:
table.append({'prompt': "possible keywords", **metric.compute(**data["possible keywords"])})
table = []
for sample_type in data:
table.append({'sample_type': sample_type, **metric.compute(**data[sample_type])})
print(tabulate(table, headers='keys', tablefmt='github'))
......
......@@ -12,7 +12,7 @@ def main(args):
fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8')
for dial in tqdm(data):
context = []
turn_keywords = [turn['keywords'] for turn in dial]
turns_keywords = [turn['keywords'] for turn in dial]
for i, turn in enumerate(dial):
speaker = 'user' if i % 2 == 0 else 'system'
utt = turn['utterance']
......@@ -27,21 +27,22 @@ def main(args):
continue
random.shuffle(turn['keywords'])
keywords = ' : '.join(turn['keywords'])
for j in range(len(turn['keywords'])):
random.shuffle(turn['keywords'][j])
keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']])
input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords']}, ensure_ascii=False)+'\n')
if args.mode == 'key2gen':
continue
possible_keywords_turns = [turn['keywords']]
num_possible_keywords_turns = min(random.randint(1, 5), len(turn_keywords) - 1)
possible_keywords_turns += random.sample(turn_keywords[:i] + turn_keywords[i+1:], num_possible_keywords_turns)
random.shuffle(possible_keywords_turns)
for possible_keywords_turn in possible_keywords_turns:
random.shuffle(possible_keywords_turn)
possible_keywords = ' | '.join([' : '.join(possible_keywords_turn) for possible_keywords_turn in possible_keywords_turns])
possible_keywords_sents = turn['keywords'][:]
num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1)
for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns):
possible_keywords_sents.extend(turn_keywords)
random.shuffle(possible_keywords_sents)
possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents])
input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}'
fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n')
fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords'], 'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n')
if args.mode == 'key2gen_noisy':
continue
......
task_name="key2gen_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3"
names=$(echo ${dataset_name} | tr "+" "\n")
model_type="gpt"
data_dir=data/${task_name}/${model_type}/${name}/${dataset_name}
......
......@@ -4,7 +4,7 @@ model_type="gpt"
data_dir="data/${task_name}/${dataset_name}/${model_type}"
model_name_or_path="gpt2-large"
keywords_num=100
keywords_ratio=0.3
keywords_ratio=0.4
keywords_th_ratio=0
stopwords=True
for data_split in validation test train
......
......@@ -69,11 +69,11 @@ class Key2GenMetrics(datasets.Metric):
def _compute(self, predictions, references, positive_keywords, negative_keywords=None):
"""Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall"""
# rouge-1/2/L bleu-1/2 distinct-1/2
if not negative_keywords:
negative_keywords = [[]] * len(positive_keywords)
bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
cnt = {'pos': 0, 'neg': 0, 'pos_recall': 0, 'neg_recall': 0}
if positive_keywords:
if not negative_keywords:
negative_keywords = [[]] * len(positive_keywords)
for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions):
cnt['pos'] += len(poskeys)
cnt['neg'] += len(negkeys)
......
......@@ -5,7 +5,7 @@ import os
from tqdm import tqdm
import numpy as np
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.tokenize import word_tokenize, PunktSentenceTokenizer
from transformers import GPT2Tokenizer
from string import punctuation
......@@ -34,7 +34,7 @@ def merge_tokens(tokens, losses):
# token = token.replace("Ġ", "")
res.append([[token], [loss]])
elif token == '<|endoftext|>':
res.append([[token], [loss]])
res.append([[token], [0.]])
else:
assert 'Ġ' not in token
if len(res) > 0:
......@@ -80,6 +80,7 @@ def main(args):
stop_words = set(stopwords.words('english'))
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
sent_tokenizer = PunktSentenceTokenizer()
if args.keywords_th_ratio > 0:
losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])]
......@@ -88,9 +89,19 @@ def main(args):
else:
loss_th = 0
def keywords_filter(word_loss_pairs):
def keywords_filter(words, losses):
word_loss_pairs = list(zip(words, losses))
index2keyword = {}
index2turn_sent = {}
num_turns = 0
turns_sent_spans = [list(sent_tokenizer.span_tokenize(utt)) for utt in ''.join(words).strip().split('<|endoftext|>')]
utt = ''
for i, word_loss_pair in enumerate(word_loss_pairs):
if word_loss_pair[0].startswith('<|endoftext|>'):
num_turns += 1
utt = ''
continue
utt += word_loss_pair[0]
words = word_tokenize(word_loss_pair[0])
if args.stopwords and any([w.lower() in stop_words for w in words]):
# skip stopwords
......@@ -104,42 +115,54 @@ def main(args):
# skip punctuation
continue
index2keyword[i] = strip_punctuation
for sent_idx, (sent_start, sent_end) in enumerate(turns_sent_spans[num_turns]):
if len(utt.strip()) <= sent_end:
index2turn_sent[i] = (num_turns, sent_idx)
break
candidate_indexes = list(index2keyword.keys())
topk = min(round(args.keywords_ratio*len(word_loss_pairs)), args.keywords_num)
topk = min(round(args.keywords_ratio*(len(word_loss_pairs)-num_turns)), args.keywords_num)
topk_indexes = sorted(candidate_indexes, key=lambda x: word_loss_pairs[x][1], reverse=True)[:topk]
topk_indexes = sorted(topk_indexes)
keywords = []
keywords_turn_sent2idx = {}
for i, index in enumerate(topk_indexes):
if i > 0 and index == topk_indexes[i-1] + 1 and \
word_loss_pairs[index][0].strip().startswith(index2keyword[index]) and \
word_loss_pairs[topk_indexes[i-1]][0].strip().endswith(index2keyword[topk_indexes[i-1]]):
keywords[-1]+= ' '+index2keyword[index]
else:
keywords_turn_sent2idx.setdefault(index2turn_sent[index][0], {})
keywords_turn_sent2idx[index2turn_sent[index][0]].setdefault(index2turn_sent[index][1], [])
keywords_turn_sent2idx[index2turn_sent[index][0]][index2turn_sent[index][1]].append(len(keywords))
keywords.append(index2keyword[index])
return keywords
return keywords, keywords_turn_sent2idx
dialogs = []
for item in tqdm(word_loss_list):
words = item['words']
losses = item['losses']
words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']]
losses = [np.mean(loss) for loss in item['losses']]
dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses)
# print(keywords_turn_sent2idx)
turns = []
turn = {'words': [], 'losses': []}
for word, loss in zip(words, losses):
if word == ['<|endoftext|>']:
for i, (word, loss) in enumerate(zip(words, losses)):
if word != '<|endoftext|>':
turn['words'].append(word)
turn['losses'].append(loss)
if word == '<|endoftext|>' or i == len(words) - 1:
# switch turn
turn['words'] = [tokenizer.convert_tokens_to_string(tokens) for tokens in turn['words']]
turn['losses'] = [np.mean(losses) for losses in turn['losses']]
turn['utterance'] = ''.join(turn['words']).strip()
keywords = keywords_filter(list(zip(turn['words'], turn['losses'])))
turn['keywords'] = keywords
# 1) extract keywords according to LM loss within the turn
# keywords, _ = keywords_filter(turn['words'], turn['losses'])
# turn['turn-level_keywords'] = keywords
# 1) extract keywords according to LM loss over the dialog, and group them by sentence
turn['keywords'] = [[dialog_keywords[idx] for idx in k_idxes] for sent_idx, k_idxes in keywords_turn_sent2idx.get(len(turns), {}).items()]
turn.pop('words')
turn.pop('losses')
turns.append(turn)
turn = {'words': [], 'losses': []}
else:
turn['words'].append(word)
turn['losses'].append(loss)
dialogs.append(turns)
json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
......
set -e
for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd multiwoz21
do
bash get_keywords.sh ${dataset_name}
done
\ No newline at end of file
set -e
n_gpus=4
n_gpus=2
master_port=23457
task_name="key2gen_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
dataset_name=$1
model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}"
......@@ -16,7 +16,7 @@ target_column="target"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+sgd+tm1+tm2+tm3"
model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+tm1+tm2+tm3"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=2
......@@ -39,7 +39,7 @@ python -m torch.distributed.launch --master_port ${master_port} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
......
set -e
n_gpus=4
n_gpus=2
master_port=23457
task_name="key2gen_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
dataset_name="dailydialog+metalwoz+tm1+tm2+tm3"
model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}"
output_dir="output/${task_name}/${model_type}/${dataset_name}"
......@@ -19,7 +19,7 @@ max_target_length=128
model_name_or_path="t5-small"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=2
gradient_accumulation_steps=4
lr=1e-3
num_train_epochs=3
......@@ -46,7 +46,7 @@ python -m torch.distributed.launch --master_port ${master_port} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment