-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path04_classification.Rmd
188 lines (158 loc) Β· 5.38 KB
/
04_classification.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
---
title: "ISLR, Chapter 4 - Classification"
output: github_document
---
```{r import libraries, message=FALSE, warning=FALSE}
library(MASS)
library(ISLR)
library(tidyverse)
library(tidymodels)
```
# Multiple Logistic Regression
## `Default` data
```{r credit cards}
# look at the deatils of the dataset
# ?Default
# fit the model
log_reg <- logistic_reg() %>%
set_engine("glm") %>%
fit(default ~ balance + income + student, data = Default)
tidy(log_reg$fit)
summary(log_reg$fit)
# look at the difference in distribution of the balance for students & non-students
Default %>%
ggplot(aes(student,balance)) +
geom_boxplot()
# get predictors
X <- Default %>%
select(-default)
head(X)
# compare prediction with the true y
predictions <- log_reg %>%
predict(new_data = X) %>%
bind_cols(Default %>% select(default))
predictions
```
## Lab - `Smarket` data
Using the dataset `Smarket` from the `ISLR` package.
`Direction` is going to be the response, it indicates whether the market went up or down since the previous day.
```{r LDA - fit}
summary(Smarket)
names(Smarket)
help("Smarket") ## to get the details of the dataset
# look at the pairplot of all the variables using GGally
# pairs(Smarket, col = Smarket$Direction)
GGally::ggpairs(Smarket, aes(color = Direction))
# fit the model
smarket_model <- logistic_reg() %>%
set_engine("glm") %>%
fit(Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + Lag5 + Volume, data = Smarket)
# look at just the intercepts
tidy.model_fit(smarket_model)
# look at the full summary of the fitted model
summary(smarket_model$fit)
```
None of the coefficients are significant & there is only a modest change in deviance-- only four units
Null deviance = deviance for the mean, log likelihood if you just use the mean model
Residual deviance = deviance of the model with all the predictors
```{r LDA - predictions}
# predicting the response of the data used to train/fit the model
predictions <- smarket_model$fit %>%
predict(type = "response")
# output is a vector of fitted probabilities
predictions[1:5]
```
The predictions are close to 50%-- not strong predictions, which is expected with stock data.
Turning the predictions into classifications by setting a threshold to 0.5:
```{r LDA - classifying}
predictions <- ifelse(predictions > 0.5, "Up", "Down")
predictions
# confusion matrix
table(predictions,Smarket$Direction)
# accuracy
mean(predictions == Smarket$Direction)
```
Many of the observations were misclassified.
### Train/test Split -> New fit -> New predictions
```{r}
train <- Smarket %>% filter(Year < 2005)
test <- Smarket %>% filter(Year >= 2005)
smarket_model <- logistic_reg() %>%
set_engine("glm") %>%
fit(Direction ~ Lag1 + Lag2 + Lag3 + Lag4 + Lag5 + Volume, data = train)
predictions <- smarket_model %>%
predict(new_data = test) %>%
bind_cols(test %>% select(Direction))
# confusion matrix
predictions %>%
conf_mat(Direction, .pred_class)
# accuracy
mean(predictions$Direction == predictions$.pred_class)
```
The model does worse than the null rate (50%) with only 48% accuracy. We're probably overfitting. We can try reducing the predictors to just the most important ones, `Lag1` and `Lag2`s
```{r}
smarket_model <- logistic_reg() %>%
set_engine("glm") %>%
fit(Direction ~ Lag1 + Lag2, data = train)
predictions <- smarket_model %>%
predict(new_data = test) %>%
bind_cols(test %>% select(Direction))
# confusion matrix
predictions %>%
conf_mat(Direction, .pred_class) %>%
autoplot()
# accuracy
mean(predictions$Direction == predictions$.pred_class)
summary(smarket_model$fit)
```
The predictors are still not significant, but the prediction performance slightly improved.
# Linear Discriminant Analysis
We will use the same response and predictor variables and see if LDA can do a better job.
```{r}
lda_smarket <- lda(Direction ~ Lag1 + Lag2, data = train)
# summarize results
lda_smarket
# plot the lda fit
# it plots the linear discriminant function for the two y groups
plot(lda_smarket)
```
The two histograms look very familiar & the prior probabilities are almost equal.
Now we'll predict the `Direction` for the year 2005
```{r}
test <- test %>% filter(year == 2005)
predictions <- predict(lda_smarket, test)
# predict() outputs a list, which we can make into a tibble
predictions <- data.frame(predictions)
# Predicted class, posterior probabilities, values of the LDA score
str(predictions)
# confusion matrix
table(predictions$class, test$Direction)
# accuracy
mean(predictions$class == test$Direction)
```
# K Nearest Neighbors
This time we'll predict Direction using `knn` from the library `class`.
```{r}
# ?knn
Xlag_train <- train %>% select(Direction, Lag1, Lag2)
Xlag_test <- test %>% select(Direction, Lag1, Lag2)
Xlag_train2 <- train %>% select(Lag1, Lag2)
Xlag_test2 <- test %>% select(Lag1, Lag2)
# using parsnip, which uses kknn() - Weighted K-Nearest Neighbor Classifier
knn_smarket <- nearest_neighbor() %>%
set_engine("class") %>%
set_mode("classification") %>%
fit(Direction ~ Lag1 + Lag2, data = Xlag_train)
predictions <- knn_smarket %>%
predict(new_data = Xlag_test) %>%
mutate(truth = test$Direction)
predictions %>%
conf_mat(truth,.pred_class)
mean(predictions$truth == predictions$.pred_class)
# using knn()
set.seed(1)
knn <- knn(train = Xlag_train2, test = Xlag_test2, train$Direction, k=3)
table(knn,test$Direction)
mean(knn == test$Direction)
```
k = 1 nearest neighbor did no better than flipping a coin.