[RLlib] R2D2 training iteration fn AND switch off execution_plan API by default. (#24165)

This commit is contained in:
Sven Mika 2022-05-03 07:59:26 +02:00 committed by GitHub
parent e8fc66af34
commit 1bc6419e0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 58 additions and 76 deletions

View file

@ -86,7 +86,6 @@ class A3CConfig(TrainerConfig):
# but to wait until n seconds have passed and then to summarize the
# thus far collected results.
self.min_time_s_per_reporting = 5
self._disable_execution_plan_api = True
# __sphinx_doc_end__
# fmt: on

View file

@ -116,9 +116,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Reporting interval.
"min_time_s_per_reporting": 2,
# Use the `training_iteration` method instead of an execution plan.
"_disable_execution_plan_api": True,
},
_allow_unknown_configs=True,
)

View file

@ -19,6 +19,12 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import FilterManager
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.torch_utils import set_torch_seed
from ray.rllib.utils.typing import TrainerConfigDict
@ -407,6 +413,9 @@ class ARSTrainer(Trainer):
results, num_episodes, num_timesteps = self._collect_results(
theta_id, config["num_rollouts"]
)
# Update our sample steps counters.
self._counters[NUM_AGENT_STEPS_SAMPLED] += num_timesteps
self._counters[NUM_ENV_STEPS_SAMPLED] += num_timesteps
all_noise_indices = []
all_training_returns = []
@ -465,6 +474,11 @@ class ARSTrainer(Trainer):
assert g.shape == (self.policy.num_params,) and g.dtype == np.float32
# Compute the new weights theta.
theta, update_ratio = self.optimizer.update(-g)
# Update our train steps counters.
self._counters[NUM_AGENT_STEPS_TRAINED] += num_timesteps
self._counters[NUM_ENV_STEPS_TRAINED] += num_timesteps
# Set the new weights in the local copy of the policy.
self.policy.set_flat_weights(theta)
# update the reward list

View file

@ -69,9 +69,6 @@ CQL_DEFAULT_CONFIG = merge_dicts(
"min_sample_timesteps_per_reporting": 0,
"min_train_timesteps_per_reporting": 100,
# Use the Trainer's `training_iteration` function instead of `execution_plan`.
"_disable_execution_plan_api": True,
# Deprecated keys.
# Use `replay_buffer_config.capacity` instead.
"buffer_size": DEPRECATED_VALUE,

View file

@ -132,8 +132,6 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# TODO: Find a way to support None again as a means to replay
# proceeding as fast as possible.
"training_intensity": 1,
# Use `training_iteration` instead of `execution_plan` by default.
"_disable_execution_plan_api": True,
},
)
# __sphinx_doc_end__

View file

@ -143,12 +143,6 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# === Parallelism ===
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": True,
},
_allow_unknown_configs=True,
)

View file

@ -75,12 +75,6 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 2500,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
},
_allow_unknown_configs=True,
)

View file

@ -146,12 +146,6 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 0,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -75,7 +75,10 @@ DEFAULT_CONFIG = with_common_config({
"env_config": {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
}
},
# Use `execution_plan` instead of `training_iteration`.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -17,6 +17,12 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import FilterManager
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.torch_utils import set_torch_seed
from ray.rllib.utils.typing import TrainerConfigDict
@ -413,6 +419,9 @@ class ESTrainer(Trainer):
results, num_episodes, num_timesteps = self._collect_results(
theta_id, config["episodes_per_batch"], config["train_batch_size"]
)
# Update our sample steps counters.
self._counters[NUM_AGENT_STEPS_SAMPLED] += num_timesteps
self._counters[NUM_ENV_STEPS_SAMPLED] += num_timesteps
all_noise_indices = []
all_training_returns = []
@ -462,6 +471,11 @@ class ESTrainer(Trainer):
)
# Compute the new weights theta.
theta, update_ratio = self.optimizer.update(-g + config["l2_coeff"] * theta)
# Update our train steps counters.
self._counters[NUM_AGENT_STEPS_TRAINED] += num_timesteps
self._counters[NUM_ENV_STEPS_TRAINED] += num_timesteps
# Set the new weights in the local copy of the policy.
self.policy.set_flat_weights(theta)
# Store the rewards

View file

@ -104,6 +104,8 @@ class ImpalaConfig(TrainerConfig):
self.num_gpus = 1
self.lr = 0.0005
self.min_time_s_per_reporting = 10
# IMPALA and APPO are not on the new training_iteration API yet.
self._disable_execution_plan_api = False
# __sphinx_doc_end__
# fmt: on
@ -183,24 +185,26 @@ class ImpalaConfig(TrainerConfig):
max_sample_requests_in_flight_per_worker: Level of queuing for sampling.
broadcast_interval: Max number of workers to broadcast one set of
weights to.
num_aggregation_workers: Use n (`num_aggregation_workers`) extra Actors for
multi-level aggregation of the data produced by the m RolloutWorkers
(`num_workers`). Note that n should be much smaller than m.
This can make sense if ingesting >2GB/s of samples, or if
the data requires decompression.
grad_clip:
grad_clip: If specified, clip the global norm of gradients by this amount.
opt_type: Either "adam" or "rmsprop".
lr_schedule:
decay: `opt_type=rmsprop` settings.
momentum:
epsilon:
lr_schedule: Learning rate schedule. In the format of
[[timestep, lr-value], [timestep, lr-value], ...]
Intermediary timesteps will be assigned to interpolated learning rate
values. A schedule should normally start from timestep 0.
decay: Decay setting for the RMSProp optimizer, in case `opt_type=rmsprop`.
momentum: Momentum setting for the RMSProp optimizer, in case
`opt_type=rmsprop`.
epsilon: Epsilon setting for the RMSProp optimizer, in case
`opt_type=rmsprop`.
vf_loss_coeff: Coefficient for the value function term in the loss function.
entropy_coeff: Coefficient for the entropy regularizer term in the loss
function.
entropy_coeff_schedule:
entropy_coeff_schedule: Decay schedule for the entropy regularizer.
_separate_vf_optimizer: Set this to true to have two separate optimizers
optimize the policy-and value networks.
_lr_vf: If _separate_vf_optimizer is True, define separate learning rate

View file

@ -78,6 +78,9 @@ DEFAULT_CONFIG = with_common_config({
# to tune vf_loss_coeff.
# Use config.model.vf_share_layers instead.
"vf_share_layers": DEPRECATED_VALUE,
# Use `execution_plan` instead of `training_iteration`.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -90,9 +90,6 @@ DEFAULT_CONFIG = with_common_config({
# === Parallelism ===
"num_workers": 0,
# Use new `training_iteration` API (instead of `execution_plan` method).
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -113,6 +113,9 @@ DEFAULT_CONFIG = with_common_config({
# to tune vf_loss_coeff.
# Use config.model.vf_share_layers instead.
"vf_share_layers": DEPRECATED_VALUE,
# Use `execution_plan` instead of `training_iteration`.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -48,7 +48,6 @@ class PGConfig(TrainerConfig):
# Override some of TrainerConfig's default values with PG-specific values.
self.num_workers = 0
self.lr = 0.0004
self._disable_execution_plan_api = True
self._disable_preprocessor_api = True
# __sphinx_doc_end__
# fmt: on

View file

@ -107,7 +107,6 @@ class PPOConfig(TrainerConfig):
self.train_batch_size = 4000
self.lr = 5e-5
self.model["vf_share_layers"] = False
self._disable_execution_plan_api = True
# __sphinx_doc_end__
# fmt: on

View file

@ -133,12 +133,6 @@ DEFAULT_CONFIG = with_common_config({
# Only torch supported so far.
"framework": "torch",
# === Experimental Flags ===
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": True,
# Deprecated keys:
# Use `replay_buffer_config.learning_starts` instead.
"learning_starts": DEPRECATED_VALUE,

View file

@ -174,11 +174,6 @@ DEFAULT_CONFIG = with_common_config({
# Use a Beta-distribution instead of a SquashedGaussian for bounded,
# continuous action spaces (not recommended, for debugging only).
"_use_beta_distribution": False,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -142,8 +142,6 @@ DEFAULT_CONFIG = with_common_config({
# Switch on no-preprocessors for easier Q-model coding.
"_disable_preprocessor_api": True,
# Use `training_iteration()` instead of `execution_plan()` by default.
"_disable_execution_plan_api": True,
# Deprecated keys:
# Use `capacity` in `replay_buffer_config` instead.

View file

@ -643,6 +643,11 @@ COMMON_CONFIG: TrainerConfigDict = {
"logger_config": None,
# === API deprecations/simplifications/changes ===
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration()` method will be called on each
# training iteration.
"_disable_execution_plan_api": True,
# Experimental flag.
# If True, TFPolicy will handle more than one loss/optimizer.
# Set this to True, if you would like to return more than
@ -663,11 +668,6 @@ COMMON_CONFIG: TrainerConfigDict = {
# - Models that have the previous action(s) as part of their input.
# - Algorithms reading from offline files (incl. action information).
"_disable_action_flattening": False,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
# If True, disable the environment pre-checking module.
"disable_env_checking": False,

View file

@ -214,7 +214,7 @@ class TrainerConfig:
self._tf_policy_handles_more_than_one_loss = False
self._disable_preprocessor_api = False
self._disable_action_flattening = False
self._disable_execution_plan_api = False
self._disable_execution_plan_api = True
def to_dict(self) -> TrainerConfigDict:
"""Converts all settings into a legacy config dict for backward compatibility.

View file

@ -5,7 +5,6 @@ import ray
import numpy as np
from ray.rllib import Policy
from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.examples.env.parametric_actions_cartpole import ParametricActionsCartPole
@ -14,13 +13,6 @@ from ray.rllib.utils import override
from ray.rllib.utils.typing import ResultDict
from ray.tune.registry import register_env
DEFAULT_CONFIG = with_common_config(
{
# Run with new `training_iteration` API.
"_disable_execution_plan_api": True,
}
)
class RandomParametricPolicy(Policy, ABC):
"""
@ -73,10 +65,6 @@ class RandomParametricTrainer(Trainer):
rollout and performs no learning.
"""
@classmethod
def get_default_config(cls):
return DEFAULT_CONFIG
def get_default_policy_class(self, config):
return RandomParametricPolicy

View file

@ -71,7 +71,6 @@ class MyTrainer(Trainer):
# parameters.
return with_common_config(
{
"_disable_execution_plan_api": True,
"num_sgd_iter": 10,
"sgd_minibatch_size": 128,
}

View file

@ -38,8 +38,7 @@ def evaluate_test(algo, env="CartPole-v0", test_episode_rollout=False):
+ ', "min_sample_timesteps_per_reporting": 5,'
'"min_time_s_per_reporting": 0.1, '
'"model": {"fcnet_hiddens": [10]}'
"}' --stop='{\"training_iteration\": 1}'"
+ " --env={} --no-ray-ui".format(env)
"}' --stop='{\"training_iteration\": 1}'" + " --env={}".format(env)
)
checkpoint_path = os.popen(