Skittles Sorter: A note on machine learning

When designing the Skittles sorting machine, it seemed fairly obvious that we would need to take controlled sample data on each color of skittle that we were sorting and analyze the RGB results ourselves.

At first we went ahead and gathered the data, loaded it into excel, and ended up with several graphs that looked like so…

The next step would have been to establish lines of regression for the R, G, and B values in relation to the C value for every color of skittle. Then in our code we would have to take the readings and see which color of skittle was most likely being read by applying all 15 of those equations.

That seemed like a lot of work.

So we came up with an easier (and better!) solution. Machine learning.

We plugged our data into a machine learning algorithm that generated a decision tree for us that looked something like (well a lot like) this.

There was no automatic way to transfer this decision tree into our code, so we had to type it out ourselves (more on that in a bit).

Now, it may seem odd that the decisionmaking process was based on relatively few, hardcoded, discrete, value tests when the RGB values behaved according to continuous relationships. The machine learning algorithm seemed to notice that even though the RGB values varied with brightness (the C value), they tended to fall within certainĀ ‘bands’ for each skittle color since the brightness tended to stay within a reasonable range (due to the builtin chip LED).

We translated this tree into code by typing out the if-else statements ourselves. The translated version of the tree pictured above turned out like this.

The actual machine learning algorithm used was pulled from a python library. It is pasted at the bottom of this post.

Within the code, we typed out our datapoints (can be plainly seen in the beginning section). Each array within the larger array is one datapoint. The ‘labels’ array tells the algorithm which category each datapoint falls under in sequence (each color is associated with a number); thankfully we typed out our data in uniform sections. The ‘feature names’ array gives aliases to each value in a datapoint, and the ‘class names’ array translates from the numeric category in the ‘labels’ array to a name like ‘Red’ or ‘Blue’.

from sklearn import tree

features = [
    [1740, 926, 744, 3049], # Red
    [1886, 318, 386, 2150],
    [1987, 1126, 964, 3862],
    [2226, 1077, 1029, 3932],
    [1801, 375, 436, 2132],
    [1522, 306, 356, 1776],
    [1611, 411, 427, 2042],
    [1723, 501, 505, 2274],
    [1626, 399, 424, 1997],
    [1617, 411, 430, 2048],
    [6073, 1410, 1282, 7949], # Orange
    [3743, 2115, 1558, 7304],
    [4135, 853, 675, 5256],
    [3578, 1139, 850, 5297],
    [4901, 1029, 811, 6338],
    [3430, 1860, 1293, 6562],
    [4362, 875, 698, 5611],
    [2938, 735, 568, 3982],
    [4096, 1200, 826, 5772],
    [4254, 2015, 1493, 7470],
    [3243, 3032, 1499, 8010], # Yellow
    [3150, 2880, 1288, 7540],
    [4148, 2887, 1035, 8054],
    [6079, 5697, 2387, 14570],
    [3778, 2928, 1025, 7648],
    [5326, 3857, 1311, 10652],
    [4609, 3200, 1071, 8958],
    [3883, 3004, 1132, 7940],
    [4340, 2969, 1013, 8285],
    [1849, 1772, 997, 4551],
    [2557, 3800, 1820, 7581], # Green
    [1171, 1510, 599, 2881],
    [1193, 1576, 587, 3085],
    [1572, 2067, 1083, 4241],
    [2035, 2771, 1451, 6236],
    [1837, 2315, 1340, 5459],
    [1309, 1734, 696, 3316],
    [1019, 1223, 520, 2396],
    [1058, 946, 743, 2575],
    [2540, 3732, 1831, 7464],
    [438, 195, 183, 585], # Purple
    [521, 314, 278, 842],
    [655, 521, 435, 1310],
    [476, 303, 263, 837],
    [525, 309, 273, 851],
    [1252, 759, 849, 2385],
    [678, 559, 476, 1414],
    [460, 232, 215, 690],
    [1223, 891, 933, 2649],
    [955, 559, 630, 1743],
    [977, 862, 619, 2460], # Nothing
    [977, 862, 619, 2460],
    [972, 858, 616, 2449],
    [972, 858, 616, 2449],
    [974, 860, 618, 2452],
    [976, 863, 620, 2461],
    [974, 861, 619, 2459],
    [973, 860, 619, 2452],
    [973, 860, 618, 2453],
    [974, 860, 618, 2457],
]
labels = [
    0,0,0,0,0,0,0,0,0,0,
    1,1,1,1,1,1,1,1,1,1,
    2,2,2,2,2,2,2,2,2,2,
    3,3,3,3,3,3,3,3,3,3,
    4,4,4,4,4,4,4,4,4,4,
    5,5,5,5,5,5,5,5,5,5
]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(features, labels)

# viz code
from sklearn.externals.six import StringIO
import pydot

dot_data = StringIO()
tree.export_graphviz(clf,
                     out_file=dot_data,
                     feature_names=["R", "G", "B", "C"],
                     class_names=["Red", "Orange", "Yellow", "Green", "Purple", "Nothing"],
                     filled=True, rounded=True,
                     impurity=False)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("tree.pdf")</pre>