Implementing Custom Algorithms
Contents
Implementing Custom Algorithms#
This tutorial provides an example of implementing a custom algorithm in robomimic. We choose to implement the recently proposed TD3-BC algorithm.
This consists of the following steps:
Implement a custom
Config
class for TD3-BC.Implement a custom
Algo
class for TD3-BC.
Implementing the Config class#
We will implement the config class in config/td3_bc_config.py
. We implement a TD3_BCConfig
config class that subclasses from BaseConfig
. Importantly, we set the class variable ALGO_NAME = "td3_bc"
to register this config under that algo name. We implement the algo_config
function to populate config.algo
with the keys needed for the algorithm - it is extremely similar to the BCQConfig
implementation. Portions of the code are reproduced below.
class TD3_BCConfig(BaseConfig):
ALGO_NAME = "td3_bc"
def algo_config(self):
# optimization parameters
self.algo.optim_params.critic.learning_rate.initial = 3e-4 # critic learning rate
self.algo.optim_params.critic.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
self.algo.optim_params.critic.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
self.algo.optim_params.critic.regularization.L2 = 0.00 # L2 regularization strength
self.algo.optim_params.critic.start_epoch = -1 # number of epochs before starting critic training (-1 means start right away)
self.algo.optim_params.critic.end_epoch = -1 # number of epochs before ending critic training (-1 means start right away)
self.algo.optim_params.actor.learning_rate.initial = 3e-4 # actor learning rate
self.algo.optim_params.actor.learning_rate.decay_factor = 0.1 # factor to decay LR by (if epoch schedule non-empty)
self.algo.optim_params.actor.learning_rate.epoch_schedule = [] # epochs where LR decay occurs
self.algo.optim_params.actor.regularization.L2 = 0.00 # L2 regularization strength
self.algo.optim_params.actor.start_epoch = -1 # number of epochs before starting actor training (-1 means start right away)
self.algo.optim_params.actor.end_epoch = -1 # number of epochs before ending actor training (-1 means start right away)
# alpha value - for weighting critic loss vs. BC loss
self.algo.alpha = 2.5
# target network related parameters
self.algo.discount = 0.99 # discount factor to use
self.algo.n_step = 1 # for using n-step returns in TD-updates
self.algo.target_tau = 0.005 # update rate for target networks
self.algo.infinite_horizon = False # if True, scale terminal rewards by 1 / (1 - discount) to treat as infinite horizon
# ================== Critic Network Config ===================
self.algo.critic.use_huber = False # Huber Loss instead of L2 for critic
self.algo.critic.max_gradient_norm = None # L2 gradient clipping for critic (None to use no clipping)
self.algo.critic.value_bounds = None # optional 2-tuple to ensure lower and upper bound on value estimates
# critic ensemble parameters (TD3 trick)
self.algo.critic.ensemble.n = 2 # number of Q networks in the ensemble
self.algo.critic.ensemble.weight = 1.0 # weighting for mixing min and max for target Q value
self.algo.critic.layer_dims = (256, 256, 256) # size of critic MLP
# ================== Actor Network Config ===================
# update actor and target networks every n gradients steps for each critic gradient step
self.algo.actor.update_freq = 2
# exploration noise used to form target action for Q-update - clipped Gaussian noise
self.algo.actor.noise_std = 0.2 # zero-mean gaussian noise with this std is applied to actions
self.algo.actor.noise_clip = 0.5 # noise is clipped in each dimension to (-noise_clip, noise_clip)
self.algo.actor.layer_dims = (256, 256, 256) # size of actor MLP
Usually, we only need to implement the algo_config
function to populate config.algo
with the keys needed for the algorithm, but we also update the experiment_config
function and observation_config
function to make it easier to reproduce experiments on gym
environments from the paper. See the source file for more details.
Finally, we add the line from robomimic.config.td3_bc_config import TD3_BCConfig
to config/__init__.py
to make sure this Config
subclass is registered by robomimic
.
Implementing the Algo class#
We will implement the algo class in algo/td3_bc.py
. As described in the Algorithm documentation, we first need to implement the algo_config_to_class
method - this is straightforward since we don’t have multiple variants of this algorithm. We take special care to make sure we register this function with the same algo name that we used for defining the config ("td3_bc"
).
@register_algo_factory_func("td3_bc")
def algo_config_to_class(algo_config):
"""
Maps algo config to the TD3_BC algo class to instantiate, along with additional algo kwargs.
Args:
algo_config (Config instance): algo config
Returns:
algo_class: subclass of Algo
algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
"""
# only one variant of TD3_BC for now
return TD3_BC, {}
Next, we’ll describe how we implement the methods outlined in the Algorithm documentation. We omit several of the methods, since their implementation is extremely similar to the BCQ
implementation. We start by defining the class and implementing _create_networks
. The code uses helper functions _create_critics
and _create_actor
to create the critic and actor networks, as in the BCQ
implementation.
class TD3_BC(PolicyAlgo, ValueAlgo):
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
self.nets = nn.ModuleDict()
self._create_critics()
self._create_actor()
# sync target networks at beginning of training
with torch.no_grad():
for critic_ind in range(len(self.nets["critic"])):
TorchUtils.hard_update(
source=self.nets["critic"][critic_ind],
target=self.nets["critic_target"][critic_ind],
)
TorchUtils.hard_update(
source=self.nets["actor"],
target=self.nets["actor_target"],
)
self.nets = self.nets.float().to(self.device)
def _create_critics(self):
critic_class = ValueNets.ActionValueNetwork
critic_args = dict(
obs_shapes=self.obs_shapes,
ac_dim=self.ac_dim,
mlp_layer_dims=self.algo_config.critic.layer_dims,
value_bounds=self.algo_config.critic.value_bounds,
goal_shapes=self.goal_shapes,
**ObsNets.obs_encoder_args_from_config(self.obs_config.encoder),
)
# Q network ensemble and target ensemble
self.nets["critic"] = nn.ModuleList()
self.nets["critic_target"] = nn.ModuleList()
for _ in range(self.algo_config.critic.ensemble.n):
critic = critic_class(**critic_args)
self.nets["critic"].append(critic)
critic_target = critic_class(**critic_args)
self.nets["critic_target"].append(critic_target)
def _create_actor(self):
actor_class = PolicyNets.ActorNetwork
actor_args = dict(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
mlp_layer_dims=self.algo_config.actor.layer_dims,
**ObsNets.obs_encoder_args_from_config(self.obs_config.encoder),
)
self.nets["actor"] = actor_class(**actor_args)
self.nets["actor_target"] = actor_class(**actor_args)
Next we describe the train_on_batch
function, which implements the main training logic. The function trains the critic using the _train_critic_on_batch
helper function, and then actor using the _train_actor_on_batch
helper function (the actor is trained at a slower rate according to the config.algo.actor.update_freq
config variable, as in the original author’s implementation). Finally, the target network parameters are moved a little closer to the current network parameters, using TorchUtils.soft_update
.
def train_on_batch(self, batch, epoch, validate=False):
"""
Training on a single batch of data.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
epoch (int): epoch number - required by some Algos that need
to perform staged training and early stopping
validate (bool): if True, don't perform any learning updates.
Returns:
info (dict): dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
"""
with TorchUtils.maybe_no_grad(no_grad=validate):
info = PolicyAlgo.train_on_batch(self, batch, epoch, validate=validate)
# Critic training
no_critic_backprop = validate or (not self._check_epoch(net_name="critic", epoch=epoch))
with TorchUtils.maybe_no_grad(no_grad=no_critic_backprop):
critic_info = self._train_critic_on_batch(
batch=batch,
epoch=epoch,
no_backprop=no_critic_backprop,
)
info.update(critic_info)
# update actor and target networks at lower frequency
if not no_critic_backprop:
# update counter only on critic training gradient steps
self.actor_update_counter += 1
do_actor_update = (self.actor_update_counter % self.algo_config.actor.update_freq == 0)
# Actor training
no_actor_backprop = validate or (not self._check_epoch(net_name="actor", epoch=epoch))
no_actor_backprop = no_actor_backprop or (not do_actor_update)
with TorchUtils.maybe_no_grad(no_grad=no_actor_backprop):
actor_info = self._train_actor_on_batch(
batch=batch,
epoch=epoch,
no_backprop=no_actor_backprop,
)
info.update(actor_info)
if not no_actor_backprop:
# to match original implementation, only update target networks on
# actor gradient steps
with torch.no_grad():
# update the target critic networks
for critic_ind in range(len(self.nets["critic"])):
TorchUtils.soft_update(
source=self.nets["critic"][critic_ind],
target=self.nets["critic_target"][critic_ind],
tau=self.algo_config.target_tau,
)
# update target actor network
TorchUtils.soft_update(
source=self.nets["actor"],
target=self.nets["actor_target"],
tau=self.algo_config.target_tau,
)
return info
Below, we show the helper functions for training the critics, to be explicit in how the Bellman backup is used to construct the TD loss. The target Q values for the TD loss are obtained in the same way as TD3.
def _train_critic_on_batch(self, batch, epoch, no_backprop=False):
info = OrderedDict()
# batch variables
s_batch = batch["obs"]
a_batch = batch["actions"]
r_batch = batch["rewards"]
ns_batch = batch["next_obs"]
goal_s_batch = batch["goal_obs"]
# 1 if not done, 0 otherwise
done_mask_batch = 1. - batch["dones"]
info["done_masks"] = done_mask_batch
# Bellman backup for Q-targets
q_targets = self._get_target_values(
next_states=ns_batch,
goal_states=goal_s_batch,
rewards=r_batch,
dones=done_mask_batch,
)
info["critic/q_targets"] = q_targets
# Train all critics using this set of targets for regression
for critic_ind, critic in enumerate(self.nets["critic"]):
critic_loss = self._compute_critic_loss(
critic=critic,
states=s_batch,
actions=a_batch,
goal_states=goal_s_batch,
q_targets=q_targets,
)
info["critic/critic{}_loss".format(critic_ind + 1)] = critic_loss
if not no_backprop:
critic_grad_norms = TorchUtils.backprop_for_loss(
net=self.nets["critic"][critic_ind],
optim=self.optimizers["critic"][critic_ind],
loss=critic_loss,
max_grad_norm=self.algo_config.critic.max_gradient_norm,
)
info["critic/critic{}_grad_norms".format(critic_ind + 1)] = critic_grad_norms
return info
def _get_target_values(self, next_states, goal_states, rewards, dones):
"""
Helper function to get target values for training Q-function with TD-loss.
"""
with torch.no_grad():
# get next actions via target actor and noise
next_target_actions = self.nets["actor_target"](next_states, goal_states)
noise = (
torch.randn_like(next_target_actions) * self.algo_config.actor.noise_std
).clamp(-self.algo_config.actor.noise_clip, self.algo_config.actor.noise_clip)
next_actions = (next_target_actions + noise).clamp(-1.0, 1.0)
# TD3 trick to combine max and min over all Q-ensemble estimates into single target estimates
all_value_targets = self.nets["critic_target"][0](next_states, next_actions, goal_states).reshape(-1, 1)
max_value_targets = all_value_targets
min_value_targets = all_value_targets
for critic_target in self.nets["critic_target"][1:]:
all_value_targets = critic_target(next_states, next_actions, goal_states).reshape(-1, 1)
max_value_targets = torch.max(max_value_targets, all_value_targets)
min_value_targets = torch.min(min_value_targets, all_value_targets)
value_targets = self.algo_config.critic.ensemble.weight * min_value_targets + \
(1. - self.algo_config.critic.ensemble.weight) * max_value_targets
q_targets = rewards + dones * self.discount * value_targets
return q_targets
def _compute_critic_loss(self, critic, states, actions, goal_states, q_targets):
"""
Helper function to compute loss between estimated Q-values and target Q-values.
"""
q_estimated = critic(states, actions, goal_states)
if self.algo_config.critic.use_huber:
critic_loss = nn.SmoothL1Loss()(q_estimated, q_targets)
else:
critic_loss = nn.MSELoss()(q_estimated, q_targets)
return critic_loss
Next we show the helper function for training the actor, which is trained through a weighted combination of the TD3 (DDPG) and BC loss.
def _train_actor_on_batch(self, batch, epoch, no_backprop=False):
info = OrderedDict()
# Actor loss (update with mixture of DDPG loss and BC loss)
s_batch = batch["obs"]
a_batch = batch["actions"]
goal_s_batch = batch["goal_obs"]
# lambda mixture weight is combination of hyperparameter (alpha) and Q-value normalization
actor_actions = self.nets["actor"](s_batch, goal_s_batch)
Q_values = self.nets["critic"][0](s_batch, actor_actions, goal_s_batch)
lam = self.algo_config.alpha / Q_values.abs().mean().detach()
actor_loss = -lam * Q_values.mean() + nn.MSELoss()(actor_actions, a_batch)
info["actor/loss"] = actor_loss
if not no_backprop:
actor_grad_norms = TorchUtils.backprop_for_loss(
net=self.nets["actor"],
optim=self.optimizers["actor"],
loss=actor_loss,
)
info["actor/grad_norms"] = actor_grad_norms
return info
Finally, we describe the get_action
implementation - which is used at test-time during rollouts. The implementation is extremely simple - just query the actor for an action.
def get_action(self, obs_dict, goal_dict=None):
"""
Get policy action outputs.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
action (torch.Tensor): action tensor
"""
assert not self.nets.training
return self.nets["actor"](obs_dict=obs_dict, goal_dict=goal_dict)
Finally, we add the line from robomimic.algo.td3_bc import TD3_BC
to algo/__init__.py
to make sure this Algo
subclass is registered by robomimic
.
That’s it! See algo/td3_bc.py
for the complete implementation, and compare it to algo/bcq.py
to see the similarity between the two implementations.
We can now run the generate_config_templates.py
script to generate the json template for our new algorithm, and then run it on our desired dataset.
# generate ../exps/templates/td3_bc.json
$ python generate_config_templates.py
# run training
$ python train.py --config ../exps/templates/td3_bc.json --dataset /path/to/walker2d_medium_expert.hdf5