diff --git a/roc-auc/calculate_roc.R b/roc-auc/calculate_roc.R index f5e2565..0bdc453 100644 --- a/roc-auc/calculate_roc.R +++ b/roc-auc/calculate_roc.R @@ -1,15 +1,29 @@ -calculate_roc <- function(df, cost_of_fp, cost_of_fn, n=100) { +#' Calculate ROC +#' +#' @param df The data.frame containing the predictions and true values. +#' @param pred_column The column name of df containing the predictions. +#' @param target_column The column name of df containing the target. +#' @param cost_of_fp The cost of false positives. +#' @param cost_of_fn The cost of false negatives. +#' @param n The number of points to compute on the ROC curve. +#' +#' @return A data.frame containing the ROC curve at n points. +#' + +calculate_roc <- function(df, pred_column="pred", target_column="survived", + cost_of_fp=1, cost_of_fn=1, n=100) { tpr <- function(df, threshold) { - sum(df$pred >= threshold & df$survived == 1) / sum(df$survived == 1) + sum(df[[pred_column]] >= threshold & df[[target_column]] == 1) / + sum(df[[target_column]] == 1) } fpr <- function(df, threshold) { - sum(df$pred >= threshold & df$survived == 0) / sum(df$survived == 0) + sum(df[[pred_column]] >= threshold & df[[target_column]] == 0) / sum(df[[target_column]] == 0) } cost <- function(df, threshold, cost_of_fp, cost_of_fn) { - sum(df$pred >= threshold & df$survived == 0) * cost_of_fp + - sum(df$pred < threshold & df$survived == 1) * cost_of_fn + sum(df[[pred_column]] >= threshold & df[[target_column]] == 0) * cost_of_fp + + sum(df[[pred_column]] < threshold & df[[target_column]] == 1) * cost_of_fn } roc <- data.frame(threshold = seq(0,1,length.out=n), tpr=NA, fpr=NA) @@ -18,4 +32,4 @@ calculate_roc <- function(df, cost_of_fp, cost_of_fn, n=100) { roc$cost <- sapply(roc$threshold, function(th) cost(df, th, cost_of_fp, cost_of_fn)) return(roc) -} \ No newline at end of file +} diff --git a/roc-auc/plot_roc.R b/roc-auc/plot_roc.R index d5e11dc..8a6aa45 100644 --- a/roc-auc/plot_roc.R +++ b/roc-auc/plot_roc.R @@ -1,5 +1,6 @@ plot_roc <- function(roc, threshold, cost_of_fp, cost_of_fn) { library(gridExtra) + library(grid) norm_vec <- function(v) (v - min(v))/diff(range(v)) @@ -25,4 +26,4 @@ plot_roc <- function(roc, threshold, cost_of_fp, cost_of_fn) { sub_title <- sprintf("threshold at %.2f - cost of FP = %d, cost of FN = %d", threshold, cost_of_fp, cost_of_fn) grid.arrange(p_roc, p_cost, ncol=2, sub=textGrob(sub_title, gp=gpar(cex=1), just="bottom")) -} \ No newline at end of file +}