diff --git a/src/core/defcomp.jl b/src/core/defcomp.jl index 2b8da496d..1f8e387e9 100644 --- a/src/core/defcomp.jl +++ b/src/core/defcomp.jl @@ -25,7 +25,8 @@ function _generate_run_func(comp_name, module_name, args, body) global function $(func_name)($(p), #::Mimi.ComponentInstanceParameters, $(v), #::Mimi.ComponentInstanceVariables $(d), #::NamedTuple - $(t)) #::T <: Mimi.AbstractTimestep + $(t); #::T <: Mimi.AbstractTimestep + get_dim_keys::Function) #::Function $(body...) return nothing end diff --git a/src/core/instances.jl b/src/core/instances.jl index 4ba0cad1a..13e1fd578 100644 --- a/src/core/instances.jl +++ b/src/core/instances.jl @@ -312,18 +312,18 @@ function get_shifted_ts(ci, ts::VariableTimestep{TIMES}) where {TIMES} end end -function run_timestep(ci::AbstractComponentInstance, clock::Clock, dims::NamedTuple) +function run_timestep(ci::AbstractComponentInstance, clock::Clock, dims::NamedTuple, get_dim_keys::Function) if ci.run_timestep !== nothing && _runnable(ci, clock) - ci.run_timestep(parameters(ci), variables(ci), dims, get_shifted_ts(ci, clock.ts)) + ci.run_timestep(parameters(ci), variables(ci), dims, get_shifted_ts(ci, clock.ts); get_dim_keys=get_dim_keys) end return nothing end -function run_timestep(cci::AbstractCompositeComponentInstance, clock::Clock, dims::NamedTuple) +function run_timestep(cci::AbstractCompositeComponentInstance, clock::Clock, dims::NamedTuple, get_dim_keys::Function) if _runnable(cci, clock) for ci in components(cci) - run_timestep(ci, clock, dims) + run_timestep(ci, clock, dims, get_dim_keys) end end return nothing @@ -359,11 +359,19 @@ function Base.run(mi::ModelInstance, ntimesteps::Int=typemax(Int), # into timestep arrays. dim_val_named_tuple = NamedTuple(name => (name == :time ? timesteps(clock) : collect(values(dim))) for (name, dim) in dim_dict(mi.md)) + # Define dim_keys, a function that allows the component to return the keys + # of a given dimension. This is passed through and serves as a keyword argument + # to the run_timestep function of each component so they may access the dimension + # key information. + function get_dim_keys(dim_name::Symbol) + return dim_keys(mi, dim_name) + end + # recursively initializes all components init(mi, dim_val_named_tuple) while ! finished(clock) - run_timestep(mi, clock, dim_val_named_tuple) + run_timestep(mi, clock, dim_val_named_tuple, get_dim_keys) advance(clock) end diff --git a/test/test_components.jl b/test/test_components.jl index e5aba0d22..2b7da2c6f 100644 --- a/test/test_components.jl +++ b/test/test_components.jl @@ -160,4 +160,35 @@ ci = compinstance(m, :C) # Get the component instance @test ci.first == 2010 # The component instance's first and last values are the same as the comp def @test ci.last == 2090 + + # 3. Test using the dim_keys function in run_timestep + my_model = Model() + set_dimension!(my_model, :time, 2000:2010) # Set the time dimension + set_dimension!(my_model, :region, [:A, :B, :C]) # Set the region dimension + + @defcomp testcomp5 begin + + var1 = Variable(index=[time]) + par1 = Parameter(index=[time]) + + var2 = Variable{Int}(index=[region]) # hold region idxs 1:n + var3 = Variable{Symbol}(index=[region]) # hold region keys + + function run_timestep(p, v, d, t) + v.var1[t] = p.par1[t] + for region in d.region + v.var2[region] = region + v.var3[region] = get_dim_keys(:region)[region] + end + end + end + + add_comp!(my_model, testcomp5) + update_param!(my_model, :testcomp5, :par1, collect(1:11)) + + run(my_model) + + @test my_model[:testcomp5, :var2] == [1,2,3] + @test my_model[:testcomp5, :var3] == [:A, :B, :C] + end