A fused Lion optimizer kernel with Hugging Face kernels
I wrote a fused Lion optimizer step with Hugging Face kernels, packaged it with kernel-builder, and tested it on CUDA and Metal/MPS.
I wrote a fused Lion optimizer step with Hugging Face kernels, packaged it with the standard kernel-builder flow, and tested it on CUDA and Metal/MPS. On an NVIDIA L4, the fused CUDA path is 2.5x to 6.7x faster than eager PyTorch and 2.7x to 3.7x faster than torch._foreach_* on the workloads I measured. On an Apple M5 Pro, the Metal/MPS build passes the standard kernels benchmark check and runs 3.27x faster than its PyTorch reference path.
The most interesting part was not writing the arithmetic. The arithmetic is a small loop. The useful parts were learning where a kernel project should draw its boundaries, what the builder generates, what the author still has to own, and how easy it is for a "correct" floating-point test to be stricter than the math allows.
Why Lion is a good kernel-shaped problem
Lion is compact. Per element, the update is:
That is exactly the kind of optimizer step that makes sense as a first custom kernel. There are no reductions, no inter-element communication, no shared memory, and no complicated shape semantics. Each element of param, exp_avg, and grad can be handled independently. The reference implementation is also tiny, which matters more than it sounds: a kernel with a five-line PyTorch reference is much easier to test than one whose correctness is mostly statistical.
It is still worth fusing because optimizer steps are usually memory-bound. Eager PyTorch expresses the update as several tensor operations, so the same arrays get read and written repeatedly and intermediate tensors appear along the way. A fused fp32 Lion step reads param, exp_avg, and grad once, then writes param and exp_avg once. That is reads plus writes, or bytes per element. The eager formulation is much closer to four times that traffic.
That simple traffic model became the north star for the benchmark. If the fused kernel was actually doing the right thing, I expected the large-tensor CUDA speedup to land near 4x. It did: 4.2x on the L4. That is a good feeling, because the performance story is not magic. It is just fewer trips to memory.
For v1 I kept the kernel float32-only. That was not a technical limitation so much as a correctness boundary. fp16 and bf16 support should come with an explicit policy, probably fp32 compute with lower-precision storage, plus tests for that policy. Silently adding half precision because a template made it easy would make the API look more complete than the implementation really is.
What the kernels project looks like
The project lives outside the kernels repository. The wrapper project pins the dependency:
dependencies = [
"kernels[torch]>=0.16,<0.17",
]That pin is worth keeping. kernels is still pre-1.0, and the docs explicitly warn that minor releases can break things. If the goal is to ship a kernel as a library artifact rather than keep a local experiment alive, an unbounded dependency is asking for surprise breakage.
The actual kernel project starts from build.toml. That file is the source of truth for the package name, version, Hub target, Torch sources, and backend-specific native sources. In this project it declares CPU, CUDA, Metal, ROCm, and XPU backends. The important mental model is that build.toml describes the kernel, while kernel-builder creates the local Python build project around it.
The standard development loop is:
cargo run --manifest-path kernels/kernel-builder/Cargo.toml -- \
create-pyproject -f --unique-id local lion-kernel/kernel
cd lion-kernel/kernel
CMAKE_ARGS=-DGPU_LANG=CUDA python setup.py build_kernel
LION_DEVICE=cuda python -m pytest tests -vFor Metal the build flag is -DGPU_LANG=METAL, and the tests run with LION_DEVICE=mps. The same source tree can produce different local build variants under build/.
One practical lesson: generated files are not source. setup.py, CMakeLists.txt, cmake/, metadata-*.json, _ops.py, torch-ext/registration.h, _cmake_build/, and compiled .so files are products of the builder or the build. I initially treated torch-ext/registration.h like something I had to own because torch_binding.cpp includes it. That was backwards. The include is real, but the file is generated. Once that clicked, the project became much easier to reason about.
The other practical rule is that a built extension belongs to the PyTorch it was compiled against. If the venv changes Torch versions, rebuild.
The op boundary
The Python API is intentionally small: lion_step mutates one (param, exp_avg, grad) triplet, lion_foreach applies that step over lists, and LionOptimizerStep wraps the operation as a small torch.nn.Module.
lion_foreach is deliberately boring. It is a Python loop over tensors, not a single multi-tensor CUDA launch. That sounds underpowered until you benchmark it. On realistic tensor lists, the simple per-tensor fused launch was already fast enough that a more complex pointer-array kernel did not earn its keep. Complexity should pay rent, especially in native code.
The Torch schema is where the real contract starts:
lion_step(Tensor! p, Tensor! exp_avg, Tensor grad,
float lr, float beta1, float beta2, float weight_decay, float eps) -> ()The exclamation marks on Tensor! matter. They tell the dispatcher that p and exp_avg are mutated in place. That aliasing information is not decorative; it is part of how PyTorch understands the op, especially when tracing, compiling, or scheduling around mutations.
The C++ binding registers the same logical op to different device dispatch keys. CPU uses torch::kCPU, CUDA and ROCm use torch::kCUDA, Metal uses torch::kMPS, and XPU uses torch::kXPU. The builder defines compile-time macros such as CUDA_KERNEL or METAL_KERNEL, so the same registration source can compile into the right variant.
Before the native loop runs, the binding validates the boring things: dtype is fp32, all tensors are contiguous, shapes match, devices match, and empty tensors are no-ops. This is the kind of code that feels like paperwork until you remove it. A bad optimizer kernel often will not crash. It will just train a model slightly wrong.
The API accepts eps and ignores it. Lion does not have an Adam-style epsilon, but optimizer integrations often pass an epsilon argument unconditionally. Accepting it keeps the interface easy to slot into existing code without changing the math.
The native loop
The CUDA kernel is a grid-stride loop. Each thread handles one or more flattened elements:
float par = p[i];
float mom = exp_avg[i];
float g = grad[i];
float update = beta1 * mom + (1.0f - beta1) * g;
float u = (update > 0.0f) - (update < 0.0f);
float decay = (weight_decay > 0.0f) ? (lr * weight_decay * par) : 0.0f;
p[i] = par - lr * u - decay;
exp_avg[i] = beta2 * mom + (1.0f - beta2) * g;There are a few small details here that are easy to miss.
First, mom is loaded before either output is written. Both the direction computation and the final momentum update use the original momentum value. That preserves the Lion ordering: compute the signed update direction from old momentum and current gradient, update the parameter, then write the new momentum for the next step. If you rewrite this as scalar in-place code and update exp_avg[i] too early, you have changed the algorithm.
Second, the sign expression is branchless on CUDA:
(update > 0.0f) - (update < 0.0f)That produces -1, 0, or +1, and it returns exactly 0 when update == 0.0f, matching torch.sign(0) == 0. Metal uses the same semantics with a ternary expression. The exact spelling is not important; matching the edge behavior is.
Third, launch semantics matter. The CUDA launcher uses PyTorch's current CUDA stream, installs an optional device guard, and calls C10_CUDA_KERNEL_LAUNCH_CHECK(). Launching on the wrong stream can pass a tiny unit test and still be wrong inside real training code. For MPS, the wrapper uses PyTorch's MPS command buffer and dispatch queue, then embeds the compiled Metal library into the extension.
The CPU backend is just a plain loop with the same validation and math. It is not there to win a benchmark. It is there because a CPU backend makes packaging, import behavior, API shape, and most correctness tests cheap to iterate on before touching GPU hardware.
The test that was too strict
The ordinary correctness tests are what you would expect: one-step equivalence against a PyTorch reference, randomized hyperparameters, shape sweeps, zero gradients, sign(0), empty tensors, in-place mutation, and rejection paths for unsupported dtype, shape, and layout.
The interesting failure came from the repeated-step test. After many steps, one element out of 4096 differed from the reference by exactly:
That number tells you the shape of the problem. A difference of means one implementation stepped +lr while the other stepped -lr. The only way that happens in Lion is a disagreement in the sign of:
CUDA may contract that expression into a fused multiply-add. Eager PyTorch usually materializes the products as separately rounded fp32 tensors before adding them. The two evaluation orders can differ by around one ulp. Almost everywhere that is harmless. Near zero it is everything, because sign() is discontinuous. If the true value is within one ulp of zero, two reasonable fp32 computations can land on opposite sides.
The mistake would be to hide that with a big global tolerance. That would also hide real bugs, such as a wrong weight decay term accumulating over time. The better test is stricter in the right places: the momentum path has no sign(), so it should stay close to the reference; the parameter path may have a tiny number of isolated -shaped disagreements, but not larger or systematic drift.
The same idea showed up again on MPS. The exact number of isolated sign flips changed slightly, but the magnitude still matched the pattern. That is a good example of a test evolving from "bitwise-ish agreement forever" into "assert the invariant the math actually promises."
Benchmarks
For CUDA, I used a development benchmark that times GPU work with CUDA events and compares three implementations: the fused kernel, eager PyTorch, and torch._foreach_*. CUDA events matter because CUDA launches are asynchronous; wall-clock timing can otherwise measure how long Python took to enqueue work rather than how long the GPU spent executing.
The benchmark starts with a correctness check against the pure PyTorch reference. Benchmarking wrong code is worse than not benchmarking.
On an NVIDIA L4 with torch 2.10.0+cu128, the results were:
| Workload | Fused | Eager | _foreach |
vs eager | vs _foreach |
|---|---|---|---|---|---|
| 1 x 67.1M params | 5.46 ms / 246 GB/s | 23.08 ms | 20.43 ms | 4.2x | 3.7x |
| 512 x 64K params | 3.68 ms / 182 GB/s | 24.52 ms | 10.08 ms | 6.7x | 2.7x |
| GPT-2 124M param list | 10.79 ms / 231 GB/s | 27.12 ms | 37.96 ms | 2.5x | 3.5x |
The large-tensor result is the cleanest one. The L4's peak memory bandwidth is around 300 GB/s, and the fused kernel reaches 246 GB/s under the 20-bytes-per-element model. That says the kernel is close enough to the memory wall that clever arithmetic changes are unlikely to matter.
The many-small-tensors case is useful for a different reason. It is the workload that should punish a one-launch-per-tensor API. Even there, the simple implementation is fast. The realistic GPT-2-shaped list sustains 231 GB/s, which is enough evidence that a single-launch multi-tensor v2 is not worth building yet.
I also ran the standard kernels benchmark path on Metal/MPS:
| Hardware | Workload | Fused | Ref | Speedup | Match |
|---|---|---|---|---|---|
| Apple M5 Pro (20 GPU cores), PyTorch 2.10.0 | LionBenchmark.step |
1.5522 ms | 5.0754 ms | 3.27x | yes |
That MPS result is not the same benchmark as the CUDA dev script. It is the standard kernels benchmark flow, using synchronized wall-clock timing and the benchmark class' reference check. I treat it as a strong smoke test for the Metal backend and a promising performance signal, not as a full MPS performance study.
One funny result from the CUDA table is that _foreach loses to eager PyTorch on the GPT-2-shaped list. My best guess is that multi_tensor_apply chunking gets awkward when one large embedding tensor dominates a list of smaller tensors. Either way, it is a reminder that "the optimized builtin path" can have its own cliffs.
The broader benchmark lesson is simple: quote the hardware, quote the timing method, and do not overgeneralize. I did not have the chance to test ROCm or XPU. CUDA and MPS both look promising, but every backend deserves its own run before anyone claims portability as a finished fact.
What kernel-builder bought me
The best part of kernels is that it draws a useful line between kernel authoring and package machinery. I still had to write the op schema, Python API, native backend implementations, validation, tests, examples, and benchmarks. The builder handled the repeated work around generated registration glue, CMake wiring, metadata, local build variants, and loading through the kernels runtime.
That division feels right. The author should own the math and the contract. The tooling should own the tedious shape of "turn these sources into a loadable kernel artifact for this backend and this Torch version."
The rough edges I hit were mostly first-run ergonomics. A kernel-builder doctor lion-kernel/kernel --backend cuda command would be useful: check Python, active Torch, CUDA or Metal toolchain, CMake, Ninja, stale generated files, and whether the compiled artifact matches the active Torch. A guided dev-build command could wrap create-pyproject, clean stale CMake state, build the kernel, and print the exact Python environment and import path being used.
The generated-file boundary could also be louder in the docs. New authors should not have to discover by accident that registration.h, _ops.py, setup.py, metadata files, and build directories are local products rather than source files.
Conclusions
The main lesson is not "write custom kernels for everything." Most code should stay in PyTorch until there is a concrete reason to leave it. Lion has that reason: the update is simple, elementwise, and memory-bound, so fusion directly reduces memory traffic and launch count. That is the happy path for a custom optimizer kernel.
The second lesson is that the native loop is only one part of the job. A useful kernel needs a clear Python API, correct dispatcher aliasing, defensive validation, backend-specific launch behavior, realistic tests, and benchmarks that measure the right thing. The loop may be 20 lines, but the contract around it is what makes it shippable.
The third lesson is to let the benchmark veto complexity. A multi-tensor single-launch Lion kernel sounds appealing, but the current lion_foreach path is already fast on the workloads I tested. Until a real workload shows launch overhead dominating, the simpler design is the better design.
Finally, kernels made the project feel like something that could actually be distributed rather than a one-off extension living in a notebook or an examples folder. CUDA and MPS both produced encouraging results. I still need hardware time for ROCm and XPU before making claims about those backends, but the core pattern feels solid: describe the kernel once, keep generated machinery generated, validate the op boundary aggressively, and only optimize the parts the measurements point at.