Subscribe / Unsubscribe Enewsletters | Login | Register

Pencil Banner

Machine learning for Java developers

Gregor Roth | Sept. 18, 2017
Set up a machine learning algorithm and develop your first prediction function in Java.

Typically, the best-fitting target algorithms will then be selected. The other half of untouched example data will be used to calculate error metrics for the final, selected model. While I won't introduce them here, there are other variations of this technique, such as k fold cross-validation.

 

Machine learning tools and frameworks: Weka

As you've seen, developing and testing a target function requires well-tuned configuration parameters, such as the proper learning rate or iteration count. The example code I've shown reflects a very small set of the possible configuration parameters, and the examples have been simplified to keep the code readable. In practice, you will likely rely on machine learning frameworks, libraries, and tools.

Most frameworks or libraries implement an extensive collection of machine learning algorithms. Additionally, they provide convenient high-level APIs to train, validate, and process data models. Weka is one of the most popular frameworks for the JVM.

Weka provides a Java library for programmatic usage, as well as a graphical workbench to train and validate data models. In the code below, the Weka library is used to create a training data set, which includes features and a label. The setClassIndex() method is used to mark the label column. In Weka, the label is defined as a class:


// define the feature and label attributes
ArrayList<Attribute> attributes = new ArrayList<>();
Attribute sizeAttribute = new Attribute("sizeFeature");
attributes.add(sizeAttribute);
Attribute squaredSizeAttribute = new Attribute("squaredSizeFeature");
attributes.add(squaredSizeAttribute);
Attribute priceAttribute = new Attribute("priceLabel");
attributes.add(priceAttribute);


// create and fill the features list with 5000 examples
Instances trainingDataset = new Instances("trainData", attributes, 5000);
trainingDataset.setClassIndex(trainingSet.numAttributes() - 1);
Instance instance = new DenseInstance(3);

instance.setValue(sizeAttribute, 90.0);
instance.setValue(squaredSizeAttribute, Math.pow(90.0, 2));
instance.setValue(priceAttribute, 249.0);
trainingDataset.add(instance);
Instance instance = new DenseInstance(3);
instance.setValue(sizeAttribute, 101.0);
...

The data set or Instance object can also be stored and loaded as a file. Weka uses an ARFF (Attribute Relation File Format), which is supported by the graphical Weka workbench. This data set is used to train the target function, known as a classifier in Weka.

Recall that in order to train a target function, you have to first choose the machine learning algorithm. In the code below, an instance of the LinearRegression classifier will be created. This classifier will be train by calling the buildClassifier(). The buildClassifier() method tunes the theta parameters based on the training data to find the best-fitting model. Using Weka, you do not have to worry about setting a learning rate or iteration count. Weka also does the feature scaling internally.


Classifier targetFunction = new LinearRegression();
targetFunction.buildClassifier(trainingDataset);

Once it's established, the target function can be used to predict the price of a house, as shown below:


Instances unlabeledInstances = new Instances("predictionset", attributes, 1);
unlabeledInstances.setClassIndex(trainingSet.numAttributes() - 1);
Instance unlabeled = new DenseInstance(3);
unlabeled.setValue(sizeAttribute, 1330.0);
unlabeled.setValue(squaredSizeAttribute, Math.pow(1330.0, 2));
unlabeledInstances.add(unlabeled);

double prediction  = targetFunction.classifyInstance(unlabeledInstances.get(0));

Weka provides an Evaluation class to validate the trained classifier or model. In the code below, a dedicated validation data set is used to avoid biased results. Measures such as the cost or error rate will be printed to the console. Typically, evaluation results are used to compare models that have been trained using different machine-learning algorithms, or a variant of these:

 

Previous Page  1  2  3  4  5  6  7  8  9  10  11  Next Page 

Sign up for Computerworld eNewsletters.