Jump to content

Google JAX: Difference between revisions

From Wikipedia, the free encyclopedia
Content deleted Content added
Citation bot (talk | contribs)
Added date. | Use this bot. Report bugs. | Suggested by Dominic3203 | Category:Machine learning | #UCB_Category 43/230
Jazzbox (talk | contribs)
m I don't think you are supposed to put a comma immediately to the left of a verb as was the case
Line 29: Line 29:
| website = <!-- {{URL|example.org}} or {{official URL}} -->
| website = <!-- {{URL|example.org}} or {{official URL}} -->
}}
}}
'''Google JAX''' is a machine learning framework for transforming numerical functions, to be used in [[Python (programming language)|Python]].<ref name=":0">{{Citation |title=JAX: Autograd and XLA |date=2022-06-18 |url=https://github.com/google/jax |archive-url=https://web.archive.org/web/20220618205214/https://github.com/google/jax |publisher=Google |bibcode=2021ascl.soft11002B |access-date=2022-06-18 |archive-date=2022-06-18|last1=Bradbury |first1=James |last2=Frostig |first2=Roy |last3=Hawkins |first3=Peter |last4=Johnson |first4=Matthew James |last5=Leary |first5=Chris |last6=MacLaurin |first6=Dougal |last7=Necula |first7=George |last8=Paszke |first8=Adam |last9=Vanderplas |first9=Jake |last10=Wanderman-Milne |first10=Skye |last11=Zhang |first11=Qiao |journal=Astrophysics Source Code Library }}</ref><ref>{{Cite journal |last1=Frostig |first1=Roy |last2=Johnson |first2=Matthew James |last3=Leary |first3=Chris |date=2018-02-02 |title=Compiling machine learning programs via high-level tracing |url=https://mlsys.org/Conferences/doc/2018/146.pdf |url-status=live |journal=MLsys |pages=1–3 |archive-url=https://web.archive.org/web/20220621153349/https://mlsys.org/Conferences/doc/2018/146.pdf |archive-date=2022-06-21}}</ref><ref>{{Cite web |title=Using JAX to accelerate our research |url=https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |url-status=live |archive-url=https://web.archive.org/web/20220618205746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |archive-date=2022-06-18 |access-date=2022-06-18 |website=www.deepmind.com |date=4 December 2020 |language=en}}</ref> It is described as bringing together a modified version of autograd<ref>{{Citation |title=HIPS/autograd |date=2024-03-27 |url=https://github.com/HIPS/autograd |access-date=2024-03-28 |publisher=Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton}}</ref> (automatic obtaining of the gradient function through differentiation of a function) and [[TensorFlow]]'s XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of [[NumPy]] as closely as possible and works with various existing frameworks such as [[TensorFlow]] and [[PyTorch]].<ref>{{Cite web |last=Lynley |first=Matthew |title=Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta |url=https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-url=https://web.archive.org/web/20220621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-date=2022-06-21 |access-date=2022-06-21 |website=Business Insider |language=en-US}}</ref><ref>{{Cite web |date=2022-04-25 |title=Why is Google's JAX so popular? |url=https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |url-status=live |archive-url=https://web.archive.org/web/20220618210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |archive-date=2022-06-18 |access-date=2022-06-18 |website=Analytics India Magazine |language=en-US}}</ref> The primary functions of JAX are:<ref name=":0" />
'''Google JAX''' is a machine learning framework for transforming numerical functions to be used in [[Python (programming language)|Python]].<ref name=":0">{{Citation |title=JAX: Autograd and XLA |date=2022-06-18 |url=https://github.com/google/jax |archive-url=https://web.archive.org/web/20220618205214/https://github.com/google/jax |publisher=Google |bibcode=2021ascl.soft11002B |access-date=2022-06-18 |archive-date=2022-06-18|last1=Bradbury |first1=James |last2=Frostig |first2=Roy |last3=Hawkins |first3=Peter |last4=Johnson |first4=Matthew James |last5=Leary |first5=Chris |last6=MacLaurin |first6=Dougal |last7=Necula |first7=George |last8=Paszke |first8=Adam |last9=Vanderplas |first9=Jake |last10=Wanderman-Milne |first10=Skye |last11=Zhang |first11=Qiao |journal=Astrophysics Source Code Library }}</ref><ref>{{Cite journal |last1=Frostig |first1=Roy |last2=Johnson |first2=Matthew James |last3=Leary |first3=Chris |date=2018-02-02 |title=Compiling machine learning programs via high-level tracing |url=https://mlsys.org/Conferences/doc/2018/146.pdf |url-status=live |journal=MLsys |pages=1–3 |archive-url=https://web.archive.org/web/20220621153349/https://mlsys.org/Conferences/doc/2018/146.pdf |archive-date=2022-06-21}}</ref><ref>{{Cite web |title=Using JAX to accelerate our research |url=https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |url-status=live |archive-url=https://web.archive.org/web/20220618205746/https://www.deepmind.com/blog/using-jax-to-accelerate-our-research |archive-date=2022-06-18 |access-date=2022-06-18 |website=www.deepmind.com |date=4 December 2020 |language=en}}</ref> It is described as bringing together a modified version of autograd<ref>{{Citation |title=HIPS/autograd |date=2024-03-27 |url=https://github.com/HIPS/autograd |access-date=2024-03-28 |publisher=Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton}}</ref> (automatic obtaining of the gradient function through differentiation of a function) and [[TensorFlow]]'s XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of [[NumPy]] as closely as possible and works with various existing frameworks such as [[TensorFlow]] and [[PyTorch]].<ref>{{Cite web |last=Lynley |first=Matthew |title=Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta |url=https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-url=https://web.archive.org/web/20220621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6 |archive-date=2022-06-21 |access-date=2022-06-21 |website=Business Insider |language=en-US}}</ref><ref>{{Cite web |date=2022-04-25 |title=Why is Google's JAX so popular? |url=https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |url-status=live |archive-url=https://web.archive.org/web/20220618210503/https://analyticsindiamag.com/why-is-googles-jax-so-popular/ |archive-date=2022-06-18 |access-date=2022-06-18 |website=Analytics India Magazine |language=en-US}}</ref> The primary functions of JAX are:<ref name=":0" />


# grad: [[automatic differentiation]]
# grad: [[automatic differentiation]]

Revision as of 21:25, 13 June 2024

JAX
Developer(s)Google
Stable release
0.4.24[1] Edit this on Wikidata / 6 February 2024; 10 months ago (6 February 2024)
Repositorygithub.com/google/jax
Written inPython, C++
Operating systemLinux, macOS, Windows
PlatformPython, NumPy
Size9.0 MB
TypeMachine learning
LicenseApache 2.0
Websitejax.readthedocs.io/en/latest/ Edit this on Wikidata

Google JAX is a machine learning framework for transforming numerical functions to be used in Python.[2][3][4] It is described as bringing together a modified version of autograd[5] (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[6][7] The primary functions of JAX are:[2]

  1. grad: automatic differentiation
  2. jit: compilation
  3. vmap: auto-vectorization
  4. pmap: SPMD programming

grad

The code below demonstrates the grad function's automatic differentiation.

# imports
from jax import grad
import jax.numpy as jnp

# define the logistic function
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)

# evaluate the gradient of the logistic function at x = 1 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

The final line should outputː

0.19661194

jit

The code below demonstrates the jit function's optimization through fusion.

# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
    return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)

The computation time for jit_cube (line no. 17) should be noticeably shorter than that for cube (line no. 16). Increasing the values on line no. 10, will increase the difference.

vmap

The code below demonstrates the vmap function's vectorization.

# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp

# define function
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

Illustration video of vectorized addition

pmap

The code below demonstrates the pmap function's parallelization for matrix multiplication.

# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)

The final line should print the valuesː

[1.1566595 1.1805978]

Libraries using JAX

Several python libraries use JAX as a backend, including:

Some R libraries use JAX as a backend as well, including:

  • fastrerandomize, a library that uses the linear-algebra optimized compiler in JAX to speed up selection of balanced randomizations in a design of experiments procedure known as rerandomization.[16]

See also

References

  1. ^ https://github.com/google/jax/releases/tag/jax-v0.4.24. {{cite web}}: Missing or empty |title= (help)
  2. ^ a b Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18
  3. ^ Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.
  4. ^ "Using JAX to accelerate our research". www.deepmind.com. 4 December 2020. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  5. ^ HIPS/autograd, Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton, 2024-03-27, retrieved 2024-03-28
  6. ^ Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21.
  7. ^ "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  8. ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  9. ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  10. ^ Kidger, Patrick (2022-07-29), Equinox, retrieved 2022-07-29
  11. ^ Optax, DeepMind, 2022-07-28, retrieved 2022-07-29
  12. ^ RLax, DeepMind, 2022-07-29, retrieved 2022-07-29
  13. ^ Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08, retrieved 2023-08-08
  14. ^ "typing — Support for type hints". Python documentation. Retrieved 2023-08-08.
  15. ^ jaxtyping, Google, 2023-08-08, retrieved 2023-08-08
  16. ^ Jerzak, Connor (2023-10-01), fastrerandomize, retrieved 2023-10-03