robomimic.algo package
Contents
robomimic.algo package#
Submodules#
robomimic.algo.algo module#
This file contains base classes that other algorithm classes subclass. Each algorithm file also implements a algorithm factory function that takes in an algorithm config (config.algo) and returns the particular Algo subclass that should be instantiated, along with any extra kwargs. These factory functions are registered into a global dictionary with the @register_algo_factory_func function decorator. This makes it easy for @algo_factory to instantiate the correct Algo subclass.
- class robomimic.algo.algo.Algo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
object
Base algorithm class that all other algorithms subclass. Defines several functions that should be overriden by subclasses, in order to provide a standard API to be used by training functions such as @run_epoch in utils/train_utils.py.
- deserialize(model_dict)#
Load model from a checkpoint.
- Parameters
model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- serialize()#
Get dictionary of current model parameters.
- set_eval()#
Prepare networks for evaluation.
- set_train()#
Prepare networks for training.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
- class robomimic.algo.algo.HierarchicalAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all hierarchical algorithms that consist of (1) subgoal planning and (2) subgoal-conditioned policy learning.
- property current_subgoal#
Get the current subgoal for conditioning the low-level policy
- Returns
predicted subgoal
- Return type
current subgoal (dict)
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Get subgoal predictions from high-level subgoal planner.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
predicted subgoal
- Return type
subgoal (dict)
- class robomimic.algo.algo.PlannerAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all algorithms that can be used for planning subgoals conditioned on current observations and potential goal observations.
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Get predicted subgoal outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, …]
- Return type
subgoal prediction (dict)
- sample_subgoals(obs_dict, goal_dict, num_samples=1)#
For planners that rely on sampling subgoals.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, num_samples, …]
- Return type
subgoals (dict)
- class robomimic.algo.algo.PolicyAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all algorithms that can be used as policies.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- class robomimic.algo.algo.RolloutPolicy(policy, obs_normalization_stats=None)#
Bases:
object
Wraps @Algo object to make it easy to run policies in a rollout loop.
- start_episode()#
Prepare the policy to start a new rollout.
- class robomimic.algo.algo.ValueAlgo(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.Algo
Base class for all algorithms that can learn a value function.
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- robomimic.algo.algo.algo_factory(algo_name, config, obs_key_shapes, ac_dim, device)#
Factory function for creating algorithms based on the algorithm name and config.
- Parameters
algo_name (str) – the algorithm name
config (BaseConfig instance) – config object
obs_key_shapes (OrderedDict) – dictionary that maps observation keys to shapes
ac_dim (int) – dimension of action space
device (torch.Device) – where the algo should live (i.e. cpu, gpu)
- robomimic.algo.algo.algo_name_to_factory_func(algo_name)#
Uses registry to retrieve algo factory function from algo name.
- Parameters
algo_name (str) – the algorithm name
- robomimic.algo.algo.register_algo_factory_func(algo_name)#
Function decorator to register algo factory functions that map algo configs to algo class names. Each algorithm implements such a function, and decorates it with this decorator.
- Parameters
algo_name (str) – the algorithm name to register the algorithm under
robomimic.algo.bc module#
Implementation of Behavioral Cloning (BC).
- class robomimic.algo.bc.BC(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.PolicyAlgo
Normal BC training.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
- class robomimic.algo.bc.BC_GMM(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC_Gaussian
BC training with a Gaussian Mixture Model policy.
- class robomimic.algo.bc.BC_Gaussian(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC
BC training with a Gaussian policy.
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- class robomimic.algo.bc.BC_RNN(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC
BC training with an RNN policy.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- class robomimic.algo.bc.BC_RNN_GMM(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC_RNN
BC training with an RNN GMM policy.
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- class robomimic.algo.bc.BC_VAE(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.bc.BC
BC training with a VAE policy.
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- train_on_batch(batch, epoch, validate=False)#
Update from superclass to set categorical temperature, for categorical VAEs.
robomimic.algo.bcq module#
Batch-Constrained Q-Learning (BCQ), with support for more general generative action models (the original paper uses a cVAE). (Paper - https://arxiv.org/abs/1812.02900).
- class robomimic.algo.bcq.BCQ(**kwargs)#
Bases:
robomimic.algo.algo.PolicyAlgo
,robomimic.algo.algo.ValueAlgo
Default BCQ training, based on https://arxiv.org/abs/1812.02900 and https://github.com/sfujim/BCQ
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- set_discount(discount)#
Useful function to modify discount factor if necessary (e.g. for n-step returns).
- set_train()#
Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
- class robomimic.algo.bcq.BCQ_Distributional(**kwargs)#
Bases:
robomimic.algo.bcq.BCQ
BCQ with distributional critics. Distributional critics output categorical distributions over a discrete set of values instead of expected returns. Some parts of this implementation were adapted from ACME (https://github.com/deepmind/acme).
- class robomimic.algo.bcq.BCQ_GMM(**kwargs)#
Bases:
robomimic.algo.bcq.BCQ
A simple modification to BCQ that replaces the VAE used to sample action proposals from the batch with a GMM.
robomimic.algo.cql module#
Implementation of Conservative Q-Learning (CQL). Based off of https://github.com/aviralkumar2907/CQL. (Paper - https://arxiv.org/abs/2006.04779).
- class robomimic.algo.cql.CQL(**kwargs)#
Bases:
robomimic.algo.algo.PolicyAlgo
,robomimic.algo.algo.ValueAlgo
CQL-extension of SAC for the off-policy, offline setting. See https://arxiv.org/abs/2006.04779
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- property log_cql_weight#
- property log_entropy_weight#
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant info and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- set_train()#
Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
robomimic.algo.gl module#
Subgoal prediction models, used in HBC / IRIS.
- class robomimic.algo.gl.GL(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.PlannerAlgo
Implements goal prediction component for HBC and IRIS.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs. Assumes one input observation (first dimension should be 1).
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- get_actor_goal_for_training_from_processed_batch(processed_batch, **kwargs)#
Retrieve subgoals from processed batch to use for training the actor. Subclasses can modify this function to change the subgoals.
- Parameters
processed_batch (dict) – processed batch from @process_batch_for_training
- Returns
subgoal observations to condition actor on
- Return type
actor_subgoals (dict)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Takes a batch of observations and predicts a batch of subgoals.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, …]
- Return type
subgoal prediction (dict)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- sample_subgoals(obs_dict, goal_dict=None, num_samples=1)#
Sample @num_samples subgoals from the network per observation. Since this class implements a deterministic subgoal prediction, this function returns identical subgoals for each input observation.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, num_samples, …]
- Return type
subgoals (dict)
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
- class robomimic.algo.gl.GL_VAE(algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.gl.GL
Implements goal prediction via VAE.
- get_actor_goal_for_training_from_processed_batch(processed_batch, use_latent_subgoals=False, use_prior_correction=False, num_prior_samples=100, **kwargs)#
Modify from superclass to support a @use_latent_subgoals option. The VAE can optionally return latent subgoals by passing the subgoal observations in the batch through the encoder.
- Parameters
processed_batch (dict) – processed batch from @process_batch_for_training
use_latent_subgoals (bool) – if True, condition the actor on latent subgoals by using the VAE encoder to encode subgoal observations at train-time, and using the VAE prior to generate latent subgoals at test-time
use_prior_correction (bool) – if True, use a “prior correction” trick to choose a latent subgoal sampled from the prior that is close to the latent from the VAE encoder (posterior). This can help with issues at test-time where the encoder latent distribution might not match the prior latent distribution.
num_prior_samples (int) – number of VAE prior samples to take and choose among, if @use_prior_correction is true
- Returns
subgoal observations to condition actor on
- Return type
actor_subgoals (dict)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Takes a batch of observations and predicts a batch of subgoals.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, …]
- Return type
subgoal prediction (dict)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- sample_subgoals(obs_dict, goal_dict=None, num_samples=1)#
Sample @num_samples subgoals from the VAE per observation.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, num_samples, …]
- Return type
subgoals (dict)
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
- class robomimic.algo.gl.ValuePlanner(planner_algo_class, value_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.PlannerAlgo
,robomimic.algo.algo.ValueAlgo
Base class for all algorithms that are used for planning subgoals based on (1) a @PlannerAlgo that is used to sample candidate subgoals and (2) a @ValueAlgo that is used to select one of the subgoals.
- deserialize(model_dict)#
Load model from a checkpoint.
- Parameters
model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- get_subgoal_predictions(obs_dict, goal_dict=None)#
Takes a batch of observations and predicts a batch of subgoals.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, …]
- Return type
subgoal prediction (dict)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- sample_subgoals(obs_dict, goal_dict, num_samples=1)#
Sample @num_samples subgoals from the planner algo per observation.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
name -> Tensor [batch_size, num_samples, …]
- Return type
subgoals (dict)
- serialize()#
Get dictionary of current model parameters.
- set_eval()#
Prepare networks for evaluation.
- set_train()#
Prepare networks for training.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
robomimic.algo.hbc module#
Implementation of Hierarchical Behavioral Cloning, where a planner model outputs subgoals (future observations), and an actor model is conditioned on the subgoals to try and reach them. Largely based on the Generalization Through Imitation (GTI) paper (see https://arxiv.org/abs/2003.06085).
- class robomimic.algo.hbc.HBC(planner_algo_class, policy_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.algo.HierarchicalAlgo
Default HBC training, largely based on https://arxiv.org/abs/2003.06085
- property current_subgoal#
Return the current subgoal (at rollout time) with shape (batch, …)
- deserialize(model_dict)#
Load model from a checkpoint.
- Parameters
model_dict (dict) – a dictionary saved by self.serialize() that contains the same keys as @self.network_classes
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- reset()#
Reset algo state to prepare for environment rollouts.
- serialize()#
Get dictionary of current model parameters.
- set_eval()#
Prepare networks for evaluation.
- set_train()#
Prepare networks for training.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)
robomimic.algo.iris module#
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
- class robomimic.algo.iris.IRIS(planner_algo_class, value_algo_class, policy_algo_class, algo_config, obs_config, global_config, obs_key_shapes, ac_dim, device)#
Bases:
robomimic.algo.hbc.HBC
,robomimic.algo.algo.ValueAlgo
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
robomimic.algo.td3_bc module#
Implementation of TD3-BC. Based on https://github.com/sfujim/TD3_BC (Paper - https://arxiv.org/abs/1812.02900).
Note that several parts are exactly the same as the BCQ implementation, such as @_create_critics, @process_batch_for_training, and @_train_critic_on_batch. They are replicated here (instead of subclassing from the BCQ algo class) to be explicit and have implementation details self-contained in this file.
- class robomimic.algo.td3_bc.TD3_BC(**kwargs)#
Bases:
robomimic.algo.algo.PolicyAlgo
,robomimic.algo.algo.ValueAlgo
Default TD3_BC training, based on https://arxiv.org/abs/2106.06860 and https://github.com/sfujim/TD3_BC.
- get_action(obs_dict, goal_dict=None)#
Get policy action outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
action tensor
- Return type
action (torch.Tensor)
- get_state_action_value(obs_dict, actions, goal_dict=None)#
Get state-action value outputs.
- Parameters
obs_dict (dict) – current observation
actions (torch.Tensor) – action
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- get_state_value(obs_dict, goal_dict=None)#
Get state value outputs.
- Parameters
obs_dict (dict) – current observation
goal_dict (dict) – (optional) goal
- Returns
value tensor
- Return type
value (torch.Tensor)
- log_info(info)#
Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging.
- Parameters
info (dict) – dictionary of info
- Returns
name -> summary statistic
- Return type
loss_log (dict)
- on_epoch_end(epoch)#
Called at the end of each epoch.
- process_batch_for_training(batch)#
Processes input batch from a data loader to filter out relevant information and prepare the batch for training.
Exactly the same as BCQ.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader
- Returns
- processed and filtered batch that
will be used for training
- Return type
input_batch (dict)
- set_discount(discount)#
Useful function to modify discount factor if necessary (e.g. for n-step returns).
- set_train()#
Prepare networks for evaluation. Update from super class to make sure target networks stay in evaluation mode all the time.
- train_on_batch(batch, epoch, validate=False)#
Training on a single batch of data.
- Parameters
batch (dict) – dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training
epoch (int) – epoch number - required by some Algos that need to perform staged training and early stopping
validate (bool) – if True, don’t perform any learning updates.
- Returns
- dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
- Return type
info (dict)