Subscribe / Unsubscribe Enewsletters | Login | Register

Pencil Banner

Machine learning for Java developers

Gregor Roth | Sept. 18, 2017
Set up a machine learning algorithm and develop your first prediction function in Java.


// create the dataset
List<Double[]> dataset = new ArrayList<>();
dataset.add(new Double[] { 1.0,  90.0,  8100.0 });   // feature vector of house#1
dataset.add(new Double[] { 1.0, 101.0, 10201.0 });   // feature vector of house#2
dataset.add(new Double[] { 1.0, 103.0, 10609.0 });   // ...
//...

// create the labels
List<Double> labels = new ArrayList<>();
labels.add(249.0);        // price label of house#1
labels.add(338.0);        // price label of house#2
labels.add(304.0);        // ...
//...

// scale the extended feature list
Function<Double[], Double[]> scalingFunc = FeaturesScaling.createFunction(dataset);
List<Double[]>  scaledDataset  = dataset.stream().map(scalingFunc).collect(Collectors.toList());

// create hypothesis function with initial thetas and train it with learning rate 0.1
LinearRegressionFunction targetFunction =  new LinearRegressionFunction(new double[] { 1.0, 1.0, 1.0 });
for (int i = 0; i < 10000; i++) {
   targetFunction = Learner.train(targetFunction, scaledDataset, labels, 0.1);
}


// make a prediction of a house with size if 600 m2
Double[] scaledFeatureVector = scalingFunc.apply(new Double[] { 1.0, 600.0, 360000.0 });
double predictedPrice = targetFunction.apply(scaledFeatureVector);

As you add more and more features, you may find that the target function fits better and better--but beware! If you go too far, and add too many features, you could end up with a target function that is overfitting.

Overfitting and cross-validation

Overfitting occurs when the target function or model fits the training data too well, by capturing noise or random fluctuations in the training data. A pattern of overfitting behavior is shown in the graph on the far-right side below:

machine learning fig4

Although an overfitting model matches very well on the training data, it will perform badly when asked to solve for unknown, unseen data. There are a few ways to avoid overfitting.

  • Use a larger set of training data.
  • Use an improved machine learning algorithm by considering regularization.
  • Use fewer features, as shown in the middle diagram above.

If your predictive model overfits, you should remove any features that do not contribute to its accuracy. The challenge here is to find the features that contribute most meaningfully to your prediction output.

As shown in the diagrams, overfitting can be identified by visualizing graphs. Even though this works well using two dimensional or three dimensional graphs, it will become difficult if you use more than two domain-specific features. This is why cross-validation is often used to detect overfitting.

In a cross-validation, you evaluate the trained models using an unseen validation data set after the learning process has completed. The available, labeled data set will be split into three parts:

  • The training data set.
  • The validation data set.
  • The test data set.

In this case, 60 percent of the house example records may be used to train different variants of the target algorithm. After the learning process, half of the remaining, untouched example records will be used to validate that the trained target algorithms work well for unseen data.

 

Previous Page  1  2  3  4  5  6  7  8  9  10  11  Next Page 

Sign up for Computerworld eNewsletters.