TL;DR

A tutorial shows how to build and evaluate a simple Deep Q Network (DQN) that learns Tic-Tac-Toe using Jax and the PGX game library. The article walks through game representation, batched random play, a compact neural architecture with Flax/nnx, and an evaluation loop; the author says a model can reach perfect play in roughly 15 seconds on a laptop.

What happened

The author demonstrates how to implement a Tic-Tac-Toe agent in Jax using PGX to represent and batch game states. PGX encodes states with a dataclass exposing fields such as current_player, a 3×3×2 observation array (channels for current player and opponent), a flat legal_action_mask, per-player rewards, and a terminated flag. The write-up includes a JIT-compiled random-action selector and a batched random-play loop to explore behavior. It then defines a small DQN in Flax/nnx: the board is flattened to nine values (1 for X, -1 for O), passed through two hidden layers, and produces nine tanh outputs interpreted as values per square. Action selection masks out occupied squares and picks the highest-valued legal move. The article also provides code to measure wins, losses and ties versus a random opponent and links to a GitHub repo and Colab notebook; training details begin but are not fully included in the provided excerpt.

Why it matters

  • Shows a concise, runnable pipeline for reinforcement learning in Jax using a simple game as a pedagogical example.
  • PGX’s pure-Jax implementation and batching make it straightforward to run many parallel games and leverage JIT compilation for speed.
  • A minimal DQN architecture is sufficient to reach strong play for a small, discrete game, illustrating how model complexity can match task difficulty.
  • Provides reproducible artifacts (GitHub and Colab) so readers can try the examples locally or in the cloud.

Key facts

  • PGX represents a game state with fields including current_player, observation, legal_action_mask, rewards, and terminated.
  • The observation is a boolean array shaped (3, 3, 2); the first channel marks the current player’s pieces and the second channel marks the opponent’s.
  • legal_action_mask is a flat boolean array with True for empty squares and False for filled ones.
  • Rewards are a length-2 array giving per-player rewards; rewards on states after game end can be zero for subsequent transitions.
  • PGX provides a step function that transitions a state given an action; environments can be batched to run many games in parallel.
  • The example random-action policy samples legal moves via jax.random.categorical over logits derived from the legal_action_mask.
  • DQN architecture (using flax/nnx): flatten board to 9 inputs (X=1, O=-1), two hidden layers, and a tanh output of length 9.
  • Action selection multiplies network outputs by legal_action_mask (with illegal positions set to -inf) then takes argmax.
  • The author supplies evaluation code that mixes network-chosen moves and random moves depending on current_player, and accumulates wins/ties/losses.
  • Code is available on a GitHub repository and as a Colab notebook (the Colab run is noted to be slower).

What to watch next

  • Full training loop details, hyperparameters and optimizer choices: not confirmed in the source
  • How the trained model performs against stronger, non-random opponents: not confirmed in the source
  • Reproducibility across hardware and the stated ~15‑second convergence claim (laptop): the article asserts ~15 seconds but independent verification is not provided in the excerpt

Quick glossary

  • Jax: A numerical computing library that provides composable function transformations such as JIT compilation and automatic differentiation.
  • PGX: A library that implements game environments in pure Jax and exposes batched state representations and transition functions.
  • DQN (Deep Q Network): A neural network that estimates value (or Q) scores for actions in a given state; used to select actions that maximize expected returns.
  • legal_action_mask: A boolean mask indicating which actions (board positions) are currently available to play.
  • Flax/nnx: A neural network library used with Jax for defining and composing model layers; nnx is a lightweight module API used in the example.

Reader FAQ

Is the example code available?
Yes; the article links to a GitHub repository and a Colab notebook.

How long does training take?
The author states the model can reach perfect play in about 15 seconds on a laptop.

Will the trained model always beat a random player?
The article says the model should never lose to a random player though it may tie occasionally.

Are full training hyperparameters and optimizer settings provided here?
not confirmed in the source

Learning to Play Tic-Tac-Toe with Jax JANUARY 3, 2026 In this article we’ll learn how to train a neural network to play Tic-Tac-Toe using reinforcement learning in Jax. This article…

Sources

Related posts

By

Leave a Reply

Your email address will not be published. Required fields are marked *