1010
1111from qutip_qoc ._optimizer import _global_local_optimization
1212from qutip_qoc ._time import _TimeInterval
13- from qutip_qoc ._rl import _RL
13+
14+ import qutip as qt
1415from qutip_qoc ._genetic import _GENETIC
1516
17+ try :
18+ from qutip_qoc ._rl import _RL
19+ _rl_available = True
20+ except ImportError :
21+ _rl_available = False
22+
1623__all__ = ["optimize_pulses" ]
1724
1825
@@ -24,6 +31,7 @@ def optimize_pulses(
2431 optimizer_kwargs = None ,
2532 minimizer_kwargs = None ,
2633 integrator_kwargs = None ,
34+ optimization_type = None ,
2735):
2836 """
2937 Run GOAT, JOPT, GRAPE, CRAB or RL optimization.
@@ -120,6 +128,11 @@ def optimize_pulses(
120128 Options for the solver, see :obj:`MESolver.options` and
121129 `Integrator <./classes.html#classes-ode>`_ for a list of all options.
122130
131+ optimization_type : str, optional
132+ Type of optimization. By default, QuTiP-QOC will try to automatically determine
133+ whether this is a *state transfer* or a *gate synthesis* problem. Set this
134+ flag to ``"state_transfer"`` or ``"gate_synthesis"`` to set the mode manually.
135+
123136 Returns
124137 -------
125138 result : :class:`qutip_qoc.Result`
@@ -183,10 +196,43 @@ def optimize_pulses(
183196 "maxiter" : algorithm_kwargs .get ("max_iter" , 1000 ),
184197 "gtol" : algorithm_kwargs .get ("min_grad" , 0.0 if alg == "CRAB" else 1e-8 ),
185198 }
199+ # Iterate over objectives and convert initial and target states based on the optimization type
200+ for objective in objectives :
201+ H_list = objective .H if isinstance (objective .H , list ) else [objective .H ]
202+ if any (qt .issuper (H_i ) for H_i in H_list ):
203+ if isinstance (optimization_type , str ) and optimization_type .lower () == "state_transfer" :
204+ if qt .isket (objective .initial ):
205+ objective .initial = qt .operator_to_vector (qt .ket2dm (objective .initial ))
206+ elif qt .isoper (objective .initial ):
207+ objective .initial = qt .operator_to_vector (objective .initial )
208+ if qt .isket (objective .target ):
209+ objective .target = qt .operator_to_vector (qt .ket2dm (objective .target ))
210+ elif qt .isoper (objective .target ):
211+ objective .target = qt .operator_to_vector (objective .target )
212+ elif isinstance (optimization_type , str ) and optimization_type .lower () == "gate_synthesis" :
213+ objective .initial = qt .to_super (objective .initial )
214+ objective .target = qt .to_super (objective .target )
215+ elif optimization_type is None :
216+ if qt .isoper (objective .initial ) and qt .isoper (objective .target ):
217+ if np .isclose ((objective .initial ).tr (), 1 ) and np .isclose ((objective .target ).tr (), 1 ):
218+ objective .initial = qt .operator_to_vector (objective .initial )
219+ objective .target = qt .operator_to_vector (objective .target )
220+ else :
221+ objective .initial = qt .to_super (objective .initial )
222+ objective .target = qt .to_super (objective .target )
223+ if qt .isket (objective .initial ):
224+ objective .initial = qt .operator_to_vector (qt .ket2dm (objective .initial ))
225+ if qt .isket (objective .target ):
226+ objective .target = qt .operator_to_vector (qt .ket2dm (objective .target ))
186227
187228 # prepare qtrl optimizers
188229 qtrl_optimizers = []
189230 if alg == "CRAB" or alg == "GRAPE" :
231+ dyn_type = "GEN_MAT"
232+ for objective in objectives :
233+ if any (qt .isoper (H_i ) for H_i in (objective .H if isinstance (objective .H , list ) else [objective .H ])):
234+ dyn_type = "UNIT"
235+
190236 if alg == "GRAPE" : # algorithm specific kwargs
191237 use_as_amps = True
192238 minimizer_kwargs .setdefault ("method" , "L-BFGS-B" ) # gradient
@@ -243,7 +289,7 @@ def optimize_pulses(
243289 "accuracy_factor" : None , # deprecated
244290 "alg_params" : alg_params ,
245291 "optim_params" : algorithm_kwargs .get ("optim_params" , None ),
246- "dyn_type" : algorithm_kwargs .get ("dyn_type" , "GEN_MAT" ),
292+ "dyn_type" : algorithm_kwargs .get ("dyn_type" , dyn_type ),
247293 "dyn_params" : algorithm_kwargs .get ("dyn_params" , None ),
248294 "prop_type" : algorithm_kwargs .get (
249295 "prop_type" , "DEF"
@@ -354,6 +400,12 @@ def optimize_pulses(
354400 qtrl_optimizers .append (qtrl_optimizer )
355401
356402 elif alg == "RL" :
403+ if not _rl_available :
404+ raise ImportError (
405+ "The required dependencies (gymnasium, stable-baselines3) for "
406+ "the reinforcement learning algorithm are not available."
407+ )
408+
357409 rl_env = _RL (
358410 objectives ,
359411 control_parameters ,
@@ -393,4 +445,4 @@ def optimize_pulses(
393445 minimizer_kwargs ,
394446 integrator_kwargs ,
395447 qtrl_optimizers ,
396- )
448+ )
0 commit comments