diff --git a/Project.toml b/Project.toml index 23a8af183..3727d7191 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.38" +DynamicPPL = "0.39" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.3" @@ -90,3 +90,6 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/not-experimental"} diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 9e4c8b6ef..1f565559c 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -37,7 +37,6 @@ $(TYPEDFIELDS) """ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} logdensity::L - vi::V "Cache of sample, log density, and gradient of log density evaluation." cache::C metric::M @@ -70,9 +69,8 @@ function Turing.Inference.initialstep( Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q) # Create first sample and state. - vi = DynamicPPL.unflatten(vi, Q.q) - sample = Turing.Inference.Transition(model, vi, nothing) - state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) + sample = DynamicPPL.ParamsWithStats(Q.q, ℓ) + state = DynamicNUTSState(ℓ, Q, steps.H.κ, steps.ϵ) return sample, state end @@ -85,15 +83,13 @@ function AbstractMCMC.step( kwargs..., ) # Compute next sample. - vi = state.vi ℓ = state.logdensity steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize) Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) # Create next sample and state. - vi = DynamicPPL.unflatten(vi, Q.q) - sample = Turing.Inference.Transition(model, vi, nothing) - newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) + sample = DynamicPPL.ParamsWithStats(Q.q, ℓ) + newstate = DynamicNUTSState(ℓ, Q, state.metric, state.stepsize) return sample, newstate end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 695f9c3aa..29b490d57 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -6,7 +6,6 @@ using DynamicPPL: Metadata, VarInfo, LogDensityFunction, - SimpleVarInfo, AbstractVarInfo, # TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only # implemented for NTVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it @@ -92,9 +91,6 @@ function DynamicPPL.unflatten(vi::DynamicPPL.NTVarInfo, θ::NamedTuple) set_namedtuple!(deepcopy(vi), θ) return vi end -function DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) - return SimpleVarInfo(θ, vi.logp, vi.transformation) -end """ mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real) @@ -114,312 +110,6 @@ function mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio:: return log(rand()) + logp_current ≤ logp_proposal + log_proposal_ratio end -###################### -# Default Transition # -###################### -getstats(::Any) = NamedTuple() -getstats(nt::NamedTuple) = nt - -struct Transition{T,F<:AbstractFloat,N<:NamedTuple} - θ::T - logprior::F - loglikelihood::F - stat::N - - """ - Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true) - - Construct a new `Turing.Inference.Transition` object using the outputs of a - sampler step. - - Here, `vi` represents a VarInfo _for which the appropriate parameters have - already been set_. However, the accumulators (e.g. logp) may in general - have junk contents. The role of this method is to re-evaluate `model` and - thus set the accumulators to the correct values. - - `stats` is any object on which `Turing.Inference.getstats` can be called to - return a NamedTuple of statistics. This could be, for example, the transition - returned by an (unwrapped) external sampler. Or alternatively, it could - simply be a NamedTuple itself (for which `getstats` acts as the identity). - - By default, the model is re-evaluated in order to obtain values of: - - the values of the parameters as per user parameterisation (`vals_as_in_model`) - - the various components of the log joint probability (`logprior`, `loglikelihood`) - that are guaranteed to be correct. - - If you **know** for a fact that the VarInfo `vi` already contains this information, - then you can set `reevaluate=false` to skip the re-evaluation step. - - !!! warning - Note that in general this is unsafe and may lead to wrong results. - - If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that - the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`, - and `LogLikelihoodAccumulator` set up with the correct values. Note that the - `ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it - must be set up to track `x := y` statements. - """ - function Transition( - model::DynamicPPL.Model, vi::AbstractVarInfo, stats; reevaluate=true - ) - # Avoid mutating vi as it may be used later e.g. when constructing - # sampler states. - vi = deepcopy(vi) - if reevaluate - vi = DynamicPPL.setaccs!!( - vi, - ( - DynamicPPL.ValuesAsInModelAccumulator(true), - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - ), - ) - _, vi = DynamicPPL.evaluate!!(model, vi) - end - - # Extract all the information we need - vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values - logprior = DynamicPPL.getlogprior(vi) - loglikelihood = DynamicPPL.getloglikelihood(vi) - - # Get additional statistics - stats = getstats(stats) - return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}( - vals_as_in_model, logprior, loglikelihood, stats - ) - end - - function Transition( - model::DynamicPPL.Model, - untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}, - stats; - reevaluate=true, - ) - # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's - # much faster to convert it to a typed varinfo first, hence this method. - # https://github.com/TuringLang/Turing.jl/issues/2604 - return Transition( - model, DynamicPPL.typed_varinfo(untyped_vi), stats; reevaluate=reevaluate - ) - end -end - -function getstats_with_lp(t::Transition) - return merge( - t.stat, - ( - lp=t.logprior + t.loglikelihood, - logprior=t.logprior, - loglikelihood=t.loglikelihood, - ), - ) -end -function getstats_with_lp(vi::AbstractVarInfo) - return ( - lp=DynamicPPL.getlogjoint(vi), - logprior=DynamicPPL.getlogprior(vi), - loglikelihood=DynamicPPL.getloglikelihood(vi), - ) -end - -########################## -# Chain making utilities # -########################## - -# TODO(penelopeysm): Separate Turing.Inference.getparams (should only be -# defined for AbstractVarInfo and Turing.Inference.Transition; returns varname -# => value maps) from AbstractMCMC.getparams (defined for any sampler transition, -# returns vector). -""" - Turing.Inference.getparams(model::DynamicPPL.Model, t::Any) - -Return a vector of parameter values from the given sampler transition `t` (i.e., -the first return value of AbstractMCMC.step). By default, returns the `t.θ` field. - -!!! note - This method only needs to be implemented for external samplers. It will be -removed in future releases and replaced with `AbstractMCMC.getparams`. -""" -getparams(::DynamicPPL.Model, t) = t.θ -""" - Turing.Inference.getparams(model::DynamicPPL.Model, t::AbstractVarInfo) - -Return a key-value map of parameters from the varinfo. -""" -function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) - t = Transition(model, vi, nothing) - return getparams(model, t) -end -function _params_to_array(model::DynamicPPL.Model, ts::Vector) - names_set = OrderedSet{VarName}() - # Extract the parameter names and values from each transition. - dicts = map(ts) do t - # In general getparams returns a dict of VarName => values. We need to also - # split it up into constituent elements using - # `AbstractPPL.varname_and_value_leaves` because otherwise MCMCChains.jl - # won't understand it. - vals = getparams(model, t) - nms_and_vs = if isempty(vals) - Tuple{VarName,Any}[] - else - iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) - mapreduce(collect, vcat, iters) - end - nms = map(first, nms_and_vs) - vs = map(last, nms_and_vs) - for nm in nms - push!(names_set, nm) - end - # Convert the names and values to a single dictionary. - return OrderedDict(zip(nms, vs)) - end - names = collect(names_set) - vals = [get(dicts[i], key, missing) for i in eachindex(dicts), key in names] - - return names, vals -end - -function get_transition_extras(ts::AbstractVector) - # Extract stats + log probabilities from each transition or VarInfo - extra_data = map(getstats_with_lp, ts) - return names_values(extra_data) -end - -function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where {names} - values = [getfield(data, name) for data in extra_data, name in names] - return collect(names), values -end - -function names_values(xs::AbstractVector{<:NamedTuple}) - # Obtain all parameter names. - names_set = Set{Symbol}() - for x in xs - for k in keys(x) - push!(names_set, k) - end - end - names_unique = collect(names_set) - - # Extract all values as matrix. - values = [haskey(x, name) ? x[name] : missing for x in xs, name in names_unique] - - return names_unique, values -end - -getlogevidence(transitions, sampler, state) = missing - -# Default MCMCChains.Chains constructor. -function AbstractMCMC.bundle_samples( - ts::Vector{<:Transition}, - model::DynamicPPL.Model, - spl::AbstractSampler, - state, - chain_type::Type{MCMCChains.Chains}; - save_state=false, - stats=missing, - sort_chain=false, - include_varname_to_symbol=true, - discard_initial=0, - thinning=1, - kwargs..., -) - # Convert transitions to array format. - # Also retrieve the variable names. - varnames, vals = _params_to_array(model, ts) - varnames_symbol = map(Symbol, varnames) - - # Get the values of the extra parameters in each transition. - extra_params, extra_values = get_transition_extras(ts) - - # Extract names & construct param array. - nms = [varnames_symbol; extra_params] - parray = hcat(vals, extra_values) - - # Get the average or final log evidence, if it exists. - le = getlogevidence(ts, spl, state) - - # Set up the info tuple. - info = NamedTuple() - - if include_varname_to_symbol - info = merge(info, (varname_to_symbol=OrderedDict(zip(varnames, varnames_symbol)),)) - end - - if save_state - info = merge(info, (model=model, sampler=spl, samplerstate=state)) - end - - # Merge in the timing info, if available - if !ismissing(stats) - info = merge(info, (start_time=stats.start, stop_time=stats.stop)) - end - - # Conretize the array before giving it to MCMCChains. - parray = MCMCChains.concretize(parray) - - # Chain construction. - chain = MCMCChains.Chains( - parray, - nms, - (internals=extra_params,); - evidence=le, - info=info, - start=discard_initial + 1, - thin=thinning, - ) - - return sort_chain ? sort(chain) : chain -end - -function AbstractMCMC.bundle_samples( - ts::Vector{<:Transition}, - model::DynamicPPL.Model, - spl::AbstractSampler, - state, - chain_type::Type{Vector{NamedTuple}}; - kwargs..., -) - return map(ts) do t - # Construct a dictionary of pairs `vn => value`. - params = OrderedDict(getparams(model, t)) - # Group the variable names by their symbol. - sym_to_vns = group_varnames_by_symbol(keys(params)) - # Convert the values to a vector. - vals = map(values(sym_to_vns)) do vns - map(Base.Fix1(getindex, params), vns) - end - return merge(NamedTuple(zip(keys(sym_to_vns), vals)), getstats_with_lp(t)) - end -end - -""" - group_varnames_by_symbol(vns) - -Group the varnames by their symbol. - -# Arguments -- `vns`: Iterable of `VarName`. - -# Returns -- `OrderedDict{Symbol, Vector{VarName}}`: A dictionary mapping symbol to a vector of varnames. -""" -function group_varnames_by_symbol(vns) - d = OrderedDict{Symbol,Vector{VarName}}() - for vn in vns - sym = DynamicPPL.getsym(vn) - if !haskey(d, sym) - d[sym] = VarName[] - end - push!(d[sym], vn) - end - return d -end - -function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples) - nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples)) - return setinfo(c, merge(nt, c.info)) -end - ####################################### # Concrete algorithm implementations. # ####################################### diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 226536aca..899749e28 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -26,8 +26,8 @@ function Emcee(n_walkers::Int, stretch_length=2.0) return Emcee{typeof(ensemble)}(ensemble) end -struct EmceeState{V<:AbstractVarInfo,S} - vi::V +struct EmceeState{L<:LogDensityFunction,S} + ldf::L states::S end @@ -65,11 +65,11 @@ function AbstractMCMC.step( end # Compute initial transition and states. - transition = [Transition(model, vi, nothing) for vi in vis] + transition = [DynamicPPL.ParamsWithStats(vi, model) for vi in vis] - # TODO: Make compatible with immutable `AbstractVarInfo`. + linked_vi = DynamicPPL.link!!(vis[1], model) state = EmceeState( - vis[1], + DynamicPPL.LogDensityFunction(model, getlogjoint_internal, linked_vi), map(vis) do vi vi = DynamicPPL.link!!(vi, model) AMH.Transition(vi[:], DynamicPPL.getlogjoint_internal(vi), false) @@ -83,23 +83,18 @@ function AbstractMCMC.step( rng::AbstractRNG, model::Model, spl::Emcee, state::EmceeState; kwargs... ) # Generate a log joint function. - vi = state.vi - densitymodel = AMH.DensityModel( - Base.Fix1( - LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), - ), - ) + densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, state.ldf)) # Compute the next states. - t, states = AbstractMCMC.step(rng, densitymodel, spl.ensemble, state.states) + _, states = AbstractMCMC.step(rng, densitymodel, spl.ensemble, state.states) # Compute the next transition and state. transition = map(states) do _state - vi = DynamicPPL.unflatten(vi, _state.params) - return Transition(model, vi, t) + return DynamicPPL.ParamsWithStats( + _state.params, state.ldf, AbstractMCMC.getstats(_state) + ) end - newstate = EmceeState(vi, states) + newstate = EmceeState(state.ldf, states) return transition, newstate end @@ -110,60 +105,14 @@ function AbstractMCMC.bundle_samples( spl::Emcee, state::EmceeState, chain_type::Type{MCMCChains.Chains}; - save_state=false, - sort_chain=false, - discard_initial=0, - thinning=1, kwargs..., ) - # Convert transitions to array format. - # Also retrieve the variable names. - params_vec = map(Base.Fix1(_params_to_array, model), samples) - - # Extract names and values separately. - varnames = params_vec[1][1] - varnames_symbol = map(Symbol, varnames) - vals_vec = [p[2] for p in params_vec] - - # Get the values of the extra parameters in each transition. - extra_vec = map(get_transition_extras, samples) - - # Get the extra parameter names & values. - extra_params = extra_vec[1][1] - extra_values_vec = [e[2] for e in extra_vec] - - # Extract names & construct param array. - nms = [varnames_symbol; extra_params] - # `hcat` first to ensure we get the right `eltype`. - x = hcat(first(vals_vec), first(extra_values_vec)) - # Pre-allocate to minimize memory usage. - parray = Array{eltype(x),3}(undef, length(vals_vec), size(x, 2), size(x, 1)) - for (i, (vals, extras)) in enumerate(zip(vals_vec, extra_values_vec)) - parray[i, :, :] = transpose(hcat(vals, extras)) - end - - # Get the average or final log evidence, if it exists. - le = getlogevidence(samples, state, spl) - - # Set up the info tuple. - info = (varname_to_symbol=OrderedDict(zip(varnames, varnames_symbol)),) - if save_state - info = merge(info, (model=model, sampler=spl, samplerstate=state)) + n_walkers = _get_n_walkers(spl) + chains = map(1:n_walkers) do i + this_walker_samples = [s[i] for s in samples] + AbstractMCMC.bundle_samples( + this_walker_samples, model, spl, state, chain_type; kwargs... + ) end - - # Concretize the array before giving it to MCMCChains. - parray = MCMCChains.concretize(parray) - - # Chain construction. - chain = MCMCChains.Chains( - parray, - nms, - (internals=extra_params,); - evidence=le, - info=info, - start=discard_initial + 1, - thin=thinning, - ) - - return sort_chain ? sort(chain) : chain + return AbstractMCMC.chainscat(chains...) end diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 18dbfa417..fa02b6222 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -31,7 +31,7 @@ function Turing.Inference.initialstep( EllipticalSliceSampling.isgaussian(typeof(dist)) || error("ESS only supports Gaussian prior distributions") end - return Transition(model, vi, nothing), vi + return DynamicPPL.ParamsWithStats(vi, model), vi end function AbstractMCMC.step( @@ -56,7 +56,7 @@ function AbstractMCMC.step( vi = DynamicPPL.unflatten(vi, sample) vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood) - return Transition(model, vi, nothing), vi + return DynamicPPL.ParamsWithStats(vi, model), vi end # Prior distribution of considered random variable diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index 94e9e1706..85537d6ca 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -122,14 +122,12 @@ function externalsampler( end # TODO(penelopeysm): Can't we clean this up somehow? -struct TuringState{S,V1,M,V} +struct TuringState{S,V,L<:DynamicPPL.LogDensityFunction} state::S - # Note that this varinfo must have the correct parameters set; but logp - # does not matter as it will be re-evaluated - varinfo::V1 - # Note that in general the VarInfo inside this LogDensityFunction will have - # junk parameters and logp. It only exists to provide structure - ldf::DynamicPPL.LogDensityFunction{M,V} + # Note that this varinfo is used only for structure. Its parameters and other info do + # not need to be accurate + varinfo::V + ldf::L end # get_varinfo should return something from which the correct parameters can be @@ -187,11 +185,10 @@ function AbstractMCMC.step( end new_parameters = AbstractMCMC.getparams(f.model, state_inner) - new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) new_stats = AbstractMCMC.getstats(state_inner) return ( - Turing.Inference.Transition(f.model, new_vi, new_stats), - TuringState(state_inner, new_vi, f), + DynamicPPL.ParamsWithStats(new_parameters, f, new_stats), + TuringState(state_inner, varinfo, f), ) end @@ -211,10 +208,9 @@ function AbstractMCMC.step( ) new_parameters = AbstractMCMC.getparams(f.model, state_inner) - new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters) new_stats = AbstractMCMC.getstats(state_inner) return ( - Turing.Inference.Transition(f.model, new_vi, new_stats), - TuringState(state_inner, new_vi, f), + DynamicPPL.ParamsWithStats(new_parameters, f, new_stats), + TuringState(state_inner, state.varinfo, f), ) end diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 1ff50a646..2d86c5afc 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -17,10 +17,9 @@ isgibbscomponent(::SGLD) = false isgibbscomponent(::SGHMC) = false isgibbscomponent(::SMC) = false -function can_be_wrapped(ctx::DynamicPPL.AbstractContext) - return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf -end -can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) +can_be_wrapped(::DynamicPPL.AbstractContext) = true +can_be_wrapped(::DynamicPPL.AbstractParentContext) = false +can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(DynamicPPL.childcontext(ctx)) # Basically like a `DynamicPPL.FixedContext` but # 1. Hijacks the tilde pipeline to fix variables. @@ -51,7 +50,7 @@ $(FIELDS) """ struct GibbsContext{ VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext -} <: DynamicPPL.AbstractContext +} <: DynamicPPL.AbstractParentContext """ the VarNames being sampled """ @@ -82,7 +81,6 @@ function GibbsContext(target_varnames, global_varinfo) return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext()) end -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) return GibbsContext( @@ -331,7 +329,7 @@ function AbstractMCMC.step( initial_params=initial_params, kwargs..., ) - return Transition(model, vi, nothing), GibbsState(vi, states) + return DynamicPPL.ParamsWithStats(vi, model), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -355,7 +353,7 @@ function AbstractMCMC.step_warmup( initial_params=initial_params, kwargs..., ) - return Transition(model, vi, nothing), GibbsState(vi, states) + return DynamicPPL.ParamsWithStats(vi, model), GibbsState(vi, states) end """ @@ -435,7 +433,7 @@ function AbstractMCMC.step( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi, nothing), GibbsState(vi, states) + return DynamicPPL.ParamsWithStats(vi, model), GibbsState(vi, states) end function AbstractMCMC.step_warmup( @@ -454,7 +452,7 @@ function AbstractMCMC.step_warmup( vi, states = gibbs_step_recursive( rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs... ) - return Transition(model, vi, nothing), GibbsState(vi, states) + return DynamicPPL.ParamsWithStats(vi, model), GibbsState(vi, states) end """ @@ -490,18 +488,12 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::ExternalSampler, - state::TuringState, - params::AbstractVarInfo, + ::DynamicPPL.Model, ::ExternalSampler, state::TuringState, params::AbstractVarInfo ) - logdensity = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.adtype - ) new_inner_state = AbstractMCMC.setparams!!( - AbstractMCMC.LogDensityModel(logdensity), state.state, params[:] + AbstractMCMC.LogDensityModel(state.ldf), state.state, params[:] ) - return TuringState(new_inner_state, params, logdensity) + return TuringState(new_inner_state, params, state.ldf) end function setparams_varinfo!!( @@ -515,11 +507,11 @@ function setparams_varinfo!!( z = state.z resize!(z.θ, length(θ_new)) z.θ .= θ_new - return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor) + return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.ldf) end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::PG, state::PGState, params::AbstractVarInfo + ::DynamicPPL.Model, ::PG, state::PGState, params::AbstractVarInfo ) return PGState(params, state.rng) end diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 101847b75..6eca47c87 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -12,6 +12,7 @@ struct HMCState{ THam<:AHMC.Hamiltonian, PhType<:AHMC.PhasePoint, TAdapt<:AHMC.Adaptation.AbstractAdaptor, + L<:DynamicPPL.LogDensityFunction, } vi::TV i::Int @@ -19,6 +20,7 @@ struct HMCState{ hamiltonian::THam z::PhType adaptor::TAdapt + ldf::L end ### @@ -225,8 +227,8 @@ function Turing.Inference.initialstep( kernel = make_ahmc_kernel(spl, ϵ) adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ) - transition = Transition(model, vi, NamedTuple()) - state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor) + transition = DynamicPPL.ParamsWithStats(theta, ldf, NamedTuple()) + state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor, ldf) return transition, state end @@ -270,8 +272,8 @@ function AbstractMCMC.step( end # Compute next transition and state. - transition = Transition(model, vi, t) - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) + transition = DynamicPPL.ParamsWithStats(t.z.θ, state.ldf, t.stat) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, state.ldf) return transition, newstate end diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 88f915d1f..541d1eba5 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -29,7 +29,7 @@ struct IS <: AbstractSampler end function Turing.Inference.initialstep( rng::AbstractRNG, model::Model, spl::IS, vi::AbstractVarInfo; kwargs... ) - return Transition(model, vi, nothing), nothing + return DynamicPPL.ParamsWithStats(vi, model), nothing end function AbstractMCMC.step( @@ -38,18 +38,12 @@ function AbstractMCMC.step( model = DynamicPPL.setleafcontext(model, ISContext(rng)) _, vi = DynamicPPL.evaluate!!(model, DynamicPPL.VarInfo()) vi = DynamicPPL.typed_varinfo(vi) - return Transition(model, vi, nothing), nothing -end - -# Calculate evidence. -function getlogevidence(samples::Vector{<:Transition}, ::IS, state) - return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples)) + return DynamicPPL.ParamsWithStats(vi, model), nothing end struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end -DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf() function DynamicPPL.tilde_assume!!( ctx::ISContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 833303b86..270b6327d 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -179,33 +179,33 @@ get_varinfo(s::MHState) = s.varinfo ##################### """ - set_namedtuple!(vi::VarInfo, nt::NamedTuple) + OldLogDensityFunction -Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. +This is a clone of pre-0.39 DynamicPPL.LogDensityFunction. It is needed for MH because MH +doesn't actually obey the LogDensityProblems.jl interface: it evaluates +'LogDensityFunctions' with a NamedTuple(!!) + +This means that we can't _really_ use DynamicPPL's LogDensityFunction, since that only +promises to obey the interface of being called with a vector. + +In particular, because `set_namedtuple!` acts on a VarInfo, we need to store the VarInfo +inside this struct (which DynamicPPL's LogDensityFunction no longer does). + +This SHOULD really be refactored to remove this requirement. """ -function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple) - for (n, vals) in pairs(nt) - vns = vi.metadata[n].vns - if vals isa AbstractVector - vals = unvectorize(vals) - end - if length(vns) == 1 - # Only one variable, assign the values to it - DynamicPPL.setindex!(vi, vals, vns[1]) - else - # Spread the values across the variables - length(vns) == length(vals) || error("Unequal number of variables and values") - for (vn, val) in zip(vns, vals) - DynamicPPL.setindex!(vi, val, vn) - end - end - end +struct OldLogDensityFunction{M<:DynamicPPL.Model,V<:DynamicPPL.AbstractVarInfo} + model::M + varinfo::V +end +function (f::OldLogDensityFunction)(x::AbstractVector) + vi = DynamicPPL.unflatten(f.varinfo, x) + _, vi = DynamicPPL.evaluate!!(f.model, vi) + return DynamicPPL.getlogjoint_internal(vi) end - # NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems # interface in that it gets evaluated with a NamedTuple. Hence we need this # method just to deal with MH. -function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) +function (f::OldLogDensityFunction)(x::NamedTuple) vi = deepcopy(f.varinfo) # Note that the NamedTuple `x` does NOT conform to the structure required for # `InitFromParams`. In particular, for models that look like this: @@ -223,8 +223,31 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) set_namedtuple!(vi, x) # Update log probability. _, vi_new = DynamicPPL.evaluate!!(f.model, vi) - lj = f.getlogdensity(vi_new) - return lj + return DynamicPPL.getlogjoint_internal(vi_new) +end + +""" + set_namedtuple!(vi::VarInfo, nt::NamedTuple) + +Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. +""" +function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple) + for (n, vals) in pairs(nt) + vns = vi.metadata[n].vns + if vals isa AbstractVector + vals = unvectorize(vals) + end + if length(vns) == 1 + # Only one variable, assign the values to it + DynamicPPL.setindex!(vi, vals, vns[1]) + else + # Spread the values across the variables + length(vns) == length(vals) || error("Unequal number of variables and values") + for (vn, val) in zip(vns, vals) + DynamicPPL.setindex!(vi, val, vn) + end + end + end end # unpack a vector if possible @@ -335,12 +358,7 @@ function propose!!(rng::AbstractRNG, prev_state::MHState, model::Model, spl::MH, # Make a new transition. model = DynamicPPL.setleafcontext(model, MHContext(rng)) - densitymodel = AMH.DensityModel( - Base.Fix1( - LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), - ), - ) + densitymodel = AMH.DensityModel(OldLogDensityFunction(model, vi)) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) # trans.params isa NamedTuple set_namedtuple!(vi, trans.params) @@ -370,12 +388,7 @@ function propose!!( # Make a new transition. model = DynamicPPL.setleafcontext(model, MHContext(rng)) - densitymodel = AMH.DensityModel( - Base.Fix1( - LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), - ), - ) + densitymodel = AMH.DensityModel(OldLogDensityFunction(model, vi)) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) # trans.params isa AbstractVector vi = DynamicPPL.unflatten(vi, trans.params) @@ -393,7 +406,8 @@ function Turing.Inference.initialstep( # just link everything before sampling. vi = maybe_link!!(vi, spl, spl.proposals, model) - return Transition(model, vi, nothing), MHState(vi, DynamicPPL.getlogjoint_internal(vi)) + return DynamicPPL.ParamsWithStats(vi, model), + MHState(vi, DynamicPPL.getlogjoint_internal(vi)) end function AbstractMCMC.step( @@ -404,13 +418,12 @@ function AbstractMCMC.step( # 2. A bunch of NamedTuples that specify the proposal space new_state = propose!!(rng, state, model, spl, spl.proposals) - return Transition(model, new_state.varinfo, nothing), new_state + return DynamicPPL.ParamsWithStats(new_state.varinfo, model), new_state end struct MHContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end -DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf() function DynamicPPL.tilde_assume!!( context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 7aadef09e..67ea770b8 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -7,7 +7,6 @@ struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end -DynamicPPL.NodeTrait(::ParticleMCMCContext) = DynamicPPL.IsLeaf() struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel model::M @@ -101,10 +100,6 @@ struct SMCState{P,F<:AbstractFloat} average_logevidence::F end -function getlogevidence(samples, ::SMC, state::SMCState) - return state.average_logevidence -end - function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, @@ -159,7 +154,9 @@ function Turing.Inference.initialstep( # Compute the first transition and the first state. stats = (; weight=weight, logevidence=logevidence) - transition = Transition(model, particle.model.f.varinfo, stats) + transition = DynamicPPL.ParamsWithStats( + deepcopy(particle.model.f.varinfo), model, stats + ) state = SMCState(particles, 2, logevidence) return transition, state @@ -178,7 +175,9 @@ function AbstractMCMC.step( # Compute the transition and the next state. stats = (; weight=weight, logevidence=state.average_logevidence) - transition = Transition(model, particle.model.f.varinfo, stats) + transition = DynamicPPL.ParamsWithStats( + deepcopy(particle.model.f.varinfo), model, stats + ) nextstate = SMCState(state.particles, index + 1, state.average_logevidence) return transition, nextstate @@ -239,21 +238,6 @@ end get_varinfo(state::PGState) = state.vi -function getlogevidence( - transitions::AbstractVector{<:Turing.Inference.Transition}, ::PG, ::PGState -) - logevidences = map(transitions) do t - if haskey(t.stat, :logevidence) - return t.stat.logevidence - else - # This should not really happen, but if it does we can handle it - # gracefully - return missing - end - end - return mean(logevidences) -end - function Turing.Inference.initialstep( rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, vi::AbstractVarInfo; kwargs... ) @@ -280,7 +264,9 @@ function Turing.Inference.initialstep( # Compute the first transition. _vi = reference.model.f.varinfo - transition = Transition(model, _vi, (; logevidence=logevidence)) + transition = DynamicPPL.ParamsWithStats( + deepcopy(_vi), model, (; logevidence=logevidence) + ) return transition, PGState(_vi, reference.rng) end @@ -316,7 +302,9 @@ function AbstractMCMC.step( # Compute the transition. _vi = newreference.model.f.varinfo - transition = Transition(model, _vi, (; logevidence=logevidence)) + transition = DynamicPPL.ParamsWithStats( + deepcopy(_vi), model, (; logevidence=logevidence) + ) return transition, PGState(_vi, newreference.rng) end diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index c4ec6c6f3..fe05d9096 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -12,14 +12,11 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) - vi = DynamicPPL.setaccs!!( - DynamicPPL.VarInfo(), - ( - DynamicPPL.ValuesAsInModelAccumulator(true), - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - ), - ) - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior()) - return Transition(model, vi, nothing; reevaluate=false), nothing + accs = DynamicPPL.AccumulatorTuple(( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + )) + _, vi = DynamicPPL.fast_evaluate!!(rng, model, DynamicPPL.InitFromPrior(), accs) + return DynamicPPL.ParamsWithStats(vi), nothing end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 267a21620..f9d5d4ade 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -45,9 +45,9 @@ function SGHMC(; return SGHMC(_learning_rate, _momentum_decay, adtype) end -struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} +struct SGHMCState{L,V<:AbstractVector{<:Real},T<:AbstractVector{<:Real}} logdensity::L - vi::V + params::V velocity::T end @@ -60,11 +60,12 @@ function Turing.Inference.initialstep( end # Compute initial sample and state. - sample = Transition(model, vi, nothing) ℓ = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) - state = SGHMCState(ℓ, vi, zero(vi[:])) + initial_params = vi[:] + sample = DynamicPPL.ParamsWithStats(initial_params, ℓ) + state = SGHMCState(ℓ, initial_params, zero(vi[:])) return sample, state end @@ -74,8 +75,7 @@ function AbstractMCMC.step( ) # Compute gradient of log density. ℓ = state.logdensity - vi = state.vi - θ = vi[:] + θ = state.params grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) # Update latent variables and velocity according to @@ -86,12 +86,9 @@ function AbstractMCMC.step( α = spl.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) - # Save new variables. - vi = DynamicPPL.unflatten(vi, θ) - # Compute next sample and state. - sample = Transition(model, vi, nothing) - newstate = SGHMCState(ℓ, vi, newv) + sample = DynamicPPL.ParamsWithStats(θ, ℓ) + newstate = SGHMCState(ℓ, θ, newv) return sample, newstate end @@ -176,9 +173,9 @@ function SGLD(; return SGLD(stepsize, adtype) end -struct SGLDState{L,V<:AbstractVarInfo} +struct SGLDState{L,V<:AbstractVector{<:Real}} logdensity::L - vi::V + params::V step::Int end @@ -191,11 +188,13 @@ function Turing.Inference.initialstep( end # Create first sample and state. - transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.stepsize(0)))) ℓ = DynamicPPL.LogDensityFunction( model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) - state = SGLDState(ℓ, vi, 1) + initial_params = vi[:] + stats = (; SGLD_stepsize=zero(spl.stepsize(0))) + transition = DynamicPPL.ParamsWithStats(initial_params, ℓ, stats) + state = SGLDState(ℓ, initial_params, 1) return transition, state end @@ -205,19 +204,16 @@ function AbstractMCMC.step( ) # Perform gradient step. ℓ = state.logdensity - vi = state.vi - θ = vi[:] + θ = state.params grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) step = state.step stepsize = spl.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) - # Save new variables. - vi = DynamicPPL.unflatten(vi, θ) - # Compute next sample and state. - transition = Transition(model, vi, (; SGLD_stepsize=stepsize)) - newstate = SGLDState(ℓ, vi, state.step + 1) + stats = (; SGLD_stepsize=stepsize) + transition = DynamicPPL.ParamsWithStats(θ, ℓ, stats) + newstate = SGLDState(ℓ, θ, state.step + 1) return transition, newstate end diff --git a/test/Project.toml b/test/Project.toml index 73361d794..3b9dcb4c2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/not-experimental"} diff --git a/test/ad.jl b/test/ad.jl index 287c92834..1fa8003fd 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -94,7 +94,7 @@ encountered. """ struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: - DynamicPPL.AbstractContext + DynamicPPL.AbstractParentContext child::ChildContext function ADTypeCheckContext(adbackend, child) @@ -108,7 +108,6 @@ end adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType -DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) return ADTypeCheckContext(adtype(c), child) @@ -138,14 +137,25 @@ Check that the element types in `vi` are compatible with the ADType of `context` Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. """ function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) + # If we are using InitFromPrior or InitFromUniform to generate new values, + # then the parameter type will be Any, so we should skip the check. + lc = DynamicPPL.leafcontext(context) + if lc isa DynamicPPL.InitContext{ + <:Any,<:Union{DynamicPPL.InitFromPrior,DynamicPPL.InitFromUniform} + } + return nothing + end + # Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or + # InitFromUniform, so this will fail. But on the bright side, you would never _really_ + # use AD with those strategies, so that's fine. The cases where you do want to + # use this are DefaultContext (i.e., old, slow, LogDensityFunction) and + # InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and + # both of those give you sensible results for `get_param_eltype`. + param_eltype = DynamicPPL.get_param_eltype(vi, context) valids = valid_eltypes(context) - for val in vi[:] - valtype = typeof(val) - if !any(valtype .<: valids) - throw(IncompatibleADTypeError(valtype, adtype(context))) - end + if !(any(param_eltype .<: valids)) + throw(IncompatibleADTypeError(param_eltype, adtype(context))) end - return nothing end # A bunch of tilde_assume/tilde_observe methods that just call the same method on the child @@ -200,10 +210,10 @@ end @testset "Expected: $expected_adtype, Actual: $actual_adtype" begin if actual_adtype == expected_adtype # Check that this does not throw an error. - sample(contextualised_tm, sampler, 2) + sample(contextualised_tm, sampler, 2; check_model=false) else @test_throws AbstractWrongADBackendError sample( - contextualised_tm, sampler, 2 + contextualised_tm, sampler, 2; check_model=false ) end end diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 6918eaddf..5f577818c 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -71,12 +71,12 @@ using Turing @testset "save/resume correctly reloads state" begin struct StaticSampler <: AbstractMCMC.AbstractSampler end function Turing.Inference.initialstep(rng, model, ::StaticSampler, vi; kwargs...) - return Turing.Inference.Transition(model, vi, nothing), vi + return DynamicPPL.ParamsWithStats(vi, model), vi end function AbstractMCMC.step( rng, model, ::StaticSampler, vi::DynamicPPL.AbstractVarInfo; kwargs... ) - return Turing.Inference.Transition(model, vi, nothing), vi + return DynamicPPL.ParamsWithStats(vi, model), vi end @model demo() = x ~ Normal() @@ -174,24 +174,10 @@ using Turing @test mean(chains, :m) ≈ 0 atol = 0.1 end - @testset "Vector chain_type" begin - chains = sample( - StableRNG(seed), gdemo_d(), Prior(), N; chain_type=Vector{NamedTuple} - ) - @test chains isa Vector{<:NamedTuple} - @test length(chains) == N - @test all(haskey(x, :lp) for x in chains) - @test all(haskey(x, :logprior) for x in chains) - @test all(haskey(x, :loglikelihood) for x in chains) - @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11 - @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 - end - @testset "accumulators are set correctly" begin - # Prior() uses `reevaluate=false` when constructing a - # `Turing.Inference.Transition`, so we had better make sure that it - # does capture colon-eq statements, as we can't rely on the default - # `Transition` constructor to do this for us. + # Prior() does not reevaluate the model when constructing a + # `DynamicPPL.ParamsWithStats`, so we had better make sure that it does capture + # colon-eq statements, and that the logp components are correctly calculated. @model function coloneq() x ~ Normal() 10.0 ~ Normal(x) @@ -639,32 +625,6 @@ using Turing ) end - @testset "getparams" begin - @model function e(x=1.0) - return x ~ Normal() - end - evi = DynamicPPL.VarInfo(e()) - @test isempty(Turing.Inference.getparams(e(), evi)) - - @model function f() - return x ~ Normal() - end - fvi = DynamicPPL.VarInfo(f()) - fparams = Turing.Inference.getparams(f(), fvi) - @test fparams[@varname(x)] == fvi[@varname(x)] - @test length(fparams) == 1 - - @model function g() - x ~ Normal() - return y ~ Poisson() - end - gvi = DynamicPPL.VarInfo(g()) - gparams = Turing.Inference.getparams(g(), gvi) - @test gparams[@varname(x)] == gvi[@varname(x)] - @test gparams[@varname(y)] == gvi[@varname(y)] - @test length(gparams) == 2 - end - @testset "empty model" begin @model function e(x=1.0) return x ~ Normal() diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index 36f53462e..8ca8ead06 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -187,8 +187,8 @@ function test_initial_params(model, sampler; kwargs...) transition2, _ = AbstractMCMC.step( rng2, model, sampler; initial_params=init_strategy, kwargs... ) - vn_to_val1 = DynamicPPL.OrderedDict(transition1.θ) - vn_to_val2 = DynamicPPL.OrderedDict(transition2.θ) + vn_to_val1 = transition1.params + vn_to_val2 = transition2.params for vn in union(keys(vn_to_val1), keys(vn_to_val2)) @test vn_to_val1[vn] ≈ vn_to_val2[vn] end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 1e3d5856c..d5b36dec1 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -22,12 +22,9 @@ using Turing: Inference using Turing.Inference: AdvancedHMC, AdvancedMH using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess -function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) - transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val - [first(vn_and_val)] - end +function check_transition_varnames(transition::DynamicPPL.ParamsWithStats, parent_varnames) # Varnames in `transition` should be subsumed by those in `parent_varnames`. - for vn in transition_varnames + for vn in keys(transition.params) @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) end end @@ -159,15 +156,15 @@ end # It is modified by the capture_targets_and_algs function. targets_and_algs = Any[] - function capture_targets_and_algs(sampler, context) - if DynamicPPL.NodeTrait(context) == DynamicPPL.IsLeaf() - return nothing - end + function capture_targets_and_algs(sampler, context::DynamicPPL.AbstractParentContext) if context isa Inference.GibbsContext push!(targets_and_algs, (context.target_varnames, sampler)) end return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context)) end + function capture_targets_and_algs(sampler, ::DynamicPPL.AbstractContext) + return nothing # Leaf context. + end # The methods that capture testing information for us. function AbstractMCMC.step( @@ -306,7 +303,7 @@ end ) spl.non_warmup_init_count += 1 vi = DynamicPPL.VarInfo(model) - return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) + return (DynamicPPL.ParamsWithStats(vi, model), VarInfoState(vi)) end function AbstractMCMC.step_warmup( @@ -314,7 +311,7 @@ end ) spl.warmup_init_count += 1 vi = DynamicPPL.VarInfo(model) - return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) + return (DynamicPPL.ParamsWithStats(vi, model), VarInfoState(vi)) end function AbstractMCMC.step( @@ -325,7 +322,7 @@ end kwargs..., ) spl.non_warmup_count += 1 - return Turing.Inference.Transition(model, s.vi, nothing), s + return DynamicPPL.ParamsWithStats(s.vi, model, nothing), s end function AbstractMCMC.step_warmup( @@ -336,7 +333,7 @@ end kwargs..., ) spl.warmup_count += 1 - return Turing.Inference.Transition(model, s.vi, nothing), s + return DynamicPPL.ParamsWithStats(s.vi, model, nothing), s end @model f() = x ~ Normal() @@ -481,12 +478,13 @@ end ::Type{MCMCChains.Chains}; kwargs..., ) - samples isa Vector{<:Inference.Transition} || error("incorrect transitions") + samples isa Vector{<:DynamicPPL.ParamsWithStats} || + error("incorrect transitions") return nothing end function callback(rng, model, sampler, sample, state, i; kwargs...) - sample isa Inference.Transition || error("incorrect sample") + sample isa DynamicPPL.ParamsWithStats || error("incorrect sample") return nothing end