Skip to content
Snippets Groups Projects
Unverified Commit e368deeb authored by Carrey Wang's avatar Carrey Wang Committed by GitHub
Browse files

Add warmup for DQN and fix minor bugs (#150)


* Initial commit

* Update README.md

* sync with commit aa1af0ee81ba591d1cf3c222c9d71963ed1dca98

* add gitignore

* update tutorial

* update mdrg, not use dbPointer

* update mdrg, not use dbPointer

* update mdrg, download before use dbPointer

* update analyzer

* update README

* update tutorial

* Fix dbquery when matching name

* move change to dev branch

* move dbquery change from master to dev branch

* disable travis for now

* do nothing in travis

* do nothing in travis for now

* not deploy now

* add docs

* update .travis.yml

* update .travis.yml

* update .travis.yml

* update rst files

* add alias center for centre in dbquery

* fix the policy training bug

* fix the bug of nan gradient

* add cross-lingual dst data

* Update README.md

* Update README.md

* Add CrossWoz Web support and some minor bug fix (#19)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

Co-authored-by: default avatarkflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: default avatarCarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>

* modify xdst data name

* Translation train on MultiWOZ (Chinese) nad CrossWOZ (English) of SUMBT (#17)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* updata xdst baseline

* Update README.md

* fix allennlp==0.9.0

* Update README.md

* modify build message function for goal generation

* Fix goal generator and dbquery for multiwoz (#32)

* move dbquery change from master to dev branch

* add alias center for centre in dbquery

* replace attraction type 'mutliple sports' to 'multiple sports', involving only one entity

* add depart and destination constraints for searching db (ignore=False), modify goal generator to draw the values of these two slots from database

* fix bug (#35)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* fix bug #34

* revert changes

* update demo video link

* Update README.md

* some changes in #36 (#37)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* fix bug #34

* revert changes

* revert changes

* some changes of #36

* fix analyzer example.py

* dst/evaluate.py: Use utf-8 encoding

* use transformers library to automate model caching

* Update README.md

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba7 (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* fix nlu max len

* update travis

* Update run_agent.py

* Create README.md

* Update README.md

* modify human_eval README

* fix sclstm crosswoz import issues

* update travis.yml

* try to fix deploy

* Update README.md

* Update README.md

* Improve agenda policy (#52)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba7 (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* Update README.md

#53

* Update README.md (#57)

* Improve agenda policy (#60)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba7 (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* fix self.cur_domain=None when system offer book

* Improve agenda policy (#62)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba7 (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* fix self.cur_domain=None when system offer book

* fix agenda for 0 choice

* fix sequicityy

* fix sequicityy

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* update setup.py:add tokenizers requirement

* fix typo

* update user nlg template

* Update README.md

* remove fail book in multiwoz goal generator

* fix taxi dontcare problem

* can manually set user goal in agenda now

* test goal overlap between generator and trainset

* change default taxi depart and destination from address to name/'the hotel/restaurant'

* change initiative from 4 to randint(2,4)

* agenda pop more da when only answer dontcare

* add 'the same area/pricerange/people/day' in agenda with 0.3 probability

* remove unnecessary thank you

* add domain for postcode and Phone in user templateNLG

* add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy

* add template for interent-no, parking-no in templatenlg

* update Evaluator: check whether final goal satisfies constraints

* update evaluator: check booked entity

* output goal analysis to file

* update goal analysis

* update

* Update analyzer.py

* Fix simulator (#83)

* remove fail book in multiwoz goal generator

* fix taxi dontcare problem

* can manually set user goal in agenda now

* test goal overlap between generator and trainset

* change default taxi depart and destination from address to name/'the hotel/restaurant'

* change initiative from 4 to randint(2,4)

* agenda pop more da when only answer dontcare

* add 'the same area/pricerange/people/day' in agenda with 0.3 probability

* remove unnecessary thank you

* add domain for postcode and Phone in user templateNLG

* add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy

* add template for interent-no, parking-no in templatenlg

* remove police and hospital domain in goal generator

* update multiwoz evaluator: adding 'internet/parking-none, 24:** to valid value

* fix nlg template (#88)

* add new_goal_model without police and hospital domain (#89)

* Normalize string comparisons in multiwoz template nlg to be case insensitive (#87)

* normalize template nlg keys to be lower case

* fix slot comparison in multiwoz nlg to be case insensitive

* use value_lower instead of calling .lower() on each comparison

* Add police n hospital (#95)

* add back police and hospital goal

* update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result

* update dbquery and session (#99)

* update dbquery: ? matches all; fix bug in init_session

* update multiwoz_eval, check Ref of booked

* filter domain in final_goal_analyze

Co-authored-by: default avatarnewRuntieException <wdz15@mails.tsinghua.edu.cn>

* Add dockerfile (#98)

* fix nlg template

* add dockerfile

* include missing packages at setup.py (#102)

* multiwoz dbquery doesnt require mutable constraints (#106)

* Add police n hospital (#107)

* add back police and hospital goal

* update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result

* update user templatenlg

* add test set example for dstc9 (multiwoz_zh, crosswoz_en) (#108)

* Add dockerfile (#110)

* fix nlg template

* add dockerfile

* add package for dockerfile

* update versions

* Update README.md

* Update versions in setup (#111)

* move dbquery change from master to dev branch

* add alias center for centre in dbquery

* fix sequicityy

* update versions

Co-authored-by: default avatarzqwerty <zhuq96@hotmail.com>
Co-authored-by: default avatarzhuqi <zqwerty@users.noreply.github.com>

* Update README.md

* Update README.md

* Update README.md

* fix system nlg template bug (#117)

* add 'book' in DST evaluation. (#85)

* Maintenance (#119)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* dstc9 eval

* dstc9 xldst evaluation

* Nlg template fix (#121)

* fix nlg template

* fix user nlg template issue

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* dstc9 xldst evaluation (#122)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* revise xldst evaluation (#124)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* Update dst.py

* update precision, recall, f1 calculation

* minor change

* fix policy evaluation

* Nlg template fix (#127)

* fix nlg template

* fix user nlg template issue

* fix system NLG template

* nlu update and bugfix (#118)

* jointBERT_new avaliable && fix milu dataset_reader && fix jointBERT/tag2id

* remove jointBERT_new

* update milu/multiwoz/nlu.py model_file path

* add metrics in XLDST evaluation (#126)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* add input reqt vals in human eval (#128)

* Maintenance (#129)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* Human (#131)

* change task config

* add final goal logging

* encapsule PipelineAgent internal state interface for return and
replacement

* Maintenance (#132)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix a database typo

* Maintenance (#134)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix goal generator for police domain message

* fix a minor typo in crosswoz database (#133)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* remove low performance baselines (#136)

* Human2 (#137)

* change task config

* add final goal logging

* encapsule PipelineAgent internal state interface for return and
replacement

* fix bug associted with the issue of strange user input

* Fix a bug in TRADE CrossWOZ training (#138)

* add 'book' in DST evaluation.

* Fix TRADE crosswoz training evaluation bug

Co-authored-by: default avatarzheng <zheng@zhangzheng-PC.lan>

* Maintenance (#140)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix goal generator for police domain message

* update template NLG

* Add note for deploy web service (#139)

* add 'book' in DST evaluation.

* Fix TRADE crosswoz training evaluation bug

* Add note on deploy

Co-authored-by: default avatarzheng <zheng@zhangzheng-PC.lan>

* add value unification

* fix XLDST evaluation (#141)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification

* fix user Nlg template (#142)

* fix system nlg template bug

* fix user nlg issue

* fix white character issue #144

* deal with white charater in XLDST evaluation (#145)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification

* fix white character issue #144

* DQN (#113)

* implemented script to extract all the statistics for all dialogue_act in data

* changed script for actions be compatible to sys_da_voc.txt actions

* multiwoz vector now supports composite actions

* implemented ReplayMemory and EpsilongGreedyPolicy

* implemented a basic version of dqn

* included some comments

* Add DQN Test and Change file structure  (#146)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

* Update server.py

* add test for DQN

* change server

Co-authored-by: default avatarCarrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: default avatarkflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: default avatarCarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: default avatarzimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: default avatarMR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>

* update eval

* dump dst eval results

* make value lower

* add progress bar

* fix bug in last commit

* Update policy_agenda_multiwoz.py

* remove unnecessary mapping (#147)

* udpate dstc9 eval

* make value lower

* add warm up for dqn and fix bugs

* rm unrelated files

Co-authored-by: default avatarzhuqi <zqwerty@users.noreply.github.com>
Co-authored-by: default avatarzqwerty <zhuq96@hotmail.com>
Co-authored-by: default avatarRyuichi Takanobu <truthless11@gmail.com>
Co-authored-by: default avatarnewRuntieException <wdz15@mails.tsinghua.edu.cn>
Co-authored-by: default avatarliangrz <liangrz15@mails.tsinghua.edu.cn>
Co-authored-by: default avatarCarrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: default avatarkflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: default avatarCarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: default avatar罗崚骁 <function2@qq.com>
Co-authored-by: default avatarmehrad <mehrad@stanford.edu>
Co-authored-by: default avatarpengbaolin <39398162+pengbaolin@users.noreply.github.com>
Co-authored-by: default avatarJinchao Li <38700695+jincli@users.noreply.github.com>
Co-authored-by: default avatarShahin Shayandeh <shahins@microsoft.com>
Co-authored-by: default avataraaa123git <43716234+aaa123git@users.noreply.github.com>
Co-authored-by: default avatarBruno Eidi Nishimoto <bruno_nishimoto@hotmail.com>
Co-authored-by: default avatarVojtěch Hudeček <vojta.hudecek@gmail.com>
Co-authored-by: default avatarzhangzthu <zhangz.goal@gmail.com>
Co-authored-by: default avatarxw <48146603+xwwwwww@users.noreply.github.com>
Co-authored-by: default avatarzheng <zheng@zhangzheng-PC.lan>
Co-authored-by: default avatarzimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: default avatarMR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>
parent 3811af82
Branches
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy ...@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy
from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
from convlab2.util.train_util import init_logging_handler from convlab2.util.train_util import init_logging_handler
from convlab2.policy.vector.vector_multiwoz import MultiWozVector from convlab2.policy.vector.vector_multiwoz import MultiWozVector
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
from convlab2.util.file_util import cached_path from convlab2.util.file_util import cached_path
import zipfile import zipfile
import sys import sys
...@@ -32,6 +33,8 @@ class DQN(Policy): ...@@ -32,6 +33,8 @@ class DQN(Policy):
self.training_iter = cfg['training_iter'] self.training_iter = cfg['training_iter']
self.training_batch_iter = cfg['training_batch_iter'] self.training_batch_iter = cfg['training_batch_iter']
self.batch_size = cfg['batch_size'] self.batch_size = cfg['batch_size']
self.epsilon = cfg['epsilon_spec']['start']
self.rule_bot = RuleBasedMultiwozBot()
self.gamma = cfg['gamma'] self.gamma = cfg['gamma']
self.is_train = is_train self.is_train = is_train
if is_train: if is_train:
...@@ -58,9 +61,10 @@ class DQN(Policy): ...@@ -58,9 +61,10 @@ class DQN(Policy):
self.loss_fn = nn.MSELoss() self.loss_fn = nn.MSELoss()
def update_memory(self, sample): def update_memory(self, sample):
self.memory.reset()
self.memory.append(sample) self.memory.append(sample)
def predict(self, state): def predict(self, state, warm_up=False):
""" """
Predict an system action given state. Predict an system action given state.
Args: Args:
...@@ -68,14 +72,29 @@ class DQN(Policy): ...@@ -68,14 +72,29 @@ class DQN(Policy):
Returns: Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
""" """
if warm_up:
action = self.rule_action(state)
state['system_action'] = action
else:
s_vec = torch.Tensor(self.vector.state_vectorize(state)) s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.net.select_action(s_vec.to(device=DEVICE)) a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train)
action = self.vector.action_devectorize(a.numpy()) action = self.vector.action_devectorize(a.numpy())
state['system_action'] = action state['system_action'] = action
return action return action
def rule_action(self, state):
if self.epsilon > np.random.rand():
a = torch.randint(self.vector.da_dim, (1, ))
# transforms action index to a vector action (one-hot encoding)
a_vec = torch.zeros(self.vector.da_dim)
a_vec[a] = 1.
action = self.vector.action_devectorize(a_vec.numpy())
else:
# rule-based warm up
action = self.rule_bot.predict(state)
return action
def init_session(self): def init_session(self):
""" """
Restore after one session Restore after one session
......
...@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz): ...@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
queue.put([pid, buff]) queue.put([pid, buff])
evt.wait() evt.wait()
def warmupsampler(pid, queue, evt, env, policy, batchsz):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
processes.
:param pid: process id
:param queue: multiprocessing.Queue, to collect sampled data
:param evt: multiprocessing.Event, to keep the process alive
:param env: environment instance
:param policy: policy network, to generate action from current policy
:param batchsz: total sampled items
:return:
"""
buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally
# the final sampled number may be larger than batchsz.
sampled_num = 0
sampled_traj_num = 0
traj_len = 50
real_traj_len = 0
while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state
s = env.reset()
for t in range(traj_len):
# [s_dim] => [a_dim]
s_vec = torch.Tensor(policy.vector.state_vectorize(s))
a = policy.predict(s, warm_up=True)
# interact with env
next_s, r, done = env.step(a)
# a flag indicates ending or not
mask = 0 if done else 1
# get reward compared to demostrations
next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
# save to queue
buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
# update per step
s = next_s
real_traj_len = t
if done:
break
def sample(env, policy, batchsz, process_num): # this is end of one trajectory
sampled_num += real_traj_len
sampled_traj_num += 1
# t indicates the valid trajectory length
# this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue
queue.put([pid, buff])
evt.wait()
def sample(env, policy, batchsz, process_num, warm_up=False):
""" """
Given batchsz number of task, the batchsz will be splited equally to each processes Given batchsz number of task, the batchsz will be splited equally to each processes
and when processes return, it merge all data and return and when processes return, it merge all data and return
...@@ -119,6 +182,9 @@ def sample(env, policy, batchsz, process_num): ...@@ -119,6 +182,9 @@ def sample(env, policy, batchsz, process_num):
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz) process_args = (i, queue, evt, env, policy, process_batchsz)
if warm_up:
processes.append(mp.Process(target=warmupsampler, args=process_args))
else:
processes.append(mp.Process(target=sampler, args=process_args)) processes.append(mp.Process(target=sampler, args=process_args))
for p in processes: for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped. # set the process as daemon, and it will be killed once the main process is stoped.
...@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num): ...@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
policy.update(epoch) policy.update(epoch)
def warm_start(env, policy, batchsz, epoch, process_num):
# sample data asynchronously
buff = sample(env, policy, batchsz, process_num, warm_up=True)
policy.update_memory(buff)
policy.update(epoch)
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--load_path", type=str, default="", help="path of model to load") parser.add_argument("--load_path", type=str, default="", help="path of model to load")
...@@ -170,6 +243,7 @@ if __name__ == '__main__': ...@@ -170,6 +243,7 @@ if __name__ == '__main__':
evaluator = MultiWozEvaluator() evaluator = MultiWozEvaluator()
env = Environment(None, simulator, None, dst_sys, evaluator) env = Environment(None, simulator, None, dst_sys, evaluator)
warm_start(env, policy_sys, args.batchsz, 0, args.process_num)
for i in range(args.epoch): for i in range(args.epoch):
update(env, policy_sys, args.batchsz, i, args.process_num) update(env, policy_sys, args.batchsz, i, args.process_num)
...@@ -66,7 +66,7 @@ class Beam(object): ...@@ -66,7 +66,7 @@ class Beam(object):
# bestScoresId is flattened as a (beam x word) array, # bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from # so we need to calculate which word and beam each score came from
prev_k = best_scores_id / num_words prev_k = best_scores_id // num_words
self.prev_ks.append(prev_k) self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words) self.next_ys.append(best_scores_id - prev_k * num_words)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment