diff --git a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py index 3c04ea8a29aebecf06dffac16fdf79fef41c763a..6b1068cef045550f57621fe0ab4aad8a4047cfbb 100644 --- a/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py +++ b/convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py @@ -4,46 +4,51 @@ from tabulate import tabulate def main(predict_result): data = { - "keywords": { + "grounded keywords": { "positive_keywords": [], "negative_keywords": None, "predictions": [], "references": [] }, - "possible keywords": { + "all keywords": { "positive_keywords": [], "negative_keywords": [], "predictions": [], "references": [] + }, + "no keywords": { + "positive_keywords": None, "negative_keywords": None, + "predictions": [], "references": [] } } with open(predict_result) as f: for line in f: item = json.loads(line) - if item["keywords+context"].startswith("keywords"): - data["keywords"]["predictions"].append(item['predictions'].strip()) - data["keywords"]["references"].append(item['response'].strip()) - positive_keywords = [k.strip() for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split('|')[1].split(' : ') if len(k) > 0] - data["keywords"]["positive_keywords"].append(positive_keywords) - elif item["keywords+context"].startswith("possible keywords"): - data["possible keywords"]["predictions"].append(item['predictions'].strip()) - data["possible keywords"]["references"].append(item['response'].strip()) - possible_keywords = [k.strip() for ks in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split('|') for k in ks.split(' : ') if len(k) > 0] - has_positive = True + prediction = item['predictions'].strip() + reference = item['target'].strip() + if 'all_keywords' in item and item['all_keywords']: + sample_type = 'all keywords' + + positive_keywords = [k for g in item['keywords'] for k in g] + data[sample_type]["positive_keywords"].append(positive_keywords) + + all_keywords = [k for g in item['all_keywords'] for k in g] for keyword in positive_keywords: - if keyword in possible_keywords: - possible_keywords.remove(keyword) - else: - has_positive = False - break - if has_positive: - data["possible keywords"]["positive_keywords"].append(positive_keywords) - else: - data["possible keywords"]["positive_keywords"].append([]) - data["possible keywords"]["negative_keywords"].append(possible_keywords) - # print(data) - # if len(data["possible keywords"]["positive_keywords"])>0: - # break + all_keywords.remove(keyword) + data[sample_type]["negative_keywords"].append(all_keywords) + + elif 'keywords' in item and item['keywords']: + sample_type = 'grounded keywords' + + positive_keywords = [k for g in item['keywords'] for k in g] + data[sample_type]["positive_keywords"].append(positive_keywords) + + else: + sample_type = 'no keywords' + + data[sample_type]["predictions"].append(prediction) + data[sample_type]["references"].append(reference) + metric = datasets.load_metric('./key2gen_metric.py') - table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}] - if len(data["possible keywords"]["predictions"]) > 0: - table.append({'prompt': "possible keywords", **metric.compute(**data["possible keywords"])}) + table = [] + for sample_type in data: + table.append({'sample_type': sample_type, **metric.compute(**data[sample_type])}) print(tabulate(table, headers='keys', tablefmt='github')) diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py index 0a1d63457a19c3ff42cd77a04c3d4505ad60a9bf..e7e40112c9ccaff28b98e220ae5e1e24bba4ebf6 100644 --- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py +++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.py @@ -12,7 +12,7 @@ def main(args): fout = open(os.path.join(args.output_dir, f"{filename.split('/')[-1].split('_')[1]}.json"), 'w', encoding='utf-8') for dial in tqdm(data): context = [] - turn_keywords = [turn['keywords'] for turn in dial] + turns_keywords = [turn['keywords'] for turn in dial] for i, turn in enumerate(dial): speaker = 'user' if i % 2 == 0 else 'system' utt = turn['utterance'] @@ -27,21 +27,22 @@ def main(args): continue random.shuffle(turn['keywords']) - keywords = ' : '.join(turn['keywords']) + for j in range(len(turn['keywords'])): + random.shuffle(turn['keywords'][j]) + keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in turn['keywords']]) input_seq = f'generate a response: grounded knowledge: | {keywords} | context:\n\n{context_seq}' - fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n') + fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords']}, ensure_ascii=False)+'\n') if args.mode == 'key2gen': continue - possible_keywords_turns = [turn['keywords']] - num_possible_keywords_turns = min(random.randint(1, 5), len(turn_keywords) - 1) - possible_keywords_turns += random.sample(turn_keywords[:i] + turn_keywords[i+1:], num_possible_keywords_turns) - random.shuffle(possible_keywords_turns) - for possible_keywords_turn in possible_keywords_turns: - random.shuffle(possible_keywords_turn) - possible_keywords = ' | '.join([' : '.join(possible_keywords_turn) for possible_keywords_turn in possible_keywords_turns]) + possible_keywords_sents = turn['keywords'][:] + num_possible_keywords_turns = min(random.randint(1, 5), len(turns_keywords) - 1) + for turn_keywords in random.sample(turns_keywords[:i] + turns_keywords[i+1:], num_possible_keywords_turns): + possible_keywords_sents.extend(turn_keywords) + random.shuffle(possible_keywords_sents) + possible_keywords = ' | '.join([' : '.join(sent_keywords) for sent_keywords in possible_keywords_sents]) input_seq = f'generate a response: all knowledge: | {possible_keywords} | context:\n\n{context_seq}' - fout.write(json.dumps({'source': input_seq, 'target': utt}, ensure_ascii=False)+'\n') + fout.write(json.dumps({'source': input_seq, 'target': utt, 'keywords': turn['keywords'], 'all_keywords': possible_keywords_sents}, ensure_ascii=False)+'\n') if args.mode == 'key2gen_noisy': continue diff --git a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh index ca3017eaac6864a13c0ef055b589b3d177417518..f24058ecfa63c40c9100f03f061d64b58946796f 100644 --- a/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh +++ b/convlab2/base_models/gpt/keyword_extraction/gen_pretraining_data.sh @@ -1,5 +1,5 @@ task_name="key2gen_noisy" -dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" +dataset_name="dailydialog+metalwoz+tm1+tm2+tm3" names=$(echo ${dataset_name} | tr "+" "\n") model_type="gpt" data_dir=data/${task_name}/${model_type}/${name}/${dataset_name} diff --git a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh index cffa944b0374cff67abe223b4c2ea252ebd889f4..c060c9a03f7799ae59407c86c7c356788b3f529d 100644 --- a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh +++ b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh @@ -4,7 +4,7 @@ model_type="gpt" data_dir="data/${task_name}/${dataset_name}/${model_type}" model_name_or_path="gpt2-large" keywords_num=100 -keywords_ratio=0.3 +keywords_ratio=0.4 keywords_th_ratio=0 stopwords=True for data_split in validation test train diff --git a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py index 418f8d78af372b893e82f67273bc46010260cfc2..d9722d96ca71a961dc7ad837191fa202848111f3 100644 --- a/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py +++ b/convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py @@ -69,25 +69,25 @@ class Key2GenMetrics(datasets.Metric): def _compute(self, predictions, references, positive_keywords, negative_keywords=None): """Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall""" - # rouge-1/2/L bleu-1/2 distinct-1/2 - if not negative_keywords: - negative_keywords = [[]] * len(positive_keywords) bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score cnt = {'pos': 0, 'neg': 0, 'pos_recall': 0, 'neg_recall': 0} - for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions): - cnt['pos'] += len(poskeys) - cnt['neg'] += len(negkeys) + if positive_keywords: + if not negative_keywords: + negative_keywords = [[]] * len(positive_keywords) + for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions): + cnt['pos'] += len(poskeys) + cnt['neg'] += len(negkeys) - prediction = prediction.lower() - for key in poskeys: - key = key.lower() - if key in prediction: - cnt['pos_recall'] += 1 - - for key in negkeys: - key = key.lower() - if key in prediction: - cnt['neg_recall'] += 1 + prediction = prediction.lower() + for key in poskeys: + key = key.lower() + if key in prediction: + cnt['pos_recall'] += 1 + + for key in negkeys: + key = key.lower() + if key in prediction: + cnt['neg_recall'] += 1 return { "bleu": bleu, diff --git a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py index 256c15952b2365c793205ae96783c3dc232ddb72..01581d2d34dee0b1f64b8ee65ab68375e6c7d5db 100644 --- a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py +++ b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py @@ -5,7 +5,7 @@ import os from tqdm import tqdm import numpy as np from nltk.corpus import stopwords -from nltk.tokenize import word_tokenize +from nltk.tokenize import word_tokenize, PunktSentenceTokenizer from transformers import GPT2Tokenizer from string import punctuation @@ -34,7 +34,7 @@ def merge_tokens(tokens, losses): # token = token.replace("Ġ", "") res.append([[token], [loss]]) elif token == '<|endoftext|>': - res.append([[token], [loss]]) + res.append([[token], [0.]]) else: assert 'Ġ' not in token if len(res) > 0: @@ -80,6 +80,7 @@ def main(args): stop_words = set(stopwords.words('english')) tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') + sent_tokenizer = PunktSentenceTokenizer() if args.keywords_th_ratio > 0: losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])] @@ -88,9 +89,19 @@ def main(args): else: loss_th = 0 - def keywords_filter(word_loss_pairs): + def keywords_filter(words, losses): + word_loss_pairs = list(zip(words, losses)) index2keyword = {} + index2turn_sent = {} + num_turns = 0 + turns_sent_spans = [list(sent_tokenizer.span_tokenize(utt)) for utt in ''.join(words).strip().split('<|endoftext|>')] + utt = '' for i, word_loss_pair in enumerate(word_loss_pairs): + if word_loss_pair[0].startswith('<|endoftext|>'): + num_turns += 1 + utt = '' + continue + utt += word_loss_pair[0] words = word_tokenize(word_loss_pair[0]) if args.stopwords and any([w.lower() in stop_words for w in words]): # skip stopwords @@ -104,42 +115,54 @@ def main(args): # skip punctuation continue index2keyword[i] = strip_punctuation + for sent_idx, (sent_start, sent_end) in enumerate(turns_sent_spans[num_turns]): + if len(utt.strip()) <= sent_end: + index2turn_sent[i] = (num_turns, sent_idx) + break candidate_indexes = list(index2keyword.keys()) - topk = min(round(args.keywords_ratio*len(word_loss_pairs)), args.keywords_num) + topk = min(round(args.keywords_ratio*(len(word_loss_pairs)-num_turns)), args.keywords_num) topk_indexes = sorted(candidate_indexes, key=lambda x: word_loss_pairs[x][1], reverse=True)[:topk] topk_indexes = sorted(topk_indexes) keywords = [] + keywords_turn_sent2idx = {} for i, index in enumerate(topk_indexes): if i > 0 and index == topk_indexes[i-1] + 1 and \ word_loss_pairs[index][0].strip().startswith(index2keyword[index]) and \ word_loss_pairs[topk_indexes[i-1]][0].strip().endswith(index2keyword[topk_indexes[i-1]]): keywords[-1]+= ' '+index2keyword[index] else: + keywords_turn_sent2idx.setdefault(index2turn_sent[index][0], {}) + keywords_turn_sent2idx[index2turn_sent[index][0]].setdefault(index2turn_sent[index][1], []) + keywords_turn_sent2idx[index2turn_sent[index][0]][index2turn_sent[index][1]].append(len(keywords)) keywords.append(index2keyword[index]) - return keywords + return keywords, keywords_turn_sent2idx dialogs = [] for item in tqdm(word_loss_list): - words = item['words'] - losses = item['losses'] + words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']] + losses = [np.mean(loss) for loss in item['losses']] + dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses) + # print(keywords_turn_sent2idx) turns = [] turn = {'words': [], 'losses': []} - for word, loss in zip(words, losses): - if word == ['<|endoftext|>']: + for i, (word, loss) in enumerate(zip(words, losses)): + if word != '<|endoftext|>': + turn['words'].append(word) + turn['losses'].append(loss) + if word == '<|endoftext|>' or i == len(words) - 1: # switch turn - turn['words'] = [tokenizer.convert_tokens_to_string(tokens) for tokens in turn['words']] - turn['losses'] = [np.mean(losses) for losses in turn['losses']] turn['utterance'] = ''.join(turn['words']).strip() - keywords = keywords_filter(list(zip(turn['words'], turn['losses']))) - turn['keywords'] = keywords + # 1) extract keywords according to LM loss within the turn + # keywords, _ = keywords_filter(turn['words'], turn['losses']) + # turn['turn-level_keywords'] = keywords + # 1) extract keywords according to LM loss over the dialog, and group them by sentence + turn['keywords'] = [[dialog_keywords[idx] for idx in k_idxes] for sent_idx, k_idxes in keywords_turn_sent2idx.get(len(turns), {}).items()] turn.pop('words') turn.pop('losses') turns.append(turn) turn = {'words': [], 'losses': []} - else: - turn['words'].append(word) - turn['losses'].append(loss) + dialogs.append(turns) json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False) diff --git a/convlab2/base_models/gpt/keyword_extraction/run.sh b/convlab2/base_models/gpt/keyword_extraction/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..f35c2403ce21f9450d3d7a84dc8e7076ee6f5f89 --- /dev/null +++ b/convlab2/base_models/gpt/keyword_extraction/run.sh @@ -0,0 +1,5 @@ +set -e +for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd multiwoz21 +do + bash get_keywords.sh ${dataset_name} +done \ No newline at end of file diff --git a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh index 69d4c9980addceab7d01b4ba13e8069107a9555e..faaef560c20bd1a928f9c99503277780c4e8c26d 100644 --- a/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh +++ b/convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh @@ -1,8 +1,8 @@ set -e -n_gpus=4 +n_gpus=2 master_port=23457 task_name="key2gen_noisy" -dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" +dataset_name=$1 model_type="gpt" data_dir="data/${task_name}/${model_type}/${dataset_name}" output_dir="output/${task_name}/${model_type}/${dataset_name}" @@ -16,7 +16,7 @@ target_column="target" truncation_side="left" max_source_length=512 max_target_length=128 -model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+sgd+tm1+tm2+tm3" +model_name_or_path="output/${task_name}/${model_type}/dailydialog+metalwoz+tm1+tm2+tm3" per_device_train_batch_size=128 per_device_eval_batch_size=128 gradient_accumulation_steps=2 @@ -39,7 +39,7 @@ python -m torch.distributed.launch --master_port ${master_port} \ --output_dir ${output_dir} \ --logging_dir ${logging_dir} \ --overwrite_output_dir \ - --preprocessing_num_workers 4 \ + --preprocessing_num_workers 16 \ --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --gradient_accumulation_steps ${gradient_accumulation_steps} \ diff --git a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh index 8d9a4490b6e85039041e2ec05fef22128e99a18e..9a413f2dd91381cd0a05b13fdc79e5be588f8cc2 100644 --- a/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh +++ b/convlab2/base_models/gpt/keyword_extraction/train_t5_key2gen.sh @@ -1,8 +1,8 @@ set -e -n_gpus=4 +n_gpus=2 master_port=23457 task_name="key2gen_noisy" -dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3" +dataset_name="dailydialog+metalwoz+tm1+tm2+tm3" model_type="gpt" data_dir="data/${task_name}/${model_type}/${dataset_name}" output_dir="output/${task_name}/${model_type}/${dataset_name}" @@ -19,7 +19,7 @@ max_target_length=128 model_name_or_path="t5-small" per_device_train_batch_size=128 per_device_eval_batch_size=128 -gradient_accumulation_steps=2 +gradient_accumulation_steps=4 lr=1e-3 num_train_epochs=3 @@ -46,7 +46,7 @@ python -m torch.distributed.launch --master_port ${master_port} \ --output_dir ${output_dir} \ --logging_dir ${logging_dir} \ --overwrite_output_dir \ - --preprocessing_num_workers 4 \ + --preprocessing_num_workers 16 \ --per_device_train_batch_size ${per_device_train_batch_size} \ --per_device_eval_batch_size ${per_device_eval_batch_size} \ --gradient_accumulation_steps ${gradient_accumulation_steps} \