Custom Brax Environments

Published:

Context

I think we all recognize brax environments are useful for very fast RL training. How do we create custom envs which are useful?

Setup

Have this at the top:


import jax
import jax.numpy as jnp
from brax.envs.base import Env, State

And is setup with class EnvName(Env):

Init function

Anything which is assigned to self stays the same throughout training.

Obs and action space

In the init, you must define the action space, observation space and backend for jax

e.g.

 @property
def observation_size(self) -> int:
    return 4

@property
def action_size(self) -> int:
    return 4

@property
def backend(self) -> str:
    return "abstract"

For hybrid observations, we need to do something special. First, pass the state out of step flattened. Then, reshape it back to the shape we want. For example of a encoder:

import flax.linen as nn
import jax.numpy as jnp

class HybridEncoder(nn.Module):
    @nn.compact
    def __call__(self, x: jnp.ndarray):
        # Reshape back to image
        image_part = x[..., :3072].reshape((-1, 32, 32, 3))
        symbolic_part = x[..., 3072:]

        # Pass through cnn
        x_img = nn.Conv(features=32, kernel_size=(3, 3))(image_part)
        x_img = nn.relu(x_img)
        x_img = nn.max_pool(x_img, window_shape=(2, 2), strides=(2, 2))
        x_img = x_img.reshape((x_img.shape[0], -1))  # Flatten
        
        # Concat and pass to mlp
        combined = jnp.concatenate([x_img, symbolic_part], axis=-1)
        x_mlp = nn.Dense(features=256)(combined)
        return nn.relu(x_mlp)

Then we have to do something like this (TODO: UPDATE):

def make_my_networks(
    observation_size, 
    action_size, 
    preprocess_observations_fn):
    
    return ppo_networks.make_ppo_networks(
        observation_size=observation_size,
        action_size=action_size,
        preprocess_observations_fn=preprocess_observations_fn,
        # Brax will append policy/value heads to this encoder
        policy_network_factory=lambda: HybridEncoder() 
    )

# Pass the factory to ppo.train
train_fn = functools.partial(
    ppo.train,
    num_timesteps=100_000,
    network_factory=make_my_networks, # Use the custom factory here
    # ... other hyperparams
)

Reset function

Reset returns a state object. Here is an example:

def reset(self, rng: jnp.ndarray) -> State:
    initial_obs = jnp.zeros(self.observation_size) 

    return State(
        pipeline_state=None,  # Set to none for custom envs
        obs=initial_obs,
        reward=jnp.array(0.0),
        done=jnp.array(0.0),
        metrics={'episode_reward': 0.0}, # Accumulates over training automatically. Use for wandb. Only for floats
        info={} # For things like task id or single step metrics
    )

Step function

Step returns a state object:

def step(self, state: State, action: jnp.ndarray) -> State:
    # Calculate new state
    new_obs = state.obs + action
    
    # Calcuate reward 
    reward = -jnp.sum(jnp.square(new_obs))

    # Done is 1 for done, 0 for not
    done = jnp.where(jnp.abs(new_obs) > 10.0, 1.0, 0.0)
    
    # Return new state by doing state.replace
    return state.replace(
        obs=new_obs,
        reward=reward,
        done=done
    )