Functions: prepare_eda_recipe -> build_eda_recipe,
create_efficiency_plot -> plot_efficiency,
format_class_imbalance_tourney_gt -> format_tournament_gt
Targets: model_inputs_prefix -> baf_model_input_prefix,
tbl_fraud_by_month_data -> fraud_by_month_summary,
model_diag -> diag_fit, winning_params -> best_params,
production_recipe_blueprint -> prod_recipe,
final_eval_data -> test_predictions
pkgdown: restructured reference index into 6 logical sections,
removed stale names and development comments.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1330 lines
46 KiB
R
1330 lines
46 KiB
R
#' Convert BAF CSV to partitioned Parquet in MinIO (S3)
|
|
#'
|
|
#' Reads `Base.csv` from a MinIO/S3 bucket prefix (e.g., `"01_raw"`) and writes a
|
|
#' Hive-style partitioned Parquet dataset to another prefix (e.g., `"02_intermediate"`),
|
|
#' partitioned by `variant` (e.g., `variant=Base/part-*.parquet`).
|
|
#'
|
|
#' Connection settings are taken from environment variables:
|
|
#' \itemize{
|
|
#' \item \code{BAF_ENDPOINT} (e.g. \code{"minio:9000"} or \code{"192.168.4.xx:9000"})
|
|
#' \item \code{BAF_KEY} (MinIO access key)
|
|
#' \item \code{BAF_SECRET} (MinIO secret key)
|
|
#' }
|
|
#'
|
|
#' @param from_prefix Character. Prefix/key under the bucket containing CSVs (e.g. \code{"01_raw"}).
|
|
#' @param to_prefix Character. Prefix/key under the bucket to write Parquet dataset (e.g. \code{"02_intermediate"}).
|
|
#' @param bucket_name Character. Bucket name. Default \code{"baf-fraud"}.
|
|
#'
|
|
#' @return A character string giving the destination dataset prefix (typically \code{to_prefix}).
|
|
#'
|
|
#' @export
|
|
#'
|
|
#' @importFrom arrow s3_bucket read_csv_arrow write_dataset
|
|
#' @importFrom dplyr mutate
|
|
#' @importFrom stringr str_remove str_replace_all
|
|
#'
|
|
#' @examples
|
|
#' \dontrun{
|
|
#' Sys.setenv(
|
|
#' BAF_ENDPOINT = "minio:9000",
|
|
#' BAF_KEY = "YOUR_ACCESS_KEY",
|
|
#' BAF_SECRET = "YOUR_SECRET_KEY"
|
|
#' )
|
|
#' convert_to_parquet(from_prefix = "01_raw", to_prefix = "02_intermediate", bucket_name = "baf-fraud")
|
|
#' }
|
|
convert_to_parquet <- function(
|
|
from_prefix,
|
|
to_prefix,
|
|
bucket_name = "baf-fraud"
|
|
) {
|
|
endpoint <- Sys.getenv("BAF_ENDPOINT")
|
|
access_key <- Sys.getenv("BAF_KEY")
|
|
secret_key <- Sys.getenv("BAF_SECRET")
|
|
|
|
if (endpoint == "") stop("Missing env var: BAF_ENDPOINT")
|
|
if (access_key == "") stop("Missing env var: BAF_KEY")
|
|
if (secret_key == "") stop("Missing env var: BAF_SECRET")
|
|
|
|
bucket <- s3_bucket(
|
|
bucket_name,
|
|
endpoint_override = endpoint,
|
|
scheme = "http",
|
|
access_key = access_key,
|
|
secret_key = secret_key,
|
|
region = "us-east-1"
|
|
)
|
|
|
|
path_raw <- bucket$path(from_prefix)
|
|
path_out <- bucket$path(to_prefix)
|
|
|
|
# List CSVs (Arrow may return full keys; basename() normalizes to file name)
|
|
file_list <- basename(path_raw$ls())
|
|
file_list <- file_list[grepl("\\.csv$", file_list, ignore.case = TRUE)]
|
|
|
|
# Current mode: only Base.csv (since you've trimmed the bucket)
|
|
file_list <- file_list[tolower(file_list) == "base.csv"]
|
|
|
|
if (length(file_list) == 0) {
|
|
stop("No Base.csv found under ", bucket_name, "/", from_prefix, "/")
|
|
}
|
|
|
|
message("Found ", length(file_list), " file(s) to process.")
|
|
|
|
for (file_name in file_list) {
|
|
variant_name <- file_name |>
|
|
str_remove("\\.csv$") |>
|
|
str_replace_all(" ", "_") # e.g., "Variant I.csv" -> "Variant_I"
|
|
|
|
message("\u2714 Processing: ", variant_name, "...")
|
|
|
|
df <- read_csv_arrow(path_raw$path(file_name)) |>
|
|
mutate(variant = variant_name)
|
|
|
|
write_dataset(
|
|
df,
|
|
path = path_out,
|
|
format = "parquet",
|
|
partitioning = "variant"
|
|
)
|
|
}
|
|
message("\u2714 Converted Base.csv to Parquet on MinIO at s3://", bucket_name, "/", to_prefix, "/variant=Base/")
|
|
out_base <- bucket$path(to_prefix)$path("variant=Base")
|
|
n_parquet <- sum(grepl("\\.parquet$", out_base$ls(), ignore.case = TRUE))
|
|
message("\u2714 Wrote ", n_parquet, " parquet file(s) under variant=Base/")
|
|
# Return a stable "artifact pointer" for targets
|
|
file.path(to_prefix, "variant=Base")
|
|
}
|
|
|
|
#' Connect to BAF dataset on MinIO (Arrow or DuckDB)
|
|
#'
|
|
#' @param prefix Character. Dataset prefix inside the bucket
|
|
#' (e.g., "02_intermediate/variant=Base").
|
|
#' @param bucket_name Character. Bucket name. Defaults to env var BAF_BUCKET.
|
|
#' @param use_duckdb Logical. If TRUE, return a DuckDB-backed lazy tbl.
|
|
#'
|
|
#' @return An Arrow Dataset (default) or a DuckDB-backed lazy table.
|
|
#' @export
|
|
#'
|
|
#' @importFrom arrow s3_bucket open_dataset to_duckdb
|
|
connect_baf <- function(prefix, bucket_name = Sys.getenv("BAF_BUCKET"), use_duckdb = TRUE) {
|
|
|
|
endpoint <- Sys.getenv("BAF_ENDPOINT")
|
|
key <- Sys.getenv("BAF_KEY")
|
|
secret <- Sys.getenv("BAF_SECRET")
|
|
|
|
if (bucket_name == "") stop("Missing env var or arg: BAF_BUCKET / bucket_name")
|
|
if (endpoint == "") stop("Missing env var: BAF_ENDPOINT")
|
|
if (key == "") stop("Missing env var: BAF_KEY")
|
|
if (secret == "") stop("Missing env var: BAF_SECRET")
|
|
|
|
b <- arrow::s3_bucket(
|
|
bucket_name,
|
|
endpoint_override = endpoint,
|
|
scheme = "http",
|
|
access_key = key,
|
|
secret_key = secret,
|
|
region = "us-east-1"
|
|
)
|
|
|
|
ds <- arrow::open_dataset(b$path(prefix), format = "parquet")
|
|
|
|
if (isTRUE(use_duckdb)) {
|
|
ds <- arrow::to_duckdb(ds)
|
|
message("\u2714 Connected to s3://", bucket_name, "/", prefix, " via DuckDB Engine")
|
|
} else {
|
|
message("\u2714 Connected to s3://", bucket_name, "/", prefix, " via Arrow Engine")
|
|
}
|
|
|
|
ds
|
|
}
|
|
|
|
#' Clean the BAF Base dataset and write to 03_primary
|
|
#'
|
|
#' @param in_prefix Character. Input dataset prefix inside bucket (e.g. "02_intermediate/variant=Base").
|
|
#' @param out_prefix Character. Output dataset prefix inside bucket (e.g. "03_primary/variant=Base").
|
|
#' @param bucket_name Character. Bucket name. Default "baf-fraud".
|
|
#' @param partitioning Character vector of columns to partition by. Default "month". Set NULL to disable.
|
|
#' @param existing_data_behavior One of "overwrite", "error", "delete_matching". Default "overwrite".
|
|
#' @param verbose Logical. Emit progress messages. Default TRUE.
|
|
#'
|
|
#' @return Character. out_prefix (for downstream targets).
|
|
#' @export
|
|
#'
|
|
#' @importFrom dplyr mutate if_else select rename tbl_vars
|
|
#' @importFrom arrow s3_bucket write_dataset
|
|
clean_baf_base <- function(
|
|
in_prefix,
|
|
out_prefix = "03_primary/variant=Base",
|
|
bucket_name = "baf-fraud",
|
|
partitioning = "month",
|
|
existing_data_behavior = c("overwrite", "error", "delete_matching"),
|
|
verbose = TRUE
|
|
) {
|
|
existing_data_behavior <- match.arg(existing_data_behavior)
|
|
|
|
endpoint <- Sys.getenv("BAF_ENDPOINT")
|
|
key <- Sys.getenv("BAF_KEY")
|
|
secret <- Sys.getenv("BAF_SECRET")
|
|
|
|
if (endpoint == "") stop("Missing env var: BAF_ENDPOINT")
|
|
if (key == "") stop("Missing env var: BAF_KEY")
|
|
if (secret == "") stop("Missing env var: BAF_SECRET")
|
|
|
|
if (verbose) message("Beginning cleaning...")
|
|
|
|
# Arrow-native dataset (required for Arrow write_dataset)
|
|
ds <- connect_baf(in_prefix, bucket_name = bucket_name, use_duckdb = FALSE)
|
|
|
|
# 1) outcome label
|
|
ds_labeled <- ds |>
|
|
mutate(outcome = if_else(fraud_bool == 1L, "Fraud", "Legit")) |>
|
|
select(-fraud_bool)
|
|
|
|
if (verbose) message("\u2714 Outcome column created as `outcome`")
|
|
|
|
# Normalize email column name to match datasheet
|
|
vars <- dplyr::tbl_vars(ds_labeled)
|
|
if ("device_distinct_emails_8w" %in% vars && !("device_distinct_emails" %in% vars)) {
|
|
ds_labeled <- ds_labeled |> rename(device_distinct_emails = device_distinct_emails_8w)
|
|
}
|
|
|
|
# Re-check after potential rename
|
|
vars <- dplyr::tbl_vars(ds_labeled)
|
|
required <- c(
|
|
"prev_address_months_count",
|
|
"current_address_months_count",
|
|
"bank_months_count",
|
|
"session_length_in_minutes",
|
|
"device_distinct_emails",
|
|
"intended_balcon_amount"
|
|
)
|
|
missing_required <- setdiff(required, vars)
|
|
if (length(missing_required) > 0) {
|
|
stop("Missing expected columns: ", paste(missing_required, collapse = ", "))
|
|
}
|
|
|
|
# 2) sentinel -1 -> NA (Arrow-friendly: explicit if_else per column)
|
|
ds_na_recode <- ds_labeled |>
|
|
mutate(
|
|
prev_address_months_count = if_else(prev_address_months_count == -1L, NA_integer_, prev_address_months_count),
|
|
current_address_months_count = if_else(current_address_months_count == -1L, NA_integer_, current_address_months_count),
|
|
bank_months_count = if_else(bank_months_count == -1L, NA_integer_, bank_months_count),
|
|
session_length_in_minutes = if_else(session_length_in_minutes == -1L, NA_integer_, session_length_in_minutes),
|
|
device_distinct_emails = if_else(device_distinct_emails == -1L, NA_integer_, device_distinct_emails)
|
|
)
|
|
|
|
if (verbose) message("\u2714 Sentinel (-1) values converted to NA")
|
|
|
|
# 3) intended_balcon_amount: negatives are missing -> NA
|
|
ds_balcon_recode <- ds_na_recode |>
|
|
mutate(
|
|
intended_balcon_amount = if_else(intended_balcon_amount < 0, NA_real_, intended_balcon_amount)
|
|
)
|
|
|
|
if (verbose) message("\u2714 intended_balcon_amount constrained to values >= 0 (negatives set to NA)")
|
|
|
|
# 4) Write to MinIO using arrow s3_bucket path (keeps endpoint_override)
|
|
b <- arrow::s3_bucket(
|
|
bucket_name,
|
|
endpoint_override = endpoint,
|
|
scheme = "http",
|
|
access_key = key,
|
|
secret_key = secret,
|
|
region = "us-east-1"
|
|
)
|
|
out_path <- b$path(out_prefix)
|
|
|
|
if (is.null(partitioning)) {
|
|
arrow::write_dataset(
|
|
ds_balcon_recode,
|
|
path = out_path,
|
|
format = "parquet",
|
|
existing_data_behavior = existing_data_behavior
|
|
)
|
|
} else {
|
|
arrow::write_dataset(
|
|
ds_balcon_recode,
|
|
path = out_path,
|
|
format = "parquet",
|
|
partitioning = partitioning,
|
|
existing_data_behavior = existing_data_behavior
|
|
)
|
|
}
|
|
|
|
if (verbose) message("\u2714 Wrote cleaned dataset to s3://", bucket_name, "/", out_prefix)
|
|
|
|
out_prefix
|
|
}
|
|
#' Plot applications by month (Legit vs Fraud) on a log scale
|
|
#'
|
|
#' Builds an exploratory chart of absolute application counts by month
|
|
#' split by outcome (Legit vs Fraud). Uses a log10 y-axis so rare fraud
|
|
#' remains visible on the same axis.
|
|
#'
|
|
#' Data source: expects a cleaned "primary" dataset prefix (e.g. 03_primary/variant=Base)
|
|
#' stored in MinIO/S3, accessed via \code{connect_baf()}.
|
|
#'
|
|
#' @param dataset_prefix Character. Prefix inside the bucket, e.g. "03_primary/variant=Base".
|
|
#' @param bucket_name Character. Bucket name. Default "baf-fraud".
|
|
#' @param palette Character. colorspace qualitative palette name. Default "Dark 3".
|
|
#' @param title Character. Plot title. Default "".
|
|
#'
|
|
#' @return A ggplot object.
|
|
#' @export
|
|
#'
|
|
#' @importFrom dplyr group_by summarise mutate arrange collect rename n
|
|
#' @importFrom tidyr pivot_longer
|
|
#' @importFrom ggplot2 ggplot aes geom_line geom_point scale_y_log10 labs theme
|
|
#' @importFrom cowplot theme_cowplot
|
|
#' @importFrom colorspace qualitative_hcl
|
|
plot_fraud_by_month <- function(
|
|
dataset_prefix,
|
|
bucket_name = "baf-fraud",
|
|
palette = "Dark 3",
|
|
title = ""
|
|
) {
|
|
ds <- connect_baf(dataset_prefix, bucket_name = bucket_name, use_duckdb = TRUE)
|
|
|
|
ds_fraud <- ds |>
|
|
dplyr::group_by(month) |>
|
|
dplyr::summarise(
|
|
Fraud = sum(outcome == "Fraud", na.rm = TRUE),
|
|
Legit = sum(outcome == "Legit", na.rm = TRUE),
|
|
Total = dplyr::n(),
|
|
.groups = "drop"
|
|
) |>
|
|
dplyr::mutate(Pct_Fraud = 100 * Fraud / Total) |>
|
|
dplyr::arrange(month) |>
|
|
dplyr::collect() |>
|
|
dplyr::rename(Month = month) |>
|
|
tidyr::pivot_longer(c(Fraud, Legit), names_to = "Outcome", values_to = "n") |>
|
|
dplyr::mutate(Outcome = factor(Outcome, levels = c("Legit", "Fraud")))
|
|
|
|
cols <- colorspace::qualitative_hcl(2, palette = palette)
|
|
names(cols) <- levels(ds_fraud$Outcome)
|
|
|
|
ggplot2::ggplot(ds_fraud, ggplot2::aes(x = factor(Month), y = n, group = Outcome, color = Outcome)) +
|
|
ggplot2::geom_line(linewidth = 1) +
|
|
ggplot2::geom_point(size = 2) +
|
|
ggplot2::scale_y_log10(
|
|
breaks = c(1e3, 1e4, 1e5),
|
|
labels = c("1k", "10k", "100k")
|
|
) +
|
|
ggplot2::labs(
|
|
title = title,
|
|
x = "Month",
|
|
y = "Applications (log10 scale)",
|
|
color = "Outcome"
|
|
) +
|
|
ggplot2::scale_color_manual(values = cols) +
|
|
cowplot::theme_cowplot(font_size = 20)
|
|
}
|
|
|
|
#' Fraud prevalence by month (counts + percent)
|
|
#'
|
|
#' Computes monthly counts of Fraud/Legit, totals, and percent fraud.
|
|
#'
|
|
#' @param in_prefix Character. Dataset prefix inside the bucket, e.g. "03_primary/variant=Base".
|
|
#' @param use_duckdb Logical. Use DuckDB for lazy querying. Default TRUE.
|
|
#'
|
|
#' @return A tibble with Month, Fraud, Legit, Total, Pct_Fraud.
|
|
#' @export
|
|
#'
|
|
#' @importFrom dplyr group_by summarise n mutate arrange rename
|
|
#' @importFrom dplyr `%>%`
|
|
compute_fraud_by_month <- function(in_prefix, use_duckdb = TRUE) {
|
|
ds <- connect_baf(in_prefix, use_duckdb = use_duckdb)
|
|
|
|
ds %>%
|
|
dplyr::group_by(month) %>%
|
|
dplyr::summarise(
|
|
Fraud = sum(outcome == "Fraud", na.rm = TRUE),
|
|
Legit = sum(outcome == "Legit", na.rm = TRUE),
|
|
Total = dplyr::n(),
|
|
.groups = "drop"
|
|
) %>%
|
|
dplyr::mutate(Pct_Fraud = 100 * Fraud / Total) %>%
|
|
dplyr::arrange(month) %>%
|
|
dplyr::collect() %>%
|
|
dplyr::rename(Month = month)
|
|
}
|
|
#' Format fraud-by-month table as a gt object
|
|
#'
|
|
#' @param x Tibble from compute_fraud_by_month().
|
|
#'
|
|
#' @return A gt table.
|
|
#' @export
|
|
#'
|
|
#' @importFrom gt gt fmt_number cols_label tab_options
|
|
format_fraud_by_month_gt <- function(x) {
|
|
gt::gt(x) %>%
|
|
gt::cols_label(
|
|
Month = "Month",
|
|
Fraud = "Fraud",
|
|
Legit = "Legit",
|
|
Total = "Total",
|
|
Pct_Fraud = "% Fraud"
|
|
) %>%
|
|
gt::fmt_number(columns = c(Fraud, Legit, Total), decimals = 0) %>%
|
|
gt::fmt_number(columns = Pct_Fraud, decimals = 2) %>%
|
|
gt::tab_options(
|
|
table.font.size = "80%",
|
|
data_row.padding = gt::px(2)
|
|
)
|
|
}
|
|
#' Save a report table artifact
|
|
#'
|
|
#' @param x Object to save.
|
|
#' @param filename Output filename, e.g. "tbl_fraud_by_month.rds".
|
|
#' @param out_dir Output directory. Default "reports/tables".
|
|
#'
|
|
#' @return Character path to saved file.
|
|
#' @export
|
|
#'
|
|
#' @importFrom readr write_rds
|
|
save_report_table <- function(x, filename, out_dir = "reports/tables") {
|
|
dir.create(out_dir, showWarnings = FALSE, recursive = TRUE)
|
|
out_path <- file.path(out_dir, filename)
|
|
readr::write_rds(x, out_path)
|
|
normalizePath(out_path, winslash = "/", mustWork = FALSE)
|
|
}
|
|
|
|
|
|
|
|
#' Save a report figure artifact
|
|
#'
|
|
#' Saves a ggplot object to \code{reports/figures/}.
|
|
#' Intended for use in `targets` pipelines as a file-producing target.
|
|
#'
|
|
#' @param plot A ggplot object.
|
|
#' @param filename Character. Output filename, e.g. \code{"fig_fraud_by_month.png"}.
|
|
#' @param out_dir Character. Output directory. Default \code{"reports/figures"}.
|
|
#' @param width,height,dpi Numeric. Passed to \code{ggplot2::ggsave()}.
|
|
#'
|
|
#' @return Character. Normalized path to the saved file.
|
|
#' @export
|
|
#'
|
|
#' @importFrom ggplot2 ggsave
|
|
save_report_figure <- function(
|
|
plot,
|
|
filename,
|
|
out_dir = "reports/figures",
|
|
width = 12,
|
|
height = 6.75,
|
|
dpi = 300
|
|
) {
|
|
dir.create(out_dir, showWarnings = FALSE, recursive = TRUE)
|
|
out_path <- file.path(out_dir, filename)
|
|
ggplot2::ggsave(out_path, plot = plot, width = width, height = height, dpi = dpi)
|
|
normalizePath(out_path, winslash = "/", mustWork = FALSE)
|
|
}
|
|
#' Render Quarto revealjs slideshow after required assets exist
|
|
#'
|
|
#' @param qmd Character. Input Quarto file (e.g. "index.qmd").
|
|
#' @param assets Character vector. File paths that must exist before rendering.
|
|
#' @param output_dir Character. Output directory for rendered slides.
|
|
#'
|
|
#' @return Character path to the rendered HTML file.
|
|
#' @export
|
|
#'
|
|
#' @importFrom quarto quarto_render
|
|
render_slides <- function(qmd = "index.qmd", assets, output_dir = "reports/slides") {
|
|
missing <- assets[!file.exists(assets)]
|
|
if (length(missing) > 0) {
|
|
stop("Missing report assets:\n", paste(missing, collapse = "\n"))
|
|
}
|
|
|
|
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)
|
|
|
|
quarto::quarto_render(
|
|
input = qmd,
|
|
quiet = FALSE,
|
|
quarto_args = c("--output-dir", output_dir)
|
|
)
|
|
|
|
file.path(output_dir, sub("\\.qmd$", ".html", basename(qmd)))
|
|
}
|
|
#' Run Class Imbalance Tournament
|
|
#'
|
|
#' Trains LightGBM models across different class imbalance strategies
|
|
#' (Standard, SMOTE, Adasyn, etc.) using sliding time windows. Evaluates
|
|
#' performance using PR-AUC and calculates statistical significance.
|
|
#' Includes common-sense hyperparameter defaults to prevent overfitting.
|
|
#'
|
|
#' @param tasks A tibble containing recipe_name, data_folder, and scale_pos_weight.
|
|
#' @param windows A tibble containing window_id, train_months, and test_month.
|
|
#' @param feature_prefix Character. The upstream dependency prefix (used to force DAG execution).
|
|
#' @param bucket_name Character. Bucket name. Default "baf-fraud".
|
|
#' @param inputs_prefix Character. The folder containing the sampled data. Default "05_model_input".
|
|
#'
|
|
#' @return A tibble with the summarized tournament results.
|
|
#' @export
|
|
#'
|
|
#' @importFrom arrow s3_bucket open_dataset
|
|
#' @importFrom dplyr filter collect select any_of bind_rows
|
|
#' @importFrom lightgbm lgb.Dataset lgb.train
|
|
#' @importFrom yardstick pr_auc
|
|
#' @importFrom glue glue
|
|
run_imbalance_tournament <- function(
|
|
tasks,
|
|
windows,
|
|
feature_prefix,
|
|
bucket_name = "baf-fraud",
|
|
inputs_prefix = "05_model_input"
|
|
) {
|
|
|
|
endpoint <- Sys.getenv("BAF_ENDPOINT")
|
|
key <- Sys.getenv("BAF_KEY")
|
|
secret <- Sys.getenv("BAF_SECRET")
|
|
|
|
if (endpoint == "") stop("Missing env var: BAF_ENDPOINT")
|
|
|
|
b <- arrow::s3_bucket(
|
|
bucket_name,
|
|
endpoint_override = endpoint,
|
|
scheme = "http",
|
|
access_key = key,
|
|
secret_key = secret,
|
|
region = "us-east-1"
|
|
)
|
|
|
|
results_log <- list()
|
|
counter <- 1
|
|
|
|
# 1. THE LOOP
|
|
for (i in seq_len(nrow(tasks))) {
|
|
task <- tasks[i, ]
|
|
|
|
for (j in seq_len(nrow(windows))) {
|
|
win <- windows[j, ]
|
|
|
|
message(glue::glue("\n\u2699\ufe0f {task$recipe_name} | {win$window_id}"))
|
|
|
|
# Load Training Data
|
|
train_df <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/{task$data_folder}"))) |>
|
|
dplyr::filter(month %in% win$train_months[[1]]) |>
|
|
dplyr::collect()
|
|
|
|
X_train <- train_df |>
|
|
dplyr::select(-outcome, -dplyr::any_of(c("month", "month_date"))) |>
|
|
as.matrix()
|
|
y_train <- as.numeric(train_df$outcome == "Fraud")
|
|
|
|
# Train Model (with strict overfitting brakes)
|
|
dtrain <- lightgbm::lgb.Dataset(data = X_train, label = y_train)
|
|
start_time <- Sys.time()
|
|
|
|
model <- lightgbm::lgb.train(
|
|
params = list(
|
|
objective = "binary",
|
|
metric = "auc",
|
|
learning_rate = 0.05,
|
|
|
|
# --- The Common Sense Defaults ---
|
|
max_depth = 6,
|
|
num_leaves = 31,
|
|
min_data_in_leaf = 250,
|
|
feature_fraction = 0.8,
|
|
bagging_fraction = 0.8,
|
|
bagging_freq = 1,
|
|
# ---------------------------------
|
|
|
|
device = "cpu",
|
|
scale_pos_weight = task$scale_pos_weight
|
|
),
|
|
data = dtrain,
|
|
nrounds = 500,
|
|
verbose = -1
|
|
)
|
|
|
|
end_time <- Sys.time()
|
|
runtime <- as.numeric(difftime(end_time, start_time, units = "secs"))
|
|
|
|
# Load Testing Data (Always evaluate on the baseline)
|
|
test_df <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/baseline"))) |>
|
|
dplyr::filter(month == win$test_month) |>
|
|
dplyr::collect()
|
|
|
|
X_test <- test_df |>
|
|
dplyr::select(-outcome, -dplyr::any_of(c("month", "month_date"))) |>
|
|
as.matrix()
|
|
|
|
preds <- predict(model, X_test)
|
|
|
|
# Score Model
|
|
eval_df <- data.frame(
|
|
truth = factor(test_df$outcome, levels = c("Fraud", "Legit")),
|
|
prob = preds
|
|
)
|
|
|
|
score <- yardstick::pr_auc(eval_df, truth, prob)$.estimate
|
|
|
|
message(glue::glue(" -> PR-AUC: {round(score, 4)} | Time: {round(runtime, 2)}s"))
|
|
|
|
results_log[[counter]] <- data.frame(
|
|
recipe = task$recipe_name,
|
|
window = win$window_id,
|
|
pr_auc = score,
|
|
runtime_sec = runtime
|
|
)
|
|
counter <- counter + 1
|
|
|
|
# Cleanup
|
|
rm(train_df, X_train, y_train, dtrain, model, test_df, X_test, preds, eval_df)
|
|
gc()
|
|
}
|
|
}
|
|
|
|
# Return the raw log for downstream targets to handle
|
|
results_df <- dplyr::bind_rows(results_log)
|
|
return(results_df)
|
|
}
|
|
|
|
#' Format Tournament Results Table
|
|
#'
|
|
#' Aggregates results from the model tournament and performs paired t-tests
|
|
#' against the 'Standard' model to determine statistical significance.
|
|
#'
|
|
#' @param results_df The tibble output from `run_imbalance_tournament`.
|
|
#'
|
|
#' @importFrom dplyr filter arrange pull group_by summarize mutate case_when desc
|
|
#' @importFrom gt gt tab_header fmt_number data_color
|
|
#' @importFrom stats t.test
|
|
#'
|
|
#' @return A formatted gt table object.
|
|
#' @export
|
|
format_tournament_gt <- function(results_df) {
|
|
|
|
# Extract scores for the 'Standard' recipe to use as the baseline for t-tests
|
|
standard_scores <- results_df |>
|
|
dplyr::filter(recipe == "Standard") |>
|
|
dplyr::arrange(window) |>
|
|
dplyr::pull(pr_auc)
|
|
|
|
# Internal helper to calculate p-values vs the Standard baseline
|
|
get_p_value <- function(target_recipe, df) {
|
|
if (target_recipe == "Standard") return(1.0)
|
|
|
|
target_scores <- df |>
|
|
dplyr::filter(recipe == target_recipe) |>
|
|
dplyr::arrange(window) |>
|
|
dplyr::pull(pr_auc)
|
|
|
|
tryCatch({
|
|
# Paired t-test accounts for the same windows/seeds being used
|
|
test <- stats::t.test(target_scores, standard_scores, paired = TRUE)
|
|
test$p.value
|
|
}, error = function(e) NA_real_)
|
|
}
|
|
|
|
# Aggregating window results into a final summary
|
|
final_stats <- results_df |>
|
|
dplyr::group_by(recipe) |>
|
|
dplyr::summarize(
|
|
avg_pr_auc = mean(pr_auc),
|
|
avg_runtime = mean(runtime_sec),
|
|
p_val_vs_std = get_p_value(unique(recipe), results_df)
|
|
) |>
|
|
dplyr::mutate(
|
|
significance = dplyr::case_when(
|
|
recipe == "Standard" ~ "-",
|
|
p_val_vs_std < 0.05 ~ "Yes (*)",
|
|
TRUE ~ "No (ns)"
|
|
)
|
|
) |>
|
|
dplyr::arrange(dplyr::desc(avg_pr_auc))
|
|
|
|
# Formatting with gt for the Quarto presentation
|
|
final_stats |>
|
|
gt::gt() |>
|
|
gt::tab_header(
|
|
title = "Class Imbalance Strategy Showdown",
|
|
subtitle = "Paired t-test comparison against 'Standard' baseline"
|
|
) |>
|
|
gt::fmt_number(columns = c(avg_pr_auc, p_val_vs_std), decimals = 4) |>
|
|
gt::data_color(
|
|
columns = avg_pr_auc,
|
|
palette = c("#ffcccc", "#ffffff", "#ccffcc") # Red-White-Green scale
|
|
)
|
|
}
|
|
|
|
#' Plot Effectiveness vs Efficiency
|
|
#' @param results_df Tibble from run_imbalance_tournament
|
|
#' @importFrom ggplot2 ggplot aes geom_point scale_color_manual labs theme_minimal
|
|
#' @importFrom ggrepel geom_text_repel
|
|
#' @importFrom cowplot theme_half_open background_grid
|
|
plot_efficiency <- function(results_df) {
|
|
# Aggregate by recipe
|
|
plot_data <- results_df |>
|
|
dplyr::group_by(recipe) |>
|
|
dplyr::summarize(
|
|
avg_pr_auc = mean(pr_auc),
|
|
avg_time = mean(runtime_sec)
|
|
)
|
|
|
|
ggplot2::ggplot(plot_data, ggplot2::aes(x = avg_time, y = avg_pr_auc)) +
|
|
ggplot2::geom_point(ggplot2::aes(color = recipe == "Standard"), size = 5) +
|
|
ggplot2::scale_color_manual(values = c("TRUE" = "#E74C3C", "FALSE" = "#2C3E50")) +
|
|
ggrepel::geom_text_repel(
|
|
ggplot2::aes(label = recipe, fontface = ifelse(recipe == "Standard", "bold", "plain")),
|
|
family = "Atkinson Hyperlegible"
|
|
) +
|
|
ggplot2::labs(
|
|
title = "Strategy Showdown",
|
|
x = "Avg Training Time (s)",
|
|
y = "PR-AUC"
|
|
) +
|
|
cowplot::theme_half_open(font_family = "Atkinson Hyperlegible") +
|
|
cowplot::background_grid(major = "y")
|
|
}
|
|
#' Build EDA Recipe
|
|
#' @param eda_data Raw EDA data
|
|
#' @importFrom recipes recipe update_role step_novel step_unknown step_impute_median step_dummy all_nominal_predictors all_numeric_predictors prep
|
|
#' @export
|
|
build_eda_recipe <- function(eda_data) {
|
|
recipe(outcome ~ ., data = eda_data) |>
|
|
update_role(month, new_role = "ID") |>
|
|
step_novel(all_nominal_predictors()) |>
|
|
step_unknown(all_nominal_predictors()) |>
|
|
step_impute_median(all_numeric_predictors()) |>
|
|
step_dummy(all_nominal_predictors(), one_hot = TRUE) |>
|
|
prep()
|
|
}
|
|
|
|
#' Train Diagnostic Model
|
|
#' @param baked_data Baked EDA data
|
|
#' @importFrom dplyr select
|
|
#' @importFrom lightgbm lgb.Dataset lgb.train
|
|
#' @export
|
|
train_diag_model <- function(baked_data) {
|
|
X_eda <- select(baked_data, -outcome, -month)
|
|
X_eda <- as.matrix(X_eda)
|
|
y_eda <- as.numeric(baked_data$outcome == "Fraud")
|
|
|
|
dtrain <- lgb.Dataset(data = X_eda, label = y_eda)
|
|
|
|
lgb.train(
|
|
params = list(objective = "binary", metric = "auc", device = "cpu"),
|
|
data = dtrain,
|
|
nrounds = 100,
|
|
verbose = -1
|
|
)
|
|
}
|
|
|
|
#' Plot Variable Importance
|
|
#' @param model Trained LightGBM model
|
|
#' @param title Character. Plot title. Default "".
|
|
#' @importFrom lightgbm lgb.importance
|
|
#' @importFrom dplyr slice_max
|
|
#' @importFrom ggplot2 ggplot aes geom_segment geom_point coord_flip scale_y_continuous labs expansion
|
|
#' @importFrom cowplot theme_minimal_vgrid
|
|
#' @importFrom stats reorder
|
|
#' @importFrom scales percent
|
|
#' @export
|
|
plot_var_imp <- function(model, title = "") {
|
|
importance_df <- lgb.importance(model, percentage = TRUE)
|
|
plot_data <- slice_max(importance_df, Gain, n = 15)
|
|
|
|
ggplot(plot_data, aes(x = reorder(Feature, Gain), y = Gain)) +
|
|
geom_segment(aes(xend = reorder(Feature, Gain), yend = 0), linewidth = 0.8) +
|
|
geom_point(size = 3.5) +
|
|
coord_flip() +
|
|
scale_y_continuous(labels = percent, expand = expansion(mult = c(0, 0.05))) +
|
|
labs(
|
|
title = title,
|
|
x = NULL, y = "Relative Importance"
|
|
) +
|
|
theme_minimal_vgrid(font_family = "Atkinson Hyperlegible")
|
|
}
|
|
|
|
#' Plot Hexbin Interaction
|
|
#' @param baked_data Baked EDA data
|
|
#' @param title Character. Plot title. Default "".
|
|
#' @importFrom dplyr mutate
|
|
#' @importFrom ggplot2 ggplot aes stat_summary_hex labs
|
|
#' @importFrom colorspace scale_fill_continuous_sequential
|
|
#' @importFrom cowplot theme_minimal_grid
|
|
#' @importFrom scales percent
|
|
#' @export
|
|
plot_hexbin_interaction <- function(baked_data, title = "") {
|
|
plot_data <- mutate(baked_data, fraud_flag = ifelse(outcome == "Fraud", 1, 0))
|
|
|
|
ggplot(plot_data, aes(x = current_address_months_count, y = credit_risk_score, z = fraud_flag)) +
|
|
stat_summary_hex(
|
|
bins = 30,
|
|
fun = function(z) if (length(z) >= 50) mean(z) else NA_real_
|
|
) +
|
|
scale_fill_continuous_sequential(
|
|
palette = "Viridis",
|
|
labels = percent,
|
|
na.value = "transparent",
|
|
rev = FALSE
|
|
) +
|
|
labs(
|
|
title = title,
|
|
x = "Months at Current Address", y = "Credit Risk Score", fill = "Fraud Rate"
|
|
) +
|
|
theme_minimal_grid(font_family = "Atkinson Hyperlegible")
|
|
}
|
|
|
|
#' Plot Missingness Signal
|
|
#' @param eda_data Raw EDA data
|
|
#' @param title Character. Plot title. Default "".
|
|
#' @importFrom dplyr group_by summarise across everything filter
|
|
#' @importFrom tidyr pivot_longer
|
|
#' @importFrom ggplot2 ggplot aes geom_linerange geom_point coord_flip scale_y_continuous labs position_dodge theme expansion
|
|
#' @importFrom colorspace scale_color_discrete_qualitative
|
|
#' @importFrom cowplot theme_minimal_vgrid
|
|
#' @importFrom scales percent
|
|
#' @importFrom stats reorder
|
|
#' @export
|
|
plot_missingness <- function(eda_data, title = "") {
|
|
missing_summary <- eda_data |>
|
|
group_by(outcome) |>
|
|
summarise(across(everything(), ~ mean(is.na(.x))), .groups = "drop") |>
|
|
pivot_longer(cols = -outcome, names_to = "feature", values_to = "pct_missing") |>
|
|
filter(pct_missing > 0.05)
|
|
|
|
ggplot(missing_summary, aes(x = reorder(feature, pct_missing), y = pct_missing, color = outcome)) +
|
|
geom_linerange(
|
|
aes(ymin = 0, ymax = pct_missing),
|
|
position = position_dodge(width = 0.5),
|
|
linewidth = 0.8
|
|
) +
|
|
geom_point(position = position_dodge(width = 0.5), size = 4) +
|
|
coord_flip() +
|
|
scale_y_continuous(
|
|
labels = percent,
|
|
expand = expansion(mult = c(0, 0.1)),
|
|
breaks = seq(0, 1, by = 0.25)
|
|
) +
|
|
scale_color_discrete_qualitative(palette = "Dark 3") +
|
|
labs(
|
|
title = title,
|
|
x = NULL, y = "Percent Missing", color = "Outcome"
|
|
) +
|
|
theme_minimal_vgrid(font_family = "Atkinson Hyperlegible") +
|
|
theme(legend.position = "right")
|
|
}
|
|
|
|
#' Plot Numeric Correlation Matrix
|
|
#' @param eda_data Raw EDA data
|
|
#' @param title Character. Plot title. Default "".
|
|
#' @importFrom dplyr select
|
|
#' @importFrom tidyselect where
|
|
#' @importFrom stats sd
|
|
#' @importFrom corrr correlate rearrange shave stretch
|
|
#' @importFrom ggplot2 ggplot aes geom_tile geom_text labs theme element_text element_blank expansion
|
|
#' @importFrom colorspace scale_fill_continuous_diverging
|
|
#' @importFrom cowplot theme_minimal_vgrid
|
|
#' @export
|
|
plot_num_cor <- function(eda_data, title = "") {
|
|
cor_numeric_only <- eda_data |>
|
|
select(where(is.numeric), -month) |>
|
|
select(where(~ isTRUE(sd(.x, na.rm = TRUE) > 0))) |>
|
|
correlate(quiet = TRUE) |>
|
|
rearrange() |>
|
|
shave()
|
|
|
|
cor_long <- stretch(cor_numeric_only, na.rm = TRUE)
|
|
|
|
ggplot(cor_long, aes(x = x, y = y, fill = r)) +
|
|
geom_tile(color = "white", linewidth = 0.5) +
|
|
scale_fill_continuous_diverging(
|
|
palette = "Green-Brown",
|
|
mid = 0,
|
|
limit = c(-1, 1),
|
|
name = "Pearson (r)",
|
|
expand = expansion(mult = c(0, 0))
|
|
) +
|
|
geom_text(
|
|
aes(label = ifelse(abs(r) > 0.2, round(r, 2), "")),
|
|
color = "black",
|
|
size = 3.5,
|
|
family = "Atkinson Hyperlegible"
|
|
) +
|
|
labs(
|
|
title = title,
|
|
x = NULL, y = NULL
|
|
) +
|
|
theme_minimal_vgrid(font_family = "Atkinson Hyperlegible") +
|
|
theme(
|
|
axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1, size = 10, face = "bold"),
|
|
axis.text.y = element_text(size = 10, face = "bold"),
|
|
panel.grid.major = element_blank()
|
|
)
|
|
}
|
|
|
|
#' Engineer features for the BAF dataset
|
|
#'
|
|
#' Reads the primary BAF dataset and engineers new features, such as
|
|
#' `n_missing`, which counts the number of missing values across key
|
|
#' tenure and financial columns. This calculation is performed out-of-memory
|
|
#' using Arrow compute.
|
|
#'
|
|
#' @param in_prefix Character. Input dataset prefix (e.g., "03_primary/variant=Base").
|
|
#' @param out_prefix Character. Output dataset prefix (e.g., "04_feature/variant=Base").
|
|
#' @param bucket_name Character. The S3/MinIO bucket name. Default "baf-fraud".
|
|
#' @param partitioning Character vector. Columns to partition by. Default "month".
|
|
#' @param existing_data_behavior Character. Behavior when data exists. Default "delete_matching".
|
|
#' @param verbose Logical. Whether to print progress messages. Default TRUE.
|
|
#'
|
|
#' @return Character. The output prefix path for downstream targets.
|
|
#' @export
|
|
#'
|
|
#' @importFrom arrow s3_bucket open_dataset write_dataset
|
|
#' @importFrom dplyr mutate
|
|
engineer_features <- function(
|
|
in_prefix = "03_primary/variant=Base",
|
|
out_prefix = "04_feature/variant=Base",
|
|
bucket_name = "baf-fraud",
|
|
partitioning = "month",
|
|
existing_data_behavior = "delete_matching",
|
|
verbose = TRUE
|
|
) {
|
|
|
|
endpoint <- Sys.getenv("BAF_ENDPOINT")
|
|
key <- Sys.getenv("BAF_KEY")
|
|
secret <- Sys.getenv("BAF_SECRET")
|
|
|
|
if (endpoint == "") stop("Missing env var: BAF_ENDPOINT")
|
|
|
|
if (verbose) message("Connecting to MinIO bucket: ", bucket_name)
|
|
|
|
b <- arrow::s3_bucket(
|
|
bucket_name,
|
|
endpoint_override = endpoint,
|
|
scheme = "http",
|
|
access_key = key,
|
|
secret_key = secret,
|
|
region = "us-east-1"
|
|
)
|
|
|
|
if (verbose) message("Opening primary dataset: ", in_prefix)
|
|
ds_primary <- arrow::open_dataset(b$path(in_prefix), format = "parquet")
|
|
|
|
if (verbose) message("Engineering 'n_missing' feature...")
|
|
ds_feature <- ds_primary |>
|
|
dplyr::mutate(
|
|
n_missing = as.integer(is.na(prev_address_months_count)) +
|
|
as.integer(is.na(current_address_months_count)) +
|
|
as.integer(is.na(bank_months_count)) +
|
|
as.integer(is.na(session_length_in_minutes)) +
|
|
as.integer(is.na(device_distinct_emails)) +
|
|
as.integer(is.na(intended_balcon_amount))
|
|
)
|
|
|
|
if (verbose) message("Writing feature dataset to: ", out_prefix)
|
|
arrow::write_dataset(
|
|
dataset = ds_feature,
|
|
path = b$path(out_prefix),
|
|
format = "parquet",
|
|
partitioning = partitioning,
|
|
existing_data_behavior = existing_data_behavior
|
|
)
|
|
|
|
if (verbose) message("\u2714 Feature engineering complete!")
|
|
|
|
out_prefix
|
|
}
|
|
|
|
#' Generate Resampled Model Inputs
|
|
#'
|
|
#' Reads the engineered feature layer, prepares a base tidymodels recipe,
|
|
#' and generates resampled datasets (Baseline, Under, SMOTE, Adasyn, Tomek)
|
|
#' across all months, saving them to the 05_model_input prefix.
|
|
#'
|
|
#' @param feature_prefix Character. Input prefix (e.g., "04_feature/variant=Base").
|
|
#' @param out_prefix Character. Output prefix base (e.g., "05_model_input").
|
|
#' @param bucket_name Character. Bucket name. Default "baf-fraud".
|
|
#'
|
|
#' @return Character. The output prefix (for targets dependency tracking).
|
|
#' @export
|
|
#'
|
|
#' @importFrom arrow s3_bucket open_dataset write_parquet
|
|
#' @importFrom dplyr filter collect mutate group_by slice_sample ungroup select
|
|
#' @importFrom recipes recipe update_role step_novel step_unknown step_indicate_na step_impute_median step_dummy step_zv all_nominal_predictors all_numeric_predictors all_predictors prep bake
|
|
#' @importFrom themis smote adasyn step_tomek
|
|
#' @importFrom lubridate %m+%
|
|
#' @importFrom glue glue
|
|
generate_model_inputs <- function(
|
|
feature_prefix = "04_feature/variant=Base",
|
|
out_prefix = "05_model_input",
|
|
bucket_name = "baf-fraud"
|
|
) {
|
|
|
|
endpoint <- Sys.getenv("BAF_ENDPOINT")
|
|
key <- Sys.getenv("BAF_KEY")
|
|
secret <- Sys.getenv("BAF_SECRET")
|
|
|
|
if (endpoint == "") stop("Missing env var: BAF_ENDPOINT")
|
|
|
|
b <- s3_bucket(
|
|
bucket_name,
|
|
endpoint_override = endpoint,
|
|
scheme = "http",
|
|
access_key = key,
|
|
secret_key = secret,
|
|
region = "us-east-1"
|
|
)
|
|
|
|
message("Opening feature dataset: ", feature_prefix)
|
|
ds_feature <- open_dataset(b$path(feature_prefix))
|
|
|
|
# 1. Prep Sample with the Date Column
|
|
message("Preparing base recipe on Month 0 sample...")
|
|
sample_data <- ds_feature |>
|
|
filter(month == 0) |>
|
|
head(5000) |>
|
|
collect() |>
|
|
mutate(month_date = as.Date("2025-02-01") %m+% months(month))
|
|
|
|
rec_base <- recipe(outcome ~ ., data = sample_data) |>
|
|
update_role(month_date, new_role = "ID") |>
|
|
step_novel(all_nominal_predictors()) |>
|
|
step_unknown(all_nominal_predictors()) |>
|
|
step_indicate_na(all_numeric_predictors()) |>
|
|
step_impute_median(all_numeric_predictors()) |>
|
|
step_dummy(all_nominal_predictors(), one_hot = TRUE) |>
|
|
step_zv(all_predictors()) |>
|
|
prep()
|
|
|
|
# 2. The S3-to-S3 Loop
|
|
for (m in 0:7) {
|
|
message("Baking and sampling month ", m, "...")
|
|
|
|
raw_df <- ds_feature |>
|
|
filter(month == m) |>
|
|
collect() |>
|
|
mutate(month_date = as.Date("2025-02-01") %m+% months(month))
|
|
|
|
baked_df <- bake(rec_base, new_data = raw_df)
|
|
|
|
# SAVE BASELINE
|
|
write_parquet(baked_df, b$path(glue("{out_prefix}/baseline/month={m}/part-0.parquet")))
|
|
|
|
# PREP NUMERIC-ONLY FOR SAMPLING
|
|
numeric_only_df <- baked_df |> select(-month_date)
|
|
|
|
# Fork: Under
|
|
baked_under <- numeric_only_df |> group_by(outcome) |> slice_sample(prop = 0.25) |> ungroup()
|
|
write_parquet(baked_under, b$path(glue("{out_prefix}/under/month={m}/part-0.parquet")))
|
|
|
|
# Fork: Smote
|
|
baked_smote <- smote(numeric_only_df, var = "outcome", over_ratio = 0.5)
|
|
write_parquet(baked_smote, b$path(glue("{out_prefix}/smote/month={m}/part-0.parquet")))
|
|
|
|
# Fork: Adasyn
|
|
baked_adasyn <- adasyn(numeric_only_df, var = "outcome", over_ratio = 0.5, k = 5)
|
|
write_parquet(baked_adasyn, b$path(glue("{out_prefix}/adasyn/month={m}/part-0.parquet")))
|
|
|
|
# Fork: Tomek
|
|
baked_tomek <- recipe(outcome ~ ., data = numeric_only_df) |>
|
|
step_tomek(outcome) |>
|
|
prep() |>
|
|
bake(new_data = NULL)
|
|
write_parquet(baked_tomek, b$path(glue("{out_prefix}/tomek/month={m}/part-0.parquet")))
|
|
|
|
# Cleanup RAM after each month
|
|
rm(raw_df, baked_df, numeric_only_df, baked_under, baked_smote, baked_adasyn, baked_tomek)
|
|
gc()
|
|
}
|
|
|
|
message("\u2714 Model inputs generated successfully!")
|
|
out_prefix
|
|
}
|
|
|
|
#' Final Model Evaluation (Months 6 & 7)
|
|
#'
|
|
#' Trains the winning strategy on the full training set (Months 0-5)
|
|
#' and evaluates it on the unseen test set (Months 6-7).
|
|
#'
|
|
#' @param params A named list of LightGBM hyperparameters with elements:
|
|
#' \code{trees}, \code{tree_depth}, \code{learn_rate}, \code{loss_reduction}, \code{min_n}.
|
|
#' @param bucket_name Character. Bucket name. Default "baf-fraud".
|
|
#' @param inputs_prefix Character. Model input prefix. Default "05_model_input".
|
|
#'
|
|
#' @return A tibble with columns \code{truth}, \code{prob}, and \code{pred_class}.
|
|
#' @export
|
|
evaluate_final_model <- function(params, bucket_name = "baf-fraud", inputs_prefix = "05_model_input") {
|
|
|
|
b <- arrow::s3_bucket(bucket_name, endpoint_override = Sys.getenv("BAF_ENDPOINT"),
|
|
scheme = "http", access_key = Sys.getenv("BAF_KEY"),
|
|
secret_key = Sys.getenv("BAF_SECRET"), region = "us-east-1")
|
|
|
|
# 1. FULL TRAIN (Months 0-5)
|
|
train_df <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/baseline"))) |>
|
|
dplyr::filter(month %in% 0:5) |> dplyr::collect()
|
|
|
|
X_train <- as.matrix(train_df |> dplyr::select(-outcome, -dplyr::any_of(c("month", "month_date"))))
|
|
y_train <- as.numeric(train_df$outcome == "Fraud")
|
|
|
|
model <- lightgbm::lgb.train(
|
|
params = list(
|
|
objective = "binary",
|
|
metric = "auc",
|
|
learning_rate = params$learn_rate,
|
|
max_depth = params$tree_depth,
|
|
num_leaves = 2^params$tree_depth - 1L,
|
|
min_data_in_leaf = params$min_n
|
|
),
|
|
data = lightgbm::lgb.Dataset(X_train, label = y_train), nrounds = params$trees, verbose = -1
|
|
)
|
|
|
|
# 2. FINAL EXAM (Months 6-7)
|
|
test_df <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/baseline"))) |>
|
|
dplyr::filter(month %in% 6:7) |> dplyr::collect()
|
|
|
|
X_test <- as.matrix(test_df |> dplyr::select(-outcome, -dplyr::any_of(c("month", "month_date"))))
|
|
preds <- predict(model, X_test)
|
|
|
|
# 3. GENERATE METRICS
|
|
eval_df <- dplyr::tibble(
|
|
truth = factor(test_df$outcome, levels = c("Fraud", "Legit")),
|
|
prob = preds,
|
|
pred_class = factor(ifelse(prob >= 0.05, "Fraud", "Legit"), levels = c("Fraud", "Legit"))
|
|
)
|
|
|
|
return(eval_df)
|
|
}
|
|
|
|
#' Plot Confusion Matrix Heatmap
|
|
#'
|
|
#' Generates a styled 4-quadrant heatmap from a yardstick confusion matrix.
|
|
#'
|
|
#' @param cm A yardstick conf_mat object.
|
|
#' @param title Character. The main title of the plot.
|
|
#' @param subtitle Character. The subtitle of the plot.
|
|
#'
|
|
#' @return A ggplot object.
|
|
#' @export
|
|
#'
|
|
#' @importFrom ggplot2 autoplot scale_fill_gradient labs theme_minimal theme element_text
|
|
plot_conf_mat_heatmap <- function(
|
|
cm,
|
|
title = ""
|
|
) {
|
|
|
|
p <- ggplot2::autoplot(cm, type = "heatmap") +
|
|
ggplot2::scale_fill_gradient(low = "#F3F4F6", high = "#1D4ED8") +
|
|
ggplot2::labs(
|
|
title = title
|
|
) +
|
|
ggplot2::theme_minimal(base_size = 14) +
|
|
ggplot2::theme(
|
|
legend.position = "none",
|
|
plot.title = ggplot2::element_text(face = "bold")
|
|
)
|
|
|
|
return(p)
|
|
}
|
|
#' Train and Serialize Production LightGBM Model
|
|
#'
|
|
#' Trains a LightGBM model on the complete dataset using the winning
|
|
#' hyperparameters, serializes it to a text file, and uploads it directly
|
|
#' to MinIO via the Apache Arrow S3 interface.
|
|
#'
|
|
#' @param data A data frame containing the full BAF dataset (Months 0-7).
|
|
#' @param recipe A prepared tidymodels recipe.
|
|
#' @param best_params A list or tibble of the winning hyperparameters.
|
|
#' @param model_filename Character. The target filename. Defaults to "lgbm_prod.txt".
|
|
#'
|
|
#' @return Character. The MinIO URI of the uploaded model artifact.
|
|
#' @export
|
|
#'
|
|
#' @importFrom parsnip boost_tree set_engine set_mode
|
|
#' @importFrom workflows workflow add_recipe add_model fit extract_fit_engine
|
|
#' @importFrom lightgbm lgb.save
|
|
#' @importFrom arrow S3FileSystem
|
|
train_production_model <- function(data, recipe, best_params, model_filename = "lgbm_prod.txt") {
|
|
|
|
# 1. Define the production model specification
|
|
lgbm_spec <- parsnip::boost_tree(
|
|
trees = best_params$trees,
|
|
tree_depth = best_params$tree_depth,
|
|
learn_rate = best_params$learn_rate,
|
|
min_n = best_params$min_n
|
|
) |>
|
|
parsnip::set_engine("lightgbm", is_unbalance = TRUE) |>
|
|
parsnip::set_mode("classification")
|
|
|
|
# 2. Bundle the workflow and fit to the ENTIRE dataset
|
|
prod_wflow <- workflows::workflow() |>
|
|
workflows::add_recipe(recipe) |>
|
|
workflows::add_model(lgbm_spec)
|
|
|
|
fitted_prod_wflow <- workflows::fit(prod_wflow, data = data)
|
|
|
|
# 3. Extract the raw LightGBM C++ booster object
|
|
lgbm_booster <- workflows::extract_fit_engine(fitted_prod_wflow)
|
|
|
|
# 4. Serialize to local disk temporarily
|
|
temp_dir <- tempdir()
|
|
local_path <- file.path(temp_dir, model_filename)
|
|
lightgbm::lgb.save(lgbm_booster, local_path)
|
|
|
|
# 5. Connect to MinIO via Arrow using exact .Renviron credentials
|
|
s3 <- arrow::S3FileSystem$create(
|
|
access_key = Sys.getenv("BAF_KEY"),
|
|
secret_key = Sys.getenv("BAF_SECRET"),
|
|
endpoint_override = Sys.getenv("BAF_ENDPOINT"),
|
|
scheme = "http" # 172.19.0.1 is an internal IP, using HTTP over port 9100
|
|
)
|
|
|
|
# 6. Open an Arrow output stream and push the binary data to MinIO
|
|
bucket_name <- Sys.getenv("BAF_BUCKET")
|
|
s3_path <- file.path(bucket_name, "06_models", model_filename)
|
|
|
|
out_stream <- s3$OpenOutputStream(s3_path)
|
|
file_size <- file.info(local_path)$size
|
|
raw_bytes <- readBin(local_path, "raw", n = file_size)
|
|
|
|
out_stream$write(raw_bytes)
|
|
out_stream$close()
|
|
|
|
# Clean up the local temporary file
|
|
unlink(local_path)
|
|
|
|
# 7. Return the storage URI for pipeline tracking
|
|
paste0("minio://", s3_path)
|
|
}
|
|
#' Build Untrained BAF Recipe
|
|
#'
|
|
#' @param data A data frame
|
|
#'
|
|
#' @return An untrained tidymodels recipe
|
|
#' @export
|
|
#'
|
|
#' @importFrom recipes recipe update_role step_novel step_unknown step_indicate_na step_impute_median step_dummy step_zv all_nominal_predictors all_numeric_predictors all_predictors
|
|
build_baf_recipe <- function(data) {
|
|
recipes::recipe(outcome ~ ., data = data) |>
|
|
recipes::update_role(month, new_role = "ID") |>
|
|
recipes::step_novel(recipes::all_nominal_predictors()) |>
|
|
recipes::step_unknown(recipes::all_nominal_predictors()) |>
|
|
recipes::step_indicate_na(recipes::all_numeric_predictors()) |>
|
|
recipes::step_impute_median(recipes::all_numeric_predictors()) |>
|
|
recipes::step_dummy(recipes::all_nominal_predictors(), one_hot = TRUE) |>
|
|
recipes::step_zv(recipes::all_predictors())
|
|
|
|
# Notice: NO prep() here!
|
|
}
|
|
|
|
#' Tune LightGBM Hyperparameters
|
|
#'
|
|
#' Performs a grid search over LightGBM hyperparameters using the same rolling
|
|
#' time windows as the imbalance tournament. Optimises PR-AUC on the pre-baked
|
|
#' baseline data stored in MinIO. Returns the best parameters as a named list
|
|
#' ready for use in \code{evaluate_final_model()} and
|
|
#' \code{train_production_model()}.
|
|
#'
|
|
#' @param imbalance_windows A tibble with columns \code{window_id},
|
|
#' \code{train_months}, and \code{test_month}, as produced by the
|
|
#' \code{imbalance_windows} target.
|
|
#' @param bucket_name Character. MinIO bucket name. Default \code{"baf-fraud"}.
|
|
#' @param inputs_prefix Character. Prefix for the model input layer.
|
|
#' Default \code{"05_model_input"}.
|
|
#' @param grid_size Integer. Number of space-filling candidates. Default \code{30}.
|
|
#' @param seed Integer. Random seed for reproducibility. Default \code{42}.
|
|
#'
|
|
#' @return A named list with elements \code{trees}, \code{tree_depth},
|
|
#' \code{learn_rate}, and \code{min_n}.
|
|
#' @export
|
|
#'
|
|
#' @importFrom arrow s3_bucket open_dataset
|
|
#' @importFrom dplyr filter collect mutate any_of
|
|
#' @importFrom purrr map
|
|
#' @importFrom rsample make_splits manual_rset
|
|
#' @importFrom recipes recipe update_role step_zv all_predictors
|
|
#' @importFrom parsnip boost_tree set_engine set_mode
|
|
#' @importFrom workflows workflow add_recipe add_model
|
|
#' @importFrom dials grid_space_filling trees tree_depth learn_rate min_n
|
|
#' @importFrom tune tune tune_grid control_grid select_best
|
|
#' @importFrom yardstick metric_set pr_auc
|
|
tune_lgbm <- function(
|
|
imbalance_windows,
|
|
bucket_name = "baf-fraud",
|
|
inputs_prefix = "05_model_input",
|
|
grid_size = 30L,
|
|
seed = 42L
|
|
) {
|
|
b <- arrow::s3_bucket(
|
|
bucket_name,
|
|
endpoint_override = Sys.getenv("BAF_ENDPOINT"),
|
|
scheme = "http",
|
|
access_key = Sys.getenv("BAF_KEY"),
|
|
secret_key = Sys.getenv("BAF_SECRET"),
|
|
region = "us-east-1"
|
|
)
|
|
|
|
message("Loading baseline data (months 0-5) for tuning...")
|
|
tune_data <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/baseline"))) |>
|
|
dplyr::filter(month %in% 0:5) |>
|
|
dplyr::collect() |>
|
|
dplyr::mutate(outcome = factor(outcome, levels = c("Fraud", "Legit")))
|
|
|
|
message("Rows loaded: ", nrow(tune_data))
|
|
|
|
# Build rolling window resamples matching the tournament windows
|
|
splits <- purrr::map(
|
|
seq_len(nrow(imbalance_windows)),
|
|
function(i) {
|
|
win <- imbalance_windows[i, ]
|
|
train_idx <- which(tune_data$month %in% win$train_months[[1]])
|
|
test_idx <- which(tune_data$month == win$test_month)
|
|
rsample::make_splits(
|
|
list(analysis = train_idx, assessment = test_idx),
|
|
data = tune_data
|
|
)
|
|
}
|
|
)
|
|
rolling_cv <- rsample::manual_rset(splits, ids = imbalance_windows$window_id)
|
|
|
|
# Minimal recipe — data is already baked; just remove ID columns
|
|
tune_recipe <- recipes::recipe(outcome ~ ., data = tune_data) |>
|
|
recipes::update_role(dplyr::any_of(c("month", "month_date")), new_role = "ID") |>
|
|
recipes::step_zv(recipes::all_predictors())
|
|
|
|
lgbm_spec <- parsnip::boost_tree(
|
|
trees = tune::tune(),
|
|
tree_depth = tune::tune(),
|
|
learn_rate = tune::tune(),
|
|
min_n = tune::tune()
|
|
) |>
|
|
parsnip::set_engine("lightgbm", num_threads = parallel::detectCores()) |>
|
|
parsnip::set_mode("classification")
|
|
|
|
tune_wflow <- workflows::workflow() |>
|
|
workflows::add_recipe(tune_recipe) |>
|
|
workflows::add_model(lgbm_spec)
|
|
|
|
set.seed(seed)
|
|
lgbm_grid <- dials::grid_space_filling(
|
|
dials::trees(range = c(100L, 1000L)),
|
|
dials::tree_depth(range = c(3L, 8L)),
|
|
dials::learn_rate(range = c(-3, -1)),
|
|
dials::min_n(range = c(100L, 500L)),
|
|
size = grid_size
|
|
)
|
|
|
|
message("Starting hyperparameter tuning (", grid_size, " candidates x ",
|
|
nrow(imbalance_windows), " windows)...")
|
|
set.seed(seed)
|
|
tune_results <- tune::tune_grid(
|
|
tune_wflow,
|
|
resamples = rolling_cv,
|
|
grid = lgbm_grid,
|
|
metrics = yardstick::metric_set(yardstick::pr_auc),
|
|
control = tune::control_grid(verbose = TRUE, save_pred = FALSE)
|
|
)
|
|
|
|
best <- tune::select_best(tune_results, metric = "pr_auc")
|
|
message("Best PR-AUC params: trees=", best$trees, " tree_depth=", best$tree_depth,
|
|
" learn_rate=", round(best$learn_rate, 5), " min_n=", best$min_n)
|
|
|
|
list(
|
|
trees = best$trees,
|
|
tree_depth = best$tree_depth,
|
|
learn_rate = best$learn_rate,
|
|
min_n = best$min_n
|
|
)
|
|
} |