robomimic.algo package
Contents
robomimic.algo package#
Submodules#
robomimic.algo.algo module#
This file contains base classes that other algorithm classes subclass. Each algorithm file also implements a algorithm factory function that takes in an algorithm config (config.algo) and returns the particular Algo subclass that should be instantiated, along with any extra kwargs. These factory functions are registered into a global dictionary with the @register_algo_factory_func function decorator. This makes it easy for @algo_factory to instantiate the correct Algo subclass.
- class robomimic.algo.algo.Algo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
object
Base algorithm class that all other algorithms subclass. Defines several functions that should be overriden by subclasses, in order to provide a standard API to be used by training functions such as @run_epoch in utils/train_utils.py.
- deserialize(model_dict)#
Load model from a checkpoint.
- Parameters:
model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- postprocess_batch_for_training(batch, obs_normalization_stats)#
Does some operations (like channel swap, uint8 to float conversion, normalization) after @process_batch_for_training is called, in order to ensure these operations take place on GPU.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader. Assumed to be on the device where training will occur (after @process_batch_for_training is called)
obs_normalization_stats (dict or None) – if provided, this should map observation keys to dicts with a “mean” and “std” of shape (1, …) where … is the default shape for the observation.
- Returns:
postproceesed batch
- Return type:
batch (dict)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- serialize()#
Get dictionary of current model parameters.
- set_eval()#
Prepare networks for evaluation.
- set_train()#
Prepare networks for training.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
- class robomimic.algo.algo.HierarchicalAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all hierarchical algorithms that consist of (1) subgoal planning and (2) subgoal-conditioned policy learning.
- property current_subgoal#
Get the current subgoal for conditioning the low-level policy
- Returns:
predicted subgoal
- Return type:
current subgoal (dict)
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Get subgoal predictions from high-level subgoal planner.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
predicted subgoal
- Return type:
subgoal (dict)
- class robomimic.algo.algo.PlannerAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all algorithms that can be used for planning subgoals conditioned on current observations and potential goal observations.
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Get predicted subgoal outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, …]
- Return type:
subgoal prediction (dict)
- sample_subgoals(obs_dict, goal_dict, num_samples=1)#
For planners that rely on sampling subgoals.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, num_samples, …]
- Return type:
subgoals (dict)
- class robomimic.algo.algo.PolicyAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all algorithms that can be used as policies.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- class robomimic.algo.algo.RolloutPolicy(policy, obs_normalization_stats=None)#
Bases:
object
Wraps @Algo object to make it easy to run policies in a rollout loop.
- start_episode()#
Prepare the policy to start a new rollout.
- class robomimic.algo.algo.ValueAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all algorithms that can learn a value function.
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters:
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- robomimic.algo.algo.algo_factory(algo_name, config, obs_key_shapes, ac_dim, device)#
Factory function for creating algorithms based on the algorithm name and config.
- Parameters:
algo_name (str) – the algorithm name
config (BaseConfig instance) – config object
obs_key_shapes (OrderedDict) – dictionary that maps observation keys to shapes
ac_dim (int) – dimension of action space
device (torch.Device) – where the algo should live (i.e. cpu, gpu)
- robomimic.algo.algo.algo_name_to_factory_func(algo_name)#
Uses registry to retrieve algo factory function from algo name.
- Parameters:
algo_name (str) – the algorithm name
- robomimic.algo.algo.register_algo_factory_func(algo_name)#
Function decorator to register algo factory functions that map algo configs to algo class names. Each algorithm implements such a function, and decorates it with this decorator.
- Parameters:
algo_name (str) – the algorithm name to register the algorithm under
robomimic.algo.bc module#
Implementation of Behavioral Cloning (BC).
- class robomimic.algo.bc.BC(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.PolicyAlgo
Normal BC training.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
- class robomimic.algo.bc.BC_GMM(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC_Gaussian
BC training with a Gaussian Mixture Model policy.
- class robomimic.algo.bc.BC_Gaussian(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC
BC training with a Gaussian policy.
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- class robomimic.algo.bc.BC_RNN(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC
BC training with an RNN policy.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- class robomimic.algo.bc.BC_RNN_GMM(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC_RNN
BC training with an RNN GMM policy.
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- class robomimic.algo.bc.BC_Transformer(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC
BC training with a Transformer policy.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs. :param obs_dict: current observation :type obs_dict: dict :param goal_dict: (optional) goal :type goal_dict: dict
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training. :param batch: dictionary with torch.Tensors sampled
from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- class robomimic.algo.bc.BC_Transformer_GMM(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC_Transformer
BC training with a Transformer GMM policy.
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging. :param info: dictionary of info :type info: dict
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- class robomimic.algo.bc.BC_VAE(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC
BC training with a VAE policy.
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- train_on_batch(batch, epoch, validate=False)#
Update from superclass to set categorical temperature, for categorical VAEs.
robomimic.algo.bcq module#
Batch-Constrained Q-Learning (BCQ), with support for more general generative action models (the original paper uses a cVAE). (Paper - https://arxiv.org/abs/1812.02900).
- class robomimic.algo.bcq.BCQ(**kwargs)#
Bases:
robomimic.algo.algo.PolicyAlgo
,robomimic.algo.algo.ValueAlgo
Default BCQ training, based on https://arxiv.org/abs/1812.02900 and https://github.com/sfujim/BCQ
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters:
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- set_discount(discount)#
Useful function to modify discount factor if necessary (e.g. for n-step returns).
- set_train()#
Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
- class robomimic.algo.bcq.BCQ_Distributional(**kwargs)#
Bases:
robomimic.algo.bcq.BCQ
BCQ with distributional critics. Distributional critics output categorical distributions over a discrete set of values instead of expected returns. Some parts of this implementation were adapted from ACME (https://github.com/deepmind/acme).
- class robomimic.algo.bcq.BCQ_GMM(**kwargs)#
Bases:
robomimic.algo.bcq.BCQ
A simple modification to BCQ that replaces the VAE used to sample action proposals from the batch with a GMM.
robomimic.algo.cql module#
Implementation of Conservative Q-Learning (CQL). Based off of https://github.com/aviralkumar2907/CQL. (Paper - https://arxiv.org/abs/2006.04779).
- class robomimic.algo.cql.CQL(**kwargs)#
Bases:
robomimic.algo.algo.PolicyAlgo
,robomimic.algo.algo.ValueAlgo
CQL-extension of SAC for the off-policy, offline setting. See https://arxiv.org/abs/2006.04779
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters:
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- property log_cql_weight#
- property log_entropy_weight#
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant info and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- set_train()#
Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
robomimic.algo.gl module#
Subgoal prediction models, used in HBC / IRIS.
- class robomimic.algo.gl.GL(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.PlannerAlgo
Implements goal prediction component for HBC and IRIS.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs. Assumes one input observation (first dimension should be 1).
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- get_actor_goal_for_training_from_processed_batch(processed_batch, **kwargs)#
Retrieve subgoals from processed batch to use for training the actor. Subclasses can modify this function to change the subgoals.
- Parameters:
processed_batch (dict) – processed batch from @process_batch_for_training
- Returns:
subgoal observations to condition actor on
- Return type:
actor_subgoals (dict)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Takes a batch of observations and predicts a batch of subgoals.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, …]
- Return type:
subgoal prediction (dict)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- sample_subgoals(obs_dict, goal_dict=None, num_samples=1)#
Sample @num_samples subgoals from the network per observation. Since this class implements a deterministic subgoal prediction, this function returns identical subgoals for each input observation.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, num_samples, …]
- Return type:
subgoals (dict)
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
- class robomimic.algo.gl.GL_VAE(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.gl.GL
Implements goal prediction via VAE.
- get_actor_goal_for_training_from_processed_batch(processed_batch, use_latent_subgoals=False, use_prior_correction=False, num_prior_samples=100, **kwargs)#
Modify from superclass to support a @use_latent_subgoals option. The VAE can optionally return latent subgoals by passing the subgoal observations in the batch through the encoder.
- Parameters:
processed_batch (dict) – processed batch from @process_batch_for_training
use_latent_subgoals (bool) – if True, condition the actor on latent subgoals by using the VAE encoder to encode subgoal observations at train-time, and using the VAE prior to generate latent subgoals at test-time
use_prior_correction (bool) – if True, use a “prior correction” trick to choose a latent subgoal sampled from the prior that is close to the latent from the VAE encoder (posterior). This can help with issues at test-time where the encoder latent distribution might not match the prior latent distribution.
num_prior_samples (int) – number of VAE prior samples to take and choose among, if @use_prior_correction is true
- Returns:
subgoal observations to condition actor on
- Return type:
actor_subgoals (dict)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Takes a batch of observations and predicts a batch of subgoals.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, …]
- Return type:
subgoal prediction (dict)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- sample_subgoals(obs_dict, goal_dict=None, num_samples=1)#
Sample @num_samples subgoals from the VAE per observation.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, num_samples, …]
- Return type:
subgoals (dict)
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
- class robomimic.algo.gl.ValuePlanner(planner_algo_class, value_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.PlannerAlgo
,robomimic.algo.algo.ValueAlgo
Base class for all algorithms that are used for planning subgoals based on (1) a @PlannerAlgo that is used to sample candidate subgoals and (2) a @ValueAlgo that is used to select one of the subgoals.
- deserialize(model_dict)#
Load model from a checkpoint.
- Parameters:
model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters:
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Takes a batch of observations and predicts a batch of subgoals.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, …]
- Return type:
subgoal prediction (dict)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- sample_subgoals(obs_dict, goal_dict, num_samples=1)#
Sample @num_samples subgoals from the planner algo per observation.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
name -> Tensor [batch_size, num_samples, …]
- Return type:
subgoals (dict)
- serialize()#
Get dictionary of current model parameters.
- set_eval()#
Prepare networks for evaluation.
- set_train()#
Prepare networks for training.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
robomimic.algo.hbc module#
Implementation of Hierarchical Behavioral Cloning, where a planner model outputs subgoals (future observations), and an actor model is conditioned on the subgoals to try and reach them. Largely based on the Generalization Through Imitation (GTI) paper (see https://arxiv.org/abs/2003.06085).
- class robomimic.algo.hbc.HBC(planner_algo_class, policy_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.HierarchicalAlgo
Default HBC training, largely based on https://arxiv.org/abs/2003.06085
- property current_subgoal#
Return the current subgoal (at rollout time) with shape (batch, …)
- deserialize(model_dict)#
Load model from a checkpoint.
- Parameters:
model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- serialize()#
Get dictionary of current model parameters.
- set_eval()#
Prepare networks for evaluation.
- set_train()#
Prepare networks for training.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
robomimic.algo.iql module#
Implementation of Implicit Q-Learning (IQL). Based off of https://github.com/rail-berkeley/rlkit/blob/master/rlkit/torch/sac/iql_trainer.py. (Paper - https://arxiv.org/abs/2110.06169).
- class robomimic.algo.iql.IQL(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.PolicyAlgo
,robomimic.algo.algo.ValueAlgo
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant info and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)
robomimic.algo.iris module#
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
- class robomimic.algo.iris.IRIS(planner_algo_class, value_algo_class, policy_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.hbc.HBC
,robomimic.algo.algo.ValueAlgo
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters:
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
robomimic.algo.td3_bc module#
Implementation of TD3-BC. Based on https://github.com/sfujim/TD3_BC (Paper - https://arxiv.org/abs/1812.02900).
Note that several parts are exactly the same as the BCQ implementation, such as @_create_critics, @process_batch_for_training, and @_train_critic_on_batch. They are replicated here (instead of subclassing from the BCQ algo class) to be explicit and have implementation details self-contained in this file.
- class robomimic.algo.td3_bc.TD3_BC(**kwargs)#
Bases:
robomimic.algo.algo.PolicyAlgo
,robomimic.algo.algo.ValueAlgo
Default TD3_BC training, based on https://arxiv.org/abs/2106.06860 and https://github.com/sfujim/TD3_BC.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
action tensor
- Return type:
action (torch.Tensor)
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters:
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters:
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns:
value tensor
- Return type:
value (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters:
info (dict) – dictionary of info
- Returns:
name -> summary statistic
- Return type:
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
Exactly the same as BCQ.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns:
- processed and filtered batch that
will be used for training
- Return type:
input_batch (dict)
- set_discount(discount)#
Useful function to modify discount factor if necessary (e.g. for n-step returns).
- set_train()#
Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters:
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns:
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type:
info (dict)