mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Rename optimizers for clarity (#2303)
* rename * fix * update * mgpu * Update a3c.py * Update bc.py * Update a3c.py * Update test_optimizers.py * Update a3c.py
This commit is contained in:
parent
e657497225
commit
44f5f0520b
20 changed files with 39 additions and 38 deletions
|
@ -7,7 +7,7 @@ import os
|
|||
|
||||
import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \
|
||||
collect_metrics
|
||||
|
@ -131,7 +131,7 @@ class A3CAgent(Agent):
|
|||
worker_index=i+1)
|
||||
for i in range(self.config["num_workers"])]
|
||||
|
||||
self.optimizer = AsyncOptimizer(
|
||||
self.optimizer = AsyncGradientsOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import ray
|
|||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.bc.bc_evaluator import BCEvaluator, GPURemoteBCEvaluator, \
|
||||
RemoteBCEvaluator
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
|
@ -71,7 +71,7 @@ class BCAgent(Agent):
|
|||
self.remote_evaluators = [
|
||||
remote_cls.remote(self.env_creator, self.config, self.logdir)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.optimizer = AsyncOptimizer(
|
||||
self.optimizer = AsyncGradientsOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
self.remote_evaluators)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.utils import merge_dicts
|
|||
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
||||
DDPG_CONFIG,
|
||||
{
|
||||
"optimizer_class": "ApexOptimizer",
|
||||
"optimizer_class": "AsyncSamplesOptimizer",
|
||||
"optimizer_config":
|
||||
merge_dicts(
|
||||
DDPG_CONFIG["optimizer_config"], {
|
||||
|
|
|
@ -102,7 +102,7 @@ DEFAULT_CONFIG = {
|
|||
# Whether to allocate CPUs for workers (if > 0).
|
||||
"num_cpus_per_worker": 1,
|
||||
# Optimizer class to use.
|
||||
"optimizer_class": "LocalSyncReplayOptimizer",
|
||||
"optimizer_class": "SyncReplayOptimizer",
|
||||
# Config to pass to the optimizer.
|
||||
"optimizer_config": {},
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
|
|
|
@ -9,7 +9,7 @@ from ray.utils import merge_dicts
|
|||
APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
DQN_CONFIG,
|
||||
{
|
||||
"optimizer_class": "ApexOptimizer",
|
||||
"optimizer_class": "AsyncSamplesOptimizer",
|
||||
"optimizer_config":
|
||||
merge_dicts(
|
||||
DQN_CONFIG["optimizer_config"], {
|
||||
|
|
|
@ -96,7 +96,7 @@ DEFAULT_CONFIG = {
|
|||
# Whether to allocate CPUs for workers (if > 0).
|
||||
"num_cpus_per_worker": 1,
|
||||
# Optimizer class to use.
|
||||
"optimizer_class": "LocalSyncReplayOptimizer",
|
||||
"optimizer_class": "SyncReplayOptimizer",
|
||||
# Config to pass to the optimizer.
|
||||
"optimizer_config": {},
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from ray.rllib.optimizers.apex_optimizer import ApexOptimizer
|
||||
from ray.rllib.optimizers.async_optimizer import AsyncOptimizer
|
||||
from ray.rllib.optimizers.local_sync import LocalSyncOptimizer
|
||||
from ray.rllib.optimizers.local_sync_replay import LocalSyncReplayOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer
|
||||
from ray.rllib.optimizers.async_samples_optimizer import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.async_gradients_optimizer import \
|
||||
AsyncGradientsOptimizer
|
||||
from ray.rllib.optimizers.sync_samples_optimizer import SyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.sync_replay_optimizer import SyncReplayOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator, \
|
||||
TFMultiGPUSupport
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ApexOptimizer", "AsyncOptimizer", "LocalSyncOptimizer",
|
||||
"LocalSyncReplayOptimizer", "LocalMultiGPUOptimizer", "SampleBatch",
|
||||
"AsyncSamplesOptimizer", "AsyncGradientsOptimizer", "SyncSamplesOptimizer",
|
||||
"SyncReplayOptimizer", "LocalMultiGPUOptimizer", "SampleBatch",
|
||||
"PolicyEvaluator", "TFMultiGPUSupport", "MultiAgentBatch"]
|
||||
|
|
|
@ -7,7 +7,7 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
|
||||
class AsyncOptimizer(PolicyOptimizer):
|
||||
class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
"""An asynchronous RL optimizer, e.g. for implementing A3C.
|
||||
|
||||
This optimizer asynchronously pulls and applies gradients from remote
|
|
@ -135,8 +135,8 @@ class LearnerThread(threading.Thread):
|
|||
self.weights_updated = True
|
||||
|
||||
|
||||
class ApexOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the Ape-X optimizer.
|
||||
class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the Ape-X optimizer (async sampling with replay).
|
||||
|
||||
This class coordinates the data transfers between the learner thread,
|
||||
remote evaluators (Ape-X actors), and replay buffer actors.
|
|
@ -16,7 +16,7 @@ from ray.rllib.utils.filter import RunningStat
|
|||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
|
||||
class LocalSyncReplayOptimizer(PolicyOptimizer):
|
||||
class SyncReplayOptimizer(PolicyOptimizer):
|
||||
"""Variant of the local sync optimizer that supports replay (for DQN).
|
||||
|
||||
This optimizer requires that policy evaluators return an additional
|
|
@ -9,7 +9,7 @@ from ray.rllib.utils.filter import RunningStat
|
|||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
|
||||
class LocalSyncOptimizer(PolicyOptimizer):
|
||||
class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
"""A simple synchronous RL optimizer.
|
||||
|
||||
In each step, this optimizer pulls samples from a number of remote
|
|
@ -3,7 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.optimizers import LocalSyncOptimizer
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \
|
||||
collect_metrics
|
||||
|
@ -54,7 +54,7 @@ class PGAgent(Agent):
|
|||
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
|
||||
|
||||
def _init(self):
|
||||
self.optimizer = LocalSyncOptimizer.make(
|
||||
self.optimizer = SyncSamplesOptimizer.make(
|
||||
evaluator_cls=CommonPolicyEvaluator,
|
||||
evaluator_args={
|
||||
"env_creator": self.env_creator,
|
||||
|
|
|
@ -13,7 +13,7 @@ from ray.tune.trial import Resources
|
|||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.ppo.ppo_evaluator import PPOEvaluator
|
||||
from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Discount factor of the MDP
|
||||
|
|
|
@ -10,8 +10,8 @@ import ray
|
|||
from ray.rllib.pg import PGAgent
|
||||
from ray.rllib.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.optimizers import LocalSyncOptimizer, \
|
||||
LocalSyncReplayOptimizer, AsyncOptimizer
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, \
|
||||
SyncReplayOptimizer, AsyncGradientsOptimizer
|
||||
from ray.rllib.test.test_common_policy_evaluator import MockEnv, MockEnv2, \
|
||||
MockPolicyGraph
|
||||
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \
|
||||
|
@ -270,7 +270,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
act_space = env.action_space
|
||||
obs_space = env.observation_space
|
||||
dqn_config = {"gamma": 0.95, "n_step": 3}
|
||||
if optimizer_cls == LocalSyncReplayOptimizer:
|
||||
if optimizer_cls == SyncReplayOptimizer:
|
||||
# TODO: support replay with non-DQN graphs. Currently this can't
|
||||
# happen since the replay buffer doesn't encode extra fields like
|
||||
# "advantages" that PG uses.
|
||||
|
@ -288,7 +288,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
batch_steps=50)
|
||||
if optimizer_cls == AsyncOptimizer:
|
||||
if optimizer_cls == AsyncGradientsOptimizer:
|
||||
remote_evs = [CommonPolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
|
@ -315,13 +315,13 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
raise Exception("failed to improve reward")
|
||||
|
||||
def testMultiAgentSyncOptimizer(self):
|
||||
self._testWithOptimizer(LocalSyncOptimizer)
|
||||
self._testWithOptimizer(SyncSamplesOptimizer)
|
||||
|
||||
def testMultiAgentAsyncOptimizer(self):
|
||||
self._testWithOptimizer(AsyncOptimizer)
|
||||
def testMultiAgentAsyncGradientsOptimizer(self):
|
||||
self._testWithOptimizer(AsyncGradientsOptimizer)
|
||||
|
||||
def testMultiAgentReplayOptimizer(self):
|
||||
self._testWithOptimizer(LocalSyncReplayOptimizer)
|
||||
self._testWithOptimizer(SyncReplayOptimizer)
|
||||
|
||||
def testTrainMultiCartpoleManyPolicies(self):
|
||||
n = 20
|
||||
|
@ -338,7 +338,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = LocalSyncOptimizer({}, ev, [])
|
||||
optimizer = SyncSamplesOptimizer({}, ev, [])
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
|
||||
import ray
|
||||
from ray.rllib.test.mock_evaluator import _MockEvaluator
|
||||
from ray.rllib.optimizers import AsyncOptimizer, SampleBatch
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer, SampleBatch
|
||||
|
||||
|
||||
class AsyncOptimizerTest(unittest.TestCase):
|
||||
|
@ -20,7 +20,7 @@ class AsyncOptimizerTest(unittest.TestCase):
|
|||
local = _MockEvaluator()
|
||||
remotes = ray.remote(_MockEvaluator)
|
||||
remote_evaluators = [remotes.remote() for i in range(5)]
|
||||
test_optimizer = AsyncOptimizer({
|
||||
test_optimizer = AsyncGradientsOptimizer({
|
||||
"grads_per_step": 10
|
||||
}, local, remote_evaluators)
|
||||
test_optimizer.step()
|
||||
|
|
|
@ -46,7 +46,7 @@ halfcheetah-ddpg:
|
|||
# === Parallelism ===
|
||||
num_workers: 0
|
||||
num_gpus_per_worker: 0
|
||||
optimizer_class: "LocalSyncReplayOptimizer"
|
||||
optimizer_class: "SyncReplayOptimizer"
|
||||
optimizer_config: {}
|
||||
per_worker_exploration: False
|
||||
worker_side_prioritization: False
|
||||
|
|
|
@ -46,7 +46,7 @@ mountaincarcontinuous-ddpg:
|
|||
# === Parallelism ===
|
||||
num_workers: 0
|
||||
num_gpus_per_worker: 0
|
||||
optimizer_class: "LocalSyncReplayOptimizer"
|
||||
optimizer_class: "SyncReplayOptimizer"
|
||||
optimizer_config: {}
|
||||
per_worker_exploration: False
|
||||
worker_side_prioritization: False
|
||||
|
|
|
@ -46,7 +46,7 @@ pendulum-ddpg:
|
|||
# === Parallelism ===
|
||||
num_workers: 0
|
||||
num_gpus_per_worker: 0
|
||||
optimizer_class: "LocalSyncReplayOptimizer"
|
||||
optimizer_class: "SyncReplayOptimizer"
|
||||
optimizer_config: {}
|
||||
per_worker_exploration: False
|
||||
worker_side_prioritization: False
|
||||
|
|
|
@ -83,7 +83,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
|||
"dones": [[...]], "new_obs": [[...]]})
|
||||
|
||||
# Creating policy evaluators using optimizer_cls.make().
|
||||
>>> optimizer = LocalSyncOptimizer.make(
|
||||
>>> optimizer = SyncSamplesOptimizer.make(
|
||||
evaluator_cls=CommonPolicyEvaluator,
|
||||
evaluator_args={
|
||||
"env_creator": lambda _: gym.make("CartPole-v0"),
|
||||
|
|
Loading…
Add table
Reference in a new issue