Google JAX: Difference between revisions
m →jit: bad spelling, weasel language, citation needed! |
Viktor Guer (talk | contribs) |
||
(43 intermediate revisions by 29 users not shown) | |||
Line 17: | Line 17: | ||
| latest release version = |
| latest release version = |
||
| latest release date = <!-- {{Start date and age|YYYY|MM|DD|df=yes/no}} --> |
| latest release date = <!-- {{Start date and age|YYYY|MM|DD|df=yes/no}} --> |
||
| latest preview version = v0. |
| latest preview version = v0.4.31 |
||
| latest preview date = {{Start date and age| |
| latest preview date = {{Start date and age|2024|07|30|df=yes}} |
||
| repo = {{URL|https://github.com/google/jax}} |
| repo = {{URL|https://github.com/google/jax}} |
||
| programming language = [[Python (programming language)|Python]], [[C++]] |
| programming language = [[Python (programming language)|Python]], [[C++]] |
||
Line 31: | Line 31: | ||
| website = <!-- {{URL|example.org}} or {{official URL}} --> |
| website = <!-- {{URL|example.org}} or {{official URL}} --> |
||
}} |
}} |
||
'''Google JAX''' is a machine learning framework for transforming numerical functions.<ref name=":0">{{Citation |title=JAX: Autograd and XLA |date=2022-06-18 |url=https://github.com/google/jax |archive-url=https://web.archive.org/web/20220618205214/https://github.com/google/jax |publisher=Google |bibcode=2021ascl.soft11002B |access-date=2022-06-18 |archive-date=2022-06-18|last1=Bradbury |first1=James |last2=Frostig |first2=Roy |last3=Hawkins |first3=Peter |last4=Johnson |first4=Matthew James |last5=Leary |first5=Chris |last6=MacLaurin |first6=Dougal |last7=Necula |first7=George |last8=Paszke |first8=Adam |last9=Vanderplas |first9=Jake |last10=Wanderman-Milne |first10=Skye |last11=Zhang |first11=Qiao |journal=Astrophysics Source Code Library }}</ref><ref>{{Cite journal |last1=Frostig |first1=Roy |last2=Johnson |first2=Matthew James |last3=Leary |first3=Chris |date=2018-02-02 |year=2018 |title=Compiling machine learning programs via high-level tracing |url=https://mlsys.org/Conferences/doc/2018/146.pdf |url-status=live |journal=MLsys |pages=1–3 |archive-url=https://web.archive.org/web/20220621153349/https://mlsys.org/Conferences/doc/2018/146.pdf |archive-date=2022-06-21}}</ref><ref>{{Cite web |title=Using JAX to accelerate our research |url=https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |url-status=live |archive-url=https://web.archive.org/web/20220618205746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |archive-date=2022-06-18 |access-date=2022-06-18 |website=www.deepmind.com |language=en}}</ref> It is described as bringing together a modified version of [https://github.com/HIPS/autograd autograd] (automatic obtaining of the gradient function through differentiation of a function) and [[TensorFlow]]'s [https://www.tensorflow.org/xla XLA] (Accelerated Linear Algebra). It is designed to follow the structure and workflow of [[NumPy]] as closely as possible and works with various existing frameworks such as [[TensorFlow]] and [[PyTorch]].<ref>{{Cite web |last=Lynley |first=Matthew |title=Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta |url=https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-url=https://web.archive.org/web/20220621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-date=2022-06-21 |access-date=2022-06-21 |website=Business Insider |language=en-US}}</ref><ref>{{Cite web |date=2022-04-25 |title=Why is Google's JAX so popular? |url=https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |url-status=live |archive-url=https://web.archive.org/web/20220618210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |archive-date=2022-06-18 |access-date=2022-06-18 |website=Analytics India Magazine |language=en-US}}</ref> The primary functions of JAX are:<ref name=":0" /> |
'''Google JAX''' is a machine learning framework for transforming numerical functions.<ref name=":0">{{Citation |title=JAX: Autograd and XLA |date=2022-06-18 |url=https://github.com/google/jax |archive-url=https://web.archive.org/web/20220618205214/https://github.com/google/jax |publisher=Google |bibcode=2021ascl.soft11002B |access-date=2022-06-18 |archive-date=2022-06-18 |last1=Bradbury |first1=James |last2=Frostig |first2=Roy |last3=Hawkins |first3=Peter |last4=Johnson |first4=Matthew James |last5=Leary |first5=Chris |last6=MacLaurin |first6=Dougal |last7=Necula |first7=George |last8=Paszke |first8=Adam |last9=Vanderplas |first9=Jake |last10=Wanderman-Milne |first10=Skye |last11=Zhang |first11=Qiao |journal=Astrophysics Source Code Library }}</ref><ref>{{Cite journal |last1=Frostig |first1=Roy |last2=Johnson |first2=Matthew James |last3=Leary |first3=Chris |date=2018-02-02 |year=2018 |title=Compiling machine learning programs via high-level tracing |url=https://mlsys.org/Conferences/doc/2018/146.pdf |url-status=live |journal=MLsys |pages=1–3 |archive-url=https://web.archive.org/web/20220621153349/https://mlsys.org/Conferences/doc/2018/146.pdf |archive-date=2022-06-21}}</ref><ref>{{Cite web |title=Using JAX to accelerate our research |url=https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |url-status=live |archive-url=https://web.archive.org/web/20220618205746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |archive-date=2022-06-18 |access-date=2022-06-18 |website=www.deepmind.com |language=en}}</ref> It is described as bringing together a modified version of [https://github.com/HIPS/autograd autograd] (automatic obtaining of the gradient function through differentiation of a function) and [[TensorFlow]]'s [https://www.tensorflow.org/xla XLA] (Accelerated Linear Algebra). It is designed to follow the structure and workflow of [[NumPy]] as closely as possible and works with various existing frameworks such as [[TensorFlow]] and [[PyTorch]].<ref>{{Cite web |last=Lynley |first=Matthew |title=Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta |url=https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-url=https://web.archive.org/web/20220621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-date=2022-06-21 |access-date=2022-06-21 |website=Business Insider |language=en-US}}</ref><ref>{{Cite web |date=2022-04-25 |title=Why is Google's JAX so popular? |url=https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |url-status=live |archive-url=https://web.archive.org/web/20220618210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |archive-date=2022-06-18 |access-date=2022-06-18 |website=Analytics India Magazine |language=en-US}}</ref> The primary functions of JAX are:<ref name=":0" /> |
||
# grad: automatic differentiation |
# grad: [[automatic differentiation]] |
||
# jit: compilation |
# jit: compilation |
||
# vmap: auto-vectorization |
# vmap: [[auto-vectorization]] |
||
# pmap: SPMD programming |
# pmap: [[Single program, multiple data]] (SPMD) programming |
||
== grad == |
== grad == |
||
Line 42: | Line 42: | ||
The below code demonstrates the '''grad''' function's automatic differentiation. |
The below code demonstrates the '''grad''' function's automatic differentiation. |
||
<syntaxhighlight lang=" |
<syntaxhighlight lang="numpy" line="1"> |
||
# imports |
# imports |
||
from jax import grad |
from jax import grad |
||
Line 61: | Line 61: | ||
The final line should outputː |
The final line should outputː |
||
<syntaxhighlight lang=" |
<syntaxhighlight lang="output"> |
||
0.19661194 |
0.19661194 |
||
</syntaxhighlight> |
</syntaxhighlight> |
||
== jit == |
== jit == |
||
{{Main|Just-in-time compilation}} |
|||
The below code demonstrates the '''jit''' function's optimization through fusion. |
The below code demonstrates the '''jit''' function's optimization through fusion. |
||
<syntaxhighlight lang=" |
<syntaxhighlight lang="numpy" line="1"> |
||
# imports |
# imports |
||
from jax import jit |
from jax import jit |
||
Line 89: | Line 88: | ||
</syntaxhighlight> |
</syntaxhighlight> |
||
The computation time for |
The computation time for {{code|jit_cube}} (line #17) should be noticeably shorter than that for {{code|cube}} (line #16). Increasing the values on line #7, will further exacerbate the difference. |
||
== vmap == |
== vmap == |
||
{{Main| |
{{Main|Array programming}} |
||
The below code demonstrates the '''vmap''' function's vectorization. |
The below code demonstrates the '''vmap''' function's vectorization. |
||
<syntaxhighlight lang=" |
<syntaxhighlight lang="numpy" line="1"> |
||
# imports |
# imports |
||
from |
from jax import vmap partial |
||
from jax import vmap |
|||
import jax.numpy as jnp |
import jax.numpy as jnp |
||
# define function |
# define function |
||
def grads(self, inputs): |
def grads(self, inputs): |
||
in_grad_partial = partial(self._net_grads, self._net_params) |
in_grad_partial = jax.partial(self._net_grads, self._net_params) |
||
grad_vmap = jax.vmap(in_grad_partial) |
grad_vmap = jax.vmap(in_grad_partial) |
||
rich_grads = grad_vmap(inputs) |
rich_grads = grad_vmap(inputs) |
||
Line 116: | Line 114: | ||
== pmap == |
== pmap == |
||
{{Main|Automatic parallelization}} |
|||
The below code demonstrates the '''pmap''' function's parallelization for matrix multiplication. |
The below code demonstrates the '''pmap''' function's parallelization for matrix multiplication. |
||
<syntaxhighlight lang=" |
<syntaxhighlight lang="numpy" line="1"> |
||
# import pmap and random from JAX; import JAX NumPy |
# import pmap and random from JAX; import JAX NumPy |
||
from jax import pmap, random |
from jax import pmap, random |
||
Line 138: | Line 135: | ||
The final line should print the valuesː |
The final line should print the valuesː |
||
<syntaxhighlight lang=" |
<syntaxhighlight lang="output"> |
||
[1.1566595 1.1805978] |
[1.1566595 1.1805978] |
||
</syntaxhighlight> |
</syntaxhighlight> |
||
== Libraries using Jax == |
|||
Several python libraries use Jax as a backend, including: |
|||
* Flax, a high level [[neural network]] library initially developed by [[Google Brain]].<ref>{{Citation |title=Flax: A neural network library and ecosystem for JAX designed for flexibility |date=2022-07-29 |url=https://github.com/google/flax |publisher=Google |access-date=2022-07-29}}</ref> |
|||
* Haiku, an [[Object-oriented programming|object-oriented]] library for [[Neural network|neural networks]] developed by [[DeepMind]].<ref>{{Citation |title=Haiku: Sonnet for JAX |date=2022-07-29 |url=https://github.com/deepmind/dm-haiku |publisher=DeepMind |access-date=2022-07-29}}</ref> |
|||
* Equinox, a library that revolves around the idea of representing parameterised functions (including [[Neural network|neural networks]]) as PyTrees. It was created by Patrick Kidger.<ref>{{Citation |last=Kidger |first=Patrick |title=Equinox |date=2022-07-29 |url=https://github.com/patrick-kidger/equinox |access-date=2022-07-29}}</ref> |
|||
* Optax, a library for gradient processing and [[Mathematical optimization|optimisation]] developed by [[DeepMind]].<ref>{{Citation |title=Optax |date=2022-07-28 |url=https://github.com/deepmind/optax |publisher=DeepMind |access-date=2022-07-29}}</ref> |
|||
* RLax, a library for developing [[reinforcement learning]] agents developed by [[DeepMind]].<ref>{{Citation |title=RLax |date=2022-07-29 |url=https://github.com/deepmind/rlax |publisher=DeepMind |access-date=2022-07-29}}</ref> |
|||
== See also == |
== See also == |
||
Line 157: | Line 145: | ||
* [[PyTorch]] |
* [[PyTorch]] |
||
* [[CUDA]] |
* [[CUDA]] |
||
* [[Automatic differentiation]] |
|||
* [[Just-in-time compilation]] |
|||
* [[Vectorization (disambiguation)|Vectorization]] |
|||
* [[Automatic parallelization]] |
|||
== External links == |
== External links == |
||
Line 167: | Line 151: | ||
}} |
}} |
||
* [[TensorFlow]]'s XLAː {{URL|https://www.tensorflow.org/xla}} (Accelerated Linear Algebra) |
* [[TensorFlow]]'s XLAː {{URL|https://www.tensorflow.org/xla}} (Accelerated Linear Algebra) |
||
* |
* [[YouTube]] TensorFlow Channel "Intro to JAX: Accelerating Machine Learning research": {{URL|https://www.youtube.com/watch?v=WdTeDXsOSj4}} |
||
* Original paperː {{URL|https://mlsys.org/Conferences/doc/2018/146.pdf}} |
* Original paperː {{URL|https://mlsys.org/Conferences/doc/2018/146.pdf}} |
||
Line 173: | Line 157: | ||
{{reflist}} |
{{reflist}} |
||
{{ |
{{differentiable computing}} |
||
{{Google LLC}} |
|||
[[Category:Machine learning]] |
[[Category:Machine learning]] |
Latest revision as of 21:09, 25 November 2024
Developer(s) | |
---|---|
Preview release | v0.4.31
/ 30 July 2024 |
Repository | github |
Written in | Python, C++ |
Operating system | Linux, macOS, Windows |
Platform | Python, NumPy |
Size | 9.0 MB |
Type | Machine learning |
License | Apache 2.0 |
Website | jax |
Google JAX is a machine learning framework for transforming numerical functions.[1][2][3] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[4][5] The primary functions of JAX are:[1]
- grad: automatic differentiation
- jit: compilation
- vmap: auto-vectorization
- pmap: Single program, multiple data (SPMD) programming
grad
[edit]The below code demonstrates the grad function's automatic differentiation.
# imports
from jax import grad
import jax.numpy as jnp
# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
The final line should outputː
0.19661194
jit
[edit]The below code demonstrates the jit function's optimization through fusion.
# imports
from jax import jit
import jax.numpy as jnp
# define the cube function
def cube(x):
return x * x * x
# generate data
x = jnp.ones((10000, 10000))
# create the jit version of the cube function
jit_cube = jit(cube)
# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
The computation time for jit_cube
(line #17) should be noticeably shorter than that for cube
(line #16). Increasing the values on line #7, will further exacerbate the difference.
vmap
[edit]The below code demonstrates the vmap function's vectorization.
# imports
from jax import vmap partial
import jax.numpy as jnp
# define function
def grads(self, inputs):
in_grad_partial = jax.partial(self._net_grads, self._net_params)
grad_vmap = jax.vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
pmap
[edit]The below code demonstrates the pmap function's parallelization for matrix multiplication.
# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
The final line should print the valuesː
[1.1566595 1.1805978]
See also
[edit]External links
[edit]- Documentationː jax
.readthedocs .io - Colab (Jupyter/iPython) Quickstart Guideː colab
.research .google .com /github /google /jax /blob /main /docs /notebooks /quickstart .ipynb - TensorFlow's XLAː www
.tensorflow .org /xla (Accelerated Linear Algebra) - YouTube TensorFlow Channel "Intro to JAX: Accelerating Machine Learning research": www
.youtube .com /watch?v=WdTeDXsOSj4 - Original paperː mlsys
.org /Conferences /doc /2018 /146 .pdf
References
[edit]- ^ a b Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18
- ^ Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.
{{cite journal}}
: CS1 maint: date and year (link) - ^ "Using JAX to accelerate our research". www.deepmind.com. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
- ^ Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21.
- ^ "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18.