diff --git a/include/pyoptinterface/cppad_interface.hpp b/include/pyoptinterface/cppad_interface.hpp index 86adb7c..b3b020e 100644 --- a/include/pyoptinterface/cppad_interface.hpp +++ b/include/pyoptinterface/cppad_interface.hpp @@ -33,7 +33,7 @@ ADFunDouble sparse_hessian(const ADFunDouble &f, const sparsity_pattern_t &patte // Transform ExpressionGraph to CppAD function ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph); -ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph); +ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate = true); struct CppADAutodiffGraph { diff --git a/lib/cppad_interface.cpp b/lib/cppad_interface.cpp index 9112553..26cd2b4 100644 --- a/lib/cppad_interface.cpp +++ b/lib/cppad_interface.cpp @@ -469,7 +469,7 @@ ADFunDouble cppad_trace_graph_constraints(const ExpressionGraph &graph) return f; } -ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph) +ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph, bool aggregate) { ankerl::unordered_dense::map> seen_expressions; @@ -503,15 +503,22 @@ ADFunDouble cppad_trace_graph_objective(const ExpressionGraph &graph) y[i] = cppad_trace_expression(graph, output, x, p, seen_expressions); } - CppAD::AD y_sum = 0.0; - for (size_t i = 0; i < N_outputs; i++) + ADFunDouble f; + + if (aggregate) + { + CppAD::AD y_sum = 0.0; + for (size_t i = 0; i < N_outputs; i++) + { + y_sum += y[i]; + } + f.Dependent(x, {y_sum}); + } + else { - y_sum += y[i]; + f.Dependent(x, y); } - ADFunDouble f; - f.Dependent(x, {y_sum}); - return f; } diff --git a/lib/cppad_interface_ext.cpp b/lib/cppad_interface_ext.cpp index 87ec39c..0175dda 100644 --- a/lib/cppad_interface_ext.cpp +++ b/lib/cppad_interface_ext.cpp @@ -183,6 +183,6 @@ NB_MODULE(cppad_interface_ext, m) .def_ro("hessian", &CppADAutodiffGraph::hessian_graph); m.def("cppad_trace_graph_constraints", cppad_trace_graph_constraints); - m.def("cppad_trace_graph_objective", cppad_trace_graph_objective); + m.def("cppad_trace_graph_objective", cppad_trace_graph_objective, nb::arg("graph"), nb::arg("aggregate") = true); m.def("cppad_autodiff", &cppad_autodiff); } diff --git a/lib/knitro_model.cpp b/lib/knitro_model.cpp index 18b3347..88fbd5c 100644 --- a/lib/knitro_model.cpp +++ b/lib/knitro_model.cpp @@ -806,7 +806,9 @@ void KNITROModel::_add_objective_callback(ExpressionGraph *graph, const Outputs evaluator->eval_hess(req->x, req->sigma, res->hess, true); return 0; }; - auto trace = cppad_trace_graph_objective; + auto trace = [](const ExpressionGraph &graph) { + return cppad_trace_graph_objective(graph, false); + }; _add_callback_impl(*graph, outputs.obj_idxs, {}, trace, f, g, h); }