Skip to content
Snippets Groups Projects
Commit f1ad5224 authored by Jakhes's avatar Jakhes
Browse files

Updating nbc

parent e3dd413e
Branches
No related tags found
No related merge requests found
...@@ -36,6 +36,13 @@ void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer da ...@@ -36,6 +36,13 @@ void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer da
{ {
// convert the Prolog array to arma::mat // convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum); mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// check if labels fit the data
if (data.n_cols != labelsArrSize)
{
raisePrologSystemExeption("The number of data points does not match the number of labels!");
return;
}
// convert the Prolog array to arma::rowvec // convert the Prolog array to arma::rowvec
Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize); Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
...@@ -186,6 +193,12 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo ...@@ -186,6 +193,12 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo
{ {
// convert the Prolog array to arma::mat // convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum); mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// check if labels fit the data
if (data.n_cols != labelsArrSize)
{
raisePrologSystemExeption("The number of data points does not match the number of labels!");
return;
}
// convert the Prolog array to arma::rowvec // convert the Prolog array to arma::rowvec
Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize); Row<size_t> labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
...@@ -195,6 +208,11 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo ...@@ -195,6 +208,11 @@ void trainMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRo
{ {
naiveBayesClassifier.Train(data, labelsVector, numClasses, (incrementalVariance == 1)); naiveBayesClassifier.Train(data, labelsVector, numClasses, (incrementalVariance == 1));
} }
catch(const std::out_of_range& e)
{
raisePrologSystemExeption("The given Labels dont fit the format [0,Numclasses-1]!");
return;
}
catch(const std::exception& e) catch(const std::exception& e)
{ {
raisePrologSystemExeption(e.what()); raisePrologSystemExeption(e.what());
...@@ -215,6 +233,12 @@ void trainPoint(float *pointArr, SP_integer pointArrSize, ...@@ -215,6 +233,12 @@ void trainPoint(float *pointArr, SP_integer pointArrSize,
// convert the Prolog array to arma::rowvec // convert the Prolog array to arma::rowvec
rowvec pointVector = convertArrayToRowvec(pointArr, pointArrSize); rowvec pointVector = convertArrayToRowvec(pointArr, pointArrSize);
if(label < 0)
{
raisePrologSystemExeption("The given Label should be positive!");
return;
}
try try
{ {
......
...@@ -173,18 +173,21 @@ test(nbc_TrainMatrix_Negative_NumClasses, fail) :- ...@@ -173,18 +173,21 @@ test(nbc_TrainMatrix_Negative_NumClasses, fail) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], -2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], -2, 0).
test(nbc_TrainMatrix_Too_Short_Label, [error(_,system_error('Error'))]) :- test(nbc_TrainMatrix_Too_Short_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1], 2, 0).
test(nbc_TrainMatrix_Too_Long_Label, [error(_,system_error('Error'))]) :- test(nbc_TrainMatrix_Too_Long_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1,0,1], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1,0,1], 2, 0).
test(nbc_TrainMatrix_Too_Many_Label_Classes, [error(_,system_error('Error'))]) :- test(nbc_TrainMatrix_Too_Many_Label_Classes, [error(_,system_error('The given Labels dont fit the format [0,Numclasses-1]!'))]) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,2,3], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,2,3], 2, 0).
test(nbc_TrainMatrix_After_InitTrain, [error(_,system_error('addition: incompatible matrix dimensions: 3x1 and 4x1'))]) :-
reset_Model_WithTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, [0,1,0], 2, 0).
%% Successful Tests %% Successful Tests
...@@ -192,9 +195,7 @@ test(nbc_TrainMatrix_Normal_Use) :- ...@@ -192,9 +195,7 @@ test(nbc_TrainMatrix_Normal_Use) :-
reset_Model_NoTrain, reset_Model_NoTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, 0). trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, 0).
test(nbc_TrainMatrix_After_InitTrain) :-
reset_Model_WithTrain,
trainMatrix([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, [0,1,0], 2, 0).
test(nbc_TrainMatrix_CSV_Input) :- test(nbc_TrainMatrix_CSV_Input) :-
reset_Model_NoTrain, reset_Model_NoTrain,
...@@ -224,9 +225,9 @@ test(nbc_TrainPoint_Too_Long_Point, [error(_,system_error('Error'))]) :- ...@@ -224,9 +225,9 @@ test(nbc_TrainPoint_Too_Long_Point, [error(_,system_error('Error'))]) :-
%% Successful Tests %% Successful Tests
test(nbc_TrainPoint_Normal_Use) :- %%test(nbc_TrainPoint_Normal_Use) :-
reset_Model_NoTrain, %% reset_Model_NoTrain,
trainPoint([5.1,3.5,1.4], 0). %% trainPoint([5.1,3.5,1.4], 1).
test(nbc_TrainPoint_After_InitTrain) :- test(nbc_TrainPoint_After_InitTrain) :-
reset_Model_WithTrain, reset_Model_WithTrain,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment