ray/rllib/utils/exploration/per_worker_gaussian_noise.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

50 lines
1.7 KiB
Python
Raw Normal View History

from gym.spaces import Space
from typing import Optional
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.exploration.gaussian_noise import GaussianNoise
from ray.rllib.utils.schedules import ConstantSchedule
@PublicAPI
class PerWorkerGaussianNoise(GaussianNoise):
"""A per-worker Gaussian noise class for distributed algorithms.
Sets the `scale` schedules of individual workers to a constant:
0.4 ^ (1 + [worker-index] / float([num-workers] - 1) * 7)
See Ape-X paper.
"""
def __init__(
self,
action_space: Space,
*,
framework: Optional[str],
num_workers: Optional[int],
worker_index: Optional[int],
**kwargs
):
"""
Args:
action_space: The gym action space used by the environment.
num_workers: The overall number of workers used.
worker_index: The index of the Worker using this
Exploration.
framework: One of None, "tf", "torch".
"""
scale_schedule = None
# Use a fixed, different epsilon per worker. See: Ape-X paper.
if num_workers > 0:
if worker_index > 0:
num_workers_minus_1 = float(num_workers - 1) if num_workers > 1 else 1.0
exponent = 1 + (worker_index / num_workers_minus_1) * 7
scale_schedule = ConstantSchedule(0.4 ** exponent, framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.
else:
scale_schedule = ConstantSchedule(0.0, framework=framework)
super().__init__(
action_space, scale_schedule=scale_schedule, framework=framework, **kwargs
)