Understanding Logistic Regression step by stepTraining a logistic regression classifier to predict people’s gender based on their weight and height.
Gustavo ChávezBlockedUnblockFollowFollowingFeb 21Logistic Regression is a popular statistical model used for binary classification, that is for predictions of the type this or that, yes or no, A or B, etc.
Logistic regression can, however, be used for multiclass classification, but here we will focus on its simplest application.
As an example, consider the task of predicting someone’s gender (Male/Female) based on their Weight and Height.
For this, we will train a machine learning model from a data set of 10,000 samples of people’s weight and height.
The data set is taken from the Conway & Myles Machine Learning for Hackers book, Chapter 2, and can it can be directly downloaded here.
This is a preview of what the data looks like:Each sample contains three columns: Height, Weight, and Male.
Height in inchesWeight in poundsMale: 1 means that the measurement corresponds to a male person, and 0 means that the measurement corresponds to a female person.
There are 5,000 samples from males, and 5,000 samples for females, thus the data set is balanced and we can proceed to training.
The Python’s scikit-learn code to train a logistic regression classifier and make a prediction is very straightforward:The general workflow is:get a datasettrain a classifiermake a prediction using such classifierLogistic regression hypothesisThe logistic regression classifier can be derived by analogy to the linear regression hypothesis which is:Linear regression hypothesis.
However, the logistic regression hypothesis generalizes from the linear regression hypothesis in that it uses the logistic function:The result is the logistic regression hypothesis:Logistic regression hypothesis.
The function g(z) is the logistic function, also known as the sigmoid function.
The logistic function has asymptotes at 0 and 1, and it crosses the y-axis at 0.
Logistic regression decision boundarySince our data set has two features: height and weight, the logistic regression hypothesis is the following:The logistic regression classifier will predict “Male” if:This is because the logistic regression “threshold” is set at g(z)=0.
5, see the plot of the logistic regression function above for verification.
For our data set the values of θ are:To get access to the θ parameters computed by scikit-learn one can do:# For theta_0:print( fitted_model.
intercept_ )# For theta_1 and theta_2:print( fitted_model.
coef_ )With the coefficients at hand, a manual prediction (that is, without using the function clf.
predict()) would simply require to compute the vector productand to check if the resulting scalar is bigger than or equal to zero (to predict Male), or otherwise (to predict Female).
As an example, say we want to predict the gender of someone with Height=70 inches and Weight = 180 pounds, like at line 14 at the script LogisticRegression.
py above, one can simply do:Making a prediction using the Logistic Regression parameter θ.
Since the result of the product is bigger than zero, the classifier will predict Male.
A visualization of the decision boundary and the complete data set can be seen here:As you can see, above the decision boundary lie most of the blue points that correspond to the Male class, and below it all the pink points that correspond to the Female class.
Also, from just looking at the data you can tell that the predictions won’t be perfect.
This can be improved by including more features (beyond weight and height), and by potentially using a different decision boundary.
Logistic regression decision boundaries can also be non-linear functions, such as higher degree polynomials.
Computing the logistic regression parameterThe scikit-learn library does a great job of abstracting the computation of the logistic regression parameter θ, and the way it is done is by solving an optimization problem.
Let’s start by defining the logistic regression cost function for the two points of interest: y=1, and y=0, that is, when the hypothesis function predicts Male or Female.
Then, we take a convex combination in y of these two terms to come up with the logistic regression cost function:Logistic regression cost function.
The logistic regression cost function is convex.
Thus, in order to compute θ, one needs to solve the following (unconstrained) optimization problem:There is a variety of methods that can be used to solve this unconstrained optimization problem, such as the 1st order method gradient descent that requires the gradient of the logistic regression cost function, or a 2nd order method such as Newton’s conjugate gradients that requires the gradient and the Hessian of the logistic regression cost function — this was the method prescribed in the scikit-learn script above.
For the case of gradient descent, the search direction is the negative partial derivative of the logistic regression cost function with respect to the parameter θ:Partial derivative of the logistic regression cost function.
In its most basic form, gradient descent will iterate along the negative gradient direction of θ (known as a minimizing sequence) until reaching convergence.
Prototype of gradient descent.
Notice that the constant α is usually called the learning rate or the search step and that it has to be carefully tuned to reach convergence.
Algorithms such as backtracking line search aid in the determination of α.
In summary, these are the three fundamental concepts that you should remember next time you are using, or implementing, a logistic regression classifier:1.
Logistic regression hypothesis2.
Logistic regression decision boundary3.
Logistic regression cost functionFor a discussion of the Logistic regression classifier applied to a data set with more features (using Python too) I recommend this Medium post of Susan Li.
References and further reading:Andrew Ng’s lectures on Logistic Regressionscikit-learn’s logistic regression classGustavo Chávez is a postdoctoral fellow at the Lawrence Berkeley National Laboratory where he works at the intersection of machine learning and high-performance computing.