Fruits Lovers: Solving A Simple Classification Problem with Python

Fruits Lovers: Solving A Simple Classification Problem with PythonOcktavia Nurima PutriBlockedUnblockFollowFollowingJan 9In this post, we’ll implement several machine learning algorithms in Python using Scikit-learn, the most popular machine learning tool for Python.

Using a simple dataset for the task of training a classifier to distinguish between different types of fruits.

The purpose of this post is to identify the machine learning algorithm that is best-suited for the problem at hand; thus, we want to compare different algorithms, selecting the best-performing one.

Let’s get started!DataThe fruits dataset was created by Dr.

Iain Murray from University of Edinburgh.

He bought a few dozen oranges, lemons and apples of different varieties, and recorded their measurements in a table.

And then the professors at University of Michigan formatted the fruits data slightly and it can be downloaded from here.

Let’s have a look the first a few rows of the data.

%matplotlib inlineimport pandas as pdimport matplotlib.

pyplot as pltfruits = pd.



head()Figure 1Each row of the dataset represents one piece of the fruit as represented by several features that are in the table’s columns.

We have 59 pieces of fruits and 7 features in the dataset:print(fruits.

shape)(59, 7)We have four types of fruits in the dataset:print(fruits['fruit_name'].

unique())[‘apple’ ‘mandarin’ ‘orange’ ‘lemon’]The data is pretty balanced except mandarin.

We will just have to go with it.



size())Figure 2import seaborn as snssns.


show()Figure 3VisualizationBox plot for each numeric variable will give us a clearer idea of the distribution of the input variables:fruits.

drop('fruit_label', axis=1).

plot(kind='box', subplots=True, layout=(2,2), sharex=False, sharey=False, figsize=(9,9), title='Box Plot for each input variable')plt.


show()Figure 4It looks like perhaps color score has a near Gaussian distribution.

import pylab as plfruits.

drop('fruit_label' ,axis=1).

hist(bins=30, figsize=(9,9))pl.

suptitle("Histogram for each numeric input variable")plt.


show()Figure 5Some pairs of attributes are correlated (mass and width).

This suggests a high correlation and a predictable relationship.

from pandas.


plotting import scatter_matrixfrom matplotlib import cmfeature_names = ['mass', 'width', 'height', 'color_score']X = fruits[feature_names]y = fruits['fruit_label']cmap = cm.

get_cmap('gnuplot')scatter = pd.

scatter_matrix(X, c = y, marker = 'o', s=40, hist_kwds={'bins':15}, figsize=(9,9), cmap = cmap)plt.

suptitle('Scatter-matrix for each input variable')plt.

savefig('fruits_scatter_matrix')Figure 6Statistical SummaryFigure 7We can see that the numerical values do not have the same scale.

We will need to apply scaling to the test set that we computed for the training set.

Create Training and Test Sets and Apply Scalingfrom sklearn.

model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)from sklearn.

preprocessing import MinMaxScalerscaler = MinMaxScaler()X_train = scaler.

fit_transform(X_train)X_test = scaler.

transform(X_test)Build ModelsLogistic Regressionfrom sklearn.

linear_model import LogisticRegressionlogreg = LogisticRegression()logreg.

fit(X_train, y_train)print('Accuracy of Logistic regression classifier on training set: {:.

2f}' .


score(X_train, y_train)))print('Accuracy of Logistic regression classifier on test set: {:.

2f}' .


score(X_test, y_test)))Accuracy of Logistic regression classifier on training set: 0.

70Accuracy of Logistic regression classifier on test set: 0.

40Decision Treefrom sklearn.

tree import DecisionTreeClassifierclf = DecisionTreeClassifier().

fit(X_train, y_train)print('Accuracy of Decision Tree classifier on training set: {:.

2f}' .


score(X_train, y_train)))print('Accuracy of Decision Tree classifier on test set: {:.

2f}' .


score(X_test, y_test)))Accuracy of Decision Tree classifier on training set: 1.

00Accuracy of Decision Tree classifier on test set: 0.

73K-Nearest Neighborsfrom sklearn.

neighbors import KNeighborsClassifierknn = KNeighborsClassifier()knn.

fit(X_train, y_train)print('Accuracy of K-NN classifier on training set: {:.

2f}' .


score(X_train, y_train)))print('Accuracy of K-NN classifier on test set: {:.

2f}' .


score(X_test, y_test)))Accuracy of K-NN classifier on training set: 0.

95Accuracy of K-NN classifier on test set: 1.

00Linear Discriminant Analysisfrom sklearn.

discriminant_analysis import LinearDiscriminantAnalysislda = LinearDiscriminantAnalysis()lda.

fit(X_train, y_train)print('Accuracy of LDA classifier on training set: {:.

2f}' .


score(X_train, y_train)))print('Accuracy of LDA classifier on test set: {:.

2f}' .


score(X_test, y_test)))Accuracy of LDA classifier on training set: 0.

86Accuracy of LDA classifier on test set: 0.

67Gaussian Naive Bayesfrom sklearn.

naive_bayes import GaussianNBgnb = GaussianNB()gnb.

fit(X_train, y_train)print('Accuracy of GNB classifier on training set: {:.

2f}' .


score(X_train, y_train)))print('Accuracy of GNB classifier on test set: {:.

2f}' .


score(X_test, y_test)))Accuracy of GNB classifier on training set: 0.

86Accuracy of GNB classifier on test set: 0.

67Support Vector Machinefrom sklearn.

svm import SVCsvm = SVC()svm.

fit(X_train, y_train)print('Accuracy of SVM classifier on training set: {:.

2f}' .


score(X_train, y_train)))print('Accuracy of SVM classifier on test set: {:.

2f}' .


score(X_test, y_test)))Accuracy of SVM classifier on training set: 0.

61Accuracy of SVM classifier on test set: 0.

33The KNN algorithm was the most accurate model that we tried.

The confusion matrix provides an indication of no error made on the test set.

However, the test set was very small.

from sklearn.

metrics import classification_reportfrom sklearn.

metrics import confusion_matrixpred = knn.

predict(X_test)print(confusion_matrix(y_test, pred))print(classification_report(y_test, pred))Figure 7Plot the Decision Boundary of the k-NN Classifierimport matplotlib.

cm as cmfrom matplotlib.

colors import ListedColormap, BoundaryNormimport matplotlib.

patches as mpatchesimport matplotlib.

patches as mpatchesX = fruits[['mass', 'width', 'height', 'color_score']]y = fruits['fruit_label']X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)def plot_fruit_knn(X, y, n_neighbors, weights): X_mat = X[['height', 'width']].

as_matrix() y_mat = y.

as_matrix()# Create color maps cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF','#AFAFAF']) cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF','#AFAFAF'])clf = neighbors.

KNeighborsClassifier(n_neighbors, weights=weights) clf.

fit(X_mat, y_mat)# Plot the decision boundary by assigning a color in the color map # to each mesh point.

mesh_step_size = .

01 # step size in the mesh plot_symbol_size = 50 x_min, x_max = X_mat[:, 0].

min() – 1, X_mat[:, 0].

max() + 1 y_min, y_max = X_mat[:, 1].

min() – 1, X_mat[:, 1].

max() + 1 xx, yy = np.


arange(x_min, x_max, mesh_step_size), np.

arange(y_min, y_max, mesh_step_size)) Z = clf.



ravel(), yy.

ravel()])# Put the result into a color plot Z = Z.


shape) plt.

figure() plt.

pcolormesh(xx, yy, Z, cmap=cmap_light)# Plot training points plt.

scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold, edgecolor = 'black') plt.


min(), xx.

max()) plt.


min(), yy.

max())patch0 = mpatches.

Patch(color='#FF0000', label='apple') patch1 = mpatches.

Patch(color='#00FF00', label='mandarin') patch2 = mpatches.

Patch(color='#0000FF', label='orange') patch3 = mpatches.

Patch(color='#AFAFAF', label='lemon') plt.

legend(handles=[patch0, patch1, patch2, patch3])plt.

xlabel('height (cm)')plt.

ylabel('width (cm)')plt.

title("4-Class classification (k = %i, weights = '%s')" % (n_neighbors, weights)) plt.

show()plot_fruit_knn(X_train, y_train, 5, 'uniform')Figure 8k_range = range(1, 20)scores = []for k in k_range: knn = KNeighborsClassifier(n_neighbors = k) knn.

fit(X_train, y_train) scores.


score(X_test, y_test))plt.




scatter(k_range, scores)plt.

xticks([0,5,10,15,20])Figure 9For this particular dateset, we obtain the highest accuracy when k=5.

SummaryIn this post, we focused on the prediction accuracy.

Our objective is to learn a model that has a good generalization performance.

Such a model maximizes the prediction accuracy.

We identified the machine learning algorithm that is best-suited for the problem at hand (i.


fruit types classification); therefore, we compared different algorithms and selected the best-performing one.

Source code that created this post can be found here.

I would be pleased to receive feedback or questions on any of the above.


. More details

Leave a Reply