Skip to content

Commit 6b866e9

Browse files
Merge pull request #110 from SciML/gd/forward_mooncake
feat: Add forward mode Mooncake
2 parents ffd55ca + 97ffd68 commit 6b866e9

File tree

6 files changed

+49
-2
lines changed

6 files changed

+49
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ADTypes"
22
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors"]
4-
version = "1.16.0"
4+
version = "1.17.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Algorithmic differentiation:
2121
```@docs
2222
AutoForwardDiff
2323
AutoPolyesterForwardDiff
24+
AutoMooncakeForward
2425
```
2526

2627
Finite differences:

src/ADTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export AutoChainRules,
3838
AutoGTPSA,
3939
AutoModelingToolkit,
4040
AutoMooncake,
41+
AutoMooncakeForward,
4142
AutoPolyesterForwardDiff,
4243
AutoReverseDiff,
4344
AutoSymbolics,

src/dense.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,16 @@ end
276276
"""
277277
AutoMooncake
278278
279-
Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation.
279+
Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation in reverse mode.
280280
281281
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
282282
283+
!!! info
284+
285+
When forward mode became available in Mooncake.jl v0.4.147, another struct called [`AutoMooncakeForward`](@ref) was introduced.
286+
It was kept separate to avoid a breaking release of ADTypes.jl.
287+
[`AutoMooncake`](@ref) remains for reverse mode only.
288+
283289
# Constructors
284290
285291
AutoMooncake(; config=nothing)
@@ -294,6 +300,33 @@ end
294300

295301
mode(::AutoMooncake) = ReverseMode()
296302

303+
"""
304+
AutoMooncakeForward
305+
306+
Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation in forward mode.
307+
308+
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
309+
310+
!!! info
311+
312+
This struct was introduced when forward mode became available in Mooncake.jl v0.4.147.
313+
It was kept separate from [`AutoMooncake`](@ref) to avoid a breaking release of ADTypes.jl.
314+
[`AutoMooncake`](@ref) remains for reverse mode only.
315+
316+
# Constructors
317+
318+
AutoMooncakeForward(; config=nothing)
319+
320+
# Fields
321+
322+
- `config`: either `nothing` or an instance of `Mooncake.Config` -- see the docstring of `Mooncake.Config` for more information. `AutoForwardMooncake(; config=nothing)` is equivalent to `AutoForwardMooncake(; config=Mooncake.Config())`, i.e. the default configuration.
323+
"""
324+
Base.@kwdef struct AutoMooncakeForward{Tconfig} <: AbstractADType
325+
config::Tconfig = nothing
326+
end
327+
328+
mode(::AutoMooncakeForward) = ForwardMode()
329+
297330
"""
298331
AutoPolyesterForwardDiff{chunksize,T}
299332

test/dense.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,16 @@ end
136136
@test ad.config === nothing
137137
end
138138

139+
@testset "AutoMooncakeForward" begin
140+
ad = AutoMooncakeForward(; config = :config)
141+
@test ad isa AbstractADType
142+
@test ad isa AutoMooncakeForward
143+
@test mode(ad) isa ForwardMode
144+
@test ad.config === :config
145+
ad = AutoMooncakeForward()
146+
@test ad.config === nothing
147+
end
148+
139149
@testset "AutoPolyesterForwardDiff" begin
140150
ad = AutoPolyesterForwardDiff()
141151
@test ad isa AbstractADType

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ function every_ad_with_options()
7070
AutoForwardDiff(chunksize = 3, tag = :tag),
7171
AutoGTPSA(),
7272
AutoGTPSA(descriptor = Val(:descriptor)),
73+
AutoMooncake(; config = :config),
74+
AutoMooncakeForward(; config = :config),
7375
AutoPolyesterForwardDiff(),
7476
AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
7577
AutoReverseDiff(),

0 commit comments

Comments
 (0)