Data Science for Java Developers With Tablesaw (a new article/tutorial)

Tablesaw is a gateway drug for doing data science in Java. For Java developers new to data science, it’s the easiest way to get started. A new article published on Dzone explains why and provides a short tutorial to show you how. Check out the article: Data Science for Java Developers With Tablesaw.

You can find the example code here:  https://github.com/jtablesaw/tablesaw/blob/master/core/src/examples/java/tech/tablesaw/Example1.java

Tablesaw welcomes Ben McCann

I’m very pleased to report that the Tablesaw project (https://github.com/jtablesaw/tablesaw) has added Ben McCann as a co-maintainer.

Ben is a former Googler (like me), and is now a Senior Staff Engineer at LinkedIn.  He was also co-founder at Connectifier, which was acquired by LinkedIn in 2016. Ben has already had a tremendous impact on the project, making contributions at every level. It’s great to have him on board.

An update

It’s been a while so I thought an update is in order.

I’ve been a bit overwhelmed at work (new job) so I haven’t made the anticipated progress, but have decided to push forward to a 1.0 release with no new features.

The two main activities between now and then will be:

  1. redoing the implementation of missing-data handling, and
  2. testing and bug-fixing

I’ll also be looking at some of the interfaces and see what kind of changes are needed.

What I won’t be doing until after 1.0 is an implementation of join logic. Today, if you want to join tables you must either write it yourself in java, or, if your data source is a RDBMS, do the join in your query logic using tablesaw’s database interface, and bring in the data already joined.

Join support will be added after the release, but doing it right is a considerable amount of real engineering work. Beyond that I have one more feature in mind for an upcoming release that I think will be pretty cool, but I’ll write more on that later.

Thanks for your support.

Larry

K-means clustering in Java

K-means is the most common form of “centroid” clustering. Unlike classification, clustering is an unsupervised learning method. The categories are not predetermined. Instead, the goal is to search for natural groupings in the dataset, such that the members of each group are similar to each other and different from the members of the other groups. The K represents the number of groups to find.

We’ll use a well known Scotch Whiskey dataset, which is used to cluster whiskeys according to their taste based on data collected from tasting notes. As always, we start by loading data and printing its structure.

Table t = Table.createFromCsv("data/whiskey.csv");
t.structure().print();
Index Column Name Column Type
0 RowID SHORT_INT
1 Distillery CATEGORY
2 Body SHORT_INT
3 Sweetness SHORT_INT
4 Smoky SHORT_INT
5 Medicinal SHORT_INT
6 Tobacco SHORT_INT
7 Honey SHORT_INT
8 Spicy SHORT_INT
9 Winey SHORT_INT
10 Nutty SHORT_INT
11 Malty SHORT_INT
12 Fruity SHORT_INT
13 Floral SHORT_INT
14 Postcode CATEGORY
15 Latitude FLOAT
16 Longitude INTEGER

We create our model using a subset of the columns in the input table:

Kmeans model = new Kmeans(
    5,
    t.nCol(2), t.nCol(3), t.nCol(4), t.nCol(5), t.nCol(6), t.nCol(7),
    t.nCol(8), t.nCol(9), t.nCol(10), t.nCol(11), t.nCol(12), t.nCol(13)
);

Here the argument 5 is K, the number of clusters to form. We picked it on a whim, but will discuss in a bit some slightly more scientific approaches. The other arguments are numeric columns identified by their index in table t.

That’s all it takes to perform the clustering. To see what clusters it formed, we can use the clustered(Column column) method. Here, the column parameter is a column that identifies the records (in this case “Distillery”), and the return value is a table, that maps the distilleries to their clusters.

model.clustered(t.column("Distillery")).print();
Label Cluster
Aberfeldy 0
Aberlour 0
Belvenie 0
BenNevis 0
Benriach 0
Benrinnes 0
Benromach 0
BlairAthol 0
Bowmore 0
Bruichladdich 0
Craigallechie 0
Dailuaine 0
Dalmore 0
Deanston 0
GlenGarioch 0
GlenOrd 0
Glendullan 0
Glenlivet 0
Highland Park 0
Longmorn 0
Mortlach 0
RoyalBrackla 0
RoyalLochnagar 0
Scapa 0
Strathisla 0
AnCnoc 1
ArranIsleOf 1
Auchentoshan 1
Aultmore 1
Bladnoch 1
Bunnahabhain 1
Cardhu 1
Craigganmore 1
Dalwhinnie 1
Dufftown 1
GlenElgin 1
GlenGrant 1
GlenMoray 1
GlenSpey 1
Glenallachie 1
Glenfiddich 1
Glengoyne 1
Glenkinchie 1
Glenlossie 1
Glenmorangie 1
Inchgower 1
Linkwood 1
Loch Lomond 1
Mannochmore 1
Miltonduff 1
Speyburn 1
Speyside 1
Strathmill 1
Tamdhu 1
Tamnavulin 1
Teaninich 1
Tobermory 1
Tomintoul 1
Ardmore 2
Auchroisk 2
Balblair 2
Edradour 2
GlenDeveronMacduff 2
GlenKeith 2
Glenfarclas 2
Glenrothes 2
Glenturret 2
Knochando 2
OldFettercairn 2
Tomatin 2
Tormore 2
Tullibardine 2
Balmenach 3
Glendronach 3
Macallan 3
Ardbeg 4
Caol Ila 4
Clynelish 4
GlenScotia 4
Isle of Jura 4
Lagavulin 4
Laphroig 4
Oban 4
OldPulteney 4
Springbank 4
Talisker 4

A common question in K-means clustering is “What do the clusters actually mean?” You can get a sense for this by printing the centroids for each cluster.

model.labeledCentroids().print();

All the variables are coded on a scale from 0-4. You can see in the table below that the whiskeys in cluster 1 are the smokiest, have the most body, and are among the least sweet.

Cluster
Body
Sweet-ness
Smoky
Medic-inal
Tobac-co
Honey
Spicy
Winey
Nutty
Malty
Fruity
Floral
0
2.65
2.41
1.41
0.06
0.03
1.86
1.65
1.79
1.89
2.03
2.10
1.75
1
4.0
1.5
3.75
3.75
0.5
0.25
1.5
0.75
1.0
1.5
1.0
0.0
2
1.37
2.53
1.0
0.15
0.06
1.03
1.03
0.5
1.12
1.75
2.0
2.15
3
1.84
1.94
1.94
1.05
0.15
1.0
1.47
0.68
1.47
1.68
1.21
1.31
4
3.0
1.5
3.5
2.5
1.0
0.0
2.0
0.0
1.5
1.0
1.5
0.5

Another important question is “How do we know if we have chosen a good value for K?”.

One measure of goodness is distortion: the sum of the squared distances between each data point and its cluster. To get the distortion for our model, we use the method of the same name.

 model.distortion()

which returns 397.71774891774896.

To select a value for k we can run K-means repeatedly, using 2 to n clusters, plot the distortion against k, and see where adding more clusters has little impact. This is known as the elbow method, because the chart typically looks like an elbow, with an initial vertical segment followed by a bend and a more horizontal segment: like an arm bent at the elbow.  Sometimes however, the data doesn’t produce an elbow and that’s what happens here:

int n = t.rowCount();
double[] kValues = new double[n - 2];
double[] distortions = new double[n - 2];

for (int k = 2; k < n; k++) {
  kValues[k - 2] = k;
  Kmeans kmeans = new Kmeans(k,
      t.nCol(2), t.nCol(3), t.nCol(4), t.nCol(5), t.nCol(6), t.nCol(7),
      t.nCol(8), t.nCol(9), t.nCol(10), t.nCol(11), t.nCol(12), t.nCol(13)
  );
  distortions[k - 2] = kmeans.distortion();
}
Scatter.show(kValues, "k",  distortions, "distortion");

distortions

Next week, I’ll add support for X-Means clustering, which automates the choice of how many clusters to use.

 

 

 

 

A K-Nearest Neighbors Classifier in Java

Tablesaw is a platform for data science in Java designed for ease-of-use. It builds on some great open source libraries, including Smile, the machine learning library that provides the nuts and bolts of our machine learning capability.

Here we show a Tablesaw/Java version of a tutorial written by Mike de Waard using standalone Smile and Scala. You can find the original here. First we load the data:

Table example = Table.createFromCsv("data/KNN_Example_1.csv");

Next we output the table’s structure. The method call example.structure(); does this.

Index Column Name Column Type
0 X FLOAT
1 Y FLOAT
2 Label SHORT_INT

We want to predict the integer variable “Label” from the two floating point variables X and Y.  Label takes on the value 1 or 0, which you can see using the asSet() method:

example.shortColumn("Label").asSet();

which returns {0, 1}

To get a sense for their distribution and how they relate to X and Y, we plot X and Y, coloring the points according to whether the corresponding Label is one or zero.

Scatter.show("Example data", example.nCol("X"), example.nCol("Y"), 
    example.splitOn(example.shortColumn("Label")));

knn scatter

As you can see the green “ones” cluster to the upper right and the red “zeros” to the lower left.

To validate our model we split the data in half, randomly assigning rows to each half. Tablesaw can do this with the sampleSplit() method:

Table[] splits = example.sampleSplit(.5);
Table train = splits[0];
Table test = splits[1];

Next we build the model using the training data:

Knn knn = 
    Knn.learn(2, train.shortColumn(2), train.nCol("X"), train.nCol("Y"));

To see how well our model works, we produce a confusion matrix using the test dataset:

ConfusionMatrix matrix = 
    knn.predictMatrix(test.shortColumn(2), test.nCol("X"), test.nCol("Y"));

A confusion matrix summarizes the results of our classification. Correct classifications are shown on the diagonal. Our matrix for this sample looks like this:

n = 50 Actual 0 Actual 1
Predicted 0 22 4
Predicted 1 2 22

You can also get the overall accuracy of the test using the matrix’s accuracy() method:

matrix.accuracy();

which in this case is 0.88.

That’s all there is to creating a simple classifier. Tablesaw also supports LogisticRegression and LDA (Linear Discriminant Analysis) classifiers. Their use is nearly identical.

 

 

 

Play (Money)ball! Data Science in Tablesaw

Linear regression analysis has been called the “Hello World” of machine learning, because it’s both easy to understand and very powerful. And it’s the first machine learning algorithm we’ve added to Tablesaw.

One of the best known applications of regression comes from the book Moneyball, which describes the use of data science at the Oakland A’s baseball team. My analysis is based on a lecture given in the EdX course: MITx: 15.071x The Analytics Edge.  If you’re new to data analytics, I would strongly recommend this course.

Moneyball is a great example of how to apply data science to solve a business problem. For the A’s, the business problem was “How do we make the playoffs?” They break that problem down into simpler problems that can be solved with data science. Their approach is summarized in the diagram below:Moneyball-3

 

In baseball, you make the playoffs by winning more games than your rivals. The number of games the rivals win is out of your control so the A’s looked instead at how many wins it took historically to make the playoffs. They decided that 95 wins would give them a strong chance.  Here’s how we might check that assumption in Tablesaw.

// Get the data
Table baseball = Table.createFromCsv("data/baseball.csv");

// filter to the data available at the start of the 2002 season
Table moneyball = baseball.selectWhere(column("year").isLessThan(2002));

We can check the assumption visually by plotting wins per year in a way that separates the teams who make the playoffs from those who don’t. This code produces the chart below:

NumericColumn wins = moneyball.nCol("W");
NumericColumn year = moneyball.nCol("Year");
Column playoffs = moneyball.column("Playoffs");
XchartScatter.show("Regular season wins by year", wins, year, moneyball.splitOn(playoffs));

moneyballTeams that made the playoffs are shown as greenish points.  If you draw a vertical line at 95 wins, you can see that it’s very likely that a team that wins over 95 games will make the playoffs. So far so good.

Unfortunately, you can’t directly control the number of games you win. We need to go deeper. At the next level, we hypothesize that the number of wins can be predicted by the number of Runs Scored during the season, combined with the number of Runs Allowed.

To check this assumption we compute Run Difference as Runs Scored – Runs Allowed:

IntColumn runDifference = moneyball.shortColumn("RS").subtract(moneyball.shortColumn("RA"));
moneyball.addColumn(runDifference);
runDifference.setName("RD");

Now lets see if Run Difference is correlated with Wins. We use a scatter plot again:

Scatter.show("RD x Wins", moneyball.numericColumn("RD"), moneyball.numericColumn("W"));

RD vs Wins

Our plot shows a strong linear relation between the two. Lets create our first predictive model using linear regression, with runDifference as the explanatory variable.

LeastSquares winsModel = LeastSquares.train(wins, runDifference);

If we print our “winsModel”, it produces the output below:

Linear Model:

Residuals:
     Min       1Q    Median       3Q       Max
-14.2662  -2.6511    0.1282   2.9365   11.6570

Coefficients:
             Estimate   Std.Error   t value  Pr(>|t|)
 (Intercept)  80.8814      0.1312  616.6747    0.0000 ***
 RD            0.1058      0.0013   81.5536    0.0000 ***
---------------------------------------------------------------------
Significance codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 3.9391 on 900 degrees of freedom
Multiple R-squared: 0.8808, Adjusted R-squared: 0.8807
F-statistic: 6650.9926 on 1 and 900 DF, p-value: 0.000

If you’re new to regression, here are some take-aways from the output:

  • The R-squared of 88 can be interpreted to mean that roughly 88% of the variance in Wins can be explained by the Run Difference variable. The rest may be determined by some other variable, or it may be pure chance.
  • The estimate for the Intercept is the average wins independent of Run Difference. In baseball, we have a 162 game season so we expect this value to be about 81, as it is.
  • The estimate for the RD variable of .1, suggests that an increase of 10 in Run Difference, should produce about 1 additional win over the course of the season.

Of course, this model is not simply descriptive. We can use it to make predictions. In the code below, we predict how many games we will win if we score 135 more runs than our opponents.  To do this, we pass an array of doubles, one for each explanatory variable in our model, to the predict() method. In this case, there’s just one variable – run difference.

double[] runDifference = new double[1];
runDifference[0] = 135;
double expectedWins = winsModel.predict(runDifference);

In this case, expectedWins is 95.2 when we outscore opponents by 135 runs.

It’s time to go deeper again and see how we can model Runs Scored and Runs Allowed. The approach the A’s took was to model Runs Scored using team On-base percent (OBP) and team Slugging Average (SLG). In Tablesaw, we write:

LeastSquares runsScored2 = 
    LeastSquares.train(moneyball.nCol("RS"), moneyball.nCol("OBP"), moneyball.nCol("SLG"));

Once again the first parameter takes a Tablesaw column containing the values we want to predict (Runs scored). The next two parameters take the explanatory variables OBP and SLG.

    Linear Model:

    Residuals:
               Min          1Q      Median          3Q         Max
          -70.8379    -17.1810     -1.0917     16.7812     90.0358

    Coefficients:
                Estimate        Std. Error        t value        Pr(>|t|)
    (Intercept)  -804.6271           18.9208       -42.5261          0.0000 ***
    OBP          2737.7682           90.6846        30.1900          0.0000 ***
    SLG          1584.9085           42.1556        37.5966          0.0000 ***
    ---------------------------------------------------------------------
    Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

    Residual standard error: 24.7900 on 899 degrees of freedom
    Multiple R-squared: 0.9296,    Adjusted R-squared: 0.9294
    F-statistic: 5933.7256 on 2 and 899 DF,  p-value: 0.000

 

Again we have a model with excellent explanatory power with an R-squared of 92. Now we’ll check the model visually to see if it violates any assumptions. First, our residuals should be normally distributed. We can use a histogram to verify:

Histogram.show(runsScored2.residuals());

residuals_histogramThis looks great.  It’s also important to plot the predicted (or fitted) values against the residuals. We want to see if the model fits some values better than others, which will influence whether we can trust its predictions or not. We want to see a cloud of random dots around zero on the y axis.

Our Scatter class can create this plot directly from the model:

Scatter.showFittedVsResidual(runsScored2);

FittedVresiduals

Again, the plot looks good.

Lets review.  We’ve created a model of baseball that predicts entry into the playoffs based on batting stats, with the influence of the variables as:

SLG & OBP -> Runs Scored -> Run Difference -> Regular Season Wins

Of course, we haven’t modeled the Runs Allowed side of Run Difference. We could use pitching and field stats to do this, but the A’s cleverly used the same two variables (SLG and OBP), but now looked at how their opponent’s performed against the A’s. We could do the same as these data are encoded in the dataset as OOBP and OSLG.

We used regression to build predictive models, and visualizations to check our assumptions and validate our models.

However, we still haven’t shown how this knowledge can be applied. That step involves predicting how the current team will perform using historical data, and considering the available talent to see who can bring up the team’s average OBP or SLG numbers, or reduce the opponent values of the same stats. They can create scenarios where they consider various trades and expected salary costs. Taking it to that level requires individual player stats that aren’t in our dataset, so we’ll leave it here, but I hope this post has shown how Tablesaw makes regression analysis in Java easy and practical.

 

 

New Plot Types in Tablesaw

In a prior post, I showed how to create some native Java scatter plots and a quantile plot in Tablesaw. Since then, I’ve added a few more plot types.

When it comes to plotting, Tablesaw integrates other libraries and tries to make their use as consistent as possible. Like the earlier scatter plots, this line chart is rendered using XChart under the covers:boston_robberiesThe dramatic increase in armed robberies is shown by plotting the sorted data against its in sequence.  The code looks like this:

Table baseball = Table.createFromCsv("data/boston-robberies.csv");
NumericColumn x = baseball.nCol("Record");
NumericColumn y = baseball.nCol("Robberies");
Line.show("Monthly Boston Armed Robberies Jan. 1966 - Oct. 1975", x, y);

Histograms are a must have. We use the plotting capabilities of the Smile machine learning library to create the one below. batting_histogram

Although they’re from different libraries, the Tablesaw API is similar:

Table baseball = Table.createFromCsv("data/baseball.csv");
NumericColumn x = baseball.nCol("BA");
Histogram.show("Distribution of team batting averages", x);

This is currently the only Smile plot we’re using, but there’s more to come. Heatmaps, Contour plots and QQ plots are coming soon. We’re also starting to integrate Smile’s machine learning capabilities, which will be a huge step forward for Tablesaw.

Bar plots are unglamorous, but very useful. Tablesaw can produce both horizontal and vertical bar plots, and also creates Pareto charts directly as a convenience. They’re all based on the JavaFx chart library, and like the other Tablesaw plots, they’re rendered in Swing windows. Here we show a Pareto chart of tornado fatalities by US state.

paretoThe code to produce this chart, including a filter to remove states with fewer than three fatalities is shown below. The grouping is done using the summarize method, which produces tabular summaries that can be passed directly to the plotting API.

Note the use of the #sum method. Any numerical summary supported by Tablesaw (standard deviation, median, sumOfLogs, etc.) can be substituted for easy plotting.

Table table = Table.createFromCsv("data/tornadoes_1950-2014.csv");
table = table.selectWhere(column("Fatalities").isGreaterThan(3));
Pareto.show("Tornado Fatalities by State", 
    table.summarize("fatalities", sum).by("State"));

As you can see, loading from a CSV, filtering the data, grouping, summing, sorting, and plotting is all done in three lines of code.

Finally, we have a BoxPlot.

tornado_boxplot

For Boxplots, the groups are formed using Table’s splitOn() method, or simply by passing the names of the summary and grouping columns along with the Table:

Table table = Table.createFromCsv("data/tornadoes_1950-2014.csv");
Box.show("Tornado Injuries by Scale", table, "injuries", "scale");

I hope you’ll find Tablesaw useful for your data analytics work.

 

Tablesaw gets Graphic

Today we introduced the first elements of what will be Tablesaw’s support for exploratory data visualization in pure Java. As Tablesaw expands its scope to integrate statistical and machine learning capabilities, this kind of visualization will be critical.tornadosThis slightly ghostly US map image was created by as a simple scatter plot of the starting latitude and longitude for every US tornado between 1950 and 2014. The code below loads the data, filters out missing records, and renders the plot:

Table tornado = Table.createFromCsv("data/tornadoes_1950-2014.csv");

tornado = tornado.selectWhere(
    both(column("Start Lat").isGreaterThan(0f),
         column("Scale").isGreaterThanOrEqualTo(0)));

Scatter.show("US Tornados 1950-2014",
    tornado.numericColumn("Start Lon"),
    tornado.numericColumn("Start Lat"));

These plots provide visual feedback to the analyst while she’s working. They’re for discovery, rather than for presentation, and ease of use is stressed over beauty. Behind the scenes, the charts are created with Tim Molter’s awesome XChart library:  https://github.com/timmolter/XChart.

The following chart is taken from a baseball data set. It shows how to split a table on the values of one or more columns, producing a series for each group. In this case, we color the mark differently if the team made the playoffs. winsByYear

Here’s the code:

Table baseball = Table.createFromCsv("data/baseball.csv");
Scatter.show("Regular season wins by year",
    baseball.numericColumn("W"),
    baseball.numericColumn("Year"),
    baseball.splitOn(baseball.column("Playoffs")));

A chart that looks like a scatter plot and works like a histogram is a Quantile Plot. The plot below presents the distribution of public opinion poll ratings for one US president.

bush_quantiles

This chart was build using the Quantile class:

String title = "Quantiles: George W. Bush (Feb. 2001 - Feb. 2004)";
Quantile.show(title, bush.numericColumn("approval"));

Further down the line, I expect to add JavaScript plot support based on D3. These plots will be focused more on presentation, especially Web-based presentation, as Tablesaw becomes a complete platform for data science.