Close Menu
    Facebook X (Twitter) Instagram
    Articles Stock
    • Home
    • Technology
    • AI
    • Pages
      • About us
      • Contact us
      • Disclaimer For Articles Stock
      • Privacy Policy
      • Terms and Conditions
    Facebook X (Twitter) Instagram
    Articles Stock
    AI

    Implementing Deep Q-Studying (DQN) from Scratch Utilizing RLax JAX Haiku and Optax to Practice a CartPole Reinforcement Studying Agent

    Naveed AhmadBy Naveed Ahmad23/03/2026Updated:23/03/2026No Comments7 Mins Read
    a computer screen displays lines of code NhN2MwzWQFWMqrzgiAFmCg EvqUXia9RxG2ByrMbsZCGA cover hd


    On this tutorial, we implement a reinforcement studying agent utilizing RLax, a research-oriented library developed by Google DeepMind for constructing reinforcement studying algorithms with JAX. We mix RLax with JAX, Haiku, and Optax to assemble a Deep Q-Studying (DQN) agent that learns to resolve the CartPole surroundings. As a substitute of utilizing a completely packaged RL framework, we assemble the coaching pipeline ourselves so we will clearly perceive how the core elements of reinforcement studying work together. We outline the neural community, construct a replay buffer, compute temporal distinction errors with RLax, and prepare the agent utilizing gradient-based optimization. Additionally, we concentrate on understanding how RLax offers reusable RL primitives that may be built-in into customized reinforcement studying pipelines. We use JAX for environment friendly numerical computation, Haiku for neural community modeling, and Optax for optimization.

    !pip -q set up "jax[cpu]" dm-haiku optax rlax gymnasium matplotlib numpy
    
    
    import os
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    
    
    import random
    import time
    from dataclasses import dataclass
    from collections import deque
    
    
    import gymnasium as fitness center
    import haiku as hk
    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    import numpy as np
    import optax
    import rlax
    
    
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    
    
    env = fitness center.make("CartPole-v1")
    eval_env = fitness center.make("CartPole-v1")
    
    
    obs_dim = env.observation_space.form[0]
    num_actions = env.action_space.n
    
    
    def q_network(x):
       mlp = hk.Sequential([
           hk.Linear(128), jax.nn.relu,
           hk.Linear(128), jax.nn.relu,
           hk.Linear(num_actions),
       ])
       return mlp(x)
    
    
    q_net = hk.without_apply_rng(hk.rework(q_network))
    
    
    dummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32)
    rng = jax.random.PRNGKey(seed)
    params = q_net.init(rng, dummy_obs)
    target_params = params
    
    
    optimizer = optax.chain(
       optax.clip_by_global_norm(10.0),
       optax.adam(3e-4),
    )
    opt_state = optimizer.init(params)

    We set up the required libraries and import all of the modules wanted for the reinforcement studying pipeline. We initialize the surroundings, outline the neural community structure utilizing Haiku, and arrange the Q-network that predicts motion values. We additionally initialize the community and goal community parameters, in addition to the optimizer for use throughout coaching.

    @dataclass
    class Transition:
       obs: np.ndarray
       motion: int
       reward: float
       low cost: float
       next_obs: np.ndarray
       achieved: float
    
    
    class ReplayBuffer:
       def __init__(self, capability):
           self.buffer = deque(maxlen=capability)
    
    
       def add(self, *args):
           self.buffer.append(Transition(*args))
    
    
       def pattern(self, batch_size):
           batch = random.pattern(self.buffer, batch_size)
           obs = np.stack([t.obs for t in batch]).astype(np.float32)
           motion = np.array([t.action for t in batch], dtype=np.int32)
           reward = np.array([t.reward for t in batch], dtype=np.float32)
           low cost = np.array([t.discount for t in batch], dtype=np.float32)
           next_obs = np.stack([t.next_obs for t in batch]).astype(np.float32)
           achieved = np.array([t.done for t in batch], dtype=np.float32)
           return {
               "obs": obs,
               "motion": motion,
               "reward": reward,
               "low cost": low cost,
               "next_obs": next_obs,
               "achieved": achieved,
           }
    
    
       def __len__(self):
           return len(self.buffer)
    
    
    replay = ReplayBuffer(capability=50000)
    
    
    def epsilon_by_frame(frame_idx, eps_start=1.0, eps_end=0.05, decay_frames=20000):
       combine = min(frame_idx / decay_frames, 1.0)
       return eps_start + combine * (eps_end - eps_start)
    
    
    def select_action(params, obs, epsilon):
       if random.random() < epsilon:
           return env.action_space.pattern()
       q_values = q_net.apply(params, obs[None, :])
       return int(jnp.argmax(q_values[0]))

    We outline the transition construction and implement a replay buffer to retailer previous experiences from the surroundings. We create capabilities so as to add transitions and pattern mini-batches that can later be used to coach the agent. We additionally implement the epsilon-greedy exploration technique.

    @jax.jit
    def soft_update(target_params, online_params, tau):
       return jax.tree_util.tree_map(lambda t, s: (1.0 - tau) * t + tau * s, target_params, online_params)
    
    
    def batch_td_errors(params, target_params, batch):
       q_tm1 = q_net.apply(params, batch["obs"])
       q_t = q_net.apply(target_params, batch["next_obs"])
       td_errors = jax.vmap(
           lambda q1, a, r, d, q2: rlax.q_learning(q1, a, r, d, q2)
       )(q_tm1, batch["action"], batch["reward"], batch["discount"], q_t)
       return td_errors
    
    
    @jax.jit
    def train_step(params, target_params, opt_state, batch):
       def loss_fn(p):
           td_errors = batch_td_errors(p, target_params, batch)
           loss = jnp.imply(rlax.huber_loss(td_errors, delta=1.0))
           metrics = {
               "loss": loss,
               "td_abs_mean": jnp.imply(jnp.abs(td_errors)),
               "q_mean": jnp.imply(q_net.apply(p, batch["obs"])),
           }
           return loss, metrics
    
    
       (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
       updates, opt_state = optimizer.replace(grads, opt_state, params)
       params = optax.apply_updates(params, updates)
       return params, opt_state, metrics

    We outline the core studying capabilities used throughout coaching. We compute temporal distinction errors utilizing RLax’s Q-learning primitive and calculate the loss utilizing the Huber loss perform. We then implement the coaching step that computes gradients, applies optimizer updates, and returns coaching metrics.

    def evaluate_agent(params, episodes=5):
       returns = []
       for ep in vary(episodes):
           obs, _ = eval_env.reset(seed=seed + 1000 + ep)
           achieved = False
           truncated = False
           total_reward = 0.0
           whereas not (achieved or truncated):
               q_values = q_net.apply(params, obs[None, :])
               motion = int(jnp.argmax(q_values[0]))
               next_obs, reward, achieved, truncated, _ = eval_env.step(motion)
               total_reward += reward
               obs = next_obs
           returns.append(total_reward)
       return float(np.imply(returns))
    
    
    num_frames = 40000
    batch_size = 128
    warmup_steps = 1000
    train_every = 4
    eval_every = 2000
    gamma = 0.99
    tau = 0.01
    max_grad_updates_per_step = 1
    
    
    obs, _ = env.reset(seed=seed)
    episode_return = 0.0
    episode_returns = []
    eval_returns = []
    losses = []
    td_means = []
    q_means = []
    eval_steps = []
    
    
    start_time = time.time()

    We outline the analysis perform that measures the agent’s efficiency. We configure the coaching hyperparameters, together with the variety of frames, batch dimension, low cost issue, and goal community replace price. We additionally initialize variables that observe coaching statistics, together with episode returns, losses, and analysis metrics.

    for frame_idx in vary(1, num_frames + 1):
       epsilon = epsilon_by_frame(frame_idx)
       motion = select_action(params, obs.astype(np.float32), epsilon)
       next_obs, reward, achieved, truncated, _ = env.step(motion)
       terminal = achieved or truncated
       low cost = 0.0 if terminal else gamma
    
    
       replay.add(
           obs.astype(np.float32),
           motion,
           float(reward),
           float(low cost),
           next_obs.astype(np.float32),
           float(terminal),
       )
    
    
       obs = next_obs
       episode_return += reward
    
    
       if terminal:
           episode_returns.append(episode_return)
           obs, _ = env.reset()
           episode_return = 0.0
    
    
       if len(replay) >= warmup_steps and frame_idx % train_every == 0:
           for _ in vary(max_grad_updates_per_step):
               batch_np = replay.pattern(batch_size)
               batch = {ok: jnp.asarray(v) for ok, v in batch_np.objects()}
               params, opt_state, metrics = train_step(params, target_params, opt_state, batch)
               target_params = soft_update(target_params, params, tau)
               losses.append(float(metrics["loss"]))
               td_means.append(float(metrics["td_abs_mean"]))
               q_means.append(float(metrics["q_mean"]))
    
    
       if frame_idx % eval_every == 0:
           avg_eval_return = evaluate_agent(params, episodes=5)
           eval_returns.append(avg_eval_return)
           eval_steps.append(frame_idx)
           recent_train = np.imply(episode_returns[-10:]) if episode_returns else 0.0
           recent_loss = np.imply(losses[-100:]) if losses else 0.0
           print(
               f"step={frame_idx:6d} | epsilon={epsilon:.3f} | "
               f"recent_train_return={recent_train:7.2f} | "
               f"eval_return={avg_eval_return:7.2f} | "
               f"recent_loss={recent_loss:.5f} | buffer={len(replay)}"
           )
    
    
    elapsed = time.time() - start_time
    final_eval = evaluate_agent(params, episodes=10)
    
    
    print("nTraining full")
    print(f"Elapsed time: {elapsed:.1f} seconds")
    print(f"Ultimate 10-episode analysis return: {final_eval:.2f}")
    
    
    plt.determine(figsize=(14, 4))
    plt.subplot(1, 3, 1)
    plt.plot(episode_returns)
    plt.title("Coaching Episode Returns")
    plt.xlabel("Episode")
    plt.ylabel("Return")
    
    
    plt.subplot(1, 3, 2)
    plt.plot(eval_steps, eval_returns)
    plt.title("Analysis Returns")
    plt.xlabel("Surroundings Steps")
    plt.ylabel("Avg Return")
    
    
    plt.subplot(1, 3, 3)
    plt.plot(losses, label="Loss")
    plt.plot(td_means, label="|TD Error| Imply")
    plt.title("Optimization Metrics")
    plt.xlabel("Gradient Updates")
    plt.legend()
    
    
    plt.tight_layout()
    plt.present()
    
    
    obs, _ = eval_env.reset(seed=999)
    frames = []
    achieved = False
    truncated = False
    total_reward = 0.0
    
    
    render_env = fitness center.make("CartPole-v1", render_mode="rgb_array")
    obs, _ = render_env.reset(seed=999)
    
    
    whereas not (achieved or truncated):
       body = render_env.render()
       frames.append(body)
       q_values = q_net.apply(params, obs[None, :])
       motion = int(jnp.argmax(q_values[0]))
       obs, reward, achieved, truncated, _ = render_env.step(motion)
       total_reward += reward
    
    
    render_env.shut()
    
    
    print(f"Demo episode return: {total_reward:.2f}")
    
    
    attempt:
       import matplotlib.animation as animation
       from IPython.show import HTML, show
    
    
       fig = plt.determine(figsize=(6, 4))
       patch = plt.imshow(frames[0])
       plt.axis("off")
    
    
       def animate(i):
           patch.set_data(frames[i])
           return (patch,)
    
    
       anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=30, blit=True)
       show(HTML(anim.to_jshtml()))
       plt.shut(fig)
    besides Exception as e:
       print("Animation show skipped:", e)

    We run the total reinforcement studying coaching loop. We periodically replace the community parameters, consider the agent’s efficiency, and document metrics for visualization. Additionally, we plot the coaching outcomes and render an illustration episode to watch how the educated agent behaves.

    In conclusion, we constructed a whole Deep Q-Studying agent by combining RLax with the fashionable JAX-based machine studying ecosystem. We designed a neural community to estimate motion values, implement expertise replay to stabilize studying, and compute TD errors utilizing RLax’s Q-learning primitive. Throughout coaching, we up to date the community parameters utilizing gradient-based optimization and periodically evaluated the agent to trace efficiency enhancements. Additionally, we noticed how RLax permits a modular strategy to reinforcement studying by offering reusable algorithmic elements somewhat than full algorithms. This flexibility permits us to simply experiment with completely different architectures, studying guidelines, and optimization methods. By extending this basis, we will construct extra superior brokers, similar to Double DQN, distributional reinforcement studying fashions, and actor–critic strategies, utilizing the identical RLax primitives.


    Take a look at the Full Noteboo here. Additionally, be happy to comply with us on Twitter and don’t neglect to hitch our 120k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.




    Source link

    Naveed Ahmad

    Related Posts

    How BM25 and RAG Retrieve Info Otherwise?

    23/03/2026

    Meet GitAgent: The Docker for AI Brokers that’s Lastly Fixing the Fragmentation between LangChain, AutoGen, and Claude Code

    23/03/2026

    Do you wish to construct a robotic snowman?

    23/03/2026
    Leave A Reply Cancel Reply

    Categories
    • AI
    Recent Comments
      Facebook X (Twitter) Instagram Pinterest
      © 2026 ThemeSphere. Designed by ThemeSphere.

      Type above and press Enter to search. Press Esc to cancel.