K-means clustering in Java

K-means is the most common form of “centroid” clustering. Unlike classification, clustering is an unsupervised learning method. The categories are not predetermined. Instead, the goal is to search for natural groupings in the dataset, such that the members of each group are similar to each other and different from the members of the other groups. The K represents the number of groups to find.

We’ll use a well known Scotch Whiskey dataset, which is used to cluster whiskeys according to their taste based on data collected from tasting notes. As always, we start by loading data and printing its structure.

Table t = Table.createFromCsv("data/whiskey.csv");
t.structure().print();
Index Column Name Column Type
0 RowID SHORT_INT
1 Distillery CATEGORY
2 Body SHORT_INT
3 Sweetness SHORT_INT
4 Smoky SHORT_INT
5 Medicinal SHORT_INT
6 Tobacco SHORT_INT
7 Honey SHORT_INT
8 Spicy SHORT_INT
9 Winey SHORT_INT
10 Nutty SHORT_INT
11 Malty SHORT_INT
12 Fruity SHORT_INT
13 Floral SHORT_INT
14 Postcode CATEGORY
15 Latitude FLOAT
16 Longitude INTEGER

We create our model using a subset of the columns in the input table:

Kmeans model = new Kmeans(
    5,
    t.nCol(2), t.nCol(3), t.nCol(4), t.nCol(5), t.nCol(6), t.nCol(7),
    t.nCol(8), t.nCol(9), t.nCol(10), t.nCol(11), t.nCol(12), t.nCol(13)
);

Here the argument 5 is K, the number of clusters to form. We picked it on a whim, but will discuss in a bit some slightly more scientific approaches. The other arguments are numeric columns identified by their index in table t.

That’s all it takes to perform the clustering. To see what clusters it formed, we can use the clustered(Column column) method. Here, the column parameter is a column that identifies the records (in this case “Distillery”), and the return value is a table, that maps the distilleries to their clusters.

model.clustered(t.column("Distillery")).print();
Label Cluster
Aberfeldy 0
Aberlour 0
Belvenie 0
BenNevis 0
Benriach 0
Benrinnes 0
Benromach 0
BlairAthol 0
Bowmore 0
Bruichladdich 0
Craigallechie 0
Dailuaine 0
Dalmore 0
Deanston 0
GlenGarioch 0
GlenOrd 0
Glendullan 0
Glenlivet 0
Highland Park 0
Longmorn 0
Mortlach 0
RoyalBrackla 0
RoyalLochnagar 0
Scapa 0
Strathisla 0
AnCnoc 1
ArranIsleOf 1
Auchentoshan 1
Aultmore 1
Bladnoch 1
Bunnahabhain 1
Cardhu 1
Craigganmore 1
Dalwhinnie 1
Dufftown 1
GlenElgin 1
GlenGrant 1
GlenMoray 1
GlenSpey 1
Glenallachie 1
Glenfiddich 1
Glengoyne 1
Glenkinchie 1
Glenlossie 1
Glenmorangie 1
Inchgower 1
Linkwood 1
Loch Lomond 1
Mannochmore 1
Miltonduff 1
Speyburn 1
Speyside 1
Strathmill 1
Tamdhu 1
Tamnavulin 1
Teaninich 1
Tobermory 1
Tomintoul 1
Ardmore 2
Auchroisk 2
Balblair 2
Edradour 2
GlenDeveronMacduff 2
GlenKeith 2
Glenfarclas 2
Glenrothes 2
Glenturret 2
Knochando 2
OldFettercairn 2
Tomatin 2
Tormore 2
Tullibardine 2
Balmenach 3
Glendronach 3
Macallan 3
Ardbeg 4
Caol Ila 4
Clynelish 4
GlenScotia 4
Isle of Jura 4
Lagavulin 4
Laphroig 4
Oban 4
OldPulteney 4
Springbank 4
Talisker 4

A common question in K-means clustering is “What do the clusters actually mean?” You can get a sense for this by printing the centroids for each cluster.

model.labeledCentroids().print();

All the variables are coded on a scale from 0-4. You can see in the table below that the whiskeys in cluster 1 are the smokiest, have the most body, and are among the least sweet.

Cluster
Body
Sweet-ness
Smoky
Medic-inal
Tobac-co
Honey
Spicy
Winey
Nutty
Malty
Fruity
Floral
0
2.65
2.41
1.41
0.06
0.03
1.86
1.65
1.79
1.89
2.03
2.10
1.75
1
4.0
1.5
3.75
3.75
0.5
0.25
1.5
0.75
1.0
1.5
1.0
0.0
2
1.37
2.53
1.0
0.15
0.06
1.03
1.03
0.5
1.12
1.75
2.0
2.15
3
1.84
1.94
1.94
1.05
0.15
1.0
1.47
0.68
1.47
1.68
1.21
1.31
4
3.0
1.5
3.5
2.5
1.0
0.0
2.0
0.0
1.5
1.0
1.5
0.5

Another important question is “How do we know if we have chosen a good value for K?”.

One measure of goodness is distortion: the sum of the squared distances between each data point and its cluster. To get the distortion for our model, we use the method of the same name.

 model.distortion()

which returns 397.71774891774896.

To select a value for k we can run K-means repeatedly, using 2 to n clusters, plot the distortion against k, and see where adding more clusters has little impact. This is known as the elbow method, because the chart typically looks like an elbow, with an initial vertical segment followed by a bend and a more horizontal segment: like an arm bent at the elbow.  Sometimes however, the data doesn’t produce an elbow and that’s what happens here:

int n = t.rowCount();
double[] kValues = new double[n - 2];
double[] distortions = new double[n - 2];

for (int k = 2; k < n; k++) {
  kValues[k - 2] = k;
  Kmeans kmeans = new Kmeans(k,
      t.nCol(2), t.nCol(3), t.nCol(4), t.nCol(5), t.nCol(6), t.nCol(7),
      t.nCol(8), t.nCol(9), t.nCol(10), t.nCol(11), t.nCol(12), t.nCol(13)
  );
  distortions[k - 2] = kmeans.distortion();
}
Scatter.show(kValues, "k",  distortions, "distortion");

distortions

Next week, I’ll add support for X-Means clustering, which automates the choice of how many clusters to use.

 

 

 

 

A K-Nearest Neighbors Classifier in Java

Tablesaw is a platform for data science in Java designed for ease-of-use. It builds on some great open source libraries, including Smile, the machine learning library that provides the nuts and bolts of our machine learning capability.

Here we show a Tablesaw/Java version of a tutorial written by Mike de Waard using standalone Smile and Scala. You can find the original here. First we load the data:

Table example = Table.createFromCsv("data/KNN_Example_1.csv");

Next we output the table’s structure. The method call example.structure(); does this.

Index Column Name Column Type
0 X FLOAT
1 Y FLOAT
2 Label SHORT_INT

We want to predict the integer variable “Label” from the two floating point variables X and Y.  Label takes on the value 1 or 0, which you can see using the asSet() method:

example.shortColumn("Label").asSet();

which returns {0, 1}

To get a sense for their distribution and how they relate to X and Y, we plot X and Y, coloring the points according to whether the corresponding Label is one or zero.

Scatter.show("Example data", example.nCol("X"), example.nCol("Y"), 
    example.splitOn(example.shortColumn("Label")));

knn scatter

As you can see the green “ones” cluster to the upper right and the red “zeros” to the lower left.

To validate our model we split the data in half, randomly assigning rows to each half. Tablesaw can do this with the sampleSplit() method:

Table[] splits = example.sampleSplit(.5);
Table train = splits[0];
Table test = splits[1];

Next we build the model using the training data:

Knn knn = 
    Knn.learn(2, train.shortColumn(2), train.nCol("X"), train.nCol("Y"));

To see how well our model works, we produce a confusion matrix using the test dataset:

ConfusionMatrix matrix = 
    knn.predictMatrix(test.shortColumn(2), test.nCol("X"), test.nCol("Y"));

A confusion matrix summarizes the results of our classification. Correct classifications are shown on the diagonal. Our matrix for this sample looks like this:

n = 50 Actual 0 Actual 1
Predicted 0 22 4
Predicted 1 2 22

You can also get the overall accuracy of the test using the matrix’s accuracy() method:

matrix.accuracy();

which in this case is 0.88.

That’s all there is to creating a simple classifier. Tablesaw also supports LogisticRegression and LDA (Linear Discriminant Analysis) classifiers. Their use is nearly identical.

 

 

 

Play (Money)ball! Data Science in Tablesaw

Linear regression analysis has been called the “Hello World” of machine learning, because it’s both easy to understand and very powerful. And it’s the first machine learning algorithm we’ve added to Tablesaw.

One of the best known applications of regression comes from the book Moneyball, which describes the use of data science at the Oakland A’s baseball team. My analysis is based on a lecture given in the EdX course: MITx: 15.071x The Analytics Edge.  If you’re new to data analytics, I would strongly recommend this course.

Moneyball is a great example of how to apply data science to solve a business problem. For the A’s, the business problem was “How do we make the playoffs?” They break that problem down into simpler problems that can be solved with data science. Their approach is summarized in the diagram below:Moneyball-3

 

In baseball, you make the playoffs by winning more games than your rivals. The number of games the rivals win is out of your control so the A’s looked instead at how many wins it took historically to make the playoffs. They decided that 95 wins would give them a strong chance.  Here’s how we might check that assumption in Tablesaw.

// Get the data
Table baseball = Table.createFromCsv("data/baseball.csv");

// filter to the data available at the start of the 2002 season
Table moneyball = baseball.selectWhere(column("year").isLessThan(2002));

We can check the assumption visually by plotting wins per year in a way that separates the teams who make the playoffs from those who don’t. This code produces the chart below:

NumericColumn wins = moneyball.nCol("W");
NumericColumn year = moneyball.nCol("Year");
Column playoffs = moneyball.column("Playoffs");
XchartScatter.show("Regular season wins by year", wins, year, moneyball.splitOn(playoffs));

moneyballTeams that made the playoffs are shown as greenish points.  If you draw a vertical line at 95 wins, you can see that it’s very likely that a team that wins over 95 games will make the playoffs. So far so good.

Unfortunately, you can’t directly control the number of games you win. We need to go deeper. At the next level, we hypothesize that the number of wins can be predicted by the number of Runs Scored during the season, combined with the number of Runs Allowed.

To check this assumption we compute Run Difference as Runs Scored – Runs Allowed:

IntColumn runDifference = moneyball.shortColumn("RS").subtract(moneyball.shortColumn("RA"));
moneyball.addColumn(runDifference);
runDifference.setName("RD");

Now lets see if Run Difference is correlated with Wins. We use a scatter plot again:

Scatter.show("RD x Wins", moneyball.numericColumn("RD"), moneyball.numericColumn("W"));

RD vs Wins

Our plot shows a strong linear relation between the two. Lets create our first predictive model using linear regression, with runDifference as the explanatory variable.

LeastSquares winsModel = LeastSquares.train(wins, runDifference);

If we print our “winsModel”, it produces the output below:

Linear Model:

Residuals:
     Min       1Q    Median       3Q       Max
-14.2662  -2.6511    0.1282   2.9365   11.6570

Coefficients:
             Estimate   Std.Error   t value  Pr(>|t|)
 (Intercept)  80.8814      0.1312  616.6747    0.0000 ***
 RD            0.1058      0.0013   81.5536    0.0000 ***
---------------------------------------------------------------------
Significance codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 3.9391 on 900 degrees of freedom
Multiple R-squared: 0.8808, Adjusted R-squared: 0.8807
F-statistic: 6650.9926 on 1 and 900 DF, p-value: 0.000

If you’re new to regression, here are some take-aways from the output:

  • The R-squared of 88 can be interpreted to mean that roughly 88% of the variance in Wins can be explained by the Run Difference variable. The rest may be determined by some other variable, or it may be pure chance.
  • The estimate for the Intercept is the average wins independent of Run Difference. In baseball, we have a 162 game season so we expect this value to be about 81, as it is.
  • The estimate for the RD variable of .1, suggests that an increase of 10 in Run Difference, should produce about 1 additional win over the course of the season.

Of course, this model is not simply descriptive. We can use it to make predictions. In the code below, we predict how many games we will win if we score 135 more runs than our opponents.  To do this, we pass an array of doubles, one for each explanatory variable in our model, to the predict() method. In this case, there’s just one variable – run difference.

double[] runDifference = new double[1];
runDifference[0] = 135;
double expectedWins = winsModel.predict(runDifference);

In this case, expectedWins is 95.2 when we outscore opponents by 135 runs.

It’s time to go deeper again and see how we can model Runs Scored and Runs Allowed. The approach the A’s took was to model Runs Scored using team On-base percent (OBP) and team Slugging Average (SLG). In Tablesaw, we write:

LeastSquares runsScored2 = 
    LeastSquares.train(moneyball.nCol("RS"), moneyball.nCol("OBP"), moneyball.nCol("SLG"));

Once again the first parameter takes a Tablesaw column containing the values we want to predict (Runs scored). The next two parameters take the explanatory variables OBP and SLG.

    Linear Model:

    Residuals:
               Min          1Q      Median          3Q         Max
          -70.8379    -17.1810     -1.0917     16.7812     90.0358

    Coefficients:
                Estimate        Std. Error        t value        Pr(>|t|)
    (Intercept)  -804.6271           18.9208       -42.5261          0.0000 ***
    OBP          2737.7682           90.6846        30.1900          0.0000 ***
    SLG          1584.9085           42.1556        37.5966          0.0000 ***
    ---------------------------------------------------------------------
    Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

    Residual standard error: 24.7900 on 899 degrees of freedom
    Multiple R-squared: 0.9296,    Adjusted R-squared: 0.9294
    F-statistic: 5933.7256 on 2 and 899 DF,  p-value: 0.000

 

Again we have a model with excellent explanatory power with an R-squared of 92. Now we’ll check the model visually to see if it violates any assumptions. First, our residuals should be normally distributed. We can use a histogram to verify:

Histogram.show(runsScored2.residuals());

residuals_histogramThis looks great.  It’s also important to plot the predicted (or fitted) values against the residuals. We want to see if the model fits some values better than others, which will influence whether we can trust its predictions or not. We want to see a cloud of random dots around zero on the y axis.

Our Scatter class can create this plot directly from the model:

Scatter.showFittedVsResidual(runsScored2);

FittedVresiduals

Again, the plot looks good.

Lets review.  We’ve created a model of baseball that predicts entry into the playoffs based on batting stats, with the influence of the variables as:

SLG & OBP -> Runs Scored -> Run Difference -> Regular Season Wins

Of course, we haven’t modeled the Runs Allowed side of Run Difference. We could use pitching and field stats to do this, but the A’s cleverly used the same two variables (SLG and OBP), but now looked at how their opponent’s performed against the A’s. We could do the same as these data are encoded in the dataset as OOBP and OSLG.

We used regression to build predictive models, and visualizations to check our assumptions and validate our models.

However, we still haven’t shown how this knowledge can be applied. That step involves predicting how the current team will perform using historical data, and considering the available talent to see who can bring up the team’s average OBP or SLG numbers, or reduce the opponent values of the same stats. They can create scenarios where they consider various trades and expected salary costs. Taking it to that level requires individual player stats that aren’t in our dataset, so we’ll leave it here, but I hope this post has shown how Tablesaw makes regression analysis in Java easy and practical.