PPO Training
Pure-JAX PPO training for the Kolmogorov and turbulent channel environments, based on purejaxrl with HydroGym integrations (VecEnv, normalization wrappers, etc.).
Setting up the JAX Environment
We begin by setting up the JAX environment with all required software dependencies:
import argparse
import pickle
from typing import NamedTuple, Sequence
import distrax
import flax.linen as nn
import flax.serialization
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
From HydroGym, we will need first need the functions to wrap the environment in a VecEnv and normalize the observations and rewards.
from hydrogym.jax.env_core import ClipAction, LogWrapper, NormalizeVecObservation, NormalizeVecReward, VecEnv
Constructing the Reinforcement Learning Environment
To be able to construct the reinforcement learning environment, we then need to construct an utility function which takes in the environment configuration, and validated its configuration for the chosen case.
def make_env(config):
"""Instantiate the environment selected by config["ENV_NAME"]."""
env_name = config.get("ENV_NAME", "kolmogorov").lower()
if env_name == "kolmogorov":
from hydrogym.jax.envs.kolmogorov import KolmogorovFlow
env = KolmogorovFlow(env_config={}, flow_config={})
elif env_name == "channel":
from hydrogym.jax.envs.channel import ChannelFlowSpectralEnv
env = ChannelFlowSpectralEnv(env_config={})
else:
raise ValueError(f"Unknown ENV_NAME: {env_name!r}. Choose 'kolmogorov' or 'channel'.")
return env, env.default_params
In addition, we require utility functions around the saving and loading of the model
def save_model(params, filepath):
with open(filepath, "wb") as f:
# Using pickle to serialize params
pickle.dump(flax.serialization.to_bytes(params), f)
def load_model(filepath):
with open(filepath, "rb") as f:
# Deserialize params using pickle
params_bytes = pickle.load(f)
params = flax.serialization.from_bytes(None, params_bytes)
return params
Defining Reinforcement Learning Training
For the reinforcement learning training, we will need to first define an Actor-Critic network, before we can move on to define the transition, and then conclude by defining the actual training loop finally.
class ActorCritic(nn.Module):
action_dim: Sequence[int]
activation: str = "tanh"
@nn.compact
def __call__(self, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
actor_mean = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) # changed actor_mean to jnp.exp
critic = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
critic = activation(critic)
critic = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(critic)
critic = activation(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)
return pi, jnp.squeeze(critic, axis=-1)
The transition class is then defined as follows:
class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray
With the rollout function following the purejaxrl implementation:
def rollout(env, params, env_params, num_steps=10, num_envs=4, activation="tanh"):
rng = jax.random.PRNGKey(30)
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, num_envs)
observations = []
actions = []
rewards = []
dones = []
# Wrap before reset so the wrapped env is used throughout
env = ClipAction(env)
obs, env_state = env.reset(reset_rng, env_params)
network = ActorCritic(env.action_space(env_params).shape[0], activation=activation)
for _ in range(num_steps):
observations.append(obs)
rng, action_rng = jax.random.split(rng)
pi, _ = network.apply(params, obs)
action = pi.sample(seed=action_rng)
actions.append(action)
rng, step_rng = jax.random.split(rng)
obs, env_state, reward, done, _ = env.step(step_rng, env_state, action, env_params)
rewards.append(reward)
dones.append(done)
return {
"observations": jnp.array(observations),
"actions": jnp.array(actions),
"rewards": jnp.array(rewards),
"dones": jnp.array(dones),
}
Culminating in the following training loop:
def make_train(config):
total_batch = config["NUM_ENVS"] * config["NUM_STEPS"]
if total_batch % config["NUM_MINIBATCHES"] != 0:
raise ValueError(
f"NUM_ENVS * NUM_STEPS ({config['NUM_ENVS']} * {config['NUM_STEPS']} = {total_batch}) "
f"must be divisible by NUM_MINIBATCHES ({config['NUM_MINIBATCHES']}). "
f"Valid NUM_MINIBATCHES values for your settings: "
f"{[d for d in range(1, total_batch + 1) if total_batch % d == 0]}"
)
config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
config["MINIBATCH_SIZE"] = total_batch // config["NUM_MINIBATCHES"]
env, env_params = make_env(config)
env = LogWrapper(env)
env = ClipAction(env)
env = VecEnv(env)
if config["NORMALIZE_ENV"]:
env = NormalizeVecObservation(env)
env = NormalizeVecReward(env, config["GAMMA"])
def linear_schedule(count):
frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
return config["LR"] * frac
# @partial(jax.jit, static_argnums=(1,))
def train(rng):
# INIT NETWORK
network = ActorCritic(env.action_space(env_params).shape[0], activation=config["ACTIVATION"])
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(config["LR"], eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)
# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = env.reset(reset_rng, env_params)
# TRAIN LOOP
def _update_step(runner_state, unused):
# COLLECT TRAJECTORIES
def _env_step(runner_state, unused):
train_state, env_state, last_obs, rng = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
pi, value = network.apply(train_state.params, last_obs)
action = pi.sample(seed=_rng) # clip action here
log_prob = pi.log_prob(action)
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)
transition = Transition(done, action, value, reward, log_prob, last_obs, info)
runner_state = (train_state, env_state, obsv, rng)
return runner_state, transition
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["NUM_STEPS"])
# CALCULATE ADVANTAGE
train_state, env_state, last_obs, rng = runner_state
_, last_val = network.apply(train_state.params, last_obs)
def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
return (gae, value), gae
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value
advantages, targets = _calculate_gae(traj_batch, last_val)
# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info
def _loss_fn(params, traj_batch, gae, targets):
# RERUN NETWORK
pi, value = network.apply(params, traj_batch.obs)
log_prob = pi.log_prob(traj_batch.action)
# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
-config["CLIP_EPS"], config["CLIP_EPS"]
)
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config["CLIP_EPS"],
1.0 + config["CLIP_EPS"],
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()
total_loss = loss_actor + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
return total_loss, (value_loss, loss_actor, entropy)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(train_state.params, traj_batch, advantages, targets)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss
train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss
update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["UPDATE_EPOCHS"])
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]
if config.get("DEBUG"):
def callback(info):
step = int(info["timestep"].max())
total = config["TOTAL_TIMESTEPS"]
pct = 100.0 * step / total
# Extra env-specific metrics
extras = []
if "mean_tke" in info:
extras.append(f"mean_tke={float(info['mean_tke'].mean()):.4f}")
# Completed episodes in this rollout batch
done_mask = info["returned_episode"]
if done_mask.any():
mean_return = float(info["returned_episode_returns"][done_mask].mean())
extras.append(f"return={mean_return:.4f}")
extra_str = " " + " ".join(extras) if extras else ""
print(f" step {step:>6}/{total} ({pct:5.1f}%){extra_str}")
jax.debug.callback(callback, metric)
runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric
rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, _rng)
runner_state, metric = jax.lax.scan(_update_step, runner_state, None, config["NUM_UPDATES"])
return {"runner_state": runner_state, "metrics": metric}
return train
Performing the Training
At this point, we can now define the configuration of our training hyperparameters, and pull the individual pieces together
config = {
"LR": 1e-4, # try 3e-4 - 1e-5 (play around with it) 1e-4
"NUM_ENVS": 4,
"NUM_STEPS": 40, # 40
"TOTAL_TIMESTEPS": 100, # 4000
"UPDATE_EPOCHS": 10,
"NUM_MINIBATCHES": 8,
"GAMMA": 0.99,
"GAE_LAMBDA": 0.985, # can tune to go up to 0.995. 0.98
"CLIP_EPS": 0.2,
"ENT_COEF": 0.0, # can be increased to approx 0.1 or 0.2 or stay the same
"VF_COEF": 0.5,
"MAX_GRAD_NORM": 0.5,
"ACTIVATION": "tanh", # mish activation function is good to try
"ANNEAL_LR": False, # can try
"NORMALIZE_ENV": False,
"DEBUG": True,
}
define our training parameters more custom to HydroGym
parser = argparse.ArgumentParser(description="PPO training for HydroGym JAX environments")
parser.add_argument(
"--env",
default="kolmogorov",
choices=["kolmogorov", "channel"],
help="Environment to train on (default: kolmogorov)",
)
parser.add_argument("--total-timesteps", type=int, default=4000)
parser.add_argument("--num-envs", type=int, default=4)
parser.add_argument("--num-steps", type=int, default=10)
parser.add_argument("--num-minibatches", type=int, default=8, help="Must divide NUM_ENVS * NUM_STEPS (default: 8)")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--model-save-path", default=None, help="Path to save trained model (.pkl)")
parser.add_argument("--plot-path", default=None, help="Path to save reward plot (.png)")
args = parser.parse_args()
set the paths for the model to be saved, and where plots are to be saved
model_save_path = args.model_save_path or f"trained_model_{args.env}.pkl"
plot_path = args.plot_path or f"plot_reward_{args.env}.png"
just for our own sanity, inspect the configuration and paths to be sure that they are set correctly before beginning the training.
print(f"=== PPO Training: {args.env} environment ===")
print(f" Total timesteps : {config['TOTAL_TIMESTEPS']}")
print(f" Num envs : {config['NUM_ENVS']}")
print(f" Num steps : {config['NUM_STEPS']}")
print(f" Learning rate : {config['LR']}")
print(f" Model save path : {model_save_path}")
print(f" Plot save path : {plot_path}")
print("")
at which point we can run the full training
rng = jax.random.PRNGKey(30)
train_jit = jax.jit(make_train(config))
out = train_jit(rng)
After the training is completed, we can save the trained model
trained_params = out["runner_state"][0].params
save_model(trained_params, config["MODEL_SAVE_PATH"])
and plot the training results:
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
plt.xlabel("Updates")
plt.ylabel("Return")
plt.show()
plt.savefig(config["PLOT_TRAINING_PATH"], format="png")
jnp.save("rewardovertime", out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))