7 How To Do Hyperparameter Tuning
Load Package And Data
load("../../data/craet_6.Rdata")
library(tidyverse)
library(caret)
# Set Parallel Processing - Decrease computation time
if (!require("doMC")) install.packages("doMC")
library(doMC)
registerDoMC(cores = 4)
Hyper parameter tuning using tuneGrid
Model Tuning Parameter Set
Cross Validation Set
Cross validationmethod
can be one amongst:- ‘boot’: Bootstrap sampling
- ‘boot632’: Bootstrap sampling with 63.2% bias correction applied
- ‘optimism_boot’: The optimism bootstrap estimator
- ‘boot_all’: All boot methods.
- ‘cv’: k-Fold cross validation
- ‘repeatedcv’: Repeated k-Fold cross validation
- ‘oob’: Out of Bag cross validation
- ‘LOOCV’: Leave one out cross validation
- ‘LGOCV’: Leave group out cross validation
Training And Tuning
Predict
Confusion Matrix
# Step 1: Define the tuneGrid
marsGrid <- expand.grid(nprune = c(2, 4, 6, 8, 10),
degree = c(1, 2, 3))
# Step 2: Define the training control
fitControl <- trainControl(
method = 'cv', # k-fold cross validation
number = 5, # number of folds
savePredictions = 'final', # saves predictions for optimal tuning parameter
classProbs = T, # should class probabilities be returned
summaryFunction=twoClassSummary # results summary function
)
# Step 3: Training and Tuning hyper parameters by setting tuneGrid
set.seed(100)
model_mars3 = train(Purchase ~ ., data=trainData, method='earth', metric='ROC', tuneGrid = marsGrid, trControl = fitControl)
model_mars3
## Multivariate Adaptive Regression Spline
##
## 857 samples
## 18 predictor
## 2 classes: 'CH', 'MM'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 685, 685, 687, 686, 685
## Resampling results across tuning parameters:
##
## degree nprune ROC Sens Spec
## 1 2 0.8745398 0.8700916 0.7006784
## 1 4 0.8924657 0.8662454 0.7394844
## 1 6 0.8912361 0.8719414 0.7334238
## 1 8 0.8886974 0.8661722 0.7334238
## 1 10 0.8879988 0.8623626 0.7423790
## 2 2 0.8745398 0.8700916 0.7006784
## 2 4 0.8953757 0.8739377 0.7454998
## 2 6 0.8917824 0.8681868 0.7515152
## 2 8 0.8904559 0.8624359 0.7574401
## 2 10 0.8932377 0.8547436 0.7784261
## 3 2 0.8582783 0.8777106 0.6618725
## 3 4 0.8914544 0.8662454 0.7544550
## 3 6 0.8910605 0.8586264 0.7665310
## 3 8 0.8838647 0.8452015 0.7456355
## 3 10 0.8827056 0.8471062 0.7426504
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nprune = 4 and degree = 2.
# Step 4: Predict on testData
predicted3 <- predict(model_mars3, testData4)
# Step 5: Compute the confusion matrix
confusionMatrix(reference = testData$Purchase, data = predicted3, mode='everything', positive='MM')
## Confusion Matrix and Statistics
##
## Reference
## Prediction CH MM
## CH 117 21
## MM 13 62
##
## Accuracy : 0.8404
## 95% CI : (0.7841, 0.8869)
## No Information Rate : 0.6103
## P-Value [Acc > NIR] : 2.164e-13
##
## Kappa : 0.6585
## Mcnemar's Test P-Value : 0.2299
##
## Sensitivity : 0.7470
## Specificity : 0.9000
## Pos Pred Value : 0.8267
## Neg Pred Value : 0.8478
## Precision : 0.8267
## Recall : 0.7470
## F1 : 0.7848
## Prevalence : 0.3897
## Detection Rate : 0.2911
## Detection Prevalence : 0.3521
## Balanced Accuracy : 0.8235
##
## 'Positive' Class : MM
##
sessionInfo()
## R version 3.4.3 (2017-11-30)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS Sierra 10.12.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] parallel methods stats graphics grDevices utils datasets
## [8] base
##
## other attached packages:
## [1] earth_4.6.1 plotmo_3.3.5 TeachingDemos_2.10
## [4] plotrix_3.7 doMC_1.3.5 iterators_1.0.9
## [7] foreach_1.4.4 caret_6.0-78 lattice_0.20-35
## [10] forcats_0.3.0 stringr_1.3.0 dplyr_0.7.4
## [13] purrr_0.2.4 readr_1.1.1 tidyr_0.8.0
## [16] tibble_1.4.2 ggplot2_2.2.1 tidyverse_1.2.1
##
## loaded via a namespace (and not attached):
## [1] nlme_3.1-131.1 lubridate_1.7.3 dimRed_0.1.0
## [4] httr_1.3.1 rprojroot_1.3-2 tools_3.4.3
## [7] backports_1.1.2 R6_2.2.2 rpart_4.1-13
## [10] lazyeval_0.2.1 colorspace_1.3-2 nnet_7.3-12
## [13] withr_2.1.1.9000 tidyselect_0.2.4 mnormt_1.5-5
## [16] compiler_3.4.3 cli_1.0.0 rvest_0.3.2
## [19] xml2_1.2.0 bookdown_0.7 scales_0.5.0.9000
## [22] sfsmisc_1.1-2 DEoptimR_1.0-8 psych_1.7.8
## [25] robustbase_0.92-8 digest_0.6.15 foreign_0.8-69
## [28] rmarkdown_1.9 pkgconfig_2.0.1 htmltools_0.3.6
## [31] rlang_0.2.0.9000 readxl_1.0.0 ddalpha_1.3.1.1
## [34] bindr_0.1.1 jsonlite_1.5 ModelMetrics_1.1.0
## [37] magrittr_1.5 Matrix_1.2-12 Rcpp_0.12.16
## [40] munsell_0.4.3 stringi_1.1.7 yaml_2.1.18
## [43] MASS_7.3-49 plyr_1.8.4 recipes_0.1.2
## [46] grid_3.4.3 crayon_1.3.4 haven_1.1.1
## [49] splines_3.4.3 hms_0.4.2 knitr_1.20
## [52] pillar_1.2.1 reshape2_1.4.3 codetools_0.2-15
## [55] stats4_3.4.3 CVST_0.2-1 glue_1.2.0
## [58] evaluate_0.10.1 blogdown_0.5 modelr_0.1.1
## [61] cellranger_1.1.0 gtable_0.2.0 kernlab_0.9-25
## [64] assertthat_0.2.0 DRR_0.0.3 xfun_0.1
## [67] gower_0.1.2 prodlim_1.6.1 broom_0.4.3
## [70] e1071_1.6-8 class_7.3-14 survival_2.41-3
## [73] timeDate_3043.102 RcppRoll_0.2.2 bindrcpp_0.2
## [76] lava_1.6 ipred_0.9-6
save.image(file = "../../data/craet_7.Rdata")