Add tune_lgbm() and wire hyperparameter tuning into DAG

Converts scratch/tune_model.R into a pure tune_lgbm() function,
replacing hardcoded winning_params with a fully automated tar_target.
Best params (trees=844, depth=3, lr=0.0204, min_n=389) now flow
reproducibly into evaluate_final_model() and train_production_model().
PR-AUC improved from 0.165 to 0.198.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 03:25:35 -05:00
parent 33d0fc31c7
commit f47b2e1be2
7 changed files with 178 additions and 9 deletions

View File

@@ -23,6 +23,7 @@ export(save_report_figure)
export(save_report_table)
export(train_diag_model)
export(train_production_model)
export(tune_lgbm)
importFrom(arrow,S3FileSystem)
importFrom(arrow,open_dataset)
importFrom(arrow,read_csv_arrow)
@@ -43,6 +44,11 @@ importFrom(cowplot,theme_cowplot)
importFrom(cowplot,theme_half_open)
importFrom(cowplot,theme_minimal_grid)
importFrom(cowplot,theme_minimal_vgrid)
importFrom(dials,grid_space_filling)
importFrom(dials,learn_rate)
importFrom(dials,min_n)
importFrom(dials,tree_depth)
importFrom(dials,trees)
importFrom(dplyr,`%>%`)
importFrom(dplyr,across)
importFrom(dplyr,any_of)
@@ -105,6 +111,7 @@ importFrom(lubridate,"%m+%")
importFrom(parsnip,boost_tree)
importFrom(parsnip,set_engine)
importFrom(parsnip,set_mode)
importFrom(purrr,map)
importFrom(quarto,quarto_render)
importFrom(readr,write_rds)
importFrom(recipes,all_nominal_predictors)
@@ -120,6 +127,8 @@ importFrom(recipes,step_novel)
importFrom(recipes,step_unknown)
importFrom(recipes,step_zv)
importFrom(recipes,update_role)
importFrom(rsample,make_splits)
importFrom(rsample,manual_rset)
importFrom(scales,percent)
importFrom(stats,reorder)
importFrom(stats,sd)
@@ -131,9 +140,14 @@ importFrom(themis,smote)
importFrom(themis,step_tomek)
importFrom(tidyr,pivot_longer)
importFrom(tidyselect,where)
importFrom(tune,control_grid)
importFrom(tune,select_best)
importFrom(tune,tune)
importFrom(tune,tune_grid)
importFrom(workflows,add_model)
importFrom(workflows,add_recipe)
importFrom(workflows,extract_fit_engine)
importFrom(workflows,fit)
importFrom(workflows,workflow)
importFrom(yardstick,metric_set)
importFrom(yardstick,pr_auc)

View File

@@ -1208,3 +1208,123 @@ build_baf_recipe <- function(data) {
# Notice: NO prep() here!
}
#' Tune LightGBM Hyperparameters
#'
#' Performs a grid search over LightGBM hyperparameters using the same rolling
#' time windows as the imbalance tournament. Optimises PR-AUC on the pre-baked
#' baseline data stored in MinIO. Returns the best parameters as a named list
#' ready for use in \code{evaluate_final_model()} and
#' \code{train_production_model()}.
#'
#' @param imbalance_windows A tibble with columns \code{window_id},
#' \code{train_months}, and \code{test_month}, as produced by the
#' \code{imbalance_windows} target.
#' @param bucket_name Character. MinIO bucket name. Default \code{"baf-fraud"}.
#' @param inputs_prefix Character. Prefix for the model input layer.
#' Default \code{"05_model_input"}.
#' @param grid_size Integer. Number of space-filling candidates. Default \code{30}.
#' @param seed Integer. Random seed for reproducibility. Default \code{42}.
#'
#' @return A named list with elements \code{trees}, \code{tree_depth},
#' \code{learn_rate}, and \code{min_n}.
#' @export
#'
#' @importFrom arrow s3_bucket open_dataset
#' @importFrom dplyr filter collect mutate any_of
#' @importFrom purrr map
#' @importFrom rsample make_splits manual_rset
#' @importFrom recipes recipe update_role step_zv all_predictors
#' @importFrom parsnip boost_tree set_engine set_mode
#' @importFrom workflows workflow add_recipe add_model
#' @importFrom dials grid_space_filling trees tree_depth learn_rate min_n
#' @importFrom tune tune tune_grid control_grid select_best
#' @importFrom yardstick metric_set pr_auc
tune_lgbm <- function(
imbalance_windows,
bucket_name = "baf-fraud",
inputs_prefix = "05_model_input",
grid_size = 30L,
seed = 42L
) {
b <- arrow::s3_bucket(
bucket_name,
endpoint_override = Sys.getenv("BAF_ENDPOINT"),
scheme = "http",
access_key = Sys.getenv("BAF_KEY"),
secret_key = Sys.getenv("BAF_SECRET"),
region = "us-east-1"
)
message("Loading baseline data (months 0-5) for tuning...")
tune_data <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/baseline"))) |>
dplyr::filter(month %in% 0:5) |>
dplyr::collect() |>
dplyr::mutate(outcome = factor(outcome, levels = c("Fraud", "Legit")))
message("Rows loaded: ", nrow(tune_data))
# Build rolling window resamples matching the tournament windows
splits <- purrr::map(
seq_len(nrow(imbalance_windows)),
function(i) {
win <- imbalance_windows[i, ]
train_idx <- which(tune_data$month %in% win$train_months[[1]])
test_idx <- which(tune_data$month == win$test_month)
rsample::make_splits(
list(analysis = train_idx, assessment = test_idx),
data = tune_data
)
}
)
rolling_cv <- rsample::manual_rset(splits, ids = imbalance_windows$window_id)
# Minimal recipe — data is already baked; just remove ID columns
tune_recipe <- recipes::recipe(outcome ~ ., data = tune_data) |>
recipes::update_role(dplyr::any_of(c("month", "month_date")), new_role = "ID") |>
recipes::step_zv(recipes::all_predictors())
lgbm_spec <- parsnip::boost_tree(
trees = tune::tune(),
tree_depth = tune::tune(),
learn_rate = tune::tune(),
min_n = tune::tune()
) |>
parsnip::set_engine("lightgbm", num_threads = parallel::detectCores()) |>
parsnip::set_mode("classification")
tune_wflow <- workflows::workflow() |>
workflows::add_recipe(tune_recipe) |>
workflows::add_model(lgbm_spec)
set.seed(seed)
lgbm_grid <- dials::grid_space_filling(
dials::trees(range = c(100L, 1000L)),
dials::tree_depth(range = c(3L, 8L)),
dials::learn_rate(range = c(-3, -1)),
dials::min_n(range = c(100L, 500L)),
size = grid_size
)
message("Starting hyperparameter tuning (", grid_size, " candidates x ",
nrow(imbalance_windows), " windows)...")
set.seed(seed)
tune_results <- tune::tune_grid(
tune_wflow,
resamples = rolling_cv,
grid = lgbm_grid,
metrics = yardstick::metric_set(yardstick::pr_auc),
control = tune::control_grid(verbose = TRUE, save_pred = FALSE)
)
best <- tune::select_best(tune_results, metric = "pr_auc")
message("Best PR-AUC params: trees=", best$trees, " tree_depth=", best$tree_depth,
" learn_rate=", round(best$learn_rate, 5), " min_n=", best$min_n)
list(
trees = best$trees,
tree_depth = best$tree_depth,
learn_rate = best$learn_rate,
min_n = best$min_n
)
}

View File

@@ -34,6 +34,7 @@ reference:
desc: "Cross-validation and imbalance strategy testing."
contents:
- run_imbalance_tournament
- tune_lgbm
- train_diag_model
- create_efficiency_plot # Moved here: Belongs with the tournament

View File

@@ -307,12 +307,7 @@ list(
),
tar_target(
winning_params,
list(
trees = 844,
tree_depth = 3,
learn_rate = 0.0204,
min_n = 389
)
tune_lgbm(imbalance_windows)
),
tar_target(
production_model_uri,

39
man/tune_lgbm.Rd Normal file
View File

@@ -0,0 +1,39 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/functions.R
\name{tune_lgbm}
\alias{tune_lgbm}
\title{Tune LightGBM Hyperparameters}
\usage{
tune_lgbm(
imbalance_windows,
bucket_name = "baf-fraud",
inputs_prefix = "05_model_input",
grid_size = 30L,
seed = 42L
)
}
\arguments{
\item{imbalance_windows}{A tibble with columns \code{window_id},
\code{train_months}, and \code{test_month}, as produced by the
\code{imbalance_windows} target.}
\item{bucket_name}{Character. MinIO bucket name. Default \code{"baf-fraud"}.}
\item{inputs_prefix}{Character. Prefix for the model input layer.
Default \code{"05_model_input"}.}
\item{grid_size}{Integer. Number of space-filling candidates. Default \code{30}.}
\item{seed}{Integer. Random seed for reproducibility. Default \code{42}.}
}
\value{
A named list with elements \code{trees}, \code{tree_depth},
\code{learn_rate}, and \code{min_n}.
}
\description{
Performs a grid search over LightGBM hyperparameters using the same rolling
time windows as the imbalance tournament. Optimises PR-AUC on the pre-baked
baseline data stored in MinIO. Returns the best parameters as a named list
ready for use in \code{evaluate_final_model()} and
\code{train_production_model()}.
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 151 KiB

After

Width:  |  Height:  |  Size: 151 KiB