Skip to content
Snippets Groups Projects
Commit 3dd7ae72 authored by zqwerty's avatar zqwerty
Browse files

merge keywords from different model for easy illustration

parent a15bedea
Branches
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ def merge_tokens(tokens, losses, loss_merge_func=np.mean): ...@@ -20,7 +20,7 @@ def merge_tokens(tokens, losses, loss_merge_func=np.mean):
tokens[i+1] = 'Ġ'+tokens[i+1] tokens[i+1] = 'Ġ'+tokens[i+1]
i += 1 i += 1
continue continue
if token in ['user', 'system'] and i < len(tokens)-1 and tokens[i+1] == ':': if token in ['user', 'system', 'Ġuser', 'Ġsystem'] and i < len(tokens)-1 and tokens[i+1] == ':':
if i > 0: if i > 0:
tokens[i+1] = '<|endoftext|>' tokens[i+1] = '<|endoftext|>'
i += 1 i += 1
...@@ -109,7 +109,7 @@ def main(args): ...@@ -109,7 +109,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser(description="calculate NLU metrics for unified datasets") parser = ArgumentParser(description="extract keywords according to lm loss")
parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt') parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt')
parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]') parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]') parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
......
import json
def main(args):
filename2data = {f.split('/')[-1]: json.load(open(f)) for f in args.keywords_files}
first_filename = args.keywords_files[0].split('/')[-1]
dialogs = []
for i in range(len(filename2data[first_filename])):
turns = []
for j in range(len(filename2data[first_filename][i])):
utt = filename2data[first_filename][i][j]['utterance']
keywords = {filename.split('_')[2]+'_nonstopword'+filename.split('_')[-1]: ' | '.join([x[0] for x in filename2data[filename][i][j]['keywords']]) for filename in filename2data}
turns.append({
"utterance": utt,
**keywords
})
dialogs.append(turns)
json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--keywords_files', '-f', metavar='keywords_files', nargs='*', help='keywords files')
parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
args = parser.parse_args()
print(args)
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment