We're in beta · Starting with US & Canada · Shipping weekly — your feedback shapes RiseMe
Openkyber logo
Openkyber Verified
Cybersecurity, Software Development, Blockchain.

Lead ML Engineer

Alaska, Alaska, United StatesRemoteContractLeadPosted 1 month ago

Compensation estimateAI

See base, equity, bonus, and total comp estimates for this role — free, no credit card.

Sign up to see compensation estimate

100% remote role Need to work as per EST No agency or C2C will NOT be considered and Visa sponsorship is not available nor provided Machine Learning Engineer: Framework Migration &Systems Optimization (PyTorch to JAX)

We are seeking a specialized Machine Learning Engineer with deep expertise in the high-performance AI stack. This role isn't just about "translating" code; it s about re-architecting Large Language Models (LLMs) to thrive in a JAX-native environment, specifically targeting TPU and GPU clusters at scale. You will bridge the gap between high-level PyTorch research implementations and thefunctional, XLA-optimized world of JAX/XLA, ensuring that our models achieve maximum throughput and hardware efficiency.

1. Core Framework Migration Structural Porting:

Manually migrate complex PyTorch LLM architectures (Transformers, MoE, SSMs) into JAX-based frameworks (Equinox, Flax, or Pax). State Management: Transition imperative PyTorch state management to JAX s purely functional paradigm, handling PRNGKey management and immutable state updates with precision. Weight Translation: Develop robust pipelines for checkpoint conversion, ensuring numerical parity between frameworks via rigorous unit testing and error tolerance checks.

2. Advanced Profiling & Numerical Stability Bottleneck Analysis:

Use the NVIDIA Nsight and TensorBoard Profiler to identify XLA compilation overheads, excessive rematerialization, or un-fused kernels. Numerical Debugging: Implement precision-tracking tools to ensure that $BF16$ or $FP8$ training runs remain stable during the transition, preventing gradient divergence.

3. Scaling & Distributed Training Parallelism Strategies:

Implement and optimize Fully Sharded Data Parallelism (FSDP) equivalents in JAX (using pjit or sharding APIs). Hybrid Parallelism: Design 3D parallelism strategies (Data, Pipeline, and Tensor) tailored for the interconnect topology (e.g., NVLink or TPU IC) of the target hardware.

4. Hardware-Aware Optimization XLA Mastery:

Understand and influence the XLA (Accelerated Linear Algebra) compiler behavior. You will optimize HLO (High-Level Optimizer) graphs to minimize "jit-time" and maximize "run-time" efficiency. Memory Management: Apply optimizations like Selective Activation Checkpointing and memory-efficient attention (FlashAttention-2 JAX implementations) based on the specific HBM (High Bandwidth Memory) constraints of the hardware.

5. Fine-Tuning & Adaptation Efficient Fine-Tuning:

Port PyTorch-based PEFT (LoRA, DoRA) methods into the JAX stack. Architectural Evolution: Stay ahead of the curve by adapting JAX implementations for newer primitives like Mamba/SSMs, Grouped-Query Attention (GQA), and Linear Attention as they emerge in the research space.

Familiarity with the following technical Stack & Tooling

  • Core Frameworks & Libraries: JAX Ecosystem: Expertise in Flax or Equinox (for model definition), Optax (for optimization/schedules), and Orbax (for checkpointing). PyTorch Ecosystem: Deep knowledge of PyTorch 2.x, including torch.compile, DistributedDataParallel (DDP), and FSDP. Intermediate Representations: Proficiency in HLO (High-Level Optimizer) and MLIR to understand how JAX code translates to hardware instructions. Data Loaders: Experience migrating from torch.utils.data to Grain or tf.data for high-throughput JAX pipelines.
  • Profiling & Observability device memory traffic. JAX Profiler / TensorBoard: For identifying XLA compilation bottlenecks and tracing NVIDIA Nsight Systems: To analyze GPU utilization, SM occupancy, and NVLink Perfetto: For deep-dive trace analysis across multi-node TPU/GPU clusters.
  • Infrastructure & Hardware Accelerator Hardware: Strong understanding of NVIDIA H100/A100 (Hopper/Ampere) architecture and Google TPU (v4/v5p) topology. Orchestration: Experience with Slurm or Kubernetes (K8s) for managing large-scale training jobs. Cloud Providers: Proficiency in Google Cloud (Google Cloud Platform) for TPUs or AWS/Azure for high-end GPU instances.

Core Skills & Competencies

  • Software Engineering Excellence Functional Programming: A shift in mindset from OOP (Object-Oriented) to pure functions, immutability, and stateless logic. Asynchronous Programming: Understanding JAX s asynchronous dispatch model and how to avoid "host-sync" bottlenecks. Testing Rigor: Ability to write property-based tests for numerical stability.
  • Distributed Systems Knowledge Collective Communications: Deep understanding of All-Reduce, All-Gather, and Reduce-Scatter primitives. Network Topology: Understanding how rack-level interconnects (e.g., InfiniBand) affect the choice of 3D parallelism strategies.
  • Mathematical & AI Domain Expertise (Desirable) Linear Algebra: Mastery of tensor contractions, Einstein summation (einsum), and matrix decomposition. Mixed Precision Training: Expert-level knowledge of Stochastic Rounding, Loss Scaling, and the nuances of BF16 vs. FP8 training. Architecture Insight: Ability to decompose modern LLM components (KV Caches, Rotary Embeddings, Gated Linear Units) into their primitive mathematical operations

For applications and inquiries, contact: hirings@openkyber.com

Ready to apply?
You'll be redirected to Openkyber's application page.

Similar roles