ML Performance Engineer (TPU)

Location
Remote or Zurich
Employment Type
Full time
Location Type
Hybrid

About Pageshift
Pageshift is a Research Lab committed to pushing the frontier of AI storytelling and creativity. We are envisioning a world in which most entertainment is personalized and AI-generated. Our goal is to build the underlying story engine that powers it all. To do this, we are not afraid to explore new ways and create novel categories of model capability.

About the role
You are expected to optimize JAX workloads on TPUs with a strong focus on XLA behavior and Pallas kernel development. You will profile, benchmark, implement kernel-level optimizations, and validate improvements with data. The work supports long-context training workloads, so you are expected to care about memory behavior, sharding, and end-to-end efficiency.

What we're looking for:
- Passion for entertainment and storytelling
- Basic understanding of TPUs, JAX and XLA
- Willingness to optimize against a black box

Nice to have:
- Example project to show off
- Have written Pallas kernels before

Your responsibilities:
- Implementing and optimizing JAX code and Pallas kernels for TPUs


apply to this job