ray/rllib/optimizers/rollout.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

36 lines
1.2 KiB
Python

import logging
import ray
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
def collect_samples(agents, sample_batch_size, num_envs_per_worker,
train_batch_size):
"""Collects at least train_batch_size samples, never discarding any."""
num_timesteps_so_far = 0
trajectories = []
agent_dict = {}
for agent in agents:
fut_sample = agent.sample.remote()
agent_dict[fut_sample] = agent
while agent_dict:
[fut_sample], _ = ray.wait(list(agent_dict))
agent = agent_dict.pop(fut_sample)
next_sample = ray_get_and_free(fut_sample)
assert next_sample.count >= sample_batch_size * num_envs_per_worker
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)
# Only launch more tasks if we don't already have enough pending
pending = len(agent_dict) * sample_batch_size * num_envs_per_worker
if num_timesteps_so_far + pending < train_batch_size:
fut_sample2 = agent.sample.remote()
agent_dict[fut_sample2] = agent
return SampleBatch.concat_samples(trajectories)