Logistic Regression is a method used in machine learning to predict if something belongs to one of two categories, like “Yes” or “No.”
Key Ideas:¶
What does it do?
- Logistic regression helps answer questions like, “Will a customer buy this product?” or “Is this email spam?”
- It takes input data (like age, salary, etc.) and predicts whether the outcome is one category (like “Yes”) or the other (like “No”).
How does it work?
- The model looks at the input data and creates a linear equation, just like drawing a straight line in simple math.
- But instead of directly using this line for predictions, it uses a special “S” shaped curve called the sigmoid function to convert the result into a probability, between 0 and 1.
- If the probability is higher than 0.5, it predicts “Yes” (or 1), and if it’s lower, it predicts “No” (or 0).
Example:
- Imagine you’re trying to predict if someone will buy a car based on their age and income.
- If the model says the probability is 0.8 (or 80%), we predict they will buy the car.
- If the probability is 0.3 (or 30%), we predict they won’t buy the car.
- Imagine you’re trying to predict if someone will buy a car based on their age and income.
Why use it?
- It’s simple and gives clear predictions.
- It’s useful when you only have two possible outcomes (Yes/No, True/False).
Summary:¶
Logistic regression helps predict which of two categories something belongs to by turning a linear equation into a probability and making decisions based on that.
Let’s review example step by step.¶
Practical Explanation:¶
- Importing Libraries:
In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
numpy
: For numerical operations like arrays and mathematical functions.pandas
: For data manipulation (used later in the code).matplotlib.pyplot
: For plotting graphs and data visualizations.
- Creating an Array of Values:
In [2]:
X = np.arange(-10,10,0.1)
np.arange(-10, 10, 0.1)
creates an arrayX
of values from -10 to 10 with increments of 0.1.
- Plotting the Array:
In [3]:
plt.plot(X)
plt.show()
- Plots the array
X
as a simple line graph.
- Sigmoid Function Implementation:
In [14]:
import math
y = []
def sigmoid(items):
for i in items:
v = 1/(1+math.exp(-i))
y.append(v)
return y
val = sigmoid(X)
- This function applies the sigmoid function to each element in
X
. The sigmoid function is1 / (1 + e^(-x))
, which compresses any real number into the range (0, 1).- The transformed values are stored in the list
L
.
- The transformed values are stored in the list
- Plotting Sigmoid Values:
In [15]:
plt.plot(val)
plt.show()
- Plots the transformed sigmoid values
val
to visualize how the sigmoid function compresses the values ofX
.
Social Network Ads Section:¶
- Reading the CSV File:
In [16]:
data = pd.read_csv('Social_Network_Ads.csv')
- The dataset
'Social_Network_Ads.csv'
is loaded into a Pandas DataFramedata
.
- Extracting Features and Target:
In [17]:
X = data.iloc[:,2:4].values
y = data.iloc[:,4].values
X
contains the feature columns (Age, Estimated Salary).y
contains the target column (Purchased: 0 for no, 1 for yes).
- Visualizing Data:
In [18]:
plt.scatter(X[:,0], X[:,1])
plt.show()
- This scatter plot shows the relationship between Age and Salary.
- Visualizing Data by Target Class:
In [19]:
plt.scatter(X[y==0, 0], X[y==0, 1], label='No')
plt.scatter(X[y==1, 0], X[y==1, 1], label='Yes')
plt.legend()
plt.show()
- This scatter plot separates the data into two groups (No purchase and Yes purchase) based on the target variable
y
.
Splitting and Visualizing the Training Data:¶
- Splitting Data into Train and Test Sets:
In [20]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)
- Splits the data into training (75%) and testing (25%) sets using
train_test_split()
.
- Visualizing Training Data:
In [21]:
plt.scatter(X_train[y_train==0,0], X_train[y_train==0,1], label='No')
plt.scatter(X_train[y_train==1,0], X_train[y_train==1,1], label='Yes')
plt.legend()
plt.show()
- Visualizes the training data for both classes (No and Yes) in a scatter plot.
- Visualizing Test Data:
In [22]:
plt.scatter(X_test[y_test==0,0], X_test[y_test==0,1], label='No')
plt.scatter(X_test[y_test==1,0], X_test[y_test==1,1], label='Yes')
plt.legend()
plt.show()
In [23]:
from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)
StandardScaler()
is used to normalize the data, ensuring that the features have a mean of 0 and a standard deviation of 1. This helps the model perform better.- The training set is fitted and transformed, while the test set is transformed only (using the same scaler).
- Training Logistic Regression Model:
In [24]:
from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
Out[24]:
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
- A logistic regression model is trained using the training data
X_train
andy_train
.
- Making Predictions:
In [25]:
y_pred = classifier.predict(X_test)
- The model makes predictions on the test data
X_test
.
- Evaluating the Model:
In [26]:
classifier.score(X_test, y_test) * 100
Out[26]:
89.0
- The model’s accuracy is computed as a percentage on the test data.
- Confusion Matrix:
In [31]:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
cm
Out[31]:
array([[65, 3], [ 8, 24]], dtype=int64)
- A confusion matrix is created to evaluate the model’s performance by showing the counts of true/false positives and negatives.
In [30]:
from sklearn.metrics import accuracy_score
accuracy_score(y_test,y_pred)*100
Out[30]:
89.0
- Predicting on a Single Input:
In [28]:
val = ss.transform([[36, 50000]])
classifier.predict(val)
Out[28]:
array([0], dtype=int64)
- The model predicts whether a 36-year-old earning $50,000 will purchase the product.
Decision Boundary Plot:¶
- Meshgrid for Decision Boundary:
1. Plotting x1 and y1:¶
In [36]:
x1 = np.array([1,2,3,4,5])
y1 = np.array([11,22,33,44,55])
In [37]:
x1
Out[37]:
array([1, 2, 3, 4, 5])
In [38]:
y1
Out[38]:
array([11, 22, 33, 44, 55])
In [39]:
plt.plot(x1,y1,marker='o')
plt.show()
Creating Arrays:
x1 = np.array([1, 2, 3, 4, 5])
: Creates an arrayx1
with values [1, 2, 3, 4, 5].y1 = np.array([11, 22, 33, 44, 55])
: Creates an arrayy1
with values [11, 22, 33, 44, 55].
Plotting x1 vs y1:
plt.plot(x1, y1, marker='o')
: Plotsx1
vsy1
and marks the data points with circles ('o'
).plt.show()
: Displays the plot.
2. Creating a Meshgrid:¶
In [40]:
xx, yy = np.meshgrid(x1, y1)
In [41]:
xx
Out[41]:
array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
In [42]:
yy
Out[42]:
array([[11, 11, 11, 11, 11], [22, 22, 22, 22, 22], [33, 33, 33, 33, 33], [44, 44, 44, 44, 44], [55, 55, 55, 55, 55]])
np.meshgrid(x1, y1)
:- Creates two 2D arrays (
xx
andyy
) from the 1D arraysx1
andy1
. xx
holds the x-coordinates andyy
holds the y-coordinates of a grid created fromx1
andy1
.
- Creates two 2D arrays (
3. Plotting the Meshgrid:¶
In [43]:
plt.plot(xx, yy)
plt.show()
plt.plot(xx, yy)
: This plots all the points fromxx
andyy
. Sincexx
andyy
represent a grid of points, this shows a grid structure.plt.show()
: Displays the grid plot.
4. Plotting Training Data:¶
In [44]:
X_set, y_set = X_train, y_train
In [45]:
plt.title("Social Network Ads by Age and Salary")
plt.xlabel("Age")
plt.ylabel("Salary")
plt.scatter(X_set[y_set==0, 0], X_set[y_set==0, 1], label='No')
plt.scatter(X_set[y_set==1, 0], X_set[y_set==1, 1], label='Yes')
plt.legend()
plt.show()
Assigning Training Data:
X_set, y_set = X_train, y_train
: Assigns the training feature setX_train
and labelsy_train
toX_set
andy_set
.
Plotting Scatter Plot:
plt.scatter(X_set[y_set==0, 0], X_set[y_set==0, 1], label='No')
: Plots data points wherey_set == 0
(non-purchase) on the scatter plot.plt.scatter(X_set[y_set==1, 0], X_set[y_set==1, 1], label='Yes')
: Plots data points wherey_set == 1
(purchase) on the scatter plot.
Plot Titles and Labels:
- Titles the plot as “Social Network Ads by Age and Salary” and labels the x-axis as “Age” and y-axis as “Salary”.
Legend and Show:
plt.legend()
: Adds a legend to distinguish between “No” and “Yes”.plt.show()
: Displays the plot.
5. Defining X1 and X2 for Contour Plot:¶
In [48]:
X_set[:,0].min()
Out[48]:
-1.9931891594584856
In [49]:
X_set[:,0].max()
Out[49]:
2.166165495920269
In [50]:
X1 = np.arange(X_set[:,0].min()-1, X_set[:,0].max()+1, 0.01)
X2 = np.arange(X_set[:,1].min()-1, X_set[:,1].max()+1, 0.01)
In [52]:
X1
Out[52]:
array([-2.99318916, -2.98318916, -2.97318916, -2.96318916, -2.95318916, -2.94318916, -2.93318916, -2.92318916, -2.91318916, -2.90318916, -2.89318916, -2.88318916, -2.87318916, -2.86318916, -2.85318916, -2.84318916, -2.83318916, -2.82318916, -2.81318916, -2.80318916, -2.79318916, -2.78318916, -2.77318916, -2.76318916, -2.75318916, -2.74318916, -2.73318916, -2.72318916, -2.71318916, -2.70318916, -2.69318916, -2.68318916, -2.67318916, -2.66318916, -2.65318916, -2.64318916, -2.63318916, -2.62318916, -2.61318916, -2.60318916, -2.59318916, -2.58318916, -2.57318916, -2.56318916, -2.55318916, -2.54318916, -2.53318916, -2.52318916, -2.51318916, -2.50318916, -2.49318916, -2.48318916, -2.47318916, -2.46318916, -2.45318916, -2.44318916, -2.43318916, -2.42318916, -2.41318916, -2.40318916, -2.39318916, -2.38318916, -2.37318916, -2.36318916, -2.35318916, -2.34318916, -2.33318916, -2.32318916, -2.31318916, -2.30318916, -2.29318916, -2.28318916, -2.27318916, -2.26318916, -2.25318916, -2.24318916, -2.23318916, -2.22318916, -2.21318916, -2.20318916, -2.19318916, -2.18318916, -2.17318916, -2.16318916, -2.15318916, -2.14318916, -2.13318916, -2.12318916, -2.11318916, -2.10318916, -2.09318916, -2.08318916, -2.07318916, -2.06318916, -2.05318916, -2.04318916, -2.03318916, -2.02318916, -2.01318916, -2.00318916, -1.99318916, -1.98318916, -1.97318916, -1.96318916, -1.95318916, -1.94318916, -1.93318916, -1.92318916, -1.91318916, -1.90318916, -1.89318916, -1.88318916, -1.87318916, -1.86318916, -1.85318916, -1.84318916, -1.83318916, -1.82318916, -1.81318916, -1.80318916, -1.79318916, -1.78318916, -1.77318916, -1.76318916, -1.75318916, -1.74318916, -1.73318916, -1.72318916, -1.71318916, -1.70318916, -1.69318916, -1.68318916, -1.67318916, -1.66318916, -1.65318916, -1.64318916, -1.63318916, -1.62318916, -1.61318916, -1.60318916, -1.59318916, -1.58318916, -1.57318916, -1.56318916, -1.55318916, -1.54318916, -1.53318916, -1.52318916, -1.51318916, -1.50318916, -1.49318916, -1.48318916, -1.47318916, -1.46318916, -1.45318916, -1.44318916, -1.43318916, -1.42318916, -1.41318916, -1.40318916, -1.39318916, -1.38318916, -1.37318916, -1.36318916, -1.35318916, -1.34318916, -1.33318916, -1.32318916, -1.31318916, -1.30318916, -1.29318916, -1.28318916, -1.27318916, -1.26318916, -1.25318916, -1.24318916, -1.23318916, -1.22318916, -1.21318916, -1.20318916, -1.19318916, -1.18318916, -1.17318916, -1.16318916, -1.15318916, -1.14318916, -1.13318916, -1.12318916, -1.11318916, -1.10318916, -1.09318916, -1.08318916, -1.07318916, -1.06318916, -1.05318916, -1.04318916, -1.03318916, -1.02318916, -1.01318916, -1.00318916, -0.99318916, -0.98318916, -0.97318916, -0.96318916, -0.95318916, -0.94318916, -0.93318916, -0.92318916, -0.91318916, -0.90318916, -0.89318916, -0.88318916, -0.87318916, -0.86318916, -0.85318916, -0.84318916, -0.83318916, -0.82318916, -0.81318916, -0.80318916, -0.79318916, -0.78318916, -0.77318916, -0.76318916, -0.75318916, -0.74318916, -0.73318916, -0.72318916, -0.71318916, -0.70318916, -0.69318916, -0.68318916, -0.67318916, -0.66318916, -0.65318916, -0.64318916, -0.63318916, -0.62318916, -0.61318916, -0.60318916, -0.59318916, -0.58318916, -0.57318916, -0.56318916, -0.55318916, -0.54318916, -0.53318916, -0.52318916, -0.51318916, -0.50318916, -0.49318916, -0.48318916, -0.47318916, -0.46318916, -0.45318916, -0.44318916, -0.43318916, -0.42318916, -0.41318916, -0.40318916, -0.39318916, -0.38318916, -0.37318916, -0.36318916, -0.35318916, -0.34318916, -0.33318916, -0.32318916, -0.31318916, -0.30318916, -0.29318916, -0.28318916, -0.27318916, -0.26318916, -0.25318916, -0.24318916, -0.23318916, -0.22318916, -0.21318916, -0.20318916, -0.19318916, -0.18318916, -0.17318916, -0.16318916, -0.15318916, -0.14318916, -0.13318916, -0.12318916, -0.11318916, -0.10318916, -0.09318916, -0.08318916, -0.07318916, -0.06318916, -0.05318916, -0.04318916, -0.03318916, -0.02318916, -0.01318916, -0.00318916, 0.00681084, 0.01681084, 0.02681084, 0.03681084, 0.04681084, 0.05681084, 0.06681084, 0.07681084, 0.08681084, 0.09681084, 0.10681084, 0.11681084, 0.12681084, 0.13681084, 0.14681084, 0.15681084, 0.16681084, 0.17681084, 0.18681084, 0.19681084, 0.20681084, 0.21681084, 0.22681084, 0.23681084, 0.24681084, 0.25681084, 0.26681084, 0.27681084, 0.28681084, 0.29681084, 0.30681084, 0.31681084, 0.32681084, 0.33681084, 0.34681084, 0.35681084, 0.36681084, 0.37681084, 0.38681084, 0.39681084, 0.40681084, 0.41681084, 0.42681084, 0.43681084, 0.44681084, 0.45681084, 0.46681084, 0.47681084, 0.48681084, 0.49681084, 0.50681084, 0.51681084, 0.52681084, 0.53681084, 0.54681084, 0.55681084, 0.56681084, 0.57681084, 0.58681084, 0.59681084, 0.60681084, 0.61681084, 0.62681084, 0.63681084, 0.64681084, 0.65681084, 0.66681084, 0.67681084, 0.68681084, 0.69681084, 0.70681084, 0.71681084, 0.72681084, 0.73681084, 0.74681084, 0.75681084, 0.76681084, 0.77681084, 0.78681084, 0.79681084, 0.80681084, 0.81681084, 0.82681084, 0.83681084, 0.84681084, 0.85681084, 0.86681084, 0.87681084, 0.88681084, 0.89681084, 0.90681084, 0.91681084, 0.92681084, 0.93681084, 0.94681084, 0.95681084, 0.96681084, 0.97681084, 0.98681084, 0.99681084, 1.00681084, 1.01681084, 1.02681084, 1.03681084, 1.04681084, 1.05681084, 1.06681084, 1.07681084, 1.08681084, 1.09681084, 1.10681084, 1.11681084, 1.12681084, 1.13681084, 1.14681084, 1.15681084, 1.16681084, 1.17681084, 1.18681084, 1.19681084, 1.20681084, 1.21681084, 1.22681084, 1.23681084, 1.24681084, 1.25681084, 1.26681084, 1.27681084, 1.28681084, 1.29681084, 1.30681084, 1.31681084, 1.32681084, 1.33681084, 1.34681084, 1.35681084, 1.36681084, 1.37681084, 1.38681084, 1.39681084, 1.40681084, 1.41681084, 1.42681084, 1.43681084, 1.44681084, 1.45681084, 1.46681084, 1.47681084, 1.48681084, 1.49681084, 1.50681084, 1.51681084, 1.52681084, 1.53681084, 1.54681084, 1.55681084, 1.56681084, 1.57681084, 1.58681084, 1.59681084, 1.60681084, 1.61681084, 1.62681084, 1.63681084, 1.64681084, 1.65681084, 1.66681084, 1.67681084, 1.68681084, 1.69681084, 1.70681084, 1.71681084, 1.72681084, 1.73681084, 1.74681084, 1.75681084, 1.76681084, 1.77681084, 1.78681084, 1.79681084, 1.80681084, 1.81681084, 1.82681084, 1.83681084, 1.84681084, 1.85681084, 1.86681084, 1.87681084, 1.88681084, 1.89681084, 1.90681084, 1.91681084, 1.92681084, 1.93681084, 1.94681084, 1.95681084, 1.96681084, 1.97681084, 1.98681084, 1.99681084, 2.00681084, 2.01681084, 2.02681084, 2.03681084, 2.04681084, 2.05681084, 2.06681084, 2.07681084, 2.08681084, 2.09681084, 2.10681084, 2.11681084, 2.12681084, 2.13681084, 2.14681084, 2.15681084, 2.16681084, 2.17681084, 2.18681084, 2.19681084, 2.20681084, 2.21681084, 2.22681084, 2.23681084, 2.24681084, 2.25681084, 2.26681084, 2.27681084, 2.28681084, 2.29681084, 2.30681084, 2.31681084, 2.32681084, 2.33681084, 2.34681084, 2.35681084, 2.36681084, 2.37681084, 2.38681084, 2.39681084, 2.40681084, 2.41681084, 2.42681084, 2.43681084, 2.44681084, 2.45681084, 2.46681084, 2.47681084, 2.48681084, 2.49681084, 2.50681084, 2.51681084, 2.52681084, 2.53681084, 2.54681084, 2.55681084, 2.56681084, 2.57681084, 2.58681084, 2.59681084, 2.60681084, 2.61681084, 2.62681084, 2.63681084, 2.64681084, 2.65681084, 2.66681084, 2.67681084, 2.68681084, 2.69681084, 2.70681084, 2.71681084, 2.72681084, 2.73681084, 2.74681084, 2.75681084, 2.76681084, 2.77681084, 2.78681084, 2.79681084, 2.80681084, 2.81681084, 2.82681084, 2.83681084, 2.84681084, 2.85681084, 2.86681084, 2.87681084, 2.88681084, 2.89681084, 2.90681084, 2.91681084, 2.92681084, 2.93681084, 2.94681084, 2.95681084, 2.96681084, 2.97681084, 2.98681084, 2.99681084, 3.00681084, 3.01681084, 3.02681084, 3.03681084, 3.04681084, 3.05681084, 3.06681084, 3.07681084, 3.08681084, 3.09681084, 3.10681084, 3.11681084, 3.12681084, 3.13681084, 3.14681084, 3.15681084])
In [53]:
X2
Out[53]:
array([-2.58254245e+00, -2.57254245e+00, -2.56254245e+00, -2.55254245e+00, -2.54254245e+00, -2.53254245e+00, -2.52254245e+00, -2.51254245e+00, -2.50254245e+00, -2.49254245e+00, -2.48254245e+00, -2.47254245e+00, -2.46254245e+00, -2.45254245e+00, -2.44254245e+00, -2.43254245e+00, -2.42254245e+00, -2.41254245e+00, -2.40254245e+00, -2.39254245e+00, -2.38254245e+00, -2.37254245e+00, -2.36254245e+00, -2.35254245e+00, -2.34254245e+00, -2.33254245e+00, -2.32254245e+00, -2.31254245e+00, -2.30254245e+00, -2.29254245e+00, -2.28254245e+00, -2.27254245e+00, -2.26254245e+00, -2.25254245e+00, -2.24254245e+00, -2.23254245e+00, -2.22254245e+00, -2.21254245e+00, -2.20254245e+00, -2.19254245e+00, -2.18254245e+00, -2.17254245e+00, -2.16254245e+00, -2.15254245e+00, -2.14254245e+00, -2.13254245e+00, -2.12254245e+00, -2.11254245e+00, -2.10254245e+00, -2.09254245e+00, -2.08254245e+00, -2.07254245e+00, -2.06254245e+00, -2.05254245e+00, -2.04254245e+00, -2.03254245e+00, -2.02254245e+00, -2.01254245e+00, -2.00254245e+00, -1.99254245e+00, -1.98254245e+00, -1.97254245e+00, -1.96254245e+00, -1.95254245e+00, -1.94254245e+00, -1.93254245e+00, -1.92254245e+00, -1.91254245e+00, -1.90254245e+00, -1.89254245e+00, -1.88254245e+00, -1.87254245e+00, -1.86254245e+00, -1.85254245e+00, -1.84254245e+00, -1.83254245e+00, -1.82254245e+00, -1.81254245e+00, -1.80254245e+00, -1.79254245e+00, -1.78254245e+00, -1.77254245e+00, -1.76254245e+00, -1.75254245e+00, -1.74254245e+00, -1.73254245e+00, -1.72254245e+00, -1.71254245e+00, -1.70254245e+00, -1.69254245e+00, -1.68254245e+00, -1.67254245e+00, -1.66254245e+00, -1.65254245e+00, -1.64254245e+00, -1.63254245e+00, -1.62254245e+00, -1.61254245e+00, -1.60254245e+00, -1.59254245e+00, -1.58254245e+00, -1.57254245e+00, -1.56254245e+00, -1.55254245e+00, -1.54254245e+00, -1.53254245e+00, -1.52254245e+00, -1.51254245e+00, -1.50254245e+00, -1.49254245e+00, -1.48254245e+00, -1.47254245e+00, -1.46254245e+00, -1.45254245e+00, -1.44254245e+00, -1.43254245e+00, -1.42254245e+00, -1.41254245e+00, -1.40254245e+00, -1.39254245e+00, -1.38254245e+00, -1.37254245e+00, -1.36254245e+00, -1.35254245e+00, -1.34254245e+00, -1.33254245e+00, -1.32254245e+00, -1.31254245e+00, -1.30254245e+00, -1.29254245e+00, -1.28254245e+00, -1.27254245e+00, -1.26254245e+00, -1.25254245e+00, -1.24254245e+00, -1.23254245e+00, -1.22254245e+00, -1.21254245e+00, -1.20254245e+00, -1.19254245e+00, -1.18254245e+00, -1.17254245e+00, -1.16254245e+00, -1.15254245e+00, -1.14254245e+00, -1.13254245e+00, -1.12254245e+00, -1.11254245e+00, -1.10254245e+00, -1.09254245e+00, -1.08254245e+00, -1.07254245e+00, -1.06254245e+00, -1.05254245e+00, -1.04254245e+00, -1.03254245e+00, -1.02254245e+00, -1.01254245e+00, -1.00254245e+00, -9.92542448e-01, -9.82542448e-01, -9.72542448e-01, -9.62542448e-01, -9.52542448e-01, -9.42542448e-01, -9.32542448e-01, -9.22542448e-01, -9.12542448e-01, -9.02542448e-01, -8.92542448e-01, -8.82542448e-01, -8.72542448e-01, -8.62542448e-01, -8.52542448e-01, -8.42542448e-01, -8.32542448e-01, -8.22542448e-01, -8.12542448e-01, -8.02542448e-01, -7.92542448e-01, -7.82542448e-01, -7.72542448e-01, -7.62542448e-01, -7.52542448e-01, -7.42542448e-01, -7.32542448e-01, -7.22542448e-01, -7.12542448e-01, -7.02542448e-01, -6.92542448e-01, -6.82542448e-01, -6.72542448e-01, -6.62542448e-01, -6.52542448e-01, -6.42542448e-01, -6.32542448e-01, -6.22542448e-01, -6.12542448e-01, -6.02542448e-01, -5.92542448e-01, -5.82542448e-01, -5.72542448e-01, -5.62542448e-01, -5.52542448e-01, -5.42542448e-01, -5.32542448e-01, -5.22542448e-01, -5.12542448e-01, -5.02542448e-01, -4.92542448e-01, -4.82542448e-01, -4.72542448e-01, -4.62542448e-01, -4.52542448e-01, -4.42542448e-01, -4.32542448e-01, -4.22542448e-01, -4.12542448e-01, -4.02542448e-01, -3.92542448e-01, -3.82542448e-01, -3.72542448e-01, -3.62542448e-01, -3.52542448e-01, -3.42542448e-01, -3.32542448e-01, -3.22542448e-01, -3.12542448e-01, -3.02542448e-01, -2.92542448e-01, -2.82542448e-01, -2.72542448e-01, -2.62542448e-01, -2.52542448e-01, -2.42542448e-01, -2.32542448e-01, -2.22542448e-01, -2.12542448e-01, -2.02542448e-01, -1.92542448e-01, -1.82542448e-01, -1.72542448e-01, -1.62542448e-01, -1.52542448e-01, -1.42542448e-01, -1.32542448e-01, -1.22542448e-01, -1.12542448e-01, -1.02542448e-01, -9.25424478e-02, -8.25424478e-02, -7.25424478e-02, -6.25424478e-02, -5.25424478e-02, -4.25424478e-02, -3.25424478e-02, -2.25424478e-02, -1.25424478e-02, -2.54244776e-03, 7.45755224e-03, 1.74575522e-02, 2.74575522e-02, 3.74575522e-02, 4.74575522e-02, 5.74575522e-02, 6.74575522e-02, 7.74575522e-02, 8.74575522e-02, 9.74575522e-02, 1.07457552e-01, 1.17457552e-01, 1.27457552e-01, 1.37457552e-01, 1.47457552e-01, 1.57457552e-01, 1.67457552e-01, 1.77457552e-01, 1.87457552e-01, 1.97457552e-01, 2.07457552e-01, 2.17457552e-01, 2.27457552e-01, 2.37457552e-01, 2.47457552e-01, 2.57457552e-01, 2.67457552e-01, 2.77457552e-01, 2.87457552e-01, 2.97457552e-01, 3.07457552e-01, 3.17457552e-01, 3.27457552e-01, 3.37457552e-01, 3.47457552e-01, 3.57457552e-01, 3.67457552e-01, 3.77457552e-01, 3.87457552e-01, 3.97457552e-01, 4.07457552e-01, 4.17457552e-01, 4.27457552e-01, 4.37457552e-01, 4.47457552e-01, 4.57457552e-01, 4.67457552e-01, 4.77457552e-01, 4.87457552e-01, 4.97457552e-01, 5.07457552e-01, 5.17457552e-01, 5.27457552e-01, 5.37457552e-01, 5.47457552e-01, 5.57457552e-01, 5.67457552e-01, 5.77457552e-01, 5.87457552e-01, 5.97457552e-01, 6.07457552e-01, 6.17457552e-01, 6.27457552e-01, 6.37457552e-01, 6.47457552e-01, 6.57457552e-01, 6.67457552e-01, 6.77457552e-01, 6.87457552e-01, 6.97457552e-01, 7.07457552e-01, 7.17457552e-01, 7.27457552e-01, 7.37457552e-01, 7.47457552e-01, 7.57457552e-01, 7.67457552e-01, 7.77457552e-01, 7.87457552e-01, 7.97457552e-01, 8.07457552e-01, 8.17457552e-01, 8.27457552e-01, 8.37457552e-01, 8.47457552e-01, 8.57457552e-01, 8.67457552e-01, 8.77457552e-01, 8.87457552e-01, 8.97457552e-01, 9.07457552e-01, 9.17457552e-01, 9.27457552e-01, 9.37457552e-01, 9.47457552e-01, 9.57457552e-01, 9.67457552e-01, 9.77457552e-01, 9.87457552e-01, 9.97457552e-01, 1.00745755e+00, 1.01745755e+00, 1.02745755e+00, 1.03745755e+00, 1.04745755e+00, 1.05745755e+00, 1.06745755e+00, 1.07745755e+00, 1.08745755e+00, 1.09745755e+00, 1.10745755e+00, 1.11745755e+00, 1.12745755e+00, 1.13745755e+00, 1.14745755e+00, 1.15745755e+00, 1.16745755e+00, 1.17745755e+00, 1.18745755e+00, 1.19745755e+00, 1.20745755e+00, 1.21745755e+00, 1.22745755e+00, 1.23745755e+00, 1.24745755e+00, 1.25745755e+00, 1.26745755e+00, 1.27745755e+00, 1.28745755e+00, 1.29745755e+00, 1.30745755e+00, 1.31745755e+00, 1.32745755e+00, 1.33745755e+00, 1.34745755e+00, 1.35745755e+00, 1.36745755e+00, 1.37745755e+00, 1.38745755e+00, 1.39745755e+00, 1.40745755e+00, 1.41745755e+00, 1.42745755e+00, 1.43745755e+00, 1.44745755e+00, 1.45745755e+00, 1.46745755e+00, 1.47745755e+00, 1.48745755e+00, 1.49745755e+00, 1.50745755e+00, 1.51745755e+00, 1.52745755e+00, 1.53745755e+00, 1.54745755e+00, 1.55745755e+00, 1.56745755e+00, 1.57745755e+00, 1.58745755e+00, 1.59745755e+00, 1.60745755e+00, 1.61745755e+00, 1.62745755e+00, 1.63745755e+00, 1.64745755e+00, 1.65745755e+00, 1.66745755e+00, 1.67745755e+00, 1.68745755e+00, 1.69745755e+00, 1.70745755e+00, 1.71745755e+00, 1.72745755e+00, 1.73745755e+00, 1.74745755e+00, 1.75745755e+00, 1.76745755e+00, 1.77745755e+00, 1.78745755e+00, 1.79745755e+00, 1.80745755e+00, 1.81745755e+00, 1.82745755e+00, 1.83745755e+00, 1.84745755e+00, 1.85745755e+00, 1.86745755e+00, 1.87745755e+00, 1.88745755e+00, 1.89745755e+00, 1.90745755e+00, 1.91745755e+00, 1.92745755e+00, 1.93745755e+00, 1.94745755e+00, 1.95745755e+00, 1.96745755e+00, 1.97745755e+00, 1.98745755e+00, 1.99745755e+00, 2.00745755e+00, 2.01745755e+00, 2.02745755e+00, 2.03745755e+00, 2.04745755e+00, 2.05745755e+00, 2.06745755e+00, 2.07745755e+00, 2.08745755e+00, 2.09745755e+00, 2.10745755e+00, 2.11745755e+00, 2.12745755e+00, 2.13745755e+00, 2.14745755e+00, 2.15745755e+00, 2.16745755e+00, 2.17745755e+00, 2.18745755e+00, 2.19745755e+00, 2.20745755e+00, 2.21745755e+00, 2.22745755e+00, 2.23745755e+00, 2.24745755e+00, 2.25745755e+00, 2.26745755e+00, 2.27745755e+00, 2.28745755e+00, 2.29745755e+00, 2.30745755e+00, 2.31745755e+00, 2.32745755e+00, 2.33745755e+00, 2.34745755e+00, 2.35745755e+00, 2.36745755e+00, 2.37745755e+00, 2.38745755e+00, 2.39745755e+00, 2.40745755e+00, 2.41745755e+00, 2.42745755e+00, 2.43745755e+00, 2.44745755e+00, 2.45745755e+00, 2.46745755e+00, 2.47745755e+00, 2.48745755e+00, 2.49745755e+00, 2.50745755e+00, 2.51745755e+00, 2.52745755e+00, 2.53745755e+00, 2.54745755e+00, 2.55745755e+00, 2.56745755e+00, 2.57745755e+00, 2.58745755e+00, 2.59745755e+00, 2.60745755e+00, 2.61745755e+00, 2.62745755e+00, 2.63745755e+00, 2.64745755e+00, 2.65745755e+00, 2.66745755e+00, 2.67745755e+00, 2.68745755e+00, 2.69745755e+00, 2.70745755e+00, 2.71745755e+00, 2.72745755e+00, 2.73745755e+00, 2.74745755e+00, 2.75745755e+00, 2.76745755e+00, 2.77745755e+00, 2.78745755e+00, 2.79745755e+00, 2.80745755e+00, 2.81745755e+00, 2.82745755e+00, 2.83745755e+00, 2.84745755e+00, 2.85745755e+00, 2.86745755e+00, 2.87745755e+00, 2.88745755e+00, 2.89745755e+00, 2.90745755e+00, 2.91745755e+00, 2.92745755e+00, 2.93745755e+00, 2.94745755e+00, 2.95745755e+00, 2.96745755e+00, 2.97745755e+00, 2.98745755e+00, 2.99745755e+00, 3.00745755e+00, 3.01745755e+00, 3.02745755e+00, 3.03745755e+00, 3.04745755e+00, 3.05745755e+00, 3.06745755e+00, 3.07745755e+00, 3.08745755e+00, 3.09745755e+00, 3.10745755e+00, 3.11745755e+00, 3.12745755e+00, 3.13745755e+00, 3.14745755e+00, 3.15745755e+00, 3.16745755e+00, 3.17745755e+00, 3.18745755e+00, 3.19745755e+00, 3.20745755e+00, 3.21745755e+00, 3.22745755e+00, 3.23745755e+00, 3.24745755e+00, 3.25745755e+00, 3.26745755e+00, 3.27745755e+00, 3.28745755e+00, 3.29745755e+00, 3.30745755e+00, 3.31745755e+00, 3.32745755e+00])
X1
andX2
Arrays:X1
: Creates an array of values fromX_set[:, 0].min()-1
(the minimum value of feature 1, minus 1) toX_set[:, 0].max()+1
(the maximum value of feature 1, plus 1), with a step size of 0.01. This array is used for creating the decision boundary along feature 1 (Age).X2
: Similarly creates an array for feature 2 (Salary).
6. Meshgrid for X1 and X2:¶
In [54]:
xx, yy = np.meshgrid(X1, X2)
np.meshgrid(X1, X2)
:- Creates two 2D arrays (
xx
andyy
) from the 1D arraysX1
andX2
. These are used to create a grid of points for the decision boundary.
- Creates two 2D arrays (
In [55]:
xx
Out[55]:
array([[-2.99318916, -2.98318916, -2.97318916, ..., 3.13681084, 3.14681084, 3.15681084], [-2.99318916, -2.98318916, -2.97318916, ..., 3.13681084, 3.14681084, 3.15681084], [-2.99318916, -2.98318916, -2.97318916, ..., 3.13681084, 3.14681084, 3.15681084], ..., [-2.99318916, -2.98318916, -2.97318916, ..., 3.13681084, 3.14681084, 3.15681084], [-2.99318916, -2.98318916, -2.97318916, ..., 3.13681084, 3.14681084, 3.15681084], [-2.99318916, -2.98318916, -2.97318916, ..., 3.13681084, 3.14681084, 3.15681084]])
In [56]:
yy
Out[56]:
array([[-2.58254245, -2.58254245, -2.58254245, ..., -2.58254245, -2.58254245, -2.58254245], [-2.57254245, -2.57254245, -2.57254245, ..., -2.57254245, -2.57254245, -2.57254245], [-2.56254245, -2.56254245, -2.56254245, ..., -2.56254245, -2.56254245, -2.56254245], ..., [ 3.30745755, 3.30745755, 3.30745755, ..., 3.30745755, 3.30745755, 3.30745755], [ 3.31745755, 3.31745755, 3.31745755, ..., 3.31745755, 3.31745755, 3.31745755], [ 3.32745755, 3.32745755, 3.32745755, ..., 3.32745755, 3.32745755, 3.32745755]])
In [57]:
xx.shape
Out[57]:
(592, 616)
In [58]:
yy.shape
Out[58]:
(592, 616)
7. Reshaping and Predicting:¶
In [61]:
XX = np.array([xx.ravel(), yy.ravel()]).T
XX
Out[61]:
array([[-2.99318916, -2.58254245], [-2.98318916, -2.58254245], [-2.97318916, -2.58254245], ..., [ 3.13681084, 3.32745755], [ 3.14681084, 3.32745755], [ 3.15681084, 3.32745755]])
In [62]:
zz = classifier.predict(XX).reshape(xx.shape)
zz
Out[62]:
array([[0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1], ..., [0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1]], dtype=int64)
Flattening and Combining Coordinates:
xx.ravel()
andyy.ravel()
flatten thexx
andyy
arrays into 1D arrays.np.array([xx.ravel(), yy.ravel()]).T
creates an arrayXX
where each row is a pair of coordinates (Age, Salary) to be passed into the classifier.
Predicting for All Grid Points:
classifier.predict(XX)
: Predicts the labels for all the grid points (Age and Salary combinations).reshape(xx.shape)
: Reshapes the predicted labels into the shape of thexx
grid, so it can be visualized as a 2D grid.
8. Plotting the Decision Boundary:¶
In [60]:
plt.contourf(xx, yy, zz)
plt.show()
plt.contourf(xx, yy, zz)
: Creates a filled contour plot that shows the decision boundary based on the predictions (zz
). Areas where the classifier predicts “No” or “Yes” are shaded differently.plt.show()
: Displays the contour plot, which shows the decision regions based on the logistic regression classifier.
np.meshgrid()
creates a grid of points over the feature space (Age, Salary).- The model predicts labels (
z
) for each grid point, which is reshaped to match the grid shape.
- Plotting Decision Boundary:
In [65]:
plt.contourf(xx, yy, zz)
plt.scatter(X_set[y_set==0,0], X_set[y_set==0,1], label='No')
plt.scatter(X_set[y_set==1,0], X_set[y_set==1,1], label='Yes')
plt.legend()
plt.show()
- A filled contour plot (
plt.contourf()
) is drawn to show the decision boundary separating the No and Yes classes. - The scatter plot shows the actual data points on top of the decision boundary.
Testing:¶
- Visualizing Decision Boundary on Test Data:
In [66]:
plt.contourf(xx, yy, zz)
plt.scatter(X_set[y_set==0, 0], X_set[y_set==0, 1], label='No')
plt.scatter(X_set[y_set==1, 0], X_set[y_set==1, 1], label='Yes')
plt.legend()
plt.show()
- This final step tests the trained model by plotting the decision boundary along with the test data points.
Testing Part¶
In [67]:
X_set,y_set = X_test, y_test
plt.title("Social Network Ads by Age and Salary")
plt.xlabel("Age")
plt.ylabel("Salary")
X1 = np.arange(X_set[:,0].min()-1,X_set[:,0].max()+1,0.01)
X2 = np.arange(X_set[:,1].min()-1,X_set[:,1].max()+1,0.01)
xx,yy = np.meshgrid(X1,X2)
X3 = np.array([xx.ravel(),yy.ravel()]).T
zz = classifier.predict(X3).reshape(xx.shape)
plt.contourf(xx,yy,zz)
plt.scatter(X_set[y_set==0,0],X_set[y_set==0,1],label='No')
plt.scatter(X_set[y_set==1,0],X_set[y_set==1,1],label='Yes')
plt.legend()
plt.show()
In [ ]: