Skip to content

Commit 7734a83

Browse files
committed
Link to JAX examples
1 parent ee7c7a6 commit 7734a83

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

doc/.template.coding.ipynb

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@
273273
"cell_type": "markdown",
274274
"metadata": {},
275275
"source": [
276-
"> You can also add arbitrary Python code to your module. This is added to the `swig.i-in` file using the SWIG `%pythoncode` directive. See the [PMI module](https://github.com/salilab/pmi/blob/develop/pyext/swig.i-in) for an example.\n",
276+
"> You can also add arbitrary Python code to your module. This is added to the `swig.i-in` file using the SWIG `pythoncode` directive. See the [PMI module](https://github.com/salilab/pmi/blob/develop/pyext/swig.i-in) for an example.\n",
277277
">\n",
278278
"> You can also add entire Python submodules by adding Python files to the `pyext/src` subdirectory. For example the file `pyext/src/my_python.py` can be imported in Python using `import IMP.foo.my_python`. This is also [used in the PMI module](https://github.com/salilab/pmi/tree/develop/pyext/src)."
279279
]
@@ -624,10 +624,13 @@
624624
"}\n",
625625
"```\n",
626626
"\n",
627-
"Our `_get_jax` method returns a new function which, given the current state of the model `X`, calculates and returns the same score as our original C++ implementation. `X` is provided by IMP and is a simple Python dict containing model data. For example `X['xyz']` contains all of the Cartesian\n",
628-
"coordinates. The additional parameters used by our `jax_restraint` function, namely the force constant `k` and the particle index `pi`, are extracted from the original C++ object using the corresponding getter methods and baked in to the function using `functools.partial`.\n",
627+
"Our `_get_jax` method returns a new function which, given the current state of the model `X`, calculates and returns the same score as our original C++ implementation. `X` is provided by IMP and is a simple Python dict containing model data. ",
628+
"For example `X['xyz']` contains all of the Cartesian\n",
629+
"coordinates as an N x 3 array. The additional parameters used by our `jax_restraint` function, namely the force constant `k` and the particle index `pi`, are extracted from the original C++ object using the corresponding getter methods and baked in to the function using ``functools.partial``.\n",
629630
"\n",
630-
"Note that one key feature of the JAX library is automatic differentiation by which analytic first derivatives of JAX functions are determined automatically by the JAX library. Thus, it is not necessary for us to calculate them explicitly (as is done in the C++ implementation)."
631+
"Note that one key feature of the JAX library is automatic differentiation by which analytic first derivatives of JAX functions are determined automatically by the JAX library. Thus, it is not necessary for us to calculate them explicitly (as is done in the C++ implementation).\n",
632+
"\n",
633+
"Other key IMP classes, such as ``IMP::Constraint``, ``IMP::PairScore``, and ``IMP::UnaryFunction`` can also be provided with a JAX implementation in the same fashion. See the [IMP.example module](https://github.com/salilab/imp/blob/develop/modules/example/pyext/IMP_example.jax.i) for some examples."
631634
]
632635
}
633636
],

doc/coding.ipynb

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@
289289
"cell_type": "markdown",
290290
"metadata": {},
291291
"source": [
292-
"> You can also add arbitrary Python code to your module. This is added to the `swig.i-in` file using the SWIG `%pythoncode` directive. See the [PMI module](https://github.com/salilab/pmi/blob/develop/pyext/swig.i-in) for an example.\n",
292+
"> You can also add arbitrary Python code to your module. This is added to the `swig.i-in` file using the SWIG `pythoncode` directive. See the [PMI module](https://github.com/salilab/pmi/blob/develop/pyext/swig.i-in) for an example.\n",
293293
">\n",
294294
"> You can also add entire Python submodules by adding Python files to the `pyext/src` subdirectory. For example the file `pyext/src/my_python.py` can be imported in Python using `import IMP.foo.my_python`. This is also [used in the PMI module](https://github.com/salilab/pmi/tree/develop/pyext/src)."
295295
]
@@ -640,10 +640,13 @@
640640
"}\n",
641641
"```\n",
642642
"\n",
643-
"Our `_get_jax` method returns a new function which, given the current state of the model `X`, calculates and returns the same score as our original C++ implementation. `X` is provided by IMP and is a simple Python dict containing model data. For example `X['xyz']` contains all of the Cartesian\n",
644-
"coordinates. The additional parameters used by our `jax_restraint` function, namely the force constant `k` and the particle index `pi`, are extracted from the original C++ object using the corresponding getter methods and baked in to the function using `functools.partial`.\n",
643+
"Our `_get_jax` method returns a new function which, given the current state of the model `X`, calculates and returns the same score as our original C++ implementation. `X` is provided by IMP and is a simple Python dict containing model data. ",
644+
"For example `X['xyz']` contains all of the Cartesian\n",
645+
"coordinates as an N x 3 array. The additional parameters used by our `jax_restraint` function, namely the force constant `k` and the particle index `pi`, are extracted from the original C++ object using the corresponding getter methods and baked in to the function using [functools.partial](https://docs.python.org/3/library/functools.html#functools.partial).\n",
645646
"\n",
646-
"Note that one key feature of the JAX library is automatic differentiation by which analytic first derivatives of JAX functions are determined automatically by the JAX library. Thus, it is not necessary for us to calculate them explicitly (as is done in the C++ implementation)."
647+
"Note that one key feature of the JAX library is automatic differentiation by which analytic first derivatives of JAX functions are determined automatically by the JAX library. Thus, it is not necessary for us to calculate them explicitly (as is done in the C++ implementation).\n",
648+
"\n",
649+
"Other key IMP classes, such as [IMP::Constraint](https://integrativemodeling.org/2.23.0/doc/ref/classIMP_1_1Constraint.html), [IMP::PairScore](https://integrativemodeling.org/2.23.0/doc/ref/classIMP_1_1PairScore.html), and [IMP::UnaryFunction](https://integrativemodeling.org/2.23.0/doc/ref/classIMP_1_1UnaryFunction.html) can also be provided with a JAX implementation in the same fashion. See the [IMP.example module](https://github.com/salilab/imp/blob/develop/modules/example/pyext/IMP_example.jax.i) for some examples."
647650
]
648651
}
649652
],

0 commit comments

Comments
 (0)