From f1ad5224f9336b85162e4ca6d735aee36b70bb47 Mon Sep 17 00:00:00 2001 From: Jakhes <dean.schmitz@schmitzbauer.de> Date: Mon, 7 Nov 2022 23:51:09 +0100 Subject: [PATCH] Updating nbc --- .../naive_bayes_classifier.cpp | 24 +++++++++++++++++++ .../naive_bayes_classifier_test.pl | 19 ++++++++------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/methods/naive_bayes_classifier/naive_bayes_classifier.cpp b/src/methods/naive_bayes_classifier/naive_bayes_classifier.cpp index acf8b74..e06b44f 100644 --- a/src/methods/naive_bayes_classifier/naive_bayes_classifier.cpp +++ b/src/methods/naive_bayes_classifier/naive_bayes_classifier.cpp @@ -36,6 +36,13 @@ void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer da { // convert the Prolog array to arma::mat mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum); + // check if labels fit the data + if (data.n_cols != labelsArrSize) + { + raisePrologSystemExeption("The number of data points does not match the number of labels!"); + return; + } + // convert the Prolog array to arma::rowvec Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize); @@ -186,6 +193,12 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo { // convert the Prolog array to arma::mat mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum); + // check if labels fit the data + if (data.n_cols != labelsArrSize) + { + raisePrologSystemExeption("The number of data points does not match the number of labels!"); + return; + } // convert the Prolog array to arma::rowvec Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize); @@ -195,6 +208,11 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo { naiveBayesClassifier.Train(data, labelsVector, numClasses, (incrementalVariance == 1)); } + catch(const std::out_of_range& e) + { + raisePrologSystemExeption("The given Labels dont fit the format [0,Numclasses-1]!"); + return; + } catch(const std::exception& e) { raisePrologSystemExeption(e.what()); @@ -215,6 +233,12 @@ void trainPoint(float *pointArr, SP_integer pointArrSize, // convert the Prolog array to arma::rowvec rowvec pointVector = convertArrayToRowvec(pointArr, pointArrSize); + if(label < 0) + { + raisePrologSystemExeption("The given Label should be positive!"); + return; + } + try { diff --git a/src/methods/naive_bayes_classifier/naive_bayes_classifier_test.pl b/src/methods/naive_bayes_classifier/naive_bayes_classifier_test.pl index a334ac2..8bc7b33 100644 --- a/src/methods/naive_bayes_classifier/naive_bayes_classifier_test.pl +++ b/src/methods/naive_bayes_classifier/naive_bayes_classifier_test.pl @@ -173,18 +173,21 @@ test(nbc_TrainMatrix_Negative_NumClasses, fail) :- reset_Model_NoTrain, trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], -2, 0). -test(nbc_TrainMatrix_Too_Short_Label, [error(_,system_error('Error'))]) :- +test(nbc_TrainMatrix_Too_Short_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :- reset_Model_NoTrain, trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1], 2, 0). -test(nbc_TrainMatrix_Too_Long_Label, [error(_,system_error('Error'))]) :- +test(nbc_TrainMatrix_Too_Long_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :- reset_Model_NoTrain, trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1,0,1], 2, 0). -test(nbc_TrainMatrix_Too_Many_Label_Classes, [error(_,system_error('Error'))]) :- +test(nbc_TrainMatrix_Too_Many_Label_Classes, [error(_,system_error('The given Labels dont fit the format [0,Numclasses-1]!'))]) :- reset_Model_NoTrain, trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,2,3], 2, 0). +test(nbc_TrainMatrix_After_InitTrain, [error(_,system_error('addition: incompatible matrix dimensions: 3x1 and 4x1'))]) :- + reset_Model_WithTrain, + trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, [0,1,0], 2, 0). %% Successful Tests @@ -192,9 +195,7 @@ test(nbc_TrainMatrix_Normal_Use) :- reset_Model_NoTrain, trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, 0). -test(nbc_TrainMatrix_After_InitTrain) :- - reset_Model_WithTrain, - trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, [0,1,0], 2, 0). + test(nbc_TrainMatrix_CSV_Input) :- reset_Model_NoTrain, @@ -224,9 +225,9 @@ test(nbc_TrainPoint_Too_Long_Point, [error(_,system_error('Error'))]) :- %% Successful Tests -test(nbc_TrainPoint_Normal_Use) :- - reset_Model_NoTrain, - trainPoint([5.1,3.5,1.4], 0). +%%test(nbc_TrainPoint_Normal_Use) :- +%% reset_Model_NoTrain, +%% trainPoint([5.1,3.5,1.4], 1). test(nbc_TrainPoint_After_InitTrain) :- reset_Model_WithTrain, -- GitLab