2020-04-10 00:56:08 -07:00
|
|
|
import logging
|
2020-04-30 01:18:09 -07:00
|
|
|
import numpy as np
|
|
|
|
import math
|
2021-12-21 08:39:05 +01:00
|
|
|
from typing import Dict, List, Tuple, Any
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.execution.common import (
|
|
|
|
AGENT_STEPS_TRAINED_COUNTER,
|
|
|
|
APPLY_GRADS_TIMER,
|
|
|
|
COMPUTE_GRADS_TIMER,
|
|
|
|
LAST_TARGET_UPDATE_TS,
|
|
|
|
LEARN_ON_BATCH_TIMER,
|
|
|
|
LOAD_BATCH_TIMER,
|
|
|
|
NUM_TARGET_UPDATES,
|
|
|
|
STEPS_SAMPLED_COUNTER,
|
|
|
|
STEPS_TRAINED_COUNTER,
|
|
|
|
STEPS_TRAINED_THIS_ITER_COUNTER,
|
|
|
|
WORKER_UPDATE_TIMER,
|
|
|
|
_check_sample_batch_type,
|
|
|
|
_get_global_vars,
|
|
|
|
_get_shared_metrics,
|
|
|
|
)
|
2022-01-05 18:22:33 +01:00
|
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch
|
2021-12-21 08:39:05 +01:00
|
|
|
from ray.rllib.utils.annotations import ExperimentalAPI
|
2022-01-05 18:22:33 +01:00
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.utils.metrics import NUM_ENV_STEPS_TRAINED, NUM_AGENT_STEPS_TRAINED
|
|
|
|
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
|
2021-03-08 15:41:27 +01:00
|
|
|
from ray.rllib.utils.sgd import do_minibatch_sgd
|
2020-12-24 06:30:33 -08:00
|
|
|
from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
|
2020-04-30 01:18:09 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2021-12-21 08:39:05 +01:00
|
|
|
@ExperimentalAPI
|
|
|
|
def train_one_step(trainer, train_batch) -> Dict:
|
|
|
|
config = trainer.config
|
|
|
|
workers = trainer.workers
|
|
|
|
local_worker = workers.local_worker()
|
2022-01-18 15:00:27 +00:00
|
|
|
num_sgd_iter = config.get("num_sgd_iter", 1)
|
2021-12-21 08:39:05 +01:00
|
|
|
sgd_minibatch_size = config.get("sgd_minibatch_size", 0)
|
|
|
|
|
|
|
|
learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER]
|
|
|
|
with learn_timer:
|
|
|
|
# Subsample minibatches (size=`sgd_minibatch_size`) from the
|
|
|
|
# train batch and loop through train batch `num_sgd_iter` times.
|
|
|
|
if num_sgd_iter > 1 or sgd_minibatch_size > 0:
|
|
|
|
info = do_minibatch_sgd(
|
2022-01-29 18:41:57 -08:00
|
|
|
train_batch,
|
|
|
|
{
|
2022-01-27 12:17:34 +01:00
|
|
|
pid: local_worker.get_policy(pid)
|
|
|
|
for pid in local_worker.get_policies_to_train(train_batch)
|
2022-01-29 18:41:57 -08:00
|
|
|
},
|
|
|
|
local_worker,
|
|
|
|
num_sgd_iter,
|
|
|
|
sgd_minibatch_size,
|
|
|
|
[],
|
|
|
|
)
|
2021-12-21 08:39:05 +01:00
|
|
|
# Single update step using train batch.
|
|
|
|
else:
|
|
|
|
info = local_worker.learn_on_batch(train_batch)
|
|
|
|
|
|
|
|
learn_timer.push_units_processed(train_batch.count)
|
|
|
|
trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
|
|
|
|
trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
|
|
|
|
|
|
|
return info
|
|
|
|
|
|
|
|
|
2022-01-05 18:22:33 +01:00
|
|
|
@ExperimentalAPI
|
|
|
|
def multi_gpu_train_one_step(trainer, train_batch) -> Dict:
|
|
|
|
config = trainer.config
|
|
|
|
workers = trainer.workers
|
|
|
|
local_worker = workers.local_worker()
|
|
|
|
num_sgd_iter = config.get("sgd_num_iter", 1)
|
2022-01-29 18:41:57 -08:00
|
|
|
sgd_minibatch_size = config.get("sgd_minibatch_size", config["train_batch_size"])
|
2022-01-05 18:22:33 +01:00
|
|
|
|
|
|
|
# Determine the number of devices (GPUs or 1 CPU) we use.
|
|
|
|
num_devices = int(math.ceil(config["num_gpus"] or 1))
|
|
|
|
|
|
|
|
# Make sure total batch size is dividable by the number of devices.
|
|
|
|
# Batch size per tower.
|
|
|
|
per_device_batch_size = sgd_minibatch_size // num_devices
|
|
|
|
# Total batch size.
|
|
|
|
batch_size = per_device_batch_size * num_devices
|
|
|
|
assert batch_size % num_devices == 0
|
|
|
|
assert batch_size >= num_devices, "Batch size too small!"
|
|
|
|
|
|
|
|
# Handle everything as if multi-agent.
|
|
|
|
train_batch = train_batch.as_multi_agent()
|
|
|
|
|
|
|
|
# Load data into GPUs.
|
|
|
|
load_timer = trainer._timers[LOAD_BATCH_TIMER]
|
|
|
|
with load_timer:
|
|
|
|
num_loaded_samples = {}
|
|
|
|
for policy_id, batch in train_batch.policy_batches.items():
|
|
|
|
# Not a policy-to-train.
|
2022-01-27 12:17:34 +01:00
|
|
|
if not local_worker.is_policy_to_train(policy_id, train_batch):
|
2022-01-05 18:22:33 +01:00
|
|
|
continue
|
|
|
|
|
|
|
|
# Decompress SampleBatch, in case some columns are compressed.
|
|
|
|
batch.decompress_if_needed()
|
|
|
|
|
|
|
|
# Load the entire train batch into the Policy's only buffer
|
|
|
|
# (idx=0). Policies only have >1 buffers, if we are training
|
|
|
|
# asynchronously.
|
|
|
|
num_loaded_samples[policy_id] = local_worker.policy_map[
|
2022-01-29 18:41:57 -08:00
|
|
|
policy_id
|
|
|
|
].load_batch_into_buffer(batch, buffer_index=0)
|
2022-01-05 18:22:33 +01:00
|
|
|
|
|
|
|
# Execute minibatch SGD on loaded data.
|
|
|
|
learn_timer = trainer._timers[LEARN_ON_BATCH_TIMER]
|
|
|
|
with learn_timer:
|
|
|
|
# Use LearnerInfoBuilder as a unified way to build the final
|
|
|
|
# results dict from `learn_on_loaded_batch` call(s).
|
|
|
|
# This makes sure results dicts always have the same structure
|
|
|
|
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
|
|
|
|
# tf vs torch).
|
|
|
|
learner_info_builder = LearnerInfoBuilder(num_devices=num_devices)
|
|
|
|
|
|
|
|
for policy_id, samples_per_device in num_loaded_samples.items():
|
|
|
|
policy = local_worker.policy_map[policy_id]
|
2022-01-29 18:41:57 -08:00
|
|
|
num_batches = max(1, int(samples_per_device) // int(per_device_batch_size))
|
2022-01-05 18:22:33 +01:00
|
|
|
logger.debug("== sgd epochs for {} ==".format(policy_id))
|
|
|
|
for _ in range(num_sgd_iter):
|
|
|
|
permutation = np.random.permutation(num_batches)
|
|
|
|
for batch_index in range(num_batches):
|
|
|
|
# Learn on the pre-loaded data in the buffer.
|
|
|
|
# Note: For minibatch SGD, the data is an offset into
|
|
|
|
# the pre-loaded entire train batch.
|
|
|
|
results = policy.learn_on_loaded_batch(
|
2022-01-29 18:41:57 -08:00
|
|
|
permutation[batch_index] * per_device_batch_size, buffer_index=0
|
|
|
|
)
|
2022-01-05 18:22:33 +01:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
learner_info_builder.add_learn_on_batch_results(results, policy_id)
|
2022-01-05 18:22:33 +01:00
|
|
|
|
|
|
|
# Tower reduce and finalize results.
|
|
|
|
learner_info = learner_info_builder.finalize()
|
|
|
|
|
|
|
|
load_timer.push_units_processed(train_batch.count)
|
|
|
|
learn_timer.push_units_processed(train_batch.count)
|
|
|
|
|
|
|
|
trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
|
|
|
|
trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
|
|
|
|
|
|
|
# Update weights - after learning on the local worker - on all remote
|
|
|
|
# workers.
|
|
|
|
if workers.remote_workers():
|
|
|
|
with trainer._timers[WORKER_UPDATE_TIMER]:
|
2022-01-27 12:17:34 +01:00
|
|
|
weights = ray.put(
|
|
|
|
local_worker.get_weights(
|
2022-01-29 18:41:57 -08:00
|
|
|
local_worker.get_policies_to_train(train_batch)
|
|
|
|
)
|
|
|
|
)
|
2022-01-05 18:22:33 +01:00
|
|
|
for e in workers.remote_workers():
|
|
|
|
e.set_weights.remote(weights)
|
|
|
|
|
|
|
|
return learner_info
|
|
|
|
|
|
|
|
|
2020-04-10 00:56:08 -07:00
|
|
|
class TrainOneStep:
|
|
|
|
"""Callable that improves the policy and updates workers.
|
|
|
|
|
2020-04-23 12:39:19 -07:00
|
|
|
This should be used with the .for_each() operator. A tuple of the input
|
|
|
|
and learner stats will be returned.
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> rollouts = ParallelRollouts(...)
|
|
|
|
>>> train_op = rollouts.for_each(TrainOneStep(workers))
|
|
|
|
>>> print(next(train_op)) # This trains the policy on one batch.
|
2020-04-23 12:39:19 -07:00
|
|
|
SampleBatch(...), {"learner_stats": ...}
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the
|
|
|
|
local iterator context.
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
workers: WorkerSet,
|
|
|
|
policies: List[PolicyID] = frozenset([]),
|
|
|
|
num_sgd_iter: int = 1,
|
|
|
|
sgd_minibatch_size: int = 0,
|
|
|
|
):
|
2020-04-10 00:56:08 -07:00
|
|
|
self.workers = workers
|
2021-06-21 13:46:01 +02:00
|
|
|
self.local_worker = workers.local_worker()
|
|
|
|
self.policies = policies
|
2020-04-30 01:18:09 -07:00
|
|
|
self.num_sgd_iter = num_sgd_iter
|
|
|
|
self.sgd_minibatch_size = sgd_minibatch_size
|
2020-04-10 00:56:08 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]):
|
2020-04-10 00:56:08 -07:00
|
|
|
_check_sample_batch_type(batch)
|
2020-05-21 10:16:18 -07:00
|
|
|
metrics = _get_shared_metrics()
|
2020-04-10 00:56:08 -07:00
|
|
|
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
2022-01-27 12:17:34 +01:00
|
|
|
lw = self.local_worker
|
2020-04-10 00:56:08 -07:00
|
|
|
with learn_timer:
|
2021-07-01 13:01:40 +02:00
|
|
|
# Subsample minibatches (size=`sgd_minibatch_size`) from the
|
|
|
|
# train batch and loop through train batch `num_sgd_iter` times.
|
2020-04-30 01:18:09 -07:00
|
|
|
if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
|
2021-09-30 16:39:05 +02:00
|
|
|
learner_info = do_minibatch_sgd(
|
2022-01-29 18:41:57 -08:00
|
|
|
batch,
|
|
|
|
{
|
2021-06-21 13:46:01 +02:00
|
|
|
pid: lw.get_policy(pid)
|
2022-01-29 18:41:57 -08:00
|
|
|
for pid in self.policies or lw.get_policies_to_train(batch)
|
|
|
|
},
|
|
|
|
lw,
|
|
|
|
self.num_sgd_iter,
|
|
|
|
self.sgd_minibatch_size,
|
|
|
|
[],
|
|
|
|
)
|
2021-07-01 13:01:40 +02:00
|
|
|
# Single update step using train batch.
|
2020-04-30 01:18:09 -07:00
|
|
|
else:
|
2022-01-27 12:17:34 +01:00
|
|
|
learner_info = lw.learn_on_batch(batch)
|
2021-07-01 13:01:40 +02:00
|
|
|
|
2021-09-30 16:39:05 +02:00
|
|
|
metrics.info[LEARNER_INFO] = learner_info
|
2020-04-10 00:56:08 -07:00
|
|
|
learn_timer.push_units_processed(batch.count)
|
|
|
|
metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
|
2021-11-22 12:46:45 -08:00
|
|
|
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count
|
2021-03-18 20:27:41 +01:00
|
|
|
if isinstance(batch, MultiAgentBatch):
|
2022-01-29 18:41:57 -08:00
|
|
|
metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps()
|
2021-02-08 15:02:19 +01:00
|
|
|
# Update weights - after learning on the local worker - on all remote
|
|
|
|
# workers.
|
2020-04-10 00:56:08 -07:00
|
|
|
if self.workers.remote_workers():
|
|
|
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
2022-01-27 12:17:34 +01:00
|
|
|
weights = ray.put(
|
2022-01-29 18:41:57 -08:00
|
|
|
lw.get_weights(self.policies or lw.get_policies_to_train(batch))
|
|
|
|
)
|
2020-04-10 00:56:08 -07:00
|
|
|
for e in self.workers.remote_workers():
|
|
|
|
e.set_weights.remote(weights, _get_global_vars())
|
|
|
|
# Also update global vars of the local worker.
|
2022-01-27 12:17:34 +01:00
|
|
|
lw.set_global_vars(_get_global_vars())
|
2021-09-30 16:39:05 +02:00
|
|
|
return batch, learner_info
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
class MultiGPUTrainOneStep:
|
|
|
|
"""Multi-GPU version of TrainOneStep.
|
2020-04-30 01:18:09 -07:00
|
|
|
|
|
|
|
This should be used with the .for_each() operator. A tuple of the input
|
|
|
|
and learner stats will be returned.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> rollouts = ParallelRollouts(...)
|
2021-07-20 14:58:13 -04:00
|
|
|
>>> train_op = rollouts.for_each(MultiGPUTrainOneStep(workers, ...))
|
2020-04-30 01:18:09 -07:00
|
|
|
>>> print(next(train_op)) # This trains the policy on one batch.
|
|
|
|
SampleBatch(...), {"learner_stats": ...}
|
|
|
|
|
|
|
|
Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the
|
|
|
|
local iterator context.
|
|
|
|
"""
|
|
|
|
|
2022-01-05 18:22:33 +01:00
|
|
|
def __init__(
|
2022-01-29 18:41:57 -08:00
|
|
|
self,
|
|
|
|
*,
|
|
|
|
workers: WorkerSet,
|
|
|
|
sgd_minibatch_size: int,
|
|
|
|
num_sgd_iter: int,
|
|
|
|
num_gpus: int,
|
|
|
|
_fake_gpus: bool = False,
|
|
|
|
# Deprecated args.
|
|
|
|
shuffle_sequences=DEPRECATED_VALUE,
|
|
|
|
framework=DEPRECATED_VALUE
|
|
|
|
):
|
|
|
|
if framework != DEPRECATED_VALUE or shuffle_sequences != DEPRECATED_VALUE:
|
2022-01-05 18:22:33 +01:00
|
|
|
deprecation_warning(
|
2022-01-29 18:41:57 -08:00
|
|
|
old="MultiGPUTrainOneStep(framework=..., " "shuffle_sequences=...)",
|
|
|
|
error=False,
|
|
|
|
)
|
2022-01-05 18:22:33 +01:00
|
|
|
|
2020-04-30 01:18:09 -07:00
|
|
|
self.workers = workers
|
2021-06-21 13:46:01 +02:00
|
|
|
self.local_worker = workers.local_worker()
|
2020-04-30 01:18:09 -07:00
|
|
|
self.num_sgd_iter = num_sgd_iter
|
|
|
|
self.sgd_minibatch_size = sgd_minibatch_size
|
|
|
|
self.shuffle_sequences = shuffle_sequences
|
|
|
|
|
2021-03-08 15:41:27 +01:00
|
|
|
# Collect actual GPU devices to use.
|
2020-04-30 01:18:09 -07:00
|
|
|
if not num_gpus:
|
|
|
|
_fake_gpus = True
|
|
|
|
num_gpus = 1
|
|
|
|
type_ = "cpu" if _fake_gpus else "gpu"
|
|
|
|
self.devices = [
|
2021-05-18 11:51:05 +02:00
|
|
|
"/{}:{}".format(type_, 0 if _fake_gpus else i)
|
|
|
|
for i in range(int(math.ceil(num_gpus)))
|
2020-04-30 01:18:09 -07:00
|
|
|
]
|
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
# Make sure total batch size is dividable by the number of devices.
|
2021-03-08 15:41:27 +01:00
|
|
|
# Batch size per tower.
|
2021-07-20 14:58:13 -04:00
|
|
|
self.per_device_batch_size = sgd_minibatch_size // len(self.devices)
|
|
|
|
# Total batch size.
|
|
|
|
self.batch_size = self.per_device_batch_size * len(self.devices)
|
|
|
|
assert self.batch_size % len(self.devices) == 0
|
|
|
|
assert self.batch_size >= len(self.devices), "Batch size too small!"
|
2020-04-30 01:18:09 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]):
|
2020-04-30 01:18:09 -07:00
|
|
|
_check_sample_batch_type(samples)
|
|
|
|
|
2021-06-21 13:46:01 +02:00
|
|
|
# Handle everything as if multi agent.
|
2022-01-05 18:22:33 +01:00
|
|
|
samples = samples.as_multi_agent()
|
2020-04-30 01:18:09 -07:00
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
metrics = _get_shared_metrics()
|
2020-04-30 01:18:09 -07:00
|
|
|
load_timer = metrics.timers[LOAD_BATCH_TIMER]
|
|
|
|
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
2021-03-08 15:41:27 +01:00
|
|
|
# Load data into GPUs.
|
2020-04-30 01:18:09 -07:00
|
|
|
with load_timer:
|
2021-08-03 11:35:49 -04:00
|
|
|
num_loaded_samples = {}
|
2020-04-30 01:18:09 -07:00
|
|
|
for policy_id, batch in samples.policy_batches.items():
|
2021-02-18 21:36:32 +01:00
|
|
|
# Not a policy-to-train.
|
2022-01-29 18:41:57 -08:00
|
|
|
if not self.local_worker.is_policy_to_train(policy_id, samples):
|
2020-04-30 01:18:09 -07:00
|
|
|
continue
|
|
|
|
|
2021-02-18 21:36:32 +01:00
|
|
|
# Decompress SampleBatch, in case some columns are compressed.
|
|
|
|
batch.decompress_if_needed()
|
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
# Load the entire train batch into the Policy's only buffer
|
|
|
|
# (idx=0). Policies only have >1 buffers, if we are training
|
|
|
|
# asynchronously.
|
2021-08-03 11:35:49 -04:00
|
|
|
num_loaded_samples[policy_id] = self.local_worker.policy_map[
|
2022-01-29 18:41:57 -08:00
|
|
|
policy_id
|
|
|
|
].load_batch_into_buffer(batch, buffer_index=0)
|
2020-04-30 01:18:09 -07:00
|
|
|
|
2021-03-08 15:41:27 +01:00
|
|
|
# Execute minibatch SGD on loaded data.
|
2020-04-30 01:18:09 -07:00
|
|
|
with learn_timer:
|
2021-09-30 16:39:05 +02:00
|
|
|
# Use LearnerInfoBuilder as a unified way to build the final
|
|
|
|
# results dict from `learn_on_loaded_batch` call(s).
|
|
|
|
# This makes sure results dicts always have the same structure
|
|
|
|
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
|
|
|
|
# tf vs torch).
|
2022-01-29 18:41:57 -08:00
|
|
|
learner_info_builder = LearnerInfoBuilder(num_devices=len(self.devices))
|
2021-09-30 16:39:05 +02:00
|
|
|
|
2021-08-03 11:35:49 -04:00
|
|
|
for policy_id, samples_per_device in num_loaded_samples.items():
|
2021-07-20 14:58:13 -04:00
|
|
|
policy = self.local_worker.policy_map[policy_id]
|
2020-04-30 01:18:09 -07:00
|
|
|
num_batches = max(
|
2022-01-29 18:41:57 -08:00
|
|
|
1, int(samples_per_device) // int(self.per_device_batch_size)
|
|
|
|
)
|
2020-04-30 01:18:09 -07:00
|
|
|
logger.debug("== sgd epochs for {} ==".format(policy_id))
|
2021-03-08 15:41:27 +01:00
|
|
|
for _ in range(self.num_sgd_iter):
|
2020-04-30 01:18:09 -07:00
|
|
|
permutation = np.random.permutation(num_batches)
|
|
|
|
for batch_index in range(num_batches):
|
2021-07-20 14:58:13 -04:00
|
|
|
# Learn on the pre-loaded data in the buffer.
|
|
|
|
# Note: For minibatch SGD, the data is an offset into
|
|
|
|
# the pre-loaded entire train batch.
|
2021-09-30 16:39:05 +02:00
|
|
|
results = policy.learn_on_loaded_batch(
|
2022-01-29 18:41:57 -08:00
|
|
|
permutation[batch_index] * self.per_device_batch_size,
|
|
|
|
buffer_index=0,
|
|
|
|
)
|
2021-07-20 14:58:13 -04:00
|
|
|
|
2021-09-30 16:39:05 +02:00
|
|
|
learner_info_builder.add_learn_on_batch_results(
|
2022-01-29 18:41:57 -08:00
|
|
|
results, policy_id
|
|
|
|
)
|
2021-09-30 16:39:05 +02:00
|
|
|
|
|
|
|
# Tower reduce and finalize results.
|
|
|
|
learner_info = learner_info_builder.finalize()
|
2020-04-30 01:18:09 -07:00
|
|
|
|
|
|
|
load_timer.push_units_processed(samples.count)
|
|
|
|
learn_timer.push_units_processed(samples.count)
|
|
|
|
|
|
|
|
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
|
2021-11-22 12:46:45 -08:00
|
|
|
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
|
2021-03-18 20:27:41 +01:00
|
|
|
metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
|
2021-09-30 16:39:05 +02:00
|
|
|
metrics.info[LEARNER_INFO] = learner_info
|
2021-07-20 14:58:13 -04:00
|
|
|
|
2020-04-30 01:18:09 -07:00
|
|
|
if self.workers.remote_workers():
|
|
|
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
2022-01-29 18:41:57 -08:00
|
|
|
weights = ray.put(
|
|
|
|
self.workers.local_worker().get_weights(
|
|
|
|
self.local_worker.get_policies_to_train()
|
|
|
|
)
|
|
|
|
)
|
2020-04-30 01:18:09 -07:00
|
|
|
for e in self.workers.remote_workers():
|
|
|
|
e.set_weights.remote(weights, _get_global_vars())
|
2021-07-20 14:58:13 -04:00
|
|
|
|
2020-04-30 01:18:09 -07:00
|
|
|
# Also update global vars of the local worker.
|
|
|
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
2021-09-30 16:39:05 +02:00
|
|
|
return samples, learner_info
|
2020-04-30 01:18:09 -07:00
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
|
|
|
|
# Backward compatibility.
|
|
|
|
TrainTFMultiGPU = MultiGPUTrainOneStep
|
2021-07-11 23:41:38 +02:00
|
|
|
|
2021-04-16 09:16:24 +02:00
|
|
|
|
2020-04-10 00:56:08 -07:00
|
|
|
class ComputeGradients:
|
|
|
|
"""Callable that computes gradients with respect to the policy loss.
|
|
|
|
|
|
|
|
This should be used with the .for_each() operator.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> grads_op = rollouts.for_each(ComputeGradients(workers))
|
|
|
|
>>> print(next(grads_op))
|
|
|
|
{"var_0": ..., ...}, 50 # grads, batch count
|
|
|
|
|
|
|
|
Updates the LEARNER_INFO info field in the local iterator context.
|
|
|
|
"""
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def __init__(self, workers: WorkerSet):
|
2020-04-10 00:56:08 -07:00
|
|
|
self.workers = workers
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]:
|
2020-04-10 00:56:08 -07:00
|
|
|
_check_sample_batch_type(samples)
|
2020-05-21 10:16:18 -07:00
|
|
|
metrics = _get_shared_metrics()
|
2020-04-10 00:56:08 -07:00
|
|
|
with metrics.timers[COMPUTE_GRADS_TIMER]:
|
|
|
|
grad, info = self.workers.local_worker().compute_gradients(samples)
|
2021-09-30 16:39:05 +02:00
|
|
|
# RolloutWorker.compute_gradients returns pure single agent stats
|
|
|
|
# in a non-multi agent setup.
|
|
|
|
if isinstance(samples, MultiAgentBatch):
|
|
|
|
metrics.info[LEARNER_INFO] = info
|
|
|
|
else:
|
|
|
|
metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info}
|
2020-04-10 00:56:08 -07:00
|
|
|
return grad, samples.count
|
|
|
|
|
|
|
|
|
|
|
|
class ApplyGradients:
|
|
|
|
"""Callable that applies gradients and updates workers.
|
|
|
|
|
|
|
|
This should be used with the .for_each() operator.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> apply_op = grads_op.for_each(ApplyGradients(workers))
|
|
|
|
>>> print(next(apply_op))
|
|
|
|
None
|
|
|
|
|
|
|
|
Updates the STEPS_TRAINED_COUNTER counter in the local iterator context.
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self, workers, policies: List[PolicyID] = frozenset([]), update_all=True
|
|
|
|
):
|
2020-04-10 00:56:08 -07:00
|
|
|
"""Creates an ApplyGradients instance.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2020-04-10 00:56:08 -07:00
|
|
|
workers (WorkerSet): workers to apply gradients to.
|
|
|
|
update_all (bool): If true, updates all workers. Otherwise, only
|
|
|
|
update the worker that produced the sample batch we are
|
|
|
|
currently processing (i.e., A3C style).
|
|
|
|
"""
|
|
|
|
self.workers = workers
|
2021-06-21 13:46:01 +02:00
|
|
|
self.local_worker = workers.local_worker()
|
|
|
|
self.policies = policies
|
2020-04-10 00:56:08 -07:00
|
|
|
self.update_all = update_all
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def __call__(self, item: Tuple[ModelGradients, int]) -> None:
|
2020-04-10 00:56:08 -07:00
|
|
|
if not isinstance(item, tuple) or len(item) != 2:
|
|
|
|
raise ValueError(
|
2022-01-29 18:41:57 -08:00
|
|
|
"Input must be a tuple of (grad_dict, count), got {}".format(item)
|
|
|
|
)
|
2020-04-10 00:56:08 -07:00
|
|
|
gradients, count = item
|
2020-05-21 10:16:18 -07:00
|
|
|
metrics = _get_shared_metrics()
|
2020-04-10 00:56:08 -07:00
|
|
|
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
2021-11-22 12:46:45 -08:00
|
|
|
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
apply_timer = metrics.timers[APPLY_GRADS_TIMER]
|
|
|
|
with apply_timer:
|
2022-01-27 12:17:34 +01:00
|
|
|
self.local_worker.apply_gradients(gradients)
|
2020-04-10 00:56:08 -07:00
|
|
|
apply_timer.push_units_processed(count)
|
|
|
|
|
|
|
|
# Also update global vars of the local worker.
|
2022-01-27 12:17:34 +01:00
|
|
|
self.local_worker.set_global_vars(_get_global_vars())
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
if self.update_all:
|
|
|
|
if self.workers.remote_workers():
|
|
|
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
2022-01-27 12:17:34 +01:00
|
|
|
weights = ray.put(
|
|
|
|
self.local_worker.get_weights(
|
2022-01-29 18:41:57 -08:00
|
|
|
self.policies or self.local_worker.get_policies_to_train()
|
|
|
|
)
|
|
|
|
)
|
2020-04-10 00:56:08 -07:00
|
|
|
for e in self.workers.remote_workers():
|
|
|
|
e.set_weights.remote(weights, _get_global_vars())
|
|
|
|
else:
|
|
|
|
if metrics.current_actor is None:
|
|
|
|
raise ValueError(
|
|
|
|
"Could not find actor to update. When "
|
|
|
|
"update_all=False, `current_actor` must be set "
|
2022-01-29 18:41:57 -08:00
|
|
|
"in the iterator context."
|
|
|
|
)
|
2020-04-10 00:56:08 -07:00
|
|
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
2022-01-27 12:17:34 +01:00
|
|
|
weights = self.local_worker.get_weights(
|
2022-01-29 18:41:57 -08:00
|
|
|
self.policies or self.local_worker.get_policies_to_train()
|
|
|
|
)
|
|
|
|
metrics.current_actor.set_weights.remote(weights, _get_global_vars())
|
2020-04-10 00:56:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
class AverageGradients:
|
|
|
|
"""Callable that averages the gradients in a batch.
|
|
|
|
|
|
|
|
This should be used with the .for_each() operator after a set of gradients
|
|
|
|
have been batched with .batch().
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> batched_grads = grads_op.batch(32)
|
|
|
|
>>> avg_grads = batched_grads.for_each(AverageGradients())
|
|
|
|
>>> print(next(avg_grads))
|
|
|
|
{"var_0": ..., ...}, 1600 # averaged grads, summed batch count
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __call__(
|
|
|
|
self, gradients: List[Tuple[ModelGradients, int]]
|
|
|
|
) -> Tuple[ModelGradients, int]:
|
2020-04-10 00:56:08 -07:00
|
|
|
acc = None
|
|
|
|
sum_count = 0
|
|
|
|
for grad, count in gradients:
|
|
|
|
if acc is None:
|
|
|
|
acc = grad
|
|
|
|
else:
|
|
|
|
acc = [a + b for a, b in zip(acc, grad)]
|
|
|
|
sum_count += count
|
2022-01-29 18:41:57 -08:00
|
|
|
logger.info(
|
|
|
|
"Computing average of {} microbatch gradients "
|
|
|
|
"({} samples total)".format(len(gradients), sum_count)
|
|
|
|
)
|
2020-04-10 00:56:08 -07:00
|
|
|
return acc, sum_count
|
|
|
|
|
|
|
|
|
|
|
|
class UpdateTargetNetwork:
|
|
|
|
"""Periodically call policy.update_target() on all trainable policies.
|
|
|
|
|
|
|
|
This should be used with the .for_each() operator after training step
|
|
|
|
has been taken.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
|
|
|
|
>>> update_op = train_op.for_each(
|
|
|
|
... UpdateTargetIfNeeded(workers, target_update_freq=500))
|
|
|
|
>>> print(next(update_op))
|
|
|
|
None
|
|
|
|
|
|
|
|
Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
|
|
|
|
local iterator context. The value of the last update counter is used to
|
|
|
|
track when we should update the target next.
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
workers: WorkerSet,
|
|
|
|
target_update_freq: int,
|
|
|
|
by_steps_trained: bool = False,
|
|
|
|
policies: List[PolicyID] = frozenset([]),
|
|
|
|
):
|
2020-04-10 00:56:08 -07:00
|
|
|
self.workers = workers
|
2021-06-21 13:46:01 +02:00
|
|
|
self.local_worker = workers.local_worker()
|
2020-04-10 00:56:08 -07:00
|
|
|
self.target_update_freq = target_update_freq
|
2021-06-21 13:46:01 +02:00
|
|
|
self.policies = policies
|
2020-04-10 00:56:08 -07:00
|
|
|
if by_steps_trained:
|
|
|
|
self.metric = STEPS_TRAINED_COUNTER
|
|
|
|
else:
|
|
|
|
self.metric = STEPS_SAMPLED_COUNTER
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def __call__(self, _: Any) -> None:
|
2020-05-21 10:16:18 -07:00
|
|
|
metrics = _get_shared_metrics()
|
2020-04-10 00:56:08 -07:00
|
|
|
cur_ts = metrics.counters[self.metric]
|
|
|
|
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
|
|
|
if cur_ts - last_update > self.target_update_freq:
|
2022-01-29 18:41:57 -08:00
|
|
|
to_update = self.policies or self.local_worker.get_policies_to_train()
|
2022-01-27 12:17:34 +01:00
|
|
|
self.workers.local_worker().foreach_policy_to_train(
|
2022-01-29 18:41:57 -08:00
|
|
|
lambda p, pid: pid in to_update and p.update_target()
|
|
|
|
)
|
2020-04-10 00:56:08 -07:00
|
|
|
metrics.counters[NUM_TARGET_UPDATES] += 1
|
|
|
|
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
|