Skip to content
Snippets Groups Projects
Commit 14d5ae8d authored by Jakhes's avatar Jakhes
Browse files

Adding lsh tests

parent f89ae89f
No related branches found
No related tags found
No related merge requests found
:- module(lsh, [initModel/7,
:- module(lsh, [initAndTrainModel/7,
computeRecall/5,
searchWithQuery/9,
searchNoQuery/7,
train/7]).
searchNoQuery/7]).
%% requirements of library(struct)
:- load_files(library(str_decl),
......@@ -34,11 +33,16 @@
%% --Description--
%% Initiatzes the model and trains it.
%%
initModel(ReferenceList, ReferenceRows, NumProj, NumTables, HashWidth, SecondHashSize, BucketSize) :-
initAndTrainModel(ReferenceList, ReferenceRows, NumProj, NumTables, HashWidth, SecondHashSize, BucketSize) :-
NumProj >= 0,
NumTables >= 0,
HashWidth >= 0.0,
SecondHashSize > 0,
BucketSize > 0,
convert_list_to_float_array(ReferenceList, ReferenceRows, array(Xsize, Xrownum, X)),
initModelI(X, Xsize, Xrownum, NumProj, NumTables, HashWidth, SecondHashSize, BucketSize).
initAndTrainModelI(X, Xsize, Xrownum, NumProj, NumTables, HashWidth, SecondHashSize, BucketSize).
foreign(initModel, c, initModelI(+pointer(float_array), +integer, +integer,
foreign(initAndTrainModel, c, initAndTrainModelI(+pointer(float_array), +integer, +integer,
+integer, +integer, +float32, +integer, +integer)).
......@@ -77,6 +81,9 @@ foreign(computeRecall, c, computeRecallI(+pointer(float_array), +integer, +integ
%% The matrices will be set to the size of n columns by k rows, where n is the number of points in the query dataset and k is the number of neighbors being searched for.
%%
searchWithQuery(QueryList, QueryRows, K, ResultingNeighborsList, YCols, DistancesList, ZCols, NumTablesToSearch, T) :-
K > 0,
NumTablesToSearch >= 0,
T >= 0,
convert_list_to_float_array(QueryList, QueryRows, array(Xsize, Xrows, X)),
searchWithQueryI(X, Xsize, Xrows, K, Y, YCols, YRows, Z, ZCols, ZRows, NumTablesToSearch, T),
convert_float_array_to_2d_list(Y, YCols, YRows, ResultingNeighborsList),
......@@ -103,6 +110,9 @@ foreign(searchWithQuery, c, searchWithQueryI( +pointer(float_array), +integer,
%% The matrices will be set to the size of n columns by k rows, where n is the number of points in the query dataset and k is the number of neighbors being searched for.
%%
searchNoQuery(K, ResultingNeighborsList, YCols, DistancesList, ZCols, NumTablesToSearch, T) :-
K > 0,
NumTablesToSearch >= 0,
T >= 0,
searchNoQueryI(K, Y, YCols, YRows, Z, ZCols, ZRows, NumTablesToSearch, T),
convert_float_array_to_2d_list(Y, YCols, YRows, ResultingNeighborsList),
convert_float_array_to_2d_list(Z, ZCols, ZRows, DistancesList).
......@@ -113,32 +123,11 @@ foreign(searchNoQuery, c, searchNoQueryI( +integer,
+integer, +integer)).
%% --Input--
%% mat referenceSet,
%% int numProj => 10-50,
%% int numTables => 10-20,
%% float32 hashWidth => 0.0,
%% int secondHashSize => 99901,
%% int bucketSize => 500
%%
%% --Output--
%%
%% --Description--
%% Train the LSH model on the given dataset.
%%
train(ReferenceList, ReferenceRows, NumProj, NumTables, HashWidth, SecondHashSize, BucketSize) :-
convert_list_to_float_array(ReferenceList, ReferenceRows, array(Xsize, Xrownum, X)),
trainI(X, Xsize, Xrownum, NumProj, NumTables, HashWidth, SecondHashSize, BucketSize).
foreign(train, c, trainI(+pointer(float_array), +integer, +integer,
+integer, +integer, +float32, +integer, +integer)).
%% Defines the functions that get connected from main.cpp
foreign_resource(lsh, [ initModel,
foreign_resource(lsh, [ initAndTrainModel,
computeRecall,
searchWithQuery,
searchNoQuery,
train]).
searchNoQuery]).
:- load_foreign_resource(lsh).
......@@ -7,37 +7,146 @@
:- use_module('../../helper_files/helper.pl').
reset_Model :-
initModel(1,0,50,0.0001).
initAndTrainModel([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 20, 10, 0.0, 99901, 500).
%%
%% TESTING predicate predicate/10
%% TESTING predicate initAndTrainModel/7
%%
:- begin_tests(predicate).
:- begin_tests(initAndTrainModel).
%% Failure Tests
test(testDescription, [error(domain_error('expectation' , culprit), _)]) :-
reset_Model_No_Train(perceptron),
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,0,0,0], 2, culprit, 50, 0.0001, _).
test(testDescription2, [error(_,system_error('The values of the Label have to start at 0 and be >= 0 and < the given numClass!'))]) :-
reset_Model_No_Train(perceptron),
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,1,0,2], 2, perceptron, 50, 0.0001, _).
test(lsh_InitAndTrainModel_Negative_NumProj, fail) :-
initAndTrainModel([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, -20, 10, 0.0, 99901, 500).
test(lsh_InitAndTrainModel_Negative_NumTables, fail) :-
initAndTrainModel([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 20, -10, 0.0, 99901, 500).
test(lsh_InitAndTrainModel_Negative_HashWidth, fail) :-
initAndTrainModel([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 20, 10, -1.0, 99901, 500).
test(lsh_InitAndTrainModel_Negative_SecondHashSize, fail) :-
initAndTrainModel([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 20, 10, 0.0, -99901, 500).
test(lsh_InitAndTrainModel_Negative_BucketSize, fail) :-
initAndTrainModel([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 20, 10, 0.0, 99901, -500).
%% Successful Tests
test(testDescription3, [true(Error =:= 1)]) :-
reset_Model_No_Train(perceptron),
train([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [0,0,0,0], 2, perceptron, 50, 0.0001, Error).
test(lsh_InitAndTrainModel_Normal_Use) :-
initAndTrainModel([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 20, 10, 0.0, 99901, 500).
test(testDescription4, [true(Error =:= 0.9797958971132711)]) :-
reset_Model_No_Train(perceptron),
test(lsh_InitAndTrainModel_CSV_Input) :-
open('src/data_csv/iris2.csv', read, File),
take_csv_row(File, skipFirstRow,10, Data),
train(Data, 4, [0,1,0,1,1,0,1,1,1,0], 2, perceptron, 50, 0.0001, Error).
train(Data, 4, 25, 15, 1.5, 99901, 200).
:- end_tests(initAndTrainModel).
%%
%% TESTING predicate computeRecall/5
%%
:- begin_tests(computeRecall).
%% Failure Tests
test(lsh_ComputeRecall_Wrong_Dimensions, [error(_, system_error('Error'))]) :-
reset_Model,
computeRecall([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, _).
%% Successful Tests
test(lsh_ComputeRecall_Normal_Use) :-
reset_Model,
computeRecall([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, [5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, _).
:- end_tests(computeRecall).
%%
%% TESTING predicate searchWithQuery/10
%%
:- begin_tests(searchWithQuery).
%% Failure Tests
test(lsh_SearchWithQuery_Wrong_Dimensions, [error(_, system_error('Error'))]) :-
reset_Model,
searchWithQuery([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 4, 3, _, _, _, _, 0, 0).
test(lsh_SearchWithQuery_Too_High_K, [error(_, system_error('Error'))]) :-
reset_Model,
searchWithQuery([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 10, _, _, _, _, 0, 0).
test(lsh_SearchWithQuery_Negative_K, fail) :-
reset_Model,
searchWithQuery([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, -3, _, _, _, _, 0, 0).
test(lsh_SearchWithQuery_Negative_NumTablesToSearch, fail) :-
reset_Model,
searchWithQuery([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 3, _, _, _, _, -10, 0).
test(lsh_SearchWithQuery_Negative_T, fail) :-
reset_Model,
searchWithQuery([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 3, _, _, _, _, 0, -10).
%% Successful Tests
test(lsh_SearchWithQuery_Normal_Use) :-
reset_Model,
searchWithQuery([5.1,3.5,1.4,4.9,3.0,1.4,4.7,3.2,1.3,4.6,3.1,1.5], 3, 3, NeighborResultsList, _, DistancesList, _, 0, 0),
print('\nNeighborResults: '),
print(NeighborResultsList),
print('\nDistances: '),
print(DistancesList).
:- end_tests(searchWithQuery).
%%
%% TESTING predicate searchNoQuery/10
%%
:- begin_tests(searchNoQuery).
%% Failure Tests
test(lsh_SearchNoQuery_Too_High_K, [error(_, system_error('Error'))]) :-
reset_Model,
searchNoQuery(10, _, _, _, _, 0, 0).
test(lsh_SearchNoQuery_Negative_K, fail) :-
reset_Model,
searchNoQuery(-3, _, _, _, _, 0, 0).
test(lsh_SearchNoQuery_Negative_NumTablesToSearch, fail) :-
reset_Model,
searchNoQuery(3, _, _, _, _, -10, 0).
test(lsh_SearchNoQuery_Negative_T, fail) :-
reset_Model,
searchNoQuery(3, _, _, _, _, 0, -10).
%% Successful Tests
test(lsh_SearchNoQuery_Normal_Use) :-
reset_Model,
searchNoQuery(3, NeighborResultsList, _, DistancesList, _, 0, 0),
print('\nNeighborResults: '),
print(NeighborResultsList),
print('\nDistances: '),
print(DistancesList).
:- end_tests(predicate).
:- end_tests(searchNoQuery).
run_lsh_tests :-
run_tests.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment