Select Git revision
softmax_regression.cpp
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
softmax_regression.cpp 6.91 KiB
#include <sicstus/sicstus.h>
/* ex_glue.h is generated by splfr from the foreign/[2,3] facts.
Always include the glue header in your foreign resource code.
*/
#include "softmax_regression_glue.h"
#include <mlpack/methods/softmax_regression/softmax_regression.hpp>
#include <mlpack/core.hpp>
// including helper functions for converting between arma structures and arrays
#include "../../helper_files/helper.hpp"
// some of the most used namespaces
using namespace arma;
using namespace mlpack;
using namespace std;
using namespace mlpack::regression;
// Global Variable of the SoftmaxRegression object so it can be accessed from all functions
SoftmaxRegression softmaxRegression;
// input: const size_t inputSize = 0,
// const size_t numClasses = 0,
// const bool fitIntercept = false//
// output:
// description:
// Initializes the softmax_regression model without training.
//
void initModelNoTrain(SP_integer inputSize, SP_integer numClasses,
SP_integer fitIntercept)
{
softmaxRegression = SoftmaxRegression(inputSize, numClasses, (fitIntercept == 1));
}
// input: const arma::mat & data,
// const arma::Row< size_t > & labels,
// const size_t numClasses,
// const double lambda = 0.0001,
// const bool fitIntercept = false,
// OptimizerType optimizer = OptimizerType()
// output:
// description:
// Initializes the softmax_regression model and trains it.
//
void initModelWithTrain(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRowNum,
float *labelsArr, SP_integer labelsArrSize,
SP_integer numClasses, double lambda,
SP_integer fitIntercept)
{
// convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// convert the Prolog array to arma::rowvec
Row< size_t > labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
try
{
softmaxRegression = SoftmaxRegression(data, labelsVector, numClasses, lambda, (fitIntercept == 1));
}
catch(const std::exception& e)
{
raisePrologSystemExeption(e.what());
}
}
// input: const VecType & point
//
// output: size_t predicted label of point
// description:
// Classify the given point.
//
SP_integer classifyPoint(float *pointArr, SP_integer pointArrSize)
{
// convert the Prolog arrays to arma::rowvec
vec pointVector = conv_to<vec>::from(convertArrayToRowvec(pointArr, pointArrSize));
try
{
return softmaxRegression.Classify(pointVector);
}
catch(const std::exception& e)
{
raisePrologSystemExeption(e.what());
return 0;
}
}
// input: const arma::mat & dataset,
// arma::Row< size_t > & labels <-,
// arma::mat & probabilities <-
// output:
// description:
// Classify the given points, returning class probabilities and predicted class label for each point.
//
void classifyMatrix(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRowNum, float **labelsArr, SP_integer *labelsArrSize, float **probsMatArr, SP_integer *probsMatColNum, SP_integer *probsMatRowNum)
{
// convert the Prolog arrays to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// create the ReturnVector
Row< size_t > labelsReturnVector;
// create the ReturnMat
mat probsReturnMat;
try
{
softmaxRegression.Classify(data, labelsReturnVector, probsReturnMat);
}
catch(const std::exception& e)
{
raisePrologSystemExeption(e.what());
return;
}
// check for nan elements
if (labelsReturnVector.has_nan())
{
raisePrologSystemExeption("Labels return Vector contains nan!");
return;
}
if (probsReturnMat.has_nan())
{
raisePrologSystemExeption("Probabilities return Matrix contains nan!");
return;
}
// return the Vector
returnVectorInformation(labelsReturnVector, labelsArr, labelsArrSize);
// return the Matrix
returnMatrixInformation(probsReturnMat, probsMatArr, probsMatColNum, probsMatRowNum);
}
// input: const arma::mat & testData,
// const arma::Row< size_t > & labels
//
// output: double accuracy
// description:
// Computes accuracy of the learned model given the feature data and the labels associated with each data point.
//
double computeAccuracy(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRowNum,
float *labelsArr, SP_integer labelsArrSize)
{
// convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// convert the Prolog array to arma::rowvec
Row< size_t > labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
try
{
return softmaxRegression.ComputeAccuracy(data, labelsVector);
}
catch(const std::out_of_range& e)
{
raisePrologSystemExeption("The Labels Vector has the wrong Dimension!");
return 0.0;
}
catch(const std::exception& e)
{
raisePrologSystemExeption(e.what());
return 0.0;
}
}
// input:
// output: size_t
// description:
// Gets the features size of the training data.
//
SP_integer featureSize()
{
return softmaxRegression.FeatureSize();
}
// input:
// output: arma::mat&
// description:
// Get the model parameters.
//
void parameters(float **parametersMatArr, SP_integer *parametersMatColNum, SP_integer *parametersMatRowNum)
{
// create the ReturnMat
mat parametersReturnMat = softmaxRegression.Parameters();
// return the Matrix
returnMatrixInformation(parametersReturnMat, parametersMatArr, parametersMatColNum, parametersMatRowNum);
}
// input: const arma::mat & data,
// const arma::Row< size_t > & labels,
// const size_t numClasses,
// OptimizerType optimizer = OptimizerType()
//
// output: double objective value of final point
// description:
// Trains the softmax regression model with the given training data.
//
double train(float *dataMatArr, SP_integer dataMatSize, SP_integer dataMatRowNum,
float *labelsArr, SP_integer labelsArrSize, SP_integer numClasses)
{
// convert the Prolog array to arma::mat
mat data = convertArrayToMat(dataMatArr, dataMatSize, dataMatRowNum);
// convert the Prolog array to arma::rowvec
Row< size_t > labelsVector = convertArrayToVec(labelsArr, labelsArrSize);
try
{
return softmaxRegression.Train(data, labelsVector, numClasses);
}
catch(const std::exception& e)
{
raisePrologSystemExeption(e.what());
return 0.0;
}
}