@@ -1197,3 +1197,197 @@ function generate_update_b(sys::System, b::AbstractVector; expression = Val{true
11971197 return maybe_compile_function (expression, wrap_gfw, (1 , 1 , is_split (sys)), res;
11981198 eval_expression, eval_module)
11991199end
1200+
1201+ # f1 = rest
1202+ # f2 = A * x + B * x2 + C
1203+ function calculate_split_form (sys:: System ; sparse = false )
1204+ rhss = [eq. rhs for eq in full_equations (sys)]
1205+ dvs = unknowns (sys)
1206+ A, B, x2, C = semiquadratic_form (rhss, dvs)
1207+ if ! sparse
1208+ A = collect (A)
1209+ B = collect (B)
1210+ end
1211+ A = unwrap .(A)
1212+ B = unwrap .(B)
1213+ x2 = unwrap .(x2)
1214+ C = unwrap .(C)
1215+
1216+ return A, B, x2, C
1217+ end
1218+
1219+ const DIFFCACHE_PARAM_NAME = :__mtk_diffcache
1220+
1221+ function get_diffcache_param (:: Type{T} ) where {T}
1222+ toconstant (Symbolics. variable (
1223+ DIFFCACHE_PARAM_NAME; T = DiffCache{Vector{T}, Vector{T}}))
1224+ end
1225+
1226+ # x2
1227+ const BILINEAR_CACHEVAR = unwrap (only (@constants bilinear_xₘₜₖ:: Vector{Real} ))
1228+ # A
1229+ const LINEAR_MATRIX_PARAM_NAME = :linear_Aₘₜₖ
1230+ function get_linear_matrix_param (size:: NTuple{2, Int} )
1231+ m, n = size
1232+ unwrap (only (@constants linear_Aₘₜₖ[1 : m, 1 : n]))
1233+ end
1234+ # B
1235+ const BILINEAR_MATRIX_PARAM_NAME = :bilinear_Bₘₜₖ
1236+ function get_bilinear_matrix_param (size:: NTuple{2, Int} )
1237+ m, n = size
1238+ unwrap (only (@constants bilinear_Bₘₜₖ[1 : m, 1 : n]))
1239+ end
1240+
1241+ function generate_semiquadratic_functions (
1242+ sys:: System , A, B, x2, C; expression = Val{true }, wrap_gfw = Val{false },
1243+ eval_expression = false , eval_module = @__MODULE__ , kwargs... )
1244+ linear_matrix_param = unwrap (getproperty (sys, LINEAR_MATRIX_PARAM_NAME))
1245+ bilinear_matrix_param = unwrap (getproperty (sys, BILINEAR_MATRIX_PARAM_NAME))
1246+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
1247+ dvs = unknowns (sys)
1248+ ps = reorder_parameters (sys)
1249+ # Codegen is a bit manual, and we're manually creating an efficient IIP function.
1250+ # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second
1251+ # argument.
1252+ iip_x = generated_argument_name (2 )
1253+ oop_x = generated_argument_name (1 )
1254+
1255+ f1_iip_ir = Assignment[Assignment (BILINEAR_CACHEVAR,
1256+ term (view,
1257+ term (PreallocationTools. get_tmp,
1258+ diffcache_par, Symbolics. DEFAULT_OUTSYM),
1259+ 1 : length (x2)))
1260+ # write to x2
1261+ Assignment (:__tmp1 , SetArray (false , BILINEAR_CACHEVAR, x2))
1262+ # out .= C
1263+ Assignment (
1264+ :__tmp2 , SetArray (false , Symbolics. DEFAULT_OUTSYM, C))
1265+ # mul!(out, B, x2, 1, 1)
1266+ Assignment (:__tmp3 ,
1267+ term (mul!, Symbolics. DEFAULT_OUTSYM, bilinear_matrix_param,
1268+ BILINEAR_CACHEVAR, true , true ))]
1269+ f1_iip = build_function_wrapper (
1270+ sys, nothing , Symbolics. DEFAULT_OUTSYM, dvs, ps... , get_iv (sys); p_start = 3 ,
1271+ extra_assignments = f1_iip_ir, expression = Val{true }, kwargs... )
1272+ f1_oop = build_function_wrapper (
1273+ sys, term (+ , term (* , bilinear_matrix_param, x2), C), dvs, ps... ,
1274+ get_iv (sys); expression = Val{true }, iip_config = (true , false ), kwargs... )
1275+
1276+ f2_iip_ir = Assignment[
1277+ Assignment (
1278+ :__tmp1 , term (mul!, Symbolics. DEFAULT_OUTSYM, linear_matrix_param, iip_x))
1279+ ]
1280+ f2_iip = build_function_wrapper (
1281+ sys, nothing , Symbolics. DEFAULT_OUTSYM, dvs, ps... , get_iv (sys); p_start = 3 ,
1282+ extra_assignments = f2_iip_ir, expression = Val{true }, kwargs... )
1283+ f2_oop = build_function_wrapper (
1284+ sys, term (* , linear_matrix_param, oop_x), dvs, ps... , get_iv (sys);
1285+ expression = Val{true }, iip_config = (true , false ), kwargs... )
1286+
1287+ f1 = maybe_compile_function (expression, wrap_gfw, (2 , 3 , is_split (sys)),
1288+ (f1_oop, f1_iip); eval_expression, eval_module)
1289+ f2 = maybe_compile_function (expression, wrap_gfw, (2 , 3 , is_split (sys)),
1290+ (f2_oop, f2_iip); eval_expression, eval_module)
1291+ return f1, f2
1292+ end
1293+
1294+ function calculate_semiquadratic_jacobian (
1295+ sys:: System , B, x2, C; sparse = false , massmatrix = calculate_massmatrix (sys))
1296+ dvs = unknowns (sys)
1297+ if sparse
1298+ x2jac = Symbolics. sparsejacobian (x2, dvs)
1299+ Cjac = Symbolics. sparsejacobian (C, dvs)
1300+ else
1301+ x2jac = Symbolics. jacobian (x2, dvs)
1302+ Cjac = Symbolics. jacobian (C, dvs)
1303+ end
1304+
1305+ f1jac = B * x2jac + Cjac
1306+
1307+ if sparse
1308+ for i in 1 : length (dvs)
1309+ massmatrix[i, i] == 0 && continue
1310+ _iszero (f1jac[i, i]) || continue
1311+ f1jac[i, i] = 1
1312+ f1jac[i, i] = 0
1313+ end
1314+ end
1315+
1316+ return f1jac, x2jac, Cjac
1317+ end
1318+
1319+ const COLPTR_PARAM = unwrap (only (@parameters __mtk_colptr:: Vector{Int} ))
1320+ const ROWVAL_PARAM = unwrap (only (@parameters __mtk_rowval:: Vector{Int} ))
1321+
1322+ function generate_semiquadratic_jacobian (
1323+ sys:: System , B, x2, C, f1jac, x2jac, Cjac; sparse = false ,
1324+ expression = Val{true }, wrap_gfw = Val{false },
1325+ eval_expression = false , eval_module = @__MODULE__ , kwargs... )
1326+ if sparse
1327+ @assert is_parameter (sys, COLPTR_PARAM)
1328+ @assert is_parameter (sys, ROWVAL_PARAM)
1329+ end
1330+ bilinear_matrix_param = unwrap (getproperty (sys, BILINEAR_MATRIX_PARAM_NAME))
1331+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
1332+ dvs = unknowns (sys)
1333+ ps = reorder_parameters (sys)
1334+ # Codegen is a bit manual, and we're manually creating an efficient IIP function.
1335+ # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second
1336+ # argument.
1337+ iip_x = generated_argument_name (2 )
1338+ oop_x = generated_argument_name (1 )
1339+
1340+ iip_ir = Assignment[]
1341+ push! (iip_ir,
1342+ Assignment (:__mtk_preallocbuf ,
1343+ term (PreallocationTools. get_tmp, diffcache_par, Symbolics. DEFAULT_OUTSYM)))
1344+ if sparse
1345+ push! (
1346+ iip_ir, Assignment (:__mtk_nzvals , term (view, :__mtk_preallocbuf , 1 : nnz (x2jac))))
1347+ push! (iip_ir, Assignment (:__tmp1 , SetArray (false , :__mtk_nzvals , x2jac. nzvals)))
1348+ push! (iip_ir,
1349+ Assignment (:__mtk_x2jacbuf ,
1350+ term (SparseMatrixCSC, size (x2jac)... ,
1351+ COLPTR_PARAM, ROWVAL_PARAM, :__mtk_nzvals )))
1352+ cjac_idxs = AtIndex[]
1353+ for (i, j, v) in zip (findnz (Cjac)... )
1354+ push! (cjac_idxs, AtIndex (CartesianIndex (i, j), v))
1355+ end
1356+ else
1357+ push! (iip_ir,
1358+ Assignment (:__mtk_x2jacbuf ,
1359+ term (reshape, term (view, :__mtk_preallocbuf , 1 : length (x2jac)), size (x2jac))))
1360+ push! (iip_ir, Assignment (:__tmp1 , SetArray (false , :__mtk_x2jacbuf , x2jac)))
1361+ cjac_idxs = AtIndex[]
1362+ for i in eachindex (Cjac)
1363+ _iszero (Cjac[i]) && continue
1364+ push! (cjac_idxs, AtIndex (i, Cjac[i]))
1365+ end
1366+ end
1367+ push! (iip_ir, Assignment (:__tmp2 , SetArray (false , Symbolics. DEFAULT_OUTSYM, cjac_idxs)))
1368+ push! (iip_ir,
1369+ Assignment (:__tmp3 ,
1370+ term (mul!, Symbolics. DEFAULT_OUTSYM,
1371+ bilinear_matrix_param, :__mtk_x2jacbuf , true , true )))
1372+
1373+ jaciip = build_function_wrapper (
1374+ sys, nothing , Symbolics. DEFAULT_OUTSYM, dvs, ps... , get_iv (sys);
1375+ p_start = 3 , extra_assignments = iip_ir, expression = Val{true }, kwargs... )
1376+
1377+ make_x2 = if sparse
1378+ MakeSparseArray (x2jac)
1379+ else
1380+ MakeArray (x2jac, generated_argument_name (1 ))
1381+ end
1382+ make_cjac = if sparse
1383+ MakeSparseArray (Cjac)
1384+ else
1385+ MakeArray (Cjac, generated_argument_name (1 ))
1386+ end
1387+ oop_expr = term (+ , term (* , bilinear_matrix_param, make_x2), Cjac)
1388+ jacoop = build_function_wrapper (
1389+ sys, oop_expr, dvs, ps... , get_iv (sys); expression = Val{true }, kwargs... )
1390+
1391+ return maybe_compile_function (expression, wrap_gfw, (2 , 3 , is_split (sys)),
1392+ (jacoop, jaciip); eval_expression, eval_module)
1393+ end
0 commit comments