2018-12-21 03:50:44 +09:00
|
|
|
import logging
|
|
|
|
|
|
|
|
import ray
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2018-12-21 03:50:44 +09:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2020-03-14 12:05:04 -07:00
|
|
|
def collect_samples(agents, rollout_fragment_length, num_envs_per_worker,
|
2018-12-21 03:50:44 +09:00
|
|
|
train_batch_size):
|
|
|
|
"""Collects at least train_batch_size samples, never discarding any."""
|
|
|
|
|
|
|
|
num_timesteps_so_far = 0
|
|
|
|
trajectories = []
|
|
|
|
agent_dict = {}
|
|
|
|
|
|
|
|
for agent in agents:
|
|
|
|
fut_sample = agent.sample.remote()
|
|
|
|
agent_dict[fut_sample] = agent
|
|
|
|
|
|
|
|
while agent_dict:
|
|
|
|
[fut_sample], _ = ray.wait(list(agent_dict))
|
|
|
|
agent = agent_dict.pop(fut_sample)
|
2020-05-21 10:16:18 -07:00
|
|
|
next_sample = ray.get(fut_sample)
|
2018-12-21 03:50:44 +09:00
|
|
|
num_timesteps_so_far += next_sample.count
|
|
|
|
trajectories.append(next_sample)
|
|
|
|
|
|
|
|
# Only launch more tasks if we don't already have enough pending
|
2020-03-14 12:05:04 -07:00
|
|
|
pending = len(
|
|
|
|
agent_dict) * rollout_fragment_length * num_envs_per_worker
|
2018-12-21 03:50:44 +09:00
|
|
|
if num_timesteps_so_far + pending < train_batch_size:
|
|
|
|
fut_sample2 = agent.sample.remote()
|
|
|
|
agent_dict[fut_sample2] = agent
|
|
|
|
|
|
|
|
return SampleBatch.concat_samples(trajectories)
|