mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[rllib] Revert [rllib] Port DDPG to the build_tf_policy pattern (#5626)
This commit is contained in:
parent
1823ea74e3
commit
dcff263ce9
6 changed files with 658 additions and 499 deletions
|
@ -54,7 +54,7 @@ PICKLE_OBJECT_WARNING_SIZE = 10**7
|
|||
# The maximum resource quantity that is allowed. TODO(rkn): This could be
|
||||
# relaxed, but the current implementation of the node manager will be slower
|
||||
# for large resource quantities due to bookkeeping of specific resource IDs.
|
||||
MAX_RESOURCE_QUANTITY = 10000
|
||||
MAX_RESOURCE_QUANTITY = 20000
|
||||
|
||||
# Each memory "resource" counts as this many bytes of memory.
|
||||
MEMORY_RESOURCE_UNIT_BYTES = 50 * 1024 * 1024
|
||||
|
|
|
@ -41,7 +41,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# === Model ===
|
||||
# Apply a state preprocessor with spec given by the "model" config option
|
||||
# (like other RL algorithms). This is mostly useful if you have a weird
|
||||
# observation shape, like an image. Auto-enabled if a custom model is set.
|
||||
# observation shape, like an image. Disabled by default.
|
||||
"use_state_preprocessor": False,
|
||||
# Postprocess the policy network model output with these hidden layers. If
|
||||
# use_state_preprocessor is False, then these will be the *only* hidden
|
||||
|
@ -173,7 +173,7 @@ def make_exploration_schedule(config, worker_index):
|
|||
if config["per_worker_exploration"]:
|
||||
assert config["num_workers"] > 1, "This requires multiple workers"
|
||||
if worker_index >= 0:
|
||||
# Exploration constants from the Ape-X paper
|
||||
# FIXME: what do magic constants mean? (0.4, 7)
|
||||
max_index = float(config["num_workers"] - 1)
|
||||
exponent = 1 + worker_index / max_index * 7
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -186,8 +186,9 @@ def check_config_and_setup_param_noise(config):
|
|||
# between noisy policy and original policy
|
||||
policies = info["policy"]
|
||||
episode = info["episode"]
|
||||
episode.custom_metrics["policy_distance"] = policies[
|
||||
DEFAULT_POLICY_ID].model.pi_distance
|
||||
model = policies[DEFAULT_POLICY_ID].model
|
||||
if hasattr(model, "pi_distance"):
|
||||
episode.custom_metrics["policy_distance"] = model.pi_distance
|
||||
if end_callback:
|
||||
end_callback(info)
|
||||
|
||||
|
|
|
@ -10,15 +10,13 @@ import ray
|
|||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.sac.sac_model import SACModel
|
||||
from ray.rllib.agents.ddpg.noop_model import NoopModel
|
||||
from ray.rllib.agents.ddpg.ddpg_policy import ComputeTDErrorMixin, \
|
||||
TargetNetworkMixin
|
||||
from ray.rllib.agents.dqn.dqn_policy import _postprocess_dqn, PRIO_WEIGHTS
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.tf_ops import minimize_and_clip
|
||||
from ray.rllib.utils.tf_ops import minimize_and_clip, make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
|
@ -287,6 +285,55 @@ class ActorCriticOptimizerMixin(object):
|
|||
learning_rate=config["optimization"]["entropy_learning_rate"])
|
||||
|
||||
|
||||
class ComputeTDErrorMixin(object):
|
||||
def __init__(self):
|
||||
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
if not self.loss_initialized():
|
||||
return tf.zeros_like(rew_t)
|
||||
|
||||
# Do forward pass on loss to update td error attribute
|
||||
actor_critic_loss(
|
||||
self, self.model, None, {
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
|
||||
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
|
||||
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
|
||||
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
|
||||
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
|
||||
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
|
||||
})
|
||||
|
||||
return self.td_error
|
||||
|
||||
self.compute_td_error = compute_td_error
|
||||
|
||||
|
||||
class TargetNetworkMixin(object):
|
||||
def __init__(self, config):
|
||||
@make_tf_callable(self.get_session())
|
||||
def update_target_fn(tau):
|
||||
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
||||
update_target_expr = []
|
||||
model_vars = self.model.trainable_variables()
|
||||
target_model_vars = self.target_model.trainable_variables()
|
||||
assert len(model_vars) == len(target_model_vars), \
|
||||
(model_vars, target_model_vars)
|
||||
for var, var_target in zip(model_vars, target_model_vars):
|
||||
update_target_expr.append(
|
||||
var_target.assign(tau * var + (1.0 - tau) * var_target))
|
||||
logger.debug("Update target op {}".format(var_target))
|
||||
return tf.group(*update_target_expr)
|
||||
|
||||
# Hard initial update
|
||||
self._do_update = update_target_fn
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
# support both hard and soft sync
|
||||
def update_target(self, tau=None):
|
||||
self._do_update(np.float32(tau or self.config.get("tau")))
|
||||
|
||||
|
||||
def setup_early_mixins(policy, obs_space, action_space, config):
|
||||
ExplorationStateMixin.__init__(policy, obs_space, action_space, config)
|
||||
ActorCriticOptimizerMixin.__init__(policy, config)
|
||||
|
|
|
@ -56,30 +56,6 @@ class TestEagerSupport(unittest.TestCase):
|
|||
"timesteps_per_iteration": 100
|
||||
})
|
||||
|
||||
def testDDPG(self):
|
||||
check_support("DDPG", {
|
||||
"num_workers": 0,
|
||||
"learning_starts": 0,
|
||||
"timesteps_per_iteration": 10
|
||||
})
|
||||
|
||||
def testTD3(self):
|
||||
check_support("TD3", {
|
||||
"num_workers": 0,
|
||||
"learning_starts": 0,
|
||||
"timesteps_per_iteration": 10
|
||||
})
|
||||
|
||||
def testAPEX_DDPG(self):
|
||||
check_support(
|
||||
"APEX_DDPG", {
|
||||
"num_workers": 2,
|
||||
"learning_starts": 0,
|
||||
"num_gpus": 0,
|
||||
"min_iter_time_s": 1,
|
||||
"timesteps_per_iteration": 100
|
||||
})
|
||||
|
||||
def testSAC(self):
|
||||
check_support("SAC", {
|
||||
"num_workers": 0,
|
||||
|
|
Loading…
Add table
Reference in a new issue