Add tune_lgbm() and wire hyperparameter tuning into DAG
Converts scratch/tune_model.R into a pure tune_lgbm() function, replacing hardcoded winning_params with a fully automated tar_target. Best params (trees=844, depth=3, lr=0.0204, min_n=389) now flow reproducibly into evaluate_final_model() and train_production_model(). PR-AUC improved from 0.165 to 0.198. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
14
NAMESPACE
14
NAMESPACE
@@ -23,6 +23,7 @@ export(save_report_figure)
|
||||
export(save_report_table)
|
||||
export(train_diag_model)
|
||||
export(train_production_model)
|
||||
export(tune_lgbm)
|
||||
importFrom(arrow,S3FileSystem)
|
||||
importFrom(arrow,open_dataset)
|
||||
importFrom(arrow,read_csv_arrow)
|
||||
@@ -43,6 +44,11 @@ importFrom(cowplot,theme_cowplot)
|
||||
importFrom(cowplot,theme_half_open)
|
||||
importFrom(cowplot,theme_minimal_grid)
|
||||
importFrom(cowplot,theme_minimal_vgrid)
|
||||
importFrom(dials,grid_space_filling)
|
||||
importFrom(dials,learn_rate)
|
||||
importFrom(dials,min_n)
|
||||
importFrom(dials,tree_depth)
|
||||
importFrom(dials,trees)
|
||||
importFrom(dplyr,`%>%`)
|
||||
importFrom(dplyr,across)
|
||||
importFrom(dplyr,any_of)
|
||||
@@ -105,6 +111,7 @@ importFrom(lubridate,"%m+%")
|
||||
importFrom(parsnip,boost_tree)
|
||||
importFrom(parsnip,set_engine)
|
||||
importFrom(parsnip,set_mode)
|
||||
importFrom(purrr,map)
|
||||
importFrom(quarto,quarto_render)
|
||||
importFrom(readr,write_rds)
|
||||
importFrom(recipes,all_nominal_predictors)
|
||||
@@ -120,6 +127,8 @@ importFrom(recipes,step_novel)
|
||||
importFrom(recipes,step_unknown)
|
||||
importFrom(recipes,step_zv)
|
||||
importFrom(recipes,update_role)
|
||||
importFrom(rsample,make_splits)
|
||||
importFrom(rsample,manual_rset)
|
||||
importFrom(scales,percent)
|
||||
importFrom(stats,reorder)
|
||||
importFrom(stats,sd)
|
||||
@@ -131,9 +140,14 @@ importFrom(themis,smote)
|
||||
importFrom(themis,step_tomek)
|
||||
importFrom(tidyr,pivot_longer)
|
||||
importFrom(tidyselect,where)
|
||||
importFrom(tune,control_grid)
|
||||
importFrom(tune,select_best)
|
||||
importFrom(tune,tune)
|
||||
importFrom(tune,tune_grid)
|
||||
importFrom(workflows,add_model)
|
||||
importFrom(workflows,add_recipe)
|
||||
importFrom(workflows,extract_fit_engine)
|
||||
importFrom(workflows,fit)
|
||||
importFrom(workflows,workflow)
|
||||
importFrom(yardstick,metric_set)
|
||||
importFrom(yardstick,pr_auc)
|
||||
|
||||
120
R/functions.R
120
R/functions.R
@@ -1208,3 +1208,123 @@ build_baf_recipe <- function(data) {
|
||||
|
||||
# Notice: NO prep() here!
|
||||
}
|
||||
|
||||
#' Tune LightGBM Hyperparameters
|
||||
#'
|
||||
#' Performs a grid search over LightGBM hyperparameters using the same rolling
|
||||
#' time windows as the imbalance tournament. Optimises PR-AUC on the pre-baked
|
||||
#' baseline data stored in MinIO. Returns the best parameters as a named list
|
||||
#' ready for use in \code{evaluate_final_model()} and
|
||||
#' \code{train_production_model()}.
|
||||
#'
|
||||
#' @param imbalance_windows A tibble with columns \code{window_id},
|
||||
#' \code{train_months}, and \code{test_month}, as produced by the
|
||||
#' \code{imbalance_windows} target.
|
||||
#' @param bucket_name Character. MinIO bucket name. Default \code{"baf-fraud"}.
|
||||
#' @param inputs_prefix Character. Prefix for the model input layer.
|
||||
#' Default \code{"05_model_input"}.
|
||||
#' @param grid_size Integer. Number of space-filling candidates. Default \code{30}.
|
||||
#' @param seed Integer. Random seed for reproducibility. Default \code{42}.
|
||||
#'
|
||||
#' @return A named list with elements \code{trees}, \code{tree_depth},
|
||||
#' \code{learn_rate}, and \code{min_n}.
|
||||
#' @export
|
||||
#'
|
||||
#' @importFrom arrow s3_bucket open_dataset
|
||||
#' @importFrom dplyr filter collect mutate any_of
|
||||
#' @importFrom purrr map
|
||||
#' @importFrom rsample make_splits manual_rset
|
||||
#' @importFrom recipes recipe update_role step_zv all_predictors
|
||||
#' @importFrom parsnip boost_tree set_engine set_mode
|
||||
#' @importFrom workflows workflow add_recipe add_model
|
||||
#' @importFrom dials grid_space_filling trees tree_depth learn_rate min_n
|
||||
#' @importFrom tune tune tune_grid control_grid select_best
|
||||
#' @importFrom yardstick metric_set pr_auc
|
||||
tune_lgbm <- function(
|
||||
imbalance_windows,
|
||||
bucket_name = "baf-fraud",
|
||||
inputs_prefix = "05_model_input",
|
||||
grid_size = 30L,
|
||||
seed = 42L
|
||||
) {
|
||||
b <- arrow::s3_bucket(
|
||||
bucket_name,
|
||||
endpoint_override = Sys.getenv("BAF_ENDPOINT"),
|
||||
scheme = "http",
|
||||
access_key = Sys.getenv("BAF_KEY"),
|
||||
secret_key = Sys.getenv("BAF_SECRET"),
|
||||
region = "us-east-1"
|
||||
)
|
||||
|
||||
message("Loading baseline data (months 0-5) for tuning...")
|
||||
tune_data <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/baseline"))) |>
|
||||
dplyr::filter(month %in% 0:5) |>
|
||||
dplyr::collect() |>
|
||||
dplyr::mutate(outcome = factor(outcome, levels = c("Fraud", "Legit")))
|
||||
|
||||
message("Rows loaded: ", nrow(tune_data))
|
||||
|
||||
# Build rolling window resamples matching the tournament windows
|
||||
splits <- purrr::map(
|
||||
seq_len(nrow(imbalance_windows)),
|
||||
function(i) {
|
||||
win <- imbalance_windows[i, ]
|
||||
train_idx <- which(tune_data$month %in% win$train_months[[1]])
|
||||
test_idx <- which(tune_data$month == win$test_month)
|
||||
rsample::make_splits(
|
||||
list(analysis = train_idx, assessment = test_idx),
|
||||
data = tune_data
|
||||
)
|
||||
}
|
||||
)
|
||||
rolling_cv <- rsample::manual_rset(splits, ids = imbalance_windows$window_id)
|
||||
|
||||
# Minimal recipe — data is already baked; just remove ID columns
|
||||
tune_recipe <- recipes::recipe(outcome ~ ., data = tune_data) |>
|
||||
recipes::update_role(dplyr::any_of(c("month", "month_date")), new_role = "ID") |>
|
||||
recipes::step_zv(recipes::all_predictors())
|
||||
|
||||
lgbm_spec <- parsnip::boost_tree(
|
||||
trees = tune::tune(),
|
||||
tree_depth = tune::tune(),
|
||||
learn_rate = tune::tune(),
|
||||
min_n = tune::tune()
|
||||
) |>
|
||||
parsnip::set_engine("lightgbm", num_threads = parallel::detectCores()) |>
|
||||
parsnip::set_mode("classification")
|
||||
|
||||
tune_wflow <- workflows::workflow() |>
|
||||
workflows::add_recipe(tune_recipe) |>
|
||||
workflows::add_model(lgbm_spec)
|
||||
|
||||
set.seed(seed)
|
||||
lgbm_grid <- dials::grid_space_filling(
|
||||
dials::trees(range = c(100L, 1000L)),
|
||||
dials::tree_depth(range = c(3L, 8L)),
|
||||
dials::learn_rate(range = c(-3, -1)),
|
||||
dials::min_n(range = c(100L, 500L)),
|
||||
size = grid_size
|
||||
)
|
||||
|
||||
message("Starting hyperparameter tuning (", grid_size, " candidates x ",
|
||||
nrow(imbalance_windows), " windows)...")
|
||||
set.seed(seed)
|
||||
tune_results <- tune::tune_grid(
|
||||
tune_wflow,
|
||||
resamples = rolling_cv,
|
||||
grid = lgbm_grid,
|
||||
metrics = yardstick::metric_set(yardstick::pr_auc),
|
||||
control = tune::control_grid(verbose = TRUE, save_pred = FALSE)
|
||||
)
|
||||
|
||||
best <- tune::select_best(tune_results, metric = "pr_auc")
|
||||
message("Best PR-AUC params: trees=", best$trees, " tree_depth=", best$tree_depth,
|
||||
" learn_rate=", round(best$learn_rate, 5), " min_n=", best$min_n)
|
||||
|
||||
list(
|
||||
trees = best$trees,
|
||||
tree_depth = best$tree_depth,
|
||||
learn_rate = best$learn_rate,
|
||||
min_n = best$min_n
|
||||
)
|
||||
}
|
||||
@@ -34,6 +34,7 @@ reference:
|
||||
desc: "Cross-validation and imbalance strategy testing."
|
||||
contents:
|
||||
- run_imbalance_tournament
|
||||
- tune_lgbm
|
||||
- train_diag_model
|
||||
- create_efficiency_plot # Moved here: Belongs with the tournament
|
||||
|
||||
|
||||
@@ -307,12 +307,7 @@ list(
|
||||
),
|
||||
tar_target(
|
||||
winning_params,
|
||||
list(
|
||||
trees = 844,
|
||||
tree_depth = 3,
|
||||
learn_rate = 0.0204,
|
||||
min_n = 389
|
||||
)
|
||||
tune_lgbm(imbalance_windows)
|
||||
),
|
||||
tar_target(
|
||||
production_model_uri,
|
||||
|
||||
39
man/tune_lgbm.Rd
Normal file
39
man/tune_lgbm.Rd
Normal file
@@ -0,0 +1,39 @@
|
||||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/functions.R
|
||||
\name{tune_lgbm}
|
||||
\alias{tune_lgbm}
|
||||
\title{Tune LightGBM Hyperparameters}
|
||||
\usage{
|
||||
tune_lgbm(
|
||||
imbalance_windows,
|
||||
bucket_name = "baf-fraud",
|
||||
inputs_prefix = "05_model_input",
|
||||
grid_size = 30L,
|
||||
seed = 42L
|
||||
)
|
||||
}
|
||||
\arguments{
|
||||
\item{imbalance_windows}{A tibble with columns \code{window_id},
|
||||
\code{train_months}, and \code{test_month}, as produced by the
|
||||
\code{imbalance_windows} target.}
|
||||
|
||||
\item{bucket_name}{Character. MinIO bucket name. Default \code{"baf-fraud"}.}
|
||||
|
||||
\item{inputs_prefix}{Character. Prefix for the model input layer.
|
||||
Default \code{"05_model_input"}.}
|
||||
|
||||
\item{grid_size}{Integer. Number of space-filling candidates. Default \code{30}.}
|
||||
|
||||
\item{seed}{Integer. Random seed for reproducibility. Default \code{42}.}
|
||||
}
|
||||
\value{
|
||||
A named list with elements \code{trees}, \code{tree_depth},
|
||||
\code{learn_rate}, and \code{min_n}.
|
||||
}
|
||||
\description{
|
||||
Performs a grid search over LightGBM hyperparameters using the same rolling
|
||||
time windows as the imbalance tournament. Optimises PR-AUC on the pre-baked
|
||||
baseline data stored in MinIO. Returns the best parameters as a named list
|
||||
ready for use in \code{evaluate_final_model()} and
|
||||
\code{train_production_model()}.
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 56 KiB After Width: | Height: | Size: 56 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 151 KiB After Width: | Height: | Size: 151 KiB |
Reference in New Issue
Block a user