Saturday, 21 September 2013

Classification and Regression Trees(CART)

Classification & Regression Tree is a classification method, technically known as Binary Recursive Partitioning. It uses historical data to construct Decision trees. Decision trees are further used for classifying new data.

Here the point comes : Where should we use CART?

Sometimes we have problems where we want answer  in “Yes/No”.
i.e : “Is salary greater than 30000?”,” Is it going to rain today?” etc

CART asks “Yes/No” questions. CART algorithm searches all possible variables and possible values in order to find the best split (means the question that split the data into two parts to find the maximum homogeneity )
Key elements for CART analysis are :
·         Split  each node in a tree.
·         Decide whether tree is complete or not.
·         Assign each leaf node to a class outcome
It returns the decision tree as below.
CART Modeling via rpart
Classification & Regression Tree can be generated using rpart package in R.
Following are the steps to get the CART model :
·         Grow a tree:   Use following to grow a tree
rpart(formula, data, weights, subset, na.action = na.rpart, method,
      model = FALSE, x = FALSE, y = TRUE, parms, control, cost, ...)

·         Examine the results based on the model - There are some functions that help to test the result
printcp(fit)
display cp table                             
plotcp(fit)
plot cross-validation results
rsq.rpart(fit)
plot approximate R-squared and relative error for different splits (2 plots). labels are only appropriate for the "anova" method.
print(fit)
print results
summary(fit)
detailed results including surrogate splits
plot(fit)
plot decision tree
text(fit)
label the decision tree plot
post(fit,file=)
create postscript plot of decision tree
            Here fit is the model output of rpart command.

·         Pruning the tree - It helps in avoiding the overfitting of data. Typically, you will want to select a tree size that minimizes the cross-validated error, the xerror column printed by printcp( ).
Prune the tree of desired size using
prune(fit,cp=)

Here is the example of classification tree :
library(rpart)
dataset <- read.table("C:\\Users\\Nishu\\Downloads\\bank\\bank.csv",header=T,sep=";")
# grow tree
fit <- rpart(y ~ ., method="class", data=dataset )
printcp(fit) # display the results
plotcp(fit) # visualize cross-validation results
summary(fit) # detailed summary of splits
# plot tree
plot(fit, uniform=TRUE,main="Classification Tree for Bank")
text(fit, use.n=TRUE, all=TRUE, cex=.8)
# create attractive postscript plot of tree
post(fit, file = "c:/tree.ps", title = "Classification Tree")
# prune the tree
pfit<- prune(fit, cp=   fit$cptable[which.min(fit$cptable[,"xerror"]),"CP"])
# plot the pruned tree
plot(pfit, uniform=TRUE,
   main="Pruned Classification Tree")
text(pfit, use.n=TRUE, all=TRUE, cex=.8)
post(pfit, file = "c:/ptree.ps",
   title = "Pruned Classification ")

Now we have got the model. Next step is to predict data based on the trained model.
First we’ll split the dataset into trained and testdata with a fixed percentage.
library(rpart)
dataset<-read.table("C:\\Users\\Nishu\\Downloads\\bank\\bank.csv",header=T,sep=";")
sub <- sample(nrow(dataset), floor(nrow(dataset) * 0.9)) # Here we are taking 90% training data
training <- dataset [sub, ]
testing <- dataset [-sub, ]
fit <- rpart(y ~ ., method="class", data=dataset )
predict(fit,testing,type=”class”)
# to get the confusion matrix
out <- table(predict(fit,testing,type="class"),dataset[-sub,"y"])


Here confusion matrix is :
              no      yes
  no      391       25
  yes     13        24

# To get the accuracy and other details, use confusionMatrix method with Caret package

library(caret)
confusionMatrix(out)

Output would be :

no  yes
  no      391  25
  yes      13  24

               Accuracy : 0.9101
                 95% CI : (0.8936, 0.9248)
    No Information Rate : 0.8968
    P-Value [Acc > NIR] : 0.05704
    Kappa : 0.4005
Mcnemar's Test P-Value : 9.213e-08

            Sensitivity : 0.9745
            Specificity : 0.3500
         Pos Pred Value : 0.9287
         Neg Pred Value : 0.6125
             Prevalence : 0.8968
         Detection Rate : 0.8740
   Detection Prevalence : 0.9410


       'Positive' Class : no
So here we have the desired output. Prediction and Accuracy of model based on which, we can predict future data. 

Download this dataset or other dataset from here and test the algorithm.

Here you go..!!!!

11 comments:

  1. Replies
    1. Hello There,


      You make learning and reading addictive. All eyes fixed on you. Thank you being such a good and trust worthy guide.

      totally new here (and quite new in the kingdom of ML) and this is my first question;

      I am using XGBClassifier (latest version) in python training a dataset where the observations (should) have different weights. More specifically in am trying modelling the probability for an insurance policy to have a claim. The insurance policies may be on risk (have insurance cover) for different durations e.g. 1 month, 6 month 1 year creating the need to weight them according to how long they are on risk.

      Is it possible to pass this (N*1) vectors which sums up to 1 to the XGBClassifier? If so, is the implication of this that obs. with weights are more likely to be picked by the sampling algo?

      Hope somebody can enlighten me. Please let me know if I need to clarify something.

      Thanks a lot. This was a perfect step-by-step guide. Don’t think it could have been done better.

      Ciao,
      Nandy

      Delete
  2. Hi There,

    So bloody thorough! Ah! So happy and blissed out! I feel redeemed by reading out Classification and Regression Trees(CART) . Keep up the good work!

    I am in a need of creating a String function that has to replace SASC memxlt() and xltable() functions.

    Description:
    void *memxlt(void *blk, const char *table, size_t n);

    memxlt() - memxlt translates a block of memory from one character set to another. The first argument ( blk ) is the address of the area of memory to be translated, and the third argument ( n ) is the number of characters to be translated. The second argument ( table ) is a pointer to a 256-byte translate table, which should be defined so that table[c] for any character c is the value to which c should be translated. (The function xltable is frequently used to build such a table.)

    But great job man, do keep posted with the new updates.

    Kind Regards,
    Preethi

    ReplyDelete