Adithya V. Sastry

Robotics + Origami
Creator of Tin Can Linux
Research Science Institute 2025
Lover of all things tiny

"A dream is not that which
you see while sleeping, it
is that which does not let
you sleep."

-- Dr. APJ Abdul Kalam

  snail
Random notes on JAX                                                   06/13/2025
--------------------------------------------------------------------------------

So my mentor asked me to learn JAX before showing up to RSI.

What is JAX? Well, it's a goofy python library that allows you to do some
pretty interesting things. Seems pretty nifty on the surface, just looking at
the descriptions -- but the one time I saw JAX code, I didn't realize it was
python. Seemed like its own programming language or something.

Anyways, here's some stuff I found by reading through the tutorials. This also
functions as my notes so I don't forget everything I read.

Some basic stuff
----------------

A few intro things that seemed useful to know.

# Arrays

You can make arrays in JAX just like numpy arrays, torch tensors, etc. One
interesting thing though is that they can be split across multiple devices (so
part of the array could be on GPU #1, another part can be on GPU #2, etc). This
is called sharding. Hopefully I won't have to worry about this too much (but it
relates to parallelizing computations, so maybe).

# Tracing

JAX has this construct called a tracer that allows it to track the operations
that occur within a python function. It seems crazy that python code can
actually do this, but I suppose there are ways. Interesting to keep in mind.

# Jaxprs

So apparently JAX is kind of its own programming language, in a way. It can
convert python functions into a goofy looking language called jaxpr. I guess
this probably also relates to how tracing works (although still kind of beats me
how they manage to pull this off).

# RNG

JAX uses something called RNG keys. They seem similar to seeds -- if you use a
certain random key for a random function, the output is always the same. However
you can get a different output by splitting the key into two new keys. The docs
recommend deleting the original key once you split it in order to keep
generating random values.

JIT compiling
-------------

Ok so this kind of explains where the jaxpr language comes into play. You can
use JIT compiling to take a python function and pre-compute a series of
transformations (which makes subsequent calls to the function execute much
faster). However there are some things to note:

- JIT compiled functions are re-compiled on first use and afterwards any time
  the input shape changes. When the function is re-compiled it essentially runs
  the raw python code.

- The JIT compiler requires pure functions, which essentially means they have to
  produce the same output every time for a given input and cannot use or modify
  anything outside the function body. In a sense it's like a function from math.
  Things that would count as impure functions produce side effects, such as:

  - Uses or modifies a global variable. Whenever the function is compiled the
    current value of the global is cached. Attempting to save a global will
    replace the global with a JAX object.

  - Prints out a value. The print statement will be executed when the function
    is compiled, but not any other time. Useful to know for debugging.

  - Uses an iterator (these create global states apparently. I also didn't know
    that these really existed in python in the first place). This can cause an
    error, or it might just do something unexpected.

- If the function contains a conditional, only the branch that is taken when
  the function is compiled is saved. All the other branches of the if/else are
  not compiled. There is a different way to use conditionals in JIT functions.
  Also, conditionals will fail to compile in JIT if they read the value.

- You can JIT compile part of a function: just isolate it into its own function,
  JIT it, and call it, for example if the whole function isn't pure:

--------------------------------------------------------------------------------
import jax
import jax.numpy as jnp

@jax.jit
def inner(x: jax.Array, y: float):
    return x.at[0].add(y)

def conditional_fn(x):
    if jnp.all(jnp.greater(x, 0)):
        inner(x, 7.0)
    else:
        inner(x, -7.0)

print(outer(jnp.array([1.0, 2.0])))  # prints [8.0, 2.0]
print(outer(jnp.array([1.0, -2.0]))) # prints [-6.0, 2.0]--------------------------------------------------------------------------------

There was some other stuff but I kind of skimmed over it for now.

Vectorization
-------------

JAX allows you to automatically vectorize a function that takes array inputs by
allowing parallel computation for batched inputs. This is another transform that
takes advantage of jaxprs and tracing. They provide the example of computing the
convolution of two vectors, but since I was confused by it, here's a simpler
example (multiplies the [0][0] element of each matrix):

--------------------------------------------------------------------------------
import jax
import jax.numpy as jnp

def thing(x, y):
    a.at[0, 0].multiply(b.at[0, 0].get())

# If we had two 2x2 matrices, this is easy:
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
print(thing(a, b))

# Output:
# Array([[5, 2],
#        [3, 4]], dtype=int32)

# But if we wanted to compute this for a batch:
c = jnp.tile(a, (3, 1, 1))
d = jnp.tile(b, (3, 1, 1))

# We can vectorize our function:
thing_vec = jax.vmap(thing)
print(thing_vec(c, d))

# Output:
# Array([[[5, 2],
#         [3, 4]],
# 
#        [[5, 2],
#         [3, 4]],
# 
#        [[5, 2],
#         [3, 4]]], dtype=int32)--------------------------------------------------------------------------------

It uses 0 as the batch dimension by default, but you can specify a different one
for eithr the input or the output (or both).

You can also chain `jax.jit()` with `jax.vmap()`.

Autodiff
--------

Ok, this is the cool stuff. You can use JAX to compute the gradient of a python
function.

For example, let's say we have a function `$f(x) = x^2 + 7^x$`. We then know that:
`$$
\begin{align}
f'(x) &= 2x + 7^x (\ln 7) \\
f''(x) &= 2 + 7^x (\ln 7)^2
\end{align}
$$`and so on. We can do this in python with JAX:

--------------------------------------------------------------------------------
import jax

def f(x):
    return x ** 2 + 7 ** x

fp = jax.grad(f)
fpp = jax.grad(fp)

# and so on.--------------------------------------------------------------------------------

Of course the real utility is for computing gradients of more complex functions.
Pretty cool stuff!

<<< Go back                                                Subscribe via RSS >>>
========= Updates =========
Post - 01/27/2026 Snow News - 01/19/2026 Tin Can: 43 followers + 26 stars Post - 01/08/2026 Tiny robots Post - 06/13/2025 Random notes on JAX News - 06/11/2025 Tin Can Linux on Repology Post - 04/15/2025 Folding some origami insects Post - 06/23/2024 Shared Memory IPC in Rust