mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fix all example scripts to run on GPUs. (#11105)
This commit is contained in:
parent
5a42ed1848
commit
c17169dc11
56 changed files with 221 additions and 98 deletions
|
@ -27,14 +27,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_config(config):
|
|
||||||
if config.get("framework") == "tfe":
|
|
||||||
raise ValueError("APEX_DDPG does not support tf-eager yet!")
|
|
||||||
|
|
||||||
|
|
||||||
ApexDDPGTrainer = DDPGTrainer.with_updates(
|
ApexDDPGTrainer = DDPGTrainer.with_updates(
|
||||||
name="APEX_DDPG",
|
name="APEX_DDPG",
|
||||||
default_config=APEX_DDPG_DEFAULT_CONFIG,
|
default_config=APEX_DDPG_DEFAULT_CONFIG,
|
||||||
validate_config=validate_config,
|
|
||||||
execution_plan=apex_execution_plan)
|
execution_plan=apex_execution_plan)
|
||||||
|
|
|
@ -251,7 +251,10 @@ class KLCoeffMixin:
|
||||||
self.kl_coeff_val = config["kl_coeff"]
|
self.kl_coeff_val = config["kl_coeff"]
|
||||||
# The current KL value (as tf Variable for in-graph operations).
|
# The current KL value (as tf Variable for in-graph operations).
|
||||||
self.kl_coeff = get_variable(
|
self.kl_coeff = get_variable(
|
||||||
float(self.kl_coeff_val), tf_name="kl_coeff", trainable=False)
|
float(self.kl_coeff_val),
|
||||||
|
tf_name="kl_coeff",
|
||||||
|
trainable=False,
|
||||||
|
framework=config["framework"])
|
||||||
# Constant target value.
|
# Constant target value.
|
||||||
self.kl_target = config["kl_target"]
|
self.kl_target = config["kl_target"]
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ FAKE_BATCH = {
|
||||||
class TestPPO(unittest.TestCase):
|
class TestPPO(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
ray.init(local_mode=True)
|
ray.init()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
|
|
@ -166,7 +166,8 @@ class TestSAC(unittest.TestCase):
|
||||||
|
|
||||||
# Set all weights (of all nets) to fixed values.
|
# Set all weights (of all nets) to fixed values.
|
||||||
if weights_dict is None:
|
if weights_dict is None:
|
||||||
assert fw in ["tf", "tfe"] # Start with the tf vars-dict.
|
# Start with the tf vars-dict.
|
||||||
|
assert fw in ["tf2", "tf", "tfe"]
|
||||||
weights_dict = policy.get_weights()
|
weights_dict = policy.get_weights()
|
||||||
if fw == "tfe":
|
if fw == "tfe":
|
||||||
log_alpha = weights_dict[10]
|
log_alpha = weights_dict[10]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -42,6 +43,8 @@ if __name__ == "__main__":
|
||||||
"repeat_delay": 2,
|
"repeat_delay": 2,
|
||||||
},
|
},
|
||||||
"gamma": 0.99,
|
"gamma": 0.99,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"num_envs_per_worker": 20,
|
"num_envs_per_worker": 20,
|
||||||
"entropy_coeff": 0.001,
|
"entropy_coeff": 0.001,
|
||||||
|
|
|
@ -11,6 +11,7 @@ This examples shows both.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -44,7 +45,8 @@ if __name__ == "__main__":
|
||||||
config = {
|
config = {
|
||||||
"env": CorrelatedActionsEnv,
|
"env": CorrelatedActionsEnv,
|
||||||
"gamma": 0.5,
|
"gamma": 0.5,
|
||||||
"num_gpus": 0,
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "autoregressive_model",
|
"custom_model": "autoregressive_model",
|
||||||
"custom_action_dist": "binary_autoreg_dist",
|
"custom_action_dist": "binary_autoreg_dist",
|
||||||
|
@ -58,7 +60,7 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run(args.run, stop=stop, config=config)
|
results = tune.run(args.run, stop=stop, config=config, verbose=1)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
check_learning_achieved(results, args.stop_reward)
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Example of using a custom model with batch norm."""
|
"""Example of using a custom model with batch norm."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -32,6 +33,8 @@ if __name__ == "__main__":
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "bn_model",
|
"custom_model": "bn_model",
|
||||||
},
|
},
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
}
|
}
|
||||||
|
@ -42,7 +45,7 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run(args.run, stop=stop, config=config)
|
results = tune.run(args.run, stop=stop, config=config, verbose=1)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
check_learning_achieved(results, args.stop_reward)
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||||
|
@ -35,8 +36,11 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
|
|
||||||
config = dict(
|
config = dict(
|
||||||
configs[args.run], **{
|
configs[args.run],
|
||||||
|
**{
|
||||||
"env": StatelessCartPole,
|
"env": StatelessCartPole,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"model": {
|
"model": {
|
||||||
"use_lstm": True,
|
"use_lstm": True,
|
||||||
"lstm_use_prev_action_reward": args.use_prev_action_reward,
|
"lstm_use_prev_action_reward": args.use_prev_action_reward,
|
||||||
|
|
|
@ -16,6 +16,7 @@ modifies the environment.
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces import Discrete
|
from gym.spaces import Discrete
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -90,7 +91,7 @@ def centralized_critic_postprocessing(policy,
|
||||||
sample_batch[OPPONENT_OBS], policy.device),
|
sample_batch[OPPONENT_OBS], policy.device),
|
||||||
convert_to_torch_tensor(
|
convert_to_torch_tensor(
|
||||||
sample_batch[OPPONENT_ACTION], policy.device)) \
|
sample_batch[OPPONENT_ACTION], policy.device)) \
|
||||||
.detach().numpy()
|
.cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
|
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(
|
||||||
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
|
sample_batch[SampleBatch.CUR_OBS], sample_batch[OPPONENT_OBS],
|
||||||
|
@ -137,14 +138,22 @@ def loss_with_central_critic(policy, model, dist_class, train_batch):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def setup_mixins(policy, obs_space, action_space, config):
|
def setup_tf_mixins(policy, obs_space, action_space, config):
|
||||||
# copied from PPO
|
# Copied from PPOTFPolicy (w/o ValueNetworkMixin).
|
||||||
KLCoeffMixin.__init__(policy, config)
|
KLCoeffMixin.__init__(policy, config)
|
||||||
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
||||||
config["entropy_coeff_schedule"])
|
config["entropy_coeff_schedule"])
|
||||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||||
|
|
||||||
|
|
||||||
|
def setup_torch_mixins(policy, obs_space, action_space, config):
|
||||||
|
# Copied from PPOTorchPolicy (w/o ValueNetworkMixin).
|
||||||
|
TorchKLCoeffMixin.__init__(policy, config)
|
||||||
|
TorchEntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
||||||
|
config["entropy_coeff_schedule"])
|
||||||
|
TorchLR.__init__(policy, config["lr"], config["lr_schedule"])
|
||||||
|
|
||||||
|
|
||||||
def central_vf_stats(policy, train_batch, grads):
|
def central_vf_stats(policy, train_batch, grads):
|
||||||
# Report the explained variance of the central value function.
|
# Report the explained variance of the central value function.
|
||||||
return {
|
return {
|
||||||
|
@ -158,7 +167,7 @@ CCPPOTFPolicy = PPOTFPolicy.with_updates(
|
||||||
name="CCPPOTFPolicy",
|
name="CCPPOTFPolicy",
|
||||||
postprocess_fn=centralized_critic_postprocessing,
|
postprocess_fn=centralized_critic_postprocessing,
|
||||||
loss_fn=loss_with_central_critic,
|
loss_fn=loss_with_central_critic,
|
||||||
before_loss_init=setup_mixins,
|
before_loss_init=setup_tf_mixins,
|
||||||
grad_stats_fn=central_vf_stats,
|
grad_stats_fn=central_vf_stats,
|
||||||
mixins=[
|
mixins=[
|
||||||
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
|
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
|
||||||
|
@ -169,7 +178,7 @@ CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
|
||||||
name="CCPPOTorchPolicy",
|
name="CCPPOTorchPolicy",
|
||||||
postprocess_fn=centralized_critic_postprocessing,
|
postprocess_fn=centralized_critic_postprocessing,
|
||||||
loss_fn=loss_with_central_critic,
|
loss_fn=loss_with_central_critic,
|
||||||
before_init=setup_mixins,
|
before_init=setup_torch_mixins,
|
||||||
mixins=[
|
mixins=[
|
||||||
TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
|
TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
|
||||||
CentralizedValueMixin
|
CentralizedValueMixin
|
||||||
|
@ -188,7 +197,7 @@ CCTrainer = PPOTrainer.with_updates(
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ray.init(local_mode=True)
|
ray.init()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ModelCatalog.register_custom_model(
|
ModelCatalog.register_custom_model(
|
||||||
|
@ -198,6 +207,8 @@ if __name__ == "__main__":
|
||||||
config = {
|
config = {
|
||||||
"env": TwoStepGame,
|
"env": TwoStepGame,
|
||||||
"batch_mode": "complete_episodes",
|
"batch_mode": "complete_episodes",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"multiagent": {
|
"multiagent": {
|
||||||
"policies": {
|
"policies": {
|
||||||
|
@ -222,7 +233,7 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run(CCTrainer, config=config, stop=stop)
|
results = tune.run(CCTrainer, config=config, stop=stop, verbose=1)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
check_learning_achieved(results, args.stop_reward)
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
|
@ -12,6 +12,7 @@ modifies the policy to add a centralized value function.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces import Dict, Discrete
|
from gym.spaces import Dict, Discrete
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||||
|
@ -87,6 +88,8 @@ if __name__ == "__main__":
|
||||||
"env": TwoStepGame,
|
"env": TwoStepGame,
|
||||||
"batch_mode": "complete_episodes",
|
"batch_mode": "complete_episodes",
|
||||||
"callbacks": FillInActions,
|
"callbacks": FillInActions,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"multiagent": {
|
"multiagent": {
|
||||||
"policies": {
|
"policies": {
|
||||||
|
|
|
@ -8,7 +8,9 @@ For PyTorch / TF eager mode, use the --torch and --eager flags.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.models import ModelCatalog
|
from ray.rllib.models import ModelCatalog
|
||||||
from ray.rllib.examples.env.simple_rpg import SimpleRPG
|
from ray.rllib.examples.env.simple_rpg import SimpleRPG
|
||||||
|
@ -17,9 +19,10 @@ from ray.rllib.examples.models.simple_rpg_model import CustomTorchRPGModel, \
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--framework", choices=["tf", "tfe", "torch"], default="tf")
|
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf2")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
ray.init()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.framework == "torch":
|
if args.framework == "torch":
|
||||||
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
|
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
|
||||||
|
@ -31,6 +34,8 @@ if __name__ == "__main__":
|
||||||
"env": SimpleRPG,
|
"env": SimpleRPG,
|
||||||
"rollout_fragment_length": 1,
|
"rollout_fragment_length": 1,
|
||||||
"train_batch_size": 2,
|
"train_batch_size": 2,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "my_model",
|
"custom_model": "my_model",
|
||||||
|
|
|
@ -11,6 +11,7 @@ import argparse
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Discrete, Box
|
from gym.spaces import Discrete, Box
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -114,6 +115,8 @@ if __name__ == "__main__":
|
||||||
"env_config": {
|
"env_config": {
|
||||||
"corridor_length": 5,
|
"corridor_length": 5,
|
||||||
},
|
},
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "my_model",
|
"custom_model": "my_model",
|
||||||
},
|
},
|
||||||
|
|
|
@ -67,6 +67,7 @@ Result for PG_SimpleCorridor_0de4e686:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -137,7 +138,9 @@ if __name__ == "__main__":
|
||||||
"corridor_length": 10,
|
"corridor_length": 10,
|
||||||
},
|
},
|
||||||
"horizon": 20,
|
"horizon": 20,
|
||||||
"log_level": "INFO",
|
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
|
|
||||||
# Training rollouts will be collected using just the learner
|
# Training rollouts will be collected using just the learner
|
||||||
# process, but evaluation will be done in parallel with two
|
# process, but evaluation will be done in parallel with two
|
||||||
|
|
|
@ -5,6 +5,7 @@ for running perf microbenchmarks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.tune as tune
|
import ray.tune as tune
|
||||||
|
@ -32,7 +33,8 @@ if __name__ == "__main__":
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "fast_model"
|
"custom_model": "fast_model"
|
||||||
},
|
},
|
||||||
"num_gpus": 0,
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 2,
|
"num_workers": 2,
|
||||||
"num_envs_per_worker": 10,
|
"num_envs_per_worker": 10,
|
||||||
"num_data_loader_buffers": 1,
|
"num_data_loader_buffers": 1,
|
||||||
|
@ -40,7 +42,7 @@ if __name__ == "__main__":
|
||||||
"broadcast_interval": 50,
|
"broadcast_interval": 50,
|
||||||
"rollout_fragment_length": 100,
|
"rollout_fragment_length": 100,
|
||||||
"train_batch_size": sample_from(
|
"train_batch_size": sample_from(
|
||||||
lambda spec: 1000 * max(1, spec.config.num_gpus)),
|
lambda spec: 1000 * max(1, spec.config.num_gpus or 1)),
|
||||||
"fake_sampler": True,
|
"fake_sampler": True,
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
}
|
}
|
||||||
|
@ -50,6 +52,6 @@ if __name__ == "__main__":
|
||||||
"timesteps_total": args.stop_timesteps,
|
"timesteps_total": args.stop_timesteps,
|
||||||
}
|
}
|
||||||
|
|
||||||
tune.run("IMPALA", config=config, stop=stop)
|
tune.run("IMPALA", config=config, stop=stop, verbose=1)
|
||||||
|
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Example of using a custom ModelV2 Keras-style model."""
|
"""Example of using a custom ModelV2 Keras-style model."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -119,11 +120,12 @@ if __name__ == "__main__":
|
||||||
args.run,
|
args.run,
|
||||||
stop={"episode_reward_mean": args.stop},
|
stop={"episode_reward_mean": args.stop},
|
||||||
config=dict(
|
config=dict(
|
||||||
extra_config, **{
|
extra_config,
|
||||||
"log_level": "INFO",
|
**{
|
||||||
"env": "BreakoutNoFrameskip-v4"
|
"env": "BreakoutNoFrameskip-v4"
|
||||||
if args.use_vision_network else "CartPole-v0",
|
if args.use_vision_network else "CartPole-v0",
|
||||||
"num_gpus": 0,
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"callbacks": {
|
"callbacks": {
|
||||||
"on_train_result": check_has_custom_metric,
|
"on_train_result": check_has_custom_metric,
|
||||||
},
|
},
|
||||||
|
|
|
@ -50,6 +50,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"env": "CartPole-v0",
|
"env": "CartPole-v0",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "custom_loss",
|
"custom_model": "custom_loss",
|
||||||
|
@ -64,4 +66,4 @@ if __name__ == "__main__":
|
||||||
"training_iteration": args.stop_iters,
|
"training_iteration": args.stop_iters,
|
||||||
}
|
}
|
||||||
|
|
||||||
tune.run("PG", config=config, stop=stop)
|
tune.run("PG", config=config, stop=stop, verbose=1)
|
||||||
|
|
|
@ -7,14 +7,19 @@ custom metric.
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||||
from ray.rllib.env import BaseEnv
|
from ray.rllib.env import BaseEnv
|
||||||
|
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
|
||||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--torch", action="store_true")
|
||||||
|
parser.add_argument("--stop-iters", type=int, default=2000)
|
||||||
|
|
||||||
|
|
||||||
class MyCallbacks(DefaultCallbacks):
|
class MyCallbacks(DefaultCallbacks):
|
||||||
|
@ -65,8 +70,6 @@ class MyCallbacks(DefaultCallbacks):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--stop-iters", type=int, default=2000)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
|
@ -79,7 +82,9 @@ if __name__ == "__main__":
|
||||||
"env": "CartPole-v0",
|
"env": "CartPole-v0",
|
||||||
"num_envs_per_worker": 2,
|
"num_envs_per_worker": 2,
|
||||||
"callbacks": MyCallbacks,
|
"callbacks": MyCallbacks,
|
||||||
"framework": "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
}).trials
|
}).trials
|
||||||
|
|
||||||
# verify custom metrics for integration tests
|
# verify custom metrics for integration tests
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -73,6 +74,8 @@ if __name__ == "__main__":
|
||||||
"on_postprocess_traj": on_postprocess_traj,
|
"on_postprocess_traj": on_postprocess_traj,
|
||||||
},
|
},
|
||||||
"framework": "tf",
|
"framework": "tf",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
}).trials
|
}).trials
|
||||||
|
|
||||||
# verify custom metrics for integration tests
|
# verify custom metrics for integration tests
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Example of using a custom RNN keras model."""
|
"""Example of using a custom RNN keras model."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -37,6 +38,8 @@ if __name__ == "__main__":
|
||||||
"repeat_delay": 2,
|
"repeat_delay": 2,
|
||||||
},
|
},
|
||||||
"gamma": 0.9,
|
"gamma": 0.9,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"num_envs_per_worker": 20,
|
"num_envs_per_worker": 20,
|
||||||
"entropy_coeff": 0.001,
|
"entropy_coeff": 0.001,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -50,6 +51,8 @@ if __name__ == "__main__":
|
||||||
stop={"training_iteration": args.stop_iters},
|
stop={"training_iteration": args.stop_iters},
|
||||||
config={
|
config={
|
||||||
"env": "CartPole-v0",
|
"env": "CartPole-v0",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 2,
|
"num_workers": 2,
|
||||||
"framework": "tf",
|
"framework": "tf",
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -36,6 +37,8 @@ if __name__ == "__main__":
|
||||||
stop={"training_iteration": args.stop_iters},
|
stop={"training_iteration": args.stop_iters},
|
||||||
config={
|
config={
|
||||||
"env": "CartPole-v0",
|
"env": "CartPole-v0",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 2,
|
"num_workers": 2,
|
||||||
"framework": "torch",
|
"framework": "torch",
|
||||||
})
|
})
|
||||||
|
|
|
@ -6,6 +6,7 @@ This example shows:
|
||||||
You can visualize experiment results in ~/ray_results using TensorBoard.
|
You can visualize experiment results in ~/ray_results using TensorBoard.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -43,6 +44,8 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = {
|
config = {
|
||||||
"lr": 0.01,
|
"lr": 0.01,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
@ -58,12 +59,14 @@ MyTrainer = build_trainer(
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ray.init()
|
ray.init(local_mode=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ModelCatalog.register_custom_model("eager_model", EagerModel)
|
ModelCatalog.register_custom_model("eager_model", EagerModel)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"env": "CartPole-v0",
|
"env": "CartPole-v0",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "eager_model"
|
"custom_model": "eager_model"
|
||||||
|
@ -76,7 +79,7 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run(MyTrainer, stop=stop, config=config)
|
results = tune.run(MyTrainer, stop=stop, config=config, verbose=1)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
check_learning_achieved(results, args.stop_reward)
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
|
@ -25,6 +25,7 @@ using --flat in this example.
|
||||||
import argparse
|
import argparse
|
||||||
from gym.spaces import Discrete, Tuple
|
from gym.spaces import Discrete, Tuple
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -75,7 +76,6 @@ if __name__ == "__main__":
|
||||||
config = {
|
config = {
|
||||||
"env": HierarchicalWindyMazeEnv,
|
"env": HierarchicalWindyMazeEnv,
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"log_level": "INFO",
|
|
||||||
"entropy_coeff": 0.01,
|
"entropy_coeff": 0.01,
|
||||||
"multiagent": {
|
"multiagent": {
|
||||||
"policies": {
|
"policies": {
|
||||||
|
@ -94,6 +94,8 @@ if __name__ == "__main__":
|
||||||
"policy_mapping_fn": function(policy_mapping_fn),
|
"policy_mapping_fn": function(policy_mapping_fn),
|
||||||
},
|
},
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run("PPO", stop=stop, config=config, verbose=1)
|
results = tune.run("PPO", stop=stop, config=config, verbose=1)
|
||||||
|
|
|
@ -5,8 +5,9 @@
|
||||||
import argparse
|
import argparse
|
||||||
from gym.spaces import Discrete, Box
|
from gym.spaces import Discrete, Box
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
from ray.rllib.agents.ppo import PPOTrainer
|
from ray import tune
|
||||||
from ray.rllib.examples.env.random_env import RandomEnv
|
from ray.rllib.examples.env.random_env import RandomEnv
|
||||||
from ray.rllib.examples.models.mobilenet_v2_with_lstm_models import \
|
from ray.rllib.examples.models.mobilenet_v2_with_lstm_models import \
|
||||||
MobileV2PlusRNNModel, TorchMobileV2PlusRNNModel
|
MobileV2PlusRNNModel, TorchMobileV2PlusRNNModel
|
||||||
|
@ -21,6 +22,9 @@ cnn_shape_torch = (3, 224, 224)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--torch", action="store_true")
|
parser.add_argument("--torch", action="store_true")
|
||||||
|
parser.add_argument("--stop-iters", type=int, default=200)
|
||||||
|
parser.add_argument("--stop-reward", type=float, default=0.0)
|
||||||
|
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -30,8 +34,15 @@ if __name__ == "__main__":
|
||||||
"my_model", TorchMobileV2PlusRNNModel
|
"my_model", TorchMobileV2PlusRNNModel
|
||||||
if args.torch else MobileV2PlusRNNModel)
|
if args.torch else MobileV2PlusRNNModel)
|
||||||
|
|
||||||
|
stop = {
|
||||||
|
"training_iteration": args.stop_iters,
|
||||||
|
"timesteps_total": args.stop_timesteps,
|
||||||
|
"episode_reward_mean": args.stop_reward,
|
||||||
|
}
|
||||||
|
|
||||||
# Configure our Trainer.
|
# Configure our Trainer.
|
||||||
config = {
|
config = {
|
||||||
|
"env": RandomEnv,
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
"model": {
|
"model": {
|
||||||
"custom_model": "my_model",
|
"custom_model": "my_model",
|
||||||
|
@ -42,6 +53,8 @@ if __name__ == "__main__":
|
||||||
"max_seq_len": 20,
|
"max_seq_len": 20,
|
||||||
},
|
},
|
||||||
"vf_share_layers": True,
|
"vf_share_layers": True,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0, # no parallelism
|
"num_workers": 0, # no parallelism
|
||||||
"env_config": {
|
"env_config": {
|
||||||
"action_space": Discrete(2),
|
"action_space": Discrete(2),
|
||||||
|
@ -54,5 +67,4 @@ if __name__ == "__main__":
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
trainer = PPOTrainer(config=config, env=RandomEnv)
|
tune.run("PPO", config=config, stop=stop, verbose=1)
|
||||||
print(trainer.train())
|
|
||||||
|
|
|
@ -131,8 +131,8 @@ class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
|
||||||
|
|
||||||
def _a1_distribution(self):
|
def _a1_distribution(self):
|
||||||
BATCH = self.inputs.shape[0]
|
BATCH = self.inputs.shape[0]
|
||||||
a1_logits, _ = self.model.action_module(self.inputs,
|
zeros = torch.zeros((BATCH, 1)).to(self.inputs.device)
|
||||||
torch.zeros((BATCH, 1)))
|
a1_logits, _ = self.model.action_module(self.inputs, zeros)
|
||||||
a1_dist = TorchCategorical(a1_logits)
|
a1_dist = TorchCategorical(a1_logits)
|
||||||
return a1_dist
|
return a1_dist
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ class TorchCustomLossModel(TorchModelV2, nn.Module):
|
||||||
|
|
||||||
# Define a secondary loss by building a graph copy with weight sharing.
|
# Define a secondary loss by building a graph copy with weight sharing.
|
||||||
obs = restore_original_dimensions(
|
obs = restore_original_dimensions(
|
||||||
torch.from_numpy(batch["obs"]).float(),
|
torch.from_numpy(batch["obs"]).float().to(policy_loss[0].device),
|
||||||
self.obs_space,
|
self.obs_space,
|
||||||
tensorlib="torch")
|
tensorlib="torch")
|
||||||
logits, _ = self.forward({"obs": obs}, [], None)
|
logits, _ = self.forward({"obs": obs}, [], None)
|
||||||
|
@ -130,8 +130,8 @@ class TorchCustomLossModel(TorchModelV2, nn.Module):
|
||||||
|
|
||||||
# Compute the IL loss.
|
# Compute the IL loss.
|
||||||
action_dist = TorchCategorical(logits, self.model_config)
|
action_dist = TorchCategorical(logits, self.model_config)
|
||||||
imitation_loss = torch.mean(
|
imitation_loss = torch.mean(-action_dist.logp(
|
||||||
-action_dist.logp(torch.from_numpy(batch["actions"])))
|
torch.from_numpy(batch["actions"]).to(policy_loss[0].device)))
|
||||||
self.imitation_loss_metric = imitation_loss.item()
|
self.imitation_loss_metric = imitation_loss.item()
|
||||||
self.policy_loss_metric = np.mean([l.item() for l in policy_loss])
|
self.policy_loss_metric = np.mean([l.item() for l in policy_loss])
|
||||||
|
|
||||||
|
|
|
@ -57,8 +57,8 @@ class TorchFastModel(TorchModelV2, nn.Module):
|
||||||
model_config, name)
|
model_config, name)
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
self.bias = torch.tensor(
|
self.bias = nn.Parameter(
|
||||||
[0.0], dtype=torch.float32, requires_grad=True)
|
torch.tensor([0.0], dtype=torch.float32, requires_grad=True))
|
||||||
|
|
||||||
# Only needed to give some params to the optimizer (even though,
|
# Only needed to give some params to the optimizer (even though,
|
||||||
# they are never used anywhere).
|
# they are never used anywhere).
|
||||||
|
@ -67,8 +67,9 @@ class TorchFastModel(TorchModelV2, nn.Module):
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
def forward(self, input_dict, state, seq_lens):
|
def forward(self, input_dict, state, seq_lens):
|
||||||
self._output = self.bias + \
|
self._output = self.bias + torch.zeros(
|
||||||
torch.zeros(size=(input_dict["obs"].shape[0], self.num_outputs))
|
size=(input_dict["obs"].shape[0], self.num_outputs)).to(
|
||||||
|
self.bias.device)
|
||||||
return self._output, []
|
return self._output, []
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
|
|
|
@ -89,14 +89,15 @@ class MobileV2PlusRNNModel(RecurrentNetwork):
|
||||||
return tf.reshape(self._value_out, [-1])
|
return tf.reshape(self._value_out, [-1])
|
||||||
|
|
||||||
|
|
||||||
class TorchMobileV2PlusRNNModel(TorchRNN):
|
class TorchMobileV2PlusRNNModel(TorchRNN, nn.Module):
|
||||||
"""A conv. + recurrent torch net example using a pre-trained MobileNet."""
|
"""A conv. + recurrent torch net example using a pre-trained MobileNet."""
|
||||||
|
|
||||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||||
name, cnn_shape):
|
name, cnn_shape):
|
||||||
|
|
||||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
TorchRNN.__init__(self, obs_space, action_space, num_outputs,
|
||||||
name)
|
model_config, name)
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
self.lstm_state_size = 16
|
self.lstm_state_size = 16
|
||||||
self.cnn_shape = list(cnn_shape)
|
self.cnn_shape = list(cnn_shape)
|
||||||
|
|
|
@ -125,12 +125,13 @@ class TorchSharedWeightsModel(TorchModelV2, nn.Module):
|
||||||
activation_fn=None,
|
activation_fn=None,
|
||||||
initializer=torch.nn.init.xavier_uniform_,
|
initializer=torch.nn.init.xavier_uniform_,
|
||||||
)
|
)
|
||||||
|
self._global_shared_layer = TORCH_GLOBAL_SHARED_LAYER
|
||||||
self._output = None
|
self._output = None
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
def forward(self, input_dict, state, seq_lens):
|
def forward(self, input_dict, state, seq_lens):
|
||||||
out = self.first_layer(input_dict["obs"])
|
out = self.first_layer(input_dict["obs"])
|
||||||
self._output = TORCH_GLOBAL_SHARED_LAYER(out)
|
self._output = self._global_shared_layer(out)
|
||||||
model_out = self.last_layer(self._output)
|
model_out = self.last_layer(self._output)
|
||||||
return model_out, []
|
return model_out, []
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ execution, set the TF_TIMELINE_DIR environment variable.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
@ -75,6 +76,8 @@ if __name__ == "__main__":
|
||||||
"num_agents": args.num_agents,
|
"num_agents": args.num_agents,
|
||||||
},
|
},
|
||||||
"simple_optimizer": args.simple,
|
"simple_optimizer": args.simple,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_sgd_iter": 10,
|
"num_sgd_iter": 10,
|
||||||
"multiagent": {
|
"multiagent": {
|
||||||
"policies": policies,
|
"policies": policies,
|
||||||
|
|
|
@ -15,6 +15,7 @@ Result for PG_multi_cartpole_0:
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -60,6 +61,8 @@ if __name__ == "__main__":
|
||||||
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
|
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
|
||||||
},
|
},
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run("PG", config=config, stop=stop, verbose=1)
|
results = tune.run("PG", config=config, stop=stop, verbose=1)
|
||||||
|
|
|
@ -10,6 +10,7 @@ For a simpler example, see also: multiagent_cartpole.py
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
|
from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
|
||||||
|
@ -38,9 +39,9 @@ if __name__ == "__main__":
|
||||||
# Simple environment with 4 independent cartpole entities
|
# Simple environment with 4 independent cartpole entities
|
||||||
register_env("multi_agent_cartpole",
|
register_env("multi_agent_cartpole",
|
||||||
lambda _: MultiAgentCartPole({"num_agents": 4}))
|
lambda _: MultiAgentCartPole({"num_agents": 4}))
|
||||||
single_env = gym.make("CartPole-v0")
|
single_dummy_env = gym.make("CartPole-v0")
|
||||||
obs_space = single_env.observation_space
|
obs_space = single_dummy_env.observation_space
|
||||||
act_space = single_env.action_space
|
act_space = single_dummy_env.action_space
|
||||||
|
|
||||||
# You can also have multiple policies per trainer, but here we just
|
# You can also have multiple policies per trainer, but here we just
|
||||||
# show one each for PPO and DQN.
|
# show one each for PPO and DQN.
|
||||||
|
@ -69,6 +70,8 @@ if __name__ == "__main__":
|
||||||
# disable filters, otherwise we would need to synchronize those
|
# disable filters, otherwise we would need to synchronize those
|
||||||
# as well to the DQN agent
|
# as well to the DQN agent
|
||||||
"observation_filter": "NoFilter",
|
"observation_filter": "NoFilter",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -82,6 +85,8 @@ if __name__ == "__main__":
|
||||||
},
|
},
|
||||||
"gamma": 0.95,
|
"gamma": 0.95,
|
||||||
"n_step": 3,
|
"n_step": 3,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"framework": "torch" if args.torch or args.mixed_torch_tf else "tf"
|
"framework": "torch" if args.torch or args.mixed_torch_tf else "tf"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
from gym.spaces import Dict, Tuple, Box, Discrete
|
from gym.spaces import Dict, Tuple, Box, Discrete
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.tune as tune
|
import ray.tune as tune
|
||||||
|
@ -40,6 +41,8 @@ if __name__ == "__main__":
|
||||||
"gamma": 0.0, # No history in Env (bandit problem).
|
"gamma": 0.0, # No history in Env (bandit problem).
|
||||||
"lr": 0.0005,
|
"lr": 0.0005,
|
||||||
"num_envs_per_worker": 20,
|
"num_envs_per_worker": 20,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_sgd_iter": 4,
|
"num_sgd_iter": 4,
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"vf_loss_coeff": 0.01,
|
"vf_loss_coeff": 0.01,
|
||||||
|
|
|
@ -15,6 +15,7 @@ Working configurations are given below.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -55,14 +56,18 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
|
|
||||||
config = dict({
|
config = dict(
|
||||||
"env": "pa_cartpole",
|
{
|
||||||
"model": {
|
"env": "pa_cartpole",
|
||||||
"custom_model": "pa_model",
|
"model": {
|
||||||
|
"custom_model": "pa_model",
|
||||||
|
},
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
|
"num_workers": 0,
|
||||||
|
"framework": "torch" if args.torch else "tf",
|
||||||
},
|
},
|
||||||
"num_workers": 0,
|
**cfg)
|
||||||
"framework": "torch" if args.torch else "tf",
|
|
||||||
}, **cfg)
|
|
||||||
|
|
||||||
stop = {
|
stop = {
|
||||||
"training_iteration": args.stop_iters,
|
"training_iteration": args.stop_iters,
|
||||||
|
|
|
@ -1,15 +1,13 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import ray
|
from numpy import float32
|
||||||
try:
|
import os
|
||||||
from ray.rllib.agents.agent import get_agent_class
|
|
||||||
except ImportError:
|
|
||||||
from ray.rllib.agents.registry import get_agent_class
|
|
||||||
from ray.tune.registry import register_env
|
|
||||||
from ray.rllib.env import PettingZooEnv
|
|
||||||
from pettingzoo.butterfly import pistonball_v0
|
from pettingzoo.butterfly import pistonball_v0
|
||||||
from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
|
from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
|
||||||
|
|
||||||
from numpy import float32
|
import ray
|
||||||
|
from ray.rllib.agents.registry import get_agent_class
|
||||||
|
from ray.rllib.env import PettingZooEnv
|
||||||
|
from ray.tune.registry import register_env
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""For this script, you need:
|
"""For this script, you need:
|
||||||
|
@ -37,7 +35,7 @@ if __name__ == "__main__":
|
||||||
config = deepcopy(get_agent_class(alg_name)._default_config)
|
config = deepcopy(get_agent_class(alg_name)._default_config)
|
||||||
|
|
||||||
# 2. Set environment config. This will be passed to
|
# 2. Set environment config. This will be passed to
|
||||||
# the env_creator function via the register env lambda below
|
# the env_creator function via the register env lambda below.
|
||||||
config["env_config"] = {"local_ratio": 0.5}
|
config["env_config"] = {"local_ratio": 0.5}
|
||||||
|
|
||||||
# 3. Register env
|
# 3. Register env
|
||||||
|
@ -58,6 +56,8 @@ if __name__ == "__main__":
|
||||||
"policy_mapping_fn": lambda agent_id: "av"
|
"policy_mapping_fn": lambda agent_id: "av"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
|
||||||
config["log_level"] = "DEBUG"
|
config["log_level"] = "DEBUG"
|
||||||
config["num_workers"] = 1
|
config["num_workers"] = 1
|
||||||
# Fragment length, collected at once from each worker and for each agent!
|
# Fragment length, collected at once from each worker and for each agent!
|
||||||
|
|
|
@ -9,6 +9,7 @@ This demonstrates running the following policies in competition:
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from gym.spaces import Discrete
|
from gym.spaces import Discrete
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -63,6 +64,8 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"):
|
||||||
config = {
|
config = {
|
||||||
"env": RockPaperScissors,
|
"env": RockPaperScissors,
|
||||||
"gamma": 0.9,
|
"gamma": 0.9,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
"num_envs_per_worker": 4,
|
"num_envs_per_worker": 4,
|
||||||
"rollout_fragment_length": 10,
|
"rollout_fragment_length": 10,
|
||||||
|
|
|
@ -8,6 +8,7 @@ collection and policy optimization.
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -32,6 +33,7 @@ class CustomPolicy(Policy):
|
||||||
|
|
||||||
def __init__(self, observation_space, action_space, config):
|
def __init__(self, observation_space, action_space, config):
|
||||||
super().__init__(observation_space, action_space, config)
|
super().__init__(observation_space, action_space, config)
|
||||||
|
self.config["framework"] = None
|
||||||
# example parameter
|
# example parameter
|
||||||
self.w = 1.0
|
self.w = 1.0
|
||||||
|
|
||||||
|
@ -107,7 +109,8 @@ if __name__ == "__main__":
|
||||||
tune.run(
|
tune.run(
|
||||||
training_workflow,
|
training_workflow,
|
||||||
resources_per_trial={
|
resources_per_trial={
|
||||||
"gpu": 1 if args.gpu else 0,
|
"gpu": 1 if args.gpu
|
||||||
|
or int(os.environ.get("RLLIB_FORCE_NUM_GPUS", 0)) else 0,
|
||||||
"cpu": 1,
|
"cpu": 1,
|
||||||
"extra_cpu": args.num_workers,
|
"extra_cpu": args.num_workers,
|
||||||
},
|
},
|
||||||
|
@ -115,4 +118,5 @@ if __name__ == "__main__":
|
||||||
"num_workers": args.num_workers,
|
"num_workers": args.num_workers,
|
||||||
"num_iters": args.num_iters,
|
"num_iters": args.num_iters,
|
||||||
},
|
},
|
||||||
|
verbose=1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -69,7 +69,7 @@ parser.add_argument(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ray.init(local_mode=True)
|
ray.init()
|
||||||
|
|
||||||
# Create a fake-env for the server. This env will never be used (neither
|
# Create a fake-env for the server. This env will never be used (neither
|
||||||
# for sampling, nor for evaluation) and its obs/action Spaces do not
|
# for sampling, nor for evaluation) and its obs/action Spaces do not
|
||||||
|
|
|
@ -10,6 +10,7 @@ See also: centralized_critic.py for centralized critic PPO on this game.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete
|
from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -77,6 +78,8 @@ if __name__ == "__main__":
|
||||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||||
},
|
},
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
}
|
}
|
||||||
group = False
|
group = False
|
||||||
elif args.run == "QMIX":
|
elif args.run == "QMIX":
|
||||||
|
@ -93,11 +96,17 @@ if __name__ == "__main__":
|
||||||
"separate_state_space": True,
|
"separate_state_space": True,
|
||||||
"one_hot_state_encoding": True
|
"one_hot_state_encoding": True
|
||||||
},
|
},
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
}
|
}
|
||||||
group = True
|
group = True
|
||||||
else:
|
else:
|
||||||
config = {"framework": "torch" if args.torch else "tf"}
|
config = {
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
|
"framework": "torch" if args.torch else "tf",
|
||||||
|
}
|
||||||
group = False
|
group = False
|
||||||
|
|
||||||
ray.init(num_cpus=args.num_cpus or None)
|
ray.init(num_cpus=args.num_cpus or None)
|
||||||
|
|
|
@ -7,6 +7,7 @@ via a custom training workflow.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -139,6 +140,8 @@ if __name__ == "__main__":
|
||||||
"policy_mapping_fn": policy_mapping_fn,
|
"policy_mapping_fn": policy_mapping_fn,
|
||||||
"policies_to_train": ["dqn_policy", "ppo_policy"],
|
"policies_to_train": ["dqn_policy", "ppo_policy"],
|
||||||
},
|
},
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": "torch" if args.torch else "tf",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ $ python unity3d_env_local.py --env 3DBall --stop-reward [..] [--torch]?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -99,6 +100,8 @@ if __name__ == "__main__":
|
||||||
"gamma": 0.99,
|
"gamma": 0.99,
|
||||||
"sgd_minibatch_size": 256,
|
"sgd_minibatch_size": 256,
|
||||||
"train_batch_size": 4000,
|
"train_batch_size": 4000,
|
||||||
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||||
"num_sgd_iter": 20,
|
"num_sgd_iter": 20,
|
||||||
"rollout_fragment_length": 200,
|
"rollout_fragment_length": 200,
|
||||||
"clip_param": 0.2,
|
"clip_param": 0.2,
|
||||||
|
|
|
@ -307,7 +307,7 @@ class ModelCatalog:
|
||||||
model_cls = ModelCatalog._wrap_if_needed(model_cls,
|
model_cls = ModelCatalog._wrap_if_needed(model_cls,
|
||||||
model_interface)
|
model_interface)
|
||||||
|
|
||||||
if framework in ["tf", "tfe"]:
|
if framework in ["tf2", "tf", "tfe"]:
|
||||||
# Track and warn if vars were created but not registered.
|
# Track and warn if vars were created but not registered.
|
||||||
created = set()
|
created = set()
|
||||||
|
|
||||||
|
|
|
@ -423,7 +423,7 @@ class TestDistributions(unittest.TestCase):
|
||||||
def test_gumbel_softmax(self):
|
def test_gumbel_softmax(self):
|
||||||
"""Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
|
"""Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
|
||||||
for fw, sess in framework_iterator(
|
for fw, sess in framework_iterator(
|
||||||
frameworks=["tf", "tfe"], session=True):
|
frameworks=("tf2", "tf", "tfe"), session=True):
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
num_categories = 5
|
num_categories = 5
|
||||||
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
|
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
|
||||||
|
|
|
@ -200,7 +200,7 @@ def build_eager_tf_policy(name,
|
||||||
class eager_policy_cls(base):
|
class eager_policy_cls(base):
|
||||||
def __init__(self, observation_space, action_space, config):
|
def __init__(self, observation_space, action_space, config):
|
||||||
assert tf.executing_eagerly()
|
assert tf.executing_eagerly()
|
||||||
self.framework = "tfe"
|
self.framework = config.get("framework", "tfe")
|
||||||
Policy.__init__(self, observation_space, action_space, config)
|
Policy.__init__(self, observation_space, action_space, config)
|
||||||
self._is_training = False
|
self._is_training = False
|
||||||
self._loss_initialized = False
|
self._loss_initialized = False
|
||||||
|
|
|
@ -88,7 +88,7 @@ class SampleBatch:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def concat_samples(samples: List[Dict[str, TensorType]]) -> \
|
def concat_samples(samples: List["SampleBatch"]) -> \
|
||||||
Union["SampleBatch", "MultiAgentBatch"]:
|
Union["SampleBatch", "MultiAgentBatch"]:
|
||||||
"""Concatenates n data dicts or MultiAgentBatches.
|
"""Concatenates n data dicts or MultiAgentBatches.
|
||||||
|
|
||||||
|
|
|
@ -392,7 +392,7 @@ class TorchPolicy(Policy):
|
||||||
|
|
||||||
grad_info["allreduce_latency"] += time.time() - start
|
grad_info["allreduce_latency"] += time.time() - start
|
||||||
|
|
||||||
# Step the optimizer
|
# Step the optimizers.
|
||||||
for i, opt in enumerate(self._optimizers):
|
for i, opt in enumerate(self._optimizers):
|
||||||
opt.step()
|
opt.step()
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,8 @@ def build_torch_policy(
|
||||||
mixins: Optional[List[type]] = None,
|
mixins: Optional[List[type]] = None,
|
||||||
view_requirements_fn: Optional[Callable[[], Dict[
|
view_requirements_fn: Optional[Callable[[], Dict[
|
||||||
str, ViewRequirement]]] = None,
|
str, ViewRequirement]]] = None,
|
||||||
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None):
|
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None
|
||||||
|
) -> Type[TorchPolicy]:
|
||||||
"""Helper function for creating a torch policy class at runtime.
|
"""Helper function for creating a torch policy class at runtime.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -167,7 +168,8 @@ def build_torch_policy(
|
||||||
sample batches. If None, will assume a value of 1.
|
sample batches. If None, will assume a value of 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
type: TorchPolicy child class constructed from the specified args.
|
Type[TorchPolicy]: TorchPolicy child class constructed from the
|
||||||
|
specified args.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
original_kwargs = locals().copy()
|
original_kwargs = locals().copy()
|
||||||
|
|
|
@ -96,7 +96,7 @@ def ckpt_restore_test(alg_name, tfe=False):
|
||||||
if optim_state:
|
if optim_state:
|
||||||
s2 = alg2.get_policy().get_state().get("_optimizer_variables")
|
s2 = alg2.get_policy().get_state().get("_optimizer_variables")
|
||||||
# Tf -> Compare states 1:1.
|
# Tf -> Compare states 1:1.
|
||||||
if fw in ["tf", "tfe"]:
|
if fw in ["tf2", "tf", "tfe"]:
|
||||||
check(s2, optim_state)
|
check(s2, optim_state)
|
||||||
# For torch, optimizers have state_dicts with keys=params,
|
# For torch, optimizers have state_dicts with keys=params,
|
||||||
# which are different for the two models (ignore these
|
# which are different for the two models (ignore these
|
||||||
|
|
|
@ -59,22 +59,22 @@ __all__ = [
|
||||||
"add_mixins",
|
"add_mixins",
|
||||||
"check",
|
"check",
|
||||||
"check_compute_single_action",
|
"check_compute_single_action",
|
||||||
|
"deep_update",
|
||||||
"deprecation_warning",
|
"deprecation_warning",
|
||||||
"fc",
|
"fc",
|
||||||
"force_list",
|
"force_list",
|
||||||
"force_tuple",
|
"force_tuple",
|
||||||
"framework_iterator",
|
"framework_iterator",
|
||||||
"lstm",
|
"lstm",
|
||||||
"one_hot",
|
|
||||||
"relu",
|
|
||||||
"sigmoid",
|
|
||||||
"softmax",
|
|
||||||
"deep_update",
|
|
||||||
"merge_dicts",
|
"merge_dicts",
|
||||||
|
"one_hot",
|
||||||
"override",
|
"override",
|
||||||
|
"relu",
|
||||||
"renamed_function",
|
"renamed_function",
|
||||||
"renamed_agent",
|
"renamed_agent",
|
||||||
"renamed_class",
|
"renamed_class",
|
||||||
|
"sigmoid",
|
||||||
|
"softmax",
|
||||||
"try_import_tf",
|
"try_import_tf",
|
||||||
"try_import_tfp",
|
"try_import_tfp",
|
||||||
"try_import_torch",
|
"try_import_torch",
|
||||||
|
|
|
@ -211,7 +211,7 @@ class Curiosity(Exploration):
|
||||||
})
|
})
|
||||||
phi, next_phi = torch.chunk(phis, 2)
|
phi, next_phi = torch.chunk(phis, 2)
|
||||||
actions_tensor = torch.from_numpy(
|
actions_tensor = torch.from_numpy(
|
||||||
sample_batch[SampleBatch.ACTIONS]).long()
|
sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)
|
||||||
|
|
||||||
# Predict next phi with forward model.
|
# Predict next phi with forward model.
|
||||||
predicted_next_phi = self.model._curiosity_forward_fcnet(
|
predicted_next_phi = self.model._curiosity_forward_fcnet(
|
||||||
|
|
|
@ -57,7 +57,7 @@ class EpsilonGreedy(Exploration):
|
||||||
0, framework=framework, tf_name="timestep")
|
0, framework=framework, tf_name="timestep")
|
||||||
|
|
||||||
# Build the tf-info-op.
|
# Build the tf-info-op.
|
||||||
if self.framework in ["tf", "tfe"]:
|
if self.framework in ["tf2", "tf", "tfe"]:
|
||||||
self._tf_info_op = self.get_info()
|
self._tf_info_op = self.get_info()
|
||||||
|
|
||||||
@override(Exploration)
|
@override(Exploration)
|
||||||
|
@ -68,7 +68,7 @@ class EpsilonGreedy(Exploration):
|
||||||
explore: bool = True):
|
explore: bool = True):
|
||||||
|
|
||||||
q_values = action_distribution.inputs
|
q_values = action_distribution.inputs
|
||||||
if self.framework in ["tf", "tfe"]:
|
if self.framework in ["tf2", "tf", "tfe"]:
|
||||||
return self._get_tf_exploration_action_op(q_values, explore,
|
return self._get_tf_exploration_action_op(q_values, explore,
|
||||||
timestep)
|
timestep)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -72,7 +72,7 @@ class GaussianNoise(Exploration):
|
||||||
0, framework=self.framework, tf_name="timestep")
|
0, framework=self.framework, tf_name="timestep")
|
||||||
|
|
||||||
# Build the tf-info-op.
|
# Build the tf-info-op.
|
||||||
if self.framework in ["tf", "tfe"]:
|
if self.framework in ["tf2", "tf", "tfe"]:
|
||||||
self._tf_info_op = self.get_info()
|
self._tf_info_op = self.get_info()
|
||||||
|
|
||||||
@override(Exploration)
|
@override(Exploration)
|
||||||
|
|
|
@ -291,7 +291,7 @@ class ParameterNoise(Exploration):
|
||||||
"""Samples new noise and stores it in `self.noise`."""
|
"""Samples new noise and stores it in `self.noise`."""
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
tf_sess.run(self.tf_sample_new_noise_op)
|
tf_sess.run(self.tf_sample_new_noise_op)
|
||||||
elif self.framework == "tfe":
|
elif self.framework in ["tfe", "tf2"]:
|
||||||
self._tf_sample_new_noise_op()
|
self._tf_sample_new_noise_op()
|
||||||
else:
|
else:
|
||||||
for i in range(len(self.noise)):
|
for i in range(len(self.noise)):
|
||||||
|
@ -340,7 +340,7 @@ class ParameterNoise(Exploration):
|
||||||
# Add stored noise to the model's parameters.
|
# Add stored noise to the model's parameters.
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
tf_sess.run(self.tf_add_stored_noise_op)
|
tf_sess.run(self.tf_add_stored_noise_op)
|
||||||
elif self.framework == "tfe":
|
elif self.framework in ["tf2", "tfe"]:
|
||||||
self._tf_add_stored_noise_op()
|
self._tf_add_stored_noise_op()
|
||||||
else:
|
else:
|
||||||
for i in range(len(self.noise)):
|
for i in range(len(self.noise)):
|
||||||
|
@ -378,7 +378,7 @@ class ParameterNoise(Exploration):
|
||||||
# Removes the stored noise from the model's parameters.
|
# Removes the stored noise from the model's parameters.
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
tf_sess.run(self.tf_remove_noise_op)
|
tf_sess.run(self.tf_remove_noise_op)
|
||||||
elif self.framework == "tfe":
|
elif self.framework in ["tf2", "tfe"]:
|
||||||
self._tf_remove_noise_op()
|
self._tf_remove_noise_op()
|
||||||
else:
|
else:
|
||||||
for var, noise in zip(self.model_variables, self.noise):
|
for var, noise in zip(self.model_variables, self.noise):
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Random(Exploration):
|
||||||
timestep: Union[int, TensorType],
|
timestep: Union[int, TensorType],
|
||||||
explore: bool = True):
|
explore: bool = True):
|
||||||
# Instantiate the distribution object.
|
# Instantiate the distribution object.
|
||||||
if self.framework in ["tf", "tfe"]:
|
if self.framework in ["tf2", "tf", "tfe"]:
|
||||||
return self.get_tf_exploration_action_op(action_distribution,
|
return self.get_tf_exploration_action_op(action_distribution,
|
||||||
explore)
|
explore)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -190,7 +190,7 @@ def get_variable(value,
|
||||||
any: A framework-specific variable (tf.Variable, torch.tensor, or
|
any: A framework-specific variable (tf.Variable, torch.tensor, or
|
||||||
python primitive).
|
python primitive).
|
||||||
"""
|
"""
|
||||||
if framework in ["tf", "tfe"]:
|
if framework in ["tf2", "tf", "tfe"]:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
dtype = dtype or getattr(
|
dtype = dtype or getattr(
|
||||||
value, "dtype", tf.float32
|
value, "dtype", tf.float32
|
||||||
|
|
Loading…
Add table
Reference in a new issue