Subscribe / Unsubscribe Enewsletters | Login | Register

Pencil Banner

Machine learning for Java developers

Gregor Roth | Sept. 18, 2017
Set up a machine learning algorithm and develop your first prediction function in Java.

public static LinearRegressionFunction train(LinearRegressionFunction targetFunction,
                                             List<Double[]> dataset,
                                             List<Double> labels,
                                             double alpha) {
   int m = dataset.size();
   double[] thetaVector = targetFunction.getThetas();
   double[] newThetaVector = new double[thetaVector.length];

   // compute the new theta of each element of the theta array
   for (int j = 0; j < thetaVector.length; j++) {
      // summarize the error gap * feature
      double sumErrors = 0;
      for (int i = 0; i < m; i++) {
         Double[] featureVector = dataset.get(i);
         double error = targetFunction.apply(featureVector) - labels.get(i);
         sumErrors += error * featureVector[j];

      // compute the new theta value
      double gradient = (1.0 / m) * sumErrors;
      newThetaVector[j] = thetaVector[j] - alpha * gradient;

   return new LinearRegressionFunction(newThetaVector);

To validate that the cost decreases continuously, you can execute the cost function J(θ) after each training step. With each iteration, the cost must decrease. If it doesn't, then the value of the learning rate parameter is too large, and the algorithm will shoot past the minimum value. In this case the gradient descent algorithm fails.

The diagram below shows the target function using the computed, new theta parameter, starting with an initial theta vector of { 1.0, 1.0 }. The left-side column shows the prediction graph after 50 iterations; the middle column after 200 iterations; and the right column after 1,000 iterations. As you see, the cost decreases after each iteration, as the new target function fits better and better. After 500 to 600 iterations the theta parameters no longer change significantly and the cost reaches a stable plateau. The accuracy of the target function will no longer significantly improve from this point.


Previous Page  1  2  3  4  5  6  7  8  9  10  11  Next Page 

Sign up for Computerworld eNewsletters.