From 16da949b71b5dd4cdca473e62d0b876bd9ebea57 Mon Sep 17 00:00:00 2001 From: rockclimber112358 Date: Mon, 25 Apr 2016 05:52:03 +0200 Subject: [PATCH 1/2] Flexible column names and default equal costs This change allows users to use different column names for the prediction and target column names, and it allows users to run the code without explicitly specifying tp/fp costs (assuming they're equal). --- roc-auc/calculate_roc.R | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/roc-auc/calculate_roc.R b/roc-auc/calculate_roc.R index f5e2565..0bdc453 100644 --- a/roc-auc/calculate_roc.R +++ b/roc-auc/calculate_roc.R @@ -1,15 +1,29 @@ -calculate_roc <- function(df, cost_of_fp, cost_of_fn, n=100) { +#' Calculate ROC +#' +#' @param df The data.frame containing the predictions and true values. +#' @param pred_column The column name of df containing the predictions. +#' @param target_column The column name of df containing the target. +#' @param cost_of_fp The cost of false positives. +#' @param cost_of_fn The cost of false negatives. +#' @param n The number of points to compute on the ROC curve. +#' +#' @return A data.frame containing the ROC curve at n points. +#' + +calculate_roc <- function(df, pred_column="pred", target_column="survived", + cost_of_fp=1, cost_of_fn=1, n=100) { tpr <- function(df, threshold) { - sum(df$pred >= threshold & df$survived == 1) / sum(df$survived == 1) + sum(df[[pred_column]] >= threshold & df[[target_column]] == 1) / + sum(df[[target_column]] == 1) } fpr <- function(df, threshold) { - sum(df$pred >= threshold & df$survived == 0) / sum(df$survived == 0) + sum(df[[pred_column]] >= threshold & df[[target_column]] == 0) / sum(df[[target_column]] == 0) } cost <- function(df, threshold, cost_of_fp, cost_of_fn) { - sum(df$pred >= threshold & df$survived == 0) * cost_of_fp + - sum(df$pred < threshold & df$survived == 1) * cost_of_fn + sum(df[[pred_column]] >= threshold & df[[target_column]] == 0) * cost_of_fp + + sum(df[[pred_column]] < threshold & df[[target_column]] == 1) * cost_of_fn } roc <- data.frame(threshold = seq(0,1,length.out=n), tpr=NA, fpr=NA) @@ -18,4 +32,4 @@ calculate_roc <- function(df, cost_of_fp, cost_of_fn, n=100) { roc$cost <- sapply(roc$threshold, function(th) cost(df, th, cost_of_fp, cost_of_fn)) return(roc) -} \ No newline at end of file +} From 94acc02fa5fc806506d88284758abfed176a6831 Mon Sep 17 00:00:00 2001 From: rockclimber112358 Date: Mon, 25 Apr 2016 05:52:27 +0200 Subject: [PATCH 2/2] Missing dependency --- roc-auc/plot_roc.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/roc-auc/plot_roc.R b/roc-auc/plot_roc.R index d5e11dc..8a6aa45 100644 --- a/roc-auc/plot_roc.R +++ b/roc-auc/plot_roc.R @@ -1,5 +1,6 @@ plot_roc <- function(roc, threshold, cost_of_fp, cost_of_fn) { library(gridExtra) + library(grid) norm_vec <- function(v) (v - min(v))/diff(range(v)) @@ -25,4 +26,4 @@ plot_roc <- function(roc, threshold, cost_of_fp, cost_of_fn) { sub_title <- sprintf("threshold at %.2f - cost of FP = %d, cost of FN = %d", threshold, cost_of_fp, cost_of_fn) grid.arrange(p_roc, p_cost, ncol=2, sub=textGrob(sub_title, gp=gpar(cex=1), just="bottom")) -} \ No newline at end of file +}