diff --git a/convlab2/dst/trade/crosswoz/models/TRADE.py b/convlab2/dst/trade/crosswoz/models/TRADE.py index e70e877724d2eacb623010a9c8b57d2ca2bf8b1b..d0b94233d5cd0c9e5e6f7ded554b2348bd3c8f7e 100755 --- a/convlab2/dst/trade/crosswoz/models/TRADE.py +++ b/convlab2/dst/trade/crosswoz/models/TRADE.py @@ -235,7 +235,6 @@ class TRADE(nn.Module): predict_belief_bsz_ptr.append(slot_temp[si] + "-" + str(st)) all_prediction[data_dev["ID"][bi]][data_dev["turn_id"][bi]]["pred_bs_ptr"] = predict_belief_bsz_ptr - return predict_belief_bsz_ptr if set(data_dev["turn_belief"][bi]) != set(predict_belief_bsz_ptr) and args["genSample"]: print("True", set(data_dev["turn_belief"][bi]))