2020-01-25 22:36:43 -08:00
|
|
|
"""Utils for minibatch SGD across multiple RLlib policies."""
|
|
|
|
|
|
|
|
import logging
|
2021-09-30 16:39:05 +02:00
|
|
|
import numpy as np
|
2020-01-25 22:36:43 -08:00
|
|
|
import random
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
2022-01-05 18:22:33 +01:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
2020-01-25 22:36:43 -08:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2021-09-30 16:39:05 +02:00
|
|
|
def standardized(array: np.ndarray):
|
2020-01-25 22:36:43 -08:00
|
|
|
"""Normalize the values in an array.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2020-01-25 22:36:43 -08:00
|
|
|
array (np.ndarray): Array of values to normalize.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
array with zero mean and unit standard deviation.
|
|
|
|
"""
|
|
|
|
return (array - array.mean()) / max(1e-4, array.std())
|
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2021-09-23 08:31:51 +02:00
|
|
|
def minibatches(samples: SampleBatch, sgd_minibatch_size: int, shuffle: bool = True):
|
2020-01-25 22:36:43 -08:00
|
|
|
"""Return a generator yielding minibatches from a sample batch.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2021-09-23 08:31:51 +02:00
|
|
|
samples: SampleBatch to split up.
|
|
|
|
sgd_minibatch_size: Size of minibatches to return.
|
|
|
|
shuffle: Whether to shuffle the order of the generated minibatches.
|
|
|
|
Note that in case of a non-recurrent policy, the incoming batch
|
|
|
|
is globally shuffled first regardless of this setting, before
|
|
|
|
the minibatches are generated from it!
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
SampleBatch: Each of size `sgd_minibatch_size`.
|
2020-01-25 22:36:43 -08:00
|
|
|
"""
|
|
|
|
if not sgd_minibatch_size:
|
|
|
|
yield samples
|
|
|
|
return
|
|
|
|
|
|
|
|
if isinstance(samples, MultiAgentBatch):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Minibatching not implemented for multi-agent in simple mode"
|
|
|
|
)
|
|
|
|
|
2021-05-18 11:51:05 +02:00
|
|
|
if "state_in_0" not in samples and "state_out_0" not in samples:
|
2020-01-25 22:36:43 -08:00
|
|
|
samples.shuffle()
|
|
|
|
|
2021-05-18 11:51:05 +02:00
|
|
|
all_slices = samples._get_slice_indices(sgd_minibatch_size)
|
|
|
|
data_slices, state_slices = all_slices
|
2020-01-25 22:36:43 -08:00
|
|
|
|
2021-05-18 11:51:05 +02:00
|
|
|
if len(state_slices) == 0:
|
|
|
|
if shuffle:
|
|
|
|
random.shuffle(data_slices)
|
|
|
|
for i, j in data_slices:
|
|
|
|
yield samples.slice(i, j)
|
|
|
|
else:
|
|
|
|
all_slices = list(zip(data_slices, state_slices))
|
|
|
|
if shuffle:
|
|
|
|
# Make sure to shuffle data and states while linked together.
|
|
|
|
random.shuffle(all_slices)
|
|
|
|
for (i, j), (si, sj) in all_slices:
|
|
|
|
yield samples.slice(i, j, si, sj)
|
2020-01-25 22:36:43 -08:00
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
@DeveloperAPI
|
2020-01-25 22:36:43 -08:00
|
|
|
def do_minibatch_sgd(
|
|
|
|
samples,
|
|
|
|
policies,
|
|
|
|
local_worker,
|
|
|
|
num_sgd_iter,
|
|
|
|
sgd_minibatch_size,
|
|
|
|
standardize_fields,
|
|
|
|
):
|
|
|
|
"""Execute minibatch SGD.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2021-02-08 15:02:19 +01:00
|
|
|
samples (SampleBatch): Batch of samples to optimize.
|
|
|
|
policies (dict): Dictionary of policies to optimize.
|
|
|
|
local_worker (RolloutWorker): Master rollout worker instance.
|
|
|
|
num_sgd_iter (int): Number of epochs of optimization to take.
|
|
|
|
sgd_minibatch_size (int): Size of minibatches to use for optimization.
|
|
|
|
standardize_fields (list): List of sample field names that should be
|
2020-01-25 22:36:43 -08:00
|
|
|
normalized prior to optimization.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
averaged info fetches over the last SGD epoch taken.
|
|
|
|
"""
|
2022-01-05 18:22:33 +01:00
|
|
|
|
|
|
|
# Handle everything as if multi-agent.
|
|
|
|
samples = samples.as_multi_agent()
|
2020-01-25 22:36:43 -08:00
|
|
|
|
2021-09-30 16:39:05 +02:00
|
|
|
# Use LearnerInfoBuilder as a unified way to build the final
|
|
|
|
# results dict from `learn_on_loaded_batch` call(s).
|
|
|
|
# This makes sure results dicts always have the same structure
|
|
|
|
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
|
|
|
|
# tf vs torch).
|
|
|
|
learner_info_builder = LearnerInfoBuilder(num_devices=1)
|
2021-11-30 12:00:07 -05:00
|
|
|
for policy_id, policy in policies.items():
|
2020-01-25 22:36:43 -08:00
|
|
|
if policy_id not in samples.policy_batches:
|
|
|
|
continue
|
|
|
|
|
|
|
|
batch = samples.policy_batches[policy_id]
|
|
|
|
for field in standardize_fields:
|
|
|
|
batch[field] = standardized(batch[field])
|
|
|
|
|
2021-11-30 12:00:07 -05:00
|
|
|
# Check to make sure that the sgd_minibatch_size is not smaller
|
|
|
|
# than max_seq_len otherwise this will cause indexing errors while
|
|
|
|
# performing sgd when using a RNN or Attention model
|
|
|
|
if (
|
|
|
|
policy.is_recurrent()
|
|
|
|
and policy.config["model"]["max_seq_len"] > sgd_minibatch_size
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
"`sgd_minibatch_size` ({}) cannot be smaller than"
|
|
|
|
"`max_seq_len` ({}).".format(
|
|
|
|
sgd_minibatch_size, policy.config["model"]["max_seq_len"]
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-11-30 12:00:07 -05:00
|
|
|
|
2020-01-25 22:36:43 -08:00
|
|
|
for i in range(num_sgd_iter):
|
|
|
|
for minibatch in minibatches(batch, sgd_minibatch_size):
|
2021-09-30 16:39:05 +02:00
|
|
|
results = (
|
|
|
|
local_worker.learn_on_batch(
|
2020-01-25 22:36:43 -08:00
|
|
|
MultiAgentBatch({policy_id: minibatch}, minibatch.count)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-01-25 22:36:43 -08:00
|
|
|
)[policy_id]
|
2021-09-30 16:39:05 +02:00
|
|
|
learner_info_builder.add_learn_on_batch_results(results, policy_id)
|
|
|
|
|
|
|
|
learner_info = learner_info_builder.finalize()
|
|
|
|
return learner_info
|