Creating a well formatted decision tree with partykit and listing the rules of the nodes

When we would like to create a decision tree we can choose from many possibilities, but formatting the tree is not always that easy. I will show a way to create, format a tree and list the rules of the inner nodes. Being able to read this rules can be very useful by an analysis. I will use the partykit package. Besides that we will need the data.table package by listing the rules of the inner nodes and optionally the xlsx package if we want to save these rules in an xls.

> library(xlsx)
> library(partykit)
> library(data.table)

 

I am using randomly generated data. In this example I am selling jeans and I am curius about on what depends the period of selling different type of jeans. I have data about 10000 sales including the fit (straight, bootcut or skinny), length (short or normal), color (blue, black, red or yellow) and price (10-124 $) of the jeans and the number of days of the sale of it (0-3 days, 4-14 days or more).

> head(data)
  Days      Fit Price $ Color Length
2  0-3 Straight      81   Red Normal
3  0-3 Straight      61   Red  Short
1 4-14  Bootcut      87   Red Normal
4 4-14  Bootcut      47   Red Normal
5 4-14   Skinny      82   Red Normal
6 4-14 Straight      88 Black Normal
> nrow(data)
[1] 10000

 

In order to estimate the accuracy of our model, I will split my data to a training and a test dataset.

> set.seed(9898)
> testind<-sample(c(1:nrow(data)), nrow(data)*0.2)
> test<-data[testind,]
> train<-data[-testind,]

 

The tree and the prediction:

> tree<-ctree(Days~., data=train, mincriterion = 0.99)
>
> test$predict<-predict(tree, newdata = test, type = "response")

By fitting the model we could control the testtype, the size of the tree, the splitting criterias etc. as well with a list of parameters in the control = ctree_control() argument. Now I only set the mincriterion, it is the value of 1-p, a split will be implemented only by exceeding this value, which is strenger by the default.

The accuracy of the modell is 0.622, which is quite good. Without the model only with guessing our accuracy would be about 0.33 (since the explained variable has 3 values), so 0.631 is a great improvement.

> test$match<-0
> test$match[which(test$predict==test$Days)]<-1
> accuracy<-sum(test$match)/nrow(test)
> accuracy
[1] 0.631

Now I am visualizing my tree. First I define the colors. The columncol means the colors of the columns in the terminal nodes, the labelcol the background colors of the labels in the inner nodes and the indexcol the bacground colors of the indexes.

> columncol<-hcl(c(270, 260, 250), 200, 30, 0.6)
> labelcol<-hcl(200, 200, 50, 0.2)
> indexcol<-hcl(150, 200, 50, 0.4)

The imaged will be saved immediately as Tree.jpg in the working directory.

> jpeg("Tree.jpg", width=2000, height=750)
> plot.new()
> plot(tree, type = c("simple"), gp = gpar(fontsize = 15),
+      drop_terminal = TRUE, tnex=1,
+      inner_panel = node_inner(tree, abbreviate = TRUE,
+                               fill = c(labelcol, indexcol), pval = TRUE, id = TRUE),
+      terminal_panel=node_barplot(tree, col = "black", fill = columncol[c(1,2,4)], beside = TRUE,
+                                  ymax = 1, ylines = TRUE, widths = 1, gap = 0.1,
+                                  reverse = FALSE, id = TRUE))
> title(main="My Beautiful Decision Tree", cex.main=3, line=1) 
> dev.off()

 

And here is the tree:

Tree

 

Maybe it would be important to see according to what kind of rules was the data partitioned. Now I list these rules produced for each class. With this script we will get the results in formatted, and it will be saved as an xls file.

 

First I save in the terminals the predictions and the rules how have we got them.

> terminals <- data.frame(response = predict(tree, type = "response"))
> terminals$prob <- predict(tree, type = "prob")
> 
> rules<-partykit:::.list.rules.party(tree)
> terminals$rule <- rules[as.character(predict(tree, type = "node"))]

Then I format and aggregate the values and the variablenames and I save it as Rules.xls.

> value<-cbind(terminals$rule, terminals[,which(grepl("prob", colnames(terminals)))])
> colnames(value)[1]<-"rules"
> DF<-data.table(value)  
> terminals<-as.data.frame(DF[,sapply(.SD, function(x) list(max=max(x))),by=list(rules)])
> terminalsname<-cbind(names(rules), rules)
> terminals<-merge(terminals, terminalsname, by=c("rules"))
> names<-colnames(terminals)[which(grepl(".max", colnames(terminals)))]
> names<-unlist(strsplit(names, "[.]"))[-which(unlist(strsplit(names, "[.]"))=="max")]
> colnames(terminals)[which(grepl(".max", colnames(terminals)))]<-names
> colnames(terminals)[which(colnames(terminals)=="V1")]<-"Terminal_Node"
> colnames(terminals)[which(colnames(terminals)=="rules")]<-"Rule"
> terminals<-terminals[,c(5,1,2,3,4)]
> terminals$Rule[which(grepl("\"NA\"", terminals$Rule))]<-gsub("\"NA\"", "", terminals$Rule[which(grepl("\"NA\"", terminals$Rule))])
> terminals$Terminal_Node<-as.numeric(as.character(terminals$Terminal_Node))
> terminals<-terminals[order(terminals$Terminal_Node),]
> write.xlsx(terminals, "Rules.xls", row.names=FALSE)

 

So here are the decision rules.

You can learn more about decision trees here or here.

 

Terminal_Node Rule 0-3 4-14 More
4 Fit %in% c(“Bootcut”, “Skinny”) & Price $ <= 65 & Price $ <= 50 0.77 0.23 0.00
6 Fit %in% c(“Bootcut”, “Skinny”) & Price $ <= 65 & Price $ > 50 & Fit %in% c(“Bootcut”, ) 0.62 0.37 0.00
7 Fit %in% c(“Bootcut”, “Skinny”) & Price $ <= 65 & Price $ > 50 & Fit %in% c(“Skinny”, ) 0.96 0.04 0
11 Fit %in% c(“Bootcut”, “Skinny”) & Price $ > 65 & Price $ <= 99 & Fit %in% c(“Bootcut”, ) & Price $ <= 76 0.54 0.45 0.01
12 Fit %in% c(“Bootcut”, “Skinny”) & Price $ > 65 & Price $ <= 99 & Fit %in% c(“Bootcut”, ) & Price $ > 76 0.45 0.52 0.03
13 Fit %in% c(“Bootcut”, “Skinny”) & Price $ > 65 & Price $ <= 99 & Fit %in% c(“Skinny”, ) 0.92 0.08 0
14 Fit %in% c(“Bootcut”, “Skinny”) & Price $ > 65 & Price $ > 99 0.25 0.63 0.13
16 Fit %in% c(“Straight”) & Price $ <= 56 0.27 0.64 0.09
18 Fit %in% c(“Straight”) & Price $ > 56 & Price $ <= 74 0.17 0.70 0.13
20 Fit %in% c(“Straight”) & Price $ > 56 & Price $ > 74 & Price $ <= 84 0.13 0.66 0.21
21 Fit %in% c(“Straight”) & Price $ > 56 & Price $ > 74 & Price $ > 84 0.07 0.68 0.25

Leave a comment