-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathREADME.Rmd
205 lines (143 loc) · 5.97 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
# Explainable Ensemble Trees (e2tree)
<!-- badges: start -->
[![R-CMD-check](https://github.com/massimoaria/e2tree/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/massimoaria/e2tree/actions/workflows/R-CMD-check.yaml)
<!-- badges: end -->
The Explainable Ensemble Trees (e2tree) key idea consists of the definition of an algorithm to represent every ensemble approach based on decision trees model using a single tree-like structure. The goal is to explain the results from the esemble algorithm while preserving its level of accuracy, which always outperforms those provided by a decision tree. The proposed method is based on identifying the relationship tree-like structure explaining the classification or regression paths summarizing the whole ensemble process. There are two main advantages of e2tree:
- building an explainable tree that ensures the predictive performance of an RF model - allowing the decision-maker to manage with an intuitive structure (such as a tree-like structure).
In this example, we focus on Random Forest but, again, the algorithm can be generalized to every ensemble approach based on decision trees.
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%",
dpi = 300
)
```
## Setup
You can install the developer version of e2tree from [GitHub](https://github.com) with:
```{r eval=FALSE}
install.packages("remotes")
remotes::install_github("massimoaria/e2tree")
```
```{r warning=FALSE, message=FALSE}
require(e2tree)
require(randomForest)
require(dplyr)
require(ggplot2)
if (!(require(rsample, quietly=TRUE))){install.packages("rsample"); require(rsample, quietly=TRUE)}
options(dplyr.summarise.inform = FALSE)
```
```{r set-theme, include=FALSE}
theme_set(
theme_classic() +
theme(
plot.background = element_rect(fill = "transparent", colour = NA),
panel.background = element_rect(fill = "transparent", colour = NA)
)
)
knitr::opts_chunk$set(dev.args = list(bg = "transparent"))
```
## Warnings
The package is still under development and therefore, for the time being, there are the following limitations:
- Only ensembles trained with the randomForest package are supported. Additional packages and approaches will be supported in the future;
- Currently e2tree works only in the case of classification problems. It will gradually be extended to other problems related to the nature of the response variable: regression, counting, multivariate response, etc.
## Example 1: IRIS dataset
In this example, we want to show the main functions of the e2tree package.
Starting from the IRIS dataset, we will train an ensemble tree using the randomForest package and then subsequently use e2tree to obtain an explainable tree synthesis of the ensemble classifier.
```{r}
# Set random seed to make results reproducible:
set.seed(0)
# Initialize the split
iris_split <- iris %>% initial_split(prop = 0.6)
iris_split
# Assign the data to the correct sets
training <- iris_split %>% training()
validation <- iris_split %>% testing()
response_training <- training[,5]
response_validation <- validation[,5]
```
Train an Random Forest model with 1000 weak learners
```{r}
# Perform training:
ensemble = randomForest(Species ~ ., data = training, importance = TRUE, proximity = TRUE)
```
Here, we create the dissimilarity matrix between observations through the createDisMatrix function
```{r}
D = createDisMatrix(ensemble, data = training, label = "Species", parallel = TRUE)
#dis <- 1-rf$proximity
```
setting e2tree parameters
```{r}
setting=list(impTotal=0.1, maxDec=0.01, n=5, level=5, tMax=5)
```
Build an explainable tree for RF
```{r}
tree <- e2tree(Species ~ ., data = training, D, ensemble, setting)
```
Plot the Explainable Ensemble Tree
```{r}
expl_plot <- rpart2Tree(tree, ensemble)
rpart.plot::rpart.plot(expl_plot)
```
Let's have a look at the output
```{r}
tree %>% glimpse()
```
Prediction with the new tree (example on training)
```{r}
pred <- ePredTree(tree, training[,-5], target="virginica")
```
Comparison of predictions (training sample) of RF and e2tree
```{r}
table(pred$fit, ensemble$predicted)
```
Comparison of predictions (training sample) of RF and correct response
```{r}
table(ensemble$predicted, response_training)
```
Comparison of predictions (training sample) of e2tree and correct response
```{r}
table(pred$fit,response_training)
```
Variable Importance
```{r}
ensemble_imp <- ensemble$importance %>% as.data.frame %>%
mutate(Variable = rownames(ensemble$importance),
RF_Var_Imp = round(MeanDecreaseAccuracy,2)) %>%
select(Variable, RF_Var_Imp)
V <- vimp(tree, training)
#V <- V$vimp %>%
# select(Variable,MeanImpurityDecrease, `ImpDec_ setosa`, `ImpDec_ versicolor`,`ImpDec_ virginica`) %>%
# mutate_at(c("MeanImpurityDecrease","ImpDec_ setosa", "ImpDec_ versicolor","ImpDec_ virginica"), round,2) %>%
# left_join(ensemble_imp, by = "Variable") %>%
# select(Variable, RF_Var_Imp, MeanImpurityDecrease, starts_with("ImpDec")) %>%
# rename(ETree_Var_Imp = MeanImpurityDecrease)
V
```
Comparison with the validation sample
```{r}
ensemble.pred <- predict(ensemble, validation[,-5], proximity = TRUE)
pred_val<- ePredTree(tree, validation[,-5], target="virginica")
```
Comparison of predictions (sample validation) of RF and e2tree
```{r}
table(pred_val$fit, ensemble.pred$predicted)
```
Comparison of predictions (validation sample) of RF and correct response
```{r}
table(ensemble.pred$predicted, response_validation)
ensemble.prob <- predict(ensemble, validation[,-5], proximity = TRUE, type="prob")
roc_ensemble<- roc(response_validation, ensemble.prob$predicted[,"virginica"], target="virginica")
roc_ensemble$auc
```
Comparison of predictions (validation sample) of e2tree and correct response
```{r}
table(pred_val$fit, response_validation)
roc_res <- roc(response_validation, pred_val$score, target="virginica")
roc_res$auc
```