- Read Tutorial
- Watch Guide Video
WEBVTT
1
00:00:02.660 --> 00:00:04.020
This guide,
2
00:00:04.020 --> 00:00:05.950
we're gonna be talking about how we can use
3
00:00:05.950 --> 00:00:08.890
a random forest for classification.
4
00:00:08.890 --> 00:00:10.590
In terms of how it works,
5
00:00:10.590 --> 00:00:14.000
there isn't really anything new that we need to talk about.
6
00:00:14.000 --> 00:00:15.970
Because the same concepts that we used
7
00:00:15.970 --> 00:00:18.720
for the regression forest can be applied
8
00:00:18.720 --> 00:00:20.473
to a classification forest.
9
00:00:22.070 --> 00:00:24.260
But to do a quick recap,
10
00:00:24.260 --> 00:00:28.490
the whole idea behind a random forest is power in numbers.
11
00:00:28.490 --> 00:00:31.500
Because it's essentially just a bunch of decision trees
12
00:00:31.500 --> 00:00:32.623
combined together.
13
00:00:33.730 --> 00:00:36.630
And what really makes a forest so powerful
14
00:00:36.630 --> 00:00:39.300
is that instead of using just one algorithm,
15
00:00:39.300 --> 00:00:41.950
it uses a bunch of algorithms to help increase
16
00:00:41.950 --> 00:00:44.983
the predictive power while reducing over fitting.
17
00:00:46.350 --> 00:00:48.650
Then to ensure all the trees are acting
18
00:00:48.650 --> 00:00:50.720
as independently as possible,
19
00:00:50.720 --> 00:00:54.760
a technique called bootstrapping aggregation is used.
20
00:00:54.760 --> 00:00:58.240
And the goal of bagging is to make sure that each tree
21
00:00:58.240 --> 00:01:00.940
is responsible for taking its own randomly chosen
22
00:01:00.940 --> 00:01:03.330
sample of training data.
23
00:01:03.330 --> 00:01:05.760
Once a tree is done with its random sampling,
24
00:01:05.760 --> 00:01:08.040
it puts whatever information it used
25
00:01:08.040 --> 00:01:09.900
back into the training set,
26
00:01:09.900 --> 00:01:13.163
and then the next tree gathers its own random sample.
27
00:01:14.320 --> 00:01:16.080
The point of doing it that way,
28
00:01:16.080 --> 00:01:19.570
allows each tree to stay as independent as possible.
29
00:01:19.570 --> 00:01:21.893
Which helps ensure varying results.
30
00:01:23.350 --> 00:01:24.830
Along with bootstrapping,
31
00:01:24.830 --> 00:01:28.330
a random forest will also utilize feature randomness.
32
00:01:28.330 --> 00:01:31.073
And that's to add even more variation.
33
00:01:32.010 --> 00:01:34.230
So instead of using every feature,
34
00:01:34.230 --> 00:01:36.080
like we did in the last guide,
35
00:01:36.080 --> 00:01:39.130
a random forest only allows the trees to partition
36
00:01:39.130 --> 00:01:42.799
based on a handful of randomly selected features.
37
00:01:42.799 --> 00:01:46.790
And ideally, that results in a lower correlation
38
00:01:46.790 --> 00:01:50.083
from tree to tree, and helps create more diversity.
39
00:01:51.470 --> 00:01:54.520
Now, when it comes to building a random forest model,
40
00:01:54.520 --> 00:01:56.630
it's really not any different from what we did
41
00:01:56.630 --> 00:01:58.000
in the last guide.
42
00:01:58.000 --> 00:02:00.493
With the obvious exception of the algorithm.
43
00:02:01.720 --> 00:02:05.110
So what we'll do is just run through the code with you,
44
00:02:05.110 --> 00:02:06.980
but what I'd really like to focus on
45
00:02:06.980 --> 00:02:08.390
for the rest of the guide,
46
00:02:08.390 --> 00:02:11.180
is how we can assess a classification model
47
00:02:11.180 --> 00:02:13.803
by using something called the ROC curve.
48
00:02:16.980 --> 00:02:18.030
First and foremost,
49
00:02:18.030 --> 00:02:21.510
you can see that everything I did was pretty standard.
50
00:02:21.510 --> 00:02:24.740
I started off by importing the core libraries
51
00:02:24.740 --> 00:02:27.990
along with the train test split function,
52
00:02:27.990 --> 00:02:29.840
and the random forest class
53
00:02:29.840 --> 00:02:32.023
from the SK learn ensemble model.
54
00:02:34.450 --> 00:02:37.720
We're using the same tumor data frame as before,
55
00:02:37.720 --> 00:02:41.550
then it's segmented into feature and target variables,
56
00:02:41.550 --> 00:02:44.193
and applied to the train test split function.
57
00:02:45.950 --> 00:02:47.950
The one small change I did make,
58
00:02:47.950 --> 00:02:49.743
was changing the class labels.
59
00:02:50.710 --> 00:02:52.950
So to align a little more closely
60
00:02:52.950 --> 00:02:54.890
with machine learning standards,
61
00:02:54.890 --> 00:02:59.053
I went ahead and switched the two and four to zero and one.
62
00:03:00.040 --> 00:03:03.360
Then the other change that I made was that by assigning
63
00:03:03.360 --> 00:03:05.260
the benign class to one,
64
00:03:05.260 --> 00:03:06.980
I'm signaling that I want that class
65
00:03:06.980 --> 00:03:08.353
to be the positive result.
66
00:03:09.700 --> 00:03:12.530
Next up, we have the random forest classifier
67
00:03:12.530 --> 00:03:15.123
and that's assigned to the classifier object.
68
00:03:16.110 --> 00:03:19.180
Then we use the fit function along with the training data
69
00:03:19.180 --> 00:03:21.223
to create the random forest model.
70
00:03:22.320 --> 00:03:26.170
Then the last thing we need to do is check the accuracy.
71
00:03:26.170 --> 00:03:28.110
So to do that,
72
00:03:28.110 --> 00:03:30.570
all we had to do was assign the score function
73
00:03:30.570 --> 00:03:34.453
along with both of the test sets to the accuracy object.
74
00:03:35.710 --> 00:03:37.223
And after we run it,
75
00:03:40.950 --> 00:03:44.363
you can see we end up with a pretty solid accuracy score.
76
00:03:45.560 --> 00:03:48.140
Now, before we get into the ROC curve,
77
00:03:48.140 --> 00:03:51.573
I'd like to do a quick refresher of the confusion matrix.
78
00:03:52.890 --> 00:03:55.100
Depending on the number of possible cases,
79
00:03:55.100 --> 00:03:57.690
a confusion matrix can vary in size,
80
00:03:57.690 --> 00:03:59.680
but for now we're really only working
81
00:03:59.680 --> 00:04:01.680
with binary classification.
82
00:04:01.680 --> 00:04:05.107
So we'll end up with a two by two matrix.
83
00:04:05.107 --> 00:04:08.610
The columns are going to represent the predicted labels.
84
00:04:08.610 --> 00:04:10.260
With the column on the left,
85
00:04:10.260 --> 00:04:12.630
containing all of the negative results
86
00:04:12.630 --> 00:04:14.530
and the column on the right containing
87
00:04:14.530 --> 00:04:16.163
all of the positive results.
88
00:04:17.200 --> 00:04:19.040
Then moving on to the rows,
89
00:04:19.040 --> 00:04:22.560
those are going to represent the true or correct label.
90
00:04:22.560 --> 00:04:26.200
And here, the top row contains the negative label,
91
00:04:26.200 --> 00:04:29.920
and the bottom row contains the positive label.
92
00:04:29.920 --> 00:04:33.270
So if an observation is accurately predicted
93
00:04:33.270 --> 00:04:35.300
to belong to the negative class,
94
00:04:35.300 --> 00:04:38.133
the result would be described as a true negative.
95
00:04:38.990 --> 00:04:42.560
But in contrast, if the observation was incorrectly labeled,
96
00:04:42.560 --> 00:04:45.153
we would consider that to be a false negative.
97
00:04:46.490 --> 00:04:48.930
And then the same basic principles will apply
98
00:04:48.930 --> 00:04:51.640
to the positive column on the right.
99
00:04:51.640 --> 00:04:54.160
If the predicted label turns out to be correct,
100
00:04:54.160 --> 00:04:55.960
it's a true positive,
101
00:04:55.960 --> 00:04:59.163
but if it's an incorrect result, it's a false positive.
102
00:05:00.320 --> 00:05:03.900
Now we talked about this a little bit in an earlier guide,
103
00:05:03.900 --> 00:05:07.540
but the two categories we obviously want to minimize,
104
00:05:07.540 --> 00:05:11.040
are the false negatives and the false positives.
105
00:05:11.040 --> 00:05:14.793
But generally speaking, it's impossible to minimize both.
106
00:05:16.060 --> 00:05:19.250
So the reality is, eventually you'll need to choose
107
00:05:19.250 --> 00:05:20.450
which category you think
108
00:05:20.450 --> 00:05:23.060
is gonna be more important to the model.
109
00:05:23.060 --> 00:05:24.940
And for the purpose of this example,
110
00:05:24.940 --> 00:05:28.230
what we really want to avoid are false positives,
111
00:05:28.230 --> 00:05:30.160
which happened when the malignant tumor
112
00:05:30.160 --> 00:05:32.563
is incorrectly classified as benign.
113
00:05:33.660 --> 00:05:36.230
And the reason I'm bringing all of this up again
114
00:05:36.230 --> 00:05:39.560
is because we can take advantage of the ROC curve
115
00:05:39.560 --> 00:05:42.230
to do a direct comparison of true positives
116
00:05:42.230 --> 00:05:43.553
and false positives.
117
00:05:44.750 --> 00:05:47.220
So moving into the ROC curve,
118
00:05:47.220 --> 00:05:50.880
by definition, the receiver operating characteristic
119
00:05:50.880 --> 00:05:54.910
or ROC curve is a graphical plot that illustrates
120
00:05:54.910 --> 00:05:58.430
the diagnostic ability of a binary classifier system
121
00:05:58.430 --> 00:06:00.963
as its discrimination thresholds are varied.
122
00:06:01.890 --> 00:06:04.910
And basically what that means is that we'll have a graph
123
00:06:04.910 --> 00:06:07.670
that shows us how observations will be classified
124
00:06:07.670 --> 00:06:11.430
if we make changes to the decision boundary/
125
00:06:11.430 --> 00:06:13.640
And like most things that we've worked with,
126
00:06:13.640 --> 00:06:16.470
this is way easier to explain with a visual.
127
00:06:16.470 --> 00:06:19.470
So let's start off with this empty graph.
128
00:06:19.470 --> 00:06:23.270
The first thing to notice is that the false positive rate,
129
00:06:23.270 --> 00:06:26.640
which is the same thing as 1-specificity,
130
00:06:26.640 --> 00:06:29.500
is plotted along the X axis.
131
00:06:29.500 --> 00:06:31.190
And then along the Y axis,
132
00:06:31.190 --> 00:06:34.403
we have the true positive rate or sensitivity.
133
00:06:35.520 --> 00:06:37.700
The next characteristic of the ROC curve
134
00:06:37.700 --> 00:06:39.030
that you'll need to know,
135
00:06:39.030 --> 00:06:41.740
is that there's always gonna be a diagonal line
136
00:06:41.740 --> 00:06:44.579
that cuts right through the middle of the graph.
137
00:06:44.579 --> 00:06:47.310
And at every point along that line,
138
00:06:47.310 --> 00:06:49.770
the false positive and true positive rates
139
00:06:49.770 --> 00:06:51.003
are gonna be equal.
140
00:06:51.900 --> 00:06:54.010
And the reason that we have that line
141
00:06:54.010 --> 00:06:56.970
is to mimic what the ROC curve would look like
142
00:06:56.970 --> 00:06:59.630
if the model was just making blind guesses
143
00:06:59.630 --> 00:07:02.990
as to what class an observation belonged to.
144
00:07:02.990 --> 00:07:06.773
So it basically represents the worst case scenario.
145
00:07:08.290 --> 00:07:09.580
Now, in contrast,
146
00:07:09.580 --> 00:07:12.790
if you somehow manage to create the perfect model,
147
00:07:12.790 --> 00:07:15.600
the ROC curve will take on this form,
148
00:07:15.600 --> 00:07:19.560
it'll go directly up the Y axis and then go straight across
149
00:07:19.560 --> 00:07:21.873
running parallel with the X axis.
150
00:07:23.160 --> 00:07:25.580
We'll be talking about thresholds in more detail
151
00:07:25.580 --> 00:07:27.340
over the next few guides,
152
00:07:27.340 --> 00:07:31.070
but for now we're just gonna keep it fairly simple.
153
00:07:31.070 --> 00:07:32.350
So for this curve,
154
00:07:32.350 --> 00:07:35.640
the optimal threshold is gonna be right here,
155
00:07:35.640 --> 00:07:38.070
because this is the point where we maximize
156
00:07:38.070 --> 00:07:40.530
the true positive rate while eliminating
157
00:07:40.530 --> 00:07:42.113
all of the false positives.
158
00:07:44.140 --> 00:07:45.790
To be a little more realistic,
159
00:07:45.790 --> 00:07:49.170
let's say the first data point on the ROC curve
160
00:07:49.170 --> 00:07:51.553
is at 0.75 and one.
161
00:07:52.780 --> 00:07:54.880
Well, based on this location,
162
00:07:54.880 --> 00:07:57.630
which is to the left of this dotted line,
163
00:07:57.630 --> 00:08:00.690
we know the model was able to produce a greater proportion
164
00:08:00.690 --> 00:08:03.880
of true positives than false positives.
165
00:08:03.880 --> 00:08:05.560
And given that information,
166
00:08:05.560 --> 00:08:08.700
we can also make the assumption that the new threshold
167
00:08:08.700 --> 00:08:10.630
is better at making class predictions
168
00:08:10.630 --> 00:08:12.473
than the random guess threshold.
169
00:08:13.440 --> 00:08:15.460
If we increase the threshold again,
170
00:08:15.460 --> 00:08:17.770
and it produces this outcome,
171
00:08:17.770 --> 00:08:20.060
we see an even greater proportion
172
00:08:20.060 --> 00:08:23.090
of true positives to false positives.
173
00:08:23.090 --> 00:08:26.080
Which means out of the first three threshold choices,
174
00:08:26.080 --> 00:08:28.293
this one is gonna be the best so far.
175
00:08:29.200 --> 00:08:31.410
Now let's go ahead and pretend that we increase
176
00:08:31.410 --> 00:08:34.833
the threshold two more times and get these two results.
177
00:08:35.720 --> 00:08:38.210
Now, this turns out to be an important point.
178
00:08:38.210 --> 00:08:39.880
Because it's the first threshold
179
00:08:39.880 --> 00:08:42.343
where we have a false positive rate of zero.
180
00:08:43.740 --> 00:08:45.430
Eventually we'll end up with a model
181
00:08:45.430 --> 00:08:49.493
that produces zero true positives and zero false positives.
182
00:08:50.335 --> 00:08:53.720
Okay, so by looking at the final graph,
183
00:08:53.720 --> 00:08:56.770
we obviously don't have threshold at (0, 1),
184
00:08:56.770 --> 00:08:58.960
like we had in the perfect model.
185
00:08:58.960 --> 00:09:02.715
But we do have a few thresholds with zero false positives.
186
00:09:02.715 --> 00:09:04.050
And out of those,
187
00:09:04.050 --> 00:09:07.675
we have just the one that produces the most true positives,
188
00:09:07.675 --> 00:09:09.200
which is gonna indicate
189
00:09:09.200 --> 00:09:11.723
that that's the optimum threshold for the model.
190
00:09:12.680 --> 00:09:16.040
But I would also like to add that depending on the project,
191
00:09:16.040 --> 00:09:20.230
it might be more beneficial to maximize the true positives.
192
00:09:20.230 --> 00:09:22.350
So if you wanted to do it that way,
193
00:09:22.350 --> 00:09:24.700
this would be your optimal threshold.
194
00:09:24.700 --> 00:09:26.130
Because at that threshold,
195
00:09:26.130 --> 00:09:29.400
you're gonna end up with the maximum true positive rate
196
00:09:29.400 --> 00:09:32.093
along with a fairly small false positive rate.
197
00:09:33.640 --> 00:09:36.290
Now, before we do our quick run through all of this,
198
00:09:36.290 --> 00:09:38.520
with the random forest classifier,
199
00:09:38.520 --> 00:09:41.160
there's one last thing that we need to talk about,
200
00:09:41.160 --> 00:09:43.713
and that's the area under the ROC curve.
201
00:09:44.840 --> 00:09:46.520
If you remember back in calculus,
202
00:09:46.520 --> 00:09:48.360
the area under the curve could be found
203
00:09:48.360 --> 00:09:50.620
by taking the integral of the function.
204
00:09:50.620 --> 00:09:52.270
And depending on the function,
205
00:09:52.270 --> 00:09:55.530
it could give us a lot of really helpful information.
206
00:09:55.530 --> 00:09:58.330
Like when we integrate acceleration over time,
207
00:09:58.330 --> 00:10:01.310
we're able to figure out the change in velocity.
208
00:10:01.310 --> 00:10:05.220
Well, the same basic idea is true with the ROC curve.
209
00:10:05.220 --> 00:10:07.890
Because, oftentimes the area under the curve
210
00:10:07.890 --> 00:10:10.673
can be related to the overall accuracy of the model.
211
00:10:11.640 --> 00:10:13.570
So when you have more area,
212
00:10:13.570 --> 00:10:15.943
that indicates an increase in accuracy.
213
00:10:18.150 --> 00:10:20.340
To avoid making this guide way too long,
214
00:10:20.340 --> 00:10:22.790
I'm just gonna run through all the code with you.
215
00:10:24.210 --> 00:10:25.660
So to start this off,
216
00:10:25.660 --> 00:10:29.100
we're really only gonna need two new functions.
217
00:10:29.100 --> 00:10:32.130
The first is the ROC curve function,
218
00:10:32.130 --> 00:10:33.750
which is gonna give us the ability
219
00:10:33.750 --> 00:10:36.230
to create a true positive, false positive,
220
00:10:36.230 --> 00:10:37.970
and threshold object.
221
00:10:37.970 --> 00:10:40.880
Which we're gonna need to build the graph.
222
00:10:40.880 --> 00:10:45.060
Then the second is the ROC AUC score function.
223
00:10:45.060 --> 00:10:46.630
And that's what we're gonna use
224
00:10:46.630 --> 00:10:48.823
to estimate the area under the curve.
225
00:10:50.150 --> 00:10:52.350
Now, the first thing that we're gonna need to do
226
00:10:52.350 --> 00:10:54.420
is create an object that contains
227
00:10:54.420 --> 00:10:56.640
all of the predicted probabilities
228
00:10:56.640 --> 00:10:58.243
from the feature test set.
229
00:10:59.100 --> 00:11:01.700
And by using the predict proba function,
230
00:11:01.700 --> 00:11:04.020
we're able to generate a matrix that contains
231
00:11:04.020 --> 00:11:07.010
the class probability for each observation
232
00:11:07.010 --> 00:11:08.423
in the feature test set.
233
00:11:10.450 --> 00:11:13.310
I'm gonna go ahead and open this up in the variable pane.
234
00:11:13.310 --> 00:11:15.000
But the matrix that you're using
235
00:11:15.000 --> 00:11:17.573
is probably gonna be a little different than mine.
236
00:11:18.800 --> 00:11:20.370
But regardless of that,
237
00:11:20.370 --> 00:11:22.950
the first column indicates the likelihood
238
00:11:22.950 --> 00:11:25.603
the observation belongs to the malignant class.
239
00:11:26.610 --> 00:11:29.300
Then the second column indicates the likelihood
240
00:11:29.300 --> 00:11:32.053
of the observation belonging to the benign class.
241
00:11:33.010 --> 00:11:35.930
And if it makes more sense thinking of this as a percentage,
242
00:11:35.930 --> 00:11:37.023
you can do that too.
243
00:11:37.920 --> 00:11:41.150
So for all of the ones that you see in the second column,
244
00:11:41.150 --> 00:11:43.140
the model is essentially saying that
245
00:11:43.140 --> 00:11:46.560
all of those observations belong to the benign class.
246
00:11:46.560 --> 00:11:50.333
And it's making the prediction with 100% certainty.
247
00:11:53.860 --> 00:11:55.040
Now, after that,
248
00:11:55.040 --> 00:11:57.280
the next step was to slice up the matrix
249
00:11:57.280 --> 00:12:00.340
so that we only have the benign class.
250
00:12:00.340 --> 00:12:03.270
And that's because when we're building the ROC curve,
251
00:12:03.270 --> 00:12:05.030
we're really only concerned with all
252
00:12:05.030 --> 00:12:06.453
of the positive results.
253
00:12:08.130 --> 00:12:10.390
Finally, the third object in the cell
254
00:12:10.390 --> 00:12:13.200
is just an array made up of all zeros
255
00:12:13.200 --> 00:12:16.380
with the same length as the target variable.
256
00:12:16.380 --> 00:12:18.460
And when we get closer to the plotting part,
257
00:12:18.460 --> 00:12:20.073
I'll explain why we need this.
258
00:12:21.600 --> 00:12:22.810
Now, moving on,
259
00:12:22.810 --> 00:12:24.200
the first thing we're doing
260
00:12:24.200 --> 00:12:26.960
is utilizing the ROC curve function.
261
00:12:26.960 --> 00:12:29.670
And we're doing that to create a true positive,
262
00:12:29.670 --> 00:12:32.880
false positive, and threshold object.
263
00:12:32.880 --> 00:12:35.300
And the attributes we're gonna need to pass in
264
00:12:35.300 --> 00:12:37.100
are gonna be the target test set,
265
00:12:37.100 --> 00:12:39.770
which contains all of the true labels,
266
00:12:39.770 --> 00:12:41.810
the positive probability object
267
00:12:41.810 --> 00:12:44.300
that contains all of the probability estimates
268
00:12:44.300 --> 00:12:46.070
of the positive class.
269
00:12:46.070 --> 00:12:48.390
And finally, we're gonna pass in which label
270
00:12:48.390 --> 00:12:50.990
indicates the positive class.
271
00:12:50.990 --> 00:12:52.440
Then in the next line,
272
00:12:52.440 --> 00:12:55.370
we're gonna follow the exact same steps.
273
00:12:55.370 --> 00:12:58.700
But instead of using the positive probability array,
274
00:12:58.700 --> 00:13:01.733
we're gonna use the array containing all of the zeros.
275
00:13:02.860 --> 00:13:05.180
Some of you might have figured this out already,
276
00:13:05.180 --> 00:13:08.290
but the purpose of this step is to create the objects
277
00:13:08.290 --> 00:13:11.100
that we're gonna need to graph that diagonal line
278
00:13:11.100 --> 00:13:15.660
representing the random guesses or worst case scenario line.
279
00:13:15.660 --> 00:13:17.370
And by using the zero array,
280
00:13:17.370 --> 00:13:19.720
we're essentially forcing the model to predict
281
00:13:19.720 --> 00:13:21.820
that every observation is gonna belong
282
00:13:21.820 --> 00:13:23.113
to the negative class.
283
00:13:24.600 --> 00:13:27.200
Finally, we have the last two objects
284
00:13:27.200 --> 00:13:29.710
and those are for the AUC scores.
285
00:13:29.710 --> 00:13:32.460
And we'll be adding those to the graph at the very end.
286
00:13:34.030 --> 00:13:35.570
Now, before I run this,
287
00:13:35.570 --> 00:13:37.850
the next few lines of code are gonna be
288
00:13:37.850 --> 00:13:39.501
for the diagonal line.
289
00:13:39.501 --> 00:13:42.090
Then after that, the next few lines
290
00:13:42.090 --> 00:13:44.133
are for the actual ROC curve.
291
00:13:46.150 --> 00:13:49.333
Okay, so let's go ahead and run this one last time.
292
00:13:54.870 --> 00:13:56.470
And there you have it,
293
00:13:56.470 --> 00:13:58.493
your very own ROC curve.
294
00:13:59.460 --> 00:14:01.840
And not only did you build your own curve,
295
00:14:01.840 --> 00:14:05.790
it also has an AUC score really close to one.
296
00:14:05.790 --> 00:14:08.253
Which indicates a highly accurate model.
297
00:14:09.640 --> 00:14:13.470
All righty, that finally brings us to the end of the guide.
298
00:14:13.470 --> 00:14:15.260
But before I let you go,
299
00:14:15.260 --> 00:14:18.370
we'll be talking about this topic again really soon.
300
00:14:18.370 --> 00:14:22.160
So try your best to not forget any of this.
301
00:14:22.160 --> 00:14:23.800
And as always,
302
00:14:23.800 --> 00:14:27.133
I will wrap this up and I will see you in the next guide.