K-means clustering in Java

K-means is the most common form of “centroid” clustering. Unlike classification, clustering is an unsupervised learning method. The categories are not predetermined. Instead, the goal is to search for natural groupings in the dataset, such that the members of each group are similar to each other and different from the members of the other groups. The K represents the number of groups to find.

We’ll use a well known Scotch Whiskey dataset, which is used to cluster whiskeys according to their taste based on data collected from tasting notes. As always, we start by loading data and printing its structure.

Table t = Table.createFromCsv("data/whiskey.csv");
t.structure().print();
Index Column Name Column Type
0 RowID SHORT_INT
1 Distillery CATEGORY
2 Body SHORT_INT
3 Sweetness SHORT_INT
4 Smoky SHORT_INT
5 Medicinal SHORT_INT
6 Tobacco SHORT_INT
7 Honey SHORT_INT
8 Spicy SHORT_INT
9 Winey SHORT_INT
10 Nutty SHORT_INT
11 Malty SHORT_INT
12 Fruity SHORT_INT
13 Floral SHORT_INT
14 Postcode CATEGORY
15 Latitude FLOAT
16 Longitude INTEGER

We create our model using a subset of the columns in the input table:

Kmeans model = new Kmeans(
    5,
    t.nCol(2), t.nCol(3), t.nCol(4), t.nCol(5), t.nCol(6), t.nCol(7),
    t.nCol(8), t.nCol(9), t.nCol(10), t.nCol(11), t.nCol(12), t.nCol(13)
);

Here the argument 5 is K, the number of clusters to form. We picked it on a whim, but will discuss in a bit some slightly more scientific approaches. The other arguments are numeric columns identified by their index in table t.

That’s all it takes to perform the clustering. To see what clusters it formed, we can use the clustered(Column column) method. Here, the column parameter is a column that identifies the records (in this case “Distillery”), and the return value is a table, that maps the distilleries to their clusters.

model.clustered(t.column("Distillery")).print();
Label Cluster
Aberfeldy 0
Aberlour 0
Belvenie 0
BenNevis 0
Benriach 0
Benrinnes 0
Benromach 0
BlairAthol 0
Bowmore 0
Bruichladdich 0
Craigallechie 0
Dailuaine 0
Dalmore 0
Deanston 0
GlenGarioch 0
GlenOrd 0
Glendullan 0
Glenlivet 0
Highland Park 0
Longmorn 0
Mortlach 0
RoyalBrackla 0
RoyalLochnagar 0
Scapa 0
Strathisla 0
AnCnoc 1
ArranIsleOf 1
Auchentoshan 1
Aultmore 1
Bladnoch 1
Bunnahabhain 1
Cardhu 1
Craigganmore 1
Dalwhinnie 1
Dufftown 1
GlenElgin 1
GlenGrant 1
GlenMoray 1
GlenSpey 1
Glenallachie 1
Glenfiddich 1
Glengoyne 1
Glenkinchie 1
Glenlossie 1
Glenmorangie 1
Inchgower 1
Linkwood 1
Loch Lomond 1
Mannochmore 1
Miltonduff 1
Speyburn 1
Speyside 1
Strathmill 1
Tamdhu 1
Tamnavulin 1
Teaninich 1
Tobermory 1
Tomintoul 1
Ardmore 2
Auchroisk 2
Balblair 2
Edradour 2
GlenDeveronMacduff 2
GlenKeith 2
Glenfarclas 2
Glenrothes 2
Glenturret 2
Knochando 2
OldFettercairn 2
Tomatin 2
Tormore 2
Tullibardine 2
Balmenach 3
Glendronach 3
Macallan 3
Ardbeg 4
Caol Ila 4
Clynelish 4
GlenScotia 4
Isle of Jura 4
Lagavulin 4
Laphroig 4
Oban 4
OldPulteney 4
Springbank 4
Talisker 4

A common question in K-means clustering is “What do the clusters actually mean?” You can get a sense for this by printing the centroids for each cluster.

model.labeledCentroids().print();

All the variables are coded on a scale from 0-4. You can see in the table below that the whiskeys in cluster 1 are the smokiest, have the most body, and are among the least sweet.

Cluster
Body
Sweet-ness
Smoky
Medic-inal
Tobac-co
Honey
Spicy
Winey
Nutty
Malty
Fruity
Floral
0
2.65
2.41
1.41
0.06
0.03
1.86
1.65
1.79
1.89
2.03
2.10
1.75
1
4.0
1.5
3.75
3.75
0.5
0.25
1.5
0.75
1.0
1.5
1.0
0.0
2
1.37
2.53
1.0
0.15
0.06
1.03
1.03
0.5
1.12
1.75
2.0
2.15
3
1.84
1.94
1.94
1.05
0.15
1.0
1.47
0.68
1.47
1.68
1.21
1.31
4
3.0
1.5
3.5
2.5
1.0
0.0
2.0
0.0
1.5
1.0
1.5
0.5

Another important question is “How do we know if we have chosen a good value for K?”.

One measure of goodness is distortion: the sum of the squared distances between each data point and its cluster. To get the distortion for our model, we use the method of the same name.

 model.distortion()

which returns 397.71774891774896.

To select a value for k we can run K-means repeatedly, using 2 to n clusters, plot the distortion against k, and see where adding more clusters has little impact. This is known as the elbow method, because the chart typically looks like an elbow, with an initial vertical segment followed by a bend and a more horizontal segment: like an arm bent at the elbow.  Sometimes however, the data doesn’t produce an elbow and that’s what happens here:

int n = t.rowCount();
double[] kValues = new double[n - 2];
double[] distortions = new double[n - 2];

for (int k = 2; k < n; k++) {
  kValues[k - 2] = k;
  Kmeans kmeans = new Kmeans(k,
      t.nCol(2), t.nCol(3), t.nCol(4), t.nCol(5), t.nCol(6), t.nCol(7),
      t.nCol(8), t.nCol(9), t.nCol(10), t.nCol(11), t.nCol(12), t.nCol(13)
  );
  distortions[k - 2] = kmeans.distortion();
}
Scatter.show(kValues, "k",  distortions, "distortion");

distortions

Next week, I’ll add support for X-Means clustering, which automates the choice of how many clusters to use.

 

 

 

 

2 thoughts on “K-means clustering in Java

    • No I haven’t, sorry. The Tablesaw workflow is very similar to K-means, but you could check the documentation from the Smile project, or a general source of info on X-means.

      Like

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s