Skip to content
Snippets Groups Projects
Commit 3dd7ae72 authored by zqwerty's avatar zqwerty
Browse files

merge keywords from different model for easy illustration

parent a15bedea
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ def merge_tokens(tokens, losses, loss_merge_func=np.mean):
tokens[i+1] = 'Ġ'+tokens[i+1]
i += 1
continue
if token in ['user', 'system'] and i < len(tokens)-1 and tokens[i+1] == ':':
if token in ['user', 'system', 'Ġuser', 'Ġsystem'] and i < len(tokens)-1 and tokens[i+1] == ':':
if i > 0:
tokens[i+1] = '<|endoftext|>'
i += 1
......@@ -109,7 +109,7 @@ def main(args):
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser = ArgumentParser(description="extract keywords according to lm loss")
parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt')
parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
......
import json
def main(args):
filename2data = {f.split('/')[-1]: json.load(open(f)) for f in args.keywords_files}
first_filename = args.keywords_files[0].split('/')[-1]
dialogs = []
for i in range(len(filename2data[first_filename])):
turns = []
for j in range(len(filename2data[first_filename][i])):
utt = filename2data[first_filename][i][j]['utterance']
keywords = {filename.split('_')[2]+'_nonstopword'+filename.split('_')[-1]: ' | '.join([x[0] for x in filename2data[filename][i][j]['keywords']]) for filename in filename2data}
turns.append({
"utterance": utt,
**keywords
})
dialogs.append(turns)
json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--keywords_files', '-f', metavar='keywords_files', nargs='*', help='keywords files')
parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
args = parser.parse_args()
print(args)
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment