From f641698f1cbe22db63da4afd0a4421dc4ce6a1c5 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 09:18:05 -0500 Subject: [PATCH 01/34] Updated partition_tracker to track auxiliary data for CLogLog Ordinal BART model --- include/stochtree/partition_tracker.h | 25 ++++++++++++++++ src/partition_tracker.cpp | 42 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 0790d87a..a2f5dd70 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -93,6 +93,17 @@ class ForestTracker { int GetNumFeatures() {return num_features_;} bool Initialized() {return initialized_;} + /*! + * \brief Ordinal auxiliary data management methods + * Methods to initialize, get, and set auxiliary data for cloglog ordinal bart models + * n_levels is the number of outcome levels for the ordinal response + * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) + */ + void InitializeOrdinalAuxData(data_size_t num_observations, int n_levels); + double GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const; + void SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value); + std::vector& GetOrdinalAuxDataVector(int type_idx); + private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ std::vector sum_predictions_; @@ -121,6 +132,20 @@ class ForestTracker { void UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates); void UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); void UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); + + /*! + * \brief Track auxiliary data for cloglog ordinal bart models + * Vector of vectors to store these auxiliary data + * Each inner vector holds one type of data (order: Latent variable Z, Forest predictions, Cutpoints gamma, Cumsum exp of cutpoints) + */ + std::vector> ordinal_aux_data_vec_; + + /*! + * \brief Private helper methods for ordinal auxiliary data + * n_levels is the number of outcome levels for the ordinal response + * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) + */ + void ResizeOrdinalAuxData(data_size_t num_observations, int n_levels); }; /*! \brief Class storing sample-prediction map for each tree in an ensemble */ diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 9d643380..9c2831fc 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -696,4 +696,46 @@ std::vector FeaturePresortPartition::NodeIndices(int node_id) { return out; } + +// ============================================================================ +// ORDINAL AUXILIARY DATA METHODS +// ============================================================================ + +double ForestTracker::GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); + return ordinal_aux_data_vec_[type_idx][obs_idx]; +} + +void ForestTracker::InitializeOrdinalAuxData(data_size_t num_observations, int n_levels) { + ResizeOrdinalAuxData(num_observations, n_levels); +} + +void ForestTracker::SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value) { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); + ordinal_aux_data_vec_[type_idx][obs_idx] = value; +} + +std::vector& ForestTracker::GetOrdinalAuxDataVector(int type_idx) { + // CHECK(IsValidOrdinalType(type_idx)); + return ordinal_aux_data_vec_[type_idx]; +} + +void ForestTracker::ResizeOrdinalAuxData(data_size_t num_observations, int n_levels) { + // 4 types of ordinal auxiliary data: latent Z, forest predictions, cutpoints gamma, cumsum exp of gammas + const int n_types = 4; + ordinal_aux_data_vec_.resize(n_types); + for (int i = 0; i < n_types; ++i) { + if (i < 2) { + // First two types (latent Z, forest predictions) are sized to num_observations + ordinal_aux_data_vec_[i].assign(num_observations, 0.0); + } else if (i == 2) { + // Cutpoints gamma: size n_levels - 1 + ordinal_aux_data_vec_[i].assign(n_levels - 1, 0.0); + } else if (i == 3) { + // Cumsum exp of gammas: size n_levels + ordinal_aux_data_vec_[i].assign(n_levels, 0.0); + } + } +} + } // namespace StochTree From e99791bf61546328946ddb7a3441c378b1dccd50 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:20:10 -0500 Subject: [PATCH 02/34] Added leaf model for CLogLog Ordinal BART --- include/stochtree/leaf_model.h | 252 ++++++++++++++++++++++++++++++++- src/leaf_model.cpp | 64 +++++++++ src/partition_tracker.cpp | 3 - 3 files changed, 310 insertions(+), 9 deletions(-) diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 5359775d..a563cfe1 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -347,12 +347,14 @@ namespace StochTree { * 2. `kUnivariateRegressionLeafGaussian`: Every leaf node has a zero-centered univariate normal prior and every leaf is a linear model, multiplying the leaf parameter by a (fixed) basis. * 3. `kMultivariateRegressionLeafGaussian`: Every leaf node has a multivariate normal prior, centered around the zero vector, and every leaf is a linear model, matrix-multiplying the leaf parameters by a (fixed) basis vector. * 4. `kLogLinearVariance`: Every leaf node has a inverse gamma prior and every leaf is constant. + * 5. `kCloglogOrdinal`: Every leaf node has a log-gamma prior and every leaf is constant. */ enum ModelType { kConstantLeafGaussian, kUnivariateRegressionLeafGaussian, kMultivariateRegressionLeafGaussian, - kLogLinearVariance + kLogLinearVariance, + kCloglogOrdinal }; /*! \brief Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model */ @@ -969,6 +971,236 @@ class LogLinearVarianceLeafModel { GammaSampler gamma_sampler_; }; +/*! \brief Sufficient statistic and associated operations for complementary log-log ordinal BART model */ +class CloglogOrdinalSuffStat { + public: + data_size_t n; + double sum_Y_less_K; + double other_sum; + + /*! + * \brief Construct a new CloglogOrdinalSuffStat object, setting all sufficient statistics to zero + */ + CloglogOrdinalSuffStat() { + n = 0; + sum_Y_less_K = 0.0; + other_sum = 0.0; + } + + /*! + * \brief Accumulate data from observation `row_idx` into the sufficient statistics + * + * \param dataset Data object containing training data, including covariates + * \param outcome Data object containing the original ordinal outcome values, which are used to compute sufficient statistics + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param row_idx Index of the training data observation from which the sufficient statistics should be updated + * \param tree_idx Index of the tree being updated in the course of this sufficient statistic update + */ + void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { + n += 1; + + // Get ordinal outcome value for this observation + unsigned int y = static_cast(outcome(row_idx)); + + // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma) + double Z = tracker.GetOrdinalAuxData(0, row_idx); // latent variables Z + double lambda_minus = tracker.GetOrdinalAuxData(1, row_idx); // forest predictions excluding current tree + + // Get cutpoints gamma and cumulative sum of exp(gamma) + const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma + const std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // cumsum exp of gamma + + int K = gamma.size() + 1; // Number of ordinal categories + + if (y == K - 1) { + other_sum += std::exp(lambda_minus) * seg[y]; // checked and it's correct + } else { + sum_Y_less_K += 1.0; + other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct + } + } + + /*! + * \brief Reset all of the sufficient statistics to zero + */ + void ResetSuffStat() { + n = 0; + sum_Y_less_K = 0.0; + other_sum = 0.0; + } + + /*! + * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ + void AddSuffStat(CloglogOrdinalSuffStat& lhs, CloglogOrdinalSuffStat& rhs) { + n = lhs.n + rhs.n; + sum_Y_less_K = lhs.sum_Y_less_K + rhs.sum_Y_less_K; + other_sum = lhs.other_sum + rhs.other_sum; + } + + /*! + * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ + void SubtractSuffStat(CloglogOrdinalSuffStat& lhs, CloglogOrdinalSuffStat& rhs) { + n = lhs.n - rhs.n; + sum_Y_less_K = lhs.sum_Y_less_K - rhs.sum_Y_less_K; + other_sum = lhs.other_sum - rhs.other_sum; + } + + /*! + * \brief Check whether accumulated sample size, `n`, is greater than some threshold + * + * \param threshold Value used to compute `n > threshold` + */ + bool SampleGreaterThan(data_size_t threshold) { + return n > threshold; + } + + /*! + * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold + * + * \param threshold Value used to compute `n >= threshold` + */ + bool SampleGreaterThanEqual(data_size_t threshold) { + return n >= threshold; + } + + /*! + * \brief Return the sample size accumulated by a sufficient stat object + */ + data_size_t SampleSize() { + return n; + } +}; + +/*! \brief Marginal likelihood and posterior computation for complementary log-log ordinal BART model */ +class CloglogOrdinalLeafModel { + public: + /*! + * \brief Construct a new CloglogOrdinalLeafModel object + * + * \param a Shape parameter for log-gamma prior on leaf parameters + * \param b rate parameter for log-gamma prior on leaf parameters + * Log-gamma density: f(x) = b^a / Gamma(a) * exp(a*x - b*exp(x)) + * Relationship to tau (scale of leaf parameters): tau^2 = trigamma(a) + */ + CloglogOrdinalLeafModel(double a, double b) { + a_ = a; + b_ = b; + gamma_sampler_ = GammaSampler(); + tau_ = std::sqrt(boost::math::trigamma(a_)); + } + ~CloglogOrdinalLeafModel() {} + + /*! + * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. + */ + double SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance); + + /*! + * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. + */ + double NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Helper function to compute log marginal likelihood from sufficient statistics + */ + double SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Posterior shape parameter for leaf node log-gamma distribution + */ + double PosteriorParameterShape(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Posterior rate parameter for leaf node log-gamma distribution + */ + double PosteriorParameterRate(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters. + * Samples from log-gamma: sample from gamma, then take log. + */ + void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); + + void SetScale(double tau) {tau_ = tau;} + + /*! + * \brief Get the current scale parameter value (tau_) + * \return Current tau_ value + */ + double GetScale() const {return tau_;} + + inline bool RequiresBasis() {return false;} + + /*! + * \brief Convert tau_ (scale_lambda i.e. scale for leaf parameters) to alpha (shape) and beta (rate) parameters for the log-gamma prior + * + * \param alpha Output: shape parameter for log-gamma prior + * \param beta Output: rate parameter for log-gamma prior + * \param tau Scale parameter (tau_) for leaf parameters + */ + void ScaleTauToAlphaBeta(double& alpha, double& beta, const double tau) { + double tau_sq = tau * tau; + alpha = TrigammaInverse(tau_sq); + // Note: Using exponential of digamma function for beta calculation + beta = std::exp(boost::math::digamma(alpha)); + } + + /*! + * \brief Convert alpha (shape) and beta (rate) parameters (for the log-gamma prior) back to tau_ (scale_lambda i.e. scale for leaf parameters) + * + * \param alpha Shape parameter for log-gamma prior + * \param beta Rate parameter for log-gamma prior + * \return tau Scale parameter (tau_) for leaf parameters + */ + double AlphaBetaToScaleTau(double alpha, double beta) { + // Inverse of the transformation: tau_sq = trigamma(alpha) + double tau_sq = boost::math::trigamma(alpha); + return std::sqrt(tau_sq); + } + + private: + /*! + * \brief Compute inverse trigamma function using Newton's method + * + * Implementation adapted from limma package in R, originally by Gordon Smyth + * + * \param x Input value for which to compute trigamma inverse + * \return Value y such that trigamma(y) = x + */ + double TrigammaInverse(double x) { + // Very large and very small values - deal with using asymptotics + if (x > 1E7) { + return 1.0 / std::sqrt(x); + } + if (x < 1E-6) { + return 1.0 / x; + } + + // Otherwise, use Newton's method + double y = 0.5 + 1.0 / x; + for (int i = 0; i < 50; i++) { + double tri = boost::math::trigamma(y); + double dif = tri * (1.0 - tri / x) / boost::math::polygamma(3, y); // tetragamma is polygamma(3, x) + y += dif; + if (-dif / y < 1E-8) break; + } + + return y; + } + double a_; + double b_; + GammaSampler gamma_sampler_; + double tau_; +}; + /*! * \brief Unifying layer for disparate sufficient statistic class types * @@ -980,7 +1212,8 @@ class LogLinearVarianceLeafModel { using SuffStatVariant = std::variant; + LogLinearVarianceSuffStat, + CloglogOrdinalSuffStat>; /*! * \brief Unifying layer for disparate leaf model class types @@ -993,7 +1226,8 @@ using SuffStatVariant = std::variant; + LogLinearVarianceLeafModel, + CloglogOrdinalLeafModel>; template static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) { @@ -1018,8 +1252,10 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di return createSuffStat(); } else if (model_type == kMultivariateRegressionLeafGaussian) { return createSuffStat(basis_dim); - } else { + } else if (model_type == kLogLinearVariance) { return createSuffStat(); + } else { + return createSuffStat(); } } @@ -1031,16 +1267,20 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di * \param Sigma0 Value of the leaf node prior covariance matrix, only used if `model_type = kMultivariateRegressionLeafGaussian` * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` + * \param c Value of the leaf node log-gamma prior shape parameter, only used if `model_type = kCloglogOrdinal` + * \param d Value of the leaf node log-gamma prior rate parameter, only used if `model_type = kCloglogOrdinal` */ -static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { +static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b, double c, double d) { if (model_type == kConstantLeafGaussian) { return createLeafModel(tau); } else if (model_type == kUnivariateRegressionLeafGaussian) { return createLeafModel(tau); } else if (model_type == kMultivariateRegressionLeafGaussian) { return createLeafModel(Sigma0); - } else { + } else if (model_type == kLogLinearVariance) { return createLeafModel(a, b); + } else { + return createLeafModel(c, d); } } diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 3b59ab96..78d8da76 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -274,4 +274,68 @@ void LogLinearVarianceLeafModel::SetEnsembleRootPredictedValue(ForestDataset& da } } +// ============================================================================ +// Cloglog Ordinal Leaf Model +// ============================================================================ + +double CloglogOrdinalLeafModel::SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance) { + double left_log_ml = SuffStatLogMarginalLikelihood(left_stat, global_variance); + double right_log_ml = SuffStatLogMarginalLikelihood(right_stat, global_variance); + return left_log_ml + right_log_ml; +} + +double CloglogOrdinalLeafModel::NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return SuffStatLogMarginalLikelihood(suff_stat, global_variance); +} + +double CloglogOrdinalLeafModel::SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_); + double a_term = a_ + suff_stat.sum_Y_less_K; + double b_term = b_ + suff_stat.other_sum; + double log_b_term = std::log(b_term); + double lgamma_a_term = boost::math::lgamma(a_term); + double resid_term = a_term * log_b_term; + double log_ml = prior_terms + lgamma_a_term - resid_term; + return log_ml; +} + +double CloglogOrdinalLeafModel::PosteriorParameterShape(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return a_ + suff_stat.sum_Y_less_K; +} + +double CloglogOrdinalLeafModel::PosteriorParameterRate(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return b_ + suff_stat.other_sum; +} + +void CloglogOrdinalLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { + // Vector of leaf indices for tree + std::vector tree_leaves = tree->GetLeaves(); + + // Initialize sufficient statistics + CloglogOrdinalSuffStat node_suff_stat = CloglogOrdinalSuffStat(); + + // Sample each leaf node parameter + double node_shape; + double node_rate; + double node_mu; + int32_t leaf_id; + for (int i = 0; i < tree_leaves.size(); i++) { + // Compute leaf node sufficient statistics + leaf_id = tree_leaves[i]; + node_suff_stat.ResetSuffStat(); + AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); + + // Compute posterior shape and rate + node_shape = PosteriorParameterShape(node_suff_stat, global_variance); + node_rate = PosteriorParameterRate(node_suff_stat, global_variance); + + // Draw from log-gamma dist(node_shape, node_rate) and set the leaf parameter with each draw + // std::gamma_distribution gamma_dist_(node_shape, 1.); + // node_mu = -std::log(gamma_sample / node_rate); + double gamma_sample = gamma_sampler_.Sample(node_shape, node_rate, gen); + node_mu = std::log(gamma_sample); + tree->SetLeaf(leaf_id, node_mu); + } +} + } // namespace StochTree diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 9c2831fc..8359faed 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -702,7 +702,6 @@ std::vector FeaturePresortPartition::NodeIndices(int node_id) { // ============================================================================ double ForestTracker::GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const { - // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); return ordinal_aux_data_vec_[type_idx][obs_idx]; } @@ -711,12 +710,10 @@ void ForestTracker::InitializeOrdinalAuxData(data_size_t num_observations, int n } void ForestTracker::SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value) { - // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); ordinal_aux_data_vec_[type_idx][obs_idx] = value; } std::vector& ForestTracker::GetOrdinalAuxDataVector(int type_idx) { - // CHECK(IsValidOrdinalType(type_idx)); return ordinal_aux_data_vec_[type_idx]; } From 8f77e153370a634925278db1cb276ca65d750d46 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:34:44 -0500 Subject: [PATCH 03/34] Added ordinal_sampler --- include/stochtree/ordinal_sampler.h | 86 ++++++++++++++++++++++++++ src/ordinal_sampler.cpp | 94 +++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 include/stochtree/ordinal_sampler.h create mode 100644 src/ordinal_sampler.cpp diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h new file mode 100644 index 00000000..ec148b5b --- /dev/null +++ b/include/stochtree/ordinal_sampler.h @@ -0,0 +1,86 @@ +/*! + * Copyright (c) 2024 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_ORDINAL_SAMPLER_H_ +#define STOCHTREE_ORDINAL_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace StochTree { + +/*! + * \brief Sampler for ordinal model hyperparameters + * + * This class handles MCMC sampling for ordinal-specific parameters: + * - Truncated exponential latent variables (Z) + * - Cutpoint parameters (gamma) + * - Cumulative sum of exp(gamma) (seg) [derived parameter] + */ +class OrdinalSampler { + public: + OrdinalSampler() { + gamma_sampler_ = GammaSampler(); + } + ~OrdinalSampler() {} + + /*! + * \brief Sample from truncated exponential distribution + * + * Samples from exponential distribution truncated to [0,1] + * + * \param lambda Rate parameter for exponential distribution + * \param gen Random number generator + * \return Sampled value from truncated exponential + */ + static double SampleTruncatedExponential(double lambda, std::mt19937& gen); + + + /*! + * \brief Update truncated exponential latent variables (Z) + * + * \param dataset Forest dataset containing training data (covariates) + * \param outcome Vector of outcome values + * \param tracker Forest tracker containing auxiliary data + * \param gen Random number generator + */ + void UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + std::mt19937& gen); + + /*! + * \brief Update gamma cutpoint parameters + * + * \param dataset Forest dataset containing training data (covariates) + * \param outcome Vector of outcome values + * \param tracker Forest tracker containing auxiliary data + * \param alpha_gamma Shape parameter for log-gamma prior on cutpoints gamma + * \param beta_gamma Rate parameter for log-gamma prior on cutpoints gamma + * \param gamma_0 Fixed value for first cutpoint parameter (for identifiability) + * \param gen Random number generator + */ + void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + double alpha_gamma, double beta_gamma, double gamma_0, + std::mt19937& gen); + + /*! + * \brief Update cumulative exponential sums (seg) + * + * \param tracker Forest tracker containing auxiliary data + */ + void UpdateCumulativeExpSums(ForestTracker& tracker); + + private: + GammaSampler gamma_sampler_; +}; + +} // namespace StochTree + +#endif // STOCHTREE_ORDINAL_SAMPLER_H_ diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp new file mode 100644 index 00000000..27b11f63 --- /dev/null +++ b/src/ordinal_sampler.cpp @@ -0,0 +1,94 @@ +#include +#include + +namespace StochTree { + +double OrdinalSampler::SampleTruncatedExponential(double lambda, std::mt19937& gen) { + std::uniform_real_distribution unif(0.0, 1.0); + double u = unif(gen); + double a = 1.0 - u * (1.0 - std::exp(-lambda)); + return -std::log(a) / lambda; +} + +void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + std::mt19937& gen) { + // Get auxiliary data vectors + const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // gamma cutpoints + const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) + + int K = gamma.size() + 1; // Number of ordinal categories + int N = dataset.NumObservations(); + + // Update truncated exponentials (stored in latent auxiliary data slot 0) + // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) + // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} + // and lambda_hat_i is the total forest prediction for observation i + // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it + // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y == K - 1) { + Z[i] = 1.0; + } else { + double rate = std::exp(gamma[y] + lambda_hat[i]); + Z[i] = SampleTruncatedExponential(rate, gen); + } + } +} + +void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, + double alpha_gamma, double beta_gamma, double gamma_0, + std::mt19937& gen) { + // Get auxiliary data vectors + std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's + const std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables z_i's + const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + + int K = gamma.size() + 1; // Number of ordinal categories + int N = dataset.NumObservations(); + + // Compute sufficient statistics A[k] and B[k] for gamma[k] update + std::vector A(K - 1, 0.0); + std::vector B(K - 1, 0.0); + + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y < K - 1) { + A[y] += 1.0; + B[y] += Z[i] * std::exp(lambda_hat[i]); + } + for (int k = 0; k < y; k++) { + B[k] += std::exp(lambda_hat[i]); + } + } + + // Update gamma parameters using log-gamma sampling + // First sample all gamma parameters + for (int k = 0; k < static_cast(gamma.size()); k++) { + double shape = A[k] + alpha_gamma; + double rate = B[k] + beta_gamma; + double gamma_sample = gamma_sampler_.Sample(shape, rate, gen); + gamma[k] = std::log(gamma_sample); + } + + // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability + gamma[0] = gamma_0; +} + +void OrdinalSampler::UpdateCumulativeExpSums(ForestTracker& tracker) { + // Get auxiliary data vectors + const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's + std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) + + // Update seg (sum of exponentials of gamma cutpoints) + for (int j = 0; j < static_cast(seg.size()); j++) { + if (j == 0) { + seg[j] = 0.0; // checked and it is correct + } else { + seg[j] = seg[j - 1] + std::exp(gamma[j - 1]); // checked and it is correct + } + } +} + +} // namespace StochTree From 8547425cf1abfb95c3c999316f0ffd5b147b7720 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:50:19 -0500 Subject: [PATCH 04/34] Updated tree_sampler.h Added functionality to adjust the model states before/after tree sampling for CLogLog Ordinal BART --- include/stochtree/tree_sampler.h | 42 ++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 68c9c15a..675ef6c0 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -394,6 +394,40 @@ static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dat } } +static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, + bool requires_basis, bool tree_new) { + data_size_t n = dataset.GetCovariates().rows(); + + double pred_value; + int32_t leaf_pred; + double pred_delta; + for (data_size_t i = 0; i < n; i++) { + if (tree_new) { + // If the tree has been newly sampled or adjusted, we must rerun the prediction + // method and update the SamplePredMapper stored in tracker + leaf_pred = tracker.GetNodeId(i, tree_num); + if (requires_basis) { + pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i); + } else { + pred_value = tree->PredictFromNode(leaf_pred); + } + pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num); + tracker.SetTreeSamplePrediction(i, tree_num, pred_value); + tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta); + // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num) + tracker.SetOrdinalAuxData(1, i, tracker.GetSamplePrediction(i) - pred_value); + } else { + // If the tree has not yet been modified via a sampling step, + // we can query its prediction directly from the SamplePredMapper stored in tracker + pred_value = tracker.GetTreeSamplePrediction(i, tree_num); + // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num): needed? since tree not changed? + double current_lambda_hat = tracker.GetSamplePrediction(i); + double lambda_minus = current_lambda_hat - pred_value; + tracker.SetOrdinalAuxData(1, i, lambda_minus); + } + } +} + template static inline std::tuple EvaluateProposedSplit( ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, @@ -448,7 +482,9 @@ static inline std::tuple EvaluateExist template static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { - if (backfitting) { + if constexpr (std::is_same_v) { + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); + } else if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); } else { // TODO: think about a generic way to store "state" corresponding to the other models? @@ -459,7 +495,9 @@ static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafMod template static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { - if (backfitting) { + if constexpr (std::is_same_v) { + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); + } else if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); } else { // TODO: think about a generic way to store "state" corresponding to the other models? From 6c1d3ce549af37a3cbe8e7e86372b37a79dfde1a Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Wed, 17 Sep 2025 18:57:10 -0500 Subject: [PATCH 05/34] Updated sampler.cpp --- src/sampler.cpp | 101 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/src/sampler.cpp b/src/sampler.cpp index 212ccb42..ee8bd6e6 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include [[cpp11::register]] @@ -326,3 +327,103 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( // Return result return(output); } + +// ============================================================================ +// ORDINAL AUXILIARY DATA FUNCTIONS +// ============================================================================ + +[[cpp11::register]] +void ordinal_aux_data_initialize_cpp(cpp11::external_pointer tracker_ptr, StochTree::data_size_t num_observations, int n_levels) { + tracker_ptr->InitializeOrdinalAuxData(num_observations, n_levels); +} + +[[cpp11::register]] +double ordinal_aux_data_get_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx) { + return tracker_ptr->GetOrdinalAuxData(type_idx, obs_idx); +} + +[[cpp11::register]] +void ordinal_aux_data_set_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx, double value) { + tracker_ptr->SetOrdinalAuxData(type_idx, obs_idx, value); +} + +[[cpp11::register]] +cpp11::writable::doubles ordinal_aux_data_get_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx) { + const std::vector& aux_vec = tracker_ptr->GetOrdinalAuxDataVector(type_idx); + cpp11::writable::doubles output(aux_vec.size()); + for (size_t i = 0; i < aux_vec.size(); i++) { + output[i] = aux_vec[i]; + } + return output; +} + +[[cpp11::register]] +void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx, cpp11::doubles values) { + std::vector& aux_vec = tracker_ptr->GetOrdinalAuxDataVector(type_idx); + if (aux_vec.size() != values.size()) { + cpp11::stop("Size mismatch between auxiliary data vector and input values"); + } + for (size_t i = 0; i < values.size(); i++) { + aux_vec[i] = values[i]; + } +} + +[[cpp11::register]] +void ordinal_aux_data_update_cumsum_exp_cpp(cpp11::external_pointer tracker_ptr) { + // Get auxiliary data vectors + const std::vector& gamma = tracker_ptr->GetOrdinalAuxDataVector(2); // cutpoints gamma + std::vector& seg = tracker_ptr->GetOrdinalAuxDataVector(3); // cumsum exp gamma + + // Update seg (cumulative sum of exp(gamma)) + for (size_t j = 0; j < seg.size(); j++) { + if (j == 0) { + seg[j] = 0.0; + } else { + seg[j] = seg[j - 1] + std::exp(gamma[j - 1]); + } + } +} + +// ============================================================================ +// ORDINAL SAMPLER FUNCTIONS +// ============================================================================ + +[[cpp11::register]] +cpp11::external_pointer ordinal_sampler_cpp() { + std::unique_ptr sampler_ptr = std::make_unique(); + return cpp11::external_pointer(sampler_ptr.release()); +} + +[[cpp11::register]] +void ordinal_sampler_update_latent_variables_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer data_ptr, + cpp11::external_pointer outcome_ptr, + cpp11::external_pointer tracker_ptr, + cpp11::external_pointer rng_ptr +) { + sampler_ptr->UpdateLatentVariables(*data_ptr, outcome_ptr->GetData(), *tracker_ptr, *rng_ptr); +} + +[[cpp11::register]] +void ordinal_sampler_update_gamma_params_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer data_ptr, + cpp11::external_pointer outcome_ptr, + cpp11::external_pointer tracker_ptr, + double alpha_gamma, + double beta_gamma, + double gamma_0, + cpp11::external_pointer rng_ptr +) { + sampler_ptr->UpdateGammaParams(*data_ptr, outcome_ptr->GetData(), *tracker_ptr, alpha_gamma, beta_gamma, gamma_0, *rng_ptr); +} + +[[cpp11::register]] +void ordinal_sampler_update_cumsum_exp_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer tracker_ptr +) { + sampler_ptr->UpdateCumulativeExpSums(*tracker_ptr); +} + From 084be881caa7202decfd505ed29a64f348e8ba21 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Sun, 28 Sep 2025 06:21:51 -0500 Subject: [PATCH 06/34] Added cloglog_ordinal_bart.R function --- R/cloglog_ordinal_bart.R | 180 ++++++ R/cpp11.R | 48 +- include/stochtree/category_tracker.h | 4 + include/stochtree/common.h | 4 + include/stochtree/container.h | 5 + include/stochtree/cutpoint_candidates.h | 2 + include/stochtree/data.h | 1 + include/stochtree/ensemble.h | 6 + include/stochtree/io.h | 2 + include/stochtree/leaf_model.h | 546 ++++++++++--------- include/stochtree/log.h | 2 + include/stochtree/meta.h | 1 + include/stochtree/ordinal_sampler.h | 1 - include/stochtree/partition_tracker.h | 29 +- include/stochtree/random.h | 1 + include/stochtree/random_effects.h | 3 + include/stochtree/slice_sampler.h | 180 ++++++ include/stochtree/tree.h | 3 + include/stochtree/tree_sampler.h | 365 +++++++------ include/stochtree/variance_model.h | 4 + man/bart.Rd | 8 +- man/bcf.Rd | 24 +- man/cloglog_ordinal_bart.Rd | 47 ++ man/createBARTModelFromCombinedJson.Rd | 8 +- man/createBARTModelFromCombinedJsonString.Rd | 8 +- man/createBARTModelFromJson.Rd | 8 +- man/createBARTModelFromJsonFile.Rd | 8 +- man/createBARTModelFromJsonString.Rd | 8 +- man/createBCFModelFromCombinedJson.Rd | 30 +- man/createBCFModelFromCombinedJsonString.Rd | 30 +- man/createBCFModelFromJson.Rd | 34 +- man/createBCFModelFromJsonFile.Rd | 34 +- man/createBCFModelFromJsonString.Rd | 30 +- man/createForestModel.Rd | 8 +- man/getRandomEffectSamples.bartmodel.Rd | 16 +- man/getRandomEffectSamples.bcfmodel.Rd | 34 +- man/predict.bartmodel.Rd | 8 +- man/predict.bcfmodel.Rd | 22 +- man/preprocessPredictionData.Rd | 2 +- man/resetForestModel.Rd | 22 +- man/resetRandomEffectsModel.Rd | 4 +- man/resetRandomEffectsTracker.Rd | 4 +- man/rootResetRandomEffectsModel.Rd | 4 +- man/rootResetRandomEffectsTracker.Rd | 4 +- man/saveBARTModelToJson.Rd | 8 +- man/saveBARTModelToJsonFile.Rd | 8 +- man/saveBARTModelToJsonString.Rd | 8 +- man/saveBCFModelToJson.Rd | 34 +- man/saveBCFModelToJsonFile.Rd | 34 +- man/saveBCFModelToJsonString.Rd | 34 +- src/Makevars.in | 1 + src/Makevars.win.in | 1 + src/R_data.cpp | 1 + src/R_random_effects.cpp | 2 + src/cpp11.cpp | 103 +++- src/cutpoint_candidates.cpp | 1 + src/data.cpp | 1 + src/forest.cpp | 2 + src/io.cpp | 2 + src/kernel.cpp | 2 + src/leaf_model.cpp | 2 + src/ordinal_sampler.cpp | 24 +- src/partition_tracker.cpp | 80 ++- src/py_stochtree.cpp | 19 +- src/sampler.cpp | 143 ++--- src/serialization.cpp | 3 + src/stochtree_types.h | 2 + src/tree.cpp | 4 + 68 files changed, 1503 insertions(+), 808 deletions(-) create mode 100644 R/cloglog_ordinal_bart.R create mode 100644 include/stochtree/slice_sampler.h create mode 100644 man/cloglog_ordinal_bart.Rd diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R new file mode 100644 index 00000000..9cc9b63a --- /dev/null +++ b/R/cloglog_ordinal_bart.R @@ -0,0 +1,180 @@ +#' Run the BART algorithm for ordinal outcomes using a complementary log-log link +#' +#' @param X A numeric matrix of predictors (training data). +#' @param y A numeric vector of ordinal outcomes (positive integers starting from 1). +#' @param X_test An optional numeric matrix of predictors (test data). +#' @param n_trees Number of trees in the BART ensemble. Default: `50`. +#' @param n_samples_mcmc Total number of MCMC samples to draw. Default: `500`. +#' @param n_burnin Number of burn-in samples to discard. Default: `250`. +#' @param n_thin Thinning interval for MCMC samples. Default: `1`. +#' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`. +#' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`. +#' @param variable_weights Optional vector of variable weights for splitting (default: equal weights). +#' @param feature_types Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). + + +cloglog_ordinal_bart <- function(X, y, X_test = NULL, + n_trees = 50, + n_samples_mcmc = 500, + n_burnin = 250, + n_thin = 1, + alpha_gamma = 2.0, + beta_gamma = 2.0, + variable_weights = NULL, + feature_types = NULL, + seed = NULL) { + + # BART parameters + alpha_bart <- 0.95 + beta_bart <- 2 + min_samples_in_leaf <- 5 + max_depth <- 10 + scale_leaf <- 2 / sqrt(n_trees) + cutpoint_grid_size <- 100 # Needed for stochtree:::sample_mcmc_one_iteration_cpp (for GFR), not used in ordinal BART + + # Fixed for identifiability (can be pass as argument later if desired) + gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0 + + # Determine whether a test dataset is provided + has_test <- !is.null(X_test) + # Data checks + if (!is.matrix(X)) X <- as.matrix(X) + if (!is.numeric(y)) y <- as.numeric(y) + if (has_test && !is.matrix(X_test)) X_test <- as.matrix(X_test) + + n_samples <- nrow(X) + n_features <- ncol(X) + + if (any(y < 1) || any(y != round(y))) { + stop("Ordinal outcome y must contain positive integers starting from 1") + } + + # Convert from 1-based (R) to 0-based (C++) indexing + ordinal_outcome <- as.integer(y - 1) + n_levels <- max(y) # Number of ordinal categories + + if (n_levels < 2) { + stop("Ordinal outcome must have at least 2 categories") + } + + if (is.null(variable_weights)) { + variable_weights <- rep(1.0, n_features) + } + + if (is.null(feature_types)) { + feature_types <- rep(0L, n_features) + } + + if (!is.null(seed)) { + set.seed(seed) + } + + keep_idx <- seq((n_burnin + 1), n_samples_mcmc, by = n_thin) + n_keep <- length(keep_idx) + + forest_pred_train <- matrix(0, n_samples, n_keep) + if (has_test) { + n_samples_test <- nrow(X_test) + forest_pred_test <- matrix(0, n_samples_test, n_keep) + } + gamma_samples <- matrix(0, n_levels - 1, n_keep) + latent_samples <- matrix(0, n_samples, n_keep) + + # Initialize other model structures as before + dataX <- stochtree::createForestDataset(X) + if (has_test) { + dataXtest <- stochtree::createForestDataset(X_test) + } + outcome_data <- stochtree::createOutcome(as.numeric(ordinal_outcome)) + active_forest <- stochtree::createForest(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves + active_forest$set_root_leaves(0.0) + split_prior <- stochtree:::tree_prior_cpp(alpha_bart, beta_bart, min_samples_in_leaf, max_depth) + forest_samples <- stochtree::createForestSamples(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves + forest_tracker <- stochtree:::forest_tracker_cpp( + dataX$data_ptr, + as.integer(feature_types), + as.integer(n_trees), + as.integer(n_samples) + ) + stochtree:::ordinal_aux_data_initialize_cpp(forest_tracker, as.integer(n_samples), as.integer(n_levels)) + + # Initialize gamma parameters to zero (slot 2) + initial_gamma <- rep(0.0, n_levels - 1) + for (i in seq_along(initial_gamma)) { + stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 2, i - 1, initial_gamma[i]) + } + stochtree:::ordinal_aux_data_update_cumsum_exp_cpp(forest_tracker) + + # Initialize forest predictions slot to zero (slot 1) + for (i in 1:n_samples) { + stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, i - 1, 0.0) + } + + ordinal_sampler <- stochtree:::ordinal_sampler_cpp() + rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed) + + # Set up sweep indices for tree updates (sample all trees each iteration) + sweep_indices <- as.integer(seq(0, n_trees - 1)) + + sample_counter <- 0 + for (i in 1:n_samples_mcmc) { + keep_sample <- i %in% keep_idx + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + + # 1. Sample forest using MCMC + stochtree:::sample_mcmc_one_iteration_cpp( + dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, + active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, + sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), + scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample + ) + + # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions + # This is needed for updating gamma parameters, latent z_i's + forest_pred_current <- active_forest$predict(dataX) + for (j in 1:n_samples) { + stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, j - 1, forest_pred_current[j]) + } + + # 2. Sample latent z_i's using truncated exponential + stochtree:::ordinal_sampler_update_latent_variables_cpp( + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker, rng$rng_ptr + ) + + # 3. Sample gamma parameters + stochtree:::ordinal_sampler_update_gamma_params_cpp( + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker, + alpha_gamma, beta_gamma, gamma_0, rng$rng_ptr + ) + + # 4. Update cumulative sum of exp(gamma) values + stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, forest_tracker) + + if (keep_sample) { + forest_pred_train[, sample_counter] <- active_forest$predict(dataX) + if (has_test) { + forest_pred_test[, sample_counter] <- active_forest$predict(dataXtest) + } + gamma_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 2) + gamma_samples[, sample_counter] <- gamma_current + latent_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 0) + latent_samples[, sample_counter] <- latent_current + } + } + + result <- list( + forest_predictions_train = forest_pred_train, + forest_predictions_test = if (has_test) forest_pred_test else NULL, + gamma_samples = gamma_samples, + latent_samples = latent_samples, + scale_leaf = scale_leaf, + ordinal_outcome = ordinal_outcome, + n_trees = n_trees, + n_keep = n_keep + ) + + class(result) <- "cloglog_ordinal_bart" + return(result) +} diff --git a/R/cpp11.R b/R/cpp11.R index d77c7472..64db4be1 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -624,12 +624,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) .Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums) } -sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads) { - invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads)) +sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample) { + invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) { @@ -692,6 +692,46 @@ sample_without_replacement_integer_cpp <- function(population_vector, sampling_p .Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size) } +ordinal_aux_data_initialize_cpp <- function(tracker_ptr, num_observations, n_levels) { + invisible(.Call(`_stochtree_ordinal_aux_data_initialize_cpp`, tracker_ptr, num_observations, n_levels)) +} + +ordinal_aux_data_get_cpp <- function(tracker_ptr, type_idx, obs_idx) { + .Call(`_stochtree_ordinal_aux_data_get_cpp`, tracker_ptr, type_idx, obs_idx) +} + +ordinal_aux_data_set_cpp <- function(tracker_ptr, type_idx, obs_idx, value) { + invisible(.Call(`_stochtree_ordinal_aux_data_set_cpp`, tracker_ptr, type_idx, obs_idx, value)) +} + +ordinal_aux_data_get_vector_cpp <- function(tracker_ptr, type_idx) { + .Call(`_stochtree_ordinal_aux_data_get_vector_cpp`, tracker_ptr, type_idx) +} + +ordinal_aux_data_set_vector_cpp <- function(tracker_ptr, type_idx, values) { + invisible(.Call(`_stochtree_ordinal_aux_data_set_vector_cpp`, tracker_ptr, type_idx, values)) +} + +ordinal_aux_data_update_cumsum_exp_cpp <- function(tracker_ptr) { + invisible(.Call(`_stochtree_ordinal_aux_data_update_cumsum_exp_cpp`, tracker_ptr)) +} + +ordinal_sampler_cpp <- function() { + .Call(`_stochtree_ordinal_sampler_cpp`) +} + +ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr)) +} + +ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr)) +} + +ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, tracker_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, tracker_ptr)) +} + init_json_cpp <- function() { .Call(`_stochtree_init_json_cpp`) } diff --git a/include/stochtree/category_tracker.h b/include/stochtree/category_tracker.h index 2ce44635..e5817419 100644 --- a/include/stochtree/category_tracker.h +++ b/include/stochtree/category_tracker.h @@ -29,8 +29,12 @@ #include #include +#include #include #include +#include +#include +#include #include namespace StochTree { diff --git a/include/stochtree/common.h b/include/stochtree/common.h index cd57eea2..c7aab3df 100644 --- a/include/stochtree/common.h +++ b/include/stochtree/common.h @@ -8,18 +8,22 @@ #include #include +#include #include #include #include #include #include +#include #include #include #include +#include #include #include #include #include +#include #include #include diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 4b75ef2f..bb0e7849 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -11,7 +11,12 @@ #include #include +#include +#include #include +#include +#include +#include namespace StochTree { diff --git a/include/stochtree/cutpoint_candidates.h b/include/stochtree/cutpoint_candidates.h index 76f1df4c..8c19013a 100644 --- a/include/stochtree/cutpoint_candidates.h +++ b/include/stochtree/cutpoint_candidates.h @@ -42,6 +42,8 @@ #include #include +#include + namespace StochTree { /*! \brief Computing and tracking cutpoints available for a given feature at a given node diff --git a/include/stochtree/data.h b/include/stochtree/data.h index a6061f4b..df232fb3 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace StochTree { diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index 4f6ddf42..4624b5a4 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -14,6 +14,12 @@ #include #include +#include +#include +#include +#include +#include + using json = nlohmann::json; namespace StochTree { diff --git a/include/stochtree/io.h b/include/stochtree/io.h index 55963946..3bc277fb 100644 --- a/include/stochtree/io.h +++ b/include/stochtree/io.h @@ -28,10 +28,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index a563cfe1..6adf9c23 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -16,318 +16,323 @@ #include #include #include +#include #include +#include +#include #include +#include +#include #include namespace StochTree { -/*! +/*! * \defgroup leaf_model_group Leaf Model API - * + * * \brief Classes / functions for implementing leaf models. - * - * Stochastic tree algorithms are all essentially hierarchical - * models with an adaptive group structure defined by an ensemble - * of decision trees. Each novel model is governed by - * + * + * Stochastic tree algorithms are all essentially hierarchical + * models with an adaptive group structure defined by an ensemble + * of decision trees. Each novel model is governed by + * * - A `LeafModel` class, defining the integrated likelihood and posterior, conditional on a particular tree structure * - A `SuffStat` class that tracks and accumulates sufficient statistics necessary for a `LeafModel` - * - * To provide a thorough overview of this interface (and, importantly, how to extend it), we must introduce some mathematical notation. + * + * To provide a thorough overview of this interface (and, importantly, how to extend it), we must introduce some mathematical notation. * Any forest-based regression model involves an outcome, which we'll call \f$y\f$, and features (or "covariates"), which we'll call \f$X\f$. - * Our goal is to predict \f$y\f$ as a function of \f$X\f$, which we'll call \f$f(X)\f$. - * - * NOTE: if we have a more complicated, but still additive, model, such as \f$y = X\beta + f(X)\f$, then we can just model + * Our goal is to predict \f$y\f$ as a function of \f$X\f$, which we'll call \f$f(X)\f$. + * + * NOTE: if we have a more complicated, but still additive, model, such as \f$y = X\beta + f(X)\f$, then we can just model * \f$y - X\beta = f(X)\f$, treating the residual \f$y - X\beta\f$ as the outcome data, and we are back to the general setting above. - * - * Now, since \f$f(X)\f$ is an additive tree ensemble, we can think of it as the sum of \f$b\f$ separate decision tree functions, + * + * Now, since \f$f(X)\f$ is an additive tree ensemble, we can think of it as the sum of \f$b\f$ separate decision tree functions, * where \f$b\f$ is the number of trees in an ensemble, so that - * + * * \f[ * f(X) = f_1(X) + \dots + f_b(X) * \f] - * - * and each decision tree function \f$f_j\f$ has the property that features \f$X\f$ are used to determine which leaf node an observation - * falls into, and then the parameters attached to that leaf node are used to compute \f$f_j(X)\f$. The exact mechanics of this process + * + * and each decision tree function \f$f_j\f$ has the property that features \f$X\f$ are used to determine which leaf node an observation + * falls into, and then the parameters attached to that leaf node are used to compute \f$f_j(X)\f$. The exact mechanics of this process * are model-dependent, so now we introduce the "leaf node" models that `stochtree` supports. * * \section gaussian_constant_leaf_model Gaussian Constant Leaf Model - * + * * The most standard and common tree ensemble is a sum of "constant leaf" trees, in which a leaf node's parameter uniquely determines the prediction - * for all observations that fall into that leaf. For example, if leaf 2 for a tree is reached by the conditions that \f$X_1 < 0.4 \; \& \; X_2 > 0.6\f$, then - * every observation whose first feature is less than 0.4 and whose second feature is greater than 0.6 will receive the same prediction. Mathematically, + * for all observations that fall into that leaf. For example, if leaf 2 for a tree is reached by the conditions that \f$X_1 < 0.4 \; \& \; X_2 > 0.6\f$, then + * every observation whose first feature is less than 0.4 and whose second feature is greater than 0.6 will receive the same prediction. Mathematically, * for an observation \f$i\f$ this looks like - * + * * \f[ * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \mu_{\ell} * \f] - * + * * where \f$L\f$ denotes the indices of every leaf node, \f$\mu_{\ell}\f$ is the parameter attached to leaf node \f$\ell\f$, and \f$\mathbb{1}(X \in \ell)\f$ * checks whether \f$X_i\f$ falls into leaf node \f$\ell\f$. - * + * * The way that we make such a model "stochastic" is by attaching to the leaf node parameters \f$\mu_{\ell}\f$ a "prior" distribution. - * This leaf model corresponds to the "classic" BART model of Chipman et al (2010) - * as well as its "XBART" extension (He and Hahn (2023)). + * This leaf model corresponds to the "classic" BART model of Chipman et al (2010) + * as well as its "XBART" extension (He and Hahn (2023)). * We assign each leaf node parameter a prior - * + * * \f[ * \mu \sim N\left(0, \tau\right) * \f] - * - * Assuming a homoskedastic Gaussian outcome likelihood (i.e. \f$y_i \sim N\left(f(X_i),\sigma^2\right)\f$), - * the log marginal likelihood in this model, for the outcome data in node \f$\ell\f$ of tree \f$j\f$ is given by - * + * + * Assuming a homoskedastic Gaussian outcome likelihood (i.e. \f$y_i \sim N\left(f(X_i),\sigma^2\right)\f$), + * the log marginal likelihood in this model, for the outcome data in node \f$\ell\f$ of tree \f$j\f$ is given by + * * \f[ * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - n_{\ell}\log(\sigma) + \frac{1}{2} \log\left(\frac{\sigma^2}{n_{\ell} \tau + \sigma^2}\right) - \frac{s_{yy,\ell}}{2\sigma^2} + \frac{\tau s_{y,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * n_{\ell} = \sum_{i : X_i \in \ell} 1 * \f] - * + * * \f[ * s_{y,\ell} = \sum_{i : X_i \in \ell} r_i * \f] - * + * * \f[ * s_{yy,\ell} = \sum_{i : X_i \in \ell} r_i^2 * \f] - * + * * \f[ * r_i = y_i - \sum_{k \neq j} f_k(X_i) * \f] * - * In words, this model depends on the data for a given leaf node only through three sufficient statistics, \f$n_{\ell}\f$, \f$s_{y,\ell}\f$, and \f$s_{yy,\ell}\f$, - * and it only depends on the other trees in the ensemble through the "partial residual" \f$r_i\f$. The posterior distribution for + * In words, this model depends on the data for a given leaf node only through three sufficient statistics, \f$n_{\ell}\f$, \f$s_{y,\ell}\f$, and \f$s_{yy,\ell}\f$, + * and it only depends on the other trees in the ensemble through the "partial residual" \f$r_i\f$. The posterior distribution for * node \f$\ell\f$'s leaf parameter is similarly defined as: - * + * * \f[ * \mu_{\ell} \mid - \sim N\left(\frac{\tau s_{y,\ell}}{n_{\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{n_{\ell} \tau + \sigma^2}\right) * \f] - * - * Now, consider the possibility that each observation carries a unique weight \f$w_i\f$. These could be "case weights" in a survey context or + * + * Now, consider the possibility that each observation carries a unique weight \f$w_i\f$. These could be "case weights" in a survey context or * individual-level variances ("heteroskedasticity"). These case weights transform the outcome distribution (and associated likelihood) to - * + * * \f[ - * y_i \mid - \sim N\left(\mu(X_i), \frac{\sigma^2}{w_i}\right) + * y_i \mid - \sim N\left(\mu(X_i), \frac{\sigma^2}{w_i}\right) * \f] - * - * This gives a modified log marginal likelihood of - * + * + * This gives a modified log marginal likelihood of + * * \f[ * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - \frac{1}{2} \sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right) + \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) - \frac{s_{wyy,\ell}}{2\sigma^2} + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(s_{w,\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * s_{w,\ell} = \sum_{i : X_i \in \ell} w_i * \f] - * + * * \f[ * s_{wy,\ell} = \sum_{i : X_i \in \ell} w_i r_i * \f] - * + * * \f[ * s_{wyy,\ell} = \sum_{i : X_i \in \ell} w_i r_i^2 * \f] - * - * Finally, note that when we consider splitting leaf \f$\ell\f$ into new left and right leaves, or pruning two nodes into a single leaf node, - * we compute the log marginal likelihood of the combined data and the log marginal likelihoods of the left and right leaves and compare these three values. - * - * The terms \f$\frac{n_{\ell}}{2}\log(2\pi)\f$, \f$\sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right)\f$, and \f$\frac{s_{wyy,\ell}}{2\sigma^2}\f$ - * are such that their left and right node values will always sum to the respective value in the combined log marginal likelihood, so they can be ignored + * + * Finally, note that when we consider splitting leaf \f$\ell\f$ into new left and right leaves, or pruning two nodes into a single leaf node, + * we compute the log marginal likelihood of the combined data and the log marginal likelihoods of the left and right leaves and compare these three values. + * + * The terms \f$\frac{n_{\ell}}{2}\log(2\pi)\f$, \f$\sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right)\f$, and \f$\frac{s_{wyy,\ell}}{2\sigma^2}\f$ + * are such that their left and right node values will always sum to the respective value in the combined log marginal likelihood, so they can be ignored * when evaluating splits or prunes and thus the reduced log marginal likelihood is - * + * * \f[ * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} * \f] - * + * * So the \ref StochTree::GaussianConstantSuffStat "GaussianConstantSuffStat" class tracks a generalized version of these three statistics * (which allows for each observation to have a weight \f$w_i \neq 1\f$): - * + * * - \f$n_{\ell}\f$: `data_size_t n` * - \f$s_{w,\ell}\f$: `double sum_w` * - \f$s_{wy,\ell}\f$: `double sum_yw` - * - * And these values are used by the \ref StochTree::GaussianConstantLeafModel "GaussianConstantLeafModel" class in the - * \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", - * \ref StochTree::GaussianConstantLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", - * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterMean "PosteriorParameterMean", and - * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterVariance "PosteriorParameterVariance" methods. + * + * And these values are used by the \ref StochTree::GaussianConstantLeafModel "GaussianConstantLeafModel" class in the + * \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterMean "PosteriorParameterMean", and + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterVariance "PosteriorParameterVariance" methods. * To give one example, below is the implementation of \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood": - * + * * \code{.cpp} * double left_log_ml = ( * -0.5*std::log(1 + tau_*(left_stat.sum_w/global_variance)) + ((tau_*left_stat.sum_yw*left_stat.sum_yw)/(2.0*global_variance*(tau_*left_stat.sum_w + global_variance))) * ); - * + * * double right_log_ml = ( * -0.5*std::log(1 + tau_*(right_stat.sum_w/global_variance)) + ((tau_*right_stat.sum_yw*right_stat.sum_yw)/(2.0*global_variance*(tau_*right_stat.sum_w + global_variance))) * ); - * + * * return left_log_ml + right_log_ml; - * \endcode - * + * \endcode + * * \section gaussian_multivariate_regression_leaf_model Gaussian Multivariate Regression Leaf Model - * - * In this model, the tree defines a "partitioned linear model" in which leaf node parameters define regression weights + * + * In this model, the tree defines a "partitioned linear model" in which leaf node parameters define regression weights * that are multiplied by a "basis" \f$\Omega\f$ to determine the prediction for an observation. - * + * * \f[ * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \Omega_i \vec{\beta_{\ell}} * \f] - * + * * and we assign \f$\beta_{\ell}\f$ a prior of - * + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\vec{\beta_0}, \Sigma_0\right) * \f] - * + * * where \f$\vec{\beta_0}\f$ is typically a vector of zeros. The outcome likelihood is still - * + * * \f[ * y_i \sim N\left(f(X_i), \sigma^2\right) * \f] - * + * * This gives a reduced log integrated likelihood of - * + * * \f[ * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \frac{\Sigma_0\Omega'\Omega}{\sigma^2}\right)\right) + \frac{1}{2}\frac{y'\Omega}{\sigma^2}\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\frac{\Omega'y}{\sigma^2} * \f] - * - * where \f$\Omega\f$ is a matrix of bases for every observation in leaf \f$\ell\f$ and \f$p\f$ is the dimension of \f$\Omega\f$. The posterior for \f$\vec{\beta_{\ell}}\f$ is - * + * + * where \f$\Omega\f$ is a matrix of bases for every observation in leaf \f$\ell\f$ and \f$p\f$ is the dimension of \f$\Omega\f$. The posterior for \f$\vec{\beta_{\ell}}\f$ is + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\left(\frac{\Omega'y}{\sigma^2}\right),\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\right) * \f] - * + * * This is an extension of the single-tree model of Chipman et al (2002), with: - * + * * - Support for using a separate basis for leaf model than the partitioning (i.e. tree) model (i.e. \f$X \neq \Omega\f$) * - Support for multiple trees and sampling via grow-from-root (GFR) or MCMC - * + * * We can also enable heteroskedasticity by defining a (diagonal) covariance matrix for the outcome likelihood - * + * * \f[ * \Sigma_y = \text{diag}\left(\sigma^2 / w_1,\sigma^2 / w_2,\dots,\sigma^2 / w_n\right) * \f] - * + * * This updates the reduced log integrated likelihood to - * + * * \f[ * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \Sigma_{0}\Omega'\Sigma_y^{-1}\Omega\right)\right) + \frac{1}{2}y'\Sigma_{y}^{-1}\Omega\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\Omega'\Sigma_{y}^{-1}y * \f] - * - * and a posterior for \f$\vec{\beta_{\ell}}\f$ of - * + * + * and a posterior for \f$\vec{\beta_{\ell}}\f$ of + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\left(\Omega'\Sigma_{y}^{-1}y\right),\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\right) * \f] - * + * * \section gaussian_univariate_regression_leaf_model Gaussian Univariate Regression Leaf Model - * + * * This specializes the Gaussian Multivariate Regression Leaf Model for a univariate leaf basis, which allows for several computational speedups (replacing generalized matrix operations with simple summation or sum-product operations). - * We simplify \f$\Omega\f$ to \f$\omega\f$, a univariate basis for every observation, so that \f$\Omega'\Omega = \sum_{i:i \in \ell}\omega_i^2\f$ and \f$\Omega'y = \sum_{i:i \in \ell}\omega_ir_i\f$. Similarly, the prior for the leaf - * parameter becomes univariate normal as in \ref gaussian_constant_leaf_model: - * + * We simplify \f$\Omega\f$ to \f$\omega\f$, a univariate basis for every observation, so that \f$\Omega'\Omega = \sum_{i:i \in \ell}\omega_i^2\f$ and \f$\Omega'y = \sum_{i:i \in \ell}\omega_ir_i\f$. Similarly, the prior for the leaf + * parameter becomes univariate normal as in \ref gaussian_constant_leaf_model: + * * \f[ * \beta \sim N\left(0, \tau\right) * \f] - * - * Allowing for case / variance weights \f$w_i\f$ as above, we derive a reduced log marginal likelihood of - * + * + * Allowing for case / variance weights \f$w_i\f$ as above, we derive a reduced log marginal likelihood of + * * \f[ * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wyx,\ell}^2}{2\sigma^2(s_{wxx,\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * s_{wxx,\ell} = \sum_{i : X_i \in \ell} w_i \omega_i \omega_i * \f] - * + * * \f[ * s_{wyx,\ell} = \sum_{i : X_i \in \ell} w_i r_i \omega_i * \f] - * - * and a posterior of - * + * + * and a posterior of + * * \f[ * \beta_{\ell} \mid - \sim N\left(\frac{\tau s_{wyx,\ell}}{s_{wxx,\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) * \f] - * + * * \section inverse_gamma_leaf_model Inverse Gamma Leaf Model - * - * Each of the above models is a variation on a theme: a conjugate, partitioned Gaussian leaf model. + * + * Each of the above models is a variation on a theme: a conjugate, partitioned Gaussian leaf model. * The inverse gamma leaf model allows for forest-based heteroskedasticity modeling using an inverse gamma prior on the exponentiated leaf parameter, as discussed in Murray (2021) * Define a variance function based on an ensemble of \f$b\f$ trees as - * + * * \f[ * \sigma^2(X) = \exp\left(s_1(X) + \dots + s_b(X)\right) * \f] - * - * where each tree function \f$s_j(X)\f$ is defined as - * + * + * where each tree function \f$s_j(X)\f$ is defined as + * * \f[ * s_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \lambda_{\ell} * \f] - * + * * We reparameterize \f$\lambda_{\ell} = \log(\mu_{\ell})\f$ and we place an inverse gamma prior on \f$\mu_{\ell}\f$ - * + * * \f[ * \mu_{\ell} \sim \text{IG}\left(a, b\right) * \f] - * - * As noted in Murray (2021), this model no longer enables the "Bayesian backfitting" simplification - * of conjugated Gaussian leaf models, in which sampling updates for a given tree only depend on other trees in the ensemble via their imprint on the partial residual - * \f$r_i = y_i - \sum_{k \neq j} \mu_k(X_i)\f$. + * + * As noted in Murray (2021), this model no longer enables the "Bayesian backfitting" simplification + * of conjugated Gaussian leaf models, in which sampling updates for a given tree only depend on other trees in the ensemble via their imprint on the partial residual + * \f$r_i = y_i - \sum_{k \neq j} \mu_k(X_i)\f$. * However, this model is part of a broader class of models with convenient "blocked MCMC" sampling updates (another important example being multinomial classification). - * + * * Under an outcome model - * + * * \f[ * y \sim N\left(f(X), \sigma_0^2 \sigma^2(X)\right) * \f] - * + * * updates to \f$\mu_{\ell}\f$ for a given tree \f$j\f$ are based on a reduced log marginal likelihood of - * + * * \f[ * L(y) \propto a \log (b) - \log \Gamma (a) + \log \Gamma \left(a + \frac{n_{\ell}}{2}\right) - \left(a + \frac{n_{\ell}}{2}\right) \left(b + \frac{s_{\sigma,\ell}}{2\sigma_0^2}\right) * \f] - * + * * where - * + * * \f[ * n_{\ell} = \sum_{i : X_i \in \ell} 1 * \f] - * + * * \f[ * s_{\sigma,\ell} = \sum_{i: i \in \ell} \frac{(y_i - f(X_i))^2}{\prod_{k \neq j} s_k(X_i)} * \f] - * - * and a posterior of - * + * + * and a posterior of + * * \f[ * \mu_{\ell} \mid - \sim \text{IG}\left( a + \frac{n_{\ell}}{2} , b + \frac{s_{\sigma,\ell}}{2\sigma_0^2} \right) * \f] - * + * * Thus, as above, we implement a sufficient statistic class (\ref StochTree::LogLinearVarianceSuffStat "LogLinearVarianceSuffStat"), which tracks - * + * * - \f$n_{\ell}\f$: `data_size_t n` * - \f$s_{\sigma,\ell}\f$: `double weighted_sum_ei` - * - * And these values are used by the \ref StochTree::LogLinearVarianceLeafModel "LogLinearVarianceLeafModel" class in the - * \ref StochTree::LogLinearVarianceLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", - * \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", - * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterShape "PosteriorParameterShape", and - * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterScale "PosteriorParameterScale" methods. + * + * And these values are used by the \ref StochTree::LogLinearVarianceLeafModel "LogLinearVarianceLeafModel" class in the + * \ref StochTree::LogLinearVarianceLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterShape "PosteriorParameterShape", and + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterScale "PosteriorParameterScale" methods. * To give one example, below is the implementation of \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood": - * + * * \code{.cpp} * double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_); * double a_term = a_ + 0.5 * suff_stat.n; @@ -337,8 +342,8 @@ namespace StochTree { * double resid_term = a_term * log_b_term; * double log_ml = prior_terms + lgamma_a_term - resid_term; * return log_ml; - * \endcode - * + * \endcode + * * \{ */ @@ -350,9 +355,9 @@ namespace StochTree { * 5. `kCloglogOrdinal`: Every leaf node has a log-gamma prior and every leaf is constant. */ enum ModelType { - kConstantLeafGaussian, - kUnivariateRegressionLeafGaussian, - kMultivariateRegressionLeafGaussian, + kConstantLeafGaussian, + kUnivariateRegressionLeafGaussian, + kMultivariateRegressionLeafGaussian, kLogLinearVariance, kCloglogOrdinal }; @@ -373,7 +378,7 @@ class GaussianConstantSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -398,9 +403,9 @@ class GaussianConstantSuffStat { sum_w = 0.0; sum_yw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianConstantSuffStat& suff_stat) { @@ -410,7 +415,7 @@ class GaussianConstantSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -421,7 +426,7 @@ class GaussianConstantSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -432,7 +437,7 @@ class GaussianConstantSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -440,7 +445,7 @@ class GaussianConstantSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -459,14 +464,14 @@ class GaussianConstantLeafModel { public: /*! * \brief Construct a new GaussianConstantLeafModel object - * + * * \param tau Leaf node prior scale parameter */ GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} ~GaussianConstantLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -474,28 +479,28 @@ class GaussianConstantLeafModel { double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -508,7 +513,7 @@ class GaussianConstantLeafModel { void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); /*! * \brief Set a new value for the leaf node scale parameter - * + * * \param tau Leaf node prior scale parameter */ void SetScale(double tau) {tau_ = tau;} @@ -537,7 +542,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -562,9 +567,9 @@ class GaussianUnivariateRegressionSuffStat { sum_xxw = 0.0; sum_yxw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianUnivariateRegressionSuffStat& suff_stat) { @@ -574,7 +579,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -585,7 +590,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -596,7 +601,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -604,7 +609,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -625,7 +630,7 @@ class GaussianUnivariateRegressionLeafModel { ~GaussianUnivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -633,28 +638,28 @@ class GaussianUnivariateRegressionLeafModel { double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -681,7 +686,7 @@ class GaussianMultivariateRegressionSuffStat { Eigen::MatrixXd ytWX; /*! * \brief Construct a new GaussianMultivariateRegressionSuffStat object - * + * * \param basis_dim Size of the basis vector that defines the leaf regression */ GaussianMultivariateRegressionSuffStat(int basis_dim) { @@ -692,7 +697,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -717,9 +722,9 @@ class GaussianMultivariateRegressionSuffStat { XtWX = Eigen::MatrixXd::Zero(p, p); ytWX = Eigen::MatrixXd::Zero(1, p); } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianMultivariateRegressionSuffStat& suff_stat) { @@ -729,7 +734,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -740,7 +745,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -751,7 +756,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -759,7 +764,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -778,14 +783,14 @@ class GaussianMultivariateRegressionLeafModel { public: /*! * \brief Construct a new GaussianMultivariateRegressionLeafModel object - * + * * \param Sigma_0 Prior covariance, must have the same number of rows and columns as dimensions of the basis vector for the multivariate regression problem */ GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();} ~GaussianMultivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -793,28 +798,28 @@ class GaussianMultivariateRegressionLeafModel { double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -843,7 +848,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -863,7 +868,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(LogLinearVarianceSuffStat& suff_stat) { @@ -872,7 +877,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -882,7 +887,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -892,7 +897,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -900,7 +905,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -921,7 +926,7 @@ class LogLinearVarianceLeafModel { ~LogLinearVarianceLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -929,7 +934,7 @@ class LogLinearVarianceLeafModel { double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ @@ -937,21 +942,21 @@ class LogLinearVarianceLeafModel { double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior shape parameter. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior scale parameter. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "full" residual net of all the model's mean terms @@ -971,13 +976,14 @@ class LogLinearVarianceLeafModel { GammaSampler gamma_sampler_; }; + /*! \brief Sufficient statistic and associated operations for complementary log-log ordinal BART model */ class CloglogOrdinalSuffStat { public: data_size_t n; double sum_Y_less_K; double other_sum; - + /*! * \brief Construct a new CloglogOrdinalSuffStat object, setting all sufficient statistics to zero */ @@ -986,10 +992,10 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = 0.0; other_sum = 0.0; } - + /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containing training data, including covariates * \param outcome Data object containing the original ordinal outcome values, which are used to compute sufficient statistics * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -998,20 +1004,20 @@ class CloglogOrdinalSuffStat { */ void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { n += 1; - + // Get ordinal outcome value for this observation unsigned int y = static_cast(outcome(row_idx)); - + // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma) double Z = tracker.GetOrdinalAuxData(0, row_idx); // latent variables Z - double lambda_minus = tracker.GetOrdinalAuxData(1, row_idx); // forest predictions excluding current tree + double lambda_minus = tracker.GetOrdinalAuxData(1, row_idx); // forest predictions excluding current tree // Get cutpoints gamma and cumulative sum of exp(gamma) const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma const std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // cumsum exp of gamma int K = gamma.size() + 1; // Number of ordinal categories - + if (y == K - 1) { other_sum += std::exp(lambda_minus) * seg[y]; // checked and it's correct } else { @@ -1019,7 +1025,7 @@ class CloglogOrdinalSuffStat { other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct } } - + /*! * \brief Reset all of the sufficient statistics to zero */ @@ -1028,10 +1034,21 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = 0.0; other_sum = 0.0; } - + + /*! + * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` + * + * \param suff_stat Sufficient statistic to be added to the current sufficient statistics + */ + void AddSuffStatInplace(CloglogOrdinalSuffStat& suff_stat) { + n += suff_stat.n; + sum_Y_less_K += suff_stat.sum_Y_less_K; + other_sum += suff_stat.other_sum; + } + /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -1040,10 +1057,10 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = lhs.sum_Y_less_K + rhs.sum_Y_less_K; other_sum = lhs.other_sum + rhs.other_sum; } - + /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -1052,25 +1069,25 @@ class CloglogOrdinalSuffStat { sum_Y_less_K = lhs.sum_Y_less_K - rhs.sum_Y_less_K; other_sum = lhs.other_sum - rhs.other_sum; } - + /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { return n > threshold; } - + /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { return n >= threshold; } - + /*! * \brief Return the sample size accumulated by a sufficient stat object */ @@ -1084,8 +1101,8 @@ class CloglogOrdinalLeafModel { public: /*! * \brief Construct a new CloglogOrdinalLeafModel object - * - * \param a Shape parameter for log-gamma prior on leaf parameters + * + * \param a shape parameter for log-gamma prior on leaf parameters * \param b rate parameter for log-gamma prior on leaf parameters * Log-gamma density: f(x) = b^a / Gamma(a) * exp(a*x - b*exp(x)) * Relationship to tau (scale of leaf parameters): tau^2 = trigamma(a) @@ -1094,6 +1111,7 @@ class CloglogOrdinalLeafModel { a_ = a; b_ = b; gamma_sampler_ = GammaSampler(); + // slice_sampler_ = SliceSampler(); tau_ = std::sqrt(boost::math::trigamma(a_)); } ~CloglogOrdinalLeafModel() {} @@ -1139,9 +1157,39 @@ class CloglogOrdinalLeafModel { inline bool RequiresBasis() {return false;} + // /*! + // * \brief Update the scale parameter (tau_) using slice sampling + // * + // * \param lambda Vector of leaf parameter values from all trees + // * \param scale_sigma_lambda Prior scale parameter for scale parameter (tau_) of leaf parameters + // * \param gen Random number generator + // */ + // void UpdateScaleLambda(const std::vector& lambda, double scale_sigma_lambda, std::mt19937& gen) { + // double n = static_cast(lambda.size()); + // double sum_lambda = 0.0; + // double sum_exp_lambda = 0.0; + + // for (size_t i = 0; i < lambda.size(); i++) { + // sum_lambda += lambda[i]; + // sum_exp_lambda += std::exp(lambda[i]); + // } + + // // Create log-likelihood function + // ScaleLambdaLoglik loglik_func(n, sum_lambda, sum_exp_lambda, scale_sigma_lambda); + + // // Sample new scale parameter using slice sampling + // double current_tau = tau_; + // double w = 1.0; // Step size for slice sampler + // double lower = 1e-6; // Lower bound for tau + // double upper = std::numeric_limits::infinity(); // Upper bound + + // double new_tau = slice_sampler_.Sample(current_tau, &loglik_func, w, lower, upper, gen); + // tau_ = new_tau; + // } + /*! * \brief Convert tau_ (scale_lambda i.e. scale for leaf parameters) to alpha (shape) and beta (rate) parameters for the log-gamma prior - * + * * \param alpha Output: shape parameter for log-gamma prior * \param beta Output: rate parameter for log-gamma prior * \param tau Scale parameter (tau_) for leaf parameters @@ -1155,7 +1203,7 @@ class CloglogOrdinalLeafModel { /*! * \brief Convert alpha (shape) and beta (rate) parameters (for the log-gamma prior) back to tau_ (scale_lambda i.e. scale for leaf parameters) - * + * * \param alpha Shape parameter for log-gamma prior * \param beta Rate parameter for log-gamma prior * \return tau Scale parameter (tau_) for leaf parameters @@ -1169,9 +1217,9 @@ class CloglogOrdinalLeafModel { private: /*! * \brief Compute inverse trigamma function using Newton's method - * + * * Implementation adapted from limma package in R, originally by Gordon Smyth - * + * * \param x Input value for which to compute trigamma inverse * \return Value y such that trigamma(y) = x */ @@ -1198,34 +1246,34 @@ class CloglogOrdinalLeafModel { double a_; double b_; GammaSampler gamma_sampler_; + // SliceSampler slice_sampler_; double tau_; }; -/*! - * \brief Unifying layer for disparate sufficient statistic class types - * - * Joins together GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, - * GaussianMultivariateRegressionSuffStat, and LogLinearVarianceSuffStat - * as a combined "variant" type. See the std::variant documentation +/*! \brief Unifying layer for disparate sufficient statistic class types + * + * Joins together GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, + * GaussianMultivariateRegressionSuffStat, LogLinearVarianceSuffStat, and CloglogOrdinalSuffStat + * as a combined "variant" type. See the std::variant documentation * for more detail. */ -using SuffStatVariant = std::variant; /*! * \brief Unifying layer for disparate leaf model class types - * - * Joins together GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, - * GaussianMultivariateRegressionLeafModel, and LogLinearVarianceLeafModel - * as a combined "variant" type. See the std::variant documentation + * + * Joins together GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, + * GaussianMultivariateRegressionLeafModel, LogLinearVarianceLeafModel, and CloglogOrdinalLeafModel + * as a combined "variant" type. See the std::variant documentation * for more detail. */ -using LeafModelVariant = std::variant; @@ -1241,7 +1289,7 @@ static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_ /*! * \brief Factory function that creates a new `SuffStat` object for the specified model type - * + * * \param model_type Enumeration storing the model type * \param basis_dim [Optional] dimension of the basis vector, only used if `model_type = kMultivariateRegressionLeafGaussian` */ @@ -1261,16 +1309,14 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di /*! * \brief Factory function that creates a new `LeafModel` object for the specified model type - * + * * \param model_type Enumeration storing the model type - * \param tau Value of the leaf node prior scale parameter, only used if `model_type = kConstantLeafGaussian` or `model_type = kUnivariateRegressionLeafGaussian` + * \param tau Value of the leaf node prior scale parameter, only used if `model_type = kConstantLeafGaussian`, `model_type = kUnivariateRegressionLeafGaussian` * \param Sigma0 Value of the leaf node prior covariance matrix, only used if `model_type = kMultivariateRegressionLeafGaussian` - * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` - * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` - * \param c Value of the leaf node log-gamma prior shape parameter, only used if `model_type = kCloglogOrdinal` - * \param d Value of the leaf node log-gamma prior rate parameter, only used if `model_type = kCloglogOrdinal` + * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` (or value of the leaf node log-gamma prior shape parameter, only used if `model_type = kCloglogOrdinal`) + * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` (or value of the leaf node log-gamma prior rate parameter, only used if `model_type = kCloglogOrdinal`) */ -static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b, double c, double d) { +static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { if (model_type == kConstantLeafGaussian) { return createLeafModel(tau); } else if (model_type == kUnivariateRegressionLeafGaussian) { @@ -1280,14 +1326,14 @@ static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau } else if (model_type == kLogLinearVariance) { return createLeafModel(a, b); } else { - return createLeafModel(c, d); + return createLeafModel(a, b); } } template static inline void AccumulateSuffStatProposed( - SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, - ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, + SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, + ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, SuffStatConstructorArgs&... suff_stat_args ) { // Determine the position of the node's indices in the forest tracking data structure @@ -1312,13 +1358,13 @@ static inline void AccumulateSuffStatProposed( std::vector thread_suff_stats_left; std::vector thread_suff_stats_right; for (int i = 0; i < num_threads; i++) { - thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size, + thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size, node_begin_index + (i + 1) * chunk_size); thread_suff_stats_node.emplace_back(suff_stat_args...); thread_suff_stats_left.emplace_back(suff_stat_args...); thread_suff_stats_right.emplace_back(suff_stat_args...); } - + // Accumulate sufficient statistics StochTree::ParallelFor(0, num_threads, num_threads, [&](int i) { int start_idx = thread_ranges[i].first; @@ -1356,7 +1402,7 @@ static inline void AccumulateSuffStatProposed( } template -static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, +static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) { // Acquire iterators auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id); @@ -1392,7 +1438,7 @@ static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, Fo node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id); node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id); } - + // Accumulate sufficient statistics for (auto i = node_begin_iter; i != node_end_iter; i++) { auto idx = *i; @@ -1401,13 +1447,13 @@ static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, Fo } template -static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, - ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, +static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, + ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, int feature_num, int cutpoint_num) { // Acquire iterators auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num); auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num); - + // Determine node start point data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num); diff --git a/include/stochtree/log.h b/include/stochtree/log.h index 9f64c31b..3a4c5600 100644 --- a/include/stochtree/log.h +++ b/include/stochtree/log.h @@ -15,6 +15,8 @@ #include #include #include +#include +#include #include #include diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index d0aa4049..991c254f 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h index ec148b5b..054a14c7 100644 --- a/include/stochtree/ordinal_sampler.h +++ b/include/stochtree/ordinal_sampler.h @@ -43,7 +43,6 @@ class OrdinalSampler { */ static double SampleTruncatedExponential(double lambda, std::mt19937& gen); - /*! * \brief Update truncated exponential latent variables (Z) * diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index a2f5dd70..3f342f15 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -31,7 +31,11 @@ #include #include +#include #include +#include +#include +#include #include namespace StochTree { @@ -104,6 +108,7 @@ class ForestTracker { void SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value); std::vector& GetOrdinalAuxDataVector(int type_idx); + private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ std::vector sum_predictions_; @@ -132,7 +137,7 @@ class ForestTracker { void UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates); void UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); void UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); - + /*! * \brief Track auxiliary data for cloglog ordinal bart models * Vector of vectors to store these auxiliary data @@ -146,6 +151,8 @@ class ForestTracker { * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) */ void ResizeOrdinalAuxData(data_size_t num_observations, int n_levels); + // bool IsValidOrdinalType(int type_idx) const; + // bool IsValidOrdinalIndex(int type_idx, data_size_t obs_idx) const; }; /*! \brief Class storing sample-prediction map for each tree in an ensemble */ @@ -456,7 +463,7 @@ class UnsortedNodeSampleTracker { /*! \brief Number of trees */ int NumTrees() { return num_trees_; } - /*! \brief Return a pointer to the feature partition tracking tree i */ + /*! \brief Number of trees */ FeatureUnsortedPartition* GetFeaturePartition(int i) { return feature_partitions_[i].get(); } private: @@ -637,24 +644,24 @@ class SortedNodeSampleTracker { } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, TreeSplit& split, int num_threads = -1) { - StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, TreeSplit& split) { + for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeature(covariates, node_id, feature_split, split); - }); + } } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, double split_value, int num_threads = -1) { - StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, double split_value) { + for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeatureNumeric(covariates, node_id, feature_split, split_value); - }); + } } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, std::vector const& category_list, int num_threads = -1) { - StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, std::vector const& category_list) { + for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeatureCategorical(covariates, node_id, feature_split, category_list); - }); + } } /*! \brief First index of data points contained in node_id */ diff --git a/include/stochtree/random.h b/include/stochtree/random.h index 3d39b647..a841f396 100644 --- a/include/stochtree/random.h +++ b/include/stochtree/random.h @@ -5,6 +5,7 @@ #ifndef STOCHTREE_RANDOM_H_ #define STOCHTREE_RANDOM_H_ +#include #include #include #include diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index b322a560..701ebeaa 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -17,11 +17,14 @@ #include #include +#include #include #include #include #include +#include #include +#include #include namespace StochTree { diff --git a/include/stochtree/slice_sampler.h b/include/stochtree/slice_sampler.h new file mode 100644 index 00000000..07fe5a26 --- /dev/null +++ b/include/stochtree/slice_sampler.h @@ -0,0 +1,180 @@ +/*! + * Copyright (c) 2024 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_SLICE_SAMPLER_H_ +#define STOCHTREE_SLICE_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#ifndef M_LN2 +#define M_LN2 0.6931471805599453 // ln(2) +#endif + +namespace StochTree { + +/*! + * \brief Abstract base class for log-likelihood functions used in slice sampling + */ +class LoglikFunction { + public: + virtual ~LoglikFunction() {} + + /*! + * \brief Evaluate the log-likelihood function at point x + * \param x Input value + * \return Log-likelihood value + */ + virtual double Evaluate(double x) = 0; +}; + +/*! + * \brief Log-likelihood function for scale_lambda parameter in ordinal models + */ +class ScaleLambdaLoglik : public LoglikFunction { + public: + /*! + * \brief Constructor for scale lambda log-likelihood + * \param n Number of observations (lambda values) + * \param sum_lambda Sum of all lambda values + * \param sum_exp_lambda Sum of exp(lambda) values + * \param scale Prior scale parameter for scale_lambda + */ + ScaleLambdaLoglik(double n, double sum_lambda, double sum_exp_lambda, double scale) + : n_(n), sum_lambda_(sum_lambda), sum_exp_lambda_(sum_exp_lambda), scale_(scale) {} + + /*! + * \brief Evaluate log-likelihood of scale_lambda parameter + * \param sigma Input scale parameter value (scale_lambda) + * \return Log-likelihood value + */ + double Evaluate(double sigma) override { + if (sigma <= 0) return -std::numeric_limits::infinity(); + + // Convert scale_lambda to alpha and beta parameters + double alpha, beta; + ScaleLambdaToAlphaBeta(alpha, beta, sigma); + + // Log-likelihood contribution from lambda values (gamma prior) + double loglik = n_ * alpha * std::log(beta) + - n_ * boost::math::lgamma(alpha) + + alpha * sum_lambda_ + - beta * sum_exp_lambda_; + + // Add constants and prior terms + loglik += M_LN2 - 0.5 * std::log(2.0 * M_PI); // M_LN2 - LN_2_BY_PI approximation + + // Prior on scale_lambda (half-normal or similar) + double scale_ratio = sigma / scale_; + loglik -= 0.5 * scale_ratio * scale_ratio; + + return loglik; + } + + private: + double n_; + double sum_lambda_; + double sum_exp_lambda_; + double scale_; + + /*! + * \brief Convert scale_lambda to alpha and beta parameters for the gamma prior + */ + void ScaleLambdaToAlphaBeta(double& alpha, double& beta, const double sigma) { + double sigma_sq = sigma * sigma; + alpha = TrigammaInverse(sigma_sq); + beta = std::exp(boost::math::digamma(alpha)); + } + + /*! + * \brief Compute inverse trigamma function using Newton's method + */ + double TrigammaInverse(double x) { + if (x > 1E7) return 1.0 / std::sqrt(x); + if (x < 1E-6) return 1.0 / x; + + double y = 0.5 + 1.0 / x; + for (int i = 0; i < 50; i++) { + double tri = boost::math::trigamma(y); + double dif = tri * (1.0 - tri / x) / boost::math::polygamma(3, y); + y += dif; + if (-dif / y < 1E-8) break; + } + return y; + } +}; + +/*! + * \brief Slice sampler implementation + */ +class SliceSampler { + public: + SliceSampler() {} + ~SliceSampler() {} + + /*! + * \brief Sample from a distribution using slice sampling + * \param x0 Initial value + * \param loglik_func Log-likelihood function + * \param w Step size for expanding interval + * \param lower Lower bound + * \param upper Upper bound + * \param gen Random number generator + * \return Sampled value + */ + double Sample(double x0, LoglikFunction* loglik_func, double w, + double lower, double upper, std::mt19937& gen) { + + std::uniform_real_distribution unif(0.0, 1.0); + std::exponential_distribution exp_dist(1.0); + + // Find the log density at the initial point + double gx0 = loglik_func->Evaluate(x0); + + // Determine the slice level, in log terms + double logy = gx0 - exp_dist(gen); + + // Find the initial interval to sample from + double u = w * unif(gen); + double L = x0 - u; + double R = x0 + (w - u); + + // Expand the interval until its ends are outside the slice + while (L > lower && loglik_func->Evaluate(L) > logy) { + L -= w; + } + + while (R < upper && loglik_func->Evaluate(R) > logy) { + R += w; + } + + // Shrink interval to bounds + if (L < lower) L = lower; + if (R > upper) R = upper; + + // Sample from the interval, shrinking it on each rejection + double x1; + do { + x1 = L + (R - L) * unif(gen); + double gx1 = loglik_func->Evaluate(x1); + + if (gx1 >= logy) break; + + if (x1 > x0) { + R = x1; + } else { + L = x1; + } + } while (true); + + return x1; + } +}; + +} // namespace StochTree + +#endif // STOCHTREE_SLICE_SAMPLER_H_ diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index 3810e3cb..85ce7191 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -13,6 +13,9 @@ #include #include +#include +#include +#include #include #include diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 675ef6c0..6b7579c6 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -8,13 +8,19 @@ #include #include #include -#include #include #include #include +#include +#include #include #include +#include +#include +#include +#include +#include #include namespace StochTree { @@ -22,7 +28,7 @@ namespace StochTree { /*! * \defgroup sampling_group Forest Sampler API * - * \brief Functions for sampling from a forest. The core interface of these functions, + * \brief Functions for sampling from a forest. The core interfce of these functions, * as used by the R, Python, and standalone C++ program, is defined by * \ref MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a * given forest, and \ref GFRSampleOneIter, which runs one iteration of the @@ -147,7 +153,7 @@ static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracke } static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree, - int tree_num, int leaf_node, int feature_split, bool keep_sorted = false, int num_threads = -1) { + int tree_num, int leaf_node, int feature_split, bool keep_sorted = false) { // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete if (tree->OutputDimension() > 1) { std::vector temp_leaf_values(tree->OutputDimension(), 0.); @@ -160,7 +166,7 @@ static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& datase int right_node = tree->RightChild(leaf_node); // Update the ForestTracker - tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted, num_threads); + tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted); } static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree, @@ -295,6 +301,8 @@ static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector } } + + static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function op, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); @@ -432,7 +440,7 @@ template EvaluateProposedSplit( ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance, - int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args + LeafSuffStatConstructorArgs&... leaf_suff_stat_args ) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -440,11 +448,8 @@ static inline std::tuple EvaluatePropo LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); // Accumulate sufficient statistics - AccumulateSuffStatProposed( - node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, split, tree_num, leaf_num, split_feature, num_threads, - leaf_suff_stat_args... - ); + AccumulateSuffStatProposed(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, + residual, global_variance, split, tree_num, leaf_num, split_feature, 1, leaf_suff_stat_args...); data_size_t left_n = left_suff_stat.n; data_size_t right_n = right_suff_stat.n; @@ -481,36 +486,164 @@ static inline std::tuple EvaluateExist template static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); } else if (backfitting) { - UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); + UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); } else { - // TODO: think about a generic way to store "state" corresponding to the other models? - UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), false); + // TODO: think about a generic way to store "state" corresponding to the other models? + UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), false); } } template static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); } else if (backfitting) { - UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); + UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); } else { - // TODO: think about a generic way to store "state" corresponding to the other models? - UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), true); + // TODO: think about a generic way to store "state" corresponding to the other models? + UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), true); } } +template +static inline void EvaluateAllPossibleSplits( + ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id, + std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, + data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, + std::vector& feature_types, std::vector& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args +) { + // Initialize sufficient statistics + LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + + // Accumulate aggregate sufficient statistic for the node to be split + AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, split_node_id); + + // Compute the "no split" log marginal likelihood + double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); + + // Unpack data + Eigen::MatrixXd covariates = dataset.GetCovariates(); + Eigen::VectorXd outcome = residual.GetData(); + Eigen::VectorXd var_weights; + bool has_weights = dataset.HasVarWeights(); + if (has_weights) var_weights = dataset.GetVarWeights(); + + // Minimum size of newly created leaf nodes (used to rule out invalid splits) + int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); + + // Compute sufficient statistics for each possible split + data_size_t num_cutpoints = 0; + bool valid_split = false; + data_size_t node_row_iter; + data_size_t current_bin_begin, current_bin_size, next_bin_begin; + data_size_t feature_sort_idx; + data_size_t row_iter_idx; + double outcome_val, outcome_val_sq; + FeatureType feature_type; + double feature_value = 0.0; + double cutoff_value = 0.0; + double log_split_eval = 0.0; + double split_log_ml; + for (int j = 0; j < covariates.cols(); j++) { + + if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) { + // Enumerate cutpoint strides + cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), split_node_id, node_begin, node_end, j, feature_types); + + // Reset sufficient statistics + left_suff_stat.ResetSuffStat(); + right_suff_stat.ResetSuffStat(); + + // Iterate through possible cutpoints + int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); + feature_type = feature_types[j]; + // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins + for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { + current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); + current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); + next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); + + // Accumulate sufficient statistics for the left node + AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, + global_variance, tree_num, split_node_id, j, cutpoint_idx); + + // Compute the corresponding right node sufficient statistics + right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); + + // Store the bin index as the "cutpoint value" - we can use this to query the actual split + // value or the set of split categories later on once a split is chose + cutoff_value = cutpoint_idx; + + // Only include cutpoint for consideration if it defines a valid split in the training data + valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && + right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); + if (valid_split) { + num_cutpoints++; + // Add to split rule vector + cutpoint_feature_types.push_back(feature_type); + cutpoint_features.push_back(j); + cutpoint_values.push_back(cutoff_value); + // Add the log marginal likelihood of the split to the split eval vector + split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); + log_cutpoint_evaluations.push_back(split_log_ml); + } + } + } + + } + + // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) + cutpoint_features.push_back(-1); + cutpoint_values.push_back(std::numeric_limits::max()); + cutpoint_feature_types.push_back(FeatureType::kNumeric); + log_cutpoint_evaluations.push_back(no_split_log_ml); + + // Update valid cutpoint count + valid_cutpoint_count = num_cutpoints; +} + +template +static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, + std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, + std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, + std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, + std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, + std::vector& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + // Evaluate all possible cutpoints according to the leaf node model, + // recording their log-likelihood and other split information in a series of vectors. + // The last element of these vectors concerns the "no-split" option. + EvaluateAllPossibleSplits( + dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations, + cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, + node_begin, node_end, variable_weights, feature_types, feature_subset, leaf_suff_stat_args... + ); + + // Compute an adjustment to reflect the no split prior probability and the number of cutpoints + double bart_prior_no_split_adj; + double alpha = tree_prior.GetAlpha(); + double beta = tree_prior.GetBeta(); + int node_depth = tree->GetDepth(node_id); + if (valid_cutpoint_count == 0) { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); + } else { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); + } + log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; +} + template static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, std::unordered_map>& node_index_map, std::deque& split_queue, int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types, std::vector feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& feature_types, std::vector feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Leaf depth int leaf_depth = tree->GetDepth(node_id); @@ -518,153 +651,41 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel int32_t max_depth = tree_prior.GetMaxDepth(); if ((max_depth == -1) || (leaf_depth < max_depth)) { - - // Vector of vectors to store results for each feature - int p = dataset.NumCovariates(); - std::vector> feature_log_cutpoint_evaluations(p+1); - std::vector> feature_cutpoint_values(p+1); - std::vector feature_cutpoint_counts(p+1, 0); + + // Cutpoint enumeration + std::vector log_cutpoint_evaluations; + std::vector cutpoint_features; + std::vector cutpoint_values; + std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count; - - // Evaluate all possible cutpoints according to the leaf node model, - // recording their log-likelihood and other split information in a series of vectors. - - // Initialize node sufficient statistics - LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - - // Accumulate aggregate sufficient statistic for the node to be split - AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, node_id); - - // Compute the "no split" log marginal likelihood - double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - // Unpack data - Eigen::MatrixXd& covariates = dataset.GetCovariates(); - Eigen::VectorXd& outcome = residual.GetData(); - Eigen::VectorXd var_weights; - bool has_weights = dataset.HasVarWeights(); - if (has_weights) var_weights = dataset.GetVarWeights(); - - // Minimum size of newly created leaf nodes (used to rule out invalid splits) - int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); - - // Compute sufficient statistics for each possible split - data_size_t num_cutpoints = 0; - if (num_threads == -1) { - num_threads = GetOptimalThreadCount(static_cast(covariates.cols() * covariates.rows())); - } - - // Initialize cutpoint grid container - CutpointGridContainer cutpoint_grid_container(covariates, outcome, cutpoint_grid_size); - - // Evaluate all possible splits for each feature in parallel - StochTree::ParallelFor(0, covariates.cols(), num_threads, [&](int j) { - if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) { - // Enumerate cutpoint strides - cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types); - - // Left and right node sufficient statistics - LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - - // Iterate through possible cutpoints - int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); - FeatureType feature_type = feature_types[j]; - // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins - for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { - data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); - data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); - data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); - - // Accumulate sufficient statistics for the left node - AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, - global_variance, tree_num, node_id, j, cutpoint_idx); - - // Compute the corresponding right node sufficient statistics - right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); - - // Store the bin index as the "cutpoint value" - we can use this to query the actual split - // value or the set of split categories later on once a split is chose - double cutoff_value = cutpoint_idx; - - // Only include cutpoint for consideration if it defines a valid split in the training data - bool valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && - right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); - if (valid_split) { - feature_cutpoint_counts[j]++; - // Add to split rule vector - feature_cutpoint_values[j].push_back(cutoff_value); - // Add the log marginal likelihood of the split to the split eval vector - double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - feature_log_cutpoint_evaluations[j].push_back(split_log_ml); - } - } - } - }); - - // Compute total number of cutpoints - valid_cutpoint_count = std::accumulate(feature_cutpoint_counts.begin(), feature_cutpoint_counts.end(), 0); - - // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) - feature_log_cutpoint_evaluations[covariates.cols()].push_back(no_split_log_ml); - - // Compute an adjustment to reflect the no split prior probability and the number of cutpoints - double bart_prior_no_split_adj; - double alpha = tree_prior.GetAlpha(); - double beta = tree_prior.GetBeta(); - int node_depth = tree->GetDepth(node_id); - if (valid_cutpoint_count == 0) { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); - } else { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); - } - feature_log_cutpoint_evaluations[covariates.cols()][0] += bart_prior_no_split_adj; - + CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); + EvaluateCutpoints( + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, + cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, + cutpoint_grid_container, feature_subset, leaf_suff_stat_args... + ); + // TODO: maybe add some checks here? // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood - double largest_ml = -std::numeric_limits::infinity(); - for (int j = 0; j < p + 1; j++) { - if (feature_log_cutpoint_evaluations[j].size() > 0) { - double feature_max_ml = *std::max_element(feature_log_cutpoint_evaluations[j].begin(), feature_log_cutpoint_evaluations[j].end());; - largest_ml = std::max(largest_ml, feature_max_ml); - } - } - std::vector> feature_cutpoint_evaluations(p+1); - for (int j = 0; j < p + 1; j++) { - if (feature_log_cutpoint_evaluations[j].size() > 0) { - feature_cutpoint_evaluations[j].resize(feature_log_cutpoint_evaluations[j].size()); - for (int i = 0; i < feature_log_cutpoint_evaluations[j].size(); i++) { - feature_cutpoint_evaluations[j][i] = std::exp(feature_log_cutpoint_evaluations[j][i] - largest_ml); - } - } + double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); + std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); + for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ + cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); } - - // Compute sum of marginal likelihoods for each feature - std::vector feature_total_cutpoint_evaluations(p+1, 0.0); - for (int j = 0; j < p + 1; j++) { - if (feature_log_cutpoint_evaluations[j].size() > 0) { - feature_total_cutpoint_evaluations[j] = std::accumulate(feature_cutpoint_evaluations[j].begin(), feature_cutpoint_evaluations[j].end(), 0.0); - } else { - feature_total_cutpoint_evaluations[j] = 0.0; - } - } - - // First, sample a feature according to feature_total_cutpoint_evaluations - std::discrete_distribution feature_dist(feature_total_cutpoint_evaluations.begin(), feature_total_cutpoint_evaluations.end()); - int feature_chosen = feature_dist(gen); - - // Then, sample a cutpoint according to feature_cutpoint_evaluations[feature_chosen] - std::discrete_distribution cutpoint_dist(feature_cutpoint_evaluations[feature_chosen].begin(), feature_cutpoint_evaluations[feature_chosen].end()); - data_size_t cutpoint_chosen = cutpoint_dist(gen); - if (feature_chosen == p){ + // Sample the split (including a "no split" option) + std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); + data_size_t split_chosen = split_dist(gen); + + if (split_chosen == valid_cutpoint_count){ // "No split" sampled, don't split or add any nodes to split queue return; } else { // Split sampled - int feature_split = feature_chosen; - FeatureType feature_type = feature_types[feature_split]; - double split_value = feature_cutpoint_values[feature_split][cutpoint_chosen]; + int feature_split = cutpoint_features[split_chosen]; + FeatureType feature_type = cutpoint_feature_types[split_chosen]; + double split_value = cutpoint_values[split_chosen]; // Perform all of the relevant "split" operations in the model, tree and training dataset // Compute node sample size @@ -699,7 +720,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads); + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); // Determine the number of observation in the newly created left node int left_node = tree->LeftChild(node_id); @@ -725,7 +746,7 @@ template & variable_weights, int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size, - int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { int root_id = Tree::kRoot; int curr_node_id; data_size_t curr_node_begin; @@ -781,8 +802,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore SampleSplitRule( tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types, - feature_subset, num_threads, leaf_suff_stat_args... - ); + feature_subset, leaf_suff_stat_args...); } } @@ -820,7 +840,7 @@ template & variable_weights, std::vector& sweep_update_indices, double global_variance, std::vector& feature_types, int cutpoint_grid_size, - bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the GFR algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { @@ -840,7 +860,7 @@ static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& GFRSampleTreeOneIter( tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types, cutpoint_grid_size, - num_features_subsample, num_threads, leaf_suff_stat_args... + num_features_subsample, leaf_suff_stat_args... ); // Sample leaf parameters for tree i @@ -862,7 +882,7 @@ static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& template static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, - double global_variance, double prob_grow_old, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + double global_variance, double prob_grow_old, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Extract dataset information data_size_t n = dataset.GetCovariates().rows(); @@ -907,7 +927,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM // Compute the marginal likelihood of split and no split, given the leaf prior std::tuple split_eval = EvaluateProposedSplit( - dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, num_threads, leaf_suff_stat_args... + dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, leaf_suff_stat_args... ); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); @@ -957,7 +977,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM double log_acceptance_prob = std::log(mh_accept(gen)); if (log_acceptance_prob <= log_mh_ratio) { accept = true; - AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false, num_threads); + AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); } else { accept = false; } @@ -970,7 +990,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM template static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Choose a "leaf parent" node at random int num_leaves = tree->NumLeaves(); int num_leaf_parents = tree->NumLeafParents(); @@ -1049,7 +1069,7 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf template static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Determine whether it is possible to grow any of the leaves bool grow_possible = false; std::vector leaves = tree->GetLeaves(); @@ -1089,11 +1109,11 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For if (step_chosen == 0) { MCMCGrowTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, num_threads, leaf_suff_stat_args... + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, leaf_suff_stat_args... ); } else { MCMCPruneTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, num_threads, leaf_suff_stat_args... + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, leaf_suff_stat_args... ); } } @@ -1128,8 +1148,7 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For template static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, - LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the MCMC algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { @@ -1144,7 +1163,7 @@ static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tree = active_forest.GetTree(i); MCMCSampleTreeOneIter( tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, - global_variance, num_threads, leaf_suff_stat_args... + global_variance, leaf_suff_stat_args... ); // Sample leaf parameters for tree i diff --git a/include/stochtree/variance_model.h b/include/stochtree/variance_model.h index b1c2dabe..79b8831f 100644 --- a/include/stochtree/variance_model.h +++ b/include/stochtree/variance_model.h @@ -12,7 +12,11 @@ #include #include +#include #include +#include +#include +#include namespace StochTree { diff --git a/man/bart.Rd b/man/bart.Rd index 66a9b9ad..c11c619b 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -136,9 +136,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -153,6 +153,6 @@ X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/bcf.Rd b/man/bcf.Rd index 01e5fab8..f7d42e93 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -162,21 +162,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -199,8 +199,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/cloglog_ordinal_bart.Rd b/man/cloglog_ordinal_bart.Rd new file mode 100644 index 00000000..9c2aed51 --- /dev/null +++ b/man/cloglog_ordinal_bart.Rd @@ -0,0 +1,47 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cloglog_ordinal_bart.R +\name{cloglog_ordinal_bart} +\alias{cloglog_ordinal_bart} +\title{Run the BART algorithm for ordinal outcomes using a complementary log-log link} +\usage{ +cloglog_ordinal_bart( + X, + y, + X_test = NULL, + n_trees = 50, + n_samples_mcmc = 500, + n_burnin = 250, + n_thin = 1, + alpha_gamma = 2, + beta_gamma = 2, + variable_weights = NULL, + feature_types = NULL, + seed = NULL +) +} +\arguments{ +\item{X}{A numeric matrix of predictors (training data).} + +\item{y}{A numeric vector of ordinal outcomes (positive integers starting from 1).} + +\item{X_test}{An optional numeric matrix of predictors (test data).} + +\item{n_trees}{Number of trees in the BART ensemble. Default: \code{50}.} + +\item{n_samples_mcmc}{Total number of MCMC samples to draw. Default: \code{500}.} + +\item{n_burnin}{Number of burn-in samples to discard. Default: \code{250}.} + +\item{n_thin}{Thinning interval for MCMC samples. Default: \code{1}.} + +\item{alpha_gamma}{Shape parameter for the log-gamma prior on cutpoints. Default: \code{2.0}.} + +\item{beta_gamma}{Rate parameter for the log-gamma prior on cutpoints. Default: \code{2.0}.} + +\item{variable_weights}{Optional vector of variable weights for splitting (default: equal weights).} + +\item{feature_types}{Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).} +} +\description{ +Run the BART algorithm for ordinal outcomes using a complementary log-log link +} diff --git a/man/createBARTModelFromCombinedJson.Rd b/man/createBARTModelFromCombinedJson.Rd index 35d185c3..83d61d0d 100644 --- a/man/createBARTModelFromCombinedJson.Rd +++ b/man/createBARTModelFromCombinedJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- list(saveBARTModelToJson(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) diff --git a/man/createBARTModelFromCombinedJsonString.Rd b/man/createBARTModelFromCombinedJsonString.Rd index a8470dee..7a17484a 100644 --- a/man/createBARTModelFromCombinedJsonString.Rd +++ b/man/createBARTModelFromCombinedJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) diff --git a/man/createBARTModelFromJson.Rd b/man/createBARTModelFromJson.Rd index 57686122..68a02f0e 100644 --- a/man/createBARTModelFromJson.Rd +++ b/man/createBARTModelFromJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) bart_model_roundtrip <- createBARTModelFromJson(bart_json) diff --git a/man/createBARTModelFromJsonFile.Rd b/man/createBARTModelFromJsonFile.Rd index f714a94a..7608d8d2 100644 --- a/man/createBARTModelFromJsonFile.Rd +++ b/man/createBARTModelFromJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index 67068fd0..0748d97a 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJsonString(bart_model) bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) diff --git a/man/createBCFModelFromCombinedJson.Rd b/man/createBCFModelFromCombinedJson.Rd index 6f29569e..24c82e4f 100644 --- a/man/createBCFModelFromCombinedJson.Rd +++ b/man/createBCFModelFromCombinedJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_list <- list(saveBCFModelToJson(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd index bd7e63f2..e0522f75 100644 --- a/man/createBCFModelFromCombinedJsonString.Rd +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index a579b140..35cff7ce 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index 2661d4de..a2496797 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index 5f34724c..cc944f85 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json <- saveBCFModelToJsonString(bcf_model) bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index d9000925..d7a1adae 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -30,10 +30,10 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_features=p, - num_observations=n, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 0da1eb98..149586a8 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -24,9 +24,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 @@ -51,11 +51,11 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcfmodel.Rd b/man/getRandomEffectSamples.bcfmodel.Rd index 6769de62..08a8eae4 100644 --- a/man/getRandomEffectSamples.bcfmodel.Rd +++ b/man/getRandomEffectSamples.bcfmodel.Rd @@ -24,21 +24,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -74,15 +74,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 2afccbf6..8a0a47bf 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -40,9 +40,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -56,7 +56,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) y_hat_test <- predict(bart_model, X_test)$y_hat } diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index ff315808..907e5308 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -42,21 +42,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -79,8 +79,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) preds <- predict(bcf_model, X_test, Z_test, pi_test) } diff --git a/man/preprocessPredictionData.Rd b/man/preprocessPredictionData.Rd index f881fda8..a6382e69 100644 --- a/man/preprocessPredictionData.Rd +++ b/man/preprocessPredictionData.Rd @@ -22,7 +22,7 @@ types. Matrices will be passed through assuming all columns are numeric. } \examples{ cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) X_preprocessed <- preprocessPredictionData(cov_df, metadata) } diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index f0fec6ca..b02158d4 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -48,23 +48,23 @@ y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_observations=n, - num_features=p, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, - max_depth=max_depth, - variable_weights=variable_weights, - cutpoint_grid_size=cutpoint_grid_size, - leaf_model_type=leaf_model, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_observations=n, + num_features=p, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + variable_weights=variable_weights, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -forest_samples <- createForestSamples(num_trees, leaf_dimension, +forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, global_model_config, + forest_dataset, outcome, forest_samples, active_forest, + rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) diff --git a/man/resetRandomEffectsModel.Rd b/man/resetRandomEffectsModel.Rd index fec99b77..b032ccc2 100644 --- a/man/resetRandomEffectsModel.Rd +++ b/man/resetRandomEffectsModel.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/resetRandomEffectsTracker.Rd b/man/resetRandomEffectsTracker.Rd index 5249ca96..c57af16a 100644 --- a/man/resetRandomEffectsTracker.Rd +++ b/man/resetRandomEffectsTracker.Rd @@ -57,8 +57,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/rootResetRandomEffectsModel.Rd b/man/rootResetRandomEffectsModel.Rd index c58a09e9..4c3cc2f7 100644 --- a/man/rootResetRandomEffectsModel.Rd +++ b/man/rootResetRandomEffectsModel.Rd @@ -63,8 +63,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/rootResetRandomEffectsTracker.Rd b/man/rootResetRandomEffectsTracker.Rd index 8de2c514..6f2dc843 100644 --- a/man/rootResetRandomEffectsTracker.Rd +++ b/man/rootResetRandomEffectsTracker.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/saveBARTModelToJson.Rd b/man/saveBARTModelToJson.Rd index a617532e..054af24e 100644 --- a/man/saveBARTModelToJson.Rd +++ b/man/saveBARTModelToJson.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) } diff --git a/man/saveBARTModelToJsonFile.Rd b/man/saveBARTModelToJsonFile.Rd index 46a3110e..62ef6ad7 100644 --- a/man/saveBARTModelToJsonFile.Rd +++ b/man/saveBARTModelToJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/saveBARTModelToJsonString.Rd b/man/saveBARTModelToJsonString.Rd index c83f9e5d..10927c20 100644 --- a/man/saveBARTModelToJsonString.Rd +++ b/man/saveBARTModelToJsonString.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string <- saveBARTModelToJsonString(bart_model) } diff --git a/man/saveBCFModelToJson.Rd b/man/saveBCFModelToJson.Rd index ae2c286d..2c04d76c 100644 --- a/man/saveBCFModelToJson.Rd +++ b/man/saveBCFModelToJson.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) } diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index e6a9f0aa..584bbbba 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index 4328e525..2182bbe3 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) saveBCFModelToJsonString(bcf_model) } diff --git a/src/Makevars.in b/src/Makevars.in index 4eb970cb..850e2555 100644 --- a/src/Makevars.in +++ b/src/Makevars.in @@ -34,6 +34,7 @@ OBJECTS = \ data.o \ io.o \ leaf_model.o \ + ordinal_sampler.o \ partition_tracker.o \ random_effects.o \ tree.o diff --git a/src/Makevars.win.in b/src/Makevars.win.in index 95bff1dd..e9d54ab6 100644 --- a/src/Makevars.win.in +++ b/src/Makevars.win.in @@ -34,6 +34,7 @@ OBJECTS = \ data.o \ io.o \ leaf_model.o \ + ordinal_sampler.o \ partition_tracker.o \ random_effects.o \ tree.o diff --git a/src/R_data.cpp b/src/R_data.cpp index 39b77ab3..1396575f 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -5,6 +5,7 @@ #include #include #include +#include [[cpp11::register]] cpp11::external_pointer create_forest_dataset_cpp() { diff --git a/src/R_random_effects.cpp b/src/R_random_effects.cpp index e291121c..f627b3c5 100644 --- a/src/R_random_effects.cpp +++ b/src/R_random_effects.cpp @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include [[cpp11::register]] cpp11::external_pointer rfx_container_cpp(int num_components, int num_groups) { diff --git a/src/cpp11.cpp b/src/cpp11.cpp index ef98aac0..881c5314 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -1157,18 +1157,18 @@ extern "C" SEXP _stochtree_compute_leaf_indices_cpp(SEXP forest_container, SEXP END_CPP11 } // sampler.cpp -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_features_subsample, int num_threads); -extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_features_subsample, SEXP num_threads) { +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_features_subsample); +extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_features_subsample) { BEGIN_CPP11 - sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_features_subsample), cpp11::as_cpp>(num_threads)); + sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_features_subsample)); return R_NilValue; END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_threads); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_threads) { +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_threads)); + sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } @@ -1281,6 +1281,83 @@ extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP populatio return cpp11::as_sexp(sample_without_replacement_integer_cpp(cpp11::as_cpp>(population_vector), cpp11::as_cpp>(sampling_probs), cpp11::as_cpp>(sample_size))); END_CPP11 } +// sampler.cpp +void ordinal_aux_data_initialize_cpp(cpp11::external_pointer tracker_ptr, StochTree::data_size_t num_observations, int n_levels); +extern "C" SEXP _stochtree_ordinal_aux_data_initialize_cpp(SEXP tracker_ptr, SEXP num_observations, SEXP n_levels) { + BEGIN_CPP11 + ordinal_aux_data_initialize_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(num_observations), cpp11::as_cpp>(n_levels)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +double ordinal_aux_data_get_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx); +extern "C" SEXP _stochtree_ordinal_aux_data_get_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP obs_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(ordinal_aux_data_get_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(obs_idx))); + END_CPP11 +} +// sampler.cpp +void ordinal_aux_data_set_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx, double value); +extern "C" SEXP _stochtree_ordinal_aux_data_set_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP obs_idx, SEXP value) { + BEGIN_CPP11 + ordinal_aux_data_set_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(obs_idx), cpp11::as_cpp>(value)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +cpp11::writable::doubles ordinal_aux_data_get_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx); +extern "C" SEXP _stochtree_ordinal_aux_data_get_vector_cpp(SEXP tracker_ptr, SEXP type_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(ordinal_aux_data_get_vector_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx))); + END_CPP11 +} +// sampler.cpp +void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx, cpp11::doubles values); +extern "C" SEXP _stochtree_ordinal_aux_data_set_vector_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP values) { + BEGIN_CPP11 + ordinal_aux_data_set_vector_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(values)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_aux_data_update_cumsum_exp_cpp(cpp11::external_pointer tracker_ptr); +extern "C" SEXP _stochtree_ordinal_aux_data_update_cumsum_exp_cpp(SEXP tracker_ptr) { + BEGIN_CPP11 + ordinal_aux_data_update_cumsum_exp_cpp(cpp11::as_cpp>>(tracker_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +cpp11::external_pointer ordinal_sampler_cpp(); +extern "C" SEXP _stochtree_ordinal_sampler_cpp() { + BEGIN_CPP11 + return cpp11::as_sexp(ordinal_sampler_cpp()); + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_latent_variables_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_latent_variables_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP tracker_ptr, SEXP rng_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_latent_variables_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>>(rng_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_gamma_params_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, double alpha_gamma, double beta_gamma, double gamma_0, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_gamma_params_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP tracker_ptr, SEXP alpha_gamma, SEXP beta_gamma, SEXP gamma_0, SEXP rng_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_gamma_params_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(alpha_gamma), cpp11::as_cpp>(beta_gamma), cpp11::as_cpp>(gamma_0), cpp11::as_cpp>>(rng_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_cumsum_exp_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer tracker_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_cumsum_exp_cpp(SEXP sampler_ptr, SEXP tracker_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_cumsum_exp_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(tracker_ptr)); + return R_NilValue; + END_CPP11 +} // serialization.cpp cpp11::external_pointer init_json_cpp(); extern "C" SEXP _stochtree_init_json_cpp() { @@ -1711,6 +1788,16 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, + {"_stochtree_ordinal_aux_data_get_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_get_cpp, 3}, + {"_stochtree_ordinal_aux_data_get_vector_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_get_vector_cpp, 2}, + {"_stochtree_ordinal_aux_data_initialize_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_initialize_cpp, 3}, + {"_stochtree_ordinal_aux_data_set_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_set_cpp, 4}, + {"_stochtree_ordinal_aux_data_set_vector_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_set_vector_cpp, 3}, + {"_stochtree_ordinal_aux_data_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_update_cumsum_exp_cpp, 1}, + {"_stochtree_ordinal_sampler_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_cpp, 0}, + {"_stochtree_ordinal_sampler_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_cumsum_exp_cpp, 2}, + {"_stochtree_ordinal_sampler_update_gamma_params_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_gamma_params_cpp, 8}, + {"_stochtree_ordinal_sampler_update_latent_variables_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_latent_variables_cpp, 5}, {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, @@ -1776,8 +1863,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 19}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 18}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 18}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, diff --git a/src/cutpoint_candidates.cpp b/src/cutpoint_candidates.cpp index e43b8219..4a0845c7 100644 --- a/src/cutpoint_candidates.cpp +++ b/src/cutpoint_candidates.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace StochTree { diff --git a/src/data.cpp b/src/data.cpp index e48e9255..cd2913cf 100644 --- a/src/data.cpp +++ b/src/data.cpp @@ -1,6 +1,7 @@ /*! Copyright (c) 2024 by stochtree authors */ #include #include +#include namespace StochTree { diff --git a/src/forest.cpp b/src/forest.cpp index 968fe95c..02757aa7 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include [[cpp11::register]] cpp11::external_pointer active_forest_cpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { diff --git a/src/io.cpp b/src/io.cpp index 50774d9b..1324957f 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -7,7 +7,9 @@ #include #include +#include #include +#include namespace StochTree { diff --git a/src/kernel.cpp b/src/kernel.cpp index 88f12c53..6b5867bb 100644 --- a/src/kernel.cpp +++ b/src/kernel.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include typedef Eigen::Map> DoubleMatrixType; typedef Eigen::Map> IntMatrixType; diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 78d8da76..3f39fba5 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include namespace StochTree { diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp index 27b11f63..19a7c6b5 100644 --- a/src/ordinal_sampler.cpp +++ b/src/ordinal_sampler.cpp @@ -16,10 +16,10 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // gamma cutpoints const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) - + int K = gamma.size() + 1; // Number of ordinal categories - int N = dataset.NumObservations(); - + int N = dataset.NumObservations(); + // Update truncated exponentials (stored in latent auxiliary data slot 0) // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} @@ -27,7 +27,7 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) for (int i = 0; i < N; i++) { - int y = static_cast(outcome(i)); + int y = static_cast(outcome(i)); if (y == K - 1) { Z[i] = 1.0; } else { @@ -44,14 +44,14 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's const std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables z_i's const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) - + int K = gamma.size() + 1; // Number of ordinal categories int N = dataset.NumObservations(); // Compute sufficient statistics A[k] and B[k] for gamma[k] update std::vector A(K - 1, 0.0); std::vector B(K - 1, 0.0); - + for (int i = 0; i < N; i++) { int y = static_cast(outcome(i)); if (y < K - 1) { @@ -62,16 +62,16 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& B[k] += std::exp(lambda_hat[i]); } } - - // Update gamma parameters using log-gamma sampling + + // Update gamma parameters using log-gamma sampling // First sample all gamma parameters - for (int k = 0; k < static_cast(gamma.size()); k++) { + for (int k = 0; k < static_cast(gamma.size()); k++) { double shape = A[k] + alpha_gamma; - double rate = B[k] + beta_gamma; + double rate = B[k] + beta_gamma; double gamma_sample = gamma_sampler_.Sample(shape, rate, gen); gamma[k] = std::log(gamma_sample); } - + // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability gamma[0] = gamma_0; } @@ -80,7 +80,7 @@ void OrdinalSampler::UpdateCumulativeExpSums(ForestTracker& tracker) { // Get auxiliary data vectors const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) - + // Update seg (sum of exponentials of gamma cutpoints) for (int j = 0; j < static_cast(seg.size()); j++) { if (j == 0) { diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 8359faed..bb35efd7 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -6,6 +6,12 @@ #include #include +#include +#include +#include +#include +#include + namespace StochTree { ForestTracker::ForestTracker(Eigen::MatrixXd& covariates, std::vector& feature_types, int num_trees, int num_observations) { @@ -28,15 +34,15 @@ void ForestTracker::ReconstituteFromForest(TreeEnsemble& forest, ForestDataset& // (1) Updates the residual by adding currently cached tree predictions and subtracting predictions from new tree // (2) Updates sample_node_mapper_, sample_pred_mapper_, and sum_predictions_ based on the new forest UpdateSampleTrackersResidual(forest, dataset, residual, is_mean_model); - + // Since GFR always starts over from root, this data structure can always simply be reset Eigen::MatrixXd& covariates = dataset.GetCovariates(); sorted_node_sample_tracker_.reset(new SortedNodeSampleTracker(presort_container_.get(), covariates, feature_types_)); - + // Reconstitute each of the remaining data structures in the tracker based on splits in the ensemble // UnsortedNodeSampleTracker unsorted_node_sample_tracker_->ReconstituteFromForest(forest, dataset); - + } void ForestTracker::ResetRoot(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num) { @@ -156,7 +162,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& fore for (int j = 0; j < num_trees_; j++) { // Query the previously cached prediction for tree j, observation i prev_tree_pred = sample_pred_mapper_->GetPred(i, j); - + // Compute the new prediction for tree j, observation i new_tree_pred = 0.0; Tree* tree = forest.GetTree(j); @@ -164,7 +170,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& fore for (int32_t k = 0; k < output_dim; k++) { new_tree_pred += tree->LeafValue(nidx, k) * basis(i, k); } - + if (is_mean_model) { // Adjust the residual by adding the previous prediction and subtracting the new prediction new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; @@ -202,7 +208,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& fo Tree* tree = forest.GetTree(j); std::int32_t nidx = EvaluateTree(*tree, covariates, i); new_tree_pred = tree->LeafValue(nidx, 0); - + if (is_mean_model) { // Adjust the residual by adding the previous prediction and subtracting the new prediction new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; @@ -211,7 +217,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& fo new_weight = std::log(dataset.VarWeightValue(i)) + new_tree_pred - prev_tree_pred; dataset.SetVarWeightValue(i, new_weight, true); } - + // Update the sample node mapper and sample prediction mapper sample_node_mapper_->SetNodeId(i, j, nidx); sample_pred_mapper_->SetPred(i, j, new_tree_pred); @@ -280,7 +286,7 @@ void ForestTracker::AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int3 sample_node_mapper_->AddSplit(covariates, split, split_feature, tree_id, split_node_id, left_node_id, right_node_id); unsorted_node_sample_tracker_->PartitionTreeNode(covariates, tree_id, split_node_id, left_node_id, right_node_id, split_feature, split); if (keep_sorted) { - sorted_node_sample_tracker_->PartitionNode(covariates, split_node_id, split_feature, split, num_threads); + sorted_node_sample_tracker_->PartitionNode(covariates, split_node_id, split_feature, split); } } @@ -346,21 +352,21 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d CHECK_EQ(num_deleted_nodes_, 0); data_size_t n = dataset.NumObservations(); CHECK_EQ(indices_.size(), n); - + // Extract covariates Eigen::MatrixXd& covariates = dataset.GetCovariates(); // Set node counters num_nodes_ = tree.NumNodes(); num_deleted_nodes_ = tree.NumDeletedNodes(); - + // Resize tracking vectors node_begin_.resize(num_nodes_); node_length_.resize(num_nodes_); parent_nodes_.resize(num_nodes_); left_nodes_.resize(num_nodes_); right_nodes_.resize(num_nodes_); - + // Unpack tree's splits into this data structure bool is_deleted; TreeNodeType node_type; @@ -399,11 +405,11 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d } else { continue; } - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[i]); auto node_end = (indices_.begin() + node_begin_[i] + node_length_[i]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split_rule.SplitTrue(covariates(row, split_index)); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[i]); num_true = std::distance(node_begin, right_node_begin); @@ -415,7 +421,7 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d parent_nodes_[left_nodes_[i]] = i; left_nodes_[left_nodes_[i]] = StochTree::Tree::kInvalidNodeId; left_nodes_[right_nodes_[i]] = StochTree::Tree::kInvalidNodeId; - + // Add right node tracking information node_begin_[right_nodes_[i]] = node_start_idx + num_true; node_length_[right_nodes_[i]] = num_false; @@ -455,11 +461,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split.SplitTrue(covariates(row, feature_split)); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -474,11 +480,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_split, split_value); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -493,11 +499,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_split, category_list); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -536,7 +542,7 @@ void FeatureUnsortedPartition::ExpandNodeTrackingVectors(int node_id, int left_n parent_nodes_[left_node_id] = node_id; left_nodes_[left_node_id] = StochTree::Tree::kInvalidNodeId; left_nodes_[right_node_id] = StochTree::Tree::kInvalidNodeId; - + // Add right node tracking information right_nodes_[node_id] = right_node_id; node_begin_[right_node_id] = node_start_idx + num_left; @@ -578,7 +584,7 @@ bool FeatureUnsortedPartition::RightNodeIsLeaf(int node_id) { } void FeatureUnsortedPartition::PruneNodeToLeaf(int node_id) { - // No need to "un-sift" the indices in the newly pruned node, we don't depend on the indices + // No need to "un-sift" the indices in the newly pruned node, we don't depend on the indices // having any type of sort order, so the indices will simply be "re-sifted" if the node is later partitioned if (IsLeaf(node_id)) return; if (!LeftNodeIsLeaf(node_id)) { @@ -614,7 +620,7 @@ std::vector FeatureUnsortedPartition::NodeIndices(int node_id) { void FeaturePresortPartition::AddLeftRightNodes(data_size_t left_node_begin, data_size_t left_node_size, data_size_t right_node_begin, data_size_t right_node_size) { // Assumes that we aren't pruning / deleting nodes, since this is for use with recursive algorithms - + // Add the left ("true") node to the offset size vector node_offset_sizes_.emplace_back(left_node_begin, left_node_size); // Add the right ("false") node to the offset size vector @@ -627,11 +633,11 @@ void FeaturePresortPartition::SplitFeature(Eigen::MatrixXd& covariates, int32_t data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split.SplitTrue(covariates(row, feature_index)); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -645,11 +651,11 @@ void FeaturePresortPartition::SplitFeatureNumeric(Eigen::MatrixXd& covariates, i data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_index, split_value); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -663,11 +669,11 @@ void FeaturePresortPartition::SplitFeatureCategorical(Eigen::MatrixXd& covariate data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_index, category_list); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -696,12 +702,12 @@ std::vector FeaturePresortPartition::NodeIndices(int node_id) { return out; } - // ============================================================================ // ORDINAL AUXILIARY DATA METHODS // ============================================================================ double ForestTracker::GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); return ordinal_aux_data_vec_[type_idx][obs_idx]; } @@ -710,10 +716,12 @@ void ForestTracker::InitializeOrdinalAuxData(data_size_t num_observations, int n } void ForestTracker::SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value) { + // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); ordinal_aux_data_vec_[type_idx][obs_idx] = value; } std::vector& ForestTracker::GetOrdinalAuxDataVector(int type_idx) { + // CHECK(IsValidOrdinalType(type_idx)); return ordinal_aux_data_vec_[type_idx]; } @@ -735,4 +743,16 @@ void ForestTracker::ResizeOrdinalAuxData(data_size_t num_observations, int n_lev } } +// bool ForestTracker::IsValidOrdinalType(int type_idx) const { +// return (type_idx >= 0 && type_idx < static_cast(ordinal_aux_data_vec_.size()) && +// !ordinal_aux_data_vec_.empty()); +// } + +// bool ForestTracker::IsValidOrdinalIndex(int type_idx, data_size_t obs_idx) const { +// if (!IsValidOrdinalType(type_idx)) { +// return false; +// } +// return (obs_idx >= 0 && obs_idx < ordinal_aux_data_vec_[type_idx].size()); +// } + } // namespace StochTree diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 950caeb8..34931fa9 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) @@ -1077,7 +1078,7 @@ class ForestSamplerCpp { void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, py::array_t feature_types, py::array_t sweep_update_indices, int cutpoint_grid_size, py::array_t leaf_model_scale_input, py::array_t variable_weights, double a_forest, double b_forest, double global_variance, - int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true, int num_threads = -1) { + int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true) { // Refactoring completely out of the Python interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; @@ -1139,23 +1140,23 @@ class ForestSamplerCpp { std::mt19937* rng_ptr = rng.GetRng(); if (gfr) { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample); } } else { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false); } } } diff --git a/src/sampler.cpp b/src/sampler.cpp index ee8bd6e6..255f6e7c 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -4,39 +4,41 @@ #include #include #include +#include #include #include #include -#include +#include #include +#include +#include [[cpp11::register]] -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, - cpp11::external_pointer split_prior, - cpp11::external_pointer rng, - cpp11::integers sweep_indices, - cpp11::integers feature_types, int cutpoint_grid_size, - cpp11::doubles_matrix<> leaf_model_scale_input, - cpp11::doubles variable_weights, +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::external_pointer split_prior, + cpp11::external_pointer rng, + cpp11::integers sweep_indices, + cpp11::integers feature_types, int cutpoint_grid_size, + cpp11::doubles_matrix<> leaf_model_scale_input, + cpp11::doubles variable_weights, double a_forest, double b_forest, - double global_variance, int leaf_model_int, - bool keep_forest, int num_features_subsample, - int num_threads + double global_variance, int leaf_model_int, + bool keep_forest, int num_features_subsample ) { // Refactoring completely out of the R interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; - + // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Unpack sweep indices std::vector sweep_indices_(sweep_indices.size()); // if (sweep_indices.size() > 0) { @@ -45,19 +47,20 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { var_weights_vector[i] = variable_weights[i]; } - + // Prepare the samplers StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); int num_basis = data->NumBasis(); - + // Run one iteration of the sampler if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample); } } [[cpp11::register]] -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, - cpp11::external_pointer split_prior, - cpp11::external_pointer rng, - cpp11::integers sweep_indices, - cpp11::integers feature_types, int cutpoint_grid_size, - cpp11::doubles_matrix<> leaf_model_scale_input, - cpp11::doubles variable_weights, +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::external_pointer split_prior, + cpp11::external_pointer rng, + cpp11::integers sweep_indices, + cpp11::integers feature_types, int cutpoint_grid_size, + cpp11::doubles_matrix<> leaf_model_scale_input, + cpp11::doubles variable_weights, double a_forest, double b_forest, - double global_variance, int leaf_model_int, - bool keep_forest, int num_threads + double global_variance, int leaf_model_int, + bool keep_forest ) { // Refactoring completely out of the R interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; - + // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Unpack sweep indices std::vector sweep_indices_; if (sweep_indices.size() > 0) { @@ -127,19 +130,20 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { var_weights_vector[i] = variable_weights[i]; } - + // Prepare the samplers StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); int num_basis = data->NumBasis(); - + // Run one iteration of the sampler if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false); } } [[cpp11::register]] -double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, - cpp11::external_pointer dataset, - cpp11::external_pointer rng, +double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, + cpp11::external_pointer dataset, + cpp11::external_pointer rng, double a, double b ) { // Run one iteration of the sampler @@ -191,8 +197,8 @@ double sample_sigma2_one_iteration_cpp(cpp11::external_pointer active_forest, - cpp11::external_pointer rng, +double sample_tau_one_iteration_cpp(cpp11::external_pointer active_forest, + cpp11::external_pointer rng, double a, double b ) { // Run one iteration of the sampler @@ -209,7 +215,7 @@ cpp11::external_pointer rng_cpp(int random_seed = -1) { } else { rng_ = std::make_unique(random_seed); } - + // Release management of the pointer to R session return cpp11::external_pointer(rng_.release()); } @@ -218,7 +224,7 @@ cpp11::external_pointer rng_cpp(int random_seed = -1) { cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth = -1) { // Create smart pointer to newly allocated object std::unique_ptr prior_ptr_ = std::make_unique(alpha, beta, min_samples_leaf, max_depth); - + // Release management of the pointer to R session return cpp11::external_pointer(prior_ptr_.release()); } @@ -275,10 +281,10 @@ cpp11::external_pointer forest_tracker_cpp(cpp11::exte for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Create smart pointer to newly allocated object std::unique_ptr tracker_ptr_ = std::make_unique(data->GetCovariates(), feature_types_, num_trees, n); - + // Release management of the pointer to R session return cpp11::external_pointer(tracker_ptr_.release()); } @@ -295,8 +301,8 @@ cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_point [[cpp11::register]] cpp11::writable::integers sample_without_replacement_integer_cpp( - cpp11::integers population_vector, - cpp11::doubles sampling_probs, + cpp11::integers population_vector, + cpp11::doubles sampling_probs, int sample_size ) { // Unpack pointer to population vector @@ -308,14 +314,14 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( // Create output vector cpp11::writable::integers output(sample_size); - + // Unpack pointer to output vector int* output_ptr = INTEGER(PROTECT(output)); // Create C++ RNG std::random_device rd; std::mt19937 gen(rd()); - + // Run the sampler StochTree::sample_without_replacement( output_ptr, sampling_probs_ptr, population_vector_ptr, population_size, sample_size, gen @@ -372,8 +378,8 @@ void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr) { // Get auxiliary data vectors const std::vector& gamma = tracker_ptr->GetOrdinalAuxDataVector(2); // cutpoints gamma - std::vector& seg = tracker_ptr->GetOrdinalAuxDataVector(3); // cumsum exp gamma - + std::vector& seg = tracker_ptr->GetOrdinalAuxDataVector(3); // cumsum exp gamma + // Update seg (cumulative sum of exp(gamma)) for (size_t j = 0; j < seg.size(); j++) { if (j == 0) { @@ -397,7 +403,7 @@ cpp11::external_pointer ordinal_sampler_cpp() { [[cpp11::register]] void ordinal_sampler_update_latent_variables_cpp( cpp11::external_pointer sampler_ptr, - cpp11::external_pointer data_ptr, + cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, cpp11::external_pointer rng_ptr @@ -427,3 +433,4 @@ void ordinal_sampler_update_cumsum_exp_cpp( sampler_ptr->UpdateCumulativeExpSums(*tracker_ptr); } + diff --git a/src/serialization.cpp b/src/serialization.cpp index fb248f62..749395e8 100644 --- a/src/serialization.cpp +++ b/src/serialization.cpp @@ -8,6 +8,9 @@ #include #include #include +#include +#include +#include [[cpp11::register]] cpp11::external_pointer init_json_cpp() { diff --git a/src/stochtree_types.h b/src/stochtree_types.h index d3d6327c..9f4e77df 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -1,8 +1,10 @@ #include #include #include +#include #include #include +#include #include #include #include diff --git a/src/tree.cpp b/src/tree.cpp index 32c51475..fa6fd8f8 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -8,6 +8,9 @@ #include #include +#include +#include +#include namespace StochTree { @@ -665,6 +668,7 @@ void Tree::from_json(const json& tree_json) { tree_json.at("has_categorical_split").get_to(this->has_categorical_split_); tree_json.at("output_dimension").get_to(this->output_dimension_); tree_json.at("is_log_scale").get_to(this->is_log_scale_); + this->num_deleted_nodes = 0; // Unpack the array based fields JsonToTreeNodeVectors(tree_json, this); From c8492fb7650f7eccbb4e79fcc7b5797c5dbece80 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Sun, 28 Sep 2025 06:24:48 -0500 Subject: [PATCH 07/34] =?UTF-8?q?Tested=20CLogLog=20Ordinal=20BART=20?= =?UTF-8?q?=E2=80=94=20running=20successfully!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/debug/testing_cloglog_ordinal_bart.R | 111 +++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tools/debug/testing_cloglog_ordinal_bart.R diff --git a/tools/debug/testing_cloglog_ordinal_bart.R b/tools/debug/testing_cloglog_ordinal_bart.R new file mode 100644 index 00000000..78f8d61c --- /dev/null +++ b/tools/debug/testing_cloglog_ordinal_bart.R @@ -0,0 +1,111 @@ +# Simulate ordinal data and run Cloglog Ordinal BART + +# Load +library(stochtree) + +set.seed(2025) + +# Simulation +n_samples <- 2000 +n_features <- 5 +n_categories <- 3 + +X <- matrix(rnorm(n_samples * n_features), n_samples, n_features) + +beta <- rep(1 / sqrt(n_features), n_features) +gamma_true <- c(-2, 1) + +linear_predictor <- X %*% beta + +# Transform linear predictor using the complementary log-log link function +p_0 <- 1 - exp(-exp(gamma_true[1] + linear_predictor)) +p_1 <- exp(-exp(gamma_true[1] + linear_predictor)) * + (1 - exp(-exp(gamma_true[2] + linear_predictor))) +p_2 <- exp(-exp(gamma_true[1] + linear_predictor)) * + exp(-exp(gamma_true[2] + linear_predictor)) + +true_probs <- cbind(p_0, p_1, p_2) + +# Get Outcomes +ordinal_outcome <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(ordinal_outcome), "\n") + +train_index <- 1:(n_samples/2) +test_index <- (1:n_samples)[- train_index] + +X_train <- X[train_index, ] +y_train <- ordinal_outcome[train_index] +X_test <- X[-train_index, ] +y_test <- ordinal_outcome[-train_index] + +out <- cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + n_samples_mcmc = 1000, + n_burnin = 500, + n_thin = 1 +) + + +# Inference and diagnostics +par(mfrow = c(2, 1)) +plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[1], col = 'red', lty = 2) +plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[2], col = 'red', lty = 2) + +gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) +summary(gamma1) +hist(gamma1) + +gamma2 <- out$gamma_samples[2,] + colMeans(out$forest_predictions_train) +summary(gamma2) +hist(gamma2) + +par(mfrow = c(3,2), mar = c(5,4,1,1)) +rowMeans(out$gamma_samples) +moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) +plot(moo[,1]) +abline(h = gamma_true[1] + mean(linear_predictor[train_index])) +plot(moo[,2]) +abline(h = gamma_true[2] + mean(linear_predictor[train_index])) +plot(out$gamma_samples[1,]) +plot(out$gamma_samples[2,]) + +# Compare forest predictions with the truth + +plot(rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train), linear_predictor[train_index]) +abline(a=0,b=1,col='blue', lwd=2) + +plot(rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test), linear_predictor[test_index]) +abline(a=0,b=1,col='blue', lwd=2) + +# Train set ordinal class probabilities + +p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[1, ]))) +p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_train + out$gamma_samples[1,]))) +p_hat_2 <- 1 - p_hat_1 - p_hat_0 + +mean(log(-log(1 - p_hat_0)) - rowMeans(out$forest_predictions_train)) + +plot(p_hat_0, p_0[train_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_1, p_1[train_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_2, p_2[train_index]) +abline(a=0,b=1,col='blue', lwd=2) + +# Test set ordinal class probabilities + +p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[1, ]))) +p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_test + out$gamma_samples[1,]))) +p_hat_2 <- 1 - p_hat_1 - p_hat_0 + +plot(p_hat_0, p_0[test_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_1, p_1[test_index]) +abline(a=0,b=1,col='blue', lwd=2) +plot(p_hat_2, p_2[test_index]) +abline(a=0,b=1,col='blue', lwd=2) + From 444c0674e8d8d538e044ecddcc7deea4570c3d55 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Sun, 28 Sep 2025 19:37:54 -0500 Subject: [PATCH 08/34] Added vignette for CLogLog Ordinal Bart --- NAMESPACE | 1 + R/cloglog_ordinal_bart.R | 1 + tools/debug/testing_cloglog_ordinal_bart.R | 159 +++++++++++-------- vignettes/CLogLogOrdinalBart.Rmd | 173 +++++++++++++++++++++ vignettes/vignettes.bib | 9 +- 5 files changed, 278 insertions(+), 65 deletions(-) create mode 100644 vignettes/CLogLogOrdinalBart.Rmd diff --git a/NAMESPACE b/NAMESPACE index 2f4103c0..a4062f5e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,6 +7,7 @@ S3method(predict,bcfmodel) export(bart) export(bcf) export(calibrateInverseGammaErrorVariance) +export(cloglog_ordinal_bart) export(computeForestLeafIndices) export(computeForestLeafVariances) export(computeForestMaxLeafIndex) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index 9cc9b63a..a8117c77 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -11,6 +11,7 @@ #' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`. #' @param variable_weights Optional vector of variable weights for splitting (default: equal weights). #' @param feature_types Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). +#' @export cloglog_ordinal_bart <- function(X, y, X_test = NULL, diff --git a/tools/debug/testing_cloglog_ordinal_bart.R b/tools/debug/testing_cloglog_ordinal_bart.R index 78f8d61c..71ef790a 100644 --- a/tools/debug/testing_cloglog_ordinal_bart.R +++ b/tools/debug/testing_cloglog_ordinal_bart.R @@ -5,38 +5,47 @@ library(stochtree) set.seed(2025) -# Simulation -n_samples <- 2000 -n_features <- 5 -n_categories <- 3 - -X <- matrix(rnorm(n_samples * n_features), n_samples, n_features) - -beta <- rep(1 / sqrt(n_features), n_features) -gamma_true <- c(-2, 1) - -linear_predictor <- X %*% beta - -# Transform linear predictor using the complementary log-log link function -p_0 <- 1 - exp(-exp(gamma_true[1] + linear_predictor)) -p_1 <- exp(-exp(gamma_true[1] + linear_predictor)) * - (1 - exp(-exp(gamma_true[2] + linear_predictor))) -p_2 <- exp(-exp(gamma_true[1] + linear_predictor)) * - exp(-exp(gamma_true[2] + linear_predictor)) +# Sample size and number of predictors +n <- 2000 +p <- 5 -true_probs <- cbind(p_0, p_1, p_2) +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta -# Get Outcomes -ordinal_outcome <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) -cat("Outcome distribution:", table(ordinal_outcome), "\n") -train_index <- 1:(n_samples/2) -test_index <- (1:n_samples)[- train_index] - -X_train <- X[train_index, ] -y_train <- ordinal_outcome[train_index] -X_test <- X[-train_index, ] -y_test <- ordinal_outcome[-train_index] +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories <- 3 +gamma_true <- c(-2, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") + +# CLogLog Ordinal BART model fitting +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) + +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] out <- cloglog_ordinal_bart( X = X_train, @@ -67,45 +76,67 @@ par(mfrow = c(3,2), mar = c(5,4,1,1)) rowMeans(out$gamma_samples) moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) plot(moo[,1]) -abline(h = gamma_true[1] + mean(linear_predictor[train_index])) +abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) plot(moo[,2]) -abline(h = gamma_true[2] + mean(linear_predictor[train_index])) +abline(h = gamma_true[2] + mean(true_lambda_function[train_idx])) plot(out$gamma_samples[1,]) plot(out$gamma_samples[2,]) -# Compare forest predictions with the truth - -plot(rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train), linear_predictor[train_index]) -abline(a=0,b=1,col='blue', lwd=2) - -plot(rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test), linear_predictor[test_index]) +# Compare forest predictions with the truth function (for training and test sets) +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') -# Train set ordinal class probabilities - -p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[1, ]))) -p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_train + out$gamma_samples[1,]))) -p_hat_2 <- 1 - p_hat_1 - p_hat_0 - -mean(log(-log(1 - p_hat_0)) - rowMeans(out$forest_predictions_train)) - -plot(p_hat_0, p_0[train_index]) -abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_1, p_1[train_index]) -abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_2, p_2[train_index]) -abline(a=0,b=1,col='blue', lwd=2) - -# Test set ordinal class probabilities - -p_hat_0 <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[1, ]))) -p_hat_1 <- rowMeans((1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[2,]))) * exp(-exp(out$forest_predictions_test + out$gamma_samples[1,]))) -p_hat_2 <- 1 - p_hat_1 - p_hat_0 - -plot(p_hat_0, p_0[test_index]) +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_1, p_1[test_index]) -abline(a=0,b=1,col='blue', lwd=2) -plot(p_hat_2, p_2[test_index]) -abline(a=0,b=1,col='blue', lwd=2) - +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) + +# Compare estimated vs true class probabilities for training set +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) + +# Compare estimated vs true class probabilities for test set +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} diff --git a/vignettes/CLogLogOrdinalBart.Rmd b/vignettes/CLogLogOrdinalBart.Rmd new file mode 100644 index 00000000..a87b1ebb --- /dev/null +++ b/vignettes/CLogLogOrdinalBart.Rmd @@ -0,0 +1,173 @@ +--- +title: "Complementary Log-Log Ordinal BART in StochTree" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{CLogLog-Ordinal-BART} + %\VignetteEncoding{UTF-8} + %\VignetteEngine{knitr::rmarkdown} +bibliography: vignettes.bib +editor_options: + markdown: + wrap: 72 +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +This vignette demonstrates how to use the `cloglog_ordinal_bart()` function for modeling ordinal outcomes using a complementary log-log link function in the BART (Bayesian Additive Regression Trees) framework. + +To begin, we load the `stochtree` package. + +```{r setup} +library(stochtree) +``` + +# Introduction to Ordinal BART with CLogLog Link + +Ordinal data represents outcomes that have a natural ordering but undefined distances between categories. Examples include survey responses (strongly disagree, disagree, neutral, agree, strongly agree), severity ratings (mild, moderate, severe), or educational levels (elementary, high school, college, graduate). + +The complementary log-log (CLogLog) model uses the link function: +$$\text{cloglog}(p) = \log(-\log(1-p))$$ + +This link function is asymmetric and particularly appropriate when the probability of being in higher categories changes rapidly at certain thresholds, making it different from the symmetric probit or logit links commonly used in ordinal regression. + +In the BART framework with CLogLog ordinal regression, we model: +$$P(Y = k \mid Y \geq k, X = x) = 1 - \exp\left(-e^{\gamma_k + \lambda(x)}\right)$$ + +where $\lambda(x)$ is learned by the BART ensemble and $c_k = \log \sum_{j \leq k}e^{\gamma_j}$ are the cutpoints for the ordinal categories. + +## Data Simulation + +```{r demo1-simulation} +set.seed(2025) +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories <- 3 +gamma_true <- c(-2, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") +``` + +## Model Fitting + +Now let's fit the CLogLog Ordinal BART model: + +```{r demo1-model-fitting} +# Split data into train and test sets +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) + +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +# Fit CLogLog Ordinal BART model +out <- stochtree::cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + n_samples_mcmc = 1000, + n_burnin = 500, + n_thin = 1 +) +``` + +## Model Results and Interpretation + +Let's examine the posterior samples and model performance: + +```{r demo1-results} +# Compare forest predictions with the truth function (for training and test sets) +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} + +# Compare estimated vs true class probabilities for training set +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} + +# Compare estimated vs true class probabilities for test set +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} +``` + +# Conclusion + +The CLogLog Ordinal BART model in `stochtree` provides a flexible and powerful approach for modeling ordinal outcomes, especially better suited for asymmetric outcomes: Rare events (e.g., credit default, fraud detection, system failures, adverse drug reactions), Toxic thresholds (e.g., credit risk escalation, dose-response toxicity, engagement drop-offs), Discrete survival outcomes (e.g., time-to-default, customer churn, progression-free survival). + +The CLogLog Ordinal BART implementation in `stochtree` builds on the paper by @alam2025unified. + +# References diff --git a/vignettes/vignettes.bib b/vignettes/vignettes.bib index a1b0a768..65a6f152 100644 --- a/vignettes/vignettes.bib +++ b/vignettes/vignettes.bib @@ -117,4 +117,11 @@ @book{scholkopf2002learning author={Sch{\"o}lkopf, Bernhard and Smola, Alexander J}, year={2002}, publisher={MIT press} -} \ No newline at end of file +} + +@article{alam2025unified, + title={A Unified Bayesian Nonparametric Framework for Ordinal, Survival, and Density Regression Using the Complementary Log-Log Link}, + author={Alam, Entejar and Linero, Antonio R}, + journal={arXiv preprint arXiv:2502.00606}, + year={2025} +} From 132071ebefe99cb657c3a9a95cc10d5609c0fd19 Mon Sep 17 00:00:00 2001 From: Entejar Alam Date: Tue, 30 Sep 2025 15:11:46 -0500 Subject: [PATCH 09/34] Update leaf_model.h --- include/stochtree/leaf_model.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 6adf9c23..02bf4d16 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -403,7 +403,7 @@ class GaussianConstantSuffStat { sum_w = 0.0; sum_yw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics From 18c9e158455f3aedf015782bd1f161bfcbebe948 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 11:46:33 -0400 Subject: [PATCH 10/34] Migrated auxiliary data to ForestDataset from ForestTracker --- DESCRIPTION | 2 +- R/cloglog_ordinal_bart.R | 47 +- R/cpp11.R | 64 +-- R/data.R | 63 +++ R/model.R | 1 - include/stochtree/data.h | 47 ++ include/stochtree/leaf_model.h | 10 +- include/stochtree/ordinal_sampler.h | 19 +- include/stochtree/partition_tracker.h | 29 -- include/stochtree/tree_sampler.h | 4 +- man/ForestDataset.Rd | 167 ++++++ src/R_data.cpp | 46 ++ src/cpp11.cpp | 563 +++++++++++---------- src/ordinal_sampler.cpp | 27 +- src/partition_tracker.cpp | 53 -- src/sampler.cpp | 65 +-- tools/debug/testing_cloglog_ordinal_bart.R | 4 + 17 files changed, 713 insertions(+), 498 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 15e4c0c0..1e1b9806 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Description: Flexible stochastic tree ensemble software. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 LinkingTo: cpp11, BH Suggests: diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index a8117c77..a6687de5 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -98,19 +98,40 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, as.integer(n_samples) ) stochtree:::ordinal_aux_data_initialize_cpp(forest_tracker, as.integer(n_samples), as.integer(n_levels)) - - # Initialize gamma parameters to zero (slot 2) + # Latent variable (Z in Alam et al (2025) notation) + dataX$add_auxiliary_dimension(nrow(X)) + # Forest predictions (eta in Alam et al (2025) notation) + dataX$add_auxiliary_dimension(nrow(X)) + # Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation) + dataX$add_auxiliary_dimension(n_levels - 1) + # Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation) + # This auxiliary series is designed so that the element stored at position `i` + # corresponds to the sum of all exponentiated gamma_j values for j < i. + # It has n_levels elements instead of n_levels - 1 because even the largest + # categorical index has a valid value of sum_{j < i} exp(gamma_j) + dataX$add_auxiliary_dimension(n_levels) + + # Initialize gamma parameters to zero (3rd auxiliary data series, mapped to `dim_idx = 2` with 0-indexing) initial_gamma <- rep(0.0, n_levels - 1) for (i in seq_along(initial_gamma)) { - stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 2, i - 1, initial_gamma[i]) + dataX$set_auxiliary_data_value(2, i - 1, initial_gamma[i]) + } + + # Convert the log-scale parameters into cumulative exponentiated parameters. + # This is done under the hood in a C++ function for efficiency. + dataX$update_auxiliary_data_vector_cumulative_exp_sum(2, 3) + + # Initialize forest predictions to zero (slot 1) + for (i in 1:n_samples) { + dataX$set_auxiliary_data_value(1, i - 1, 0.0) } - stochtree:::ordinal_aux_data_update_cumsum_exp_cpp(forest_tracker) - # Initialize forest predictions slot to zero (slot 1) + # Initialize latent variables to zero (slot 0) for (i in 1:n_samples) { - stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, i - 1, 0.0) + dataX$set_auxiliary_data_value(0, i - 1, 0.0) } + # Initialize samplers ordinal_sampler <- stochtree:::ordinal_sampler_cpp() rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed) @@ -135,32 +156,32 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions # This is needed for updating gamma parameters, latent z_i's forest_pred_current <- active_forest$predict(dataX) - for (j in 1:n_samples) { - stochtree:::ordinal_aux_data_set_cpp(forest_tracker, 1, j - 1, forest_pred_current[j]) + for (i in 1:n_samples) { + dataX$set_auxiliary_data_value(1, i - 1, forest_pred_current[i]); } # 2. Sample latent z_i's using truncated exponential stochtree:::ordinal_sampler_update_latent_variables_cpp( - ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker, rng$rng_ptr + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, rng$rng_ptr ) # 3. Sample gamma parameters stochtree:::ordinal_sampler_update_gamma_params_cpp( - ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, forest_tracker, + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, alpha_gamma, beta_gamma, gamma_0, rng$rng_ptr ) # 4. Update cumulative sum of exp(gamma) values - stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, forest_tracker) + dataX$update_auxiliary_data_vector_cumulative_exp_sum(2,3) if (keep_sample) { forest_pred_train[, sample_counter] <- active_forest$predict(dataX) if (has_test) { forest_pred_test[, sample_counter] <- active_forest$predict(dataXtest) } - gamma_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 2) + gamma_current <- dataX$get_auxiliary_data_vector(2) gamma_samples[, sample_counter] <- gamma_current - latent_current <- stochtree:::ordinal_aux_data_get_vector_cpp(forest_tracker, 0) + latent_current <- dataX$get_auxiliary_data_vector(0) latent_samples[, sample_counter] <- latent_current } } diff --git a/R/cpp11.R b/R/cpp11.R index 64db4be1..995dc82a 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -56,6 +56,34 @@ forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) { .Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr) } +forest_dataset_has_auxiliary_dimension_cpp <- function(dataset_ptr, dim_idx) { + .Call(`_stochtree_forest_dataset_has_auxiliary_dimension_cpp`, dataset_ptr, dim_idx) +} + +forest_dataset_add_auxiliary_dimension_cpp <- function(dataset_ptr, dim_size) { + invisible(.Call(`_stochtree_forest_dataset_add_auxiliary_dimension_cpp`, dataset_ptr, dim_size)) +} + +forest_dataset_get_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx) { + .Call(`_stochtree_forest_dataset_get_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx) +} + +forest_dataset_set_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx, value) { + invisible(.Call(`_stochtree_forest_dataset_set_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx, value)) +} + +forest_dataset_get_auxiliary_data_vector_cpp <- function(dataset_ptr, dim_idx) { + .Call(`_stochtree_forest_dataset_get_auxiliary_data_vector_cpp`, dataset_ptr, dim_idx) +} + +forest_dataset_store_auxiliary_data_vector_as_column_cpp <- function(dataset_ptr, output_matrix, dim_idx, matrix_col_idx) { + .Call(`_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp`, dataset_ptr, output_matrix, dim_idx, matrix_col_idx) +} + +forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum <- function(dataset_ptr, reference_vector_idx, target_vector_idx) { + invisible(.Call(`_stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum`, dataset_ptr, reference_vector_idx, target_vector_idx)) +} + create_column_vector_cpp <- function(outcome) { .Call(`_stochtree_create_column_vector_cpp`, outcome) } @@ -692,44 +720,20 @@ sample_without_replacement_integer_cpp <- function(population_vector, sampling_p .Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size) } -ordinal_aux_data_initialize_cpp <- function(tracker_ptr, num_observations, n_levels) { - invisible(.Call(`_stochtree_ordinal_aux_data_initialize_cpp`, tracker_ptr, num_observations, n_levels)) -} - -ordinal_aux_data_get_cpp <- function(tracker_ptr, type_idx, obs_idx) { - .Call(`_stochtree_ordinal_aux_data_get_cpp`, tracker_ptr, type_idx, obs_idx) -} - -ordinal_aux_data_set_cpp <- function(tracker_ptr, type_idx, obs_idx, value) { - invisible(.Call(`_stochtree_ordinal_aux_data_set_cpp`, tracker_ptr, type_idx, obs_idx, value)) -} - -ordinal_aux_data_get_vector_cpp <- function(tracker_ptr, type_idx) { - .Call(`_stochtree_ordinal_aux_data_get_vector_cpp`, tracker_ptr, type_idx) -} - -ordinal_aux_data_set_vector_cpp <- function(tracker_ptr, type_idx, values) { - invisible(.Call(`_stochtree_ordinal_aux_data_set_vector_cpp`, tracker_ptr, type_idx, values)) -} - -ordinal_aux_data_update_cumsum_exp_cpp <- function(tracker_ptr) { - invisible(.Call(`_stochtree_ordinal_aux_data_update_cumsum_exp_cpp`, tracker_ptr)) -} - ordinal_sampler_cpp <- function() { .Call(`_stochtree_ordinal_sampler_cpp`) } -ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr) { - invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, rng_ptr)) +ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, rng_ptr)) } -ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) { - invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, tracker_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr)) +ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr)) } -ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, tracker_ptr) { - invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, tracker_ptr)) +ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, data_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, data_ptr)) } init_json_cpp <- function() { diff --git a/R/data.R b/R/data.R index 13cd714f..f946a146 100644 --- a/R/data.R +++ b/R/data.R @@ -108,6 +108,69 @@ ForestDataset <- R6::R6Class( #' @return True if variance weights are loaded, false otherwise has_variance_weights = function() { return(dataset_has_variance_weights_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has auxiliary data stored at the dimension indicated + #' @param dim_idx Dimension of auxiliary data + #' @return True if auxiliary data has been allocated for `dim_idx` False otherwise + has_auxiliary_dimension = function(dim_idx) { + return(forest_dataset_has_auxiliary_dimension_cpp(self$data_ptr, dim_idx)) + }, + + #' @description + #' Initialize a new dimension / lane of auxiliary data and allocate data in its place + #' @param dim_size Size of the new vector of data to allocate + #' @return None + add_auxiliary_dimension = function(dim_size) { + return(forest_dataset_add_auxiliary_dimension_cpp(self$data_ptr, dim_size)) + }, + + #' @description + #' Retrieve auxiliary data value + #' @param dim_idx Dimension from which data value to be retrieved + #' @param element_idx Element to retrieve from dimension `dim_idx` + #' @return Floating point value stored in the requested auxiliary data space + get_auxiliary_data_value = function(dim_idx, element_idx) { + return(forest_dataset_get_auxiliary_data_value_cpp(self$data_ptr, dim_idx, element_idx)) + }, + + #' @description + #' Set auxiliary data value + #' @param dim_idx Dimension in which data value to be set + #' @param element_idx Element to set within dimension `dim_idx` + #' @param value Data value to set at auxiliary data dimension `dim_idx` and element `element_idx` + #' @return None + set_auxiliary_data_value = function(dim_idx, element_idx, value) { + return(forest_dataset_set_auxiliary_data_value_cpp(self$data_ptr, dim_idx, element_idx, value)) + }, + + #' @description + #' Retrieve entire auxiliary data vector + #' @param dim_idx Dimension to retrieve + #' @return Vector of all of the auxiliary data stored at dimension `dim_idx` + get_auxiliary_data_vector = function(dim_idx) { + return(forest_dataset_get_auxiliary_data_vector_cpp(self$data_ptr, dim_idx)) + }, + + #' @description + #' Retrieve auxiliary data vector and place it into a column of the supplied matrix + #' @param output_matrix Matrix to be overwritten + #' @param dim_idx Auxiliary data dimension to retrieve + #' @param matrix_col_idx Matrix column in which to copy auxiliary data + #' @return Vector of all of the auxiliary data stored at dimension `dim_idx` + store_auxiliary_data_vector_matrix = function(output_matrix, dim_idx, matrix_col_idx) { + return(forest_dataset_store_auxiliary_data_vector_as_column_cpp(self$data_ptr, output_matrix, dim_idx, matrix_col_idx)) + }, + + #' @description + #' Updates the elements of one auxiliary data vector based on the cumulative exponentiated values of elements of another vector. + #' If the target value has `k` elements, the reference vector must have `k - 1` elements. + #' @param reference_vector_idx Index of the auxiliary data vector to be exponentiated and scanned + #' @param target_vector_idx Index of the auxiliary data vector to be written with exponentiated and scanned values of `reference_vector_idx` + #' @return None + update_auxiliary_data_vector_cumulative_exp_sum = function(reference_vector_idx, target_vector_idx) { + return(forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(self$data_ptr, reference_vector_idx, target_vector_idx)) } ) ) diff --git a/R/model.R b/R/model.R index 38df5970..a6553dd6 100644 --- a/R/model.R +++ b/R/model.R @@ -378,7 +378,6 @@ createForestModel <- function( )) } - #' Draw `sample_size` samples from `population_vector` without replacement, weighted by `sampling_probabilities` #' #' @param population_vector Vector from which to draw samples. diff --git a/include/stochtree/data.h b/include/stochtree/data.h index df232fb3..1bc56a6e 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -470,6 +470,46 @@ class ForestDataset { if (exponentiate) var_weights_.SetElement(row_id, std::exp(new_value)); else var_weights_.SetElement(row_id, new_value); } + /*! + * \brief Auxiliary data management methods + * Methods to initialize, get, and set auxiliary data for BART models with more structure than the ``classic`` conjugate-Gaussian leaf BART model + */ + void AddAuxiliaryDimension(int dim_size) { + if (!has_auxiliary_data_) has_auxiliary_data_ = true; + auxiliary_data_.resize(num_auxiliary_dims_); + auxiliary_data_[num_auxiliary_dims_].assign(dim_size, 0.0); + num_auxiliary_dims_++; + } + double GetAuxiliaryDataValue(int dim_idx, data_size_t element_idx) { + return auxiliary_data_[dim_idx][element_idx]; + } + void SetAuxiliaryDataValue(int dim_idx, data_size_t element_idx, double value) { + auxiliary_data_[dim_idx][element_idx] = value; + } + std::vector& GetAuxiliaryDataVector(int dim_idx) { + return auxiliary_data_[dim_idx]; + } + const std::vector& GetAuxiliaryDataVectorConst(int dim_idx) { + return auxiliary_data_[dim_idx]; + } + bool HasAuxiliaryDimension(int dim_idx) { + return (num_auxiliary_dims_ > dim_idx) & (dim_idx >= 0); + } + + void UpdateAuxiliaryDataVectorCumulativeExpSum(int reference_vector_idx, int target_vector_idx) { + CHECK(HasAuxiliaryDimension(reference_vector_idx)); + CHECK(HasAuxiliaryDimension(target_vector_idx)); + const std::vector& reference_vector = GetAuxiliaryDataVectorConst(reference_vector_idx); + std::vector& target_vector = GetAuxiliaryDataVector(target_vector_idx); + int num_levels = target_vector.size(); + double cumulative_exp_sum = 0.0; + target_vector[0] = cumulative_exp_sum; + for (int i = 1; i < num_levels - 1; i++) { + cumulative_exp_sum += std::exp(reference_vector[i]); + target_vector[i] = cumulative_exp_sum; + } + } + private: ColumnMatrix covariates_; ColumnMatrix basis_; @@ -480,6 +520,13 @@ class ForestDataset { bool has_covariates_{false}; bool has_basis_{false}; bool has_var_weights_{false}; + + /*! + * \brief Vector of vectors to track (potentially jagged) auxiliary data for complex BART models + */ + std::vector> auxiliary_data_; + int num_auxiliary_dims_{0}; + bool has_auxiliary_data_{false}; }; /*! \brief API for loading and accessing data used to sample (additive) random effects */ diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 02bf4d16..092989b3 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -1009,13 +1009,13 @@ class CloglogOrdinalSuffStat { unsigned int y = static_cast(outcome(row_idx)); // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma) - double Z = tracker.GetOrdinalAuxData(0, row_idx); // latent variables Z - double lambda_minus = tracker.GetOrdinalAuxData(1, row_idx); // forest predictions excluding current tree + double Z = dataset.GetAuxiliaryDataValue(0, row_idx); // latent variables Z + double lambda_minus = dataset.GetAuxiliaryDataValue(1, row_idx); // forest predictions excluding current tree // Get cutpoints gamma and cumulative sum of exp(gamma) - const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma - const std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // cumsum exp of gamma - + const std::vector& gamma = dataset.GetAuxiliaryDataVectorConst(2); // cutpoints gamma + const std::vector& seg = dataset.GetAuxiliaryDataVectorConst(3); // cumsum exp of gamma + int K = gamma.size() + 1; // Number of ordinal categories if (y == K - 1) { diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h index 054a14c7..a83fdf9f 100644 --- a/include/stochtree/ordinal_sampler.h +++ b/include/stochtree/ordinal_sampler.h @@ -46,35 +46,32 @@ class OrdinalSampler { /*! * \brief Update truncated exponential latent variables (Z) * - * \param dataset Forest dataset containing training data (covariates) + * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling * \param outcome Vector of outcome values - * \param tracker Forest tracker containing auxiliary data * \param gen Random number generator */ - void UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, - std::mt19937& gen); + void UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, std::mt19937& gen); /*! * \brief Update gamma cutpoint parameters * - * \param dataset Forest dataset containing training data (covariates) + * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling * \param outcome Vector of outcome values - * \param tracker Forest tracker containing auxiliary data * \param alpha_gamma Shape parameter for log-gamma prior on cutpoints gamma * \param beta_gamma Rate parameter for log-gamma prior on cutpoints gamma * \param gamma_0 Fixed value for first cutpoint parameter (for identifiability) * \param gen Random number generator */ - void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, - double alpha_gamma, double beta_gamma, double gamma_0, - std::mt19937& gen); + void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, + double alpha_gamma, double beta_gamma, + double gamma_0, std::mt19937& gen); /*! * \brief Update cumulative exponential sums (seg) * - * \param tracker Forest tracker containing auxiliary data + * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling */ - void UpdateCumulativeExpSums(ForestTracker& tracker); + void UpdateCumulativeExpSums(ForestDataset& dataset); private: GammaSampler gamma_sampler_; diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 3f342f15..a62121b1 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -96,19 +96,6 @@ class ForestTracker { int GetNumTrees() {return num_trees_;} int GetNumFeatures() {return num_features_;} bool Initialized() {return initialized_;} - - /*! - * \brief Ordinal auxiliary data management methods - * Methods to initialize, get, and set auxiliary data for cloglog ordinal bart models - * n_levels is the number of outcome levels for the ordinal response - * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) - */ - void InitializeOrdinalAuxData(data_size_t num_observations, int n_levels); - double GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const; - void SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value); - std::vector& GetOrdinalAuxDataVector(int type_idx); - - private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ std::vector sum_predictions_; @@ -137,22 +124,6 @@ class ForestTracker { void UpdateSampleTrackersInternal(TreeEnsemble& forest, Eigen::MatrixXd& covariates); void UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); void UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& forest, ForestDataset& dataset, ColumnVector& residual, bool is_mean_model); - - /*! - * \brief Track auxiliary data for cloglog ordinal bart models - * Vector of vectors to store these auxiliary data - * Each inner vector holds one type of data (order: Latent variable Z, Forest predictions, Cutpoints gamma, Cumsum exp of cutpoints) - */ - std::vector> ordinal_aux_data_vec_; - - /*! - * \brief Private helper methods for ordinal auxiliary data - * n_levels is the number of outcome levels for the ordinal response - * type_idx is the index of the type of auxiliary data (0: latent Z, 1: forest predictions, 2: cutpoints gamma, 3: cumsum exp of cutpoints) - */ - void ResizeOrdinalAuxData(data_size_t num_observations, int n_levels); - // bool IsValidOrdinalType(int type_idx) const; - // bool IsValidOrdinalIndex(int type_idx, data_size_t obs_idx) const; }; /*! \brief Class storing sample-prediction map for each tree in an ensemble */ diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 6b7579c6..8c41ff3d 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -423,7 +423,7 @@ static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& tracker.SetTreeSamplePrediction(i, tree_num, pred_value); tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta); // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num) - tracker.SetOrdinalAuxData(1, i, tracker.GetSamplePrediction(i) - pred_value); + dataset.SetAuxiliaryDataValue(1, i, tracker.GetSamplePrediction(i) - pred_value); } else { // If the tree has not yet been modified via a sampling step, // we can query its prediction directly from the SamplePredMapper stored in tracker @@ -431,7 +431,7 @@ static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num): needed? since tree not changed? double current_lambda_hat = tracker.GetSamplePrediction(i); double lambda_minus = current_lambda_hat - pred_value; - tracker.SetOrdinalAuxData(1, i, lambda_minus); + dataset.SetAuxiliaryDataValue(1, i, lambda_minus); } } } diff --git a/man/ForestDataset.Rd b/man/ForestDataset.Rd index dfd7760f..ac6e34d5 100644 --- a/man/ForestDataset.Rd +++ b/man/ForestDataset.Rd @@ -29,6 +29,13 @@ weights are optional. \item \href{#method-ForestDataset-get_variance_weights}{\code{ForestDataset$get_variance_weights()}} \item \href{#method-ForestDataset-has_basis}{\code{ForestDataset$has_basis()}} \item \href{#method-ForestDataset-has_variance_weights}{\code{ForestDataset$has_variance_weights()}} +\item \href{#method-ForestDataset-has_auxiliary_dimension}{\code{ForestDataset$has_auxiliary_dimension()}} +\item \href{#method-ForestDataset-add_auxiliary_dimension}{\code{ForestDataset$add_auxiliary_dimension()}} +\item \href{#method-ForestDataset-get_auxiliary_data_value}{\code{ForestDataset$get_auxiliary_data_value()}} +\item \href{#method-ForestDataset-set_auxiliary_data_value}{\code{ForestDataset$set_auxiliary_data_value()}} +\item \href{#method-ForestDataset-get_auxiliary_data_vector}{\code{ForestDataset$get_auxiliary_data_vector()}} +\item \href{#method-ForestDataset-store_auxiliary_data_vector_matrix}{\code{ForestDataset$store_auxiliary_data_vector_matrix()}} +\item \href{#method-ForestDataset-update_auxiliary_data_vector_cumulative_exp_sum}{\code{ForestDataset$update_auxiliary_data_vector_cumulative_exp_sum()}} } } \if{html}{\out{
}} @@ -195,4 +202,164 @@ Whether or not a dataset has variance weights True if variance weights are loaded, false otherwise } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-has_auxiliary_dimension}{}}} +\subsection{Method \code{has_auxiliary_dimension()}}{ +Whether or not a dataset has auxiliary data stored at the dimension indicated +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$has_auxiliary_dimension(dim_idx)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension of auxiliary data} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +True if auxiliary data has been allocated for \code{dim_idx} False otherwise +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-add_auxiliary_dimension}{}}} +\subsection{Method \code{add_auxiliary_dimension()}}{ +Initialize a new dimension / lane of auxiliary data and allocate data in its place +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$add_auxiliary_dimension(dim_size)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_size}}{Size of the new vector of data to allocate} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +None +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_auxiliary_data_value}{}}} +\subsection{Method \code{get_auxiliary_data_value()}}{ +Retrieve auxiliary data value +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_auxiliary_data_value(dim_idx, element_idx)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension from which data value to be retrieved} + +\item{\code{element_idx}}{Element to retrieve from dimension \code{dim_idx}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Floating point value stored in the requested auxiliary data space +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-set_auxiliary_data_value}{}}} +\subsection{Method \code{set_auxiliary_data_value()}}{ +Set auxiliary data value +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$set_auxiliary_data_value(dim_idx, element_idx, value)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension in which data value to be set} + +\item{\code{element_idx}}{Element to set within dimension \code{dim_idx}} + +\item{\code{value}}{Data value to set at auxiliary data dimension \code{dim_idx} and element \code{element_idx}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +None +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_auxiliary_data_vector}{}}} +\subsection{Method \code{get_auxiliary_data_vector()}}{ +Retrieve entire auxiliary data vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_auxiliary_data_vector(dim_idx)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension to retrieve} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Vector of all of the auxiliary data stored at dimension \code{dim_idx} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-store_auxiliary_data_vector_matrix}{}}} +\subsection{Method \code{store_auxiliary_data_vector_matrix()}}{ +Retrieve auxiliary data vector and place it into a column of the supplied matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$store_auxiliary_data_vector_matrix( + output_matrix, + dim_idx, + matrix_col_idx +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{output_matrix}}{Matrix to be overwritten} + +\item{\code{dim_idx}}{Auxiliary data dimension to retrieve} + +\item{\code{matrix_col_idx}}{Matrix column in which to copy auxiliary data} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Vector of all of the auxiliary data stored at dimension \code{dim_idx} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-update_auxiliary_data_vector_cumulative_exp_sum}{}}} +\subsection{Method \code{update_auxiliary_data_vector_cumulative_exp_sum()}}{ +Updates the elements of one auxiliary data vector based on the cumulative exponentiated values of elements of another vector. +If the target value has \code{k} elements, the reference vector must have \code{k - 1} elements. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$update_auxiliary_data_vector_cumulative_exp_sum( + reference_vector_idx, + target_vector_idx +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{reference_vector_idx}}{Index of the auxiliary data vector to be exponentiated and scanned} + +\item{\code{target_vector_idx}}{Index of the auxiliary data vector to be written with exponentiated and scanned values of \code{reference_vector_idx}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +None +} +} } diff --git a/src/R_data.cpp b/src/R_data.cpp index 1396575f..084e215c 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -149,6 +149,52 @@ cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external return output; } +[[cpp11::register]] +bool forest_dataset_has_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_idx) { + return dataset_ptr->HasAuxiliaryDimension(dim_idx); +} + +[[cpp11::register]] +void forest_dataset_add_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_size) { + dataset_ptr->AddAuxiliaryDimension(dim_size); +} + +[[cpp11::register]] +double forest_dataset_get_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx) { + return dataset_ptr->GetAuxiliaryDataValue(dim_idx, element_idx); +} + +[[cpp11::register]] +void forest_dataset_set_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx, double value) { + dataset_ptr->SetAuxiliaryDataValue(dim_idx, element_idx, value); +} + +[[cpp11::register]] +cpp11::writable::doubles forest_dataset_get_auxiliary_data_vector_cpp(cpp11::external_pointer dataset_ptr, int dim_idx) { + const std::vector output_raw = dataset_ptr->GetAuxiliaryDataVector(dim_idx); + int n = output_raw.size(); + cpp11::writable::doubles output(n); + for (int i = 0; i < n; i++) { + output[i] = output_raw[i]; + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles_matrix<> forest_dataset_store_auxiliary_data_vector_as_column_cpp(cpp11::external_pointer dataset_ptr, cpp11::writable::doubles_matrix<> output_matrix, int dim_idx, int matrix_col_idx) { + const std::vector output_raw = dataset_ptr->GetAuxiliaryDataVector(dim_idx); + int n = output_raw.size(); + for (int i = 0; i < n; i++) { + output_matrix(i, matrix_col_idx) = output_raw[i]; + } + return output_matrix; +} + +[[cpp11::register]] +void forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(cpp11::external_pointer dataset_ptr, int reference_vector_idx, int target_vector_idx) { + dataset_ptr->UpdateAuxiliaryDataVectorCumulativeExpSum(reference_vector_idx, target_vector_idx); +} + [[cpp11::register]] cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome) { // Unpack pointers to data and dimensions diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 881c5314..3336adca 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -109,6 +109,58 @@ extern "C" SEXP _stochtree_forest_dataset_get_variance_weights_cpp(SEXP dataset_ END_CPP11 } // R_data.cpp +bool forest_dataset_has_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_idx); +extern "C" SEXP _stochtree_forest_dataset_has_auxiliary_dimension_cpp(SEXP dataset_ptr, SEXP dim_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_has_auxiliary_dimension_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx))); + END_CPP11 +} +// R_data.cpp +void forest_dataset_add_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_size); +extern "C" SEXP _stochtree_forest_dataset_add_auxiliary_dimension_cpp(SEXP dataset_ptr, SEXP dim_size) { + BEGIN_CPP11 + forest_dataset_add_auxiliary_dimension_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_size)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +double forest_dataset_get_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx); +extern "C" SEXP _stochtree_forest_dataset_get_auxiliary_data_value_cpp(SEXP dataset_ptr, SEXP dim_idx, SEXP element_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_auxiliary_data_value_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx), cpp11::as_cpp>(element_idx))); + END_CPP11 +} +// R_data.cpp +void forest_dataset_set_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx, double value); +extern "C" SEXP _stochtree_forest_dataset_set_auxiliary_data_value_cpp(SEXP dataset_ptr, SEXP dim_idx, SEXP element_idx, SEXP value) { + BEGIN_CPP11 + forest_dataset_set_auxiliary_data_value_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx), cpp11::as_cpp>(element_idx), cpp11::as_cpp>(value)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles forest_dataset_get_auxiliary_data_vector_cpp(cpp11::external_pointer dataset_ptr, int dim_idx); +extern "C" SEXP _stochtree_forest_dataset_get_auxiliary_data_vector_cpp(SEXP dataset_ptr, SEXP dim_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_auxiliary_data_vector_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles_matrix<> forest_dataset_store_auxiliary_data_vector_as_column_cpp(cpp11::external_pointer dataset_ptr, cpp11::writable::doubles_matrix<> output_matrix, int dim_idx, int matrix_col_idx); +extern "C" SEXP _stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp(SEXP dataset_ptr, SEXP output_matrix, SEXP dim_idx, SEXP matrix_col_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_store_auxiliary_data_vector_as_column_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>>(output_matrix), cpp11::as_cpp>(dim_idx), cpp11::as_cpp>(matrix_col_idx))); + END_CPP11 +} +// R_data.cpp +void forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(cpp11::external_pointer dataset_ptr, int reference_vector_idx, int target_vector_idx); +extern "C" SEXP _stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(SEXP dataset_ptr, SEXP reference_vector_idx, SEXP target_vector_idx) { + BEGIN_CPP11 + forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(reference_vector_idx), cpp11::as_cpp>(target_vector_idx)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome); extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) { BEGIN_CPP11 @@ -1282,52 +1334,6 @@ extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP populatio END_CPP11 } // sampler.cpp -void ordinal_aux_data_initialize_cpp(cpp11::external_pointer tracker_ptr, StochTree::data_size_t num_observations, int n_levels); -extern "C" SEXP _stochtree_ordinal_aux_data_initialize_cpp(SEXP tracker_ptr, SEXP num_observations, SEXP n_levels) { - BEGIN_CPP11 - ordinal_aux_data_initialize_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(num_observations), cpp11::as_cpp>(n_levels)); - return R_NilValue; - END_CPP11 -} -// sampler.cpp -double ordinal_aux_data_get_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx); -extern "C" SEXP _stochtree_ordinal_aux_data_get_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP obs_idx) { - BEGIN_CPP11 - return cpp11::as_sexp(ordinal_aux_data_get_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(obs_idx))); - END_CPP11 -} -// sampler.cpp -void ordinal_aux_data_set_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx, double value); -extern "C" SEXP _stochtree_ordinal_aux_data_set_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP obs_idx, SEXP value) { - BEGIN_CPP11 - ordinal_aux_data_set_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(obs_idx), cpp11::as_cpp>(value)); - return R_NilValue; - END_CPP11 -} -// sampler.cpp -cpp11::writable::doubles ordinal_aux_data_get_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx); -extern "C" SEXP _stochtree_ordinal_aux_data_get_vector_cpp(SEXP tracker_ptr, SEXP type_idx) { - BEGIN_CPP11 - return cpp11::as_sexp(ordinal_aux_data_get_vector_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx))); - END_CPP11 -} -// sampler.cpp -void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx, cpp11::doubles values); -extern "C" SEXP _stochtree_ordinal_aux_data_set_vector_cpp(SEXP tracker_ptr, SEXP type_idx, SEXP values) { - BEGIN_CPP11 - ordinal_aux_data_set_vector_cpp(cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(type_idx), cpp11::as_cpp>(values)); - return R_NilValue; - END_CPP11 -} -// sampler.cpp -void ordinal_aux_data_update_cumsum_exp_cpp(cpp11::external_pointer tracker_ptr); -extern "C" SEXP _stochtree_ordinal_aux_data_update_cumsum_exp_cpp(SEXP tracker_ptr) { - BEGIN_CPP11 - ordinal_aux_data_update_cumsum_exp_cpp(cpp11::as_cpp>>(tracker_ptr)); - return R_NilValue; - END_CPP11 -} -// sampler.cpp cpp11::external_pointer ordinal_sampler_cpp(); extern "C" SEXP _stochtree_ordinal_sampler_cpp() { BEGIN_CPP11 @@ -1335,26 +1341,26 @@ extern "C" SEXP _stochtree_ordinal_sampler_cpp() { END_CPP11 } // sampler.cpp -void ordinal_sampler_update_latent_variables_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, cpp11::external_pointer rng_ptr); -extern "C" SEXP _stochtree_ordinal_sampler_update_latent_variables_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP tracker_ptr, SEXP rng_ptr) { +void ordinal_sampler_update_latent_variables_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_latent_variables_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP rng_ptr) { BEGIN_CPP11 - ordinal_sampler_update_latent_variables_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>>(rng_ptr)); + ordinal_sampler_update_latent_variables_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(rng_ptr)); return R_NilValue; END_CPP11 } // sampler.cpp -void ordinal_sampler_update_gamma_params_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer tracker_ptr, double alpha_gamma, double beta_gamma, double gamma_0, cpp11::external_pointer rng_ptr); -extern "C" SEXP _stochtree_ordinal_sampler_update_gamma_params_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP tracker_ptr, SEXP alpha_gamma, SEXP beta_gamma, SEXP gamma_0, SEXP rng_ptr) { +void ordinal_sampler_update_gamma_params_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, double alpha_gamma, double beta_gamma, double gamma_0, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_gamma_params_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP alpha_gamma, SEXP beta_gamma, SEXP gamma_0, SEXP rng_ptr) { BEGIN_CPP11 - ordinal_sampler_update_gamma_params_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(tracker_ptr), cpp11::as_cpp>(alpha_gamma), cpp11::as_cpp>(beta_gamma), cpp11::as_cpp>(gamma_0), cpp11::as_cpp>>(rng_ptr)); + ordinal_sampler_update_gamma_params_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>(alpha_gamma), cpp11::as_cpp>(beta_gamma), cpp11::as_cpp>(gamma_0), cpp11::as_cpp>>(rng_ptr)); return R_NilValue; END_CPP11 } // sampler.cpp -void ordinal_sampler_update_cumsum_exp_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer tracker_ptr); -extern "C" SEXP _stochtree_ordinal_sampler_update_cumsum_exp_cpp(SEXP sampler_ptr, SEXP tracker_ptr) { +void ordinal_sampler_update_cumsum_exp_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_cumsum_exp_cpp(SEXP sampler_ptr, SEXP data_ptr) { BEGIN_CPP11 - ordinal_sampler_update_cumsum_exp_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(tracker_ptr)); + ordinal_sampler_update_cumsum_exp_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr)); return R_NilValue; END_CPP11 } @@ -1659,229 +1665,230 @@ extern "C" SEXP _stochtree_json_load_string_cpp(SEXP json_ptr, SEXP json_string) extern "C" { static const R_CallMethodDef CallEntries[] = { - {"_stochtree_active_forest_cpp", (DL_FUNC) &_stochtree_active_forest_cpp, 4}, - {"_stochtree_add_numeric_split_tree_value_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_active_forest_cpp, 7}, - {"_stochtree_add_numeric_split_tree_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_forest_container_cpp, 8}, - {"_stochtree_add_numeric_split_tree_vector_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_active_forest_cpp, 7}, - {"_stochtree_add_numeric_split_tree_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_forest_container_cpp, 8}, - {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, - {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, - {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, - {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, - {"_stochtree_add_to_forest_forest_container_cpp", (DL_FUNC) &_stochtree_add_to_forest_forest_container_cpp, 3}, - {"_stochtree_adjust_residual_active_forest_cpp", (DL_FUNC) &_stochtree_adjust_residual_active_forest_cpp, 6}, - {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, - {"_stochtree_all_roots_active_forest_cpp", (DL_FUNC) &_stochtree_all_roots_active_forest_cpp, 1}, - {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, - {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, - {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, - {"_stochtree_combine_forests_forest_container_cpp", (DL_FUNC) &_stochtree_combine_forests_forest_container_cpp, 2}, - {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, - {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, - {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, - {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, - {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, - {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, - {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, - {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, - {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, - {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, - {"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2}, - {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, - {"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2}, - {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, - {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, - {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4}, - {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, - {"_stochtree_forest_container_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_string_cpp, 2}, - {"_stochtree_forest_container_get_max_leaf_index_cpp", (DL_FUNC) &_stochtree_forest_container_get_max_leaf_index_cpp, 2}, - {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, - {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, - {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, - {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, - {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, - {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, - {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, - {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, - {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, - {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, - {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, - {"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1}, - {"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1}, - {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1}, - {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, - {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, - {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, - {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, - {"_stochtree_get_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_get_max_depth_tree_prior_cpp, 1}, - {"_stochtree_get_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_get_min_samples_leaf_tree_prior_cpp, 1}, - {"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2}, - {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, - {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, - {"_stochtree_get_tree_leaves_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_active_forest_cpp, 2}, - {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, - {"_stochtree_get_tree_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_active_forest_cpp, 3}, - {"_stochtree_get_tree_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_forest_container_cpp, 4}, - {"_stochtree_init_json_cpp", (DL_FUNC) &_stochtree_init_json_cpp, 0}, - {"_stochtree_initialize_forest_model_active_forest_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_active_forest_cpp, 6}, - {"_stochtree_initialize_forest_model_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_cpp, 6}, - {"_stochtree_is_categorical_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_categorical_split_node_forest_container_cpp, 4}, - {"_stochtree_is_exponentiated_active_forest_cpp", (DL_FUNC) &_stochtree_is_exponentiated_active_forest_cpp, 1}, - {"_stochtree_is_exponentiated_forest_container_cpp", (DL_FUNC) &_stochtree_is_exponentiated_forest_container_cpp, 1}, - {"_stochtree_is_leaf_constant_active_forest_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_active_forest_cpp, 1}, - {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, - {"_stochtree_is_leaf_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_node_forest_container_cpp, 4}, - {"_stochtree_is_numeric_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_numeric_split_node_forest_container_cpp, 4}, - {"_stochtree_json_add_bool_cpp", (DL_FUNC) &_stochtree_json_add_bool_cpp, 3}, - {"_stochtree_json_add_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_bool_subfolder_cpp, 4}, - {"_stochtree_json_add_double_cpp", (DL_FUNC) &_stochtree_json_add_double_cpp, 3}, - {"_stochtree_json_add_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_double_subfolder_cpp, 4}, - {"_stochtree_json_add_forest_cpp", (DL_FUNC) &_stochtree_json_add_forest_cpp, 2}, - {"_stochtree_json_add_integer_cpp", (DL_FUNC) &_stochtree_json_add_integer_cpp, 3}, - {"_stochtree_json_add_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_subfolder_cpp, 4}, - {"_stochtree_json_add_integer_vector_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_cpp, 3}, - {"_stochtree_json_add_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_subfolder_cpp, 4}, - {"_stochtree_json_add_rfx_container_cpp", (DL_FUNC) &_stochtree_json_add_rfx_container_cpp, 2}, - {"_stochtree_json_add_rfx_groupids_cpp", (DL_FUNC) &_stochtree_json_add_rfx_groupids_cpp, 2}, - {"_stochtree_json_add_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_json_add_rfx_label_mapper_cpp, 2}, - {"_stochtree_json_add_string_cpp", (DL_FUNC) &_stochtree_json_add_string_cpp, 3}, - {"_stochtree_json_add_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_subfolder_cpp, 4}, - {"_stochtree_json_add_string_vector_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_cpp, 3}, - {"_stochtree_json_add_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_subfolder_cpp, 4}, - {"_stochtree_json_add_vector_cpp", (DL_FUNC) &_stochtree_json_add_vector_cpp, 3}, - {"_stochtree_json_add_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_vector_subfolder_cpp, 4}, - {"_stochtree_json_contains_field_cpp", (DL_FUNC) &_stochtree_json_contains_field_cpp, 2}, - {"_stochtree_json_contains_field_subfolder_cpp", (DL_FUNC) &_stochtree_json_contains_field_subfolder_cpp, 3}, - {"_stochtree_json_extract_bool_cpp", (DL_FUNC) &_stochtree_json_extract_bool_cpp, 2}, - {"_stochtree_json_extract_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_bool_subfolder_cpp, 3}, - {"_stochtree_json_extract_double_cpp", (DL_FUNC) &_stochtree_json_extract_double_cpp, 2}, - {"_stochtree_json_extract_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_double_subfolder_cpp, 3}, - {"_stochtree_json_extract_integer_cpp", (DL_FUNC) &_stochtree_json_extract_integer_cpp, 2}, - {"_stochtree_json_extract_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_subfolder_cpp, 3}, - {"_stochtree_json_extract_integer_vector_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_cpp, 2}, - {"_stochtree_json_extract_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_subfolder_cpp, 3}, - {"_stochtree_json_extract_string_cpp", (DL_FUNC) &_stochtree_json_extract_string_cpp, 2}, - {"_stochtree_json_extract_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_subfolder_cpp, 3}, - {"_stochtree_json_extract_string_vector_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_cpp, 2}, - {"_stochtree_json_extract_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_subfolder_cpp, 3}, - {"_stochtree_json_extract_vector_cpp", (DL_FUNC) &_stochtree_json_extract_vector_cpp, 2}, - {"_stochtree_json_extract_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_vector_subfolder_cpp, 3}, - {"_stochtree_json_increment_rfx_count_cpp", (DL_FUNC) &_stochtree_json_increment_rfx_count_cpp, 1}, - {"_stochtree_json_load_file_cpp", (DL_FUNC) &_stochtree_json_load_file_cpp, 2}, - {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, - {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, - {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, - {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, - {"_stochtree_leaf_dimension_active_forest_cpp", (DL_FUNC) &_stochtree_leaf_dimension_active_forest_cpp, 1}, - {"_stochtree_leaf_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_dimension_forest_container_cpp, 1}, - {"_stochtree_leaf_values_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_values_forest_container_cpp, 4}, - {"_stochtree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_leaves_forest_container_cpp, 3}, - {"_stochtree_left_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_left_child_node_forest_container_cpp, 4}, - {"_stochtree_multiply_forest_forest_container_cpp", (DL_FUNC) &_stochtree_multiply_forest_forest_container_cpp, 3}, - {"_stochtree_node_depth_forest_container_cpp", (DL_FUNC) &_stochtree_node_depth_forest_container_cpp, 4}, - {"_stochtree_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_nodes_forest_container_cpp, 3}, - {"_stochtree_num_leaf_parents_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaf_parents_forest_container_cpp, 3}, - {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2}, - {"_stochtree_num_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_forest_container_cpp, 3}, - {"_stochtree_num_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_nodes_forest_container_cpp, 3}, - {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, - {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, - {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, - {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, - {"_stochtree_ordinal_aux_data_get_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_get_cpp, 3}, - {"_stochtree_ordinal_aux_data_get_vector_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_get_vector_cpp, 2}, - {"_stochtree_ordinal_aux_data_initialize_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_initialize_cpp, 3}, - {"_stochtree_ordinal_aux_data_set_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_set_cpp, 4}, - {"_stochtree_ordinal_aux_data_set_vector_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_set_vector_cpp, 3}, - {"_stochtree_ordinal_aux_data_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_aux_data_update_cumsum_exp_cpp, 1}, - {"_stochtree_ordinal_sampler_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_cpp, 0}, - {"_stochtree_ordinal_sampler_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_cumsum_exp_cpp, 2}, - {"_stochtree_ordinal_sampler_update_gamma_params_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_gamma_params_cpp, 8}, - {"_stochtree_ordinal_sampler_update_latent_variables_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_latent_variables_cpp, 5}, - {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, - {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, - {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, - {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, - {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, - {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, - {"_stochtree_predict_forest_raw_single_tree_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_tree_cpp, 4}, - {"_stochtree_predict_raw_active_forest_cpp", (DL_FUNC) &_stochtree_predict_raw_active_forest_cpp, 2}, - {"_stochtree_propagate_basis_update_active_forest_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_active_forest_cpp, 4}, - {"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5}, - {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, - {"_stochtree_remove_sample_forest_container_cpp", (DL_FUNC) &_stochtree_remove_sample_forest_container_cpp, 2}, - {"_stochtree_reset_active_forest_cpp", (DL_FUNC) &_stochtree_reset_active_forest_cpp, 3}, - {"_stochtree_reset_forest_model_cpp", (DL_FUNC) &_stochtree_reset_forest_model_cpp, 5}, - {"_stochtree_reset_rfx_model_cpp", (DL_FUNC) &_stochtree_reset_rfx_model_cpp, 3}, - {"_stochtree_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_reset_rfx_tracker_cpp, 4}, - {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, - {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, - {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, - {"_stochtree_rfx_container_delete_sample_cpp", (DL_FUNC) &_stochtree_rfx_container_delete_sample_cpp, 2}, - {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, - {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, - {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, - {"_stochtree_rfx_container_get_beta_cpp", (DL_FUNC) &_stochtree_rfx_container_get_beta_cpp, 1}, - {"_stochtree_rfx_container_get_sigma_cpp", (DL_FUNC) &_stochtree_rfx_container_get_sigma_cpp, 1}, - {"_stochtree_rfx_container_get_xi_cpp", (DL_FUNC) &_stochtree_rfx_container_get_xi_cpp, 1}, - {"_stochtree_rfx_container_num_components_cpp", (DL_FUNC) &_stochtree_rfx_container_num_components_cpp, 1}, - {"_stochtree_rfx_container_num_groups_cpp", (DL_FUNC) &_stochtree_rfx_container_num_groups_cpp, 1}, - {"_stochtree_rfx_container_num_samples_cpp", (DL_FUNC) &_stochtree_rfx_container_num_samples_cpp, 1}, - {"_stochtree_rfx_container_predict_cpp", (DL_FUNC) &_stochtree_rfx_container_predict_cpp, 3}, - {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, - {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, - {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, - {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, - {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, - {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, - {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, - {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, - {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, - {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, - {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, - {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, - {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, - {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, - {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, - {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, - {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, - {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, - {"_stochtree_rfx_label_mapper_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_string_cpp, 2}, - {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, - {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, - {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, - {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 8}, - {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, - {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, - {"_stochtree_rfx_model_set_variance_prior_scale_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_scale_cpp, 2}, - {"_stochtree_rfx_model_set_variance_prior_shape_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_shape_cpp, 2}, - {"_stochtree_rfx_model_set_working_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_covariance_cpp, 2}, - {"_stochtree_rfx_model_set_working_parameter_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_cpp, 2}, - {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, - {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, - {"_stochtree_right_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_right_child_node_forest_container_cpp, 4}, - {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, - {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 18}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, - {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, - {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, - {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, - {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, - {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, - {"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2}, - {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, - {"_stochtree_split_categories_forest_container_cpp", (DL_FUNC) &_stochtree_split_categories_forest_container_cpp, 4}, - {"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4}, - {"_stochtree_split_theshold_forest_container_cpp", (DL_FUNC) &_stochtree_split_theshold_forest_container_cpp, 4}, - {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, - {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, - {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, - {"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2}, - {"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2}, - {"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2}, - {"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2}, + {"_stochtree_active_forest_cpp", (DL_FUNC) &_stochtree_active_forest_cpp, 4}, + {"_stochtree_add_numeric_split_tree_value_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_active_forest_cpp, 7}, + {"_stochtree_add_numeric_split_tree_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_forest_container_cpp, 8}, + {"_stochtree_add_numeric_split_tree_vector_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_active_forest_cpp, 7}, + {"_stochtree_add_numeric_split_tree_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_forest_container_cpp, 8}, + {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, + {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, + {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, + {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, + {"_stochtree_add_to_forest_forest_container_cpp", (DL_FUNC) &_stochtree_add_to_forest_forest_container_cpp, 3}, + {"_stochtree_adjust_residual_active_forest_cpp", (DL_FUNC) &_stochtree_adjust_residual_active_forest_cpp, 6}, + {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, + {"_stochtree_all_roots_active_forest_cpp", (DL_FUNC) &_stochtree_all_roots_active_forest_cpp, 1}, + {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, + {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, + {"_stochtree_combine_forests_forest_container_cpp", (DL_FUNC) &_stochtree_combine_forests_forest_container_cpp, 2}, + {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, + {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, + {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, + {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, + {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, + {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, + {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, + {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, + {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, + {"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2}, + {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, + {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, + {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4}, + {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, + {"_stochtree_forest_container_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_string_cpp, 2}, + {"_stochtree_forest_container_get_max_leaf_index_cpp", (DL_FUNC) &_stochtree_forest_container_get_max_leaf_index_cpp, 2}, + {"_stochtree_forest_dataset_add_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_auxiliary_dimension_cpp, 2}, + {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, + {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, + {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, + {"_stochtree_forest_dataset_get_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_value_cpp, 3}, + {"_stochtree_forest_dataset_get_auxiliary_data_vector_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_vector_cpp, 2}, + {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, + {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, + {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, + {"_stochtree_forest_dataset_has_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_has_auxiliary_dimension_cpp, 2}, + {"_stochtree_forest_dataset_set_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_set_auxiliary_data_value_cpp, 4}, + {"_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp", (DL_FUNC) &_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp, 4}, + {"_stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum", (DL_FUNC) &_stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum, 3}, + {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, + {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, + {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, + {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, + {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, + {"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1}, + {"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1}, + {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1}, + {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, + {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, + {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, + {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, + {"_stochtree_get_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_get_max_depth_tree_prior_cpp, 1}, + {"_stochtree_get_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_get_min_samples_leaf_tree_prior_cpp, 1}, + {"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2}, + {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, + {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, + {"_stochtree_get_tree_leaves_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_active_forest_cpp, 2}, + {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, + {"_stochtree_get_tree_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_active_forest_cpp, 3}, + {"_stochtree_get_tree_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_forest_container_cpp, 4}, + {"_stochtree_init_json_cpp", (DL_FUNC) &_stochtree_init_json_cpp, 0}, + {"_stochtree_initialize_forest_model_active_forest_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_active_forest_cpp, 6}, + {"_stochtree_initialize_forest_model_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_cpp, 6}, + {"_stochtree_is_categorical_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_categorical_split_node_forest_container_cpp, 4}, + {"_stochtree_is_exponentiated_active_forest_cpp", (DL_FUNC) &_stochtree_is_exponentiated_active_forest_cpp, 1}, + {"_stochtree_is_exponentiated_forest_container_cpp", (DL_FUNC) &_stochtree_is_exponentiated_forest_container_cpp, 1}, + {"_stochtree_is_leaf_constant_active_forest_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_active_forest_cpp, 1}, + {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, + {"_stochtree_is_leaf_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_node_forest_container_cpp, 4}, + {"_stochtree_is_numeric_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_numeric_split_node_forest_container_cpp, 4}, + {"_stochtree_json_add_bool_cpp", (DL_FUNC) &_stochtree_json_add_bool_cpp, 3}, + {"_stochtree_json_add_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_bool_subfolder_cpp, 4}, + {"_stochtree_json_add_double_cpp", (DL_FUNC) &_stochtree_json_add_double_cpp, 3}, + {"_stochtree_json_add_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_double_subfolder_cpp, 4}, + {"_stochtree_json_add_forest_cpp", (DL_FUNC) &_stochtree_json_add_forest_cpp, 2}, + {"_stochtree_json_add_integer_cpp", (DL_FUNC) &_stochtree_json_add_integer_cpp, 3}, + {"_stochtree_json_add_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_subfolder_cpp, 4}, + {"_stochtree_json_add_integer_vector_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_cpp, 3}, + {"_stochtree_json_add_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_subfolder_cpp, 4}, + {"_stochtree_json_add_rfx_container_cpp", (DL_FUNC) &_stochtree_json_add_rfx_container_cpp, 2}, + {"_stochtree_json_add_rfx_groupids_cpp", (DL_FUNC) &_stochtree_json_add_rfx_groupids_cpp, 2}, + {"_stochtree_json_add_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_json_add_rfx_label_mapper_cpp, 2}, + {"_stochtree_json_add_string_cpp", (DL_FUNC) &_stochtree_json_add_string_cpp, 3}, + {"_stochtree_json_add_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_subfolder_cpp, 4}, + {"_stochtree_json_add_string_vector_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_cpp, 3}, + {"_stochtree_json_add_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_subfolder_cpp, 4}, + {"_stochtree_json_add_vector_cpp", (DL_FUNC) &_stochtree_json_add_vector_cpp, 3}, + {"_stochtree_json_add_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_vector_subfolder_cpp, 4}, + {"_stochtree_json_contains_field_cpp", (DL_FUNC) &_stochtree_json_contains_field_cpp, 2}, + {"_stochtree_json_contains_field_subfolder_cpp", (DL_FUNC) &_stochtree_json_contains_field_subfolder_cpp, 3}, + {"_stochtree_json_extract_bool_cpp", (DL_FUNC) &_stochtree_json_extract_bool_cpp, 2}, + {"_stochtree_json_extract_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_bool_subfolder_cpp, 3}, + {"_stochtree_json_extract_double_cpp", (DL_FUNC) &_stochtree_json_extract_double_cpp, 2}, + {"_stochtree_json_extract_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_double_subfolder_cpp, 3}, + {"_stochtree_json_extract_integer_cpp", (DL_FUNC) &_stochtree_json_extract_integer_cpp, 2}, + {"_stochtree_json_extract_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_subfolder_cpp, 3}, + {"_stochtree_json_extract_integer_vector_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_cpp, 2}, + {"_stochtree_json_extract_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_subfolder_cpp, 3}, + {"_stochtree_json_extract_string_cpp", (DL_FUNC) &_stochtree_json_extract_string_cpp, 2}, + {"_stochtree_json_extract_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_subfolder_cpp, 3}, + {"_stochtree_json_extract_string_vector_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_cpp, 2}, + {"_stochtree_json_extract_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_subfolder_cpp, 3}, + {"_stochtree_json_extract_vector_cpp", (DL_FUNC) &_stochtree_json_extract_vector_cpp, 2}, + {"_stochtree_json_extract_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_vector_subfolder_cpp, 3}, + {"_stochtree_json_increment_rfx_count_cpp", (DL_FUNC) &_stochtree_json_increment_rfx_count_cpp, 1}, + {"_stochtree_json_load_file_cpp", (DL_FUNC) &_stochtree_json_load_file_cpp, 2}, + {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, + {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, + {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, + {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, + {"_stochtree_leaf_dimension_active_forest_cpp", (DL_FUNC) &_stochtree_leaf_dimension_active_forest_cpp, 1}, + {"_stochtree_leaf_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_dimension_forest_container_cpp, 1}, + {"_stochtree_leaf_values_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_values_forest_container_cpp, 4}, + {"_stochtree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_leaves_forest_container_cpp, 3}, + {"_stochtree_left_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_left_child_node_forest_container_cpp, 4}, + {"_stochtree_multiply_forest_forest_container_cpp", (DL_FUNC) &_stochtree_multiply_forest_forest_container_cpp, 3}, + {"_stochtree_node_depth_forest_container_cpp", (DL_FUNC) &_stochtree_node_depth_forest_container_cpp, 4}, + {"_stochtree_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_nodes_forest_container_cpp, 3}, + {"_stochtree_num_leaf_parents_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaf_parents_forest_container_cpp, 3}, + {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2}, + {"_stochtree_num_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_forest_container_cpp, 3}, + {"_stochtree_num_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_nodes_forest_container_cpp, 3}, + {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, + {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, + {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, + {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, + {"_stochtree_ordinal_sampler_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_cpp, 0}, + {"_stochtree_ordinal_sampler_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_cumsum_exp_cpp, 2}, + {"_stochtree_ordinal_sampler_update_gamma_params_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_gamma_params_cpp, 7}, + {"_stochtree_ordinal_sampler_update_latent_variables_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_latent_variables_cpp, 4}, + {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, + {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, + {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, + {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, + {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, + {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, + {"_stochtree_predict_forest_raw_single_tree_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_tree_cpp, 4}, + {"_stochtree_predict_raw_active_forest_cpp", (DL_FUNC) &_stochtree_predict_raw_active_forest_cpp, 2}, + {"_stochtree_propagate_basis_update_active_forest_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_active_forest_cpp, 4}, + {"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5}, + {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, + {"_stochtree_remove_sample_forest_container_cpp", (DL_FUNC) &_stochtree_remove_sample_forest_container_cpp, 2}, + {"_stochtree_reset_active_forest_cpp", (DL_FUNC) &_stochtree_reset_active_forest_cpp, 3}, + {"_stochtree_reset_forest_model_cpp", (DL_FUNC) &_stochtree_reset_forest_model_cpp, 5}, + {"_stochtree_reset_rfx_model_cpp", (DL_FUNC) &_stochtree_reset_rfx_model_cpp, 3}, + {"_stochtree_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_reset_rfx_tracker_cpp, 4}, + {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, + {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, + {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, + {"_stochtree_rfx_container_delete_sample_cpp", (DL_FUNC) &_stochtree_rfx_container_delete_sample_cpp, 2}, + {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, + {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, + {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, + {"_stochtree_rfx_container_get_beta_cpp", (DL_FUNC) &_stochtree_rfx_container_get_beta_cpp, 1}, + {"_stochtree_rfx_container_get_sigma_cpp", (DL_FUNC) &_stochtree_rfx_container_get_sigma_cpp, 1}, + {"_stochtree_rfx_container_get_xi_cpp", (DL_FUNC) &_stochtree_rfx_container_get_xi_cpp, 1}, + {"_stochtree_rfx_container_num_components_cpp", (DL_FUNC) &_stochtree_rfx_container_num_components_cpp, 1}, + {"_stochtree_rfx_container_num_groups_cpp", (DL_FUNC) &_stochtree_rfx_container_num_groups_cpp, 1}, + {"_stochtree_rfx_container_num_samples_cpp", (DL_FUNC) &_stochtree_rfx_container_num_samples_cpp, 1}, + {"_stochtree_rfx_container_predict_cpp", (DL_FUNC) &_stochtree_rfx_container_predict_cpp, 3}, + {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, + {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, + {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, + {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, + {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, + {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, + {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, + {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, + {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, + {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, + {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, + {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, + {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, + {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, + {"_stochtree_rfx_label_mapper_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_string_cpp, 2}, + {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, + {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, + {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, + {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 8}, + {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, + {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, + {"_stochtree_rfx_model_set_variance_prior_scale_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_scale_cpp, 2}, + {"_stochtree_rfx_model_set_variance_prior_shape_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_shape_cpp, 2}, + {"_stochtree_rfx_model_set_working_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_covariance_cpp, 2}, + {"_stochtree_rfx_model_set_working_parameter_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_cpp, 2}, + {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, + {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, + {"_stochtree_right_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_right_child_node_forest_container_cpp, 4}, + {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, + {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, + {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 18}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, + {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, + {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, + {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, + {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, + {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, + {"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2}, + {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, + {"_stochtree_split_categories_forest_container_cpp", (DL_FUNC) &_stochtree_split_categories_forest_container_cpp, 4}, + {"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4}, + {"_stochtree_split_theshold_forest_container_cpp", (DL_FUNC) &_stochtree_split_theshold_forest_container_cpp, 4}, + {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, + {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, + {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, + {"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2}, + {"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2}, + {"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2}, + {"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2}, {NULL, NULL, 0} }; } diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp index 19a7c6b5..54c212c7 100644 --- a/src/ordinal_sampler.cpp +++ b/src/ordinal_sampler.cpp @@ -10,12 +10,11 @@ double OrdinalSampler::SampleTruncatedExponential(double lambda, std::mt19937& g return -std::log(a) / lambda; } -void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, - std::mt19937& gen) { +void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, std::mt19937& gen) { // Get auxiliary data vectors - const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // gamma cutpoints - const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) - std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) + const std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // gamma cutpoints + const std::vector& lambda_hat = dataset.GetAuxiliaryDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) int K = gamma.size() + 1; // Number of ordinal categories int N = dataset.NumObservations(); @@ -37,13 +36,13 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector } } -void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, - double alpha_gamma, double beta_gamma, double gamma_0, - std::mt19937& gen) { +void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, + double alpha_gamma, double beta_gamma, + double gamma_0, std::mt19937& gen) { // Get auxiliary data vectors - std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's - const std::vector& Z = tracker.GetOrdinalAuxDataVector(0); // latent variables z_i's - const std::vector& lambda_hat = tracker.GetOrdinalAuxDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // cutpoints gamma_k's + const std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables z_i's + const std::vector& lambda_hat = dataset.GetAuxiliaryDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) int K = gamma.size() + 1; // Number of ordinal categories int N = dataset.NumObservations(); @@ -76,10 +75,10 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& gamma[0] = gamma_0; } -void OrdinalSampler::UpdateCumulativeExpSums(ForestTracker& tracker) { +void OrdinalSampler::UpdateCumulativeExpSums(ForestDataset& dataset) { // Get auxiliary data vectors - const std::vector& gamma = tracker.GetOrdinalAuxDataVector(2); // cutpoints gamma_k's - std::vector& seg = tracker.GetOrdinalAuxDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) + const std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // cutpoints gamma_k's + std::vector& seg = dataset.GetAuxiliaryDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) // Update seg (sum of exponentials of gamma cutpoints) for (int j = 0; j < static_cast(seg.size()); j++) { diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index bb35efd7..d9faf57a 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -702,57 +702,4 @@ std::vector FeaturePresortPartition::NodeIndices(int node_id) { return out; } -// ============================================================================ -// ORDINAL AUXILIARY DATA METHODS -// ============================================================================ - -double ForestTracker::GetOrdinalAuxData(int type_idx, data_size_t obs_idx) const { - // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); - return ordinal_aux_data_vec_[type_idx][obs_idx]; -} - -void ForestTracker::InitializeOrdinalAuxData(data_size_t num_observations, int n_levels) { - ResizeOrdinalAuxData(num_observations, n_levels); -} - -void ForestTracker::SetOrdinalAuxData(int type_idx, data_size_t obs_idx, double value) { - // CHECK(IsValidOrdinalIndex(type_idx, obs_idx)); - ordinal_aux_data_vec_[type_idx][obs_idx] = value; -} - -std::vector& ForestTracker::GetOrdinalAuxDataVector(int type_idx) { - // CHECK(IsValidOrdinalType(type_idx)); - return ordinal_aux_data_vec_[type_idx]; -} - -void ForestTracker::ResizeOrdinalAuxData(data_size_t num_observations, int n_levels) { - // 4 types of ordinal auxiliary data: latent Z, forest predictions, cutpoints gamma, cumsum exp of gammas - const int n_types = 4; - ordinal_aux_data_vec_.resize(n_types); - for (int i = 0; i < n_types; ++i) { - if (i < 2) { - // First two types (latent Z, forest predictions) are sized to num_observations - ordinal_aux_data_vec_[i].assign(num_observations, 0.0); - } else if (i == 2) { - // Cutpoints gamma: size n_levels - 1 - ordinal_aux_data_vec_[i].assign(n_levels - 1, 0.0); - } else if (i == 3) { - // Cumsum exp of gammas: size n_levels - ordinal_aux_data_vec_[i].assign(n_levels, 0.0); - } - } -} - -// bool ForestTracker::IsValidOrdinalType(int type_idx) const { -// return (type_idx >= 0 && type_idx < static_cast(ordinal_aux_data_vec_.size()) && -// !ordinal_aux_data_vec_.empty()); -// } - -// bool ForestTracker::IsValidOrdinalIndex(int type_idx, data_size_t obs_idx) const { -// if (!IsValidOrdinalType(type_idx)) { -// return false; -// } -// return (obs_idx >= 0 && obs_idx < ordinal_aux_data_vec_[type_idx].size()); -// } - } // namespace StochTree diff --git a/src/sampler.cpp b/src/sampler.cpp index 255f6e7c..6cbbc77b 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -334,61 +334,6 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( return(output); } -// ============================================================================ -// ORDINAL AUXILIARY DATA FUNCTIONS -// ============================================================================ - -[[cpp11::register]] -void ordinal_aux_data_initialize_cpp(cpp11::external_pointer tracker_ptr, StochTree::data_size_t num_observations, int n_levels) { - tracker_ptr->InitializeOrdinalAuxData(num_observations, n_levels); -} - -[[cpp11::register]] -double ordinal_aux_data_get_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx) { - return tracker_ptr->GetOrdinalAuxData(type_idx, obs_idx); -} - -[[cpp11::register]] -void ordinal_aux_data_set_cpp(cpp11::external_pointer tracker_ptr, int type_idx, StochTree::data_size_t obs_idx, double value) { - tracker_ptr->SetOrdinalAuxData(type_idx, obs_idx, value); -} - -[[cpp11::register]] -cpp11::writable::doubles ordinal_aux_data_get_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx) { - const std::vector& aux_vec = tracker_ptr->GetOrdinalAuxDataVector(type_idx); - cpp11::writable::doubles output(aux_vec.size()); - for (size_t i = 0; i < aux_vec.size(); i++) { - output[i] = aux_vec[i]; - } - return output; -} - -[[cpp11::register]] -void ordinal_aux_data_set_vector_cpp(cpp11::external_pointer tracker_ptr, int type_idx, cpp11::doubles values) { - std::vector& aux_vec = tracker_ptr->GetOrdinalAuxDataVector(type_idx); - if (aux_vec.size() != values.size()) { - cpp11::stop("Size mismatch between auxiliary data vector and input values"); - } - for (size_t i = 0; i < values.size(); i++) { - aux_vec[i] = values[i]; - } -} - -[[cpp11::register]] -void ordinal_aux_data_update_cumsum_exp_cpp(cpp11::external_pointer tracker_ptr) { - // Get auxiliary data vectors - const std::vector& gamma = tracker_ptr->GetOrdinalAuxDataVector(2); // cutpoints gamma - std::vector& seg = tracker_ptr->GetOrdinalAuxDataVector(3); // cumsum exp gamma - - // Update seg (cumulative sum of exp(gamma)) - for (size_t j = 0; j < seg.size(); j++) { - if (j == 0) { - seg[j] = 0.0; - } else { - seg[j] = seg[j - 1] + std::exp(gamma[j - 1]); - } - } -} // ============================================================================ // ORDINAL SAMPLER FUNCTIONS @@ -405,10 +350,9 @@ void ordinal_sampler_update_latent_variables_cpp( cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, - cpp11::external_pointer tracker_ptr, cpp11::external_pointer rng_ptr ) { - sampler_ptr->UpdateLatentVariables(*data_ptr, outcome_ptr->GetData(), *tracker_ptr, *rng_ptr); + sampler_ptr->UpdateLatentVariables(*data_ptr, outcome_ptr->GetData(), *rng_ptr); } [[cpp11::register]] @@ -416,21 +360,20 @@ void ordinal_sampler_update_gamma_params_cpp( cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, - cpp11::external_pointer tracker_ptr, double alpha_gamma, double beta_gamma, double gamma_0, cpp11::external_pointer rng_ptr ) { - sampler_ptr->UpdateGammaParams(*data_ptr, outcome_ptr->GetData(), *tracker_ptr, alpha_gamma, beta_gamma, gamma_0, *rng_ptr); + sampler_ptr->UpdateGammaParams(*data_ptr, outcome_ptr->GetData(), alpha_gamma, beta_gamma, gamma_0, *rng_ptr); } [[cpp11::register]] void ordinal_sampler_update_cumsum_exp_cpp( cpp11::external_pointer sampler_ptr, - cpp11::external_pointer tracker_ptr + cpp11::external_pointer data_ptr ) { - sampler_ptr->UpdateCumulativeExpSums(*tracker_ptr); + sampler_ptr->UpdateCumulativeExpSums(*data_ptr); } diff --git a/tools/debug/testing_cloglog_ordinal_bart.R b/tools/debug/testing_cloglog_ordinal_bart.R index 71ef790a..0eca735a 100644 --- a/tools/debug/testing_cloglog_ordinal_bart.R +++ b/tools/debug/testing_cloglog_ordinal_bart.R @@ -47,6 +47,8 @@ y_train <- y[train_idx] X_test <- X[test_idx, ] y_test <- y[test_idx] +start <- Sys.time() + out <- cloglog_ordinal_bart( X = X_train, y = y_train, @@ -56,6 +58,8 @@ out <- cloglog_ordinal_bart( n_thin = 1 ) +end <- Sys.time() +print(end - start) # Inference and diagnostics par(mfrow = c(2, 1)) From 36b6a9828335744e947274a34f0a408708eca3ae Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 11:48:02 -0400 Subject: [PATCH 11/34] Removed call to deprectated cpp function --- R/cloglog_ordinal_bart.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index a6687de5..d5a82db2 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -97,7 +97,6 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, as.integer(n_trees), as.integer(n_samples) ) - stochtree:::ordinal_aux_data_initialize_cpp(forest_tracker, as.integer(n_samples), as.integer(n_levels)) # Latent variable (Z in Alam et al (2025) notation) dataX$add_auxiliary_dimension(nrow(X)) # Forest predictions (eta in Alam et al (2025) notation) From 2de707c082e5ebcc7d3b5f17f0dbffe76ab5f349 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:01:30 -0400 Subject: [PATCH 12/34] Fixed indexing bug --- R/cloglog_ordinal_bart.R | 6 ++---- include/stochtree/data.h | 6 +++--- tools/debug/testing_cloglog_ordinal_bart.R | 5 ++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index d5a82db2..570bb944 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -12,8 +12,6 @@ #' @param variable_weights Optional vector of variable weights for splitting (default: equal weights). #' @param feature_types Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). #' @export - - cloglog_ordinal_bart <- function(X, y, X_test = NULL, n_trees = 50, n_samples_mcmc = 500, @@ -31,7 +29,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, min_samples_in_leaf <- 5 max_depth <- 10 scale_leaf <- 2 / sqrt(n_trees) - cutpoint_grid_size <- 100 # Needed for stochtree:::sample_mcmc_one_iteration_cpp (for GFR), not used in ordinal BART + cutpoint_grid_size <- 100 # Needed for stochtree:::sample_mcmc_one_iteration_cpp (for GFR), not used in MCMC BART # Fixed for identifiability (can be pass as argument later if desired) gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0 @@ -171,7 +169,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, ) # 4. Update cumulative sum of exp(gamma) values - dataX$update_auxiliary_data_vector_cumulative_exp_sum(2,3) + dataX$update_auxiliary_data_vector_cumulative_exp_sum(2, 3) if (keep_sample) { forest_pred_train[, sample_counter] <- active_forest$predict(dataX) diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 1bc56a6e..8c3511d0 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -476,8 +476,8 @@ class ForestDataset { */ void AddAuxiliaryDimension(int dim_size) { if (!has_auxiliary_data_) has_auxiliary_data_ = true; - auxiliary_data_.resize(num_auxiliary_dims_); - auxiliary_data_[num_auxiliary_dims_].assign(dim_size, 0.0); + auxiliary_data_.resize(num_auxiliary_dims_ + 1); + auxiliary_data_[num_auxiliary_dims_].resize(dim_size); num_auxiliary_dims_++; } double GetAuxiliaryDataValue(int dim_idx, data_size_t element_idx) { @@ -505,7 +505,7 @@ class ForestDataset { double cumulative_exp_sum = 0.0; target_vector[0] = cumulative_exp_sum; for (int i = 1; i < num_levels - 1; i++) { - cumulative_exp_sum += std::exp(reference_vector[i]); + cumulative_exp_sum = cumulative_exp_sum + std::exp(reference_vector[i - 1]); target_vector[i] = cumulative_exp_sum; } } diff --git a/tools/debug/testing_cloglog_ordinal_bart.R b/tools/debug/testing_cloglog_ordinal_bart.R index 0eca735a..2d1c607f 100644 --- a/tools/debug/testing_cloglog_ordinal_bart.R +++ b/tools/debug/testing_cloglog_ordinal_bart.R @@ -14,7 +14,6 @@ X <- matrix(rnorm(n * p), n, p) beta <- rep(1 / sqrt(p), p) true_lambda_function <- X %*% beta - # Set cutpoints for ordinal categories (3 categories: 1, 2, 3) n_categories <- 3 gamma_true <- c(-2, 1) @@ -38,10 +37,9 @@ for (j in 1:n_categories) { y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) cat("Outcome distribution:", table(y), "\n") -# CLogLog Ordinal BART model fitting +# Train test split train_idx <- sample(1:n, size = floor(0.8 * n)) test_idx <- setdiff(1:n, train_idx) - X_train <- X[train_idx, ] y_train <- y[train_idx] X_test <- X[test_idx, ] @@ -49,6 +47,7 @@ y_test <- y[test_idx] start <- Sys.time() +# Sample the cloglog ordinal BART model out <- cloglog_ordinal_bart( X = X_train, y = y_train, From 2d19399150866add5833ae70c71709f76281ae30 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:14:15 -0400 Subject: [PATCH 13/34] Refactored and fixed bugs --- R/cloglog_ordinal_bart.R | 13 +++++++------ R/data.R | 10 ---------- include/stochtree/data.h | 14 -------------- src/R_data.cpp | 5 ----- src/sampler.cpp | 2 -- 5 files changed, 7 insertions(+), 37 deletions(-) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index 570bb944..9932df1c 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -78,6 +78,10 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, } gamma_samples <- matrix(0, n_levels - 1, n_keep) latent_samples <- matrix(0, n_samples, n_keep) + + # Initialize samplers + ordinal_sampler <- stochtree:::ordinal_sampler_cpp() + rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed) # Initialize other model structures as before dataX <- stochtree::createForestDataset(X) @@ -95,6 +99,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, as.integer(n_trees), as.integer(n_samples) ) + # Latent variable (Z in Alam et al (2025) notation) dataX$add_auxiliary_dimension(nrow(X)) # Forest predictions (eta in Alam et al (2025) notation) @@ -116,7 +121,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, # Convert the log-scale parameters into cumulative exponentiated parameters. # This is done under the hood in a C++ function for efficiency. - dataX$update_auxiliary_data_vector_cumulative_exp_sum(2, 3) + ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) # Initialize forest predictions to zero (slot 1) for (i in 1:n_samples) { @@ -128,10 +133,6 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, dataX$set_auxiliary_data_value(0, i - 1, 0.0) } - # Initialize samplers - ordinal_sampler <- stochtree:::ordinal_sampler_cpp() - rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed) - # Set up sweep indices for tree updates (sample all trees each iteration) sweep_indices <- as.integer(seq(0, n_trees - 1)) @@ -169,7 +170,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, ) # 4. Update cumulative sum of exp(gamma) values - dataX$update_auxiliary_data_vector_cumulative_exp_sum(2, 3) + ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) if (keep_sample) { forest_pred_train[, sample_counter] <- active_forest$predict(dataX) diff --git a/R/data.R b/R/data.R index f946a146..f2aa46aa 100644 --- a/R/data.R +++ b/R/data.R @@ -161,16 +161,6 @@ ForestDataset <- R6::R6Class( #' @return Vector of all of the auxiliary data stored at dimension `dim_idx` store_auxiliary_data_vector_matrix = function(output_matrix, dim_idx, matrix_col_idx) { return(forest_dataset_store_auxiliary_data_vector_as_column_cpp(self$data_ptr, output_matrix, dim_idx, matrix_col_idx)) - }, - - #' @description - #' Updates the elements of one auxiliary data vector based on the cumulative exponentiated values of elements of another vector. - #' If the target value has `k` elements, the reference vector must have `k - 1` elements. - #' @param reference_vector_idx Index of the auxiliary data vector to be exponentiated and scanned - #' @param target_vector_idx Index of the auxiliary data vector to be written with exponentiated and scanned values of `reference_vector_idx` - #' @return None - update_auxiliary_data_vector_cumulative_exp_sum = function(reference_vector_idx, target_vector_idx) { - return(forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(self$data_ptr, reference_vector_idx, target_vector_idx)) } ) ) diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 8c3511d0..8cf16e4d 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -496,20 +496,6 @@ class ForestDataset { return (num_auxiliary_dims_ > dim_idx) & (dim_idx >= 0); } - void UpdateAuxiliaryDataVectorCumulativeExpSum(int reference_vector_idx, int target_vector_idx) { - CHECK(HasAuxiliaryDimension(reference_vector_idx)); - CHECK(HasAuxiliaryDimension(target_vector_idx)); - const std::vector& reference_vector = GetAuxiliaryDataVectorConst(reference_vector_idx); - std::vector& target_vector = GetAuxiliaryDataVector(target_vector_idx); - int num_levels = target_vector.size(); - double cumulative_exp_sum = 0.0; - target_vector[0] = cumulative_exp_sum; - for (int i = 1; i < num_levels - 1; i++) { - cumulative_exp_sum = cumulative_exp_sum + std::exp(reference_vector[i - 1]); - target_vector[i] = cumulative_exp_sum; - } - } - private: ColumnMatrix covariates_; ColumnMatrix basis_; diff --git a/src/R_data.cpp b/src/R_data.cpp index 084e215c..6ede6473 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -190,11 +190,6 @@ cpp11::writable::doubles_matrix<> forest_dataset_store_auxiliary_data_vector_as_ return output_matrix; } -[[cpp11::register]] -void forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(cpp11::external_pointer dataset_ptr, int reference_vector_idx, int target_vector_idx) { - dataset_ptr->UpdateAuxiliaryDataVectorCumulativeExpSum(reference_vector_idx, target_vector_idx); -} - [[cpp11::register]] cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome) { // Unpack pointers to data and dimensions diff --git a/src/sampler.cpp b/src/sampler.cpp index 6cbbc77b..d06f025e 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -375,5 +375,3 @@ void ordinal_sampler_update_cumsum_exp_cpp( ) { sampler_ptr->UpdateCumulativeExpSums(*data_ptr); } - - From f9a0b5a26561c84fce32c8f658b340bd7b7bd90f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:15:50 -0400 Subject: [PATCH 14/34] Updated multinomial cloglog vignette --- ..._cloglog_ordinal_bart.R => cloglog_ordinal_bart_multinomial.R} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tools/debug/{testing_cloglog_ordinal_bart.R => cloglog_ordinal_bart_multinomial.R} (100%) diff --git a/tools/debug/testing_cloglog_ordinal_bart.R b/tools/debug/cloglog_ordinal_bart_multinomial.R similarity index 100% rename from tools/debug/testing_cloglog_ordinal_bart.R rename to tools/debug/cloglog_ordinal_bart_multinomial.R From 66164f56a5f3bbd80c5efee572e11b62f1de85f2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:27:26 -0400 Subject: [PATCH 15/34] Added binary outcome cloglog model demo --- tools/debug/cloglog_ordinal_bart_binary.R | 138 ++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 tools/debug/cloglog_ordinal_bart_binary.R diff --git a/tools/debug/cloglog_ordinal_bart_binary.R b/tools/debug/cloglog_ordinal_bart_binary.R new file mode 100644 index 00000000..21f5e9cb --- /dev/null +++ b/tools/debug/cloglog_ordinal_bart_binary.R @@ -0,0 +1,138 @@ +# Simulate ordinal data and run Cloglog Ordinal BART + +# Load +library(stochtree) + +set.seed(2025) + +# Sample size and number of predictors +n <- 2000 +p_X <- 5 + +# Design matrix and true lambda function +X <- matrix(runif(n * p_X), ncol = p_X) +true_lambda_function <- ifelse(X[, 1] > 0.5, 2, -1) + +# Set cutpoints for ordinal categories (2 categories: 1, 2) +n_categories <- 2 +gamma_true <- c(-1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") + +# Train test split +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +start <- Sys.time() + +# Sample the cloglog ordinal BART model +out <- cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + n_samples_mcmc = 1000, + n_burnin = 500, + n_thin = 1 +) + +end <- Sys.time() +print(end - start) + +# Inference and diagnostics +par(mfrow = c(2, 1)) +plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[1], col = 'red', lty = 2) +# plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") +# abline(h = gamma_true[2], col = 'red', lty = 2) + +gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) +summary(gamma1) +hist(gamma1) + +par(mfrow = c(2,1), mar = c(5,4,1,1)) +rowMeans(out$gamma_samples) +moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) +plot(moo[,1]) +abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) +plot(out$gamma_samples[1,]) + +# Compare forest predictions with the truth function (for training and test sets) +par(mfrow = c(2,1)) +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) + +# Compare estimated vs true class probabilities for training set +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) + +# Compare estimated vs true class probabilities for test set +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} From d5de763864218cba1897a79c8c8a7bab904fbfed Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:47:08 -0400 Subject: [PATCH 16/34] Reworking sampler implementation to match current stochtree::main API --- include/stochtree/tree_sampler.h | 2 +- src/sampler.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 8c41ff3d..f4c65ba9 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -1148,7 +1148,7 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For template static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the MCMC algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { diff --git a/src/sampler.cpp b/src/sampler.cpp index d06f025e..2e80aa4d 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -110,7 +110,7 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { - StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false); + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); } } From 03024596a8989135ad1d5a6447d422422832cbaa Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:52:24 -0400 Subject: [PATCH 17/34] Reflecting num_threads further through the interface --- include/stochtree/tree_sampler.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index f4c65ba9..fdaf507f 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -440,7 +440,7 @@ template EvaluateProposedSplit( ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, TreeSplit& split, int tree_num, int leaf_num, int split_feature, double global_variance, - LeafSuffStatConstructorArgs&... leaf_suff_stat_args + int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args ) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -882,7 +882,7 @@ static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& template static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, std::vector& variable_weights, - double global_variance, double prob_grow_old, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + double global_variance, double prob_grow_old, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Extract dataset information data_size_t n = dataset.GetCovariates().rows(); @@ -927,7 +927,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM // Compute the marginal likelihood of split and no split, given the leaf prior std::tuple split_eval = EvaluateProposedSplit( - dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, leaf_suff_stat_args... + dataset, tracker, residual, leaf_model, split, tree_num, leaf_chosen, var_chosen, global_variance, num_threads, leaf_suff_stat_args... ); double split_log_marginal_likelihood = std::get<0>(split_eval); double no_split_log_marginal_likelihood = std::get<1>(split_eval); @@ -990,7 +990,8 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM template static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Choose a "leaf parent" node at random int num_leaves = tree->NumLeaves(); int num_leaf_parents = tree->NumLeafParents(); @@ -1109,11 +1110,11 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For if (step_chosen == 0) { MCMCGrowTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, leaf_suff_stat_args... + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, variable_weights, global_variance, prob_grow, num_threads, leaf_suff_stat_args... ); } else { MCMCPruneTreeOneIter( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, leaf_suff_stat_args... + tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, num_threads, leaf_suff_stat_args... ); } } @@ -1143,6 +1144,7 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon). * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). + * \param num_threads * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template @@ -1163,7 +1165,7 @@ static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tree = active_forest.GetTree(i); MCMCSampleTreeOneIter( tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, - global_variance, leaf_suff_stat_args... + global_variance, num_threads, leaf_suff_stat_args... ); // Sample leaf parameters for tree i From a7c79d429306e8b7bbb3573176e8766d9ca85008 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:54:46 -0400 Subject: [PATCH 18/34] Refactoring out unused slice sampler for leaf scale parameter --- include/stochtree/leaf_model.h | 99 ---------------------------------- 1 file changed, 99 deletions(-) diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 092989b3..88719608 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -1105,14 +1105,11 @@ class CloglogOrdinalLeafModel { * \param a shape parameter for log-gamma prior on leaf parameters * \param b rate parameter for log-gamma prior on leaf parameters * Log-gamma density: f(x) = b^a / Gamma(a) * exp(a*x - b*exp(x)) - * Relationship to tau (scale of leaf parameters): tau^2 = trigamma(a) */ CloglogOrdinalLeafModel(double a, double b) { a_ = a; b_ = b; gamma_sampler_ = GammaSampler(); - // slice_sampler_ = SliceSampler(); - tau_ = std::sqrt(boost::math::trigamma(a_)); } ~CloglogOrdinalLeafModel() {} @@ -1146,108 +1143,12 @@ class CloglogOrdinalLeafModel { * Samples from log-gamma: sample from gamma, then take log. */ void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); - - void SetScale(double tau) {tau_ = tau;} - - /*! - * \brief Get the current scale parameter value (tau_) - * \return Current tau_ value - */ - double GetScale() const {return tau_;} - inline bool RequiresBasis() {return false;} - // /*! - // * \brief Update the scale parameter (tau_) using slice sampling - // * - // * \param lambda Vector of leaf parameter values from all trees - // * \param scale_sigma_lambda Prior scale parameter for scale parameter (tau_) of leaf parameters - // * \param gen Random number generator - // */ - // void UpdateScaleLambda(const std::vector& lambda, double scale_sigma_lambda, std::mt19937& gen) { - // double n = static_cast(lambda.size()); - // double sum_lambda = 0.0; - // double sum_exp_lambda = 0.0; - - // for (size_t i = 0; i < lambda.size(); i++) { - // sum_lambda += lambda[i]; - // sum_exp_lambda += std::exp(lambda[i]); - // } - - // // Create log-likelihood function - // ScaleLambdaLoglik loglik_func(n, sum_lambda, sum_exp_lambda, scale_sigma_lambda); - - // // Sample new scale parameter using slice sampling - // double current_tau = tau_; - // double w = 1.0; // Step size for slice sampler - // double lower = 1e-6; // Lower bound for tau - // double upper = std::numeric_limits::infinity(); // Upper bound - - // double new_tau = slice_sampler_.Sample(current_tau, &loglik_func, w, lower, upper, gen); - // tau_ = new_tau; - // } - - /*! - * \brief Convert tau_ (scale_lambda i.e. scale for leaf parameters) to alpha (shape) and beta (rate) parameters for the log-gamma prior - * - * \param alpha Output: shape parameter for log-gamma prior - * \param beta Output: rate parameter for log-gamma prior - * \param tau Scale parameter (tau_) for leaf parameters - */ - void ScaleTauToAlphaBeta(double& alpha, double& beta, const double tau) { - double tau_sq = tau * tau; - alpha = TrigammaInverse(tau_sq); - // Note: Using exponential of digamma function for beta calculation - beta = std::exp(boost::math::digamma(alpha)); - } - - /*! - * \brief Convert alpha (shape) and beta (rate) parameters (for the log-gamma prior) back to tau_ (scale_lambda i.e. scale for leaf parameters) - * - * \param alpha Shape parameter for log-gamma prior - * \param beta Rate parameter for log-gamma prior - * \return tau Scale parameter (tau_) for leaf parameters - */ - double AlphaBetaToScaleTau(double alpha, double beta) { - // Inverse of the transformation: tau_sq = trigamma(alpha) - double tau_sq = boost::math::trigamma(alpha); - return std::sqrt(tau_sq); - } - private: - /*! - * \brief Compute inverse trigamma function using Newton's method - * - * Implementation adapted from limma package in R, originally by Gordon Smyth - * - * \param x Input value for which to compute trigamma inverse - * \return Value y such that trigamma(y) = x - */ - double TrigammaInverse(double x) { - // Very large and very small values - deal with using asymptotics - if (x > 1E7) { - return 1.0 / std::sqrt(x); - } - if (x < 1E-6) { - return 1.0 / x; - } - - // Otherwise, use Newton's method - double y = 0.5 + 1.0 / x; - for (int i = 0; i < 50; i++) { - double tri = boost::math::trigamma(y); - double dif = tri * (1.0 - tri / x) / boost::math::polygamma(3, y); // tetragamma is polygamma(3, x) - y += dif; - if (-dif / y < 1E-8) break; - } - - return y; - } double a_; double b_; GammaSampler gamma_sampler_; - // SliceSampler slice_sampler_; - double tau_; }; /*! \brief Unifying layer for disparate sufficient statistic class types From 6ffdef7fe61a540dd7a44ccffe1e06b154dc7e63 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:56:33 -0400 Subject: [PATCH 19/34] Adding num_threads (back) to GFR interface --- include/stochtree/tree_sampler.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index fdaf507f..0901c0e0 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -301,8 +301,6 @@ static inline void UpdateResidualNewOutcome(ForestTracker& tracker, ColumnVector } } - - static inline void UpdateMeanModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, bool requires_basis, std::function op, bool tree_new) { data_size_t n = dataset.GetCovariates().rows(); @@ -840,7 +838,7 @@ template & variable_weights, std::vector& sweep_update_indices, double global_variance, std::vector& feature_types, int cutpoint_grid_size, - bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + bool keep_forest, bool pre_initialized, bool backfitting, int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the GFR algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { @@ -860,7 +858,7 @@ static inline void GFRSampleOneIter(TreeEnsemble& active_forest, ForestTracker& GFRSampleTreeOneIter( tree, tracker, forests, leaf_model, dataset, residual, tree_prior, gen, variable_weights, i, global_variance, feature_types, cutpoint_grid_size, - num_features_subsample, leaf_suff_stat_args... + num_features_subsample, num_threads, leaf_suff_stat_args... ); // Sample leaf parameters for tree i From 853b129fc6f85fcc0e7cebd4436b7d69db5f3166 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 12:59:55 -0400 Subject: [PATCH 20/34] Continue building in multithreading support to cloglog branch --- include/stochtree/tree_sampler.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 0901c0e0..c35ef080 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -744,7 +744,7 @@ template & variable_weights, int tree_num, double global_variance, std::vector& feature_types, int cutpoint_grid_size, - int num_features_subsample, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + int num_features_subsample, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { int root_id = Tree::kRoot; int curr_node_id; data_size_t curr_node_begin; @@ -800,7 +800,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore SampleSplitRule( tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types, - feature_subset, leaf_suff_stat_args...); + feature_subset, num_threads, leaf_suff_stat_args...); } } @@ -832,6 +832,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). * \param num_features_subsample How many features to subsample when running the GFR algorithm. + * \param num_threads Number of threads to use for split evaluations and other compute-intensive operations. * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template @@ -1142,7 +1143,7 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon). * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). - * \param num_threads + * \param num_threads Number of threads to use for split evaluations and other compute-intensive operations. * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template From 9edad367b5a6cc44e27d8e79aaddc2047abe663c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 13:02:49 -0400 Subject: [PATCH 21/34] Update tree_sampler.h --- include/stochtree/tree_sampler.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index c35ef080..1bcd42d2 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -513,7 +513,7 @@ static inline void EvaluateAllPossibleSplits( ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id, std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types, std::vector& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args + std::vector& feature_types, std::vector& feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args ) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -613,14 +613,14 @@ static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafMod std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, - std::vector& feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Evaluate all possible cutpoints according to the leaf node model, // recording their log-likelihood and other split information in a series of vectors. // The last element of these vectors concerns the "no-split" option. EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, - node_begin, node_end, variable_weights, feature_types, feature_subset, leaf_suff_stat_args... + node_begin, node_end, variable_weights, feature_types, feature_subset, num_threads, leaf_suff_stat_args... ); // Compute an adjustment to reflect the no split prior probability and the number of cutpoints @@ -641,7 +641,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, std::unordered_map>& node_index_map, std::deque& split_queue, int node_id, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types, std::vector feature_subset, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& feature_types, std::vector feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Leaf depth int leaf_depth = tree->GetDepth(node_id); @@ -661,7 +661,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container, feature_subset, leaf_suff_stat_args... + cutpoint_grid_container, feature_subset, num_threads, leaf_suff_stat_args... ); // TODO: maybe add some checks here? From 04de102e2b0252b1aa78fe31d555b7bb9a609de5 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 13:05:15 -0400 Subject: [PATCH 22/34] Updating GFR to reflect multithreading capabilities in the main branch --- include/stochtree/tree_sampler.h | 294 +++++++++++++++---------------- 1 file changed, 139 insertions(+), 155 deletions(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 1bcd42d2..4ab82ac9 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -508,134 +508,6 @@ static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafMode } } -template -static inline void EvaluateAllPossibleSplits( - ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id, - std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, - data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types, std::vector& feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args -) { - // Initialize sufficient statistics - LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); - - // Accumulate aggregate sufficient statistic for the node to be split - AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, split_node_id); - - // Compute the "no split" log marginal likelihood - double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); - - // Unpack data - Eigen::MatrixXd covariates = dataset.GetCovariates(); - Eigen::VectorXd outcome = residual.GetData(); - Eigen::VectorXd var_weights; - bool has_weights = dataset.HasVarWeights(); - if (has_weights) var_weights = dataset.GetVarWeights(); - - // Minimum size of newly created leaf nodes (used to rule out invalid splits) - int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); - - // Compute sufficient statistics for each possible split - data_size_t num_cutpoints = 0; - bool valid_split = false; - data_size_t node_row_iter; - data_size_t current_bin_begin, current_bin_size, next_bin_begin; - data_size_t feature_sort_idx; - data_size_t row_iter_idx; - double outcome_val, outcome_val_sq; - FeatureType feature_type; - double feature_value = 0.0; - double cutoff_value = 0.0; - double log_split_eval = 0.0; - double split_log_ml; - for (int j = 0; j < covariates.cols(); j++) { - - if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) { - // Enumerate cutpoint strides - cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), split_node_id, node_begin, node_end, j, feature_types); - - // Reset sufficient statistics - left_suff_stat.ResetSuffStat(); - right_suff_stat.ResetSuffStat(); - - // Iterate through possible cutpoints - int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); - feature_type = feature_types[j]; - // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins - for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { - current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); - current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); - next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); - - // Accumulate sufficient statistics for the left node - AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, - global_variance, tree_num, split_node_id, j, cutpoint_idx); - - // Compute the corresponding right node sufficient statistics - right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); - - // Store the bin index as the "cutpoint value" - we can use this to query the actual split - // value or the set of split categories later on once a split is chose - cutoff_value = cutpoint_idx; - - // Only include cutpoint for consideration if it defines a valid split in the training data - valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && - right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); - if (valid_split) { - num_cutpoints++; - // Add to split rule vector - cutpoint_feature_types.push_back(feature_type); - cutpoint_features.push_back(j); - cutpoint_values.push_back(cutoff_value); - // Add the log marginal likelihood of the split to the split eval vector - split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); - log_cutpoint_evaluations.push_back(split_log_ml); - } - } - } - - } - - // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) - cutpoint_features.push_back(-1); - cutpoint_values.push_back(std::numeric_limits::max()); - cutpoint_feature_types.push_back(FeatureType::kNumeric); - log_cutpoint_evaluations.push_back(no_split_log_ml); - - // Update valid cutpoint count - valid_cutpoint_count = num_cutpoints; -} - -template -static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, - std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end, - std::vector& log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, - std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, - std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, - std::vector& feature_subset, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { - // Evaluate all possible cutpoints according to the leaf node model, - // recording their log-likelihood and other split information in a series of vectors. - // The last element of these vectors concerns the "no-split" option. - EvaluateAllPossibleSplits( - dataset, tracker, residual, tree_prior, gen, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations, - cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, - node_begin, node_end, variable_weights, feature_types, feature_subset, num_threads, leaf_suff_stat_args... - ); - - // Compute an adjustment to reflect the no split prior probability and the number of cutpoints - double bart_prior_no_split_adj; - double alpha = tree_prior.GetAlpha(); - double beta = tree_prior.GetBeta(); - int node_depth = tree->GetDepth(node_id); - if (valid_cutpoint_count == 0) { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); - } else { - bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); - } - log_cutpoint_evaluations[log_cutpoint_evaluations.size()-1] += bart_prior_no_split_adj; -} - template static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, @@ -649,41 +521,153 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel int32_t max_depth = tree_prior.GetMaxDepth(); if ((max_depth == -1) || (leaf_depth < max_depth)) { - - // Cutpoint enumeration - std::vector log_cutpoint_evaluations; - std::vector cutpoint_features; - std::vector cutpoint_values; - std::vector cutpoint_feature_types; + + // Vector of vectors to store results for each feature + int p = dataset.NumCovariates(); + std::vector> feature_log_cutpoint_evaluations(p+1); + std::vector> feature_cutpoint_values(p+1); + std::vector feature_cutpoint_counts(p+1, 0); StochTree::data_size_t valid_cutpoint_count; - CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); - EvaluateCutpoints( - tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, - cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container, feature_subset, num_threads, leaf_suff_stat_args... - ); - // TODO: maybe add some checks here? - // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood - double largest_mll = *std::max_element(log_cutpoint_evaluations.begin(), log_cutpoint_evaluations.end()); - std::vector cutpoint_evaluations(log_cutpoint_evaluations.size()); - for (data_size_t i = 0; i < log_cutpoint_evaluations.size(); i++){ - cutpoint_evaluations[i] = std::exp(log_cutpoint_evaluations[i] - largest_mll); + // Evaluate all possible cutpoints according to the leaf node model, + // recording their log-likelihood and other split information in a series of vectors. + + // Initialize node sufficient statistics + LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + + // Accumulate aggregate sufficient statistic for the node to be split + AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, node_id); + + // Compute the "no split" log marginal likelihood + double no_split_log_ml = leaf_model.NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); + + // Unpack data + Eigen::MatrixXd& covariates = dataset.GetCovariates(); + Eigen::VectorXd& outcome = residual.GetData(); + Eigen::VectorXd var_weights; + bool has_weights = dataset.HasVarWeights(); + if (has_weights) var_weights = dataset.GetVarWeights(); + + // Minimum size of newly created leaf nodes (used to rule out invalid splits) + int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); + + // Compute sufficient statistics for each possible split + data_size_t num_cutpoints = 0; + if (num_threads == -1) { + num_threads = GetOptimalThreadCount(static_cast(covariates.cols() * covariates.rows())); + } + + // Initialize cutpoint grid container + CutpointGridContainer cutpoint_grid_container(covariates, outcome, cutpoint_grid_size); + + // Evaluate all possible splits for each feature in parallel + StochTree::ParallelFor(0, covariates.cols(), num_threads, [&](int j) { + if ((std::abs(variable_weights.at(j)) > kEpsilon) && (feature_subset[j])) { + // Enumerate cutpoint strides + cutpoint_grid_container.CalculateStrides(covariates, outcome, tracker.GetSortedNodeSampleTracker(), node_id, node_begin, node_end, j, feature_types); + + // Left and right node sufficient statistics + LeafSuffStat left_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); + + // Iterate through possible cutpoints + int32_t num_feature_cutpoints = cutpoint_grid_container.NumCutpoints(j); + FeatureType feature_type = feature_types[j]; + // Since we partition an entire cutpoint bin to the left, we must stop one bin before the total number of cutpoint bins + for (data_size_t cutpoint_idx = 0; cutpoint_idx < (num_feature_cutpoints - 1); cutpoint_idx++) { + data_size_t current_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx, j); + data_size_t current_bin_size = cutpoint_grid_container.BinLength(cutpoint_idx, j); + data_size_t next_bin_begin = cutpoint_grid_container.BinStartIndex(cutpoint_idx + 1, j); + + // Accumulate sufficient statistics for the left node + AccumulateCutpointBinSuffStat(left_suff_stat, tracker, cutpoint_grid_container, dataset, residual, + global_variance, tree_num, node_id, j, cutpoint_idx); + + // Compute the corresponding right node sufficient statistics + right_suff_stat.SubtractSuffStat(node_suff_stat, left_suff_stat); + + // Store the bin index as the "cutpoint value" - we can use this to query the actual split + // value or the set of split categories later on once a split is chose + double cutoff_value = cutpoint_idx; + + // Only include cutpoint for consideration if it defines a valid split in the training data + bool valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && + right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); + if (valid_split) { + feature_cutpoint_counts[j]++; + // Add to split rule vector + feature_cutpoint_values[j].push_back(cutoff_value); + // Add the log marginal likelihood of the split to the split eval vector + double split_log_ml = leaf_model.SplitLogMarginalLikelihood(left_suff_stat, right_suff_stat, global_variance); + feature_log_cutpoint_evaluations[j].push_back(split_log_ml); + } + } + } + }); + + // Compute total number of cutpoints + valid_cutpoint_count = std::accumulate(feature_cutpoint_counts.begin(), feature_cutpoint_counts.end(), 0); + + // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) + feature_log_cutpoint_evaluations[covariates.cols()].push_back(no_split_log_ml); + + // Compute an adjustment to reflect the no split prior probability and the number of cutpoints + double bart_prior_no_split_adj; + double alpha = tree_prior.GetAlpha(); + double beta = tree_prior.GetBeta(); + int node_depth = tree->GetDepth(node_id); + if (valid_cutpoint_count == 0) { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0); + } else { + bart_prior_no_split_adj = std::log(((std::pow(1+node_depth, beta))/alpha) - 1.0) + std::log(valid_cutpoint_count); } + feature_log_cutpoint_evaluations[covariates.cols()][0] += bart_prior_no_split_adj; + - // Sample the split (including a "no split" option) - std::discrete_distribution split_dist(cutpoint_evaluations.begin(), cutpoint_evaluations.end()); - data_size_t split_chosen = split_dist(gen); + // Convert log marginal likelihood to marginal likelihood, normalizing by the maximum log-likelihood + double largest_ml = -std::numeric_limits::infinity(); + for (int j = 0; j < p + 1; j++) { + if (feature_log_cutpoint_evaluations[j].size() > 0) { + double feature_max_ml = *std::max_element(feature_log_cutpoint_evaluations[j].begin(), feature_log_cutpoint_evaluations[j].end());; + largest_ml = std::max(largest_ml, feature_max_ml); + } + } + std::vector> feature_cutpoint_evaluations(p+1); + for (int j = 0; j < p + 1; j++) { + if (feature_log_cutpoint_evaluations[j].size() > 0) { + feature_cutpoint_evaluations[j].resize(feature_log_cutpoint_evaluations[j].size()); + for (int i = 0; i < feature_log_cutpoint_evaluations[j].size(); i++) { + feature_cutpoint_evaluations[j][i] = std::exp(feature_log_cutpoint_evaluations[j][i] - largest_ml); + } + } + } + + // Compute sum of marginal likelihoods for each feature + std::vector feature_total_cutpoint_evaluations(p+1, 0.0); + for (int j = 0; j < p + 1; j++) { + if (feature_log_cutpoint_evaluations[j].size() > 0) { + feature_total_cutpoint_evaluations[j] = std::accumulate(feature_cutpoint_evaluations[j].begin(), feature_cutpoint_evaluations[j].end(), 0.0); + } else { + feature_total_cutpoint_evaluations[j] = 0.0; + } + } + + // First, sample a feature according to feature_total_cutpoint_evaluations + std::discrete_distribution feature_dist(feature_total_cutpoint_evaluations.begin(), feature_total_cutpoint_evaluations.end()); + int feature_chosen = feature_dist(gen); + + // Then, sample a cutpoint according to feature_cutpoint_evaluations[feature_chosen] + std::discrete_distribution cutpoint_dist(feature_cutpoint_evaluations[feature_chosen].begin(), feature_cutpoint_evaluations[feature_chosen].end()); + data_size_t cutpoint_chosen = cutpoint_dist(gen); - if (split_chosen == valid_cutpoint_count){ + if (feature_chosen == p){ // "No split" sampled, don't split or add any nodes to split queue return; } else { // Split sampled - int feature_split = cutpoint_features[split_chosen]; - FeatureType feature_type = cutpoint_feature_types[split_chosen]; - double split_value = cutpoint_values[split_chosen]; + int feature_split = feature_chosen; + FeatureType feature_type = feature_types[feature_split]; + double split_value = feature_cutpoint_values[feature_split][cutpoint_chosen]; // Perform all of the relevant "split" operations in the model, tree and training dataset // Compute node sample size @@ -718,7 +702,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads); // Determine the number of observation in the newly created left node int left_node = tree->LeftChild(node_id); From cdca9156e7609e521c5a767a26ef73055b0518fd Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 13:18:34 -0400 Subject: [PATCH 23/34] Reflecting num_threads through the MCMC and GFR interface --- R/cloglog_ordinal_bart.R | 12 +- R/cpp11.R | 12 +- include/stochtree/tree_sampler.h | 4 +- man/ForestDataset.Rd | 27 -- man/cloglog_ordinal_bart.Rd | 11 +- src/cpp11.cpp | 467 +++++++++++++++---------------- src/sampler.cpp | 13 +- 7 files changed, 259 insertions(+), 287 deletions(-) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index 9932df1c..e566fb51 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -9,8 +9,10 @@ #' @param n_thin Thinning interval for MCMC samples. Default: `1`. #' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`. #' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`. -#' @param variable_weights Optional vector of variable weights for splitting (default: equal weights). -#' @param feature_types Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). +#' @param variable_weights (Optional) vector of variable weights for splitting (default: equal weights). +#' @param feature_types (Optional) vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). +#' @param seed (Optional) random seed for reproducibility. +#' @param num_threads (Optional) Number of threads to use in split evaluations and other compute-intensive operations. Default: 1. #' @export cloglog_ordinal_bart <- function(X, y, X_test = NULL, n_trees = 50, @@ -21,7 +23,8 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, beta_gamma = 2.0, variable_weights = NULL, feature_types = NULL, - seed = NULL) { + seed = NULL, + num_threads = 1) { # BART parameters alpha_bart <- 0.95 @@ -148,7 +151,8 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), - scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample + scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample, + num_threads ) # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions diff --git a/R/cpp11.R b/R/cpp11.R index 995dc82a..c714cc75 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -80,10 +80,6 @@ forest_dataset_store_auxiliary_data_vector_as_column_cpp <- function(dataset_ptr .Call(`_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp`, dataset_ptr, output_matrix, dim_idx, matrix_col_idx) } -forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum <- function(dataset_ptr, reference_vector_idx, target_vector_idx) { - invisible(.Call(`_stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum`, dataset_ptr, reference_vector_idx, target_vector_idx)) -} - create_column_vector_cpp <- function(outcome) { .Call(`_stochtree_create_column_vector_cpp`, outcome) } @@ -652,12 +648,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) .Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums) } -sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample) { - invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample)) +sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads) { + invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_features_subsample, num_threads)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads) { + invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, sweep_indices, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, num_threads)) } sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) { diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 4ab82ac9..7c0254c6 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -702,7 +702,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads); + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); // Determine the number of observation in the newly created left node int left_node = tree->LeftChild(node_id); @@ -1053,7 +1053,7 @@ static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, Leaf template static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - int tree_num, double global_variance, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Determine whether it is possible to grow any of the leaves bool grow_possible = false; std::vector leaves = tree->GetLeaves(); diff --git a/man/ForestDataset.Rd b/man/ForestDataset.Rd index ac6e34d5..e684cd18 100644 --- a/man/ForestDataset.Rd +++ b/man/ForestDataset.Rd @@ -35,7 +35,6 @@ weights are optional. \item \href{#method-ForestDataset-set_auxiliary_data_value}{\code{ForestDataset$set_auxiliary_data_value()}} \item \href{#method-ForestDataset-get_auxiliary_data_vector}{\code{ForestDataset$get_auxiliary_data_vector()}} \item \href{#method-ForestDataset-store_auxiliary_data_vector_matrix}{\code{ForestDataset$store_auxiliary_data_vector_matrix()}} -\item \href{#method-ForestDataset-update_auxiliary_data_vector_cumulative_exp_sum}{\code{ForestDataset$update_auxiliary_data_vector_cumulative_exp_sum()}} } } \if{html}{\out{
}} @@ -336,30 +335,4 @@ Retrieve auxiliary data vector and place it into a column of the supplied matrix Vector of all of the auxiliary data stored at dimension \code{dim_idx} } } -\if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ForestDataset-update_auxiliary_data_vector_cumulative_exp_sum}{}}} -\subsection{Method \code{update_auxiliary_data_vector_cumulative_exp_sum()}}{ -Updates the elements of one auxiliary data vector based on the cumulative exponentiated values of elements of another vector. -If the target value has \code{k} elements, the reference vector must have \code{k - 1} elements. -\subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ForestDataset$update_auxiliary_data_vector_cumulative_exp_sum( - reference_vector_idx, - target_vector_idx -)}\if{html}{\out{
}} -} - -\subsection{Arguments}{ -\if{html}{\out{
}} -\describe{ -\item{\code{reference_vector_idx}}{Index of the auxiliary data vector to be exponentiated and scanned} - -\item{\code{target_vector_idx}}{Index of the auxiliary data vector to be written with exponentiated and scanned values of \code{reference_vector_idx}} -} -\if{html}{\out{
}} -} -\subsection{Returns}{ -None -} -} } diff --git a/man/cloglog_ordinal_bart.Rd b/man/cloglog_ordinal_bart.Rd index 9c2aed51..6cdba1c2 100644 --- a/man/cloglog_ordinal_bart.Rd +++ b/man/cloglog_ordinal_bart.Rd @@ -16,7 +16,8 @@ cloglog_ordinal_bart( beta_gamma = 2, variable_weights = NULL, feature_types = NULL, - seed = NULL + seed = NULL, + num_threads = 1 ) } \arguments{ @@ -38,9 +39,13 @@ cloglog_ordinal_bart( \item{beta_gamma}{Rate parameter for the log-gamma prior on cutpoints. Default: \code{2.0}.} -\item{variable_weights}{Optional vector of variable weights for splitting (default: equal weights).} +\item{variable_weights}{(Optional) vector of variable weights for splitting (default: equal weights).} -\item{feature_types}{Optional vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).} +\item{feature_types}{(Optional) vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).} + +\item{seed}{(Optional) random seed for reproducibility.} + +\item{num_threads}{(Optional) Number of threads to use in split evaluations and other compute-intensive operations. Default: 1.} } \description{ Run the BART algorithm for ordinal outcomes using a complementary log-log link diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 3336adca..0f20cdcb 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -153,14 +153,6 @@ extern "C" SEXP _stochtree_forest_dataset_store_auxiliary_data_vector_as_column_ END_CPP11 } // R_data.cpp -void forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(cpp11::external_pointer dataset_ptr, int reference_vector_idx, int target_vector_idx); -extern "C" SEXP _stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(SEXP dataset_ptr, SEXP reference_vector_idx, SEXP target_vector_idx) { - BEGIN_CPP11 - forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(reference_vector_idx), cpp11::as_cpp>(target_vector_idx)); - return R_NilValue; - END_CPP11 -} -// R_data.cpp cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome); extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) { BEGIN_CPP11 @@ -1209,18 +1201,18 @@ extern "C" SEXP _stochtree_compute_leaf_indices_cpp(SEXP forest_container, SEXP END_CPP11 } // sampler.cpp -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_features_subsample); -extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_features_subsample) { +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_features_subsample, int num_threads); +extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_features_subsample, SEXP num_threads) { BEGIN_CPP11 - sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_features_subsample)); + sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_features_subsample), cpp11::as_cpp>(num_threads)); return R_NilValue; END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers sweep_indices, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, int num_threads); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP sweep_indices, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP num_threads) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); + sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(sweep_indices), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(num_threads)); return R_NilValue; END_CPP11 } @@ -1665,230 +1657,229 @@ extern "C" SEXP _stochtree_json_load_string_cpp(SEXP json_ptr, SEXP json_string) extern "C" { static const R_CallMethodDef CallEntries[] = { - {"_stochtree_active_forest_cpp", (DL_FUNC) &_stochtree_active_forest_cpp, 4}, - {"_stochtree_add_numeric_split_tree_value_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_active_forest_cpp, 7}, - {"_stochtree_add_numeric_split_tree_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_forest_container_cpp, 8}, - {"_stochtree_add_numeric_split_tree_vector_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_active_forest_cpp, 7}, - {"_stochtree_add_numeric_split_tree_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_forest_container_cpp, 8}, - {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, - {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, - {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, - {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, - {"_stochtree_add_to_forest_forest_container_cpp", (DL_FUNC) &_stochtree_add_to_forest_forest_container_cpp, 3}, - {"_stochtree_adjust_residual_active_forest_cpp", (DL_FUNC) &_stochtree_adjust_residual_active_forest_cpp, 6}, - {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, - {"_stochtree_all_roots_active_forest_cpp", (DL_FUNC) &_stochtree_all_roots_active_forest_cpp, 1}, - {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, - {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, - {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, - {"_stochtree_combine_forests_forest_container_cpp", (DL_FUNC) &_stochtree_combine_forests_forest_container_cpp, 2}, - {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, - {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, - {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, - {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, - {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, - {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, - {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, - {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, - {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, - {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, - {"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2}, - {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, - {"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2}, - {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, - {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, - {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4}, - {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, - {"_stochtree_forest_container_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_string_cpp, 2}, - {"_stochtree_forest_container_get_max_leaf_index_cpp", (DL_FUNC) &_stochtree_forest_container_get_max_leaf_index_cpp, 2}, - {"_stochtree_forest_dataset_add_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_auxiliary_dimension_cpp, 2}, - {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, - {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, - {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, - {"_stochtree_forest_dataset_get_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_value_cpp, 3}, - {"_stochtree_forest_dataset_get_auxiliary_data_vector_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_vector_cpp, 2}, - {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, - {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, - {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, - {"_stochtree_forest_dataset_has_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_has_auxiliary_dimension_cpp, 2}, - {"_stochtree_forest_dataset_set_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_set_auxiliary_data_value_cpp, 4}, - {"_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp", (DL_FUNC) &_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp, 4}, - {"_stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum", (DL_FUNC) &_stochtree_forest_dataset_update_auxiliary_data_vector_cumulative_exp_sum, 3}, - {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, - {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, - {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, - {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, - {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, - {"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1}, - {"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1}, - {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1}, - {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, - {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, - {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, - {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, - {"_stochtree_get_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_get_max_depth_tree_prior_cpp, 1}, - {"_stochtree_get_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_get_min_samples_leaf_tree_prior_cpp, 1}, - {"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2}, - {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, - {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, - {"_stochtree_get_tree_leaves_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_active_forest_cpp, 2}, - {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, - {"_stochtree_get_tree_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_active_forest_cpp, 3}, - {"_stochtree_get_tree_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_forest_container_cpp, 4}, - {"_stochtree_init_json_cpp", (DL_FUNC) &_stochtree_init_json_cpp, 0}, - {"_stochtree_initialize_forest_model_active_forest_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_active_forest_cpp, 6}, - {"_stochtree_initialize_forest_model_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_cpp, 6}, - {"_stochtree_is_categorical_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_categorical_split_node_forest_container_cpp, 4}, - {"_stochtree_is_exponentiated_active_forest_cpp", (DL_FUNC) &_stochtree_is_exponentiated_active_forest_cpp, 1}, - {"_stochtree_is_exponentiated_forest_container_cpp", (DL_FUNC) &_stochtree_is_exponentiated_forest_container_cpp, 1}, - {"_stochtree_is_leaf_constant_active_forest_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_active_forest_cpp, 1}, - {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, - {"_stochtree_is_leaf_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_node_forest_container_cpp, 4}, - {"_stochtree_is_numeric_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_numeric_split_node_forest_container_cpp, 4}, - {"_stochtree_json_add_bool_cpp", (DL_FUNC) &_stochtree_json_add_bool_cpp, 3}, - {"_stochtree_json_add_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_bool_subfolder_cpp, 4}, - {"_stochtree_json_add_double_cpp", (DL_FUNC) &_stochtree_json_add_double_cpp, 3}, - {"_stochtree_json_add_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_double_subfolder_cpp, 4}, - {"_stochtree_json_add_forest_cpp", (DL_FUNC) &_stochtree_json_add_forest_cpp, 2}, - {"_stochtree_json_add_integer_cpp", (DL_FUNC) &_stochtree_json_add_integer_cpp, 3}, - {"_stochtree_json_add_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_subfolder_cpp, 4}, - {"_stochtree_json_add_integer_vector_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_cpp, 3}, - {"_stochtree_json_add_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_subfolder_cpp, 4}, - {"_stochtree_json_add_rfx_container_cpp", (DL_FUNC) &_stochtree_json_add_rfx_container_cpp, 2}, - {"_stochtree_json_add_rfx_groupids_cpp", (DL_FUNC) &_stochtree_json_add_rfx_groupids_cpp, 2}, - {"_stochtree_json_add_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_json_add_rfx_label_mapper_cpp, 2}, - {"_stochtree_json_add_string_cpp", (DL_FUNC) &_stochtree_json_add_string_cpp, 3}, - {"_stochtree_json_add_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_subfolder_cpp, 4}, - {"_stochtree_json_add_string_vector_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_cpp, 3}, - {"_stochtree_json_add_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_subfolder_cpp, 4}, - {"_stochtree_json_add_vector_cpp", (DL_FUNC) &_stochtree_json_add_vector_cpp, 3}, - {"_stochtree_json_add_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_vector_subfolder_cpp, 4}, - {"_stochtree_json_contains_field_cpp", (DL_FUNC) &_stochtree_json_contains_field_cpp, 2}, - {"_stochtree_json_contains_field_subfolder_cpp", (DL_FUNC) &_stochtree_json_contains_field_subfolder_cpp, 3}, - {"_stochtree_json_extract_bool_cpp", (DL_FUNC) &_stochtree_json_extract_bool_cpp, 2}, - {"_stochtree_json_extract_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_bool_subfolder_cpp, 3}, - {"_stochtree_json_extract_double_cpp", (DL_FUNC) &_stochtree_json_extract_double_cpp, 2}, - {"_stochtree_json_extract_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_double_subfolder_cpp, 3}, - {"_stochtree_json_extract_integer_cpp", (DL_FUNC) &_stochtree_json_extract_integer_cpp, 2}, - {"_stochtree_json_extract_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_subfolder_cpp, 3}, - {"_stochtree_json_extract_integer_vector_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_cpp, 2}, - {"_stochtree_json_extract_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_subfolder_cpp, 3}, - {"_stochtree_json_extract_string_cpp", (DL_FUNC) &_stochtree_json_extract_string_cpp, 2}, - {"_stochtree_json_extract_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_subfolder_cpp, 3}, - {"_stochtree_json_extract_string_vector_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_cpp, 2}, - {"_stochtree_json_extract_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_subfolder_cpp, 3}, - {"_stochtree_json_extract_vector_cpp", (DL_FUNC) &_stochtree_json_extract_vector_cpp, 2}, - {"_stochtree_json_extract_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_vector_subfolder_cpp, 3}, - {"_stochtree_json_increment_rfx_count_cpp", (DL_FUNC) &_stochtree_json_increment_rfx_count_cpp, 1}, - {"_stochtree_json_load_file_cpp", (DL_FUNC) &_stochtree_json_load_file_cpp, 2}, - {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, - {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, - {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, - {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, - {"_stochtree_leaf_dimension_active_forest_cpp", (DL_FUNC) &_stochtree_leaf_dimension_active_forest_cpp, 1}, - {"_stochtree_leaf_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_dimension_forest_container_cpp, 1}, - {"_stochtree_leaf_values_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_values_forest_container_cpp, 4}, - {"_stochtree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_leaves_forest_container_cpp, 3}, - {"_stochtree_left_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_left_child_node_forest_container_cpp, 4}, - {"_stochtree_multiply_forest_forest_container_cpp", (DL_FUNC) &_stochtree_multiply_forest_forest_container_cpp, 3}, - {"_stochtree_node_depth_forest_container_cpp", (DL_FUNC) &_stochtree_node_depth_forest_container_cpp, 4}, - {"_stochtree_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_nodes_forest_container_cpp, 3}, - {"_stochtree_num_leaf_parents_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaf_parents_forest_container_cpp, 3}, - {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2}, - {"_stochtree_num_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_forest_container_cpp, 3}, - {"_stochtree_num_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_nodes_forest_container_cpp, 3}, - {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, - {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, - {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, - {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, - {"_stochtree_ordinal_sampler_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_cpp, 0}, - {"_stochtree_ordinal_sampler_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_cumsum_exp_cpp, 2}, - {"_stochtree_ordinal_sampler_update_gamma_params_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_gamma_params_cpp, 7}, - {"_stochtree_ordinal_sampler_update_latent_variables_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_latent_variables_cpp, 4}, - {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, - {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, - {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, - {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, - {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, - {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, - {"_stochtree_predict_forest_raw_single_tree_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_tree_cpp, 4}, - {"_stochtree_predict_raw_active_forest_cpp", (DL_FUNC) &_stochtree_predict_raw_active_forest_cpp, 2}, - {"_stochtree_propagate_basis_update_active_forest_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_active_forest_cpp, 4}, - {"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5}, - {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, - {"_stochtree_remove_sample_forest_container_cpp", (DL_FUNC) &_stochtree_remove_sample_forest_container_cpp, 2}, - {"_stochtree_reset_active_forest_cpp", (DL_FUNC) &_stochtree_reset_active_forest_cpp, 3}, - {"_stochtree_reset_forest_model_cpp", (DL_FUNC) &_stochtree_reset_forest_model_cpp, 5}, - {"_stochtree_reset_rfx_model_cpp", (DL_FUNC) &_stochtree_reset_rfx_model_cpp, 3}, - {"_stochtree_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_reset_rfx_tracker_cpp, 4}, - {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, - {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, - {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, - {"_stochtree_rfx_container_delete_sample_cpp", (DL_FUNC) &_stochtree_rfx_container_delete_sample_cpp, 2}, - {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, - {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, - {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, - {"_stochtree_rfx_container_get_beta_cpp", (DL_FUNC) &_stochtree_rfx_container_get_beta_cpp, 1}, - {"_stochtree_rfx_container_get_sigma_cpp", (DL_FUNC) &_stochtree_rfx_container_get_sigma_cpp, 1}, - {"_stochtree_rfx_container_get_xi_cpp", (DL_FUNC) &_stochtree_rfx_container_get_xi_cpp, 1}, - {"_stochtree_rfx_container_num_components_cpp", (DL_FUNC) &_stochtree_rfx_container_num_components_cpp, 1}, - {"_stochtree_rfx_container_num_groups_cpp", (DL_FUNC) &_stochtree_rfx_container_num_groups_cpp, 1}, - {"_stochtree_rfx_container_num_samples_cpp", (DL_FUNC) &_stochtree_rfx_container_num_samples_cpp, 1}, - {"_stochtree_rfx_container_predict_cpp", (DL_FUNC) &_stochtree_rfx_container_predict_cpp, 3}, - {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, - {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, - {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, - {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, - {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, - {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, - {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, - {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, - {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, - {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, - {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, - {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, - {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, - {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, - {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, - {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, - {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, - {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, - {"_stochtree_rfx_label_mapper_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_string_cpp, 2}, - {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, - {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, - {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, - {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 8}, - {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, - {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, - {"_stochtree_rfx_model_set_variance_prior_scale_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_scale_cpp, 2}, - {"_stochtree_rfx_model_set_variance_prior_shape_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_shape_cpp, 2}, - {"_stochtree_rfx_model_set_working_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_covariance_cpp, 2}, - {"_stochtree_rfx_model_set_working_parameter_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_cpp, 2}, - {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, - {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, - {"_stochtree_right_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_right_child_node_forest_container_cpp, 4}, - {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, - {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 18}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, - {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, - {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, - {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, - {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, - {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, - {"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2}, - {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, - {"_stochtree_split_categories_forest_container_cpp", (DL_FUNC) &_stochtree_split_categories_forest_container_cpp, 4}, - {"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4}, - {"_stochtree_split_theshold_forest_container_cpp", (DL_FUNC) &_stochtree_split_theshold_forest_container_cpp, 4}, - {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, - {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, - {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, - {"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2}, - {"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2}, - {"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2}, - {"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2}, + {"_stochtree_active_forest_cpp", (DL_FUNC) &_stochtree_active_forest_cpp, 4}, + {"_stochtree_add_numeric_split_tree_value_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_active_forest_cpp, 7}, + {"_stochtree_add_numeric_split_tree_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_forest_container_cpp, 8}, + {"_stochtree_add_numeric_split_tree_vector_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_active_forest_cpp, 7}, + {"_stochtree_add_numeric_split_tree_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_forest_container_cpp, 8}, + {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, + {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, + {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, + {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, + {"_stochtree_add_to_forest_forest_container_cpp", (DL_FUNC) &_stochtree_add_to_forest_forest_container_cpp, 3}, + {"_stochtree_adjust_residual_active_forest_cpp", (DL_FUNC) &_stochtree_adjust_residual_active_forest_cpp, 6}, + {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, + {"_stochtree_all_roots_active_forest_cpp", (DL_FUNC) &_stochtree_all_roots_active_forest_cpp, 1}, + {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, + {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, + {"_stochtree_combine_forests_forest_container_cpp", (DL_FUNC) &_stochtree_combine_forests_forest_container_cpp, 2}, + {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, + {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, + {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, + {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, + {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, + {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, + {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, + {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, + {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, + {"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2}, + {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, + {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, + {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4}, + {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, + {"_stochtree_forest_container_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_string_cpp, 2}, + {"_stochtree_forest_container_get_max_leaf_index_cpp", (DL_FUNC) &_stochtree_forest_container_get_max_leaf_index_cpp, 2}, + {"_stochtree_forest_dataset_add_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_auxiliary_dimension_cpp, 2}, + {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, + {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, + {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, + {"_stochtree_forest_dataset_get_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_value_cpp, 3}, + {"_stochtree_forest_dataset_get_auxiliary_data_vector_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_vector_cpp, 2}, + {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, + {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, + {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, + {"_stochtree_forest_dataset_has_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_has_auxiliary_dimension_cpp, 2}, + {"_stochtree_forest_dataset_set_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_set_auxiliary_data_value_cpp, 4}, + {"_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp", (DL_FUNC) &_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp, 4}, + {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, + {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, + {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, + {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, + {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, + {"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1}, + {"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1}, + {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1}, + {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, + {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, + {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, + {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, + {"_stochtree_get_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_get_max_depth_tree_prior_cpp, 1}, + {"_stochtree_get_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_get_min_samples_leaf_tree_prior_cpp, 1}, + {"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2}, + {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, + {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, + {"_stochtree_get_tree_leaves_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_active_forest_cpp, 2}, + {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, + {"_stochtree_get_tree_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_active_forest_cpp, 3}, + {"_stochtree_get_tree_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_forest_container_cpp, 4}, + {"_stochtree_init_json_cpp", (DL_FUNC) &_stochtree_init_json_cpp, 0}, + {"_stochtree_initialize_forest_model_active_forest_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_active_forest_cpp, 6}, + {"_stochtree_initialize_forest_model_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_cpp, 6}, + {"_stochtree_is_categorical_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_categorical_split_node_forest_container_cpp, 4}, + {"_stochtree_is_exponentiated_active_forest_cpp", (DL_FUNC) &_stochtree_is_exponentiated_active_forest_cpp, 1}, + {"_stochtree_is_exponentiated_forest_container_cpp", (DL_FUNC) &_stochtree_is_exponentiated_forest_container_cpp, 1}, + {"_stochtree_is_leaf_constant_active_forest_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_active_forest_cpp, 1}, + {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, + {"_stochtree_is_leaf_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_node_forest_container_cpp, 4}, + {"_stochtree_is_numeric_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_numeric_split_node_forest_container_cpp, 4}, + {"_stochtree_json_add_bool_cpp", (DL_FUNC) &_stochtree_json_add_bool_cpp, 3}, + {"_stochtree_json_add_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_bool_subfolder_cpp, 4}, + {"_stochtree_json_add_double_cpp", (DL_FUNC) &_stochtree_json_add_double_cpp, 3}, + {"_stochtree_json_add_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_double_subfolder_cpp, 4}, + {"_stochtree_json_add_forest_cpp", (DL_FUNC) &_stochtree_json_add_forest_cpp, 2}, + {"_stochtree_json_add_integer_cpp", (DL_FUNC) &_stochtree_json_add_integer_cpp, 3}, + {"_stochtree_json_add_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_subfolder_cpp, 4}, + {"_stochtree_json_add_integer_vector_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_cpp, 3}, + {"_stochtree_json_add_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_subfolder_cpp, 4}, + {"_stochtree_json_add_rfx_container_cpp", (DL_FUNC) &_stochtree_json_add_rfx_container_cpp, 2}, + {"_stochtree_json_add_rfx_groupids_cpp", (DL_FUNC) &_stochtree_json_add_rfx_groupids_cpp, 2}, + {"_stochtree_json_add_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_json_add_rfx_label_mapper_cpp, 2}, + {"_stochtree_json_add_string_cpp", (DL_FUNC) &_stochtree_json_add_string_cpp, 3}, + {"_stochtree_json_add_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_subfolder_cpp, 4}, + {"_stochtree_json_add_string_vector_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_cpp, 3}, + {"_stochtree_json_add_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_subfolder_cpp, 4}, + {"_stochtree_json_add_vector_cpp", (DL_FUNC) &_stochtree_json_add_vector_cpp, 3}, + {"_stochtree_json_add_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_vector_subfolder_cpp, 4}, + {"_stochtree_json_contains_field_cpp", (DL_FUNC) &_stochtree_json_contains_field_cpp, 2}, + {"_stochtree_json_contains_field_subfolder_cpp", (DL_FUNC) &_stochtree_json_contains_field_subfolder_cpp, 3}, + {"_stochtree_json_extract_bool_cpp", (DL_FUNC) &_stochtree_json_extract_bool_cpp, 2}, + {"_stochtree_json_extract_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_bool_subfolder_cpp, 3}, + {"_stochtree_json_extract_double_cpp", (DL_FUNC) &_stochtree_json_extract_double_cpp, 2}, + {"_stochtree_json_extract_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_double_subfolder_cpp, 3}, + {"_stochtree_json_extract_integer_cpp", (DL_FUNC) &_stochtree_json_extract_integer_cpp, 2}, + {"_stochtree_json_extract_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_subfolder_cpp, 3}, + {"_stochtree_json_extract_integer_vector_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_cpp, 2}, + {"_stochtree_json_extract_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_subfolder_cpp, 3}, + {"_stochtree_json_extract_string_cpp", (DL_FUNC) &_stochtree_json_extract_string_cpp, 2}, + {"_stochtree_json_extract_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_subfolder_cpp, 3}, + {"_stochtree_json_extract_string_vector_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_cpp, 2}, + {"_stochtree_json_extract_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_subfolder_cpp, 3}, + {"_stochtree_json_extract_vector_cpp", (DL_FUNC) &_stochtree_json_extract_vector_cpp, 2}, + {"_stochtree_json_extract_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_vector_subfolder_cpp, 3}, + {"_stochtree_json_increment_rfx_count_cpp", (DL_FUNC) &_stochtree_json_increment_rfx_count_cpp, 1}, + {"_stochtree_json_load_file_cpp", (DL_FUNC) &_stochtree_json_load_file_cpp, 2}, + {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, + {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, + {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, + {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, + {"_stochtree_leaf_dimension_active_forest_cpp", (DL_FUNC) &_stochtree_leaf_dimension_active_forest_cpp, 1}, + {"_stochtree_leaf_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_dimension_forest_container_cpp, 1}, + {"_stochtree_leaf_values_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_values_forest_container_cpp, 4}, + {"_stochtree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_leaves_forest_container_cpp, 3}, + {"_stochtree_left_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_left_child_node_forest_container_cpp, 4}, + {"_stochtree_multiply_forest_forest_container_cpp", (DL_FUNC) &_stochtree_multiply_forest_forest_container_cpp, 3}, + {"_stochtree_node_depth_forest_container_cpp", (DL_FUNC) &_stochtree_node_depth_forest_container_cpp, 4}, + {"_stochtree_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_nodes_forest_container_cpp, 3}, + {"_stochtree_num_leaf_parents_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaf_parents_forest_container_cpp, 3}, + {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2}, + {"_stochtree_num_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_forest_container_cpp, 3}, + {"_stochtree_num_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_nodes_forest_container_cpp, 3}, + {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, + {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, + {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, + {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, + {"_stochtree_ordinal_sampler_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_cpp, 0}, + {"_stochtree_ordinal_sampler_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_cumsum_exp_cpp, 2}, + {"_stochtree_ordinal_sampler_update_gamma_params_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_gamma_params_cpp, 7}, + {"_stochtree_ordinal_sampler_update_latent_variables_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_latent_variables_cpp, 4}, + {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, + {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, + {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, + {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, + {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, + {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, + {"_stochtree_predict_forest_raw_single_tree_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_tree_cpp, 4}, + {"_stochtree_predict_raw_active_forest_cpp", (DL_FUNC) &_stochtree_predict_raw_active_forest_cpp, 2}, + {"_stochtree_propagate_basis_update_active_forest_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_active_forest_cpp, 4}, + {"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5}, + {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, + {"_stochtree_remove_sample_forest_container_cpp", (DL_FUNC) &_stochtree_remove_sample_forest_container_cpp, 2}, + {"_stochtree_reset_active_forest_cpp", (DL_FUNC) &_stochtree_reset_active_forest_cpp, 3}, + {"_stochtree_reset_forest_model_cpp", (DL_FUNC) &_stochtree_reset_forest_model_cpp, 5}, + {"_stochtree_reset_rfx_model_cpp", (DL_FUNC) &_stochtree_reset_rfx_model_cpp, 3}, + {"_stochtree_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_reset_rfx_tracker_cpp, 4}, + {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, + {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, + {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, + {"_stochtree_rfx_container_delete_sample_cpp", (DL_FUNC) &_stochtree_rfx_container_delete_sample_cpp, 2}, + {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, + {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, + {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, + {"_stochtree_rfx_container_get_beta_cpp", (DL_FUNC) &_stochtree_rfx_container_get_beta_cpp, 1}, + {"_stochtree_rfx_container_get_sigma_cpp", (DL_FUNC) &_stochtree_rfx_container_get_sigma_cpp, 1}, + {"_stochtree_rfx_container_get_xi_cpp", (DL_FUNC) &_stochtree_rfx_container_get_xi_cpp, 1}, + {"_stochtree_rfx_container_num_components_cpp", (DL_FUNC) &_stochtree_rfx_container_num_components_cpp, 1}, + {"_stochtree_rfx_container_num_groups_cpp", (DL_FUNC) &_stochtree_rfx_container_num_groups_cpp, 1}, + {"_stochtree_rfx_container_num_samples_cpp", (DL_FUNC) &_stochtree_rfx_container_num_samples_cpp, 1}, + {"_stochtree_rfx_container_predict_cpp", (DL_FUNC) &_stochtree_rfx_container_predict_cpp, 3}, + {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, + {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, + {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, + {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, + {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, + {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, + {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, + {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, + {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, + {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, + {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, + {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, + {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, + {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, + {"_stochtree_rfx_label_mapper_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_string_cpp, 2}, + {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, + {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, + {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, + {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 8}, + {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, + {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, + {"_stochtree_rfx_model_set_variance_prior_scale_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_scale_cpp, 2}, + {"_stochtree_rfx_model_set_variance_prior_shape_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_shape_cpp, 2}, + {"_stochtree_rfx_model_set_working_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_covariance_cpp, 2}, + {"_stochtree_rfx_model_set_working_parameter_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_cpp, 2}, + {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, + {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, + {"_stochtree_right_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_right_child_node_forest_container_cpp, 4}, + {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, + {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, + {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 19}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 18}, + {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, + {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, + {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, + {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, + {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, + {"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2}, + {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, + {"_stochtree_split_categories_forest_container_cpp", (DL_FUNC) &_stochtree_split_categories_forest_container_cpp, 4}, + {"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4}, + {"_stochtree_split_theshold_forest_container_cpp", (DL_FUNC) &_stochtree_split_theshold_forest_container_cpp, 4}, + {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, + {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, + {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, + {"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2}, + {"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2}, + {"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2}, + {"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2}, {NULL, NULL, 0} }; } diff --git a/src/sampler.cpp b/src/sampler.cpp index 2e80aa4d..91169e9a 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -27,7 +27,8 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_basis); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample); + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); } } From bf2144730adc676ab07636e364c449dd7871fb9d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 17:06:40 -0500 Subject: [PATCH 24/34] Set up cloglog to work with GFR and updated examples --- R/cloglog_ordinal_bart.R | 44 ++++++++++++------- man/cloglog_ordinal_bart.Rd | 11 +++-- src/leaf_model.cpp | 4 -- src/ordinal_sampler.cpp | 2 +- src/sampler.cpp | 2 +- tools/debug/cloglog_ordinal_bart_binary.R | 17 +++---- .../debug/cloglog_ordinal_bart_multinomial.R | 5 ++- 7 files changed, 50 insertions(+), 35 deletions(-) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index e566fb51..4726260a 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -4,8 +4,9 @@ #' @param y A numeric vector of ordinal outcomes (positive integers starting from 1). #' @param X_test An optional numeric matrix of predictors (test data). #' @param n_trees Number of trees in the BART ensemble. Default: `50`. -#' @param n_samples_mcmc Total number of MCMC samples to draw. Default: `500`. -#' @param n_burnin Number of burn-in samples to discard. Default: `250`. +#' @param num_gfr Number of GFR samples to draw at the beginning of the sampler. Default: `10`. +#' @param num_burnin Number of burn-in MCMC samples to discard. Default: `0`. +#' @param num_mcmc Total number of MCMC samples to draw. Default: `500`. #' @param n_thin Thinning interval for MCMC samples. Default: `1`. #' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`. #' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`. @@ -16,8 +17,9 @@ #' @export cloglog_ordinal_bart <- function(X, y, X_test = NULL, n_trees = 50, - n_samples_mcmc = 500, - n_burnin = 250, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 500, n_thin = 1, alpha_gamma = 2.0, beta_gamma = 2.0, @@ -25,7 +27,6 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, feature_types = NULL, seed = NULL, num_threads = 1) { - # BART parameters alpha_bart <- 0.95 beta_bart <- 2 @@ -39,6 +40,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, # Determine whether a test dataset is provided has_test <- !is.null(X_test) + # Data checks if (!is.matrix(X)) X <- as.matrix(X) if (!is.numeric(y)) y <- as.numeric(y) @@ -71,9 +73,11 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, set.seed(seed) } - keep_idx <- seq((n_burnin + 1), n_samples_mcmc, by = n_thin) + # Indices of MCMC samples to keep after GFR, burn-in, and thinning + keep_idx <- seq(num_gfr + num_burnin + 1, num_gfr + num_burnin + num_mcmc, by = n_thin) n_keep <- length(keep_idx) + # Storage for MCMC samples forest_pred_train <- matrix(0, n_samples, n_keep) if (has_test) { n_samples_test <- nrow(X_test) @@ -82,7 +86,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, gamma_samples <- matrix(0, n_levels - 1, n_keep) latent_samples <- matrix(0, n_samples, n_keep) - # Initialize samplers + # Initialize samplers ordinal_sampler <- stochtree:::ordinal_sampler_cpp() rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed) @@ -140,20 +144,30 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, sweep_indices <- as.integer(seq(0, n_trees - 1)) sample_counter <- 0 - for (i in 1:n_samples_mcmc) { + for (i in 1:(num_mcmc + num_burnin + num_gfr)) { keep_sample <- i %in% keep_idx if (keep_sample) { sample_counter <- sample_counter + 1 } # 1. Sample forest using MCMC - stochtree:::sample_mcmc_one_iteration_cpp( - dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, - active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, - sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), - scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample, - num_threads - ) + if (i > num_gfr) { + stochtree:::sample_mcmc_one_iteration_cpp( + dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, + active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, + sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), + scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample, + num_threads + ) + } else { + stochtree:::sample_gfr_one_iteration_cpp( + dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, + active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, + sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), + scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample, + ncol(X), num_threads + ) + } # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions # This is needed for updating gamma parameters, latent z_i's diff --git a/man/cloglog_ordinal_bart.Rd b/man/cloglog_ordinal_bart.Rd index 6cdba1c2..839c5a26 100644 --- a/man/cloglog_ordinal_bart.Rd +++ b/man/cloglog_ordinal_bart.Rd @@ -9,8 +9,9 @@ cloglog_ordinal_bart( y, X_test = NULL, n_trees = 50, - n_samples_mcmc = 500, - n_burnin = 250, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 500, n_thin = 1, alpha_gamma = 2, beta_gamma = 2, @@ -29,9 +30,11 @@ cloglog_ordinal_bart( \item{n_trees}{Number of trees in the BART ensemble. Default: \code{50}.} -\item{n_samples_mcmc}{Total number of MCMC samples to draw. Default: \code{500}.} +\item{num_gfr}{Number of GFR samples to draw at the beginning of the sampler. Default: \code{10}.} -\item{n_burnin}{Number of burn-in samples to discard. Default: \code{250}.} +\item{num_burnin}{Number of burn-in MCMC samples to discard. Default: \code{0}.} + +\item{num_mcmc}{Total number of MCMC samples to draw. Default: \code{500}.} \item{n_thin}{Thinning interval for MCMC samples. Default: \code{1}.} diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 3f39fba5..5ba26f57 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -276,10 +276,6 @@ void LogLinearVarianceLeafModel::SetEnsembleRootPredictedValue(ForestDataset& da } } -// ============================================================================ -// Cloglog Ordinal Leaf Model -// ============================================================================ - double CloglogOrdinalLeafModel::SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance) { double left_log_ml = SuffStatLogMarginalLikelihood(left_stat, global_variance); double right_log_ml = SuffStatLogMarginalLikelihood(right_stat, global_variance); diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp index 54c212c7..ba9e9255 100644 --- a/src/ordinal_sampler.cpp +++ b/src/ordinal_sampler.cpp @@ -72,7 +72,7 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& } // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability - gamma[0] = gamma_0; + gamma[0] = gamma_0; } void OrdinalSampler::UpdateCumulativeExpSums(ForestDataset& dataset) { diff --git a/src/sampler.cpp b/src/sampler.cpp index 91169e9a..b843b98a 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -55,7 +55,7 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer 0.5, 2, -1) +X <- matrix(runif(n * p), ncol = p) +# true_lambda_function <- ifelse(X[, 1] > 0.5, 2, -1) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta # Set cutpoints for ordinal categories (2 categories: 1, 2) n_categories <- 2 -gamma_true <- c(-1) +gamma_true <- c(-2) ordinal_cutpoints <- log(cumsum(exp(gamma_true))) ordinal_cutpoints @@ -51,8 +53,9 @@ out <- cloglog_ordinal_bart( X = X_train, y = y_train, X_test = X_test, - n_samples_mcmc = 1000, - n_burnin = 500, + num_gfr = 0, + num_burnin = 5000, + num_mcmc = 1000, n_thin = 1 ) @@ -63,8 +66,6 @@ print(end - start) par(mfrow = c(2, 1)) plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") abline(h = gamma_true[1], col = 'red', lty = 2) -# plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") -# abline(h = gamma_true[2], col = 'red', lty = 2) gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) summary(gamma1) diff --git a/tools/debug/cloglog_ordinal_bart_multinomial.R b/tools/debug/cloglog_ordinal_bart_multinomial.R index 2d1c607f..15ae9857 100644 --- a/tools/debug/cloglog_ordinal_bart_multinomial.R +++ b/tools/debug/cloglog_ordinal_bart_multinomial.R @@ -52,8 +52,9 @@ out <- cloglog_ordinal_bart( X = X_train, y = y_train, X_test = X_test, - n_samples_mcmc = 1000, - n_burnin = 500, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, n_thin = 1 ) From a5cee2b440c9a079bba8d1921f1599958154b525 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 19:40:54 -0500 Subject: [PATCH 25/34] Updating vignettes --- tools/debug/cloglog_ordinal_bart_binary.R | 8 +- .../cloglog_ordinal_bart_four_category.R | 159 ++++++++++++++++++ ... => cloglog_ordinal_bart_three_category.R} | 1 + 3 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 tools/debug/cloglog_ordinal_bart_four_category.R rename tools/debug/{cloglog_ordinal_bart_multinomial.R => cloglog_ordinal_bart_three_category.R} (99%) diff --git a/tools/debug/cloglog_ordinal_bart_binary.R b/tools/debug/cloglog_ordinal_bart_binary.R index 20e3a272..feb1acdd 100644 --- a/tools/debug/cloglog_ordinal_bart_binary.R +++ b/tools/debug/cloglog_ordinal_bart_binary.R @@ -6,7 +6,7 @@ library(stochtree) set.seed(2025) # Sample size and number of predictors -n <- 2000 +n <- 10000 p <- 5 # Design matrix and true lambda function @@ -53,8 +53,8 @@ out <- cloglog_ordinal_bart( X = X_train, y = y_train, X_test = X_test, - num_gfr = 0, - num_burnin = 5000, + num_gfr = 10, + num_burnin = 0, num_mcmc = 1000, n_thin = 1 ) @@ -87,7 +87,7 @@ cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) -plot(lambda_pred_test, true_lambda_function[test_idx]) +plot(lambda_pred_test, gamma_true[1] + true_lambda_function[test_idx]) abline(a=0,b=1,col='blue', lwd=2) cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') diff --git a/tools/debug/cloglog_ordinal_bart_four_category.R b/tools/debug/cloglog_ordinal_bart_four_category.R new file mode 100644 index 00000000..56a99d30 --- /dev/null +++ b/tools/debug/cloglog_ordinal_bart_four_category.R @@ -0,0 +1,159 @@ +# Simulate ordinal data and run Cloglog Ordinal BART + +# Load +library(stochtree) + +set.seed(2025) + +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + +# Set cutpoints for ordinal categories (4 categories: 1, 2, 3, 4) +n_categories <- 4 +gamma_true <- c(-2, 0, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- apply(sapply(1:(j-1), function(k) exp(-exp(gamma_true[k] + true_lambda_function))), 1, prod) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") + +# Train test split +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +start <- Sys.time() + +# Sample the cloglog ordinal BART model +out <- cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + n_thin = 1 +) + +end <- Sys.time() +print(end - start) + +# Inference and diagnostics +par(mfrow = c(2, 2)) +plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[1], col = 'red', lty = 2) +plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[2], col = 'red', lty = 2) +plot(out$gamma_samples[3, ], type = 'l', main = expression(gamma[3]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[3], col = 'red', lty = 2) + +par(mfrow = c(2, 2)) +gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) +summary(gamma1) +hist(gamma1) + +gamma2 <- out$gamma_samples[2,] + colMeans(out$forest_predictions_train) +summary(gamma2) +hist(gamma2) + +gamma3 <- out$gamma_samples[3,] + colMeans(out$forest_predictions_train) +summary(gamma3) +hist(gamma3) + +par(mfrow = c(2,3), mar = c(5,4,1,1)) +rowMeans(out$gamma_samples) +moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) +plot(moo[,1]) +abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) +plot(moo[,2]) +abline(h = gamma_true[2] + mean(true_lambda_function[train_idx])) +plot(moo[,3]) +abline(h = gamma_true[3] + mean(true_lambda_function[train_idx])) +plot(out$gamma_samples[1,]) +plot(out$gamma_samples[2,]) +plot(out$gamma_samples[3,]) + +# Compare forest predictions with the truth function (for training and test sets) +par(mfrow = c(2,1)) +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) + +# Compare estimated vs true class probabilities for training set +par(mfrow = c(2,2)) +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} + +mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) + +# Compare estimated vs true class probabilities for test set +par(mfrow = c(2,2)) +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} diff --git a/tools/debug/cloglog_ordinal_bart_multinomial.R b/tools/debug/cloglog_ordinal_bart_three_category.R similarity index 99% rename from tools/debug/cloglog_ordinal_bart_multinomial.R rename to tools/debug/cloglog_ordinal_bart_three_category.R index 15ae9857..a62e4677 100644 --- a/tools/debug/cloglog_ordinal_bart_multinomial.R +++ b/tools/debug/cloglog_ordinal_bart_three_category.R @@ -87,6 +87,7 @@ plot(out$gamma_samples[1,]) plot(out$gamma_samples[2,]) # Compare forest predictions with the truth function (for training and test sets) +par(mfrow = c(2,1)) lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) plot(lambda_pred_train, true_lambda_function[train_idx]) abline(a=0,b=1,col='blue', lwd=2) From 815c53860fe57c5d8c22f9af8a8e2515af5c7955 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 19:41:21 -0500 Subject: [PATCH 26/34] WIP fix for data augmentation in the binary case --- src/ordinal_sampler.cpp | 48 +++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp index ba9e9255..9bf579b2 100644 --- a/src/ordinal_sampler.cpp +++ b/src/ordinal_sampler.cpp @@ -19,19 +19,37 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector int K = gamma.size() + 1; // Number of ordinal categories int N = dataset.NumObservations(); - // Update truncated exponentials (stored in latent auxiliary data slot 0) - // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) - // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} - // and lambda_hat_i is the total forest prediction for observation i - // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it - // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) - for (int i = 0; i < N; i++) { - int y = static_cast(outcome(i)); - if (y == K - 1) { - Z[i] = 1.0; - } else { - double rate = std::exp(gamma[y] + lambda_hat[i]); - Z[i] = SampleTruncatedExponential(rate, gen); + // Handle data augmentation separately for binary and multinomial outcomes (as documented in each branch below) + if (K == 2) { + // Here we fix gamma_1 = exp(0) = 1 for identifiability and augment + // z_i ~ TExp(rate = e^{lambda_hat_i}; 0, 1) if y_i = 0 + // z_i ~ TExp(rate = e^{lambda_hat_i}; 1, infty) if y_i = 1 + // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} + // and lambda_hat_i is the total forest prediction for observation i + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + double rate = std::exp(lambda_hat[i]); + if (y == 0) { + Z[i] = SampleTruncatedExponential(rate, gen); + } else { + Z[i] = SampleTruncatedExponential(rate, gen); + } + } + } else { + // Update truncated exponentials (stored in latent auxiliary data slot 0) + // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) + // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} + // and lambda_hat_i is the total forest prediction for observation i + // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it + // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y == K - 1) { + Z[i] = 1.0; + } else { + double rate = std::exp(gamma[y] + lambda_hat[i]); + Z[i] = SampleTruncatedExponential(rate, gen); + } } } } @@ -72,7 +90,9 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& } // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability - gamma[0] = gamma_0; + if (K > 2) { + gamma[0] = gamma_0; + } } void OrdinalSampler::UpdateCumulativeExpSums(ForestDataset& dataset) { From 2106f32988f40e321f77c1caf7ffd7fcb4b41ff6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 23:00:05 -0500 Subject: [PATCH 27/34] Updating sampler --- include/stochtree/ordinal_sampler.h | 24 +++++++- src/ordinal_sampler.cpp | 61 ++++++++----------- tools/debug/cloglog_ordinal_bart_binary.R | 19 +++--- .../cloglog_ordinal_bart_three_category.R | 4 +- 4 files changed, 60 insertions(+), 48 deletions(-) diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h index a83fdf9f..bfc474da 100644 --- a/include/stochtree/ordinal_sampler.h +++ b/include/stochtree/ordinal_sampler.h @@ -17,6 +17,22 @@ namespace StochTree { +static double sample_truncated_exponential_low_high(double u, double rate, double low, double high) { + return -std::log((1-u)*std::exp(-rate*low) + u*std::exp(-rate*high))/rate; +} + +static double sample_truncated_exponential_low(double u, double rate, double low) { + return -std::log((1-u)*std::exp(-rate*low))/rate; +} + +static double sample_truncated_exponential_high(double u, double rate, double high) { + return -std::log1p(u*std::expm1(-high*rate))/rate; +} + +static double sample_exponential(double u, double rate) { + return -std::log1p(-u)/rate; +} + /*! * \brief Sampler for ordinal model hyperparameters * @@ -35,13 +51,15 @@ class OrdinalSampler { /*! * \brief Sample from truncated exponential distribution * - * Samples from exponential distribution truncated to [0,1] + * Samples from exponential distribution truncated to [low,high] * - * \param lambda Rate parameter for exponential distribution * \param gen Random number generator + * \param rate Rate parameter for exponential distribution + * \param low Lower truncation bound + * \param high Upper truncation bound * \return Sampled value from truncated exponential */ - static double SampleTruncatedExponential(double lambda, std::mt19937& gen); + static double SampleTruncatedExponential(std::mt19937& gen, double rate, double low = 0.0, double high = 1.0); /*! * \brief Update truncated exponential latent variables (Z) diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp index 9bf579b2..6a09a000 100644 --- a/src/ordinal_sampler.cpp +++ b/src/ordinal_sampler.cpp @@ -3,11 +3,18 @@ namespace StochTree { -double OrdinalSampler::SampleTruncatedExponential(double lambda, std::mt19937& gen) { +double OrdinalSampler::SampleTruncatedExponential(std::mt19937& gen, double rate, double low, double high) { std::uniform_real_distribution unif(0.0, 1.0); double u = unif(gen); - double a = 1.0 - u * (1.0 - std::exp(-lambda)); - return -std::log(a) / lambda; + if ((low <= 0.0) && (high <= 0.0)) { + return sample_exponential(u, rate); + } else if ((low <= 0.0) && (high > 0.0)) { + return sample_truncated_exponential_high(u, rate, high); + } else if ((low > 0.0) && (high <= 0.0)) { + return sample_truncated_exponential_low(u, rate, low); + } else { + return sample_truncated_exponential_low_high(u, rate, low, high); + } } void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, std::mt19937& gen) { @@ -19,37 +26,19 @@ void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::Vector int K = gamma.size() + 1; // Number of ordinal categories int N = dataset.NumObservations(); - // Handle data augmentation separately for binary and multinomial outcomes (as documented in each branch below) - if (K == 2) { - // Here we fix gamma_1 = exp(0) = 1 for identifiability and augment - // z_i ~ TExp(rate = e^{lambda_hat_i}; 0, 1) if y_i = 0 - // z_i ~ TExp(rate = e^{lambda_hat_i}; 1, infty) if y_i = 1 - // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} - // and lambda_hat_i is the total forest prediction for observation i - for (int i = 0; i < N; i++) { - int y = static_cast(outcome(i)); - double rate = std::exp(lambda_hat[i]); - if (y == 0) { - Z[i] = SampleTruncatedExponential(rate, gen); - } else { - Z[i] = SampleTruncatedExponential(rate, gen); - } - } - } else { - // Update truncated exponentials (stored in latent auxiliary data slot 0) - // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) - // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} - // and lambda_hat_i is the total forest prediction for observation i - // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it - // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) - for (int i = 0; i < N; i++) { - int y = static_cast(outcome(i)); - if (y == K - 1) { - Z[i] = 1.0; - } else { - double rate = std::exp(gamma[y] + lambda_hat[i]); - Z[i] = SampleTruncatedExponential(rate, gen); - } + // Update truncated exponentials (stored in latent auxiliary data slot 0) + // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) + // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} + // and lambda_hat_i is the total forest prediction for observation i + // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it + // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y == K - 1) { + Z[i] = 1.0; + } else { + double rate = std::exp(gamma[y] + lambda_hat[i]); + Z[i] = SampleTruncatedExponential(gen, rate, 0.0, 1.0); } } } @@ -90,9 +79,9 @@ void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& } // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability - if (K > 2) { + // if (K > 2) { gamma[0] = gamma_0; - } + // } } void OrdinalSampler::UpdateCumulativeExpSums(ForestDataset& dataset) { diff --git a/tools/debug/cloglog_ordinal_bart_binary.R b/tools/debug/cloglog_ordinal_bart_binary.R index feb1acdd..25840b32 100644 --- a/tools/debug/cloglog_ordinal_bart_binary.R +++ b/tools/debug/cloglog_ordinal_bart_binary.R @@ -6,7 +6,7 @@ library(stochtree) set.seed(2025) # Sample size and number of predictors -n <- 10000 +n <- 2000 p <- 5 # Design matrix and true lambda function @@ -17,7 +17,7 @@ true_lambda_function <- X %*% beta # Set cutpoints for ordinal categories (2 categories: 1, 2) n_categories <- 2 -gamma_true <- c(-2) +gamma_true <- c(-1) ordinal_cutpoints <- log(cumsum(exp(gamma_true))) ordinal_cutpoints @@ -53,10 +53,12 @@ out <- cloglog_ordinal_bart( X = X_train, y = y_train, X_test = X_test, - num_gfr = 10, - num_burnin = 0, + num_gfr = 0, + num_burnin = 1000, num_mcmc = 1000, - n_thin = 1 + n_thin = 1, + alpha_gamma = 1.0, + beta_gamma = 1.0 ) end <- Sys.time() @@ -71,7 +73,7 @@ gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) summary(gamma1) hist(gamma1) -par(mfrow = c(2,1), mar = c(5,4,1,1)) +par(mfrow = c(2,1)) rowMeans(out$gamma_samples) moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) plot(moo[,1]) @@ -81,9 +83,10 @@ plot(out$gamma_samples[1,]) # Compare forest predictions with the truth function (for training and test sets) par(mfrow = c(2,1)) lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) -plot(lambda_pred_train, true_lambda_function[train_idx]) +# lambda_pred_train <- rowMeans(out$forest_predictions_train) +plot(lambda_pred_train, gamma_true[1] + true_lambda_function[train_idx]) abline(a=0,b=1,col='blue', lwd=2) -cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +cor_train <- cor(true_lambda_function[train_idx] + mean(out$gamma_samples[1,]), gamma_true[1] + lambda_pred_train) text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) diff --git a/tools/debug/cloglog_ordinal_bart_three_category.R b/tools/debug/cloglog_ordinal_bart_three_category.R index a62e4677..d50e9746 100644 --- a/tools/debug/cloglog_ordinal_bart_three_category.R +++ b/tools/debug/cloglog_ordinal_bart_three_category.R @@ -16,7 +16,7 @@ true_lambda_function <- X %*% beta # Set cutpoints for ordinal categories (3 categories: 1, 2, 3) n_categories <- 3 -gamma_true <- c(-2, 1) +gamma_true <- c(-2, 3) ordinal_cutpoints <- log(cumsum(exp(gamma_true))) ordinal_cutpoints @@ -116,6 +116,7 @@ for (j in 1:n_categories) { mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) # Compare estimated vs true class probabilities for training set +par(mfrow = c(2,2)) for (j in 1:n_categories) { plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) abline(a = 0, b = 1, col = 'blue', lwd = 2) @@ -139,6 +140,7 @@ for (j in 1:n_categories) { mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) # Compare estimated vs true class probabilities for test set +par(mfrow = c(2,2)) for (j in 1:n_categories) { plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) abline(a = 0, b = 1, col = 'blue', lwd = 2) From 02dc2acf6727d3544c21ca11f7c99eae7fc2705c Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 23:02:29 -0500 Subject: [PATCH 28/34] Remove unused slice sampler code --- include/stochtree/slice_sampler.h | 180 ------------------------------ 1 file changed, 180 deletions(-) delete mode 100644 include/stochtree/slice_sampler.h diff --git a/include/stochtree/slice_sampler.h b/include/stochtree/slice_sampler.h deleted file mode 100644 index 07fe5a26..00000000 --- a/include/stochtree/slice_sampler.h +++ /dev/null @@ -1,180 +0,0 @@ -/*! - * Copyright (c) 2024 stochtree authors. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the project root for license information. - */ -#ifndef STOCHTREE_SLICE_SAMPLER_H_ -#define STOCHTREE_SLICE_SAMPLER_H_ - -#include -#include -#include -#include -#include - -#ifndef M_LN2 -#define M_LN2 0.6931471805599453 // ln(2) -#endif - -namespace StochTree { - -/*! - * \brief Abstract base class for log-likelihood functions used in slice sampling - */ -class LoglikFunction { - public: - virtual ~LoglikFunction() {} - - /*! - * \brief Evaluate the log-likelihood function at point x - * \param x Input value - * \return Log-likelihood value - */ - virtual double Evaluate(double x) = 0; -}; - -/*! - * \brief Log-likelihood function for scale_lambda parameter in ordinal models - */ -class ScaleLambdaLoglik : public LoglikFunction { - public: - /*! - * \brief Constructor for scale lambda log-likelihood - * \param n Number of observations (lambda values) - * \param sum_lambda Sum of all lambda values - * \param sum_exp_lambda Sum of exp(lambda) values - * \param scale Prior scale parameter for scale_lambda - */ - ScaleLambdaLoglik(double n, double sum_lambda, double sum_exp_lambda, double scale) - : n_(n), sum_lambda_(sum_lambda), sum_exp_lambda_(sum_exp_lambda), scale_(scale) {} - - /*! - * \brief Evaluate log-likelihood of scale_lambda parameter - * \param sigma Input scale parameter value (scale_lambda) - * \return Log-likelihood value - */ - double Evaluate(double sigma) override { - if (sigma <= 0) return -std::numeric_limits::infinity(); - - // Convert scale_lambda to alpha and beta parameters - double alpha, beta; - ScaleLambdaToAlphaBeta(alpha, beta, sigma); - - // Log-likelihood contribution from lambda values (gamma prior) - double loglik = n_ * alpha * std::log(beta) - - n_ * boost::math::lgamma(alpha) - + alpha * sum_lambda_ - - beta * sum_exp_lambda_; - - // Add constants and prior terms - loglik += M_LN2 - 0.5 * std::log(2.0 * M_PI); // M_LN2 - LN_2_BY_PI approximation - - // Prior on scale_lambda (half-normal or similar) - double scale_ratio = sigma / scale_; - loglik -= 0.5 * scale_ratio * scale_ratio; - - return loglik; - } - - private: - double n_; - double sum_lambda_; - double sum_exp_lambda_; - double scale_; - - /*! - * \brief Convert scale_lambda to alpha and beta parameters for the gamma prior - */ - void ScaleLambdaToAlphaBeta(double& alpha, double& beta, const double sigma) { - double sigma_sq = sigma * sigma; - alpha = TrigammaInverse(sigma_sq); - beta = std::exp(boost::math::digamma(alpha)); - } - - /*! - * \brief Compute inverse trigamma function using Newton's method - */ - double TrigammaInverse(double x) { - if (x > 1E7) return 1.0 / std::sqrt(x); - if (x < 1E-6) return 1.0 / x; - - double y = 0.5 + 1.0 / x; - for (int i = 0; i < 50; i++) { - double tri = boost::math::trigamma(y); - double dif = tri * (1.0 - tri / x) / boost::math::polygamma(3, y); - y += dif; - if (-dif / y < 1E-8) break; - } - return y; - } -}; - -/*! - * \brief Slice sampler implementation - */ -class SliceSampler { - public: - SliceSampler() {} - ~SliceSampler() {} - - /*! - * \brief Sample from a distribution using slice sampling - * \param x0 Initial value - * \param loglik_func Log-likelihood function - * \param w Step size for expanding interval - * \param lower Lower bound - * \param upper Upper bound - * \param gen Random number generator - * \return Sampled value - */ - double Sample(double x0, LoglikFunction* loglik_func, double w, - double lower, double upper, std::mt19937& gen) { - - std::uniform_real_distribution unif(0.0, 1.0); - std::exponential_distribution exp_dist(1.0); - - // Find the log density at the initial point - double gx0 = loglik_func->Evaluate(x0); - - // Determine the slice level, in log terms - double logy = gx0 - exp_dist(gen); - - // Find the initial interval to sample from - double u = w * unif(gen); - double L = x0 - u; - double R = x0 + (w - u); - - // Expand the interval until its ends are outside the slice - while (L > lower && loglik_func->Evaluate(L) > logy) { - L -= w; - } - - while (R < upper && loglik_func->Evaluate(R) > logy) { - R += w; - } - - // Shrink interval to bounds - if (L < lower) L = lower; - if (R > upper) R = upper; - - // Sample from the interval, shrinking it on each rejection - double x1; - do { - x1 = L + (R - L) * unif(gen); - double gx1 = loglik_func->Evaluate(x1); - - if (gx1 >= logy) break; - - if (x1 > x0) { - R = x1; - } else { - L = x1; - } - } while (true); - - return x1; - } -}; - -} // namespace StochTree - -#endif // STOCHTREE_SLICE_SAMPLER_H_ From 8f92425bac549d2dfb54898d4b13f4a495bf0646 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 23:23:07 -0500 Subject: [PATCH 29/34] Cleaned up PR --- include/stochtree/category_tracker.h | 4 --- include/stochtree/common.h | 4 --- include/stochtree/container.h | 5 --- include/stochtree/cutpoint_candidates.h | 2 -- include/stochtree/data.h | 1 - include/stochtree/ensemble.h | 6 ---- include/stochtree/io.h | 2 -- include/stochtree/leaf_model.h | 6 ---- include/stochtree/log.h | 2 -- include/stochtree/meta.h | 1 - include/stochtree/ordinal_sampler.h | 2 +- include/stochtree/partition_tracker.h | 25 ++++++------- include/stochtree/random.h | 1 - include/stochtree/random_effects.h | 3 -- include/stochtree/tree.h | 3 -- include/stochtree/tree_sampler.h | 47 ++++++++++++------------- include/stochtree/variance_model.h | 4 --- src/R_data.cpp | 1 - src/R_random_effects.cpp | 2 -- src/cutpoint_candidates.cpp | 1 - src/data.cpp | 1 - src/forest.cpp | 2 -- src/io.cpp | 2 -- src/kernel.cpp | 3 -- src/leaf_model.cpp | 2 -- src/partition_tracker.cpp | 8 +---- src/py_stochtree.cpp | 18 +++++----- src/sampler.cpp | 8 ----- src/serialization.cpp | 3 -- src/tree.cpp | 4 --- 30 files changed, 44 insertions(+), 129 deletions(-) diff --git a/include/stochtree/category_tracker.h b/include/stochtree/category_tracker.h index e5817419..2ce44635 100644 --- a/include/stochtree/category_tracker.h +++ b/include/stochtree/category_tracker.h @@ -29,12 +29,8 @@ #include #include -#include #include #include -#include -#include -#include #include namespace StochTree { diff --git a/include/stochtree/common.h b/include/stochtree/common.h index c7aab3df..cd57eea2 100644 --- a/include/stochtree/common.h +++ b/include/stochtree/common.h @@ -8,22 +8,18 @@ #include #include -#include #include #include #include #include #include -#include #include #include #include -#include #include #include #include #include -#include #include #include diff --git a/include/stochtree/container.h b/include/stochtree/container.h index bb0e7849..4b75ef2f 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -11,12 +11,7 @@ #include #include -#include -#include #include -#include -#include -#include namespace StochTree { diff --git a/include/stochtree/cutpoint_candidates.h b/include/stochtree/cutpoint_candidates.h index 8c19013a..76f1df4c 100644 --- a/include/stochtree/cutpoint_candidates.h +++ b/include/stochtree/cutpoint_candidates.h @@ -42,8 +42,6 @@ #include #include -#include - namespace StochTree { /*! \brief Computing and tracking cutpoints available for a given feature at a given node diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 8cf16e4d..393203b1 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -9,7 +9,6 @@ #include #include #include -#include namespace StochTree { diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index 4624b5a4..4f6ddf42 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -14,12 +14,6 @@ #include #include -#include -#include -#include -#include -#include - using json = nlohmann::json; namespace StochTree { diff --git a/include/stochtree/io.h b/include/stochtree/io.h index 3bc277fb..55963946 100644 --- a/include/stochtree/io.h +++ b/include/stochtree/io.h @@ -28,12 +28,10 @@ #include #include #include -#include #include #include #include #include -#include #include #include diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 88719608..fc9ab0af 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -16,15 +16,9 @@ #include #include #include -#include #include -#include -#include #include -#include -#include -#include namespace StochTree { diff --git a/include/stochtree/log.h b/include/stochtree/log.h index 3a4c5600..9f64c31b 100644 --- a/include/stochtree/log.h +++ b/include/stochtree/log.h @@ -15,8 +15,6 @@ #include #include #include -#include -#include #include #include diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index 991c254f..d0aa4049 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h index bfc474da..d67563e2 100644 --- a/include/stochtree/ordinal_sampler.h +++ b/include/stochtree/ordinal_sampler.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2024 stochtree authors. All rights reserved. + * Copyright (c) 2025 stochtree authors. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #ifndef STOCHTREE_ORDINAL_SAMPLER_H_ diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index a62121b1..f25c875c 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -31,12 +31,7 @@ #include #include -#include #include -#include -#include -#include -#include namespace StochTree { @@ -434,7 +429,7 @@ class UnsortedNodeSampleTracker { /*! \brief Number of trees */ int NumTrees() { return num_trees_; } - /*! \brief Number of trees */ + /*! \brief Return a pointer to the feature partition tracking tree i */ FeatureUnsortedPartition* GetFeaturePartition(int i) { return feature_partitions_[i].get(); } private: @@ -615,24 +610,24 @@ class SortedNodeSampleTracker { } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, TreeSplit& split) { - for (int i = 0; i < num_features_; i++) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, TreeSplit& split, int num_threads = -1) { + StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { feature_partitions_[i]->SplitFeature(covariates, node_id, feature_split, split); - } + }); } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, double split_value) { - for (int i = 0; i < num_features_; i++) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, double split_value, int num_threads = -1) { + StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { feature_partitions_[i]->SplitFeatureNumeric(covariates, node_id, feature_split, split_value); - } + }); } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, std::vector const& category_list) { - for (int i = 0; i < num_features_; i++) { + void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, std::vector const& category_list, int num_threads = -1) { + StochTree::ParallelFor(0, num_features_, num_threads, [&](int i) { feature_partitions_[i]->SplitFeatureCategorical(covariates, node_id, feature_split, category_list); - } + }); } /*! \brief First index of data points contained in node_id */ diff --git a/include/stochtree/random.h b/include/stochtree/random.h index a841f396..3d39b647 100644 --- a/include/stochtree/random.h +++ b/include/stochtree/random.h @@ -5,7 +5,6 @@ #ifndef STOCHTREE_RANDOM_H_ #define STOCHTREE_RANDOM_H_ -#include #include #include #include diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index 701ebeaa..b322a560 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -17,14 +17,11 @@ #include #include -#include #include #include #include #include -#include #include -#include #include namespace StochTree { diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index 85ce7191..3810e3cb 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -13,9 +13,6 @@ #include #include -#include -#include -#include #include #include diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 7c0254c6..b8101fd2 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -8,19 +8,13 @@ #include #include #include +#include #include #include #include -#include -#include #include #include -#include -#include -#include -#include -#include #include namespace StochTree { @@ -28,7 +22,7 @@ namespace StochTree { /*! * \defgroup sampling_group Forest Sampler API * - * \brief Functions for sampling from a forest. The core interfce of these functions, + * \brief Functions for sampling from a forest. The core interface of these functions, * as used by the R, Python, and standalone C++ program, is defined by * \ref MCMCSampleOneIter, which runs one iteration of the MCMC sampler for a * given forest, and \ref GFRSampleOneIter, which runs one iteration of the @@ -153,7 +147,7 @@ static inline bool NodeNonConstant(ForestDataset& dataset, ForestTracker& tracke } static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, TreeSplit& split, std::mt19937& gen, Tree* tree, - int tree_num, int leaf_node, int feature_split, bool keep_sorted = false) { + int tree_num, int leaf_node, int feature_split, bool keep_sorted = false, int num_threads = -1) { // Use zeros as a "temporary" leaf values since we draw leaf parameters after tree sampling is complete if (tree->OutputDimension() > 1) { std::vector temp_leaf_values(tree->OutputDimension(), 0.); @@ -166,7 +160,7 @@ static inline void AddSplitToModel(ForestTracker& tracker, ForestDataset& datase int right_node = tree->RightChild(leaf_node); // Update the ForestTracker - tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted); + tracker.AddSplit(dataset.GetCovariates(), split, feature_split, tree_num, leaf_node, left_node, right_node, keep_sorted, num_threads); } static inline void RemoveSplitFromModel(ForestTracker& tracker, ForestDataset& dataset, TreePrior& tree_prior, std::mt19937& gen, Tree* tree, @@ -446,8 +440,11 @@ static inline std::tuple EvaluatePropo LeafSuffStat right_suff_stat = LeafSuffStat(leaf_suff_stat_args...); // Accumulate sufficient statistics - AccumulateSuffStatProposed(node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, - residual, global_variance, split, tree_num, leaf_num, split_feature, 1, leaf_suff_stat_args...); + AccumulateSuffStatProposed( + node_suff_stat, left_suff_stat, right_suff_stat, dataset, tracker, + residual, global_variance, split, tree_num, leaf_num, split_feature, num_threads, + leaf_suff_stat_args... + ); data_size_t left_n = left_suff_stat.n; data_size_t right_n = right_suff_stat.n; @@ -486,12 +483,12 @@ template static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { - UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); } else if (backfitting) { - UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); + UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); } else { - // TODO: think about a generic way to store "state" corresponding to the other models? - UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), false); + // TODO: think about a generic way to store "state" corresponding to the other models? + UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), false); } } @@ -499,12 +496,12 @@ template static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { if constexpr (std::is_same_v) { - UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); } else if (backfitting) { - UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); + UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); } else { - // TODO: think about a generic way to store "state" corresponding to the other models? - UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), true); + // TODO: think about a generic way to store "state" corresponding to the other models? + UpdateVarModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), true); } } @@ -702,7 +699,7 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel } // Add split to tree and trackers - AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true); + AddSplitToModel(tracker, dataset, tree_prior, tree_split, gen, tree, tree_num, node_id, feature_split, true, num_threads); // Determine the number of observation in the newly created left node int left_node = tree->LeftChild(node_id); @@ -784,7 +781,8 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore SampleSplitRule( tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, node_index_map, split_queue, curr_node_id, curr_node_begin, curr_node_end, variable_weights, feature_types, - feature_subset, num_threads, leaf_suff_stat_args...); + feature_subset, num_threads, leaf_suff_stat_args... + ); } } @@ -960,7 +958,7 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM double log_acceptance_prob = std::log(mh_accept(gen)); if (log_acceptance_prob <= log_mh_ratio) { accept = true; - AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false); + AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false, num_threads); } else { accept = false; } @@ -1133,7 +1131,8 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For template static inline void MCMCSampleOneIter(TreeEnsemble& active_forest, ForestTracker& tracker, ForestContainer& forests, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior, std::mt19937& gen, std::vector& variable_weights, - std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& sweep_update_indices, double global_variance, bool keep_forest, bool pre_initialized, bool backfitting, int num_threads, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Run the MCMC algorithm for each tree int num_trees = forests.NumTrees(); for (const int& i : sweep_update_indices) { diff --git a/include/stochtree/variance_model.h b/include/stochtree/variance_model.h index 79b8831f..b1c2dabe 100644 --- a/include/stochtree/variance_model.h +++ b/include/stochtree/variance_model.h @@ -12,11 +12,7 @@ #include #include -#include #include -#include -#include -#include namespace StochTree { diff --git a/src/R_data.cpp b/src/R_data.cpp index 6ede6473..caf3e9bc 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -5,7 +5,6 @@ #include #include #include -#include [[cpp11::register]] cpp11::external_pointer create_forest_dataset_cpp() { diff --git a/src/R_random_effects.cpp b/src/R_random_effects.cpp index f627b3c5..e291121c 100644 --- a/src/R_random_effects.cpp +++ b/src/R_random_effects.cpp @@ -7,9 +7,7 @@ #include #include #include -#include #include -#include [[cpp11::register]] cpp11::external_pointer rfx_container_cpp(int num_components, int num_groups) { diff --git a/src/cutpoint_candidates.cpp b/src/cutpoint_candidates.cpp index 4a0845c7..e43b8219 100644 --- a/src/cutpoint_candidates.cpp +++ b/src/cutpoint_candidates.cpp @@ -2,7 +2,6 @@ #include #include -#include namespace StochTree { diff --git a/src/data.cpp b/src/data.cpp index cd2913cf..e48e9255 100644 --- a/src/data.cpp +++ b/src/data.cpp @@ -1,7 +1,6 @@ /*! Copyright (c) 2024 by stochtree authors */ #include #include -#include namespace StochTree { diff --git a/src/forest.cpp b/src/forest.cpp index 2bc9f03a..ddd247a0 100644 --- a/src/forest.cpp +++ b/src/forest.cpp @@ -7,9 +7,7 @@ #include #include #include -#include #include -#include [[cpp11::register]] cpp11::external_pointer active_forest_cpp(int num_trees, int output_dimension = 1, bool is_leaf_constant = true, bool is_exponentiated = false) { diff --git a/src/io.cpp b/src/io.cpp index 1324957f..50774d9b 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -7,9 +7,7 @@ #include #include -#include #include -#include namespace StochTree { diff --git a/src/kernel.cpp b/src/kernel.cpp index 6b5867bb..38fdd35c 100644 --- a/src/kernel.cpp +++ b/src/kernel.cpp @@ -2,9 +2,6 @@ #include "stochtree_types.h" #include #include -#include -#include -#include typedef Eigen::Map> DoubleMatrixType; typedef Eigen::Map> IntMatrixType; diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 5ba26f57..3c1f91c9 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -1,7 +1,5 @@ #include #include -#include -#include namespace StochTree { diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index d9faf57a..73b37fe8 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -6,12 +6,6 @@ #include #include -#include -#include -#include -#include -#include - namespace StochTree { ForestTracker::ForestTracker(Eigen::MatrixXd& covariates, std::vector& feature_types, int num_trees, int num_observations) { @@ -286,7 +280,7 @@ void ForestTracker::AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int3 sample_node_mapper_->AddSplit(covariates, split, split_feature, tree_id, split_node_id, left_node_id, right_node_id); unsorted_node_sample_tracker_->PartitionTreeNode(covariates, tree_id, split_node_id, left_node_id, right_node_id, split_feature, split); if (keep_sorted) { - sorted_node_sample_tracker_->PartitionNode(covariates, split_node_id, split_feature, split); + sorted_node_sample_tracker_->PartitionNode(covariates, split_node_id, split_feature, split, num_threads); } } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 2c514cb2..a95ab0ee 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1079,7 +1079,7 @@ class ForestSamplerCpp { void SampleOneIteration(ForestContainerCpp& forest_samples, ForestCpp& forest, ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng, py::array_t feature_types, py::array_t sweep_update_indices, int cutpoint_grid_size, py::array_t leaf_model_scale_input, py::array_t variable_weights, double a_forest, double b_forest, double global_variance, - int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true) { + int leaf_model_int, int num_features_subsample, bool keep_forest = true, bool gfr = true, int num_threads = -1) { // Refactoring completely out of the Python interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; @@ -1141,23 +1141,23 @@ class ForestSamplerCpp { std::mt19937* rng_ptr = rng.GetRng(); if (gfr) { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_basis); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample); + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); } } else { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); } else if (model_type == StochTree::ModelType::kUnivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); } else if (model_type == StochTree::ModelType::kMultivariateRegressionLeafGaussian) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_basis); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { - StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false); + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); } } } diff --git a/src/sampler.cpp b/src/sampler.cpp index b843b98a..f356d968 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -8,10 +8,7 @@ #include #include #include -#include #include -#include -#include [[cpp11::register]] void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, @@ -337,11 +334,6 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( return(output); } - -// ============================================================================ -// ORDINAL SAMPLER FUNCTIONS -// ============================================================================ - [[cpp11::register]] cpp11::external_pointer ordinal_sampler_cpp() { std::unique_ptr sampler_ptr = std::make_unique(); diff --git a/src/serialization.cpp b/src/serialization.cpp index 749395e8..fb248f62 100644 --- a/src/serialization.cpp +++ b/src/serialization.cpp @@ -8,9 +8,6 @@ #include #include #include -#include -#include -#include [[cpp11::register]] cpp11::external_pointer init_json_cpp() { diff --git a/src/tree.cpp b/src/tree.cpp index fa6fd8f8..32c51475 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -8,9 +8,6 @@ #include #include -#include -#include -#include namespace StochTree { @@ -668,7 +665,6 @@ void Tree::from_json(const json& tree_json) { tree_json.at("has_categorical_split").get_to(this->has_categorical_split_); tree_json.at("output_dimension").get_to(this->output_dimension_); tree_json.at("is_log_scale").get_to(this->is_log_scale_); - this->num_deleted_nodes = 0; // Unpack the array based fields JsonToTreeNodeVectors(tree_json, this); From 95a0ce9738980c9fa51086fe9006c19bb2022985 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 23:26:25 -0500 Subject: [PATCH 30/34] Including variant in leaf model header file --- include/stochtree/leaf_model.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index fc9ab0af..4ea38014 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -19,6 +19,7 @@ #include #include +#include namespace StochTree { From 42f9ac4042a65f8fac2d3af42ce815fcd499be9d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 27 Oct 2025 23:46:13 -0500 Subject: [PATCH 31/34] Updated vignettes and function defaults --- R/cloglog_ordinal_bart.R | 14 +++++++------- man/cloglog_ordinal_bart.Rd | 8 ++++---- tools/debug/cloglog_ordinal_bart_binary.R | 4 +--- tools/debug/cloglog_ordinal_bart_four_category.R | 4 ++-- tools/debug/cloglog_ordinal_bart_three_category.R | 6 +++--- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R index 4726260a..7220774d 100644 --- a/R/cloglog_ordinal_bart.R +++ b/R/cloglog_ordinal_bart.R @@ -4,8 +4,8 @@ #' @param y A numeric vector of ordinal outcomes (positive integers starting from 1). #' @param X_test An optional numeric matrix of predictors (test data). #' @param n_trees Number of trees in the BART ensemble. Default: `50`. -#' @param num_gfr Number of GFR samples to draw at the beginning of the sampler. Default: `10`. -#' @param num_burnin Number of burn-in MCMC samples to discard. Default: `0`. +#' @param num_gfr Number of GFR samples to draw at the beginning of the sampler. Default: `0`. +#' @param num_burnin Number of burn-in MCMC samples to discard. Default: `1000`. #' @param num_mcmc Total number of MCMC samples to draw. Default: `500`. #' @param n_thin Thinning interval for MCMC samples. Default: `1`. #' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`. @@ -17,8 +17,8 @@ #' @export cloglog_ordinal_bart <- function(X, y, X_test = NULL, n_trees = 50, - num_gfr = 10, - num_burnin = 0, + num_gfr = 0, + num_burnin = 1000, num_mcmc = 500, n_thin = 1, alpha_gamma = 2.0, @@ -33,7 +33,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, min_samples_in_leaf <- 5 max_depth <- 10 scale_leaf <- 2 / sqrt(n_trees) - cutpoint_grid_size <- 100 # Needed for stochtree:::sample_mcmc_one_iteration_cpp (for GFR), not used in MCMC BART + cutpoint_grid_size <- 100 # Needed for stochtree::sample_gfr_one_iteration_cpp, not used in MCMC BART # Fixed for identifiability (can be pass as argument later if desired) gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0 @@ -128,7 +128,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, # Convert the log-scale parameters into cumulative exponentiated parameters. # This is done under the hood in a C++ function for efficiency. - ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) + stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) # Initialize forest predictions to zero (slot 1) for (i in 1:n_samples) { @@ -188,7 +188,7 @@ cloglog_ordinal_bart <- function(X, y, X_test = NULL, ) # 4. Update cumulative sum of exp(gamma) values - ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) + stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) if (keep_sample) { forest_pred_train[, sample_counter] <- active_forest$predict(dataX) diff --git a/man/cloglog_ordinal_bart.Rd b/man/cloglog_ordinal_bart.Rd index 839c5a26..049aa532 100644 --- a/man/cloglog_ordinal_bart.Rd +++ b/man/cloglog_ordinal_bart.Rd @@ -9,8 +9,8 @@ cloglog_ordinal_bart( y, X_test = NULL, n_trees = 50, - num_gfr = 10, - num_burnin = 0, + num_gfr = 0, + num_burnin = 1000, num_mcmc = 500, n_thin = 1, alpha_gamma = 2, @@ -30,9 +30,9 @@ cloglog_ordinal_bart( \item{n_trees}{Number of trees in the BART ensemble. Default: \code{50}.} -\item{num_gfr}{Number of GFR samples to draw at the beginning of the sampler. Default: \code{10}.} +\item{num_gfr}{Number of GFR samples to draw at the beginning of the sampler. Default: \code{0}.} -\item{num_burnin}{Number of burn-in MCMC samples to discard. Default: \code{0}.} +\item{num_burnin}{Number of burn-in MCMC samples to discard. Default: \code{1000}.} \item{num_mcmc}{Total number of MCMC samples to draw. Default: \code{500}.} diff --git a/tools/debug/cloglog_ordinal_bart_binary.R b/tools/debug/cloglog_ordinal_bart_binary.R index 25840b32..1d62d3e5 100644 --- a/tools/debug/cloglog_ordinal_bart_binary.R +++ b/tools/debug/cloglog_ordinal_bart_binary.R @@ -56,9 +56,7 @@ out <- cloglog_ordinal_bart( num_gfr = 0, num_burnin = 1000, num_mcmc = 1000, - n_thin = 1, - alpha_gamma = 1.0, - beta_gamma = 1.0 + n_thin = 1 ) end <- Sys.time() diff --git a/tools/debug/cloglog_ordinal_bart_four_category.R b/tools/debug/cloglog_ordinal_bart_four_category.R index 56a99d30..02d41fcf 100644 --- a/tools/debug/cloglog_ordinal_bart_four_category.R +++ b/tools/debug/cloglog_ordinal_bart_four_category.R @@ -52,8 +52,8 @@ out <- cloglog_ordinal_bart( X = X_train, y = y_train, X_test = X_test, - num_gfr = 10, - num_burnin = 0, + num_gfr = 0, + num_burnin = 1000, num_mcmc = 1000, n_thin = 1 ) diff --git a/tools/debug/cloglog_ordinal_bart_three_category.R b/tools/debug/cloglog_ordinal_bart_three_category.R index d50e9746..a10072c5 100644 --- a/tools/debug/cloglog_ordinal_bart_three_category.R +++ b/tools/debug/cloglog_ordinal_bart_three_category.R @@ -16,7 +16,7 @@ true_lambda_function <- X %*% beta # Set cutpoints for ordinal categories (3 categories: 1, 2, 3) n_categories <- 3 -gamma_true <- c(-2, 3) +gamma_true <- c(-2, 1) ordinal_cutpoints <- log(cumsum(exp(gamma_true))) ordinal_cutpoints @@ -52,8 +52,8 @@ out <- cloglog_ordinal_bart( X = X_train, y = y_train, X_test = X_test, - num_gfr = 10, - num_burnin = 0, + num_gfr = 0, + num_burnin = 1000, num_mcmc = 1000, n_thin = 1 ) From 4f576a6eff04a7ff796619ce535666ce21464562 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 28 Oct 2025 00:03:54 -0500 Subject: [PATCH 32/34] Added a release candidate readme --- RC_README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 RC_README.md diff --git a/RC_README.md b/RC_README.md new file mode 100644 index 00000000..2d677fd0 --- /dev/null +++ b/RC_README.md @@ -0,0 +1,16 @@ +# Release Candidate for StochTree Cloglog BART + +This branch serves as a staging / testing zone for the planned incorporation of BART / BCF with a complementary log-log link function into `stochtree`. + +## Installation + +The cloglog release candidate version of `stochtree` can be installed from github via + +``` +remotes::install_github("StochasticTree/stochtree", ref="cloglog-bart-rc") +``` + +## Vignettes and Demos + +Before incorporating this functionality into `stochtree`, we intend to develop a rich set of vignettes. +We have included demo scripts for the cloglog model on synthetic ordinal data with 2, 3 and 4 categories in the `tools` subfolder of this branch. From dc25a1d75435ef1d94ca4300a24d62dd17662ab2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 28 Oct 2025 00:20:45 -0500 Subject: [PATCH 33/34] Updated demo scripts --- tools/debug/cloglog_ordinal_bart_binary.R | 29 ++++++++--------- .../cloglog_ordinal_bart_four_category.R | 31 +++++++++---------- .../cloglog_ordinal_bart_three_category.R | 28 ++++++++--------- 3 files changed, 40 insertions(+), 48 deletions(-) diff --git a/tools/debug/cloglog_ordinal_bart_binary.R b/tools/debug/cloglog_ordinal_bart_binary.R index 1d62d3e5..5f167e04 100644 --- a/tools/debug/cloglog_ordinal_bart_binary.R +++ b/tools/debug/cloglog_ordinal_bart_binary.R @@ -1,8 +1,7 @@ -# Simulate ordinal data and run Cloglog Ordinal BART - -# Load +# Load library library(stochtree) +# Set seed set.seed(2025) # Sample size and number of predictors @@ -11,7 +10,6 @@ p <- 5 # Design matrix and true lambda function X <- matrix(runif(n * p), ncol = p) -# true_lambda_function <- ifelse(X[, 1] > 0.5, 2, -1) beta <- rep(1 / sqrt(p), p) true_lambda_function <- X %*% beta @@ -46,8 +44,6 @@ y_train <- y[train_idx] X_test <- X[test_idx, ] y_test <- y[test_idx] -start <- Sys.time() - # Sample the cloglog ordinal BART model out <- cloglog_ordinal_bart( X = X_train, @@ -59,18 +55,17 @@ out <- cloglog_ordinal_bart( n_thin = 1 ) -end <- Sys.time() -print(end - start) - -# Inference and diagnostics +# Traceplot of cutoff parameters par(mfrow = c(2, 1)) plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") abline(h = gamma_true[1], col = 'red', lty = 2) +# Histogram of cutoff parameters gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) summary(gamma1) hist(gamma1) +# Traceplots of cutoff parameters combined with average forest predictions par(mfrow = c(2,1)) rowMeans(out$gamma_samples) moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) @@ -78,15 +73,17 @@ plot(moo[,1]) abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) plot(out$gamma_samples[1,]) -# Compare forest predictions with the truth function (for training and test sets) +# Compare forest predictions with the truth (for training and test sets) par(mfrow = c(2,1)) + +# Train set lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) -# lambda_pred_train <- rowMeans(out$forest_predictions_train) plot(lambda_pred_train, gamma_true[1] + true_lambda_function[train_idx]) abline(a=0,b=1,col='blue', lwd=2) cor_train <- cor(true_lambda_function[train_idx] + mean(out$gamma_samples[1,]), gamma_true[1] + lambda_pred_train) text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') +# Test set lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) plot(lambda_pred_test, gamma_true[1] + true_lambda_function[test_idx]) abline(a=0,b=1,col='blue', lwd=2) @@ -105,10 +102,10 @@ for (j in 1:n_categories) { (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) } } - +# Compute average difference mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) -# Compare estimated vs true class probabilities for training set +# Plot estimated vs true class probabilities for training set for (j in 1:n_categories) { plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) abline(a = 0, b = 1, col = 'blue', lwd = 2) @@ -128,10 +125,10 @@ for (j in 1:n_categories) { (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) } } - +# Compute average difference mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) -# Compare estimated vs true class probabilities for test set +# Plot estimated vs true class probabilities for test set for (j in 1:n_categories) { plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) abline(a = 0, b = 1, col = 'blue', lwd = 2) diff --git a/tools/debug/cloglog_ordinal_bart_four_category.R b/tools/debug/cloglog_ordinal_bart_four_category.R index 02d41fcf..36ee8710 100644 --- a/tools/debug/cloglog_ordinal_bart_four_category.R +++ b/tools/debug/cloglog_ordinal_bart_four_category.R @@ -1,8 +1,7 @@ -# Simulate ordinal data and run Cloglog Ordinal BART - -# Load +# Load library library(stochtree) +# Set seed set.seed(2025) # Sample size and number of predictors @@ -45,8 +44,6 @@ y_train <- y[train_idx] X_test <- X[test_idx, ] y_test <- y[test_idx] -start <- Sys.time() - # Sample the cloglog ordinal BART model out <- cloglog_ordinal_bart( X = X_train, @@ -58,10 +55,7 @@ out <- cloglog_ordinal_bart( n_thin = 1 ) -end <- Sys.time() -print(end - start) - -# Inference and diagnostics +# Traceplots of cutoff parameters par(mfrow = c(2, 2)) plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") abline(h = gamma_true[1], col = 'red', lty = 2) @@ -70,20 +64,20 @@ abline(h = gamma_true[2], col = 'red', lty = 2) plot(out$gamma_samples[3, ], type = 'l', main = expression(gamma[3]), ylab = "Value", xlab = "MCMC Sample") abline(h = gamma_true[3], col = 'red', lty = 2) +# Histograms of cutoff parameters par(mfrow = c(2, 2)) gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) summary(gamma1) hist(gamma1) - gamma2 <- out$gamma_samples[2,] + colMeans(out$forest_predictions_train) summary(gamma2) hist(gamma2) - gamma3 <- out$gamma_samples[3,] + colMeans(out$forest_predictions_train) summary(gamma3) hist(gamma3) -par(mfrow = c(2,3), mar = c(5,4,1,1)) +# Traceplots of cutoff parameters combined with average forest predictions +par(mfrow = c(2,3)) rowMeans(out$gamma_samples) moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) plot(moo[,1]) @@ -96,14 +90,17 @@ plot(out$gamma_samples[1,]) plot(out$gamma_samples[2,]) plot(out$gamma_samples[3,]) -# Compare forest predictions with the truth function (for training and test sets) +# Compare forest predictions with the truth (for training and test sets) par(mfrow = c(2,1)) + +# Train set lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) plot(lambda_pred_train, true_lambda_function[train_idx]) abline(a=0,b=1,col='blue', lwd=2) cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') +# Test set lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) plot(lambda_pred_test, true_lambda_function[test_idx]) abline(a=0,b=1,col='blue', lwd=2) @@ -122,10 +119,10 @@ for (j in 1:n_categories) { (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) } } - +# Compute average difference mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) -# Compare estimated vs true class probabilities for training set +# Plot estimated vs true class probabilities for training set par(mfrow = c(2,2)) for (j in 1:n_categories) { plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) @@ -146,10 +143,10 @@ for (j in 1:n_categories) { (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) } } - +# Average difference mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) -# Compare estimated vs true class probabilities for test set +# Plot estimated vs true class probabilities for test set par(mfrow = c(2,2)) for (j in 1:n_categories) { plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) diff --git a/tools/debug/cloglog_ordinal_bart_three_category.R b/tools/debug/cloglog_ordinal_bart_three_category.R index a10072c5..a7ba416d 100644 --- a/tools/debug/cloglog_ordinal_bart_three_category.R +++ b/tools/debug/cloglog_ordinal_bart_three_category.R @@ -1,8 +1,7 @@ -# Simulate ordinal data and run Cloglog Ordinal BART - -# Load +# Load library library(stochtree) +# Set seed set.seed(2025) # Sample size and number of predictors @@ -45,8 +44,6 @@ y_train <- y[train_idx] X_test <- X[test_idx, ] y_test <- y[test_idx] -start <- Sys.time() - # Sample the cloglog ordinal BART model out <- cloglog_ordinal_bart( X = X_train, @@ -58,25 +55,23 @@ out <- cloglog_ordinal_bart( n_thin = 1 ) -end <- Sys.time() -print(end - start) - -# Inference and diagnostics +# Traceplots of cutoff parameters par(mfrow = c(2, 1)) plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") abline(h = gamma_true[1], col = 'red', lty = 2) plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") abline(h = gamma_true[2], col = 'red', lty = 2) +# Histograms of cutoff parameters gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) summary(gamma1) hist(gamma1) - gamma2 <- out$gamma_samples[2,] + colMeans(out$forest_predictions_train) summary(gamma2) hist(gamma2) -par(mfrow = c(3,2), mar = c(5,4,1,1)) +# Traceplots of cutoff parameters combined with average forest predictions +par(mfrow = c(3,2)) rowMeans(out$gamma_samples) moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) plot(moo[,1]) @@ -86,14 +81,17 @@ abline(h = gamma_true[2] + mean(true_lambda_function[train_idx])) plot(out$gamma_samples[1,]) plot(out$gamma_samples[2,]) -# Compare forest predictions with the truth function (for training and test sets) +# Compare forest predictions with the truth (for training and test sets) par(mfrow = c(2,1)) + +# Train set lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) plot(lambda_pred_train, true_lambda_function[train_idx]) abline(a=0,b=1,col='blue', lwd=2) cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') +# Test set lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) plot(lambda_pred_test, true_lambda_function[test_idx]) abline(a=0,b=1,col='blue', lwd=2) @@ -112,10 +110,10 @@ for (j in 1:n_categories) { (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) } } - +# Compute average difference mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) -# Compare estimated vs true class probabilities for training set +# Plot estimated vs true class probabilities for training set par(mfrow = c(2,2)) for (j in 1:n_categories) { plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) @@ -136,7 +134,7 @@ for (j in 1:n_categories) { (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) } } - +# Compute average difference mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) # Compare estimated vs true class probabilities for test set From e0ccb02b1011c012ab985b9ba8f15eaa7996a4a4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 29 Oct 2025 10:55:31 -0500 Subject: [PATCH 34/34] WIP python frontend for cloglog ordinal BART --- CMakeLists.txt | 1 + .../cloglog_ordinary_bart_three_category.py | 229 ++++++ src/py_stochtree.cpp | 103 ++- stochtree/__init__.py | 2 + stochtree/cloglog_ordinal_bart.py | 698 ++++++++++++++++++ stochtree/config.py | 8 +- stochtree/data.py | 91 +++ .../cloglog_ordinal_bart_three_category.R | 2 + 8 files changed, 1129 insertions(+), 5 deletions(-) create mode 100644 demo/debug/cloglog_ordinary_bart_three_category.py create mode 100644 stochtree/cloglog_ordinal_bart.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d1efe55..08957562 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -139,6 +139,7 @@ file( src/io.cpp src/json11.cpp src/leaf_model.cpp + src/ordinal_sampler.cpp src/partition_tracker.cpp src/random_effects.cpp src/tree.cpp diff --git a/demo/debug/cloglog_ordinary_bart_three_category.py b/demo/debug/cloglog_ordinary_bart_three_category.py new file mode 100644 index 00000000..19e90a98 --- /dev/null +++ b/demo/debug/cloglog_ordinary_bart_three_category.py @@ -0,0 +1,229 @@ +# Load libraries +import numpy as np +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +from stochtree import CloglogOrdinalBARTModel + +# Set seed +seed = 2025 +rng = np.random.default_rng(seed) + +# Sample size and number of predictors +n = 2000 +p = 5 + +# Design matrix and true lambda function +X = rng.normal(0, 1, size=(n, p)) +beta = np.repeat(1 / np.sqrt(p), p) +true_lambda_function = X @ beta + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories = 3 +gamma_true = np.array([-2, 1]) +ordinal_cutpoints = np.log(np.cumsum(np.exp(gamma_true))) +print("Ordinal cutpoints:", ordinal_cutpoints) + +# True ordinal class probabilities +true_probs = np.zeros((n, n_categories)) +for j in range(n_categories): + if j == 0: + true_probs[:, j] = 1 - np.exp(-np.exp(gamma_true[j] + true_lambda_function)) + elif j == n_categories - 1: + true_probs[:, j] = 1 - np.sum(true_probs[:, :j], axis=1) + else: + true_probs[:, j] = np.exp(-np.exp(gamma_true[j - 1] + true_lambda_function)) * ( + 1 - np.exp(-np.exp(gamma_true[j] + true_lambda_function)) + ) +print(f"Probability distribution: {np.mean(true_probs, axis=0)}") + +# Generate ordinal outcomes +y = np.zeros(n, dtype=int) +for i in range(n): + y[i] = rng.choice(np.arange(n_categories), p=true_probs[i, :]) +print(f"Outcome distribution: {np.bincount(y)}") + +# Train-test split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] + +# Run cloglog ordinal BART model +bart_model = CloglogOrdinalBARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + n_trees=50, + num_gfr=0, + num_burnin=1000, + num_mcmc=500, + n_thin=1, +) + +# Traceplots of cutoff parameters +plt.subplot(1, 2, 1) +plt.plot(bart_model.gamma_samples[0, :], linestyle="-", label=r"$\gamma_1$") +plt.subplot(1, 2, 2) +plt.plot(bart_model.gamma_samples[1, :], linestyle="-", label=r"$\gamma_2$") +plt.show() + +# Histograms of cutoff parameters +plt.clf() +gamma1 = bart_model.gamma_samples[0, :] + np.mean(bart_model.forest_pred_train, axis=0) +plt.subplot(1, 2, 1) +plt.hist(gamma1, bins=30, edgecolor="black") +gamma2 = bart_model.gamma_samples[1, :] + np.mean(bart_model.forest_pred_train, axis=0) +plt.subplot(1, 2, 2) +plt.hist(gamma2, bins=30, edgecolor="black") +plt.show() + +# Traceplots of cutoff parameters combined with average forest predictions +plt.clf() +plt.subplot(1, 2, 1) +plt.plot(gamma1, linestyle="-", label=r"$\gamma_1$") +plt.axhline( + y=gamma_true[0] + np.mean(true_lambda_function[train_inds]), + color="red", + linestyle="--", +) +plt.subplot(1, 2, 2) +plt.plot(gamma2, linestyle="-", label=r"$\gamma_2$") +plt.axhline( + y=gamma_true[1] + np.mean(true_lambda_function[train_inds]), + color="red", + linestyle="--", +) +plt.show() + +# Compare forest predictions with the truth (for training and test sets) + +# Train set +plt.clf() +lambda_pred_train = np.mean(bart_model.forest_pred_train, axis=1) - np.mean( + bart_model.forest_pred_train +) +plt.subplot(1, 2, 1) +plt.plot(lambda_pred_train, true_lambda_function[train_inds], "o") +plt.xlabel("Predicted lambda (train)") +plt.ylabel("True lambda (train)") +plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) +cor_train = np.corrcoef(true_lambda_function[train_inds], lambda_pred_train)[0, 1] +plt.text( + min(true_lambda_function[train_inds]), + max(true_lambda_function[train_inds]), + f"Correlation: {round(cor_train, 3)}", +) + +# Test set +lambda_pred_test = np.mean(bart_model.forest_pred_test, axis=1) - np.mean( + bart_model.forest_pred_test +) +plt.subplot(1, 2, 2) +plt.plot(lambda_pred_test, true_lambda_function[test_inds], "o") +plt.xlabel("Predicted lambda (test)") +plt.ylabel("True lambda (test)") +plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) +cor_test = np.corrcoef(true_lambda_function[test_inds], lambda_pred_test)[0, 1] +plt.text( + min(true_lambda_function[test_inds]), + max(true_lambda_function[test_inds]), + f"Correlation: {round(cor_test, 3)}", +) +plt.show() + +# Estimated ordinal class probabilities for the training set +est_probs_train = np.zeros((len(train_inds), n_categories)) +for j in range(n_categories): + if j == 0: + est_probs_train[:, j] = np.mean( + 1 + - np.exp( + -np.exp(bart_model.forest_pred_train + bart_model.gamma_samples[j, :]) + ), + axis=1, + ) + elif j == n_categories - 1: + est_probs_train[:, j] = 1 - np.sum(est_probs_train[:, :j], axis=1) + else: + est_probs_train[:, j] = np.mean( + np.exp( + -np.exp( + bart_model.forest_pred_train + bart_model.gamma_samples[j - 1, :] + ) + ) + * ( + 1 + - np.exp( + -np.exp( + bart_model.forest_pred_train + bart_model.gamma_samples[j, :] + ) + ) + ), + axis=1, + ) + +# Plot estimated vs true class probabilities for training set +plt.clf() +for j in range(n_categories): + plt.subplot(1, n_categories, j + 1) + plt.plot(est_probs_train[:, j], true_probs[train_inds, j], "o") + plt.xlabel("Predicted prob (train)") + plt.ylabel("True prob (train)") + plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) + cor_train = np.corrcoef(est_probs_train[:, j], true_probs[train_inds, j])[0, 1] + plt.text( + min(est_probs_train[:, j]), + max(true_probs[train_inds, j]), + f"Correlation: {round(cor_train, 3)}", + ) + plt.show() + +# Estimated ordinal class probabilities for the training set +est_probs_test = np.zeros((len(test_inds), n_categories)) +for j in range(n_categories): + if j == 0: + est_probs_test[:, j] = np.mean( + 1 + - np.exp( + -np.exp(bart_model.forest_pred_test + bart_model.gamma_samples[j, :]) + ), + axis=1, + ) + elif j == n_categories - 1: + est_probs_test[:, j] = 1 - np.sum(est_probs_test[:, :j], axis=1) + else: + est_probs_test[:, j] = np.mean( + np.exp( + -np.exp( + bart_model.forest_pred_test + bart_model.gamma_samples[j - 1, :] + ) + ) + * ( + 1 + - np.exp( + -np.exp( + bart_model.forest_pred_test + bart_model.gamma_samples[j, :] + ) + ) + ), + axis=1, + ) + +# Plot estimated vs true class probabilities for test set +plt.clf() +for j in range(n_categories): + plt.subplot(1, n_categories, j + 1) + plt.plot(est_probs_test[:, j], true_probs[test_inds, j], "o") + plt.xlabel("Predicted prob (test)") + plt.ylabel("True prob (test)") + plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) + cor_test = np.corrcoef(est_probs_test[:, j], true_probs[test_inds, j])[0, 1] + plt.text( + min(est_probs_test[:, j]), + max(true_probs[test_inds, j]), + f"Correlation: {round(cor_test, 3)}", + ) + plt.show() diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index a95ab0ee..284fd211 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -142,6 +143,42 @@ class ForestDatasetCpp { return dataset_.get(); } + bool HasAuxiliaryDimension(int dim_idx) { + return dataset_->HasAuxiliaryDimension(dim_idx); + } + + void AddAuxiliaryDimension(int dim_size) { + dataset_->AddAuxiliaryDimension(dim_size); + } + + double GetAuxiliaryDataValue(int dim_idx, data_size_t element_idx) { + return dataset_->GetAuxiliaryDataValue(dim_idx, element_idx); + } + + void SetAuxiliaryDataValue(int dim_idx, data_size_t element_idx, double value) { + dataset_->SetAuxiliaryDataValue(dim_idx, element_idx, value); + } + + py::array_t GetAuxiliaryDataArray(int dim_idx) { + std::vector output_vec = dataset_->GetAuxiliaryDataVector(dim_idx); + int n = output_vec.size(); + auto result = py::array_t(py::detail::any_container({n})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < n; i++) { + accessor(i) = output_vec[i]; + } + return result; + } + + void StoreAuxiliaryDataArrayMatrix(py::array_t output_matrix, int dim_idx, int matrix_col_idx) { + const std::vector output_raw = dataset_->GetAuxiliaryDataVector(dim_idx); + int n = output_raw.size(); + auto accessor = output_matrix.mutable_unchecked<2>(); + for (int i = 0; i < n; i++) { + accessor(i, matrix_col_idx) = output_raw[i]; + } + } + private: std::unique_ptr dataset_; }; @@ -1105,6 +1142,8 @@ class ForestSamplerCpp { else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; + else StochTree::Log::Fatal("Invalid model type"); // Unpack leaf model parameters double leaf_scale; @@ -1148,6 +1187,8 @@ class ForestSamplerCpp { StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); } } else { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { @@ -1158,6 +1199,8 @@ class ForestSamplerCpp { StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); } } } @@ -1170,6 +1213,7 @@ class ForestSamplerCpp { else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; else StochTree::Log::Fatal("Invalid model type"); // Unpack initial value @@ -1214,6 +1258,10 @@ class ForestSamplerCpp { int n = forest_data_ptr->NumObservations(); std::vector initial_preds(n, init_val); forest_data_ptr->AddVarianceWeights(initial_preds.data(), n); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + leaf_init_val = std::log(init_val) / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } } @@ -1276,6 +1324,46 @@ class ForestSamplerCpp { std::unique_ptr split_prior_; }; +class OrdinalSamplerCpp { + public: + OrdinalSamplerCpp() { + // Initialize pointer to C++ OrdinalSampler classes + ordinal_sampler_ = std::make_unique(); + } + ~OrdinalSamplerCpp() {} + + double SampleTruncatedExponential(RngCpp& rng, double rate, double lower_bound = 0.0, double upper_bound = 1.0) { + std::mt19937* rng_ptr = rng.GetRng(); + return ordinal_sampler_->SampleTruncatedExponential(*rng_ptr, rate, lower_bound, upper_bound); + } + + void UpdateLatentVariables(ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng) { + StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); + StochTree::ColumnVector* residual_ptr = residual.GetData(); + Eigen::VectorXd& residual_data_eigen = residual_ptr->GetData(); + std::mt19937* rng_ptr = rng.GetRng(); + ordinal_sampler_->UpdateLatentVariables(*dataset_ptr, residual_data_eigen, *rng_ptr); + } + + void UpdateGammaParams(ForestDatasetCpp& dataset, ResidualCpp& residual, + double alpha_gamma, double beta_gamma, + double gamma_0, RngCpp& rng) { + StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); + StochTree::ColumnVector* residual_ptr = residual.GetData(); + Eigen::VectorXd& residual_data_eigen = residual_ptr->GetData(); + std::mt19937* rng_ptr = rng.GetRng(); + ordinal_sampler_->UpdateGammaParams(*dataset_ptr, residual_data_eigen, alpha_gamma, beta_gamma, gamma_0, *rng_ptr); + } + + void UpdateCumulativeExpSums(ForestDatasetCpp& dataset) { + StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); + ordinal_sampler_->UpdateCumulativeExpSums(*dataset_ptr); + } + + private: + std::unique_ptr ordinal_sampler_; +}; + class GlobalVarianceModelCpp { public: GlobalVarianceModelCpp() { @@ -2149,7 +2237,13 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("GetBasis", &ForestDatasetCpp::GetBasis) .def("GetVarianceWeights", &ForestDatasetCpp::GetVarianceWeights) .def("HasBasis", &ForestDatasetCpp::HasBasis) - .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights); + .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights) + .def("HasAuxiliaryDimension", &ForestDatasetCpp::HasAuxiliaryDimension) + .def("AddAuxiliaryDimension", &ForestDatasetCpp::AddAuxiliaryDimension) + .def("GetAuxiliaryDataValue", &ForestDatasetCpp::GetAuxiliaryDataValue) + .def("SetAuxiliaryDataValue", &ForestDatasetCpp::SetAuxiliaryDataValue) + .def("GetAuxiliaryDataArray", &ForestDatasetCpp::GetAuxiliaryDataArray) + .def("StoreAuxiliaryDataArrayMatrix", &ForestDatasetCpp::StoreAuxiliaryDataArrayMatrix); py::class_(m, "ResidualCpp") .def(py::init,data_size_t>()) @@ -2340,6 +2434,13 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def(py::init<>()) .def("SampleOneIteration", &LeafVarianceModelCpp::SampleOneIteration); + py::class_(m, "OrdinalSamplerCpp") + .def(py::init<>()) + .def("SampleTruncatedExponential", &OrdinalSamplerCpp::SampleTruncatedExponential) + .def("UpdateLatentVariables", &OrdinalSamplerCpp::UpdateLatentVariables) + .def("UpdateGammaParams", &OrdinalSamplerCpp::UpdateGammaParams) + .def("UpdateCumulativeExpSums", &OrdinalSamplerCpp::UpdateCumulativeExpSums); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/stochtree/__init__.py b/stochtree/__init__.py index 318f6219..0b71557c 100644 --- a/stochtree/__init__.py +++ b/stochtree/__init__.py @@ -1,6 +1,7 @@ from .bart import BARTModel from .bcf import BCFModel from .calibration import calibrate_global_error_variance +from .cloglog_ordinal_bart import CloglogOrdinalBARTModel from .config import ForestModelConfig, GlobalModelConfig from .data import Dataset, Residual from .forest import Forest, ForestContainer @@ -39,6 +40,7 @@ __all__ = [ "BARTModel", "BCFModel", + "CloglogOrdinalBARTModel", "Dataset", "Residual", "ForestContainer", diff --git a/stochtree/cloglog_ordinal_bart.py b/stochtree/cloglog_ordinal_bart.py new file mode 100644 index 00000000..5137bf24 --- /dev/null +++ b/stochtree/cloglog_ordinal_bart.py @@ -0,0 +1,698 @@ +import warnings +from math import log +from numbers import Integral +from typing import Any, Dict, Optional, Union + +import numpy as np +import pandas as pd +from scipy.stats import norm + +from stochtree_cpp import OrdinalSamplerCpp +from .config import ForestModelConfig, GlobalModelConfig +from .data import Dataset, Residual +from .forest import Forest, ForestContainer +from .preprocessing import CovariatePreprocessor, _preprocess_params +from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel +from .serialization import JSONSerializer +from .utils import ( + NotSampledError, + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, +) + + +class CloglogOrdinalBARTModel: + r""" + Class that handles sampling, storage, and serialization of BART models with a cloglog link for ordinal outcomes. + This is an implementation of the model of Alam and Linero (2025), in which y is an ordinal outcome with K categories, ordered from 0 to K-1. + """ + + def __init__(self) -> None: + # Internal flag for whether the sample() method has been run + self.sampled = False + + def sample( + self, + X_train: Union[np.array, pd.DataFrame], + y_train: np.array, + X_test: Union[np.array, pd.DataFrame] = None, + n_trees: int = 50, + num_gfr: int = 0, + num_burnin: int = 1000, + num_mcmc: int = 500, + n_thin: int = 1, + alpha_gamma: float = 2.0, + beta_gamma: float = 2.0, + variable_weights: np.array = None, + feature_types: np.array = None, + seed: int = None, + num_threads=1, + ) -> None: + """Runs a Cloglog BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. + + Parameters + ---------- + X_train : np.array + Training set covariates on which trees may be partitioned. + y_train : np.array + Training set outcome (must be integer-valued from 0 to K-1, where K is the number of outcome categories). + X_test : np.array, optional + Optional test set covariates. + n_trees : int, optional + Number of trees in the BART ensemble. Defaults to `50`. + num_gfr : int, optional + Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to `0`. + num_burnin : int, optional + Number of "burn-in" iterations of the MCMC sampler. Defaults to `1000`. + num_mcmc : int, optional + Number of "retained" iterations of the MCMC sampler. Defaults to `500`. + n_thin : int, optional + Thinning interval for MCMC samples. Defaults to `1` (no thinning). + alpha_gamma : float, optional + Shape parameter for the log-gamma prior on cutpoints. Defaults to `2.0`. + beta_gamma : float, optional + Rate parameter for the log-gamma prior on cutpoints. Defaults to `2.0`. + variable_weights : np.array, optional + Variable weights for covariate selection probabilities. If `None`, uniform weights are used. + seed : int, optional + Random seed for reproducibility. If `None`, a random seed is used. + num_threads : int, optional + Number of threads to use for parallel processing. Defaults to `1`. + + Returns + ------- + self : BARTModel + Sampled BART Model. + """ + # Check data inputs + if not isinstance(X_train, pd.DataFrame) and not isinstance( + X_train, np.ndarray + ): + raise ValueError("X_train must be a pandas dataframe or numpy array") + if X_test is not None: + if not isinstance(X_test, pd.DataFrame) and not isinstance( + X_test, np.ndarray + ): + raise ValueError("X_test must be a pandas dataframe or numpy array") + if not isinstance(y_train, np.ndarray): + raise ValueError("y_train must be a numpy array") + if y_train.dtype not in [np.int32, np.int64]: + raise ValueError("y_train must be an integer-valued numpy array") + if np.any(y_train < 0): + raise ValueError("y_train must be non-negative integer-valued") + + # Convert everything to standard shape (2-dimensional) + if isinstance(X_train, np.ndarray): + if X_train.ndim == 1: + X_train = np.expand_dims(X_train, 1) + if y_train.ndim == 1: + y_train = np.expand_dims(y_train, 1) + if X_test is not None: + if isinstance(X_test, np.ndarray): + if X_test.ndim == 1: + X_test = np.expand_dims(X_test, 1) + + # Data checks + if X_test is not None: + if X_test.shape[1] != X_train.shape[1]: + raise ValueError( + "X_train and X_test must have the same number of columns" + ) + if y_train.shape[0] != X_train.shape[0]: + raise ValueError("X_train and y_train must have the same number of rows") + + # Variable weight preprocessing (and initialization if necessary) + p = X_train.shape[1] + if variable_weights is None: + if X_train.ndim > 1: + variable_weights = np.repeat(1.0 / p, p) + else: + variable_weights = np.repeat(1.0, 1) + if np.any(variable_weights < 0): + raise ValueError("variable_weights cannot have any negative weights") + + # Covariate preprocessing + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.fit(X_train) + X_train_processed = self._covariate_preprocessor.transform(X_train) + if X_test is not None: + X_test_processed = self._covariate_preprocessor.transform(X_test) + feature_types = np.asarray( + self._covariate_preprocessor._processed_feature_types + ) + original_var_indices = ( + self._covariate_preprocessor.fetch_original_feature_indices() + ) + + # Update variable weights if the covariates have been resized (by e.g. one-hot encoding) + if X_train_processed.shape[1] != X_train.shape[1]: + variable_counts = [ + original_var_indices.count(i) for i in original_var_indices + ] + variable_weights_adj = np.array([1 / i for i in variable_counts]) + variable_weights = ( + variable_weights[original_var_indices] * variable_weights_adj + ) + + # Determine whether a test set is provided + self.has_test = X_test is not None + + # Unpack data dimensions + self.n_train = y_train.shape[0] + self.n_test = X_test_processed.shape[0] if self.has_test else 0 + self.num_covariates = X_train_processed.shape[1] + + # Determine number of outcome categories + self.n_levels = np.max(np.unique(np.squeeze(y_train))) + 1 + + # Check that there are at least 2 outcome categories + if self.n_levels < 2: + raise ValueError("y_train must have at least 2 outcome categories") + + # BART parameters + alpha_bart = 0.95 + beta_bart = 2 + min_samples_in_leaf = 5 + max_depth = 10 + scale_leaf = 2 / np.sqrt(n_trees) + cutpoint_grid_size = 100 + + # Fixed for identifiability (can be pass as argument later if desired) + gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0 + + # Indices of MCMC samples to keep after GFR, burn-in, and thinning + keep_idx = np.arange( + num_gfr + num_burnin, num_gfr + num_burnin + num_mcmc, n_thin + ) + n_keep = len(keep_idx) + + # Container of parameter samples / model draws + self.num_gfr = num_gfr + self.num_burnin = num_burnin + self.num_mcmc = num_mcmc + self.forest_pred_train = np.empty((self.n_train, n_keep), dtype=np.float64) + if self.has_test: + self.forest_pred_test = np.empty((self.n_test, n_keep), dtype=np.float64) + self.gamma_samples = np.empty((self.n_levels - 1, n_keep), dtype=np.float64) + self.latent_samples = np.empty((self.n_train, n_keep), dtype=np.float64) + + # Initialize samplers + ordinal_sampler_cpp = OrdinalSamplerCpp() + if seed is None: + cpp_rng = RNG(-1) + self.rng = np.random.default_rng() + else: + cpp_rng = RNG(seed) + self.rng = np.random.default_rng(seed) + + # Data structures + forest_dataset_train = Dataset() + forest_dataset_train.add_covariates(X_train_processed) + if self.has_test: + forest_dataset_test = Dataset() + forest_dataset_test.add_covariates(X_test_processed) + outcome_train = Residual(y_train) + active_forest = Forest(n_trees, 1, True, False) + active_forest.set_root_leaves(0.0) + self.forest_samples = ForestContainer(n_trees, 1, True, False) + global_model_config = GlobalModelConfig(global_error_variance=1.0) + forest_model_config = ForestModelConfig( + num_trees=n_trees, + num_features=self.num_covariates, + num_observations=self.n_train, + feature_types=feature_types, + variable_weights=variable_weights, + leaf_dimension=1, + alpha=alpha_bart, + beta=beta_bart, + min_samples_leaf=min_samples_in_leaf, + max_depth=max_depth, + leaf_model_type=4, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_scale=scale_leaf, + ) + forest_sampler = ForestSampler( + forest_dataset_train, global_model_config, forest_model_config + ) + + # Latent variable (Z in Alam et al (2025) notation) + forest_dataset_train.add_auxiliary_dimension(self.n_train) + # Forest predictions (eta in Alam et al (2025) notation) + forest_dataset_train.add_auxiliary_dimension(self.n_train) + # Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation) + forest_dataset_train.add_auxiliary_dimension(self.n_levels - 1) + # Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation) + # This auxiliary series is designed so that the element stored at position `i` + # corresponds to the sum of all exponentiated gamma_j values for j < i. + # It has n_levels elements instead of n_levels - 1 because even the largest + # categorical index has a valid value of sum_{j < i} exp(gamma_j) + forest_dataset_train.add_auxiliary_dimension(self.n_levels) + + # Initialize gamma parameters to zero (3rd auxiliary data series, mapped to `dim_idx = 2` with 0-indexing) + initial_gamma = np.zeros((self.n_levels - 1,), dtype=np.float64) + for i in range(self.n_levels - 1): + forest_dataset_train.set_auxiliary_data_value(2, i - 1, initial_gamma[i]) + + # Convert the log-scale parameters into cumulative exponentiated parameters. + # This is done under the hood in a C++ function for efficiency. + ordinal_sampler_cpp.UpdateCumulativeExpSums(forest_dataset_train.dataset_cpp) + + # Initialize forest predictions to zero (slot 1) + for i in range(self.n_train): + forest_dataset_train.set_auxiliary_data_value(1, i, 0.0) + + # Initialize latent variables to zero (slot 0) + for i in range(self.n_train): + forest_dataset_train.set_auxiliary_data_value(0, i, 0.0) + + # Run the algorithm + sample_counter = -1 + for i in range(num_gfr + num_burnin + num_mcmc): + keep_sample = i in keep_idx + if keep_sample: + sample_counter += 1 + + # 1. Sample forest using MCMC + if i > self.num_gfr - 1: + forest_sampler.sample_one_iteration( + self.forest_samples, + active_forest, + forest_dataset_train, + outcome_train, + cpp_rng, + global_model_config, + forest_model_config, + keep_sample, + True, + num_threads, + ) + else: + forest_sampler.sample_one_iteration( + self.forest_samples, + active_forest, + forest_dataset_train, + outcome_train, + cpp_rng, + global_model_config, + forest_model_config, + keep_sample, + False, + num_threads, + ) + + # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions + # This is needed for updating gamma parameters, latent z_i's + forest_pred_current = active_forest.predict(forest_dataset_train) + for i in range(self.n_train): + forest_dataset_train.set_auxiliary_data_value( + 1, i, forest_pred_current[i] + ) + + # 2. Sample latent z_i's using truncated exponential + ordinal_sampler_cpp.UpdateLatentVariables( + forest_dataset_train.dataset_cpp, + outcome_train.residual_cpp, + cpp_rng.rng_cpp, + ) + + # 3. Sample gamma cutpoints + ordinal_sampler_cpp.UpdateGammaParams( + forest_dataset_train.dataset_cpp, + outcome_train.residual_cpp, + alpha_gamma, + beta_gamma, + gamma_0, + cpp_rng.rng_cpp, + ) + + # 4. Update cumulative sum of exp(gamma) values + ordinal_sampler_cpp.UpdateCumulativeExpSums( + forest_dataset_train.dataset_cpp + ) + + if keep_sample: + self.forest_pred_train[:, sample_counter] = active_forest.predict( + forest_dataset_train + ) + if self.has_test: + self.forest_pred_test[:, sample_counter] = active_forest.predict( + forest_dataset_test + ) + gamma_current = forest_dataset_train.get_auxiliary_data_array(2) + self.gamma_samples[:, sample_counter] = gamma_current + latent_current = forest_dataset_train.get_auxiliary_data_array(0) + self.latent_samples[:, sample_counter] = latent_current + + # Mark the model as sampled + self.sampled = True + + def predict( + self, + X: Union[np.array, pd.DataFrame], + ) -> np.array: + """Return predictions from the cloglog forest. + + Parameters + ---------- + covariates : np.array + Test set covariates. + + Returns + ------- + lambda_x : np.array, optional + Cloglog forest predictions + """ + if not self.is_sampled(): + msg = ( + "This CloglogOrdinalBARTModel instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this model." + ) + raise NotSampledError(msg) + + # Data checks + if not isinstance(X, pd.DataFrame) and not isinstance(X, np.ndarray): + raise ValueError("X must be a pandas dataframe or numpy array") + + # Convert everything to standard shape (2-dimensional) + if isinstance(X, np.ndarray): + if X.ndim == 1: + X = np.expand_dims(X, 1) + + # Covariate preprocessing + if not self._covariate_preprocessor._check_is_fitted(): + if not isinstance(X, np.ndarray): + raise ValueError( + "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." + ) + else: + warnings.warn( + "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", + RuntimeWarning, + ) + if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( + X.dtype, np.integer + ): + raise ValueError( + "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." + ) + X_processed = X + else: + X_processed = self._covariate_preprocessor.transform(X) + + # Dataset construction + pred_dataset = Dataset() + pred_dataset.add_covariates(X_processed) + + # Forest predictions + forest_pred = self.forest_samples.forest_container_cpp.Predict( + pred_dataset.dataset_cpp + ) + + return forest_pred + + # def to_json(self) -> str: + # """ + # Converts a sampled BART model to JSON string representation (which can then be saved to a file or + # processed using the `json` library) + + # Returns + # ------- + # str + # JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + # """ + # if not self.is_sampled: + # msg = ( + # "This BARTModel instance has not yet been sampled. " + # "Call 'fit' with appropriate arguments before using this model." + # ) + # raise NotSampledError(msg) + + # # Initialize JSONSerializer object + # bart_json = JSONSerializer() + + # # Add the forests + # if self.include_mean_forest: + # bart_json.add_forest(self.forest_container_mean) + # if self.include_variance_forest: + # bart_json.add_forest(self.forest_container_variance) + + # # Add the rfx + # if self.has_rfx: + # bart_json.add_random_effects(self.rfx_container) + + # # Add global parameters + # bart_json.add_scalar("outcome_scale", self.y_std) + # bart_json.add_scalar("outcome_mean", self.y_bar) + # bart_json.add_boolean("standardize", self.standardize) + # bart_json.add_scalar("sigma2_init", self.sigma2_init) + # bart_json.add_boolean("sample_sigma2_global", self.sample_sigma2_global) + # bart_json.add_boolean("sample_sigma2_leaf", self.sample_sigma2_leaf) + # bart_json.add_boolean("include_mean_forest", self.include_mean_forest) + # bart_json.add_boolean("include_variance_forest", self.include_variance_forest) + # bart_json.add_boolean("has_rfx", self.has_rfx) + # bart_json.add_integer("num_gfr", self.num_gfr) + # bart_json.add_integer("num_burnin", self.num_burnin) + # bart_json.add_integer("num_mcmc", self.num_mcmc) + # bart_json.add_integer("num_samples", self.num_samples) + # bart_json.add_integer("num_basis", self.num_basis) + # bart_json.add_boolean("requires_basis", self.has_basis) + # bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model) + + # # Add parameter samples + # if self.sample_sigma2_global: + # bart_json.add_numeric_vector( + # "sigma2_global_samples", self.global_var_samples, "parameters" + # ) + # if self.sample_sigma2_leaf: + # bart_json.add_numeric_vector( + # "sigma2_leaf_samples", self.leaf_scale_samples, "parameters" + # ) + + # # Add covariate preprocessor + # covariate_preprocessor_string = self._covariate_preprocessor.to_json() + # bart_json.add_string("covariate_preprocessor", covariate_preprocessor_string) + + # return bart_json.return_json_string() + + # def from_json(self, json_string: str) -> None: + # """ + # Converts a JSON string to an in-memory BART model. + + # Parameters + # ---------- + # json_string : str + # JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + # """ + # # Parse string to a JSON object in C++ + # bart_json = JSONSerializer() + # bart_json.load_from_json_string(json_string) + + # # Unpack forests + # self.include_mean_forest = bart_json.get_boolean("include_mean_forest") + # self.include_variance_forest = bart_json.get_boolean("include_variance_forest") + # self.has_rfx = bart_json.get_boolean("has_rfx") + # if self.include_mean_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_mean = ForestContainer(0, 0, False, False) + # self.forest_container_mean.forest_container_cpp.LoadFromJson( + # bart_json.json_cpp, "forest_0" + # ) + # if self.include_variance_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # bart_json.json_cpp, "forest_1" + # ) + # else: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # bart_json.json_cpp, "forest_0" + # ) + + # # Unpack random effects + # if self.has_rfx: + # self.rfx_container = RandomEffectsContainer() + # self.rfx_container.load_from_json(bart_json, 0) + + # # Unpack global parameters + # self.y_std = bart_json.get_scalar("outcome_scale") + # self.y_bar = bart_json.get_scalar("outcome_mean") + # self.standardize = bart_json.get_boolean("standardize") + # self.sigma2_init = bart_json.get_scalar("sigma2_init") + # self.sample_sigma2_global = bart_json.get_boolean("sample_sigma2_global") + # self.sample_sigma2_leaf = bart_json.get_boolean("sample_sigma2_leaf") + # self.num_gfr = bart_json.get_integer("num_gfr") + # self.num_burnin = bart_json.get_integer("num_burnin") + # self.num_mcmc = bart_json.get_integer("num_mcmc") + # self.num_samples = bart_json.get_integer("num_samples") + # self.num_basis = bart_json.get_integer("num_basis") + # self.has_basis = bart_json.get_boolean("requires_basis") + # self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") + + # # Unpack parameter samples + # if self.sample_sigma2_global: + # self.global_var_samples = bart_json.get_numeric_vector( + # "sigma2_global_samples", "parameters" + # ) + # if self.sample_sigma2_leaf: + # self.leaf_scale_samples = bart_json.get_numeric_vector( + # "sigma2_leaf_samples", "parameters" + # ) + + # # Unpack covariate preprocessor + # covariate_preprocessor_string = bart_json.get_string("covariate_preprocessor") + # self._covariate_preprocessor = CovariatePreprocessor() + # self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # # Mark the deserialized model as "sampled" + # self.sampled = True + + # def from_json_string_list(self, json_string_list: list[str]) -> None: + # """ + # Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object + # which can be used for prediction, etc... + + # Parameters + # ------- + # json_string_list : list of str + # List of JSON strings which can be parsed to objects of type `JSONSerializer` containing Json representation of a BART model + # """ + # # Convert strings to JSONSerializer + # json_object_list = [] + # for i in range(len(json_string_list)): + # json_string = json_string_list[i] + # json_object_list.append(JSONSerializer()) + # json_object_list[i].load_from_json_string(json_string) + + # # For scalar / preprocessing details which aren't sample-dependent, defer to the first json + # json_object_default = json_object_list[0] + + # # Unpack forests + # self.include_mean_forest = json_object_default.get_boolean( + # "include_mean_forest" + # ) + # self.include_variance_forest = json_object_default.get_boolean( + # "include_variance_forest" + # ) + # if self.include_mean_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_mean = ForestContainer(0, 0, False, False) + # for i in range(len(json_object_list)): + # if i == 0: + # self.forest_container_mean.forest_container_cpp.LoadFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + # else: + # self.forest_container_mean.forest_container_cpp.AppendFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + # if self.include_variance_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # for i in range(len(json_object_list)): + # if i == 0: + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # json_object_list[i].json_cpp, "forest_1" + # ) + # else: + # self.forest_container_variance.forest_container_cpp.AppendFromJson( + # json_object_list[i].json_cpp, "forest_1" + # ) + # else: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # for i in range(len(json_object_list)): + # if i == 0: + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + # else: + # self.forest_container_variance.forest_container_cpp.AppendFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + + # # Unpack random effects + # self.has_rfx = json_object_default.get_boolean("has_rfx") + # if self.has_rfx: + # self.rfx_container = RandomEffectsContainer() + # for i in range(len(json_object_list)): + # if i == 0: + # self.rfx_container.load_from_json(json_object_list[i], 0) + # else: + # self.rfx_container.append_from_json(json_object_list[i], 0) + + # # Unpack global parameters + # self.y_std = json_object_default.get_scalar("outcome_scale") + # self.y_bar = json_object_default.get_scalar("outcome_mean") + # self.standardize = json_object_default.get_boolean("standardize") + # self.sigma2_init = json_object_default.get_scalar("sigma2_init") + # self.sample_sigma2_global = json_object_default.get_boolean( + # "sample_sigma2_global" + # ) + # self.sample_sigma2_leaf = json_object_default.get_boolean("sample_sigma2_leaf") + # self.num_gfr = json_object_default.get_integer("num_gfr") + # self.num_burnin = json_object_default.get_integer("num_burnin") + # self.num_mcmc = json_object_default.get_integer("num_mcmc") + # self.num_basis = json_object_default.get_integer("num_basis") + # self.has_basis = json_object_default.get_boolean("requires_basis") + # self.probit_outcome_model = json_object_default.get_boolean( + # "probit_outcome_model" + # ) + + # # Unpack number of samples + # for i in range(len(json_object_list)): + # if i == 0: + # self.num_samples = json_object_list[i].get_integer("num_samples") + # else: + # self.num_samples += json_object_list[i].get_integer("num_samples") + + # # Unpack parameter samples + # if self.sample_sigma2_global: + # for i in range(len(json_object_list)): + # if i == 0: + # self.global_var_samples = json_object_list[i].get_numeric_vector( + # "sigma2_global_samples", "parameters" + # ) + # else: + # global_var_samples = json_object_list[i].get_numeric_vector( + # "sigma2_global_samples", "parameters" + # ) + # self.global_var_samples = np.concatenate( + # (self.global_var_samples, global_var_samples) + # ) + + # if self.sample_sigma2_leaf: + # for i in range(len(json_object_list)): + # if i == 0: + # self.leaf_scale_samples = json_object_list[i].get_numeric_vector( + # "sigma2_leaf_samples", "parameters" + # ) + # else: + # leaf_scale_samples = json_object_list[i].get_numeric_vector( + # "sigma2_leaf_samples", "parameters" + # ) + # self.leaf_scale_samples = np.concatenate( + # (self.leaf_scale_samples, leaf_scale_samples) + # ) + + # # Unpack covariate preprocessor + # covariate_preprocessor_string = json_object_default.get_string( + # "covariate_preprocessor" + # ) + # self._covariate_preprocessor = CovariatePreprocessor() + # self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # # Mark the deserialized model as "sampled" + # self.sampled = True + + def is_sampled(self) -> bool: + """Whether or not a BART model has been sampled. + + Returns + ------- + bool + `True` if a BART model has been sampled, `False` otherwise + """ + return self.sampled diff --git a/stochtree/config.py b/stochtree/config.py index 72cae512..3c9cd36b 100644 --- a/stochtree/config.py +++ b/stochtree/config.py @@ -44,7 +44,7 @@ class ForestModelConfig: max_depth : int, optional Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. leaf_model_type : int, optional - Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. + Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression, 3 = log linear variance, 4 = cloglog ordinal regression). Default: `0`. leaf_model_scale : float or np.ndarray, optional Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. variance_forest_shape : int, optional @@ -110,9 +110,9 @@ def __init__( if leaf_model_type is None: leaf_model_type = 0 if not _check_is_int(leaf_model_type): - raise ValueError("`leaf_model_type` must be an integer between 0 and 3") - elif leaf_model_type < 0 or leaf_model_type > 3: - raise ValueError("`leaf_model_type` must be an integer between 0 and 3") + raise ValueError("`leaf_model_type` must be an integer between 0 and 4") + elif leaf_model_type < 0 or leaf_model_type > 4: + raise ValueError("`leaf_model_type` must be an integer between 0 and 4") if not _check_is_int(leaf_dimension): raise ValueError("`leaf_dimension` must be an integer greater than 0") elif leaf_dimension <= 0: diff --git a/stochtree/data.py b/stochtree/data.py index 4e40a282..7668ac73 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -205,6 +205,97 @@ def has_variance_weights(self) -> bool: `True` if the dataset has variance weights, `False` otherwise """ return self.dataset_cpp.HasVarianceWeights() + + def has_auxiliary_dimension(self, dim_idx: int) -> bool: + """ + Whether or not a dataset has an auxiliary dimension + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension to check + + Returns + ------- + bool + `True` if the dataset has the specified auxiliary dimension, `False` otherwise + """ + return self.dataset_cpp.HasAuxiliaryDimension(dim_idx) + + def add_auxiliary_dimension(self, dim_size: int) -> None: + """ + Add an auxiliary dimension to a dataset + + Parameters + ---------- + dim_size : int + Size of the auxiliary dimension to add + """ + self.dataset_cpp.AddAuxiliaryDimension(dim_size) + + def get_auxiliary_data_value(self, dim_idx: int, element_idx: int) -> float: + """ + Get a value from an auxiliary dimension + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension + element_idx : int + Index of the element within the auxiliary dimension + + Returns + ------- + float + Value at the specified index in the auxiliary dimension + """ + return self.dataset_cpp.GetAuxiliaryDataValue(dim_idx, element_idx) + + def set_auxiliary_data_value(self, dim_idx: int, element_idx: int, value: float) -> None: + """ + Set a value in an auxiliary dimension + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension + element_idx : int + Index of the element within the auxiliary dimension + value : float + Value to set at the specified index in the auxiliary dimension + """ + self.dataset_cpp.SetAuxiliaryDataValue(dim_idx, element_idx, value) + + def get_auxiliary_data_array(self, dim_idx: int) -> np.array: + """ + Get an auxiliary dimension as a numpy array + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension + + Returns + ------- + np.array + Numpy array of the specified auxiliary dimension + """ + return self.dataset_cpp.GetAuxiliaryDataArray(dim_idx) + + def store_auxiliary_data_array_matrix(self, output_matrix: np.array, dim_idx: int, matrix_col_idx: int) -> None: + """ + Store an auxiliary dimension into a specified column of a numpy matrix + + Parameters + ---------- + output_matrix : np.array + Numpy array to store the auxiliary dimension into + dim_idx : int + Index of the auxiliary dimension + matrix_col_idx : int + Column index in the output matrix to store the auxiliary dimension + """ + self.dataset_cpp.StoreAuxiliaryDataArrayMatrix(output_matrix, dim_idx, matrix_col_idx) class Residual: diff --git a/tools/debug/cloglog_ordinal_bart_three_category.R b/tools/debug/cloglog_ordinal_bart_three_category.R index a7ba416d..96eba51a 100644 --- a/tools/debug/cloglog_ordinal_bart_three_category.R +++ b/tools/debug/cloglog_ordinal_bart_three_category.R @@ -31,6 +31,8 @@ for (j in 1:n_categories) { (1 - exp(-exp(gamma_true[j] + true_lambda_function))) } } +apply(true_probs, 2, mean) +summary(true_lambda_function) # Generate ordinal outcomes y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ]))