About PageshiftPageshift is a Research Lab committed to pushing the frontier of AI storytelling and creativity. We are envisioning a world in which most entertainment is personalized and AI-generated. Our goal is to build the underlying story engine that powers it all. To do this, we are not afraid to explore new ways and create novel categories of model capability.About the roleYou are expected to optimize JAX workloads on TPUs with a strong focus on XLA behavior and Pallas kernel development. You will profile, benchmark, implement kernel-level optimizations, and validate improvements with data. The work supports long-context training workloads, so you are expected to care about memory behavior, sharding, and end-to-end efficiency.What we're looking for:- Passion for entertainment and storytelling
- Basic understanding of TPUs, JAX and XLA
- Willingness to optimize against a black boxNice to have:- Example project to show off
- Have written Pallas kernels beforeYour responsibilities:- Implementing and optimizing JAX code and Pallas kernels for TPUs
