There are enough Python libraries out there that you’ll never understand or use them all. The more pertinent task is choosing the right one for your specific project. At Shakudo it’s pretty common for our team to begin using NumPy or another library, only to figure out halfway through that it’s not effective for our use case.
Shakudo provides data teams with a platform and managed cloud built for data teams, and we’ve taken a liking to JAX lately for machine learning-and data processing. We’ll explain why in this quick intro to JAX.
What is JAX and what does it do?
Google’s JAX is a high-performance Python package, built to accelerate machine learning research. JAX provides a lightweight API for array-based computing - much like NumPy. It adds a set of composable function transformations, including for automatic differentiation, just-in-time (JIT) compilation, and automated vectorization and parallelization of your code. We’ll talk about those more later on.
JAX is executable on CPU, GPU, or TPU, with minor edits to your code making it easy to speed up big projects in a short amount of time. We’ve seen it used for some really cool projects including protein folding research, robotics control, and physics simulations.
Automatic differentiation is a procedure for computing derivatives that avoids the pitfalls of numerical (expensive and numerically unstable), and symbolic (exponential increase in the number of expressions) differentiation. The automatic differentiation procedure takes a function (program) and simplifies it into a sequence of primitive operations for which the derivative can be easily computed. This procedure is known as backpropagation.
Why use JAX?
Because JAX syntax is so similar to NumPy, with just a few code changes it can be used in projects where NumPy just isn’t cutting it performance-wise, or where you need some extra features that JAX supports. Data-heavy industries including machine learning, blockchain, and other data and compute-heavy use cases benefit from JAX’ improved performance. Maybe you're researching JAX because you’ve hit a wall in terms of scaling your data project - a lot of Shakudo users had before they tried our platform.
Beyond speed, JAX is an all around great tool for prototyping because it’s easy to use if you already work with NumPy. It also has powerful features you won’t find in other ML libraries, and a highly familiar syntax for most Python developers.
Tests have shown that JAX can perform up to 8600% faster when used for basic functions - highly valuable for data-heavy application-facing models, or just for getting more machine learning experiments done in a day. Although most real-world applications won’t see this type of speed jump, it does show the potential value of switching.
NumPy vs JAX KDE Density Function results from our R&D team, showing a 1500x speed increase.
JAX is capable of these crazy-high speeds for the following reasons:
Vectorization: The method of vectorization enables processing multiple data as a single instruction. This method works for the cases where the same simple operation is applied on the entire data. Since most matrix operations involve applying the same operation on the rows and columns of the matrices, it makes it very amenable to vectorization, providing great speedups for linear algebra computations and machine learning.
JAX allows you to use jax.vmap to automatically generate a vectorized implementation of a function:
Code Parallelization: the process of taking a serial code that runs on a single processor and spreading the work across multiple processors. Which means it breaks the problem into smaller pieces so that all data can be processed simultaneously by the computer. This makes the process much more efficient than what it would be by waiting for the solution to one problem to solve the next one.
Automatic differentiation: a set of techniques to evaluate the derivative of a function, by exploiting sequences of elementary arithmetic operations. JAX differentiation is pretty straightforward:
You can also repeatedly apply `grad` to get higher order derivatives. That is, we can get the second derivative of `func` by applying it again on `d_func`:
How JAX is built
JAX is built to use Accelerated Linear Algebra (XLA) and Just-in-Time Compilation (JIT). XLA is a domain-specific compiler for linear algebra that fuses together operations, meaning it allows you to skip intermediate results for overall improved speed. JAX uses XLA to compile and run NumPy programs on GPUs and TPUs without changes to your code. It traces your Python code to an intermediate representation, which is then just-in-time compiled.
With JIT, the first time the interpreter runs a method, it gets compiled to machine code so that subsequent executions will run faster. JIT is a simple function:
Although it’s a powerful tool, it still doesn't work for every function. You can look to JIT documentation to understand better about what it can and can’t compile.
How to Install and Import JAX
To install the CPU-only version of JAX, use the following:
pip install --upgrade pip
pip install --upgrade "jax[cuda]"
Finally, we can import the NumPy interface and the most important JAX functions using:
If you’ve already begun your project using NumPy, you can import it to JAX using the following, and use it to do the same operations as in NumPy:
Note that there are two restraints for your NumPy project to work:
- You can only use pure functions. If you call your function twice, it has to return the same result, and you can’t do in-place updates of arrays:
- Random number generation is explicit and it uses a PRNG key:
And there you go! You’re off on your first JAX project. If you want to try it out using Shakudo, we have a free, no credit card required sandbox that allows you to develop, deploy, and troubleshoot data-heavy projects. All of this in an easily configurable workspace that pre-integrates all the open source tools and data frameworks you want. Have fun!