A K-Nearest Neighbors Classifier in Java

Tablesaw is a platform for data science in Java designed for ease-of-use. It builds on some great open source libraries, including Smile, the machine learning library that provides the nuts and bolts of our machine learning capability.

Here we show a Tablesaw/Java version of a tutorial written by Mike de Waard using standalone Smile and Scala. You can find the original here. First we load the data:

Table example = Table.createFromCsv("data/KNN_Example_1.csv");

Next we output the table’s structure. The method call example.structure(); does this.

Index Column Name Column Type
0 X FLOAT
1 Y FLOAT
2 Label SHORT_INT

We want to predict the integer variable “Label” from the two floating point variables X and Y.  Label takes on the value 1 or 0, which you can see using the asSet() method:

example.shortColumn("Label").asSet();

which returns {0, 1}

To get a sense for their distribution and how they relate to X and Y, we plot X and Y, coloring the points according to whether the corresponding Label is one or zero.

Scatter.show("Example data", example.nCol("X"), example.nCol("Y"), 
    example.splitOn(example.shortColumn("Label")));

knn scatter

As you can see the green “ones” cluster to the upper right and the red “zeros” to the lower left.

To validate our model we split the data in half, randomly assigning rows to each half. Tablesaw can do this with the sampleSplit() method:

Table[] splits = example.sampleSplit(.5);
Table train = splits[0];
Table test = splits[1];

Next we build the model using the training data:

Knn knn = 
    Knn.learn(2, train.shortColumn(2), train.nCol("X"), train.nCol("Y"));

To see how well our model works, we produce a confusion matrix using the test dataset:

ConfusionMatrix matrix = 
    knn.predictMatrix(test.shortColumn(2), test.nCol("X"), test.nCol("Y"));

A confusion matrix summarizes the results of our classification. Correct classifications are shown on the diagonal. Our matrix for this sample looks like this:

n = 50 Actual 0 Actual 1
Predicted 0 22 4
Predicted 1 2 22

You can also get the overall accuracy of the test using the matrix’s accuracy() method:

matrix.accuracy();

which in this case is 0.88.

That’s all there is to creating a simple classifier. Tablesaw also supports LogisticRegression and LDA (Linear Discriminant Analysis) classifiers. Their use is nearly identical.

 

 

 

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s