TorchTPU: PyTorch를 TPU에서 네이티브로 실행하는 구글의 전략

PyTorch + TPU, 왜 지금 중요한가

대규모 AI 모델 학습의 시대가 본격화되면서, 수천~수만 개의 가속기를 효율적으로 활용하는 소프트웨어 스택의 중요성이 급격히 커지고 있어요. Google의 TPU(Tensor Processing Unit)는 Gemini, Veo 같은 자체 AI 플랫폼은 물론 Cloud 고객의 대규모 워크로드를 처리하는 핵심 인프라인데요. 문제는 ML 커뮤니티의 상당수가 PyTorch 생태계에 익숙하다는 점이에요. TorchTPU는 바로 이 간극을 메우기 위해 설계된 스택으로, 기존 PyTorch 코드를 최소한의 변경만으로 TPU에서 네이티브로 실행할 수 있게 해줘요. 이 글에서는 TorchTPU의 핵심 아키텍처, Eager 모드 전략, 정적 컴파일 파이프라인, 그리고 2026년 로드맵까지 개발자 관점에서 깊이 살펴볼게요.

TorchTPU 메타 이미지 — PyTorch와 TPU 통합 개념도

TPU 하드웨어 아키텍처 이해하기

TorchTPU를 제대로 이해하려면 먼저 TPU의 하드웨어 구조를 알아야 해요. TPU는 단순한 칩이 아니라 통합 네트워크 시스템이에요.

ICI와 토러스 토폴로지

각 호스트에는 여러 개의 TPU 칩이 연결되어 있고, 칩들은 Inter-Chip Interconnect(ICI)를 통해 2D 또는 3D 토러스(Torus) 토폴로지로 연결돼요. 이 구조 덕분에 전통적인 네트워크 병목 없이 대규모 스케일업이 가능해요. GPU 클러스터에서 InfiniBand나 NVLink에 의존하는 것과는 근본적으로 다른 접근이에요.

TensorCore와 SparseCore

칩 내부 실행은 TensorCore와 SparseCore로 나뉘어요.

  • TensorCore: 단일 스레드 유닛으로 밀집 행렬 연산(dense matrix math)에 특화
  • SparseCore: 임베딩, gather/scatter 같은 불규칙 메모리 접근 패턴과 collective 오프로딩 처리

이 이중 구조가 TPU의 핵심 강점이고, TorchTPU는 이 두 코어를 모두 활용할 수 있도록 설계되었어요.

TorchTPU 블로그 아키텍처 다이어그램

Eager First 철학: 세 가지 실행 모드

TorchTPU의 핵심 설계 원칙은 "PyTorch처럼 느껴져야 한다"예요. 기존 PyTorch 스크립트에서 디바이스만 "tpu"로 바꾸면 코어 로직 한 줄 수정 없이 학습 루프가 돌아가야 한다는 거죠. 이를 위해 PyTorch의 PrivateUse1 인터페이스를 활용해 서브클래스나 래퍼 없이 일반 PyTorch Tensor로 동작하도록 구현했어요.

Debug Eager

연산 하나마다 디스패치하고 CPU와 동기화해요. 느리지만 shape mismatch, NaN, OOM 디버깅에 필수적이에요.

# Debug Eager 모드 예시
import torch
import torch_tpu

device = torch.device("tpu")
x = torch.randn(4, 4, device=device)
# 한 연산마다 동기화 — 디버깅 시 사용

Strict Eager

단일 연산 디스패치를 유지하되 비동기로 실행해요. 기본 PyTorch 경험을 그대로 미러링하면서 CPU와 TPU가 동기화 포인트까지 동시에 실행돼요.

Fused Eager — 진짜 혁신

연산 스트림을 자동으로 리플렉션하여 여러 스텝을 계산 밀도가 높은 청크로 퓨징한 뒤 TPU에 전달해요. 사용자가 별도 설정 없이 Strict Eager 대비 50%~100% 이상의 성능 향상을 얻을 수 있어요. TensorCore 활용을 극대화하고 메모리 대역폭 오버헤드를 최소화하는 게 핵심 메커니즘이에요.

세 모드 모두 공유 Compilation Cache를 사용하며, 단일 호스트 또는 멀티 호스트 영속 캐시로 설정할 수 있어서 워크로드를 학습할수록 컴파일 시간이 줄어들어요.

TorchTPU 파트1 커버 이미지

정적 컴파일: Dynamo → XLA → StableHLO 파이프라인

최고 성능을 원한다면 torch.compile 인터페이스를 통한 풀 그래프 컴파일을 사용할 수 있어요. 파이프라인은 다음과 같아요.

  • Torch Dynamo로 FX 그래프를 캡처
  • Torch Inductor 대신 XLA를 백엔드 컴파일러로 사용
  • PyTorch 연산자를 XLA의 기본 IR인 StableHLO로 매핑
  • XLA가 ICI 상의 밀집 연산과 collective 통신 간 오버랩을 최적화하여 TPU 바이너리 생성

XLA를 선택한 건 의도적인 결정이에요. XLA는 TPU 토폴로지에 대해 실전에서 검증된 최적화 역량을 갖추고 있고, Eager 모드에서 확립된 실행 경로를 재활용할 수 있어요.

커스텀 커널 확장성

성능을 깨뜨리지 않는 확장성도 보장해요. TorchTPU는 Pallas와 JAX로 작성된 커스텀 커널을 네이티브로 지원해요.

@torch_tpu.pallas.custom_jax_kernel
def my_custom_op(x):
    # JAX 기반 저수준 하드웨어 명령어 작성
    # lowering path와 직접 인터페이스
    return result

이 데코레이터를 통해 엔지니어가 직접 lowering path에 접근하는 저수준 커널을 작성할 수 있어요.

마무리

TorchTPU는 PyTorch 생태계와 Google TPU 인프라 사이의 간극을 실질적으로 해소하는 스택이에요. Eager First 철학으로 진입 장벽을 낮추면서도 Fused Eager와 XLA 기반 정적 컴파일로 하드웨어 성능을 최대한 끌어내는 설계가 인상적이에요. 2026년 로드맵에서 Pallas 커스텀 커널 지원 확대와 함께 TPU 기반 대규모 학습이 더욱 접근 가능해질 것으로 기대돼요.

이 블로그의 인기 게시물

가상 파일시스템으로 AI 어시스턴트 비용·속도 최적화하기

gemma4 vllm 실행 방법: 설치·최적화·멀티GPU 완전 가이드