Add tune_lgbm() and wire hyperparameter tuning into DAG
Converts scratch/tune_model.R into a pure tune_lgbm() function, replacing hardcoded winning_params with a fully automated tar_target. Best params (trees=844, depth=3, lr=0.0204, min_n=389) now flow reproducibly into evaluate_final_model() and train_production_model(). PR-AUC improved from 0.165 to 0.198. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
126
R/functions.R
126
R/functions.R
@@ -1199,12 +1199,132 @@ train_production_model <- function(data, recipe, best_params, model_filename = "
|
||||
build_baf_recipe <- function(data) {
|
||||
recipes::recipe(outcome ~ ., data = data) |>
|
||||
recipes::update_role(month, new_role = "ID") |>
|
||||
recipes::step_novel(recipes::all_nominal_predictors()) |>
|
||||
recipes::step_novel(recipes::all_nominal_predictors()) |>
|
||||
recipes::step_unknown(recipes::all_nominal_predictors()) |>
|
||||
recipes::step_indicate_na(recipes::all_numeric_predictors()) |>
|
||||
recipes::step_impute_median(recipes::all_numeric_predictors()) |>
|
||||
recipes::step_dummy(recipes::all_nominal_predictors(), one_hot = TRUE) |>
|
||||
recipes::step_dummy(recipes::all_nominal_predictors(), one_hot = TRUE) |>
|
||||
recipes::step_zv(recipes::all_predictors())
|
||||
|
||||
|
||||
# Notice: NO prep() here!
|
||||
}
|
||||
|
||||
#' Tune LightGBM Hyperparameters
|
||||
#'
|
||||
#' Performs a grid search over LightGBM hyperparameters using the same rolling
|
||||
#' time windows as the imbalance tournament. Optimises PR-AUC on the pre-baked
|
||||
#' baseline data stored in MinIO. Returns the best parameters as a named list
|
||||
#' ready for use in \code{evaluate_final_model()} and
|
||||
#' \code{train_production_model()}.
|
||||
#'
|
||||
#' @param imbalance_windows A tibble with columns \code{window_id},
|
||||
#' \code{train_months}, and \code{test_month}, as produced by the
|
||||
#' \code{imbalance_windows} target.
|
||||
#' @param bucket_name Character. MinIO bucket name. Default \code{"baf-fraud"}.
|
||||
#' @param inputs_prefix Character. Prefix for the model input layer.
|
||||
#' Default \code{"05_model_input"}.
|
||||
#' @param grid_size Integer. Number of space-filling candidates. Default \code{30}.
|
||||
#' @param seed Integer. Random seed for reproducibility. Default \code{42}.
|
||||
#'
|
||||
#' @return A named list with elements \code{trees}, \code{tree_depth},
|
||||
#' \code{learn_rate}, and \code{min_n}.
|
||||
#' @export
|
||||
#'
|
||||
#' @importFrom arrow s3_bucket open_dataset
|
||||
#' @importFrom dplyr filter collect mutate any_of
|
||||
#' @importFrom purrr map
|
||||
#' @importFrom rsample make_splits manual_rset
|
||||
#' @importFrom recipes recipe update_role step_zv all_predictors
|
||||
#' @importFrom parsnip boost_tree set_engine set_mode
|
||||
#' @importFrom workflows workflow add_recipe add_model
|
||||
#' @importFrom dials grid_space_filling trees tree_depth learn_rate min_n
|
||||
#' @importFrom tune tune tune_grid control_grid select_best
|
||||
#' @importFrom yardstick metric_set pr_auc
|
||||
tune_lgbm <- function(
|
||||
imbalance_windows,
|
||||
bucket_name = "baf-fraud",
|
||||
inputs_prefix = "05_model_input",
|
||||
grid_size = 30L,
|
||||
seed = 42L
|
||||
) {
|
||||
b <- arrow::s3_bucket(
|
||||
bucket_name,
|
||||
endpoint_override = Sys.getenv("BAF_ENDPOINT"),
|
||||
scheme = "http",
|
||||
access_key = Sys.getenv("BAF_KEY"),
|
||||
secret_key = Sys.getenv("BAF_SECRET"),
|
||||
region = "us-east-1"
|
||||
)
|
||||
|
||||
message("Loading baseline data (months 0-5) for tuning...")
|
||||
tune_data <- arrow::open_dataset(b$path(glue::glue("{inputs_prefix}/baseline"))) |>
|
||||
dplyr::filter(month %in% 0:5) |>
|
||||
dplyr::collect() |>
|
||||
dplyr::mutate(outcome = factor(outcome, levels = c("Fraud", "Legit")))
|
||||
|
||||
message("Rows loaded: ", nrow(tune_data))
|
||||
|
||||
# Build rolling window resamples matching the tournament windows
|
||||
splits <- purrr::map(
|
||||
seq_len(nrow(imbalance_windows)),
|
||||
function(i) {
|
||||
win <- imbalance_windows[i, ]
|
||||
train_idx <- which(tune_data$month %in% win$train_months[[1]])
|
||||
test_idx <- which(tune_data$month == win$test_month)
|
||||
rsample::make_splits(
|
||||
list(analysis = train_idx, assessment = test_idx),
|
||||
data = tune_data
|
||||
)
|
||||
}
|
||||
)
|
||||
rolling_cv <- rsample::manual_rset(splits, ids = imbalance_windows$window_id)
|
||||
|
||||
# Minimal recipe — data is already baked; just remove ID columns
|
||||
tune_recipe <- recipes::recipe(outcome ~ ., data = tune_data) |>
|
||||
recipes::update_role(dplyr::any_of(c("month", "month_date")), new_role = "ID") |>
|
||||
recipes::step_zv(recipes::all_predictors())
|
||||
|
||||
lgbm_spec <- parsnip::boost_tree(
|
||||
trees = tune::tune(),
|
||||
tree_depth = tune::tune(),
|
||||
learn_rate = tune::tune(),
|
||||
min_n = tune::tune()
|
||||
) |>
|
||||
parsnip::set_engine("lightgbm", num_threads = parallel::detectCores()) |>
|
||||
parsnip::set_mode("classification")
|
||||
|
||||
tune_wflow <- workflows::workflow() |>
|
||||
workflows::add_recipe(tune_recipe) |>
|
||||
workflows::add_model(lgbm_spec)
|
||||
|
||||
set.seed(seed)
|
||||
lgbm_grid <- dials::grid_space_filling(
|
||||
dials::trees(range = c(100L, 1000L)),
|
||||
dials::tree_depth(range = c(3L, 8L)),
|
||||
dials::learn_rate(range = c(-3, -1)),
|
||||
dials::min_n(range = c(100L, 500L)),
|
||||
size = grid_size
|
||||
)
|
||||
|
||||
message("Starting hyperparameter tuning (", grid_size, " candidates x ",
|
||||
nrow(imbalance_windows), " windows)...")
|
||||
set.seed(seed)
|
||||
tune_results <- tune::tune_grid(
|
||||
tune_wflow,
|
||||
resamples = rolling_cv,
|
||||
grid = lgbm_grid,
|
||||
metrics = yardstick::metric_set(yardstick::pr_auc),
|
||||
control = tune::control_grid(verbose = TRUE, save_pred = FALSE)
|
||||
)
|
||||
|
||||
best <- tune::select_best(tune_results, metric = "pr_auc")
|
||||
message("Best PR-AUC params: trees=", best$trees, " tree_depth=", best$tree_depth,
|
||||
" learn_rate=", round(best$learn_rate, 5), " min_n=", best$min_n)
|
||||
|
||||
list(
|
||||
trees = best$trees,
|
||||
tree_depth = best$tree_depth,
|
||||
learn_rate = best$learn_rate,
|
||||
min_n = best$min_n
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user