1. What is Decision Tree Regression?¶
- Decision Tree Regression is a method used to predict a continuous value (like house price) based on input features (like size, number of rooms). It uses a tree-like model to make decisions.
2. How Decision Trees Work¶
- The model splits the data into branches based on different features. Each branch leads to a decision about the output value.
- Think of it like a flowchart: starting from a question at the top, you follow the branches based on answers until you reach a final prediction.
3. Collect Your Data¶
- Gather the dataset you want to use. For example, if predicting house prices:
- Features might include size, number of bedrooms, age, etc.
- The target variable is the price of the house.
4. Split the Data¶
- Divide your dataset into two parts:
- Training set: For training the decision tree.
- Test set: For checking how well the tree performs on new data.
5. Build the Decision Tree¶
- Use a machine learning library to create and train the decision tree regression model with your training data.
- The model will automatically decide how to split the data at each node based on the feature that provides the best prediction.
from sklearn.tree import DecisionTreeRegressor
# Create the decision tree model
model = DecisionTreeRegressor()
model.fit(X_train, Y_train) # Train the model
6. Make Predictions¶
- Use the trained model to predict values for your test set or new data.
Y_pred = model.predict(X_test)
7. Evaluate the Model¶
- Check how well your model performed by comparing the predicted values to the actual values using metrics like:
- Mean Squared Error (MSE): Measures how far off your predictions are from the actual values.
- R-squared (R²): Indicates how well your model explains the variability in the target variable.
from sklearn.metrics import mean_squared_error, r2_score
mse = mean_squared_error(Y_test, Y_pred)
r2 = r2_score(Y_test, Y_pred)
print("Mean Squared Error:", mse)
print("R-squared:", r2)
8. Visualize the Decision Tree (Optional)¶
- You can visualize the decision tree to understand how decisions are made. This helps in understanding the model better.
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(12,8))
plot_tree(model, filled=True)
plt.show()
9. Analyze Errors¶
- Look at the errors (differences between predicted and actual values) to see if the model is doing well or if it needs adjustments.
10. Use the Model for Future Predictions¶
- After validating your model, you can use it to make predictions for new inputs.
new_data = [[size, bedrooms, age]] # Example new data
predictions = model.predict(new_data)
11. Conclusion¶
- Summarize how well the decision tree regression model performed and what you learned from the predictions. Discuss the importance of feature selection and how splits were made.
Let’s review example step by step.¶
Decision Tree Regression¶
We use decision trees because they are easy to understand, interpret, and can handle both numerical and categorical data.¶
Import the Libraries
In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
Make a list or Read Data
In [3]:
l = [[1,45],[2,51],[3,60],[4,80],[5,110],[6,150],[7,200],[8,240]]
l
Out[3]:
[[1, 45], [2, 51], [3, 60], [4, 80], [5, 110], [6, 150], [7, 200], [8, 240]]
Convert List into DataFrame
In [4]:
df = pd.DataFrame(l,columns=['x','y'])
df
Out[4]:
x | y | |
---|---|---|
0 | 1 | 45 |
1 | 2 | 51 |
2 | 3 | 60 |
3 | 4 | 80 |
4 | 5 | 110 |
5 | 6 | 150 |
6 | 7 | 200 |
7 | 8 | 240 |
Put the values of x and y
In [5]:
x = df.iloc[:,:1].values
x
Out[5]:
array([[1], [2], [3], [4], [5], [6], [7], [8]], dtype=int64)
In [6]:
y = df.iloc[:,1].values
y
Out[6]:
array([ 45, 51, 60, 80, 110, 150, 200, 240], dtype=int64)
Plot scatter x and y
In [7]:
plt.scatter(x,y)
plt.show()
Put Algorithm
In [8]:
from sklearn.tree import DecisionTreeRegressor
reg = DecisionTreeRegressor(random_state=0)
reg.fit(x,y)
Out[8]:
DecisionTreeRegressor(random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeRegressor(random_state=0)
Predict y
In [9]:
y_pred = reg.predict(x)
y_pred
Out[9]:
array([ 45., 51., 60., 80., 110., 150., 200., 240.])
In [10]:
y
Out[10]:
array([ 45, 51, 60, 80, 110, 150, 200, 240], dtype=int64)
Plot scatter x and y
Plot line x and y predict
In [11]:
plt.scatter(x,y)
plt.plot(x,y_pred)
plt.show()
Check accuracy
In [12]:
reg.score(x,y)*100
Out[12]:
100.0
Predict future value of y
In [13]:
reg.predict([[6.5]])
Out[13]:
array([150.])
In [14]:
reg.predict([[5.5]])
Out[14]:
array([110.])
In [15]:
x
Out[15]:
array([[1], [2], [3], [4], [5], [6], [7], [8]], dtype=int64)
In [16]:
X = np.arange(min(x),max(x),0.01).reshape(-1,1)
X
C:\Users\Mehak\AppData\Local\Temp\ipykernel_21848\3836922524.py:1: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.) X = np.arange(min(x),max(x),0.01).reshape(-1,1)
Out[16]:
array([[1. ], [1.01], [1.02], [1.03], [1.04], [1.05], [1.06], [1.07], [1.08], [1.09], [1.1 ], [1.11], [1.12], [1.13], [1.14], [1.15], [1.16], [1.17], [1.18], [1.19], [1.2 ], [1.21], [1.22], [1.23], [1.24], [1.25], [1.26], [1.27], [1.28], [1.29], [1.3 ], [1.31], [1.32], [1.33], [1.34], [1.35], [1.36], [1.37], [1.38], [1.39], [1.4 ], [1.41], [1.42], [1.43], [1.44], [1.45], [1.46], [1.47], [1.48], [1.49], [1.5 ], [1.51], [1.52], [1.53], [1.54], [1.55], [1.56], [1.57], [1.58], [1.59], [1.6 ], [1.61], [1.62], [1.63], [1.64], [1.65], [1.66], [1.67], [1.68], [1.69], [1.7 ], [1.71], [1.72], [1.73], [1.74], [1.75], [1.76], [1.77], [1.78], [1.79], [1.8 ], [1.81], [1.82], [1.83], [1.84], [1.85], [1.86], [1.87], [1.88], [1.89], [1.9 ], [1.91], [1.92], [1.93], [1.94], [1.95], [1.96], [1.97], [1.98], [1.99], [2. ], [2.01], [2.02], [2.03], [2.04], [2.05], [2.06], [2.07], [2.08], [2.09], [2.1 ], [2.11], [2.12], [2.13], [2.14], [2.15], [2.16], [2.17], [2.18], [2.19], [2.2 ], [2.21], [2.22], [2.23], [2.24], [2.25], [2.26], [2.27], [2.28], [2.29], [2.3 ], [2.31], [2.32], [2.33], [2.34], [2.35], [2.36], [2.37], [2.38], [2.39], [2.4 ], [2.41], [2.42], [2.43], [2.44], [2.45], [2.46], [2.47], [2.48], [2.49], [2.5 ], [2.51], [2.52], [2.53], [2.54], [2.55], [2.56], [2.57], [2.58], [2.59], [2.6 ], [2.61], [2.62], [2.63], [2.64], [2.65], [2.66], [2.67], [2.68], [2.69], [2.7 ], [2.71], [2.72], [2.73], [2.74], [2.75], [2.76], [2.77], [2.78], [2.79], [2.8 ], [2.81], [2.82], [2.83], [2.84], [2.85], [2.86], [2.87], [2.88], [2.89], [2.9 ], [2.91], [2.92], [2.93], [2.94], [2.95], [2.96], [2.97], [2.98], [2.99], [3. ], [3.01], [3.02], [3.03], [3.04], [3.05], [3.06], [3.07], [3.08], [3.09], [3.1 ], [3.11], [3.12], [3.13], [3.14], [3.15], [3.16], [3.17], [3.18], [3.19], [3.2 ], [3.21], [3.22], [3.23], [3.24], [3.25], [3.26], [3.27], [3.28], [3.29], [3.3 ], [3.31], [3.32], [3.33], [3.34], [3.35], [3.36], [3.37], [3.38], [3.39], [3.4 ], [3.41], [3.42], [3.43], [3.44], [3.45], [3.46], [3.47], [3.48], [3.49], [3.5 ], [3.51], [3.52], [3.53], [3.54], [3.55], [3.56], [3.57], [3.58], [3.59], [3.6 ], [3.61], [3.62], [3.63], [3.64], [3.65], [3.66], [3.67], [3.68], [3.69], [3.7 ], [3.71], [3.72], [3.73], [3.74], [3.75], [3.76], [3.77], [3.78], [3.79], [3.8 ], [3.81], [3.82], [3.83], [3.84], [3.85], [3.86], [3.87], [3.88], [3.89], [3.9 ], [3.91], [3.92], [3.93], [3.94], [3.95], [3.96], [3.97], [3.98], [3.99], [4. ], [4.01], [4.02], [4.03], [4.04], [4.05], [4.06], [4.07], [4.08], [4.09], [4.1 ], [4.11], [4.12], [4.13], [4.14], [4.15], [4.16], [4.17], [4.18], [4.19], [4.2 ], [4.21], [4.22], [4.23], [4.24], [4.25], [4.26], [4.27], [4.28], [4.29], [4.3 ], [4.31], [4.32], [4.33], [4.34], [4.35], [4.36], [4.37], [4.38], [4.39], [4.4 ], [4.41], [4.42], [4.43], [4.44], [4.45], [4.46], [4.47], [4.48], [4.49], [4.5 ], [4.51], [4.52], [4.53], [4.54], [4.55], [4.56], [4.57], [4.58], [4.59], [4.6 ], [4.61], [4.62], [4.63], [4.64], [4.65], [4.66], [4.67], [4.68], [4.69], [4.7 ], [4.71], [4.72], [4.73], [4.74], [4.75], [4.76], [4.77], [4.78], [4.79], [4.8 ], [4.81], [4.82], [4.83], [4.84], [4.85], [4.86], [4.87], [4.88], [4.89], [4.9 ], [4.91], [4.92], [4.93], [4.94], [4.95], [4.96], [4.97], [4.98], [4.99], [5. ], [5.01], [5.02], [5.03], [5.04], [5.05], [5.06], [5.07], [5.08], [5.09], [5.1 ], [5.11], [5.12], [5.13], [5.14], [5.15], [5.16], [5.17], [5.18], [5.19], [5.2 ], [5.21], [5.22], [5.23], [5.24], [5.25], [5.26], [5.27], [5.28], [5.29], [5.3 ], [5.31], [5.32], [5.33], [5.34], [5.35], [5.36], [5.37], [5.38], [5.39], [5.4 ], [5.41], [5.42], [5.43], [5.44], [5.45], [5.46], [5.47], [5.48], [5.49], [5.5 ], [5.51], [5.52], [5.53], [5.54], [5.55], [5.56], [5.57], [5.58], [5.59], [5.6 ], [5.61], [5.62], [5.63], [5.64], [5.65], [5.66], [5.67], [5.68], [5.69], [5.7 ], [5.71], [5.72], [5.73], [5.74], [5.75], [5.76], [5.77], [5.78], [5.79], [5.8 ], [5.81], [5.82], [5.83], [5.84], [5.85], [5.86], [5.87], [5.88], [5.89], [5.9 ], [5.91], [5.92], [5.93], [5.94], [5.95], [5.96], [5.97], [5.98], [5.99], [6. ], [6.01], [6.02], [6.03], [6.04], [6.05], [6.06], [6.07], [6.08], [6.09], [6.1 ], [6.11], [6.12], [6.13], [6.14], [6.15], [6.16], [6.17], [6.18], [6.19], [6.2 ], [6.21], [6.22], [6.23], [6.24], [6.25], [6.26], [6.27], [6.28], [6.29], [6.3 ], [6.31], [6.32], [6.33], [6.34], [6.35], [6.36], [6.37], [6.38], [6.39], [6.4 ], [6.41], [6.42], [6.43], [6.44], [6.45], [6.46], [6.47], [6.48], [6.49], [6.5 ], [6.51], [6.52], [6.53], [6.54], [6.55], [6.56], [6.57], [6.58], [6.59], [6.6 ], [6.61], [6.62], [6.63], [6.64], [6.65], [6.66], [6.67], [6.68], [6.69], [6.7 ], [6.71], [6.72], [6.73], [6.74], [6.75], [6.76], [6.77], [6.78], [6.79], [6.8 ], [6.81], [6.82], [6.83], [6.84], [6.85], [6.86], [6.87], [6.88], [6.89], [6.9 ], [6.91], [6.92], [6.93], [6.94], [6.95], [6.96], [6.97], [6.98], [6.99], [7. ], [7.01], [7.02], [7.03], [7.04], [7.05], [7.06], [7.07], [7.08], [7.09], [7.1 ], [7.11], [7.12], [7.13], [7.14], [7.15], [7.16], [7.17], [7.18], [7.19], [7.2 ], [7.21], [7.22], [7.23], [7.24], [7.25], [7.26], [7.27], [7.28], [7.29], [7.3 ], [7.31], [7.32], [7.33], [7.34], [7.35], [7.36], [7.37], [7.38], [7.39], [7.4 ], [7.41], [7.42], [7.43], [7.44], [7.45], [7.46], [7.47], [7.48], [7.49], [7.5 ], [7.51], [7.52], [7.53], [7.54], [7.55], [7.56], [7.57], [7.58], [7.59], [7.6 ], [7.61], [7.62], [7.63], [7.64], [7.65], [7.66], [7.67], [7.68], [7.69], [7.7 ], [7.71], [7.72], [7.73], [7.74], [7.75], [7.76], [7.77], [7.78], [7.79], [7.8 ], [7.81], [7.82], [7.83], [7.84], [7.85], [7.86], [7.87], [7.88], [7.89], [7.9 ], [7.91], [7.92], [7.93], [7.94], [7.95], [7.96], [7.97], [7.98], [7.99]])
In [17]:
Yp = reg.predict(X)
Yp
Out[17]:
array([ 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 60., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 80., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 110., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 150., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 200., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240., 240.])
In [18]:
plt.scatter(x,y)
plt.plot(X,Yp)
plt.show()