mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
This reverts commit 0eb0e0ff58
.
This commit is contained in:
parent
77f28f1c30
commit
0b8489dcc6
4 changed files with 36 additions and 74 deletions
|
@ -169,6 +169,8 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
|
||||
Rewrites rollout_fragment_length to take into account n_step truncation.
|
||||
"""
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for DDPG!")
|
||||
if config["model"]["custom_model"]:
|
||||
logger.warning(
|
||||
"Setting use_state_preprocessor=True since a custom model "
|
||||
|
|
|
@ -253,6 +253,23 @@ def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _,
|
|||
return policy.critic_loss + policy.actor_loss
|
||||
|
||||
|
||||
def make_ddpg_optimizers(policy: Policy, config: TrainerConfigDict) -> None:
|
||||
# Create separate optimizers for actor & critic losses.
|
||||
if policy.config["framework"] in ["tf2", "tfe"]:
|
||||
policy._actor_optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=config["actor_lr"])
|
||||
policy._critic_optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=config["critic_lr"])
|
||||
else:
|
||||
policy._actor_optimizer = tf1.train.AdamOptimizer(
|
||||
learning_rate=config["actor_lr"])
|
||||
policy._critic_optimizer = tf1.train.AdamOptimizer(
|
||||
learning_rate=config["critic_lr"])
|
||||
# TODO: (sven) make this function return both optimizers and
|
||||
# TFPolicy handle optimizers vs loss terms correctly (like torch).
|
||||
return None
|
||||
|
||||
|
||||
def build_apply_op(policy: Policy, optimizer: LocalOptimizer,
|
||||
grads_and_vars: ModelGradients) -> TensorType:
|
||||
# For policy gradient, update policy net one time v.s.
|
||||
|
@ -326,44 +343,14 @@ def build_ddpg_stats(policy: Policy,
|
|||
return stats
|
||||
|
||||
|
||||
class ActorCriticOptimizerMixin:
|
||||
"""Mixin class to generate the necessary optimizers for actor-critic algos.
|
||||
|
||||
- Creates global step for counting the number of update operations.
|
||||
- Creates separate optimizers for actor, critic, and alpha.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
# Eager mode.
|
||||
if config["framework"] in ["tf2", "tfe"]:
|
||||
self.global_step = get_variable(0, tf_name="global_step")
|
||||
self._actor_optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=config["actor_lr"])
|
||||
self._critic_optimizer = \
|
||||
tf.keras.optimizers.Adam(learning_rate=config["critic_lr"])
|
||||
# Static graph mode.
|
||||
else:
|
||||
self.global_step = tf1.train.get_or_create_global_step()
|
||||
self._actor_optimizer = tf1.train.AdamOptimizer(
|
||||
learning_rate=config["actor_lr"])
|
||||
self._critic_optimizer = \
|
||||
tf1.train.AdamOptimizer(learning_rate=config["critic_lr"])
|
||||
|
||||
|
||||
def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
def before_init_fn(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
"""Call mixin classes' constructors before Policy's initialization.
|
||||
|
||||
Adds the necessary optimizers to the given Policy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
obs_space (gym.spaces.Space): The Policy's observation space.
|
||||
action_space (gym.spaces.Space): The Policy's action space.
|
||||
config (TrainerConfigDict): The Policy's config.
|
||||
"""
|
||||
ActorCriticOptimizerMixin.__init__(policy, config)
|
||||
# Create global step for counting the number of update operations.
|
||||
if config["framework"] in ["tf2", "tfe"]:
|
||||
policy.global_step = get_variable(0, tf_name="global_step")
|
||||
else:
|
||||
policy.global_step = tf1.train.get_or_create_global_step()
|
||||
|
||||
|
||||
class ComputeTDErrorMixin:
|
||||
|
@ -452,15 +439,15 @@ DDPGTFPolicy = build_tf_policy(
|
|||
loss_fn=ddpg_actor_critic_loss,
|
||||
stats_fn=build_ddpg_stats,
|
||||
postprocess_fn=postprocess_nstep_and_prio,
|
||||
optimizer_fn=make_ddpg_optimizers,
|
||||
compute_gradients_fn=gradients_fn,
|
||||
apply_gradients_fn=build_apply_op,
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
||||
validate_spaces=validate_spaces,
|
||||
before_init=setup_early_mixins,
|
||||
before_init=before_init_fn,
|
||||
before_loss_init=setup_mid_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
mixins=[
|
||||
TargetNetworkMixin,
|
||||
ActorCriticOptimizerMixin,
|
||||
ComputeTDErrorMixin,
|
||||
])
|
||||
|
|
|
@ -16,8 +16,7 @@ from ray.rllib.policy.policy_template import build_policy_class
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
|
||||
concat_multi_gpu_td_errors, huber_loss, l2_loss
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss, l2_loss
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \
|
||||
LocalOptimizer, GradInfoDict
|
||||
|
||||
|
@ -176,14 +175,10 @@ def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _,
|
|||
[actor_loss, critic_loss], input_dict)
|
||||
|
||||
# Store values for stats function.
|
||||
policy.q_t = q_t
|
||||
policy.actor_loss = actor_loss
|
||||
policy.critic_loss = critic_loss
|
||||
|
||||
# Store td-error in model, such that for multi-GPU, we do not override
|
||||
# them during the parallel loss phase. TD-error tensor in final stats
|
||||
# can then be concatenated and retrieved for each individual batch item.
|
||||
model.td_error = td_error
|
||||
policy.td_error = td_error
|
||||
policy.q_t = q_t
|
||||
|
||||
# Return two loss terms (corresponding to the two optimizers, we create).
|
||||
return policy.actor_loss, policy.critic_loss
|
||||
|
@ -226,6 +221,8 @@ def build_ddpg_stats(policy: Policy,
|
|||
"mean_q": torch.mean(policy.q_t),
|
||||
"max_q": torch.max(policy.q_t),
|
||||
"min_q": torch.min(policy.q_t),
|
||||
"mean_td_error": torch.mean(policy.td_error),
|
||||
"td_error": policy.td_error,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
@ -255,7 +252,7 @@ class ComputeTDErrorMixin:
|
|||
loss_fn(self, self.model, None, input_dict)
|
||||
|
||||
# Self.td_error is set within actor_critic_loss call.
|
||||
return self.model.td_error
|
||||
return self.td_error
|
||||
|
||||
self.compute_td_error = compute_td_error
|
||||
|
||||
|
@ -304,7 +301,6 @@ DDPGTorchPolicy = build_policy_class(
|
|||
before_loss_init=setup_late_mixins,
|
||||
action_distribution_fn=get_distribution_inputs_and_class,
|
||||
make_model_and_action_dist=build_ddpg_models_and_action_dist,
|
||||
extra_learn_fetches_fn=concat_multi_gpu_td_errors,
|
||||
apply_gradients_fn=apply_gradients_fn,
|
||||
mixins=[
|
||||
TargetNetworkMixin,
|
||||
|
|
|
@ -56,29 +56,6 @@ class TestDDPG(unittest.TestCase):
|
|||
check(a, 500)
|
||||
trainer.stop()
|
||||
|
||||
def test_ddpg_fake_multi_gpu_learning(self):
|
||||
"""Test whether DDPGTrainer can learn CartPole w/ faked multi-GPU."""
|
||||
config = ddpg.DEFAULT_CONFIG.copy()
|
||||
# Fake GPU setup.
|
||||
config["num_gpus"] = 2
|
||||
config["_fake_gpus"] = True
|
||||
env = "ray.rllib.agents.sac.tests.test_sac.SimpleEnv"
|
||||
config["env_config"] = {"config": {"repeat_delay": 0}}
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
trainer = ddpg.DDPGTrainer(config=config, env=env)
|
||||
num_iterations = 50
|
||||
learnt = False
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(f"R={results['episode_reward_mean']}")
|
||||
if results["episode_reward_mean"] > 75.0:
|
||||
learnt = True
|
||||
break
|
||||
assert learnt, \
|
||||
f"DDPG multi-GPU (with fake-GPUs) did not learn {env}!"
|
||||
trainer.stop()
|
||||
|
||||
def test_ddpg_checkpoint_save_and_restore(self):
|
||||
"""Test whether a DDPGTrainer can save and load checkpoints."""
|
||||
config = ddpg.DEFAULT_CONFIG.copy()
|
||||
|
@ -312,7 +289,7 @@ class TestDDPG(unittest.TestCase):
|
|||
elif fw == "torch":
|
||||
loss_torch(policy, policy.model, None, input_)
|
||||
c, a, t = policy.critic_loss, policy.actor_loss, \
|
||||
policy.model.td_error
|
||||
policy.td_error
|
||||
# Check pure loss values.
|
||||
check(c, expect_c)
|
||||
check(a, expect_a)
|
||||
|
|
Loading…
Add table
Reference in a new issue