mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
202 lines
7.4 KiB
Python
202 lines
7.4 KiB
Python
![]() |
import logging
|
||
|
from typing import List
|
||
|
|
||
|
import ray
|
||
|
from ray.util.iter import LocalIterator
|
||
|
from ray.rllib.evaluation.metrics import get_learner_stats
|
||
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
||
|
from ray.rllib.execution.common import SampleBatchType, \
|
||
|
STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, LEARNER_INFO, \
|
||
|
APPLY_GRADS_TIMER, COMPUTE_GRADS_TIMER, WORKER_UPDATE_TIMER, \
|
||
|
LEARN_ON_BATCH_TIMER, LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES, \
|
||
|
_get_global_vars, _check_sample_batch_type
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class TrainOneStep:
|
||
|
"""Callable that improves the policy and updates workers.
|
||
|
|
||
|
This should be used with the .for_each() operator.
|
||
|
|
||
|
Examples:
|
||
|
>>> rollouts = ParallelRollouts(...)
|
||
|
>>> train_op = rollouts.for_each(TrainOneStep(workers))
|
||
|
>>> print(next(train_op)) # This trains the policy on one batch.
|
||
|
{"learner_stats": ...}
|
||
|
|
||
|
Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the
|
||
|
local iterator context.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, workers: WorkerSet):
|
||
|
self.workers = workers
|
||
|
|
||
|
def __call__(self, batch: SampleBatchType) -> List[dict]:
|
||
|
_check_sample_batch_type(batch)
|
||
|
metrics = LocalIterator.get_metrics()
|
||
|
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
||
|
with learn_timer:
|
||
|
info = self.workers.local_worker().learn_on_batch(batch)
|
||
|
learn_timer.push_units_processed(batch.count)
|
||
|
metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
|
||
|
metrics.info[LEARNER_INFO] = get_learner_stats(info)
|
||
|
if self.workers.remote_workers():
|
||
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||
|
weights = ray.put(self.workers.local_worker().get_weights())
|
||
|
for e in self.workers.remote_workers():
|
||
|
e.set_weights.remote(weights, _get_global_vars())
|
||
|
# Also update global vars of the local worker.
|
||
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||
|
return info
|
||
|
|
||
|
|
||
|
class ComputeGradients:
|
||
|
"""Callable that computes gradients with respect to the policy loss.
|
||
|
|
||
|
This should be used with the .for_each() operator.
|
||
|
|
||
|
Examples:
|
||
|
>>> grads_op = rollouts.for_each(ComputeGradients(workers))
|
||
|
>>> print(next(grads_op))
|
||
|
{"var_0": ..., ...}, 50 # grads, batch count
|
||
|
|
||
|
Updates the LEARNER_INFO info field in the local iterator context.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, workers):
|
||
|
self.workers = workers
|
||
|
|
||
|
def __call__(self, samples: SampleBatchType):
|
||
|
_check_sample_batch_type(samples)
|
||
|
metrics = LocalIterator.get_metrics()
|
||
|
with metrics.timers[COMPUTE_GRADS_TIMER]:
|
||
|
grad, info = self.workers.local_worker().compute_gradients(samples)
|
||
|
metrics.info[LEARNER_INFO] = get_learner_stats(info)
|
||
|
return grad, samples.count
|
||
|
|
||
|
|
||
|
class ApplyGradients:
|
||
|
"""Callable that applies gradients and updates workers.
|
||
|
|
||
|
This should be used with the .for_each() operator.
|
||
|
|
||
|
Examples:
|
||
|
>>> apply_op = grads_op.for_each(ApplyGradients(workers))
|
||
|
>>> print(next(apply_op))
|
||
|
None
|
||
|
|
||
|
Updates the STEPS_TRAINED_COUNTER counter in the local iterator context.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, workers, update_all=True):
|
||
|
"""Creates an ApplyGradients instance.
|
||
|
|
||
|
Arguments:
|
||
|
workers (WorkerSet): workers to apply gradients to.
|
||
|
update_all (bool): If true, updates all workers. Otherwise, only
|
||
|
update the worker that produced the sample batch we are
|
||
|
currently processing (i.e., A3C style).
|
||
|
"""
|
||
|
self.workers = workers
|
||
|
self.update_all = update_all
|
||
|
|
||
|
def __call__(self, item):
|
||
|
if not isinstance(item, tuple) or len(item) != 2:
|
||
|
raise ValueError(
|
||
|
"Input must be a tuple of (grad_dict, count), got {}".format(
|
||
|
item))
|
||
|
gradients, count = item
|
||
|
metrics = LocalIterator.get_metrics()
|
||
|
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
||
|
|
||
|
apply_timer = metrics.timers[APPLY_GRADS_TIMER]
|
||
|
with apply_timer:
|
||
|
self.workers.local_worker().apply_gradients(gradients)
|
||
|
apply_timer.push_units_processed(count)
|
||
|
|
||
|
# Also update global vars of the local worker.
|
||
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||
|
|
||
|
if self.update_all:
|
||
|
if self.workers.remote_workers():
|
||
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||
|
weights = ray.put(
|
||
|
self.workers.local_worker().get_weights())
|
||
|
for e in self.workers.remote_workers():
|
||
|
e.set_weights.remote(weights, _get_global_vars())
|
||
|
else:
|
||
|
if metrics.current_actor is None:
|
||
|
raise ValueError(
|
||
|
"Could not find actor to update. When "
|
||
|
"update_all=False, `current_actor` must be set "
|
||
|
"in the iterator context.")
|
||
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||
|
weights = self.workers.local_worker().get_weights()
|
||
|
metrics.current_actor.set_weights.remote(
|
||
|
weights, _get_global_vars())
|
||
|
|
||
|
|
||
|
class AverageGradients:
|
||
|
"""Callable that averages the gradients in a batch.
|
||
|
|
||
|
This should be used with the .for_each() operator after a set of gradients
|
||
|
have been batched with .batch().
|
||
|
|
||
|
Examples:
|
||
|
>>> batched_grads = grads_op.batch(32)
|
||
|
>>> avg_grads = batched_grads.for_each(AverageGradients())
|
||
|
>>> print(next(avg_grads))
|
||
|
{"var_0": ..., ...}, 1600 # averaged grads, summed batch count
|
||
|
"""
|
||
|
|
||
|
def __call__(self, gradients):
|
||
|
acc = None
|
||
|
sum_count = 0
|
||
|
for grad, count in gradients:
|
||
|
if acc is None:
|
||
|
acc = grad
|
||
|
else:
|
||
|
acc = [a + b for a, b in zip(acc, grad)]
|
||
|
sum_count += count
|
||
|
logger.info("Computing average of {} microbatch gradients "
|
||
|
"({} samples total)".format(len(gradients), sum_count))
|
||
|
return acc, sum_count
|
||
|
|
||
|
|
||
|
class UpdateTargetNetwork:
|
||
|
"""Periodically call policy.update_target() on all trainable policies.
|
||
|
|
||
|
This should be used with the .for_each() operator after training step
|
||
|
has been taken.
|
||
|
|
||
|
Examples:
|
||
|
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
|
||
|
>>> update_op = train_op.for_each(
|
||
|
... UpdateTargetIfNeeded(workers, target_update_freq=500))
|
||
|
>>> print(next(update_op))
|
||
|
None
|
||
|
|
||
|
Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
|
||
|
local iterator context. The value of the last update counter is used to
|
||
|
track when we should update the target next.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, workers, target_update_freq, by_steps_trained=False):
|
||
|
self.workers = workers
|
||
|
self.target_update_freq = target_update_freq
|
||
|
if by_steps_trained:
|
||
|
self.metric = STEPS_TRAINED_COUNTER
|
||
|
else:
|
||
|
self.metric = STEPS_SAMPLED_COUNTER
|
||
|
|
||
|
def __call__(self, _):
|
||
|
metrics = LocalIterator.get_metrics()
|
||
|
cur_ts = metrics.counters[self.metric]
|
||
|
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
||
|
if cur_ts - last_update > self.target_update_freq:
|
||
|
self.workers.local_worker().foreach_trainable_policy(
|
||
|
lambda p, _: p.update_target())
|
||
|
metrics.counters[NUM_TARGET_UPDATES] += 1
|
||
|
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
|