robomimic.algo package#

Submodules#

robomimic.algo.algo module#

This file contains base classes that other algorithm classes subclass. Each algorithm file also implements a algorithm factory function that takes in an algorithm config (config.algo) and returns the particular Algo subclass that should be instantiated, along with any extra kwargs. These factory functions are registered into a global dictionary with the @register_algo_factory_func function decorator. This makes it easy for @algo_factory to instantiate the correct Algo subclass.

class robomimic.algo.algo.Algo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: object

Base algorithm class that all other algorithms subclass. Defines several functions that should be overriden by subclasses, in order to provide a standard API to be used by training functions such as @run_epoch in utils/train_utils.py.

deserialize(model_dict)#

Load model from a checkpoint.

Parameters

model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss log (dict)

on_epoch_end(epoch)#

Called at the end of each epoch.

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

reset()#

Reset algo state to prepare for environment rollouts.

serialize()#

Get dictionary of current model parameters.

set_eval()#

Prepare networks for evaluation.

set_train()#

Prepare networks for training.

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

class robomimic.algo.algo.HierarchicalAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.Algo

Base class for all hierarchical algorithms that consist of (1) subgoal planning and (2) subgoal-conditioned policy learning.

property current_subgoal#

Get the current subgoal for conditioning the low-level policy

Returns

predicted subgoal

Return type

current subgoal (dict)

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

get_subgoal_predictions(obs_dict, goal_dict=None)#

Get subgoal predictions from high-level subgoal planner.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

predicted subgoal

Return type

subgoal (dict)

class robomimic.algo.algo.PlannerAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.Algo

Base class for all algorithms that can be used for planning subgoals conditioned on current observations and potential goal observations.

get_subgoal_predictions(obs_dict, goal_dict=None)#

Get predicted subgoal outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, …]

Return type

subgoal prediction (dict)

sample_subgoals(obs_dict, goal_dict, num_samples=1)#

For planners that rely on sampling subgoals.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, num_samples, …]

Return type

subgoals (dict)

class robomimic.algo.algo.PolicyAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.Algo

Base class for all algorithms that can be used as policies.

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

class robomimic.algo.algo.RolloutPolicy(policy, obs_normalization_stats=None)#

Bases: object

Wraps @Algo object to make it easy to run policies in a rollout loop.

start_episode()#

Prepare the policy to start a new rollout.

class robomimic.algo.algo.ValueAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.Algo

Base class for all algorithms that can learn a value function.

get_state_action_value(obs_dict, actions, goal_dict=None)#

Get state-action value outputs.

Parameters
  • obs_dict (dict) – current observation

  • actions (torch.Tensor) – action

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

get_state_value(obs_dict, goal_dict=None)#

Get state value outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

robomimic.algo.algo.algo_factory(algo_name, config, obs_key_shapes, ac_dim, device)#

Factory function for creating algorithms based on the algorithm name and config.

Parameters
  • algo_name (str) – the algorithm name

  • config (BaseConfig instance) – config object

  • obs_key_shapes (OrderedDict) – dictionary that maps observation keys to shapes

  • ac_dim (int) – dimension of action space

  • device (torch.Device) – where the algo should live (i.e. cpu, gpu)

robomimic.algo.algo.algo_name_to_factory_func(algo_name)#

Uses registry to retrieve algo factory function from algo name.

Parameters

algo_name (str) – the algorithm name

robomimic.algo.algo.register_algo_factory_func(algo_name)#

Function decorator to register algo factory functions that map algo configs to algo class names. Each algorithm implements such a function, and decorates it with this decorator.

Parameters

algo_name (str) – the algorithm name to register the algorithm under

robomimic.algo.bc module#

Implementation of Behavioral Cloning (BC).

class robomimic.algo.bc.BC(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.PolicyAlgo

Normal BC training.

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

class robomimic.algo.bc.BC_GMM(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.bc.BC_Gaussian

BC training with a Gaussian Mixture Model policy.

class robomimic.algo.bc.BC_Gaussian(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.bc.BC

BC training with a Gaussian policy.

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

class robomimic.algo.bc.BC_RNN(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.bc.BC

BC training with an RNN policy.

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

reset()#

Reset algo state to prepare for environment rollouts.

class robomimic.algo.bc.BC_RNN_GMM(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.bc.BC_RNN

BC training with an RNN GMM policy.

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

class robomimic.algo.bc.BC_VAE(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.bc.BC

BC training with a VAE policy.

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

train_on_batch(batch, epoch, validate=False)#

Update from superclass to set categorical temperature, for categorical VAEs.

robomimic.algo.bcq module#

Batch-Constrained Q-Learning (BCQ), with support for more general generative action models (the original paper uses a cVAE). (Paper - https://arxiv.org/abs/1812.02900).

class robomimic.algo.bcq.BCQ(**kwargs)#

Bases: robomimic.algo.algo.PolicyAlgo, robomimic.algo.algo.ValueAlgo

Default BCQ training, based on https://arxiv.org/abs/1812.02900 and https://github.com/sfujim/BCQ

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

get_state_action_value(obs_dict, actions, goal_dict=None)#

Get state-action value outputs.

Parameters
  • obs_dict (dict) – current observation

  • actions (torch.Tensor) – action

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

get_state_value(obs_dict, goal_dict=None)#

Get state value outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

on_epoch_end(epoch)#

Called at the end of each epoch.

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

set_discount(discount)#

Useful function to modify discount factor if necessary (e.g. for n-step returns).

set_train()#

Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

class robomimic.algo.bcq.BCQ_Distributional(**kwargs)#

Bases: robomimic.algo.bcq.BCQ

BCQ with distributional critics. Distributional critics output categorical distributions over a discrete set of values instead of expected returns. Some parts of this implementation were adapted from ACME (https://github.com/deepmind/acme).

class robomimic.algo.bcq.BCQ_GMM(**kwargs)#

Bases: robomimic.algo.bcq.BCQ

A simple modification to BCQ that replaces the VAE used to sample action proposals from the batch with a GMM.

robomimic.algo.cql module#

Implementation of Conservative Q-Learning (CQL). Based off of https://github.com/aviralkumar2907/CQL. (Paper - https://arxiv.org/abs/2006.04779).

class robomimic.algo.cql.CQL(**kwargs)#

Bases: robomimic.algo.algo.PolicyAlgo, robomimic.algo.algo.ValueAlgo

CQL-extension of SAC for the off-policy, offline setting. See https://arxiv.org/abs/2006.04779

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

get_state_action_value(obs_dict, actions, goal_dict=None)#

Get state-action value outputs.

Parameters
  • obs_dict (dict) – current observation

  • actions (torch.Tensor) – action

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

property log_cql_weight#
property log_entropy_weight#
log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

on_epoch_end(epoch)#

Called at the end of each epoch.

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant info and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

set_train()#

Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

robomimic.algo.gl module#

Subgoal prediction models, used in HBC / IRIS.

class robomimic.algo.gl.GL(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.PlannerAlgo

Implements goal prediction component for HBC and IRIS.

get_action(obs_dict, goal_dict=None)#

Get policy action outputs. Assumes one input observation (first dimension should be 1).

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

get_actor_goal_for_training_from_processed_batch(processed_batch, **kwargs)#

Retrieve subgoals from processed batch to use for training the actor. Subclasses can modify this function to change the subgoals.

Parameters

processed_batch (dict) – processed batch from @process_batch_for_training

Returns

subgoal observations to condition actor on

Return type

actor_subgoals (dict)

get_subgoal_predictions(obs_dict, goal_dict=None)#

Takes a batch of observations and predicts a batch of subgoals.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, …]

Return type

subgoal prediction (dict)

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

sample_subgoals(obs_dict, goal_dict=None, num_samples=1)#

Sample @num_samples subgoals from the network per observation. Since this class implements a deterministic subgoal prediction, this function returns identical subgoals for each input observation.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, num_samples, …]

Return type

subgoals (dict)

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

class robomimic.algo.gl.GL_VAE(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.gl.GL

Implements goal prediction via VAE.

get_actor_goal_for_training_from_processed_batch(processed_batch, use_latent_subgoals=False, use_prior_correction=False, num_prior_samples=100, **kwargs)#

Modify from superclass to support a @use_latent_subgoals option. The VAE can optionally return latent subgoals by passing the subgoal observations in the batch through the encoder.

Parameters
  • processed_batch (dict) – processed batch from @process_batch_for_training

  • use_latent_subgoals (bool) – if True, condition the actor on latent subgoals by using the VAE encoder to encode subgoal observations at train-time, and using the VAE prior to generate latent subgoals at test-time

  • use_prior_correction (bool) – if True, use a “prior correction” trick to choose a latent subgoal sampled from the prior that is close to the latent from the VAE encoder (posterior). This can help with issues at test-time where the encoder latent distribution might not match the prior latent distribution.

  • num_prior_samples (int) – number of VAE prior samples to take and choose among, if @use_prior_correction is true

Returns

subgoal observations to condition actor on

Return type

actor_subgoals (dict)

get_subgoal_predictions(obs_dict, goal_dict=None)#

Takes a batch of observations and predicts a batch of subgoals.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, …]

Return type

subgoal prediction (dict)

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

sample_subgoals(obs_dict, goal_dict=None, num_samples=1)#

Sample @num_samples subgoals from the VAE per observation.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, num_samples, …]

Return type

subgoals (dict)

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

class robomimic.algo.gl.ValuePlanner(planner_algo_class, value_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.PlannerAlgo, robomimic.algo.algo.ValueAlgo

Base class for all algorithms that are used for planning subgoals based on (1) a @PlannerAlgo that is used to sample candidate subgoals and (2) a @ValueAlgo that is used to select one of the subgoals.

deserialize(model_dict)#

Load model from a checkpoint.

Parameters

model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes

get_state_action_value(obs_dict, actions, goal_dict=None)#

Get state-action value outputs.

Parameters
  • obs_dict (dict) – current observation

  • actions (torch.Tensor) – action

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

get_state_value(obs_dict, goal_dict=None)#

Get state value outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

get_subgoal_predictions(obs_dict, goal_dict=None)#

Takes a batch of observations and predicts a batch of subgoals.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, …]

Return type

subgoal prediction (dict)

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

on_epoch_end(epoch)#

Called at the end of each epoch.

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

reset()#

Reset algo state to prepare for environment rollouts.

sample_subgoals(obs_dict, goal_dict, num_samples=1)#

Sample @num_samples subgoals from the planner algo per observation.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

name -> Tensor [batch_size, num_samples, …]

Return type

subgoals (dict)

serialize()#

Get dictionary of current model parameters.

set_eval()#

Prepare networks for evaluation.

set_train()#

Prepare networks for training.

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

robomimic.algo.hbc module#

Implementation of Hierarchical Behavioral Cloning, where a planner model outputs subgoals (future observations), and an actor model is conditioned on the subgoals to try and reach them. Largely based on the Generalization Through Imitation (GTI) paper (see https://arxiv.org/abs/2003.06085).

class robomimic.algo.hbc.HBC(planner_algo_class, policy_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.algo.HierarchicalAlgo

Default HBC training, largely based on https://arxiv.org/abs/2003.06085

property current_subgoal#

Return the current subgoal (at rollout time) with shape (batch, …)

deserialize(model_dict)#

Load model from a checkpoint.

Parameters

model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

on_epoch_end(epoch)#

Called at the end of each epoch.

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

reset()#

Reset algo state to prepare for environment rollouts.

serialize()#

Get dictionary of current model parameters.

set_eval()#

Prepare networks for evaluation.

set_train()#

Prepare networks for training.

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

robomimic.algo.iris module#

Implementation of IRIS (https://arxiv.org/abs/1911.05321).

class robomimic.algo.iris.IRIS(planner_algo_class, value_algo_class, policy_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#

Bases: robomimic.algo.hbc.HBC, robomimic.algo.algo.ValueAlgo

Implementation of IRIS (https://arxiv.org/abs/1911.05321).

get_state_action_value(obs_dict, actions, goal_dict=None)#

Get state-action value outputs.

Parameters
  • obs_dict (dict) – current observation

  • actions (torch.Tensor) – action

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

get_state_value(obs_dict, goal_dict=None)#

Get state value outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

robomimic.algo.td3_bc module#

Implementation of TD3-BC. Based on https://github.com/sfujim/TD3_BC (Paper - https://arxiv.org/abs/1812.02900).

Note that several parts are exactly the same as the BCQ implementation, such as @_create_critics, @process_batch_for_training, and @_train_critic_on_batch. They are replicated here (instead of subclassing from the BCQ algo class) to be explicit and have implementation details self-contained in this file.

class robomimic.algo.td3_bc.TD3_BC(**kwargs)#

Bases: robomimic.algo.algo.PolicyAlgo, robomimic.algo.algo.ValueAlgo

Default TD3_BC training, based on https://arxiv.org/abs/2106.06860 and https://github.com/sfujim/TD3_BC.

get_action(obs_dict, goal_dict=None)#

Get policy action outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

action tensor

Return type

action (torch.Tensor)

get_state_action_value(obs_dict, actions, goal_dict=None)#

Get state-action value outputs.

Parameters
  • obs_dict (dict) – current observation

  • actions (torch.Tensor) – action

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

get_state_value(obs_dict, goal_dict=None)#

Get state value outputs.

Parameters
  • obs_dict (dict) – current observation

  • goal_dict (dict) – (optional) goal

Returns

value tensor

Return type

value (torch.Tensor)

log_info(info)#

Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.

Parameters

info (dict) – dictionary of info

Returns

name -> summary statistic

Return type

loss_log (dict)

on_epoch_end(epoch)#

Called at the end of each epoch.

process_batch_for_training(batch)#

Processes input batch from a data loader to filter out relevant information and prepare the batch for training.

Exactly the same as BCQ.

Parameters

batch (dict) – dictionary with torch.Tensors sampled from a data loader

Returns

processed and filtered batch that

will be used for training

Return type

input_batch (dict)

set_discount(discount)#

Useful function to modify discount factor if necessary (e.g. for n-step returns).

set_train()#

Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.

train_on_batch(batch, epoch, validate=False)#

Training on a single batch of data.

Parameters
  • batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training

  • epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping

  • validate (bool) – if True, don’t perform any learning updates.

Returns

dictionary of relevant inputs, outputs, and losses

that might be relevant for logging

Return type

info (dict)

Module contents#