Skip to content
Snippets Groups Projects
Commit d9adbafd authored by Jakhes's avatar Jakhes
Browse files

Finishing perceptron tests

parent 2eb17d9d
No related branches found
No related tags found
No related merge requests found
......@@ -54,6 +54,12 @@ void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer da
{
// convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// check if labels and weights fit the data
if (data.n_cols != labelsArrSize)
{
raisePrologSystemExeption("The number of data points does not match the number of labels!");
return;
}
// convert the Prolog array to arma::rowvec
Row< size_t > labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
......@@ -63,6 +69,11 @@ void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer da
{
perceptronGlobal = Perceptron(data, labelsVector, numClasses, maxIterations);
}
catch(const std::out_of_range& e)
{
raisePrologSystemExeption("The given Labels dont fit the format [0,Numclasses-1]!");
return;
}
catch(const std::exception& e)
{
raisePrologSystemExeption(e.what());
......@@ -99,7 +110,8 @@ void classify(float *testMatArr, SP_integer testMatSize, SP_integer testMatRowNu
mat test = convertArrayToMat(testMatArr, testMatSize, testMatRowNum);
// create the ReturnVector
Row< size_t > predictLabelsReturnVector;
// need to be initialized with the amount of test cols or it will cause in classify a index out of bounds error
Row< size_t > predictLabelsReturnVector(test.n_cols);
try
......@@ -133,6 +145,17 @@ void train(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRowNum,
{
// convert the Prolog arrays to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// check if labels and weights fit the data
if (data.n_cols != labelsArrSize)
{
raisePrologSystemExeption("The number of data points does not match the number of labels!");
return;
}
if (data.n_cols != instanceWeightsArrSize)
{
raisePrologSystemExeption("The number of data points does not match the number of weights!");
return;
}
// convert the Prolog array
Row< size_t > labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
......@@ -144,6 +167,11 @@ void train(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRowNum,
{
perceptronGlobal.Train(data, labelsVector, numClasses, instanceWeightsVector);
}
catch(const std::out_of_range& e)
{
raisePrologSystemExeption("The given Labels dont fit the format [0,Numclasses-1]!");
return;
}
catch(const std::exception& e)
{
raisePrologSystemExeption(e.what());
......
......@@ -37,7 +37,7 @@ test(perceptron_InitModelNoTrain_Normal_Use) :-
initModelNoTrain(2, 3, 1000).
test(perceptron_InitModelNoTrain_Alternative_Input) :-
initModelNoTrain(0, 0, 1000).
initModelNoTrain(0, 1, 1000).
:- end_tests(initModelNoTrain).
......@@ -58,13 +58,13 @@ test(perceptron_InitModelWithTrain_Negative_MaxIterations, fail) :-
initModelWithTrain([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, -1000).
test(random_forest_InitModelWithTrainNoWeights_Too_Short_Label, [error(_,system_error('Error'))]) :-
test(random_forest_InitModelWithTrainNoWeights_Too_Short_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
initModelWithTrain([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1], 2, 1000).
test(random_forest_InitModelWithTrainNoWeights_Too_Long_Label, [error(_,system_error('Error'))]) :-
test(random_forest_InitModelWithTrainNoWeights_Too_Long_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
initModelWithTrain([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1,0,1], 2, 1000).
test(random_forest_InitModelWithTrainNoWeights_Too_Many_Label_Classes, [error(_,system_error('Error'))]) :-
test(random_forest_InitModelWithTrainNoWeights_Too_Many_Label_Classes, [error(_,system_error('The given Labels dont fit the format [0,Numclasses-1]!'))]) :-
initModelWithTrain([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,2,3], 2, 1000).
......@@ -89,9 +89,12 @@ test(perceptron_InitModelWithTrain_CSV_Input) :-
%% Failure Tests
test(perceptron_Biases_Before_Train, [error(_,system_error('Error'))]) :-
%% Doesnt cause an Error
test(perceptron_Biases_Before_Train) :-
reset_Model_NoTrain,
biases(_).
biases(Biases),
print('\nBiases: '),
print(Biases).
%% Successful Tests
......@@ -113,21 +116,21 @@ test(perceptron_Biases_AfterTrain) :-
%% Failure Tests
test(perceptron_Classify_Before_Train, [error(_,system_error('Error'))]) :-
reset_Model_NoTrain,
classify([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, _).
test(perceptron_Classify_Before_Train_Wrong_Dims, [error(_,system_error('Error'))]) :-
test(perceptron_Classify_Before_Train_Wrong_Dims, [error(_,system_error('matrix multiplication: incompatible matrix dimensions: 2x3 and 4x1'))]) :-
reset_Model_NoTrain,
classify([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, _).
test(perceptron_Classify_Different_Dims_Than_Train) :-
test(perceptron_Classify_Different_Dims_Than_Train, [error(_,system_error('matrix multiplication: incompatible matrix dimensions: 2x3 and 4x1'))]) :-
reset_Model_WithTrain,
classify([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, _).
%% Successful Tests
test(perceptron_Classify_Before_Train) :-
reset_Model_NoTrain,
classify([1.0,2.0,3.0], 3, _).
test(perceptron_Classify_Normal_Use) :-
reset_Model_WithTrain,
classify([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, PredictList),
......@@ -149,36 +152,33 @@ test(perceptron_Train_Negaitve_NumClasses, fail) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], -2, [0.1,0.2,0.3,0.4]).
test(perceptron_Train_Too_Small_Data_Dims, [error(_,system_error('Error'))]) :-
%% Seems to overide the dimensionality from reset_Model_NoTrain
test(perceptron_Train_Too_Small_Data_Dims) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, [0,1,0], 2, [0.1,0.2,0.3]).
test(perceptron_Train_Too_Short_Label, [error(_,system_error('Error'))]) :-
test(perceptron_Train_Too_Short_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1], 2, [0.1,0.2,0.3,0.4]).
test(perceptron_Train_Too_Long_Label, [error(_,system_error('Error'))]) :-
test(perceptron_Train_Too_Long_Label, [error(_,system_error('The number of data points does not match the number of labels!'))]) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1,0,1], 2, [0.1,0.2,0.3,0.4]).
test(perceptron_Train_Too_Many_Label_Classes, [error(_,system_error('Error'))]) :-
test(perceptron_Train_Too_Many_Label_Classes, [error(_,system_error('The given Labels dont fit the format [0,Numclasses-1]!'))]) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,2,3], 2, [0.1,0.2,0.3,0.4]).
test(perceptron_Train_Too_Short_Label, [error(_,system_error('Error'))]) :-
test(perceptron_Train_Too_Short_Weights, [error(_,system_error('The number of data points does not match the number of weights!'))]) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, [0.1,0.2]).
test(perceptron_Train_Too_Long_Label, [error(_,system_error('Error'))]) :-
test(perceptron_Train_Too_Long_Weights, [error(_,system_error('The number of data points does not match the number of weights!'))]) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, [0.1,0.2,0.3,0.4,0.6,0.7]).
test(perceptron_Train_Too_Many_Label_Classes, [error(_,system_error('Error'))]) :-
reset_Model_NoTrain,
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,1], 2, [-0.1,0.2,1.3,2.4]).
%% Successful Tests
......@@ -198,9 +198,12 @@ test(perceptron_Train_Normal_Use) :-
%% Failure Tests
test(perceptron_Weights_Before_Train, [error(_,system_error('Error'))]) :-
%% Doesnt cause an error
test(perceptron_Weights_Before_Train) :-
reset_Model_NoTrain,
weights(_, _).
weights(Weights, _),
print('\nWeights: '),
print(Weights).
%% Successful Tests
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment