[RLlib] Fix all example scripts to run on GPUs. (#11105)

This commit is contained in:
Sven Mika 2020-10-02 23:07:44 +02:00 committed by GitHub
parent 5a42ed1848
commit c17169dc11
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 221 additions and 98 deletions

View file

@ -27,14 +27,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
}, },
) )
def validate_config(config):
if config.get("framework") == "tfe":
raise ValueError("APEX_DDPG does not support tf-eager yet!")
ApexDDPGTrainer = DDPGTrainer.with_updates( ApexDDPGTrainer = DDPGTrainer.with_updates(
name="APEX_DDPG", name="APEX_DDPG",
default_config=APEX_DDPG_DEFAULT_CONFIG, default_config=APEX_DDPG_DEFAULT_CONFIG,
validate_config=validate_config,
execution_plan=apex_execution_plan) execution_plan=apex_execution_plan)

View file

@ -251,7 +251,10 @@ class KLCoeffMixin:
self.kl_coeff_val = config["kl_coeff"] self.kl_coeff_val = config["kl_coeff"]
# The current KL value (as tf Variable for in-graph operations). # The current KL value (as tf Variable for in-graph operations).
self.kl_coeff = get_variable( self.kl_coeff = get_variable(
float(self.kl_coeff_val), tf_name="kl_coeff", trainable=False) float(self.kl_coeff_val),
tf_name="kl_coeff",
trainable=False,
framework=config["framework"])
# Constant target value. # Constant target value.
self.kl_target = config["kl_target"] self.kl_target = config["kl_target"]

View file

@ -37,7 +37,7 @@ FAKE_BATCH = {
class TestPPO(unittest.TestCase): class TestPPO(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
ray.init(local_mode=True) ray.init()
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):

View file

@ -166,7 +166,8 @@ class TestSAC(unittest.TestCase):
# Set all weights (of all nets) to fixed values. # Set all weights (of all nets) to fixed values.
if weights_dict is None: if weights_dict is None:
assert fw in ["tf", "tfe"] # Start with the tf vars-dict. # Start with the tf vars-dict.
assert fw in ["tf2", "tf", "tfe"]
weights_dict = policy.get_weights() weights_dict = policy.get_weights()
if fw == "tfe": if fw == "tfe":
log_alpha = weights_dict[10] log_alpha = weights_dict[10]

View file

@ -1,4 +1,5 @@
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -42,6 +43,8 @@ if __name__ == "__main__":
"repeat_delay": 2, "repeat_delay": 2,
}, },
"gamma": 0.99, "gamma": 0.99,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
"num_workers": 0, "num_workers": 0,
"num_envs_per_worker": 20, "num_envs_per_worker": 20,
"entropy_coeff": 0.001, "entropy_coeff": 0.001,

View file

@ -11,6 +11,7 @@ This examples shows both.
""" """
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -44,7 +45,8 @@ if __name__ == "__main__":
config = { config = {
"env": CorrelatedActionsEnv, "env": CorrelatedActionsEnv,
"gamma": 0.5, "gamma": 0.5,
"num_gpus": 0, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"model": { "model": {
"custom_model": "autoregressive_model", "custom_model": "autoregressive_model",
"custom_action_dist": "binary_autoreg_dist", "custom_action_dist": "binary_autoreg_dist",
@ -58,7 +60,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
results = tune.run(args.run, stop=stop, config=config) results = tune.run(args.run, stop=stop, config=config, verbose=1)
if args.as_test: if args.as_test:
check_learning_achieved(results, args.stop_reward) check_learning_achieved(results, args.stop_reward)

View file

@ -1,6 +1,7 @@
"""Example of using a custom model with batch norm.""" """Example of using a custom model with batch norm."""
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -32,6 +33,8 @@ if __name__ == "__main__":
"model": { "model": {
"custom_model": "bn_model", "custom_model": "bn_model",
}, },
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
} }
@ -42,7 +45,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
results = tune.run(args.run, stop=stop, config=config) results = tune.run(args.run, stop=stop, config=config, verbose=1)
if args.as_test: if args.as_test:
check_learning_achieved(results, args.stop_reward) check_learning_achieved(results, args.stop_reward)

View file

@ -1,4 +1,5 @@
import argparse import argparse
import os
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.utils.test_utils import check_learning_achieved from ray.rllib.utils.test_utils import check_learning_achieved
@ -35,8 +36,11 @@ if __name__ == "__main__":
} }
config = dict( config = dict(
configs[args.run], **{ configs[args.run],
**{
"env": StatelessCartPole, "env": StatelessCartPole,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"model": { "model": {
"use_lstm": True, "use_lstm": True,
"lstm_use_prev_action_reward": args.use_prev_action_reward, "lstm_use_prev_action_reward": args.use_prev_action_reward,

View file

@ -16,6 +16,7 @@ modifies the environment.
import argparse import argparse
import numpy as np import numpy as np
from gym.spaces import Discrete from gym.spaces import Discrete
import os
import ray import ray
from ray import tune from ray import tune
@ -90,7 +91,7 @@ def centralized_critic_postprocessing(policy,
sample_batch[OPPONENT_OBS], policy.device), sample_batch[OPPONENT_OBS], policy.device),
convert_to_torch_tensor( convert_to_torch_tensor(
sample_batch[OPPONENT_ACTION], policy.device)) \ sample_batch[OPPONENT_ACTION], policy.device)) \
.detach().numpy() .cpu().detach().numpy()
else: else:
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf( sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS], sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
@ -137,14 +138,22 @@ def loss_with_central_critic(policy, model, dist_class, train_batch):
return loss return loss
def setup_mixins(policy, obs_space, action_space, config): def setup_tf_mixins(policy, obs_space, action_space, config):
# copied from PPO # Copied from PPOTFPolicy (w/o ValueNetworkMixin).
KLCoeffMixin.__init__(policy, config) KLCoeffMixin.__init__(policy, config)
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"]) config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def setup_torch_mixins(policy, obs_space, action_space, config):
# Copied from PPOTorchPolicy (w/o ValueNetworkMixin).
TorchKLCoeffMixin.__init__(policy, config)
TorchEntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
TorchLR.__init__(policy, config["lr"], config["lr_schedule"])
def central_vf_stats(policy, train_batch, grads): def central_vf_stats(policy, train_batch, grads):
# Report the explained variance of the central value function. # Report the explained variance of the central value function.
return { return {
@ -158,7 +167,7 @@ CCPPOTFPolicy = PPOTFPolicy.with_updates(
name="CCPPOTFPolicy", name="CCPPOTFPolicy",
postprocess_fn=centralized_critic_postprocessing, postprocess_fn=centralized_critic_postprocessing,
loss_fn=loss_with_central_critic, loss_fn=loss_with_central_critic,
before_loss_init=setup_mixins, before_loss_init=setup_tf_mixins,
grad_stats_fn=central_vf_stats, grad_stats_fn=central_vf_stats,
mixins=[ mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
@ -169,7 +178,7 @@ CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
name="CCPPOTorchPolicy", name="CCPPOTorchPolicy",
postprocess_fn=centralized_critic_postprocessing, postprocess_fn=centralized_critic_postprocessing,
loss_fn=loss_with_central_critic, loss_fn=loss_with_central_critic,
before_init=setup_mixins, before_init=setup_torch_mixins,
mixins=[ mixins=[
TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin, TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
CentralizedValueMixin CentralizedValueMixin
@ -188,7 +197,7 @@ CCTrainer = PPOTrainer.with_updates(
) )
if __name__ == "__main__": if __name__ == "__main__":
ray.init(local_mode=True) ray.init()
args = parser.parse_args() args = parser.parse_args()
ModelCatalog.register_custom_model( ModelCatalog.register_custom_model(
@ -198,6 +207,8 @@ if __name__ == "__main__":
config = { config = {
"env": TwoStepGame, "env": TwoStepGame,
"batch_mode": "complete_episodes", "batch_mode": "complete_episodes",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"multiagent": { "multiagent": {
"policies": { "policies": {
@ -222,7 +233,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
results = tune.run(CCTrainer, config=config, stop=stop) results = tune.run(CCTrainer, config=config, stop=stop, verbose=1)
if args.as_test: if args.as_test:
check_learning_achieved(results, args.stop_reward) check_learning_achieved(results, args.stop_reward)

View file

@ -12,6 +12,7 @@ modifies the policy to add a centralized value function.
import numpy as np import numpy as np
from gym.spaces import Dict, Discrete from gym.spaces import Dict, Discrete
import argparse import argparse
import os
from ray import tune from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.agents.callbacks import DefaultCallbacks
@ -87,6 +88,8 @@ if __name__ == "__main__":
"env": TwoStepGame, "env": TwoStepGame,
"batch_mode": "complete_episodes", "batch_mode": "complete_episodes",
"callbacks": FillInActions, "callbacks": FillInActions,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"multiagent": { "multiagent": {
"policies": { "policies": {

View file

@ -8,7 +8,9 @@ For PyTorch / TF eager mode, use the --torch and --eager flags.
""" """
import argparse import argparse
import os
import ray
from ray import tune from ray import tune
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.rllib.examples.env.simple_rpg import SimpleRPG from ray.rllib.examples.env.simple_rpg import SimpleRPG
@ -17,9 +19,10 @@ from ray.rllib.examples.models.simple_rpg_model import CustomTorchRPGModel, \
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--framework", choices=["tf", "tfe", "torch"], default="tf") "--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf2")
if __name__ == "__main__": if __name__ == "__main__":
ray.init()
args = parser.parse_args() args = parser.parse_args()
if args.framework == "torch": if args.framework == "torch":
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel) ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
@ -31,6 +34,8 @@ if __name__ == "__main__":
"env": SimpleRPG, "env": SimpleRPG,
"rollout_fragment_length": 1, "rollout_fragment_length": 1,
"train_batch_size": 2, "train_batch_size": 2,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"model": { "model": {
"custom_model": "my_model", "custom_model": "my_model",

View file

@ -11,6 +11,7 @@ import argparse
import gym import gym
from gym.spaces import Discrete, Box from gym.spaces import Discrete, Box
import numpy as np import numpy as np
import os
import ray import ray
from ray import tune from ray import tune
@ -114,6 +115,8 @@ if __name__ == "__main__":
"env_config": { "env_config": {
"corridor_length": 5, "corridor_length": 5,
}, },
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"model": { "model": {
"custom_model": "my_model", "custom_model": "my_model",
}, },

View file

@ -67,6 +67,7 @@ Result for PG_SimpleCorridor_0de4e686:
""" """
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -137,7 +138,9 @@ if __name__ == "__main__":
"corridor_length": 10, "corridor_length": 10,
}, },
"horizon": 20, "horizon": 20,
"log_level": "INFO",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
# Training rollouts will be collected using just the learner # Training rollouts will be collected using just the learner
# process, but evaluation will be done in parallel with two # process, but evaluation will be done in parallel with two

View file

@ -5,6 +5,7 @@ for running perf microbenchmarks.
""" """
import argparse import argparse
import os
import ray import ray
import ray.tune as tune import ray.tune as tune
@ -32,7 +33,8 @@ if __name__ == "__main__":
"model": { "model": {
"custom_model": "fast_model" "custom_model": "fast_model"
}, },
"num_gpus": 0, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 2, "num_workers": 2,
"num_envs_per_worker": 10, "num_envs_per_worker": 10,
"num_data_loader_buffers": 1, "num_data_loader_buffers": 1,
@ -40,7 +42,7 @@ if __name__ == "__main__":
"broadcast_interval": 50, "broadcast_interval": 50,
"rollout_fragment_length": 100, "rollout_fragment_length": 100,
"train_batch_size": sample_from( "train_batch_size": sample_from(
lambda spec: 1000 * max(1, spec.config.num_gpus)), lambda spec: 1000 * max(1, spec.config.num_gpus or 1)),
"fake_sampler": True, "fake_sampler": True,
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
} }
@ -50,6 +52,6 @@ if __name__ == "__main__":
"timesteps_total": args.stop_timesteps, "timesteps_total": args.stop_timesteps,
} }
tune.run("IMPALA", config=config, stop=stop) tune.run("IMPALA", config=config, stop=stop, verbose=1)
ray.shutdown() ray.shutdown()

View file

@ -1,6 +1,7 @@
"""Example of using a custom ModelV2 Keras-style model.""" """Example of using a custom ModelV2 Keras-style model."""
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -119,11 +120,12 @@ if __name__ == "__main__":
args.run, args.run,
stop={"episode_reward_mean": args.stop}, stop={"episode_reward_mean": args.stop},
config=dict( config=dict(
extra_config, **{ extra_config,
"log_level": "INFO", **{
"env": "BreakoutNoFrameskip-v4" "env": "BreakoutNoFrameskip-v4"
if args.use_vision_network else "CartPole-v0", if args.use_vision_network else "CartPole-v0",
"num_gpus": 0, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"callbacks": { "callbacks": {
"on_train_result": check_has_custom_metric, "on_train_result": check_has_custom_metric,
}, },

View file

@ -50,6 +50,8 @@ if __name__ == "__main__":
config = { config = {
"env": "CartPole-v0", "env": "CartPole-v0",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"model": { "model": {
"custom_model": "custom_loss", "custom_model": "custom_loss",
@ -64,4 +66,4 @@ if __name__ == "__main__":
"training_iteration": args.stop_iters, "training_iteration": args.stop_iters,
} }
tune.run("PG", config=config, stop=stop) tune.run("PG", config=config, stop=stop, verbose=1)

View file

@ -7,14 +7,19 @@ custom metric.
from typing import Dict from typing import Dict
import argparse import argparse
import numpy as np import numpy as np
import os
import ray import ray
from ray import tune from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.agents.callbacks import DefaultCallbacks parser = argparse.ArgumentParser()
parser.add_argument("--torch", action="store_true")
parser.add_argument("--stop-iters", type=int, default=2000)
class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
@ -65,8 +70,6 @@ class MyCallbacks(DefaultCallbacks):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--stop-iters", type=int, default=2000)
args = parser.parse_args() args = parser.parse_args()
ray.init() ray.init()
@ -79,7 +82,9 @@ if __name__ == "__main__":
"env": "CartPole-v0", "env": "CartPole-v0",
"num_envs_per_worker": 2, "num_envs_per_worker": 2,
"callbacks": MyCallbacks, "callbacks": MyCallbacks,
"framework": "tf", "framework": "torch" if args.torch else "tf",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}).trials }).trials
# verify custom metrics for integration tests # verify custom metrics for integration tests

View file

@ -2,6 +2,7 @@
import argparse import argparse
import numpy as np import numpy as np
import os
import ray import ray
from ray import tune from ray import tune
@ -73,6 +74,8 @@ if __name__ == "__main__":
"on_postprocess_traj": on_postprocess_traj, "on_postprocess_traj": on_postprocess_traj,
}, },
"framework": "tf", "framework": "tf",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}).trials }).trials
# verify custom metrics for integration tests # verify custom metrics for integration tests

View file

@ -1,6 +1,7 @@
"""Example of using a custom RNN keras model.""" """Example of using a custom RNN keras model."""
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -37,6 +38,8 @@ if __name__ == "__main__":
"repeat_delay": 2, "repeat_delay": 2,
}, },
"gamma": 0.9, "gamma": 0.9,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"num_envs_per_worker": 20, "num_envs_per_worker": 20,
"entropy_coeff": 0.001, "entropy_coeff": 0.001,

View file

@ -1,4 +1,5 @@
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -50,6 +51,8 @@ if __name__ == "__main__":
stop={"training_iteration": args.stop_iters}, stop={"training_iteration": args.stop_iters},
config={ config={
"env": "CartPole-v0", "env": "CartPole-v0",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 2, "num_workers": 2,
"framework": "tf", "framework": "tf",
}) })

View file

@ -1,4 +1,5 @@
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -36,6 +37,8 @@ if __name__ == "__main__":
stop={"training_iteration": args.stop_iters}, stop={"training_iteration": args.stop_iters},
config={ config={
"env": "CartPole-v0", "env": "CartPole-v0",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 2, "num_workers": 2,
"framework": "torch", "framework": "torch",
}) })

View file

@ -6,6 +6,7 @@ This example shows:
You can visualize experiment results in ~/ray_results using TensorBoard. You can visualize experiment results in ~/ray_results using TensorBoard.
""" """
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -43,6 +44,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
config = { config = {
"lr": 0.01, "lr": 0.01,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
} }

View file

@ -1,4 +1,5 @@
import argparse import argparse
import os
import random import random
import ray import ray
@ -58,12 +59,14 @@ MyTrainer = build_trainer(
) )
if __name__ == "__main__": if __name__ == "__main__":
ray.init() ray.init(local_mode=True)
args = parser.parse_args() args = parser.parse_args()
ModelCatalog.register_custom_model("eager_model", EagerModel) ModelCatalog.register_custom_model("eager_model", EagerModel)
config = { config = {
"env": "CartPole-v0", "env": "CartPole-v0",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"model": { "model": {
"custom_model": "eager_model" "custom_model": "eager_model"
@ -76,7 +79,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
results = tune.run(MyTrainer, stop=stop, config=config) results = tune.run(MyTrainer, stop=stop, config=config, verbose=1)
if args.as_test: if args.as_test:
check_learning_achieved(results, args.stop_reward) check_learning_achieved(results, args.stop_reward)

View file

@ -25,6 +25,7 @@ using --flat in this example.
import argparse import argparse
from gym.spaces import Discrete, Tuple from gym.spaces import Discrete, Tuple
import logging import logging
import os
import ray import ray
from ray import tune from ray import tune
@ -75,7 +76,6 @@ if __name__ == "__main__":
config = { config = {
"env": HierarchicalWindyMazeEnv, "env": HierarchicalWindyMazeEnv,
"num_workers": 0, "num_workers": 0,
"log_level": "INFO",
"entropy_coeff": 0.01, "entropy_coeff": 0.01,
"multiagent": { "multiagent": {
"policies": { "policies": {
@ -94,6 +94,8 @@ if __name__ == "__main__":
"policy_mapping_fn": function(policy_mapping_fn), "policy_mapping_fn": function(policy_mapping_fn),
}, },
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
} }
results = tune.run("PPO", stop=stop, config=config, verbose=1) results = tune.run("PPO", stop=stop, config=config, verbose=1)

View file

@ -5,8 +5,9 @@
import argparse import argparse
from gym.spaces import Discrete, Box from gym.spaces import Discrete, Box
import numpy as np import numpy as np
import os
from ray.rllib.agents.ppo import PPOTrainer from ray import tune
from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.examples.models.mobilenet_v2_with_lstm_models import \ from ray.rllib.examples.models.mobilenet_v2_with_lstm_models import \
MobileV2PlusRNNModel, TorchMobileV2PlusRNNModel MobileV2PlusRNNModel, TorchMobileV2PlusRNNModel
@ -21,6 +22,9 @@ cnn_shape_torch = (3, 224, 224)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--torch", action="store_true") parser.add_argument("--torch", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-reward", type=float, default=0.0)
parser.add_argument("--stop-timesteps", type=int, default=100000)
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
@ -30,8 +34,15 @@ if __name__ == "__main__":
"my_model", TorchMobileV2PlusRNNModel "my_model", TorchMobileV2PlusRNNModel
if args.torch else MobileV2PlusRNNModel) if args.torch else MobileV2PlusRNNModel)
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
# Configure our Trainer. # Configure our Trainer.
config = { config = {
"env": RandomEnv,
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
"model": { "model": {
"custom_model": "my_model", "custom_model": "my_model",
@ -42,6 +53,8 @@ if __name__ == "__main__":
"max_seq_len": 20, "max_seq_len": 20,
}, },
"vf_share_layers": True, "vf_share_layers": True,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, # no parallelism "num_workers": 0, # no parallelism
"env_config": { "env_config": {
"action_space": Discrete(2), "action_space": Discrete(2),
@ -54,5 +67,4 @@ if __name__ == "__main__":
}, },
} }
trainer = PPOTrainer(config=config, env=RandomEnv) tune.run("PPO", config=config, stop=stop, verbose=1)
print(trainer.train())

View file

@ -131,8 +131,8 @@ class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
def _a1_distribution(self): def _a1_distribution(self):
BATCH = self.inputs.shape[0] BATCH = self.inputs.shape[0]
a1_logits, _ = self.model.action_module(self.inputs, zeros = torch.zeros((BATCH, 1)).to(self.inputs.device)
torch.zeros((BATCH, 1))) a1_logits, _ = self.model.action_module(self.inputs, zeros)
a1_dist = TorchCategorical(a1_logits) a1_dist = TorchCategorical(a1_logits)
return a1_dist return a1_dist

View file

@ -116,7 +116,7 @@ class TorchCustomLossModel(TorchModelV2, nn.Module):
# Define a secondary loss by building a graph copy with weight sharing. # Define a secondary loss by building a graph copy with weight sharing.
obs = restore_original_dimensions( obs = restore_original_dimensions(
torch.from_numpy(batch["obs"]).float(), torch.from_numpy(batch["obs"]).float().to(policy_loss[0].device),
self.obs_space, self.obs_space,
tensorlib="torch") tensorlib="torch")
logits, _ = self.forward({"obs": obs}, [], None) logits, _ = self.forward({"obs": obs}, [], None)
@ -130,8 +130,8 @@ class TorchCustomLossModel(TorchModelV2, nn.Module):
# Compute the IL loss. # Compute the IL loss.
action_dist = TorchCategorical(logits, self.model_config) action_dist = TorchCategorical(logits, self.model_config)
imitation_loss = torch.mean( imitation_loss = torch.mean(-action_dist.logp(
-action_dist.logp(torch.from_numpy(batch["actions"]))) torch.from_numpy(batch["actions"]).to(policy_loss[0].device)))
self.imitation_loss_metric = imitation_loss.item() self.imitation_loss_metric = imitation_loss.item()
self.policy_loss_metric = np.mean([l.item() for l in policy_loss]) self.policy_loss_metric = np.mean([l.item() for l in policy_loss])

View file

@ -57,8 +57,8 @@ class TorchFastModel(TorchModelV2, nn.Module):
model_config, name) model_config, name)
nn.Module.__init__(self) nn.Module.__init__(self)
self.bias = torch.tensor( self.bias = nn.Parameter(
[0.0], dtype=torch.float32, requires_grad=True) torch.tensor([0.0], dtype=torch.float32, requires_grad=True))
# Only needed to give some params to the optimizer (even though, # Only needed to give some params to the optimizer (even though,
# they are never used anywhere). # they are never used anywhere).
@ -67,8 +67,9 @@ class TorchFastModel(TorchModelV2, nn.Module):
@override(ModelV2) @override(ModelV2)
def forward(self, input_dict, state, seq_lens): def forward(self, input_dict, state, seq_lens):
self._output = self.bias + \ self._output = self.bias + torch.zeros(
torch.zeros(size=(input_dict["obs"].shape[0], self.num_outputs)) size=(input_dict["obs"].shape[0], self.num_outputs)).to(
self.bias.device)
return self._output, [] return self._output, []
@override(ModelV2) @override(ModelV2)

View file

@ -89,14 +89,15 @@ class MobileV2PlusRNNModel(RecurrentNetwork):
return tf.reshape(self._value_out, [-1]) return tf.reshape(self._value_out, [-1])
class TorchMobileV2PlusRNNModel(TorchRNN): class TorchMobileV2PlusRNNModel(TorchRNN, nn.Module):
"""A conv. + recurrent torch net example using a pre-trained MobileNet.""" """A conv. + recurrent torch net example using a pre-trained MobileNet."""
def __init__(self, obs_space, action_space, num_outputs, model_config, def __init__(self, obs_space, action_space, num_outputs, model_config,
name, cnn_shape): name, cnn_shape):
super().__init__(obs_space, action_space, num_outputs, model_config, TorchRNN.__init__(self, obs_space, action_space, num_outputs,
name) model_config, name)
nn.Module.__init__(self)
self.lstm_state_size = 16 self.lstm_state_size = 16
self.cnn_shape = list(cnn_shape) self.cnn_shape = list(cnn_shape)

View file

@ -125,12 +125,13 @@ class TorchSharedWeightsModel(TorchModelV2, nn.Module):
activation_fn=None, activation_fn=None,
initializer=torch.nn.init.xavier_uniform_, initializer=torch.nn.init.xavier_uniform_,
) )
self._global_shared_layer = TORCH_GLOBAL_SHARED_LAYER
self._output = None self._output = None
@override(ModelV2) @override(ModelV2)
def forward(self, input_dict, state, seq_lens): def forward(self, input_dict, state, seq_lens):
out = self.first_layer(input_dict["obs"]) out = self.first_layer(input_dict["obs"])
self._output = TORCH_GLOBAL_SHARED_LAYER(out) self._output = self._global_shared_layer(out)
model_out = self.last_layer(self._output) model_out = self.last_layer(self._output)
return model_out, [] return model_out, []

View file

@ -11,6 +11,7 @@ execution, set the TF_TIMELINE_DIR environment variable.
import argparse import argparse
import gym import gym
import os
import random import random
import ray import ray
@ -75,6 +76,8 @@ if __name__ == "__main__":
"num_agents": args.num_agents, "num_agents": args.num_agents,
}, },
"simple_optimizer": args.simple, "simple_optimizer": args.simple,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_sgd_iter": 10, "num_sgd_iter": 10,
"multiagent": { "multiagent": {
"policies": policies, "policies": policies,

View file

@ -15,6 +15,7 @@ Result for PG_multi_cartpole_0:
import argparse import argparse
import gym import gym
import os
import ray import ray
from ray import tune from ray import tune
@ -60,6 +61,8 @@ if __name__ == "__main__":
lambda agent_id: ["pg_policy", "random"][agent_id % 2]), lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
}, },
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
} }
results = tune.run("PG", config=config, stop=stop, verbose=1) results = tune.run("PG", config=config, stop=stop, verbose=1)

View file

@ -10,6 +10,7 @@ For a simpler example, see also: multiagent_cartpole.py
import argparse import argparse
import gym import gym
import os
import ray import ray
from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
@ -38,9 +39,9 @@ if __name__ == "__main__":
# Simple environment with 4 independent cartpole entities # Simple environment with 4 independent cartpole entities
register_env("multi_agent_cartpole", register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": 4})) lambda _: MultiAgentCartPole({"num_agents": 4}))
single_env = gym.make("CartPole-v0") single_dummy_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space obs_space = single_dummy_env.observation_space
act_space = single_env.action_space act_space = single_dummy_env.action_space
# You can also have multiple policies per trainer, but here we just # You can also have multiple policies per trainer, but here we just
# show one each for PPO and DQN. # show one each for PPO and DQN.
@ -69,6 +70,8 @@ if __name__ == "__main__":
# disable filters, otherwise we would need to synchronize those # disable filters, otherwise we would need to synchronize those
# as well to the DQN agent # as well to the DQN agent
"observation_filter": "NoFilter", "observation_filter": "NoFilter",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
}) })
@ -82,6 +85,8 @@ if __name__ == "__main__":
}, },
"gamma": 0.95, "gamma": 0.95,
"n_step": 3, "n_step": 3,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch" if args.torch or args.mixed_torch_tf else "tf" "framework": "torch" if args.torch or args.mixed_torch_tf else "tf"
}) })

View file

@ -1,5 +1,6 @@
import argparse import argparse
from gym.spaces import Dict, Tuple, Box, Discrete from gym.spaces import Dict, Tuple, Box, Discrete
import os
import ray import ray
import ray.tune as tune import ray.tune as tune
@ -40,6 +41,8 @@ if __name__ == "__main__":
"gamma": 0.0, # No history in Env (bandit problem). "gamma": 0.0, # No history in Env (bandit problem).
"lr": 0.0005, "lr": 0.0005,
"num_envs_per_worker": 20, "num_envs_per_worker": 20,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_sgd_iter": 4, "num_sgd_iter": 4,
"num_workers": 0, "num_workers": 0,
"vf_loss_coeff": 0.01, "vf_loss_coeff": 0.01,

View file

@ -15,6 +15,7 @@ Working configurations are given below.
""" """
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -55,14 +56,18 @@ if __name__ == "__main__":
else: else:
cfg = {} cfg = {}
config = dict({ config = dict(
"env": "pa_cartpole", {
"model": { "env": "pa_cartpole",
"custom_model": "pa_model", "model": {
"custom_model": "pa_model",
},
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0,
"framework": "torch" if args.torch else "tf",
}, },
"num_workers": 0, **cfg)
"framework": "torch" if args.torch else "tf",
}, **cfg)
stop = { stop = {
"training_iteration": args.stop_iters, "training_iteration": args.stop_iters,

View file

@ -1,15 +1,13 @@
from copy import deepcopy from copy import deepcopy
import ray from numpy import float32
try: import os
from ray.rllib.agents.agent import get_agent_class
except ImportError:
from ray.rllib.agents.registry import get_agent_class
from ray.tune.registry import register_env
from ray.rllib.env import PettingZooEnv
from pettingzoo.butterfly import pistonball_v0 from pettingzoo.butterfly import pistonball_v0
from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0 from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
from numpy import float32 import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import PettingZooEnv
from ray.tune.registry import register_env
if __name__ == "__main__": if __name__ == "__main__":
"""For this script, you need: """For this script, you need:
@ -37,7 +35,7 @@ if __name__ == "__main__":
config = deepcopy(get_agent_class(alg_name)._default_config) config = deepcopy(get_agent_class(alg_name)._default_config)
# 2. Set environment config. This will be passed to # 2. Set environment config. This will be passed to
# the env_creator function via the register env lambda below # the env_creator function via the register env lambda below.
config["env_config"] = {"local_ratio": 0.5} config["env_config"] = {"local_ratio": 0.5}
# 3. Register env # 3. Register env
@ -58,6 +56,8 @@ if __name__ == "__main__":
"policy_mapping_fn": lambda agent_id: "av" "policy_mapping_fn": lambda agent_id: "av"
} }
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
config["log_level"] = "DEBUG" config["log_level"] = "DEBUG"
config["num_workers"] = 1 config["num_workers"] = 1
# Fragment length, collected at once from each worker and for each agent! # Fragment length, collected at once from each worker and for each agent!

View file

@ -9,6 +9,7 @@ This demonstrates running the following policies in competition:
import argparse import argparse
from gym.spaces import Discrete from gym.spaces import Discrete
import os
import random import random
from ray import tune from ray import tune
@ -63,6 +64,8 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"):
config = { config = {
"env": RockPaperScissors, "env": RockPaperScissors,
"gamma": 0.9, "gamma": 0.9,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0, "num_workers": 0,
"num_envs_per_worker": 4, "num_envs_per_worker": 4,
"rollout_fragment_length": 10, "rollout_fragment_length": 10,

View file

@ -8,6 +8,7 @@ collection and policy optimization.
import argparse import argparse
import gym import gym
import numpy as np import numpy as np
import os
import ray import ray
from ray import tune from ray import tune
@ -32,6 +33,7 @@ class CustomPolicy(Policy):
def __init__(self, observation_space, action_space, config): def __init__(self, observation_space, action_space, config):
super().__init__(observation_space, action_space, config) super().__init__(observation_space, action_space, config)
self.config["framework"] = None
# example parameter # example parameter
self.w = 1.0 self.w = 1.0
@ -107,7 +109,8 @@ if __name__ == "__main__":
tune.run( tune.run(
training_workflow, training_workflow,
resources_per_trial={ resources_per_trial={
"gpu": 1 if args.gpu else 0, "gpu": 1 if args.gpu
or int(os.environ.get("RLLIB_FORCE_NUM_GPUS", 0)) else 0,
"cpu": 1, "cpu": 1,
"extra_cpu": args.num_workers, "extra_cpu": args.num_workers,
}, },
@ -115,4 +118,5 @@ if __name__ == "__main__":
"num_workers": args.num_workers, "num_workers": args.num_workers,
"num_iters": args.num_iters, "num_iters": args.num_iters,
}, },
verbose=1,
) )

View file

@ -69,7 +69,7 @@ parser.add_argument(
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
ray.init(local_mode=True) ray.init()
# Create a fake-env for the server. This env will never be used (neither # Create a fake-env for the server. This env will never be used (neither
# for sampling, nor for evaluation) and its obs/action Spaces do not # for sampling, nor for evaluation) and its obs/action Spaces do not

View file

@ -10,6 +10,7 @@ See also: centralized_critic.py for centralized critic PPO on this game.
import argparse import argparse
from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete
import os
import ray import ray
from ray import tune from ray import tune
@ -77,6 +78,8 @@ if __name__ == "__main__":
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2", "policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
}, },
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
} }
group = False group = False
elif args.run == "QMIX": elif args.run == "QMIX":
@ -93,11 +96,17 @@ if __name__ == "__main__":
"separate_state_space": True, "separate_state_space": True,
"one_hot_state_encoding": True "one_hot_state_encoding": True
}, },
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
} }
group = True group = True
else: else:
config = {"framework": "torch" if args.torch else "tf"} config = {
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch" if args.torch else "tf",
}
group = False group = False
ray.init(num_cpus=args.num_cpus or None) ray.init(num_cpus=args.num_cpus or None)

View file

@ -7,6 +7,7 @@ via a custom training workflow.
import argparse import argparse
import gym import gym
import os
import ray import ray
from ray import tune from ray import tune
@ -139,6 +140,8 @@ if __name__ == "__main__":
"policy_mapping_fn": policy_mapping_fn, "policy_mapping_fn": policy_mapping_fn,
"policies_to_train": ["dqn_policy", "ppo_policy"], "policies_to_train": ["dqn_policy", "ppo_policy"],
}, },
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": "torch" if args.torch else "tf", "framework": "torch" if args.torch else "tf",
} }

View file

@ -21,6 +21,7 @@ $ python unity3d_env_local.py --env 3DBall --stop-reward [..] [--torch]?
""" """
import argparse import argparse
import os
import ray import ray
from ray import tune from ray import tune
@ -99,6 +100,8 @@ if __name__ == "__main__":
"gamma": 0.99, "gamma": 0.99,
"sgd_minibatch_size": 256, "sgd_minibatch_size": 256,
"train_batch_size": 4000, "train_batch_size": 4000,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_sgd_iter": 20, "num_sgd_iter": 20,
"rollout_fragment_length": 200, "rollout_fragment_length": 200,
"clip_param": 0.2, "clip_param": 0.2,

View file

@ -307,7 +307,7 @@ class ModelCatalog:
model_cls = ModelCatalog._wrap_if_needed(model_cls, model_cls = ModelCatalog._wrap_if_needed(model_cls,
model_interface) model_interface)
if framework in ["tf", "tfe"]: if framework in ["tf2", "tf", "tfe"]:
# Track and warn if vars were created but not registered. # Track and warn if vars were created but not registered.
created = set() created = set()

View file

@ -423,7 +423,7 @@ class TestDistributions(unittest.TestCase):
def test_gumbel_softmax(self): def test_gumbel_softmax(self):
"""Tests the GumbelSoftmax ActionDistribution (tf + eager only).""" """Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
for fw, sess in framework_iterator( for fw, sess in framework_iterator(
frameworks=["tf", "tfe"], session=True): frameworks=("tf2", "tf", "tfe"), session=True):
batch_size = 1000 batch_size = 1000
num_categories = 5 num_categories = 5
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories)) input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))

View file

@ -200,7 +200,7 @@ def build_eager_tf_policy(name,
class eager_policy_cls(base): class eager_policy_cls(base):
def __init__(self, observation_space, action_space, config): def __init__(self, observation_space, action_space, config):
assert tf.executing_eagerly() assert tf.executing_eagerly()
self.framework = "tfe" self.framework = config.get("framework", "tfe")
Policy.__init__(self, observation_space, action_space, config) Policy.__init__(self, observation_space, action_space, config)
self._is_training = False self._is_training = False
self._loss_initialized = False self._loss_initialized = False

View file

@ -88,7 +88,7 @@ class SampleBatch:
@staticmethod @staticmethod
@PublicAPI @PublicAPI
def concat_samples(samples: List[Dict[str, TensorType]]) -> \ def concat_samples(samples: List["SampleBatch"]) -> \
Union["SampleBatch", "MultiAgentBatch"]: Union["SampleBatch", "MultiAgentBatch"]:
"""Concatenates n data dicts or MultiAgentBatches. """Concatenates n data dicts or MultiAgentBatches.

View file

@ -392,7 +392,7 @@ class TorchPolicy(Policy):
grad_info["allreduce_latency"] += time.time() - start grad_info["allreduce_latency"] += time.time() - start
# Step the optimizer # Step the optimizers.
for i, opt in enumerate(self._optimizers): for i, opt in enumerate(self._optimizers):
opt.step() opt.step()

View file

@ -66,7 +66,8 @@ def build_torch_policy(
mixins: Optional[List[type]] = None, mixins: Optional[List[type]] = None,
view_requirements_fn: Optional[Callable[[], Dict[ view_requirements_fn: Optional[Callable[[], Dict[
str, ViewRequirement]]] = None, str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None): get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
) -> Type[TorchPolicy]:
"""Helper function for creating a torch policy class at runtime. """Helper function for creating a torch policy class at runtime.
Args: Args:
@ -167,7 +168,8 @@ def build_torch_policy(
sample batches. If None, will assume a value of 1. sample batches. If None, will assume a value of 1.
Returns: Returns:
type: TorchPolicy child class constructed from the specified args. Type[TorchPolicy]: TorchPolicy child class constructed from the
specified args.
""" """
original_kwargs = locals().copy() original_kwargs = locals().copy()

View file

@ -96,7 +96,7 @@ def ckpt_restore_test(alg_name, tfe=False):
if optim_state: if optim_state:
s2 = alg2.get_policy().get_state().get("_optimizer_variables") s2 = alg2.get_policy().get_state().get("_optimizer_variables")
# Tf -> Compare states 1:1. # Tf -> Compare states 1:1.
if fw in ["tf", "tfe"]: if fw in ["tf2", "tf", "tfe"]:
check(s2, optim_state) check(s2, optim_state)
# For torch, optimizers have state_dicts with keys=params, # For torch, optimizers have state_dicts with keys=params,
# which are different for the two models (ignore these # which are different for the two models (ignore these

View file

@ -59,22 +59,22 @@ __all__ = [
"add_mixins", "add_mixins",
"check", "check",
"check_compute_single_action", "check_compute_single_action",
"deep_update",
"deprecation_warning", "deprecation_warning",
"fc", "fc",
"force_list", "force_list",
"force_tuple", "force_tuple",
"framework_iterator", "framework_iterator",
"lstm", "lstm",
"one_hot",
"relu",
"sigmoid",
"softmax",
"deep_update",
"merge_dicts", "merge_dicts",
"one_hot",
"override", "override",
"relu",
"renamed_function", "renamed_function",
"renamed_agent", "renamed_agent",
"renamed_class", "renamed_class",
"sigmoid",
"softmax",
"try_import_tf", "try_import_tf",
"try_import_tfp", "try_import_tfp",
"try_import_torch", "try_import_torch",

View file

@ -211,7 +211,7 @@ class Curiosity(Exploration):
}) })
phi, next_phi = torch.chunk(phis, 2) phi, next_phi = torch.chunk(phis, 2)
actions_tensor = torch.from_numpy( actions_tensor = torch.from_numpy(
sample_batch[SampleBatch.ACTIONS]).long() sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)
# Predict next phi with forward model. # Predict next phi with forward model.
predicted_next_phi = self.model._curiosity_forward_fcnet( predicted_next_phi = self.model._curiosity_forward_fcnet(

View file

@ -57,7 +57,7 @@ class EpsilonGreedy(Exploration):
0, framework=framework, tf_name="timestep") 0, framework=framework, tf_name="timestep")
# Build the tf-info-op. # Build the tf-info-op.
if self.framework in ["tf", "tfe"]: if self.framework in ["tf2", "tf", "tfe"]:
self._tf_info_op = self.get_info() self._tf_info_op = self.get_info()
@override(Exploration) @override(Exploration)
@ -68,7 +68,7 @@ class EpsilonGreedy(Exploration):
explore: bool = True): explore: bool = True):
q_values = action_distribution.inputs q_values = action_distribution.inputs
if self.framework in ["tf", "tfe"]: if self.framework in ["tf2", "tf", "tfe"]:
return self._get_tf_exploration_action_op(q_values, explore, return self._get_tf_exploration_action_op(q_values, explore,
timestep) timestep)
else: else:

View file

@ -72,7 +72,7 @@ class GaussianNoise(Exploration):
0, framework=self.framework, tf_name="timestep") 0, framework=self.framework, tf_name="timestep")
# Build the tf-info-op. # Build the tf-info-op.
if self.framework in ["tf", "tfe"]: if self.framework in ["tf2", "tf", "tfe"]:
self._tf_info_op = self.get_info() self._tf_info_op = self.get_info()
@override(Exploration) @override(Exploration)

View file

@ -291,7 +291,7 @@ class ParameterNoise(Exploration):
"""Samples new noise and stores it in `self.noise`.""" """Samples new noise and stores it in `self.noise`."""
if self.framework == "tf": if self.framework == "tf":
tf_sess.run(self.tf_sample_new_noise_op) tf_sess.run(self.tf_sample_new_noise_op)
elif self.framework == "tfe": elif self.framework in ["tfe", "tf2"]:
self._tf_sample_new_noise_op() self._tf_sample_new_noise_op()
else: else:
for i in range(len(self.noise)): for i in range(len(self.noise)):
@ -340,7 +340,7 @@ class ParameterNoise(Exploration):
# Add stored noise to the model's parameters. # Add stored noise to the model's parameters.
if self.framework == "tf": if self.framework == "tf":
tf_sess.run(self.tf_add_stored_noise_op) tf_sess.run(self.tf_add_stored_noise_op)
elif self.framework == "tfe": elif self.framework in ["tf2", "tfe"]:
self._tf_add_stored_noise_op() self._tf_add_stored_noise_op()
else: else:
for i in range(len(self.noise)): for i in range(len(self.noise)):
@ -378,7 +378,7 @@ class ParameterNoise(Exploration):
# Removes the stored noise from the model's parameters. # Removes the stored noise from the model's parameters.
if self.framework == "tf": if self.framework == "tf":
tf_sess.run(self.tf_remove_noise_op) tf_sess.run(self.tf_remove_noise_op)
elif self.framework == "tfe": elif self.framework in ["tf2", "tfe"]:
self._tf_remove_noise_op() self._tf_remove_noise_op()
else: else:
for var, noise in zip(self.model_variables, self.noise): for var, noise in zip(self.model_variables, self.noise):

View file

@ -46,7 +46,7 @@ class Random(Exploration):
timestep: Union[int, TensorType], timestep: Union[int, TensorType],
explore: bool = True): explore: bool = True):
# Instantiate the distribution object. # Instantiate the distribution object.
if self.framework in ["tf", "tfe"]: if self.framework in ["tf2", "tf", "tfe"]:
return self.get_tf_exploration_action_op(action_distribution, return self.get_tf_exploration_action_op(action_distribution,
explore) explore)
else: else:

View file

@ -190,7 +190,7 @@ def get_variable(value,
any: A framework-specific variable (tf.Variable, torch.tensor, or any: A framework-specific variable (tf.Variable, torch.tensor, or
python primitive). python primitive).
""" """
if framework in ["tf", "tfe"]: if framework in ["tf2", "tf", "tfe"]:
import tensorflow as tf import tensorflow as tf
dtype = dtype or getattr( dtype = dtype or getattr(
value, "dtype", tf.float32 value, "dtype", tf.float32