#' Predict from a random_gaussian_nb model
#'
#' @param object A fitted `random_gaussian_nb` object.
#' @param newdata A data.frame of predictors. If NULL, uses training predictors.
#' @param type "class" (default) or "prob".
#' @param ... currently unused.
#'
#' @return
#' If `type = "prob"`, returns a data.frame with one column per class giving
#' posterior probabilities averaged over the bootstrap ensemble (rows correspond
#' to observations in `newdata`).
#'
#' If `type = "class"`, returns a factor of predicted class labels with levels
#' equal to the training classes.
#'
#' @export
predict.random_gaussian_nb <- function(object,
                                       newdata = NULL,
                                       type = c("class", "prob"),
                                       ...) {
  type <- match.arg(type)

  X <- if (is.null(newdata)) object$X_train else as.data.frame(newdata)

  models  <- object$.models
  classes <- object$.classes
  n_mod   <- length(models)

  if (n_mod < 1L) stop("No bootstrap models found in `object$.models`.")
  if (!is.data.frame(X)) stop("`newdata` must be coercible to a data.frame.")
  if (nrow(X) < 1L) stop("`newdata` has 0 rows.")

  # 1) Check required predictors exist
  all_feats <- unique(unlist(lapply(models, `[[`, "feats"), use.names = FALSE))
  miss <- setdiff(all_feats, names(X))
  if (length(miss))
    stop("Missing predictors in `newdata`: ", paste(miss, collapse = ", "))

  # 2) Align predictor types with training
  if (!is.null(object$is_num) && any(object$is_num)) {
    num_names <- names(object$is_num)[object$is_num]
    num_names <- intersect(num_names, names(X))
    bad_num <- num_names[!vapply(X[num_names], is.numeric, logical(1))]
    if (length(bad_num)) {
      stop("Numeric predictors must be numeric in `newdata`. Non-numeric: ",
           paste(bad_num, collapse = ", "))
    }
  }

  if (!is.null(object$is_cat) && any(object$is_cat) && !is.null(object$train_levels)) {
    cat_names <- names(object$is_cat)[object$is_cat]
    cat_names <- intersect(cat_names, names(X))
    for (f in cat_names) {
      X[[f]] <- factor(X[[f]], levels = object$train_levels[[f]])
    }
  }

  # 3) Gaussian log-likelihood row sums
  gauss_loglik_rowsum <- function(Xm, mu, sd) {
    sd <- pmax(sd, .Machine$double.eps)
    z  <- sweep(Xm, 2, mu, "-")
    z  <- sweep(z, 2, sd, "/")
    rowSums(-0.5 * z^2 - log(sd) - 0.5 * log(2 * pi))
  }

  # 4) Per-model posterior probs (stable)
  post_list <- lapply(models, function(mdl) {
    n_new <- nrow(X)
    logpost <- matrix(0, n_new, length(classes))
    colnames(logpost) <- classes

    ## numeric
    if (length(mdl$num_feats) > 0L) {
      Xm <- as.matrix(X[, mdl$num_feats, drop = FALSE])
      for (cl in classes) {
        logpost[, cl] <- logpost[, cl] + gauss_loglik_rowsum(Xm, mdl$mu[[cl]], mdl$sigma[[cl]])
      }
    }

    ## categorical
    if (length(mdl$cat_feats) > 0L) {
      for (f in mdl$cat_feats) {
        levs <- mdl$levels_map[[f]]
        x_chr <- as.character(X[[f]])
        idx <- match(x_chr, levs)

        for (cl in classes) {
          probs <- mdl$catprob[[f]][[cl]]
          probs <- pmax(probs, .Machine$double.eps)

          p <- probs[idx]

          # correct unseen prob (stored during training); fallback if old models exist
          unseen_p <- .Machine$double.eps
          if (!is.null(mdl$cat_unseen) && !is.null(mdl$cat_unseen[[f]]) &&
              !is.null(mdl$cat_unseen[[f]][[cl]])) {
            unseen_p <- mdl$cat_unseen[[f]][[cl]]
          } else {
            # backward-compatible fallback
            unseen_p <- max(min(probs, na.rm = TRUE), .Machine$double.eps)
          }

          p[is.na(p)] <- unseen_p
          logpost[, cl] <- logpost[, cl] + log(p)
        }
      }
    }

    ## priors
    pr <- mdl$prior
    pr <- pr[classes]
    pr[is.na(pr) | pr <= 0] <- .Machine$double.eps
    logpost <- sweep(logpost, 2, log(pr), "+")

    ## stable softmax
    mx <- apply(logpost, 1, max)
    ex <- exp(sweep(logpost, 1, mx, "-"))
    ex / rowSums(ex)
  })

  avg_prob <- Reduce(`+`, post_list) / n_mod
  colnames(avg_prob) <- classes

  if (type == "prob") {
    as.data.frame(avg_prob)
  } else {
    factor(classes[max.col(avg_prob, ties.method = "first")], levels = classes)
  }
}
