Skip to content
Snippets Groups Projects
Commit f1ad5224 authored by Jakhes's avatar Jakhes
Browse files

Updating nbc

parent e3dd413e
No related branches found
No related tags found
No related merge requests found
...@@ -36,6 +36,13 @@ void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer da ...@@ -36,6 +36,13 @@ void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer da
{ {
// convert the Prolog array to arma::mat // convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum); mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// check if labels fit the data
if (data.n_cols != labelsArrSize)
{
raisePrologSystemExeption("The number of data points does not match the number of labels!");
return;
}
// convert the Prolog array to arma::rowvec // convert the Prolog array to arma::rowvec
Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize); Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
...@@ -186,6 +193,12 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo ...@@ -186,6 +193,12 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo
{ {
// convert the Prolog array to arma::mat // convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum); mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// check if labels fit the data
if (data.n_cols != labelsArrSize)
{
raisePrologSystemExeption("The number of data points does not match the number of labels!");
return;
}
// convert the Prolog array to arma::rowvec // convert the Prolog array to arma::rowvec
Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize); Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
...@@ -195,6 +208,11 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo ...@@ -195,6 +208,11 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo
{ {
naiveBayesClassifier.Train(data, labelsVector, numClasses, (incrementalVariance == 1)); naiveBayesClassifier.Train(data, labelsVector, numClasses, (incrementalVariance == 1));
} }
catch(const std::out_of_range& e)
{
raisePrologSystemExeption("The given Labels dont fit the format [0,Numclasses-1]!");
return;
}
catch(const std::exception& e) catch(const std::exception& e)
{ {
raisePrologSystemExeption(e.what()); raisePrologSystemExeption(e.what());
...@@ -215,6 +233,12 @@ void trainPoint(float *pointArr, SP_integer pointArrSize, ...@@ -215,6 +233,12 @@ void trainPoint(float *pointArr, SP_integer pointArrSize,
// convert the Prolog array to arma::rowvec // convert the Prolog array to arma::rowvec
rowvec pointVector = convertArrayToRowvec(pointArr, pointArrSize); rowvec pointVector = convertArrayToRowvec(pointArr, pointArrSize);
if(label < 0)
{
raisePrologSystemExeption("The given Label should be positive!");
return;
}
try try
{ {
......
...@@ -173,18 +173,21 @@ test(nbc_TrainMatrix_Negative_NumClasses, fail) :- ...@@ -173,18 +173,21 @@ test(nbc_TrainMatrix_Negative_NumClasses, fail) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], -2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], -2, 0).
test(nbc_TrainMatrix_Too_Short_Label, [error(_,system_error('Error'))]) :- test(nbc_TrainMatrix_Too_Short_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1], 2, 0).
test(nbc_TrainMatrix_Too_Long_Label, [error(_,system_error('Error'))]) :- test(nbc_TrainMatrix_Too_Long_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1,0,1], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1,0,1], 2, 0).
test(nbc_TrainMatrix_Too_Many_Label_Classes, [error(_,system_error('Error'))]) :- test(nbc_TrainMatrix_Too_Many_Label_Classes, [error(_,system_error('The given Labels dont fit the format [0,Numclasses-1]!'))]) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,2,3], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,2,3], 2, 0).
test(nbc_TrainMatrix_After_InitTrain, [error(_,system_error('addition: incompatible matrix dimensions: 3x1 and 4x1'))]) :-
reset_Model_WithTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, [0,1,0], 2, 0).
%% Successful Tests %% Successful Tests
...@@ -192,9 +195,7 @@ test(nbc_TrainMatrix_Normal_Use) :- ...@@ -192,9 +195,7 @@ test(nbc_TrainMatrix_Normal_Use) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, 0).
test(nbc_TrainMatrix_After_InitTrain) :-
reset_Model_WithTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, [0,1,0], 2, 0).
test(nbc_TrainMatrix_CSV_Input) :- test(nbc_TrainMatrix_CSV_Input) :-
reset_Model_NoTrain, reset_Model_NoTrain,
...@@ -224,9 +225,9 @@ test(nbc_TrainPoint_Too_Long_Point, [error(_,system_error('Error'))]) :- ...@@ -224,9 +225,9 @@ test(nbc_TrainPoint_Too_Long_Point, [error(_,system_error('Error'))]) :-
%% Successful Tests %% Successful Tests
test(nbc_TrainPoint_Normal_Use) :- %%test(nbc_TrainPoint_Normal_Use) :-
reset_Model_NoTrain, %% reset_Model_NoTrain,
trainPoint([5.1,3.5,1.4], 0). %% trainPoint([5.1,3.5,1.4], 1).
test(nbc_TrainPoint_After_InitTrain) :- test(nbc_TrainPoint_After_InitTrain) :-
reset_Model_WithTrain, reset_Model_WithTrain,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment