2020-04-30 01:18:09 -07:00
|
|
|
import logging
|
2022-06-14 01:57:27 -07:00
|
|
|
from typing import List, Optional, Union
|
2020-04-10 00:56:08 -07:00
|
|
|
|
2021-12-21 08:39:05 +01:00
|
|
|
import ray
|
2020-04-10 00:56:08 -07:00
|
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
2021-03-18 20:27:41 +01:00
|
|
|
from ray.rllib.execution.common import (
|
|
|
|
_check_sample_batch_type,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-30 01:18:09 -07:00
|
|
|
from ray.rllib.policy.sample_batch import (
|
|
|
|
SampleBatch,
|
|
|
|
DEFAULT_POLICY_ID,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-13 10:52:55 +01:00
|
|
|
from ray.rllib.utils.annotations import ExperimentalAPI
|
2020-04-30 01:18:09 -07:00
|
|
|
from ray.rllib.utils.sgd import standardized
|
2022-06-14 01:57:27 -07:00
|
|
|
from ray.rllib.utils.typing import SampleBatchType
|
2022-01-13 10:52:55 +01:00
|
|
|
|
2020-04-30 01:18:09 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
|
2022-01-13 10:52:55 +01:00
|
|
|
@ExperimentalAPI
|
|
|
|
def synchronous_parallel_sample(
|
2022-04-11 08:39:10 +02:00
|
|
|
*,
|
2022-01-13 10:52:55 +01:00
|
|
|
worker_set: WorkerSet,
|
2022-04-11 08:39:10 +02:00
|
|
|
max_agent_steps: Optional[int] = None,
|
|
|
|
max_env_steps: Optional[int] = None,
|
|
|
|
concat: bool = True,
|
|
|
|
) -> Union[List[SampleBatchType], SampleBatchType]:
|
2022-01-13 10:52:55 +01:00
|
|
|
"""Runs parallel and synchronous rollouts on all remote workers.
|
|
|
|
|
|
|
|
Waits for all workers to return from the remote calls.
|
|
|
|
|
|
|
|
If no remote workers exist (num_workers == 0), use the local worker
|
|
|
|
for sampling.
|
|
|
|
|
|
|
|
Alternatively to calling `worker.sample.remote()`, the user can provide a
|
|
|
|
`remote_fn()`, which will be applied to the worker(s) instead.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
worker_set: The WorkerSet to use for sampling.
|
|
|
|
remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
|
|
|
|
of `worker.sample.remote()` to generate the requests.
|
2022-04-11 08:39:10 +02:00
|
|
|
max_agent_steps: Optional number of agent steps to be included in the
|
|
|
|
final batch.
|
|
|
|
max_env_steps: Optional number of environment steps to be included in the
|
|
|
|
final batch.
|
|
|
|
concat: Whether to concat all resulting batches at the end and return the
|
|
|
|
concat'd batch.
|
2022-01-13 10:52:55 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
The list of collected sample batch types (one for each parallel
|
|
|
|
rollout worker in the given `worker_set`).
|
|
|
|
|
|
|
|
Examples:
|
2022-06-11 15:10:39 +02:00
|
|
|
>>> # Define an RLlib Algorithm.
|
|
|
|
>>> algorithm = ... # doctest: +SKIP
|
2022-01-13 10:52:55 +01:00
|
|
|
>>> # 2 remote workers (num_workers=2):
|
2022-06-11 15:10:39 +02:00
|
|
|
>>> batches = synchronous_parallel_sample(algorithm.workers) # doctest: +SKIP
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> print(len(batches)) # doctest: +SKIP
|
|
|
|
2
|
|
|
|
>>> print(batches[0]) # doctest: +SKIP
|
|
|
|
SampleBatch(16: ['obs', 'actions', 'rewards', 'dones'])
|
2022-01-13 10:52:55 +01:00
|
|
|
>>> # 0 remote workers (num_workers=0): Using the local worker.
|
2022-06-11 15:10:39 +02:00
|
|
|
>>> batches = synchronous_parallel_sample(algorithm.workers) # doctest: +SKIP
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> print(len(batches)) # doctest: +SKIP
|
|
|
|
1
|
2022-01-13 10:52:55 +01:00
|
|
|
"""
|
2022-04-11 08:39:10 +02:00
|
|
|
# Only allow one of `max_agent_steps` or `max_env_steps` to be defined.
|
|
|
|
assert not (max_agent_steps is not None and max_env_steps is not None)
|
2021-12-21 08:39:05 +01:00
|
|
|
|
2022-04-11 08:39:10 +02:00
|
|
|
agent_or_env_steps = 0
|
|
|
|
max_agent_or_env_steps = max_agent_steps or max_env_steps or None
|
|
|
|
all_sample_batches = []
|
2021-12-21 08:39:05 +01:00
|
|
|
|
2022-04-11 08:39:10 +02:00
|
|
|
# Stop collecting batches as soon as one criterium is met.
|
|
|
|
while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or (
|
|
|
|
max_agent_or_env_steps is not None
|
|
|
|
and agent_or_env_steps < max_agent_or_env_steps
|
|
|
|
):
|
|
|
|
# No remote workers in the set -> Use local worker for collecting
|
|
|
|
# samples.
|
|
|
|
if not worker_set.remote_workers():
|
|
|
|
sample_batches = [worker_set.local_worker().sample()]
|
|
|
|
# Loop over remote workers' `sample()` method in parallel.
|
|
|
|
else:
|
|
|
|
sample_batches = ray.get(
|
|
|
|
[worker.sample.remote() for worker in worker_set.remote_workers()]
|
|
|
|
)
|
|
|
|
# Update our counters for the stopping criterion of the while loop.
|
|
|
|
for b in sample_batches:
|
|
|
|
if max_agent_steps:
|
|
|
|
agent_or_env_steps += b.agent_steps()
|
|
|
|
else:
|
|
|
|
agent_or_env_steps += b.env_steps()
|
|
|
|
all_sample_batches.extend(sample_batches)
|
|
|
|
|
|
|
|
if concat is True:
|
|
|
|
full_batch = SampleBatch.concat_samples(all_sample_batches)
|
|
|
|
# Discard collected incomplete episodes in episode mode.
|
|
|
|
# if max_episodes is not None and episodes >= max_episodes:
|
|
|
|
# last_complete_ep_idx = len(full_batch) - full_batch[
|
|
|
|
# SampleBatch.DONES
|
|
|
|
# ].reverse().index(1)
|
|
|
|
# full_batch = full_batch.slice(0, last_complete_ep_idx)
|
|
|
|
return full_batch
|
|
|
|
else:
|
|
|
|
return all_sample_batches
|
2021-12-21 08:39:05 +01:00
|
|
|
|
|
|
|
|
2022-04-11 08:39:10 +02:00
|
|
|
def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType:
|
|
|
|
"""Standardize fields of the given SampleBatch"""
|
|
|
|
_check_sample_batch_type(samples)
|
|
|
|
wrapped = False
|
|
|
|
|
|
|
|
if isinstance(samples, SampleBatch):
|
|
|
|
samples = samples.as_multi_agent()
|
|
|
|
wrapped = True
|
|
|
|
|
|
|
|
for policy_id in samples.policy_batches:
|
|
|
|
batch = samples.policy_batches[policy_id]
|
|
|
|
for field in fields:
|
|
|
|
if field in batch:
|
|
|
|
batch[field] = standardized(batch[field])
|
|
|
|
|
|
|
|
if wrapped:
|
|
|
|
samples = samples.policy_batches[DEFAULT_POLICY_ID]
|
|
|
|
|
|
|
|
return samples
|