mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] MB-MPO cleanup (comments, docstrings, type annotations). (#11033)
This commit is contained in:
parent
d2a0d23b0e
commit
ce96b03b07
44 changed files with 706 additions and 319 deletions
|
@ -289,7 +289,7 @@ class Trainable:
|
|||
return ""
|
||||
|
||||
def get_current_ip(self):
|
||||
self._local_ip = ray._private.services.get_node_ip_address()
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
return self._local_ip
|
||||
|
||||
def train(self):
|
||||
|
|
|
@ -10,7 +10,7 @@ import traceback
|
|||
import types
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray._private.services import get_node_ip_address
|
||||
from ray.services import get_node_ip_address
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.stopper import NoopStopper
|
||||
from ray.tune.progress_reporter import trial_progress_str
|
||||
|
|
20
rllib/BUILD
20
rllib/BUILD
|
@ -222,6 +222,18 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/impala", "--torch"]
|
||||
)
|
||||
|
||||
# Working, but takes a long time to learn (>15min).
|
||||
## MB-MPO
|
||||
#py_test(
|
||||
# name = "run_regression_tests_pendulum_mbmpo_torch",
|
||||
# main = "tests/run_regression_tests.py",
|
||||
# tags = ["learning_tests_torch", "learning_tests_pendulum"],
|
||||
# size = "large",
|
||||
# srcs = ["tests/run_regression_tests.py"],
|
||||
# data = ["tuned_examples/mbmpo/pendulum-mbmpo.yaml"],
|
||||
# args = ["--torch", "--yaml-dir=tuned_examples/mbmpo"]
|
||||
#)
|
||||
|
||||
# PG
|
||||
py_test(
|
||||
name = "run_regression_tests_cartpole_pg_tf",
|
||||
|
@ -486,6 +498,14 @@ py_test(
|
|||
srcs = ["agents/maml/tests/test_maml.py"]
|
||||
)
|
||||
|
||||
# MBMPOTrainer
|
||||
py_test(
|
||||
name = "test_mbmpo",
|
||||
tags = ["agents_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/mbmpo/tests/test_mbmpo.py"]
|
||||
)
|
||||
|
||||
# PGTrainer
|
||||
py_test(
|
||||
name = "test_pg",
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
|
||||
__all__ = ["Trainer", "with_common_config"]
|
||||
__all__ = [
|
||||
"Trainer",
|
||||
"with_common_config",
|
||||
]
|
||||
|
|
|
@ -49,12 +49,18 @@ class DDPGTorchModel(TorchModelV2, nn.Module):
|
|||
super(DDPGTorchModel, self).__init__(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
|
||||
self.bounded = np.logical_and(action_space.bounded_above,
|
||||
action_space.bounded_below).any()
|
||||
self.low_action = torch.tensor(action_space.low, dtype=torch.float32)
|
||||
self.action_range = torch.tensor(
|
||||
action_space.high - action_space.low, dtype=torch.float32)
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
self.bounded = np.logical_and(self.action_space.bounded_above,
|
||||
self.action_space.bounded_below).any()
|
||||
low_action = nn.Parameter(
|
||||
torch.from_numpy(self.action_space.low).float())
|
||||
low_action.requires_grad = False
|
||||
self.register_parameter("low_action", low_action)
|
||||
action_range = nn.Parameter(
|
||||
torch.from_numpy(self.action_space.high -
|
||||
self.action_space.low).float())
|
||||
action_range.requires_grad = False
|
||||
self.register_parameter("action_range", action_range)
|
||||
self.action_dim = np.product(self.action_space.shape)
|
||||
|
||||
# Build the policy network.
|
||||
self.policy_model = nn.Sequential()
|
||||
|
|
|
@ -60,13 +60,16 @@ class TestDDPG(unittest.TestCase):
|
|||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
|
||||
# Setting explore=False should always return the same action.
|
||||
a_ = trainer.compute_action(obs, explore=False)
|
||||
for _ in range(50):
|
||||
self.assertEqual(trainer.get_policy().global_timestep, 1)
|
||||
for i in range(50):
|
||||
a = trainer.compute_action(obs, explore=False)
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 2)
|
||||
check(a, a_)
|
||||
# explore=None (default: explore) should return different actions.
|
||||
actions = []
|
||||
for _ in range(50):
|
||||
for i in range(50):
|
||||
actions.append(trainer.compute_action(obs))
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 52)
|
||||
check(np.std(actions), 0.0, false=True)
|
||||
trainer.stop()
|
||||
|
||||
|
@ -80,23 +83,27 @@ class TestDDPG(unittest.TestCase):
|
|||
"final_scale": 0.001,
|
||||
}
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
|
||||
# ts=1 (get a deterministic action as per explore=False).
|
||||
# ts=0 (get a deterministic action as per explore=False).
|
||||
deterministic_action = trainer.compute_action(obs, explore=False)
|
||||
# ts=2-5 (in random window).
|
||||
self.assertEqual(trainer.get_policy().global_timestep, 1)
|
||||
# ts=1-49 (in random window).
|
||||
random_a = []
|
||||
for _ in range(49):
|
||||
for i in range(1, 50):
|
||||
random_a.append(trainer.compute_action(obs, explore=True))
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 1)
|
||||
check(random_a[-1], deterministic_action, false=True)
|
||||
self.assertTrue(np.std(random_a) > 0.5)
|
||||
|
||||
# ts > 50 (a=deterministic_action + scale * N[0,1])
|
||||
for _ in range(50):
|
||||
for i in range(50):
|
||||
a = trainer.compute_action(obs, explore=True)
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 51)
|
||||
check(a, deterministic_action, rtol=0.1)
|
||||
|
||||
# ts >> 50 (BUT: explore=False -> expect deterministic action).
|
||||
for _ in range(50):
|
||||
for i in range(50):
|
||||
a = trainer.compute_action(obs, explore=False)
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 101)
|
||||
check(a, deterministic_action)
|
||||
trainer.stop()
|
||||
|
||||
|
@ -291,7 +298,7 @@ class TestDDPG(unittest.TestCase):
|
|||
]
|
||||
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
|
||||
if tf_g.shape != torch_g.shape:
|
||||
check(tf_g, np.transpose(torch_g))
|
||||
check(tf_g, np.transpose(torch_g.cpu()))
|
||||
else:
|
||||
check(tf_g, torch_g)
|
||||
|
||||
|
@ -313,7 +320,7 @@ class TestDDPG(unittest.TestCase):
|
|||
torch_c_grads = [v.grad for v in policy.model.q_variables()]
|
||||
for tf_g, torch_g in zip(tf_c_grads, torch_c_grads):
|
||||
if tf_g.shape != torch_g.shape:
|
||||
check(tf_g, np.transpose(torch_g))
|
||||
check(tf_g, np.transpose(torch_g.cpu()))
|
||||
else:
|
||||
check(tf_g, torch_g)
|
||||
# Compare (unchanged(!) actor grads) with tf ones.
|
||||
|
@ -322,7 +329,7 @@ class TestDDPG(unittest.TestCase):
|
|||
]
|
||||
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
|
||||
if tf_g.shape != torch_g.shape:
|
||||
check(tf_g, np.transpose(torch_g))
|
||||
check(tf_g, np.transpose(torch_g.cpu()))
|
||||
else:
|
||||
check(tf_g, torch_g)
|
||||
|
||||
|
@ -379,7 +386,10 @@ class TestDDPG(unittest.TestCase):
|
|||
else:
|
||||
torch_var = policy.model.state_dict()[map_[tf_key]]
|
||||
if tf_var.shape != torch_var.shape:
|
||||
check(tf_var, np.transpose(torch_var), atol=0.1)
|
||||
check(
|
||||
tf_var,
|
||||
np.transpose(torch_var.cpu()),
|
||||
atol=0.1)
|
||||
else:
|
||||
check(tf_var, torch_var, atol=0.1)
|
||||
|
||||
|
@ -516,6 +526,8 @@ class TestDDPG(unittest.TestCase):
|
|||
for k, v in weights_dict.items() if re.search(
|
||||
"default_policy/(actor_(hidden_0|out)|sequential(_1)?)/", k)
|
||||
}
|
||||
model_dict["low_action"] = convert_to_torch_tensor(np.array([0.0]))
|
||||
model_dict["action_range"] = convert_to_torch_tensor(np.array([1.0]))
|
||||
return model_dict
|
||||
|
||||
|
||||
|
|
|
@ -38,13 +38,16 @@ class TestTD3(unittest.TestCase):
|
|||
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v0")
|
||||
# Setting explore=False should always return the same action.
|
||||
a_ = trainer.compute_action(obs, explore=False)
|
||||
for _ in range(50):
|
||||
self.assertEqual(trainer.get_policy().global_timestep, 1)
|
||||
for i in range(50):
|
||||
a = trainer.compute_action(obs, explore=False)
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 2)
|
||||
check(a, a_)
|
||||
# explore=None (default: explore) should return different actions.
|
||||
actions = []
|
||||
for _ in range(50):
|
||||
for i in range(50):
|
||||
actions.append(trainer.compute_action(obs))
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 52)
|
||||
check(np.std(actions), 0.0, false=True)
|
||||
trainer.stop()
|
||||
|
||||
|
@ -58,23 +61,27 @@ class TestTD3(unittest.TestCase):
|
|||
"final_scale": 0.001,
|
||||
}
|
||||
trainer = td3.TD3Trainer(config=lcl_config, env="Pendulum-v0")
|
||||
# ts=1 (get a deterministic action as per explore=False).
|
||||
# ts=0 (get a deterministic action as per explore=False).
|
||||
deterministic_action = trainer.compute_action(obs, explore=False)
|
||||
# ts=2-5 (in random window).
|
||||
self.assertEqual(trainer.get_policy().global_timestep, 1)
|
||||
# ts=1-29 (in random window).
|
||||
random_a = []
|
||||
for _ in range(29):
|
||||
for i in range(1, 30):
|
||||
random_a.append(trainer.compute_action(obs, explore=True))
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 1)
|
||||
check(random_a[-1], deterministic_action, false=True)
|
||||
self.assertTrue(np.std(random_a) > 0.5)
|
||||
|
||||
# ts > 30 (a=deterministic_action + scale * N[0,1])
|
||||
for _ in range(50):
|
||||
for i in range(50):
|
||||
a = trainer.compute_action(obs, explore=True)
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 31)
|
||||
check(a, deterministic_action, rtol=0.1)
|
||||
|
||||
# ts >> 30 (BUT: explore=False -> expect deterministic action).
|
||||
for _ in range(50):
|
||||
for i in range(50):
|
||||
a = trainer.compute_action(obs, explore=False)
|
||||
self.assertEqual(trainer.get_policy().global_timestep, i + 81)
|
||||
check(a, deterministic_action)
|
||||
trainer.stop()
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ class TestMAML(unittest.TestCase):
|
|||
config = maml.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 1
|
||||
config["horizon"] = 200
|
||||
config["rollout_fragment_length"] = 200
|
||||
num_iterations = 1
|
||||
|
||||
# Test for tf framework (torch not implemented yet).
|
||||
|
|
|
@ -24,6 +24,10 @@ class TestMARWIL(unittest.TestCase):
|
|||
"""Test whether a MARWILTrainer can be built with all frameworks.
|
||||
|
||||
And learns from a historic-data file.
|
||||
To generate this data, first run:
|
||||
$ ./train.py --run=PPO --env=CartPole-v0 \
|
||||
--stop='{"timesteps_total": 50000}' \
|
||||
--config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
|
||||
"""
|
||||
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||
print("rllib dir={}".format(rllib_dir))
|
||||
|
|
|
@ -1,93 +1,112 @@
|
|||
import logging
|
||||
"""
|
||||
Model-Based Meta Policy Optimization (MB-MPO)
|
||||
=============================================
|
||||
|
||||
This file defines the distributed Trainer class for model-based meta policy
|
||||
optimization.
|
||||
See `mbmpo_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
|
||||
Detailed documentation:
|
||||
ttps://docs.ray.io/en/master/rllib-algorithms.html#mbmpo
|
||||
"""
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.execution.metric_ops import CollectMetrics
|
||||
from ray.util.iter import from_actors
|
||||
from ray.rllib.agents.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
from ray.rllib.evaluation.metrics import collect_episodes
|
||||
from ray.rllib.agents.mbmpo.model_vector_env import custom_model_vector_env
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.agents.mbmpo.utils import calculate_gae_advantages, \
|
||||
MBMPOExploration
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.model_vector_env import model_vector_env
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, collect_metrics, \
|
||||
get_learner_stats
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics
|
||||
from ray.rllib.execution.metric_ops import CollectMetrics
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
from ray.rllib.utils.typing import EnvType, TrainerConfigDict
|
||||
from ray.util.iter import from_actors, LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the (base) `Trainer` config in
|
||||
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# If true, use the Generalized Advantage Estimator (GAE)
|
||||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
"use_gae": True,
|
||||
# GAE(lambda) parameter
|
||||
# GAE(lambda) parameter.
|
||||
"lambda": 1.0,
|
||||
# Initial coefficient for KL divergence
|
||||
# Initial coefficient for KL divergence.
|
||||
"kl_coeff": 0.0005,
|
||||
# Size of batches collected from each worker
|
||||
# Size of batches collected from each worker.
|
||||
"rollout_fragment_length": 200,
|
||||
# Stepsize of SGD
|
||||
# Step size of SGD.
|
||||
"lr": 1e-3,
|
||||
# Share layers for value function
|
||||
# Share layers for value function.
|
||||
"vf_share_layers": False,
|
||||
# Coefficient of the value function loss
|
||||
# Coefficient of the value function loss.
|
||||
"vf_loss_coeff": 0.5,
|
||||
# Coefficient of the entropy regularizer
|
||||
# Coefficient of the entropy regularizer.
|
||||
"entropy_coeff": 0.0,
|
||||
# PPO clip parameter
|
||||
# PPO clip parameter.
|
||||
"clip_param": 0.5,
|
||||
# Clip param for the value function. Note that this is sensitive to the
|
||||
# scale of the rewards. If your expected V is large, increase this.
|
||||
"vf_clip_param": 10.0,
|
||||
# If specified, clip the global norm of gradients by this amount
|
||||
# If specified, clip the global norm of gradients by this amount.
|
||||
"grad_clip": None,
|
||||
# Target value for KL divergence
|
||||
# Target value for KL divergence.
|
||||
"kl_target": 0.01,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes".
|
||||
"batch_mode": "complete_episodes",
|
||||
# Which observation filter to apply to the observation
|
||||
# Which observation filter to apply to the observation.
|
||||
"observation_filter": "NoFilter",
|
||||
# Number of Inner adaptation steps for the MAML algorithm
|
||||
# Number of Inner adaptation steps for the MAML algorithm.
|
||||
"inner_adaptation_steps": 1,
|
||||
# Number of MAML steps per meta-update iteration (PPO steps)
|
||||
# Number of MAML steps per meta-update iteration (PPO steps).
|
||||
"maml_optimizer_steps": 8,
|
||||
# Inner Adaptation Step size
|
||||
# Inner adaptation step size.
|
||||
"inner_lr": 1e-3,
|
||||
# Horizon of Environment (200 in MB-MPO paper)
|
||||
# Horizon of the environment (200 in MB-MPO paper).
|
||||
"horizon": 200,
|
||||
# Dynamics Ensemble Hyperparameters
|
||||
# Dynamics ensemble hyperparameters.
|
||||
"dynamics_model": {
|
||||
"custom_model": DynamicsEnsembleCustomModel,
|
||||
# Number of Transition-Dynamics Models for Ensemble
|
||||
# Number of Transition-Dynamics (TD) models in the ensemble.
|
||||
"ensemble_size": 5,
|
||||
# Hidden Layers for Model Ensemble
|
||||
# Hidden layers for each model in the TD-model ensemble.
|
||||
"fcnet_hiddens": [512, 512, 512],
|
||||
# Model Learning Rate
|
||||
# Model learning rate.
|
||||
"lr": 1e-3,
|
||||
# Max number of training epochs per MBMPO iter
|
||||
# Max number of training epochs per MBMPO iter.
|
||||
"train_epochs": 500,
|
||||
# Model Batch Size
|
||||
# Model batch size.
|
||||
"batch_size": 500,
|
||||
# Training/Validation Split
|
||||
# Training/validation split.
|
||||
"valid_split_ratio": 0.2,
|
||||
# Normalize Data (obs, action, and deltas)
|
||||
# Normalize data (obs, action, and deltas).
|
||||
"normalize_data": True,
|
||||
},
|
||||
# Exploration for MB-MPO is based on StochasticSampling, but uses 8000
|
||||
# random timesteps up-front for worker=0.
|
||||
"exploration_config": {
|
||||
"type": MBMPOExploration,
|
||||
"random_timesteps": 8000,
|
||||
},
|
||||
# Workers sample from dynamics models
|
||||
"custom_vector_env": custom_model_vector_env,
|
||||
# How many iterations through MAML per MBMPO iteration
|
||||
# Workers sample from dynamics models, not from actual envs.
|
||||
"custom_vector_env": model_vector_env,
|
||||
# How many iterations through MAML per MBMPO iteration.
|
||||
"num_maml_steps": 10,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
|
@ -101,8 +120,9 @@ METRICS_KEYS = [
|
|||
|
||||
class MetaUpdate:
|
||||
def __init__(self, workers, num_steps, maml_steps, metric_gen):
|
||||
"""Computes the MetaUpdate step in MAML, adapted for MBMPO
|
||||
for multiple MAML Iterations
|
||||
"""Computes the MetaUpdate step in MAML.
|
||||
|
||||
Adapted for MBMPO for multiple MAML Iterations.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): Set of Workers
|
||||
|
@ -111,7 +131,7 @@ class MetaUpdate:
|
|||
metric_gen (Iterator): Generates metrics dictionary
|
||||
|
||||
Returns:
|
||||
metrics (dict): MBMPO metrics for logging
|
||||
metrics (dict): MBMPO metrics for logging.
|
||||
"""
|
||||
self.workers = workers
|
||||
self.num_steps = num_steps
|
||||
|
@ -125,19 +145,19 @@ class MetaUpdate:
|
|||
data_tuple (tuple): 1st element is samples collected from MAML
|
||||
Inner adaptation steps and 2nd element is accumulated metrics
|
||||
"""
|
||||
# Metaupdate Step
|
||||
# Metaupdate Step.
|
||||
print("Meta-Update Step")
|
||||
samples = data_tuple[0]
|
||||
adapt_metrics_dict = data_tuple[1]
|
||||
self.postprocess_metrics(
|
||||
adapt_metrics_dict, prefix="MAMLIter{}".format(self.step_counter))
|
||||
|
||||
# MAML Meta-update
|
||||
# MAML Meta-update.
|
||||
for i in range(self.maml_optimizer_steps):
|
||||
fetches = self.workers.local_worker().learn_on_batch(samples)
|
||||
fetches = get_learner_stats(fetches)
|
||||
|
||||
# Update KLS
|
||||
# Update KLs.
|
||||
def update(pi, pi_id):
|
||||
assert "inner_kl" not in fetches, (
|
||||
"inner_kl should be nested under policy id key", fetches)
|
||||
|
@ -149,7 +169,7 @@ class MetaUpdate:
|
|||
|
||||
self.workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
# Modify Reporting Metrics
|
||||
# Modify Reporting Metrics.
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.info[LEARNER_INFO] = fetches
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
|
||||
|
@ -158,17 +178,17 @@ class MetaUpdate:
|
|||
td_metric = self.workers.local_worker().foreach_policy(
|
||||
fit_dynamics)[0]
|
||||
|
||||
# Sync workers with meta policy
|
||||
# Sync workers with meta policy.
|
||||
self.workers.sync_weights()
|
||||
|
||||
# Sync TD Models with workers
|
||||
# Sync TD Models with workers.
|
||||
sync_ensemble(self.workers)
|
||||
sync_stats(self.workers)
|
||||
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] = td_metric[
|
||||
STEPS_SAMPLED_COUNTER]
|
||||
|
||||
# Modify to CollectMetrics
|
||||
# Modify to CollectMetrics.
|
||||
res = self.metric_gen.__call__(None)
|
||||
res.update(self.metrics)
|
||||
self.step_counter = 0
|
||||
|
@ -195,7 +215,7 @@ class MetaUpdate:
|
|||
|
||||
|
||||
def post_process_metrics(prefix, workers, metrics):
|
||||
"""Update Current Dataset Metrics and filter out specific keys
|
||||
"""Update current dataset metrics and filter out specific keys.
|
||||
|
||||
Args:
|
||||
prefix (str): Prefix string to be appended
|
||||
|
@ -208,8 +228,15 @@ def post_process_metrics(prefix, workers, metrics):
|
|||
return metrics
|
||||
|
||||
|
||||
def inner_adaptation(workers, samples):
|
||||
# Each worker performs one gradient descent
|
||||
def inner_adaptation(workers: WorkerSet, samples: List[SampleBatch]):
|
||||
"""Performs one gradient descend step on each remote worker.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet of the Trainer.
|
||||
samples (List[SampleBatch]): The list of SampleBatches to perform
|
||||
a training step on (one for each remote worker).
|
||||
"""
|
||||
|
||||
for i, e in enumerate(workers.remote_workers()):
|
||||
e.learn_on_batch.remote(samples[i])
|
||||
|
||||
|
@ -218,11 +245,11 @@ def fit_dynamics(policy, pid):
|
|||
return policy.dynamics_model.fit()
|
||||
|
||||
|
||||
def sync_ensemble(workers):
|
||||
"""Syncs dynamics ensemble weights from main to workers
|
||||
def sync_ensemble(workers: WorkerSet) -> None:
|
||||
"""Syncs dynamics ensemble weights from driver (main) to workers.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): Set of workers, including main
|
||||
workers (WorkerSet): Set of workers, including driver (main).
|
||||
"""
|
||||
|
||||
def get_ensemble_weights(worker):
|
||||
|
@ -254,7 +281,7 @@ def sync_ensemble(workers):
|
|||
e.foreach_policy.remote(set_func, weights=weights)
|
||||
|
||||
|
||||
def sync_stats(workers):
|
||||
def sync_stats(workers: WorkerSet) -> None:
|
||||
def get_normalizations(worker):
|
||||
policy = worker.policy_map[DEFAULT_POLICY_ID]
|
||||
return policy.dynamics_model.normalizations
|
||||
|
@ -271,7 +298,7 @@ def sync_stats(workers):
|
|||
set_func, normalizations=normalization_dict)
|
||||
|
||||
|
||||
def post_process_samples(samples, config):
|
||||
def post_process_samples(samples, config: TrainerConfigDict):
|
||||
# Instead of using NN for value function, we use regression
|
||||
split_lst = []
|
||||
for sample in samples:
|
||||
|
@ -297,12 +324,23 @@ def post_process_samples(samples, config):
|
|||
return samples, split_lst
|
||||
|
||||
|
||||
# Similar to MAML Execution Plan
|
||||
def execution_plan(workers, config):
|
||||
# Train TD Models
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the PPO algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: The Policy class to use with PPOTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
# Train TD Models on the driver.
|
||||
workers.local_worker().foreach_policy(fit_dynamics)
|
||||
|
||||
# Sync workers policy with workers
|
||||
# Sync driver's policy with workers.
|
||||
workers.sync_weights()
|
||||
|
||||
# Sync TD Models and normalization stats with workers
|
||||
|
@ -310,18 +348,18 @@ def execution_plan(workers, config):
|
|||
sync_stats(workers)
|
||||
|
||||
# Dropping metrics from the first iteration
|
||||
episodes, to_be_collected = collect_episodes(
|
||||
_, _ = collect_episodes(
|
||||
workers.local_worker(),
|
||||
workers.remote_workers(), [],
|
||||
timeout_seconds=9999)
|
||||
|
||||
# Metrics Collector
|
||||
# Metrics Collector.
|
||||
metric_collect = CollectMetrics(
|
||||
workers,
|
||||
min_history=0,
|
||||
timeout_seconds=config["collect_metrics_timeout"])
|
||||
|
||||
inner_steps = config["inner_adaptation_steps"]
|
||||
num_inner_steps = config["inner_adaptation_steps"]
|
||||
|
||||
def inner_adaptation_steps(itr):
|
||||
buf = []
|
||||
|
@ -339,7 +377,7 @@ def execution_plan(workers, config):
|
|||
prefix = "DynaTrajInner_" + str(adapt_iter)
|
||||
metrics = post_process_metrics(prefix, workers, metrics)
|
||||
|
||||
if len(split) > inner_steps:
|
||||
if len(split) > num_inner_steps:
|
||||
out = SampleBatch.concat_samples(buf)
|
||||
out["split"] = np.array(split)
|
||||
buf = []
|
||||
|
@ -350,42 +388,66 @@ def execution_plan(workers, config):
|
|||
else:
|
||||
inner_adaptation(workers, samples)
|
||||
|
||||
# Iterator for Inner Adaptation Data gathering (from pre->post adaptation)
|
||||
# Iterator for Inner Adaptation Data gathering (from pre->post adaptation).
|
||||
rollouts = from_actors(workers.remote_workers())
|
||||
rollouts = rollouts.batch_across_shards()
|
||||
rollouts = rollouts.transform(inner_adaptation_steps)
|
||||
|
||||
# Metaupdate Step with outer combine loop for multiple MAML iterations
|
||||
# Meta update step with outer combine loop for multiple MAML iterations.
|
||||
train_op = rollouts.combine(
|
||||
MetaUpdate(workers, config["num_maml_steps"],
|
||||
config["maml_optimizer_steps"], metric_collect))
|
||||
return train_op
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
return MBMPOTorchPolicy
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
config["framework"] = "torch"
|
||||
"""Validates the Trainer's config dict.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The Trainer's config to check.
|
||||
|
||||
Raises:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError("MB-MPO not supported in Tensorflow yet!")
|
||||
logger.warning("MB-MPO only supported in PyTorch so far! Switching to "
|
||||
"`framework=torch`.")
|
||||
config["framework"] = "torch"
|
||||
if config["inner_adaptation_steps"] <= 0:
|
||||
raise ValueError("Inner Adaptation Steps must be >=1.")
|
||||
if config["maml_optimizer_steps"] <= 0:
|
||||
raise ValueError("PPO steps for meta-update needs to be >=0")
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise ValueError("entropy_coeff must be >=0")
|
||||
raise ValueError("`entropy_coeff` must be >=0.")
|
||||
if config["batch_mode"] != "complete_episodes":
|
||||
raise ValueError("truncate_episodes not supported")
|
||||
raise ValueError("`batch_mode=truncate_episodes` not supported.")
|
||||
if config["num_workers"] <= 0:
|
||||
raise ValueError("Must have at least 1 worker/task.")
|
||||
|
||||
|
||||
def validate_env(env: EnvType, env_context: EnvContext):
|
||||
"""Validates the local_worker's env object (after creation).
|
||||
|
||||
Args:
|
||||
env (EnvType): The env object to check (for worker=0 only).
|
||||
env_context (EnvContext): The env context used for the instantiation of
|
||||
the local worker's env (worker=0).
|
||||
|
||||
Raises:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
if not hasattr(env, "reward") or not callable(env.reward):
|
||||
raise ValueError("Env {} doest not have a `reward()` method, needed "
|
||||
"for MB-MPO!".format(env))
|
||||
|
||||
|
||||
# Build a child class of `Trainer`, which uses the default policy,
|
||||
# MBMPOTorchPolicy. A TensorFlow version is not available yet.
|
||||
MBMPOTrainer = build_trainer(
|
||||
name="MBMPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=MBMPOTorchPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
execution_plan=execution_plan,
|
||||
validate_config=validate_config)
|
||||
validate_config=validate_config,
|
||||
validate_env=validate_env,
|
||||
)
|
||||
|
|
|
@ -1,22 +1,46 @@
|
|||
import gym
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
import ray
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
|
||||
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
|
||||
setup_config
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
|
||||
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_model_and_action_dist(policy, obs_space, action_space, config):
|
||||
def make_model_and_action_dist(
|
||||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> \
|
||||
Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
||||
"""Constructs the necessary ModelV2 and action dist class for the Policy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The TFPolicy that will use the models.
|
||||
obs_space (gym.spaces.Space): The observation space.
|
||||
action_space (gym.spaces.Space): The action space.
|
||||
config (TrainerConfigDict): The SAC trainer's config dict.
|
||||
|
||||
Returns:
|
||||
ModelV2: The ModelV2 to be used by the Policy. Note: An additional
|
||||
target model will be created in this function and assigned to
|
||||
`policy.target_model`.
|
||||
"""
|
||||
# Get the output distribution class for predicting rewards and next-obs.
|
||||
policy.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
|
||||
obs_space, config, dist_type="deterministic", framework="torch")
|
||||
|
@ -50,6 +74,8 @@ def make_model_and_action_dist(policy, obs_space, action_space, config):
|
|||
return policy.pi, action_dist
|
||||
|
||||
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
MBMPOTorchPolicy = build_torch_policy(
|
||||
name="MBMPOTorchPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
|
||||
|
|
|
@ -19,7 +19,7 @@ class TDModel(nn.Module):
|
|||
def __init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
hidden_layers=[512, 512],
|
||||
hidden_layers=(512, 512),
|
||||
hidden_nonlinearity=None,
|
||||
output_nonlinearity=None,
|
||||
weight_normalization=False,
|
||||
|
@ -118,7 +118,7 @@ def process_samples(samples: SampleBatchType):
|
|||
|
||||
|
||||
class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
||||
"""Represents a Transition Dyamics ensemble
|
||||
"""Represents an ensemble of transition dynamics (TD) models.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
|
@ -139,6 +139,9 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
super(DynamicsEnsembleCustomModel, self).__init__(
|
||||
input_space, action_space, num_outputs, model_config, name)
|
||||
|
||||
# Keep the original Env's observation space for possible clipping.
|
||||
self.env_obs_space = obs_space
|
||||
|
||||
self.num_models = model_config["ensemble_size"]
|
||||
self.max_epochs = model_config["train_epochs"]
|
||||
self.lr = model_config["lr"]
|
||||
|
@ -317,10 +320,9 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
val[key] = samples[key][idx_test, :]
|
||||
return SampleBatch(train), SampleBatch(val)
|
||||
|
||||
"""Used by worker who gather trajectories via TD models
|
||||
"""
|
||||
|
||||
def predict_model_batches(self, obs, actions, device=None):
|
||||
"""Used by worker who gather trajectories via TD models.
|
||||
"""
|
||||
pre_obs = obs
|
||||
if self.normalize_data:
|
||||
obs = normalize(obs, self.normalizations[SampleBatch.CUR_OBS])
|
||||
|
@ -328,10 +330,13 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||
self.normalizations[SampleBatch.ACTIONS])
|
||||
x = np.concatenate([obs, actions], axis=-1)
|
||||
x = convert_to_torch_tensor(x, device=device)
|
||||
delta = self.forward(x).detach().numpy()
|
||||
delta = self.forward(x).detach().cpu().numpy()
|
||||
if self.normalize_data:
|
||||
delta = denormalize(delta, self.normalizations["delta"])
|
||||
return pre_obs + delta
|
||||
new_obs = pre_obs + delta
|
||||
clipped_obs = np.clip(new_obs, self.env_obs_space.low,
|
||||
self.env_obs_space.high)
|
||||
return clipped_obs
|
||||
|
||||
def set_norms(self, normalization_dict):
|
||||
self.normalizations = normalization_dict
|
||||
|
|
41
rllib/agents/mbmpo/tests/test_mbmpo.py
Normal file
41
rllib/agents/mbmpo/tests/test_mbmpo.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.mbmpo as mbmpo
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
||||
|
||||
class TestMBMPO(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_mbmpo_compilation(self):
|
||||
"""Test whether an MBMPOTrainer can be built with all frameworks."""
|
||||
config = mbmpo.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["horizon"] = 200
|
||||
config["dynamics_model"]["ensemble_size"] = 2
|
||||
num_iterations = 1
|
||||
|
||||
# Test for torch framework (tf not implemented yet).
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = mbmpo.MBMPOTrainer(
|
||||
config=config,
|
||||
env="ray.rllib.examples.env.mbmpo_env.PendulumWrapper")
|
||||
for i in range(num_iterations):
|
||||
trainer.train()
|
||||
check_compute_single_action(
|
||||
trainer, include_prev_action_reward=False)
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,13 +1,9 @@
|
|||
import numpy as np
|
||||
import scipy
|
||||
from typing import Union
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.evaluation.postprocessing import discount_cumsum
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
TensorType
|
||||
from ray.rllib.utils.exploration.stochastic_sampling import StochasticSampling
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -69,58 +65,39 @@ def calculate_gae_advantages(paths, discount, gae_lambda):
|
|||
return paths
|
||||
|
||||
|
||||
def discount_cumsum(x, discount):
|
||||
"""
|
||||
Returns:
|
||||
(float) : y[t] - discount*y[t+1] = x[t] or rev(y)[t]
|
||||
- discount*rev(y)[t-1] = rev(x)[t]
|
||||
"""
|
||||
return scipy.signal.lfilter(
|
||||
[1], [1, float(-discount)], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
class MBMPOExploration(Exploration):
|
||||
"""An exploration that simply samples from a distribution.
|
||||
|
||||
The sampling can be made deterministic by passing explore=False into
|
||||
the call to `get_exploration_action`.
|
||||
Also allows for scheduled parameters for the distributions, such as
|
||||
lowering stddev, temperature, etc.. over time.
|
||||
class MBMPOExploration(StochasticSampling):
|
||||
"""Like StochasticSampling, but only worker=0 uses Random for n timesteps.
|
||||
"""
|
||||
|
||||
def __init__(self, action_space, *, framework: str, model: ModelV2,
|
||||
def __init__(self,
|
||||
action_space,
|
||||
*,
|
||||
framework: str,
|
||||
model: ModelV2,
|
||||
random_timesteps: int = 8000,
|
||||
**kwargs):
|
||||
"""Initializes a StochasticSampling Exploration object.
|
||||
"""Initializes a MBMPOExploration instance.
|
||||
|
||||
Args:
|
||||
action_space (Space): The gym action space used by the environment.
|
||||
framework (str): One of None, "tf", "torch".
|
||||
model (ModelV2): The ModelV2 used by the owning Policy.
|
||||
random_timesteps (int): The number of timesteps for which to act
|
||||
completely randomly. Only after this number of timesteps,
|
||||
actual samples will be drawn to get exploration actions.
|
||||
NOTE: For MB-MPO, only worker=0 will use this setting. All
|
||||
other workers will not use random actions ever.
|
||||
"""
|
||||
assert framework is not None
|
||||
self.timestep = 0
|
||||
self.worker_index = kwargs["worker_index"]
|
||||
super().__init__(
|
||||
action_space, model=model, framework=framework, **kwargs)
|
||||
action_space,
|
||||
model=model,
|
||||
framework=framework,
|
||||
random_timesteps=random_timesteps,
|
||||
**kwargs)
|
||||
|
||||
@override(Exploration)
|
||||
def get_exploration_action(self,
|
||||
*,
|
||||
action_distribution: ActionDistribution,
|
||||
timestep: Union[int, TensorType],
|
||||
explore: bool = True):
|
||||
assert self.framework == "torch"
|
||||
return self._get_torch_exploration_action(action_distribution, explore)
|
||||
assert self.framework == "torch", \
|
||||
"MBMPOExploration currently only supports torch!"
|
||||
|
||||
def _get_torch_exploration_action(self, action_dist, explore):
|
||||
action = action_dist.sample()
|
||||
logp = action_dist.sampled_action_logp()
|
||||
|
||||
batch_size = action.size()[0]
|
||||
|
||||
# Initial Random Exploration for Real Env Interaction
|
||||
if self.worker_index == 0 and self.timestep < 8000:
|
||||
print("Using Random")
|
||||
action = [self.action_space.sample() for _ in range(batch_size)]
|
||||
logp = [0.0 for _ in range(batch_size)]
|
||||
self.timestep += batch_size
|
||||
return action, logp
|
||||
# Switch off Random sampling for all non-driver workers.
|
||||
if self.worker_index > 0:
|
||||
self.random_timesteps = 0
|
||||
|
|
|
@ -472,7 +472,7 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
|||
TargetNetworkMixin.__init__(policy)
|
||||
|
||||
|
||||
# Build a child class of `DynamicTFPolicy`, given the custom functions defined
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
SACTorchPolicy = build_torch_policy(
|
||||
name="SACTorchPolicy",
|
||||
|
|
|
@ -446,9 +446,6 @@ class Trainer(Trainable):
|
|||
# in self.setup().
|
||||
config = config or {}
|
||||
|
||||
# Vars to synchronize to workers on each train call
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
||||
# Trainers allow env ids to be passed directly to the constructor.
|
||||
self._env_id = self._register_if_needed(env or config.get("env"))
|
||||
|
||||
|
@ -641,9 +638,10 @@ class Trainer(Trainable):
|
|||
"using evaluation_config: {}".format(extra_config))
|
||||
|
||||
self.evaluation_workers = self._make_workers(
|
||||
self.env_creator,
|
||||
self._policy_class,
|
||||
merge_dicts(self.config, extra_config),
|
||||
env_creator=self.env_creator,
|
||||
validate_env=None,
|
||||
policy_class=self._policy_class,
|
||||
config=merge_dicts(self.config, extra_config),
|
||||
num_workers=self.config["evaluation_num_workers"])
|
||||
self.evaluation_metrics = {}
|
||||
|
||||
|
@ -668,9 +666,11 @@ class Trainer(Trainable):
|
|||
self.__setstate__(extra_data)
|
||||
|
||||
@DeveloperAPI
|
||||
def _make_workers(self, env_creator: Callable[[EnvContext], EnvType],
|
||||
policy_class: Type[Policy], config: TrainerConfigDict,
|
||||
num_workers: int) -> WorkerSet:
|
||||
def _make_workers(
|
||||
self, *, env_creator: Callable[[EnvContext], EnvType],
|
||||
validate_env: Optional[Callable[[EnvType, EnvContext], None]],
|
||||
policy_class: Type[Policy], config: TrainerConfigDict,
|
||||
num_workers: int) -> WorkerSet:
|
||||
"""Default factory method for a WorkerSet running under this Trainer.
|
||||
|
||||
Override this method by passing a custom `make_workers` into
|
||||
|
@ -679,6 +679,9 @@ class Trainer(Trainable):
|
|||
Args:
|
||||
env_creator (callable): A function that return and Env given an env
|
||||
config.
|
||||
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
|
||||
Optional callable to validate the generated environment (only
|
||||
on worker=0).
|
||||
policy (Type[Policy]): The Policy class to use for creating the
|
||||
policies of the workers.
|
||||
config (TrainerConfigDict): The Trainer's config.
|
||||
|
@ -690,6 +693,7 @@ class Trainer(Trainable):
|
|||
"""
|
||||
return WorkerSet(
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy_class=policy_class,
|
||||
trainer_config=config,
|
||||
num_workers=num_workers,
|
||||
|
@ -799,9 +803,6 @@ class Trainer(Trainable):
|
|||
filtered_obs = self.workers.local_worker().filters[policy_id](
|
||||
preprocessed, update=False)
|
||||
|
||||
# Figure out the current (sample) time step and pass it into Policy.
|
||||
self.global_vars["timestep"] += 1
|
||||
|
||||
result = self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs,
|
||||
state,
|
||||
|
@ -809,8 +810,7 @@ class Trainer(Trainable):
|
|||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"],
|
||||
explore=explore,
|
||||
timestep=self.global_vars["timestep"])
|
||||
explore=explore)
|
||||
|
||||
if state or full_fetch:
|
||||
return result
|
||||
|
@ -876,9 +876,6 @@ class Trainer(Trainable):
|
|||
state = list(zip(*filtered_state))
|
||||
state = [np.stack(s) for s in state]
|
||||
|
||||
# Figure out the current (sample) time step and pass it into Policy.
|
||||
self.global_vars["timestep"] += 1
|
||||
|
||||
# Batch compute actions
|
||||
actions, states, infos = policy.compute_actions(
|
||||
obs_batch,
|
||||
|
@ -887,8 +884,7 @@ class Trainer(Trainable):
|
|||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"],
|
||||
explore=explore,
|
||||
timestep=self.global_vars["timestep"])
|
||||
explore=explore)
|
||||
|
||||
# Unbatch actions for the environment
|
||||
atns, actions = space_utils.unbatch(actions), {}
|
||||
|
|
|
@ -2,6 +2,7 @@ import logging
|
|||
from typing import Callable, Iterable, List, Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
|
@ -40,6 +41,7 @@ def build_trainer(
|
|||
default_policy: Optional[Type[Policy]] = None,
|
||||
get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[
|
||||
Policy]]]] = None,
|
||||
validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
|
||||
before_init: Optional[Callable[[Trainer], None]] = None,
|
||||
after_init: Optional[Callable[[Trainer], None]] = None,
|
||||
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
|
||||
|
@ -68,6 +70,9 @@ def build_trainer(
|
|||
that takes a config and returns the policy class or None. If None
|
||||
is returned, will use `default_policy` (which must be provided
|
||||
then).
|
||||
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
|
||||
Optional callable to validate the generated environment (only
|
||||
on worker=0).
|
||||
before_init (Optional[Callable[[Trainer], None]]): Optional callable to
|
||||
run before anything is constructed inside Trainer (Workers with
|
||||
Policies, execution plan, etc..). Takes the Trainer instance as
|
||||
|
@ -106,12 +111,17 @@ def build_trainer(
|
|||
if validate_config:
|
||||
validate_config(config)
|
||||
|
||||
# No `get_policy_class` function.
|
||||
if get_policy_class is None:
|
||||
# Default_policy must be provided (unless in multi-agent mode,
|
||||
# where each policy can have its own default policy class.
|
||||
if not config["multiagent"]["policies"]:
|
||||
assert default_policy is not None
|
||||
self._policy_class = default_policy
|
||||
# Query the function for a class to use.
|
||||
else:
|
||||
self._policy_class = get_policy_class(config)
|
||||
# If None returned, use default policy (must be provided).
|
||||
if self._policy_class is None:
|
||||
assert default_policy is not None
|
||||
self._policy_class = default_policy
|
||||
|
@ -120,9 +130,12 @@ def build_trainer(
|
|||
before_init(self)
|
||||
|
||||
# Creating all workers (excluding evaluation workers).
|
||||
self.workers = self._make_workers(env_creator, self._policy_class,
|
||||
config,
|
||||
self.config["num_workers"])
|
||||
self.workers = self._make_workers(
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy_class=self._policy_class,
|
||||
config=config,
|
||||
num_workers=self.config["num_workers"])
|
||||
self.execution_plan = execution_plan
|
||||
self.train_exec_impl = execution_plan(self.workers, config)
|
||||
|
||||
|
|
2
rllib/env/base_env.py
vendored
2
rllib/env/base_env.py
vendored
|
@ -84,7 +84,7 @@ class BaseEnv:
|
|||
make_env: Callable[[int], EnvType] = None,
|
||||
num_envs: int = 1,
|
||||
remote_envs: bool = False,
|
||||
remote_env_batch_wait_ms: bool = 0) -> "BaseEnv":
|
||||
remote_env_batch_wait_ms: int = 0) -> "BaseEnv":
|
||||
"""Wraps any env type as needed to expose the async interface."""
|
||||
|
||||
from ray.rllib.env.remote_vector_env import RemoteVectorEnv
|
||||
|
|
|
@ -5,13 +5,23 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.utils.typing import EnvType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def custom_model_vector_env(env):
|
||||
"""Returns a VectorizedEnv wrapper around the current envioronment
|
||||
def model_vector_env(env: EnvType) -> BaseEnv:
|
||||
"""Returns a VectorizedEnv wrapper around the given environment.
|
||||
|
||||
To obtain worker configs, one can call get_global_worker().
|
||||
|
||||
Args:
|
||||
env (EnvType): The input environment (of any supported environment
|
||||
type) to be convert to a _VectorizedModelGymEnv (wrapped as
|
||||
an RLlib BaseEnv).
|
||||
|
||||
Returns:
|
||||
BaseEnv: The BaseEnv converted input `env`.
|
||||
"""
|
||||
worker = get_global_worker()
|
||||
worker_index = worker.worker_index
|
||||
|
@ -32,8 +42,13 @@ def custom_model_vector_env(env):
|
|||
|
||||
|
||||
class _VectorizedModelGymEnv(VectorEnv):
|
||||
"""Vectorized Environment Wrapper for MB-MPO. Primary change is
|
||||
in the vector_step method, which calls the dynamics models for
|
||||
"""Vectorized Environment Wrapper for MB-MPO.
|
||||
|
||||
Primary change is in the `vector_step` method, which calls the dynamics
|
||||
models for next_obs "calculation" (instead of the actual env). Also, the
|
||||
actual envs need to have two extra methods implemented: `reward(obs)` and
|
||||
(optionally) `done(obs)`. If `done` is not implemented, we will assume
|
||||
that episodes in the env do not terminate, ever.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -61,11 +76,15 @@ class _VectorizedModelGymEnv(VectorEnv):
|
|||
|
||||
@override(VectorEnv)
|
||||
def vector_reset(self):
|
||||
"""Override parent to store actual env obs for upcoming predictions.
|
||||
"""
|
||||
self.cur_obs = [e.reset() for e in self.envs]
|
||||
return self.cur_obs
|
||||
|
||||
@override(VectorEnv)
|
||||
def reset_at(self, index):
|
||||
"""Override parent to store actual env obs for upcoming predictions.
|
||||
"""
|
||||
obs = self.envs[index].reset()
|
||||
self.cur_obs[index] = obs
|
||||
return obs
|
||||
|
@ -75,19 +94,24 @@ class _VectorizedModelGymEnv(VectorEnv):
|
|||
if self.cur_obs is None:
|
||||
raise ValueError("Need to reset env first")
|
||||
|
||||
# Batch the TD-model prediction.
|
||||
obs_batch = np.stack(self.cur_obs, axis=0)
|
||||
action_batch = np.stack(actions, axis=0)
|
||||
|
||||
# Predict the next observation, given previous a) real obs
|
||||
# (after a reset), b) predicted obs (any other time).
|
||||
next_obs_batch = self.model.predict_model_batches(
|
||||
obs_batch, action_batch, device=self.device)
|
||||
|
||||
next_obs_batch = np.clip(next_obs_batch, -1000, 1000)
|
||||
|
||||
# Call env's reward function.
|
||||
# Note: Each actual env must implement one to output exact rewards.
|
||||
rew_batch = self.envs[0].reward(obs_batch, action_batch,
|
||||
next_obs_batch)
|
||||
|
||||
# If env has a `done` method, use it.
|
||||
if hasattr(self.envs[0], "done"):
|
||||
dones_batch = self.envs[0].done(next_obs_batch)
|
||||
# Otherwise, assume the episode does not end.
|
||||
else:
|
||||
dones_batch = np.asarray([False for _ in range(self.num_envs)])
|
||||
|
|
@ -4,8 +4,19 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
def discount(x: np.ndarray, gamma: float):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
def discount_cumsum(x: np.ndarray, gamma: float) -> float:
|
||||
"""Calculates the discounted cumulative sum over a reward sequence `x`.
|
||||
|
||||
y[t] - discount*y[t+1] = x[t]
|
||||
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
|
||||
|
||||
Args:
|
||||
gamma (float): The discount factor gamma.
|
||||
|
||||
Returns:
|
||||
float: The discounted cumulative sum over the reward sequence `x`.
|
||||
"""
|
||||
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
class Postprocessing:
|
||||
|
@ -54,7 +65,8 @@ def compute_advantages(rollout: SampleBatch,
|
|||
rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1])
|
||||
# This formula for the advantage comes from:
|
||||
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
|
||||
rollout[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_)
|
||||
rollout[Postprocessing.ADVANTAGES] = discount_cumsum(
|
||||
delta_t, gamma * lambda_)
|
||||
rollout[Postprocessing.VALUE_TARGETS] = (
|
||||
rollout[Postprocessing.ADVANTAGES] +
|
||||
rollout[SampleBatch.VF_PREDS]).astype(np.float32)
|
||||
|
@ -62,8 +74,8 @@ def compute_advantages(rollout: SampleBatch,
|
|||
rewards_plus_v = np.concatenate(
|
||||
[rollout[SampleBatch.REWARDS],
|
||||
np.array([last_r])])
|
||||
discounted_returns = discount(rewards_plus_v,
|
||||
gamma)[:-1].astype(np.float32)
|
||||
discounted_returns = discount_cumsum(rewards_plus_v,
|
||||
gamma)[:-1].astype(np.float32)
|
||||
|
||||
if use_critic:
|
||||
rollout[Postprocessing.
|
||||
|
|
|
@ -131,7 +131,10 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
@DeveloperAPI
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_creator: Callable[[EnvContext], EnvType],
|
||||
validate_env: Optional[Callable[[EnvType, EnvContext],
|
||||
None]] = None,
|
||||
policy: Union[type, Dict[str, Tuple[Optional[
|
||||
type], gym.Space, gym.Space, PartialTrainerConfigDict]]],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID] = None,
|
||||
|
@ -175,6 +178,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
Args:
|
||||
env_creator (Callable[[EnvContext], EnvType]): Function that
|
||||
returns a gym.Env given an EnvContext wrapped configuration.
|
||||
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
|
||||
Optional callable to validate the generated environment (only
|
||||
on worker=0).
|
||||
policy (Union[type, Dict[str, Tuple[Optional[type], gym.Space,
|
||||
gym.Space, PartialTrainerConfigDict]]]): Either a Policy class
|
||||
or a dict of policy id strings to
|
||||
|
@ -329,6 +335,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.fake_sampler: bool = fake_sampler
|
||||
|
||||
self.env = _validate_env(env_creator(env_context))
|
||||
if validate_env is not None:
|
||||
validate_env(self.env, self.env_context)
|
||||
|
||||
if isinstance(self.env, (BaseEnv, MultiAgentEnv)):
|
||||
|
||||
def wrap(env):
|
||||
|
@ -338,7 +347,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
not model_config.get("custom_preprocessor") and \
|
||||
preprocessor_pref == "deepmind":
|
||||
|
||||
# Deepmind wrappers already handle all preprocessing
|
||||
# Deepmind wrappers already handle all preprocessing.
|
||||
self.preprocessing_enabled = False
|
||||
|
||||
# If clip_rewards not explicitly set to False, switch it
|
||||
|
@ -1093,7 +1102,7 @@ def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict,
|
|||
|
||||
|
||||
def _validate_env(env: Any) -> EnvType:
|
||||
# allow this as a special case (assumed gym.Env)
|
||||
# Allow this as a special case (assumed gym.Env).
|
||||
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
|
||||
return env
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ class WorkerSet:
|
|||
def __init__(self,
|
||||
*,
|
||||
env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
|
||||
validate_env: Optional[Callable[[EnvType], None]] = None,
|
||||
policy_class: Optional[Type[Policy]] = None,
|
||||
trainer_config: Optional[TrainerConfigDict] = None,
|
||||
num_workers: int = 0,
|
||||
|
@ -42,6 +43,9 @@ class WorkerSet:
|
|||
Args:
|
||||
env_creator (Optional[Callable[[EnvContext], EnvType]]): Function
|
||||
that returns env given env config.
|
||||
validate_env (Optional[Callable[[EnvType], None]]): Optional
|
||||
callable to validate the generated environment (only on
|
||||
worker=0).
|
||||
policy (Optional[Type[Policy]]): A rllib.policy.Policy class.
|
||||
trainer_config (Optional[TrainerConfigDict]): Optional dict that
|
||||
extends the common config of the Trainer class.
|
||||
|
@ -69,9 +73,13 @@ class WorkerSet:
|
|||
self.add_workers(num_workers)
|
||||
|
||||
# Always create a local worker.
|
||||
self._local_worker = self._make_worker(RolloutWorker, env_creator,
|
||||
self._policy_class, 0,
|
||||
self._local_config)
|
||||
self._local_worker = self._make_worker(
|
||||
cls=RolloutWorker,
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy=self._policy_class,
|
||||
worker_index=0,
|
||||
config=self._local_config)
|
||||
|
||||
def local_worker(self) -> RolloutWorker:
|
||||
"""Return the local rollout worker."""
|
||||
|
@ -106,9 +114,13 @@ class WorkerSet:
|
|||
}
|
||||
cls = RolloutWorker.as_remote(**remote_args).remote
|
||||
self._remote_workers.extend([
|
||||
self._make_worker(cls, self._env_creator, self._policy_class,
|
||||
i + 1, self._remote_config)
|
||||
for i in range(num_workers)
|
||||
self._make_worker(
|
||||
cls=cls,
|
||||
env_creator=self._env_creator,
|
||||
validate_env=None,
|
||||
policy=self._policy_class,
|
||||
worker_index=i + 1,
|
||||
config=self._remote_config) for i in range(num_workers)
|
||||
])
|
||||
|
||||
def reset(self, new_remote_workers: List["ActorHandle"]) -> None:
|
||||
|
@ -205,7 +217,9 @@ class WorkerSet:
|
|||
return workers
|
||||
|
||||
def _make_worker(
|
||||
self, cls: Callable, env_creator: Callable[[EnvContext], EnvType],
|
||||
self, *, cls: Callable,
|
||||
env_creator: Callable[[EnvContext], EnvType],
|
||||
validate_env: Optional[Callable[[EnvType], None]],
|
||||
policy: Type[Policy], worker_index: int,
|
||||
config: TrainerConfigDict) -> Union[RolloutWorker, "ActorHandle"]:
|
||||
def session_creator():
|
||||
|
@ -266,8 +280,9 @@ class WorkerSet:
|
|||
"extra_python_environs_for_worker", None)
|
||||
|
||||
worker = cls(
|
||||
env_creator,
|
||||
policy,
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy=policy,
|
||||
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.evaluation.postprocessing import discount
|
||||
from ray.rllib.evaluation.postprocessing import discount_cumsum
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
|
@ -26,7 +26,7 @@ def calculate_advantages(policy,
|
|||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
sample_batch["returns"] = discount(sample_batch["rewards"], 0.99)
|
||||
sample_batch["returns"] = discount_cumsum(sample_batch["rewards"], 0.99)
|
||||
return sample_batch
|
||||
|
||||
|
||||
|
|
109
rllib/examples/env/mbmpo_env.py
vendored
109
rllib/examples/env/mbmpo_env.py
vendored
|
@ -1,53 +1,82 @@
|
|||
import gym
|
||||
from gym.envs.classic_control import PendulumEnv
|
||||
import numpy as np
|
||||
from gym.envs.mujoco import HalfCheetahEnv, HopperEnv
|
||||
|
||||
# MuJoCo may not be installed.
|
||||
HalfCheetahEnv = HopperEnv = None
|
||||
try:
|
||||
from gym.envs.mujoco import HalfCheetahEnv, HopperEnv
|
||||
except (ImportError, gym.error.DependencyNotInstalled):
|
||||
pass
|
||||
|
||||
|
||||
class HalfCheetahWrapper(HalfCheetahEnv):
|
||||
"""HalfCheetah Wrapper that wraps Mujoco Halfcheetah-v2 env
|
||||
with an additional defined reward function for model-based RL.
|
||||
class PendulumWrapper(PendulumEnv):
|
||||
"""Wrapper for the Pendulum-v0 environment.
|
||||
|
||||
This is currently used for MBMPO.
|
||||
Adds an additional `reward` method for some model-based RL algos (e.g.
|
||||
MB-MPO).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
HalfCheetahEnv.__init__(self, *args, **kwargs)
|
||||
|
||||
def reward(self, obs, action, obs_next):
|
||||
if obs.ndim == 2 and action.ndim == 2:
|
||||
assert obs.shape == obs_next.shape
|
||||
forward_vel = obs_next[:, 8]
|
||||
ctrl_cost = 0.1 * np.sum(np.square(action), axis=1)
|
||||
reward = forward_vel - ctrl_cost
|
||||
# obs = [cos(theta), sin(theta), dtheta/dt]
|
||||
# To get the angle back from obs: atan2(sin(theta), cos(theta)).
|
||||
theta = np.arctan2(
|
||||
np.clip(obs[:, 1], -1.0, 1.0), np.clip(obs[:, 0], -1.0, 1.0))
|
||||
# Do everything in (B,) space (single theta-, action- and
|
||||
# reward values).
|
||||
a = np.clip(action, -self.max_torque, self.max_torque)[0]
|
||||
costs = self.angle_normalize(theta) ** 2 + \
|
||||
0.1 * obs[:, 2] ** 2 + 0.001 * (a ** 2)
|
||||
return -costs
|
||||
|
||||
@staticmethod
|
||||
def angle_normalize(x):
|
||||
return (((x + np.pi) % (2 * np.pi)) - np.pi)
|
||||
|
||||
|
||||
if HalfCheetahEnv:
|
||||
|
||||
class HalfCheetahWrapper(HalfCheetahEnv):
|
||||
"""Wrapper for the MuJoCo HalfCheetah-v2 environment.
|
||||
|
||||
Adds an additional `reward` method for some model-based RL algos (e.g.
|
||||
MB-MPO).
|
||||
"""
|
||||
|
||||
def reward(self, obs, action, obs_next):
|
||||
if obs.ndim == 2 and action.ndim == 2:
|
||||
assert obs.shape == obs_next.shape
|
||||
forward_vel = obs_next[:, 8]
|
||||
ctrl_cost = 0.1 * np.sum(np.square(action), axis=1)
|
||||
reward = forward_vel - ctrl_cost
|
||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||
else:
|
||||
forward_vel = obs_next[8]
|
||||
ctrl_cost = 0.1 * np.square(action).sum()
|
||||
reward = forward_vel - ctrl_cost
|
||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||
|
||||
class HopperWrapper(HopperEnv):
|
||||
"""Wrapper for the MuJoCo Hopper-v2 environment.
|
||||
|
||||
Adds an additional `reward` method for some model-based RL algos (e.g.
|
||||
MB-MPO).
|
||||
"""
|
||||
|
||||
def reward(self, obs, action, obs_next):
|
||||
alive_bonus = 1.0
|
||||
assert obs.ndim == 2 and action.ndim == 2
|
||||
assert (obs.shape == obs_next.shape
|
||||
and action.shape[0] == obs.shape[0])
|
||||
vel = obs_next[:, 5]
|
||||
ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1)
|
||||
reward = vel + alive_bonus - ctrl_cost
|
||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||
else:
|
||||
forward_vel = obs_next[8]
|
||||
ctrl_cost = 0.1 * np.square(action).sum()
|
||||
reward = forward_vel - ctrl_cost
|
||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||
|
||||
|
||||
class HopperWrapper(HopperEnv):
|
||||
"""Hopper Wrapper that wraps Mujoco Hopper-v2 env
|
||||
with an additional defined reward function for model-based RL.
|
||||
|
||||
This is currently used for MBMPO.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
HopperEnv.__init__(self, *args, **kwargs)
|
||||
|
||||
def reward(self, obs, action, obs_next):
|
||||
alive_bonus = 1.0
|
||||
assert obs.ndim == 2 and action.ndim == 2
|
||||
assert obs.shape == obs_next.shape and action.shape[0] == obs.shape[0]
|
||||
vel = obs_next[:, 5]
|
||||
ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1)
|
||||
reward = vel + alive_bonus - ctrl_cost
|
||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = HopperWrapper()
|
||||
env = PendulumWrapper()
|
||||
env.reset()
|
||||
for _ in range(1000):
|
||||
for _ in range(100):
|
||||
env.step(env.action_space.sample())
|
||||
env.render()
|
||||
|
|
|
@ -69,8 +69,8 @@ def training_workflow(config, reporter):
|
|||
env = gym.make("CartPole-v0")
|
||||
policy = CustomPolicy(env.observation_space, env.action_space, {})
|
||||
workers = [
|
||||
RolloutWorker.as_remote().remote(lambda c: gym.make("CartPole-v0"),
|
||||
CustomPolicy)
|
||||
RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda c: gym.make("CartPole-v0"), policy=CustomPolicy)
|
||||
for _ in range(config["num_workers"])
|
||||
]
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ class ModelCatalog:
|
|||
action_space (Space): Action space of the target gym env.
|
||||
config (Optional[dict]): Optional model config.
|
||||
dist_type (Optional[str]): Identifier of the action distribution
|
||||
interpreted as a hint.
|
||||
type (str) interpreted as a hint.
|
||||
framework (str): One of "tf", "tfe", or "torch".
|
||||
kwargs (dict): Optional kwargs to pass on to the Distribution's
|
||||
constructor.
|
||||
|
@ -134,20 +134,21 @@ class ModelCatalog:
|
|||
distribution.
|
||||
"""
|
||||
|
||||
dist = None
|
||||
dist_cls = None
|
||||
config = config or MODEL_DEFAULTS
|
||||
# Custom distribution given.
|
||||
if config.get("custom_action_dist"):
|
||||
action_dist_name = config["custom_action_dist"]
|
||||
logger.debug(
|
||||
"Using custom action distribution {}".format(action_dist_name))
|
||||
dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
|
||||
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
|
||||
action_dist_name)
|
||||
# Dist_type is given directly as a class.
|
||||
elif type(dist_type) is type and \
|
||||
issubclass(dist_type, ActionDistribution) and \
|
||||
dist_type not in (
|
||||
MultiActionDistribution, TorchMultiActionDistribution):
|
||||
dist = dist_type
|
||||
dist_cls = dist_type
|
||||
# Box space -> DiagGaussian OR Deterministic.
|
||||
elif isinstance(action_space, gym.spaces.Box):
|
||||
if len(action_space.shape) > 1:
|
||||
|
@ -159,14 +160,15 @@ class ModelCatalog:
|
|||
"using a Tuple action space, or the multi-agent API.")
|
||||
# TODO(sven): Check for bounds and return SquashedNormal, etc..
|
||||
if dist_type is None:
|
||||
dist = TorchDiagGaussian if framework == "torch" \
|
||||
dist_cls = TorchDiagGaussian if framework == "torch" \
|
||||
else DiagGaussian
|
||||
elif dist_type == "deterministic":
|
||||
dist = TorchDeterministic if framework == "torch" \
|
||||
dist_cls = TorchDeterministic if framework == "torch" \
|
||||
else Deterministic
|
||||
# Discrete Space -> Categorical.
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
dist = TorchCategorical if framework == "torch" else Categorical
|
||||
dist_cls = (TorchCategorical
|
||||
if framework == "torch" else Categorical)
|
||||
# Tuple/Dict Spaces -> MultiAction.
|
||||
elif dist_type in (MultiActionDistribution,
|
||||
TorchMultiActionDistribution) or \
|
||||
|
@ -189,19 +191,20 @@ class ModelCatalog:
|
|||
# TODO(sven): implement
|
||||
raise NotImplementedError(
|
||||
"Simplex action spaces not supported for torch.")
|
||||
dist = Dirichlet
|
||||
dist_cls = Dirichlet
|
||||
# MultiDiscrete -> MultiCategorical.
|
||||
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||
dist = TorchMultiCategorical if framework == "torch" else \
|
||||
dist_cls = TorchMultiCategorical if framework == "torch" else \
|
||||
MultiCategorical
|
||||
return partial(dist, input_lens=action_space.nvec), \
|
||||
return partial(dist_cls, input_lens=action_space.nvec), \
|
||||
int(sum(action_space.nvec))
|
||||
# Unknown type -> Error.
|
||||
else:
|
||||
raise NotImplementedError("Unsupported args: {} {}".format(
|
||||
action_space, dist_type))
|
||||
|
||||
return dist, dist.required_model_output_shape(action_space, config)
|
||||
return dist_cls, dist_cls.required_model_output_shape(
|
||||
action_space, config)
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -59,8 +59,8 @@ class Preprocessor:
|
|||
try:
|
||||
if not self._obs_space.contains(observation):
|
||||
raise ValueError(
|
||||
"Observation outside expected value range",
|
||||
self._obs_space, observation)
|
||||
"Observation ({}) outside given space ({})!",
|
||||
observation, self._obs_space)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Observation for a Box/MultiBinary/MultiDiscrete space "
|
||||
|
|
|
@ -425,8 +425,8 @@ class MultiActionDistribution(TFActionDistribution):
|
|||
|
||||
self.action_space_struct = get_base_struct_from_space(action_space)
|
||||
|
||||
input_lens = np.array(input_lens, dtype=np.int32)
|
||||
split_inputs = tf.split(inputs, input_lens, axis=1)
|
||||
self.input_lens = np.array(input_lens, dtype=np.int32)
|
||||
split_inputs = tf.split(inputs, self.input_lens, axis=1)
|
||||
self.flat_child_distributions = tree.map_structure(
|
||||
lambda dist, input_: dist(input_, model), child_distributions,
|
||||
split_inputs)
|
||||
|
@ -492,6 +492,10 @@ class MultiActionDistribution(TFActionDistribution):
|
|||
p += c.sampled_action_logp()
|
||||
return p
|
||||
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(self, action_space, model_config):
|
||||
return np.sum(self.input_lens)
|
||||
|
||||
|
||||
class Dirichlet(TFActionDistribution):
|
||||
"""Dirichlet distribution for continuous actions that are between
|
||||
|
|
|
@ -350,9 +350,9 @@ class TorchMultiActionDistribution(TorchDistributionWrapper):
|
|||
|
||||
self.action_space_struct = get_base_struct_from_space(action_space)
|
||||
|
||||
input_lens = tree.flatten(input_lens)
|
||||
self.input_lens = tree.flatten(input_lens)
|
||||
flat_child_distributions = tree.flatten(child_distributions)
|
||||
split_inputs = torch.split(inputs, input_lens, dim=1)
|
||||
split_inputs = torch.split(inputs, self.input_lens, dim=1)
|
||||
self.flat_child_distributions = tree.map_structure(
|
||||
lambda dist, input_: dist(input_, model), flat_child_distributions,
|
||||
list(split_inputs))
|
||||
|
@ -419,3 +419,7 @@ class TorchMultiActionDistribution(TorchDistributionWrapper):
|
|||
for c in self.flat_child_distributions[1:]:
|
||||
p += c.sampled_action_logp()
|
||||
return p
|
||||
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(self, action_space, model_config):
|
||||
return np.sum(self.input_lens)
|
||||
|
|
|
@ -48,9 +48,10 @@ class TorchModelV2(ModelV2):
|
|||
|
||||
@override(ModelV2)
|
||||
def variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||
p = list(self.parameters())
|
||||
if as_dict:
|
||||
return self.state_dict()
|
||||
return list(self.parameters())
|
||||
return {k: p[i] for i, k in enumerate(self.state_dict().keys())}
|
||||
return p
|
||||
|
||||
@override(ModelV2)
|
||||
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||
|
|
|
@ -393,8 +393,8 @@ def build_eager_tf_policy(name,
|
|||
if extra_action_fetches_fn:
|
||||
extra_fetches.update(extra_action_fetches_fn(self))
|
||||
|
||||
# Increase our global sampling timestep counter by 1.
|
||||
self.global_timestep += 1
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += len(obs_batch)
|
||||
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
|
@ -667,6 +667,8 @@ def build_eager_tf_policy(name,
|
|||
dummy_batch.get(SampleBatch.PREV_ACTIONS),
|
||||
dummy_batch.get(SampleBatch.PREV_REWARDS),
|
||||
explore=False)
|
||||
# Got to reset global_timestep again after this fake run-through.
|
||||
self.global_timestep = 0
|
||||
dummy_batch.update(fetches)
|
||||
|
||||
postprocessed_batch = self.postprocess_trajectory(
|
||||
|
|
|
@ -330,6 +330,9 @@ class TFPolicy(Policy):
|
|||
# Execute session run to get action (and other fetches).
|
||||
fetched = builder.get(to_fetch)
|
||||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += fetched[0].shape[0]
|
||||
|
||||
return fetched
|
||||
|
||||
@override(Policy)
|
||||
|
|
|
@ -274,6 +274,9 @@ class TorchPolicy(Policy):
|
|||
if dist_inputs is not None:
|
||||
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
||||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
|
||||
|
||||
return actions, state_out, extra_fetches, logp
|
||||
|
||||
@override(Policy)
|
||||
|
|
|
@ -174,9 +174,11 @@ class TorchSpyModel(TorchModelV2, nn.Module):
|
|||
action_space, num_outputs, model_config, name)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
pos = input_dict["obs"]["sensors"]["position"].numpy()
|
||||
front_cam = input_dict["obs"]["sensors"]["front_cam"][0].numpy()
|
||||
task = input_dict["obs"]["inner_state"]["job_status"]["task"].numpy()
|
||||
pos = input_dict["obs"]["sensors"]["position"].detach().cpu().numpy()
|
||||
front_cam = input_dict["obs"]["sensors"]["front_cam"][
|
||||
0].detach().cpu().numpy()
|
||||
task = input_dict["obs"]["inner_state"]["job_status"][
|
||||
"task"].detach().cpu().numpy()
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
"torch_spy_in_{}".format(TorchSpyModel.capture_index),
|
||||
pickle.dumps((pos, front_cam, task)),
|
||||
|
@ -226,7 +228,7 @@ def to_list(value):
|
|||
elif isinstance(value, int):
|
||||
return value
|
||||
else:
|
||||
return value.numpy().tolist()
|
||||
return value.detach().cpu().numpy().tolist()
|
||||
|
||||
|
||||
class DictSpyModel(TFModelV2):
|
||||
|
|
|
@ -4,7 +4,7 @@ halfcheetah-mbmpo:
|
|||
stop:
|
||||
training_iteration: 500
|
||||
config:
|
||||
# Only supported in torch right now
|
||||
# Only supported in torch right now.
|
||||
framework: torch
|
||||
# 200 in paper, 1000 will take forever
|
||||
horizon: 200
|
||||
|
|
|
@ -4,7 +4,7 @@ hopper-mbmpo:
|
|||
stop:
|
||||
training_iteration: 500
|
||||
config:
|
||||
# Only supported in torch right now
|
||||
# Only supported in torch right now.
|
||||
framework: torch
|
||||
# 200 in paper, 1000 will take forever
|
||||
horizon: 200
|
||||
|
|
27
rllib/tuned_examples/mbmpo/pendulum-mbmpo.yaml
Normal file
27
rllib/tuned_examples/mbmpo/pendulum-mbmpo.yaml
Normal file
|
@ -0,0 +1,27 @@
|
|||
pendulum-mbmpo:
|
||||
env: ray.rllib.examples.env.mbmpo_env.PendulumWrapper
|
||||
run: MBMPO
|
||||
stop:
|
||||
episode_reward_mean: -500
|
||||
training_iteration: 50
|
||||
config:
|
||||
# Only supported in torch right now.
|
||||
framework: torch
|
||||
#horizon: 200
|
||||
num_envs_per_worker: 20
|
||||
inner_adaptation_steps: 1
|
||||
maml_optimizer_steps: 8
|
||||
gamma: 0.99
|
||||
lambda: 1.0
|
||||
lr: 0.001
|
||||
clip_param: 0.5
|
||||
kl_target: 0.003
|
||||
kl_coeff: 0.0000000001
|
||||
num_workers: 10
|
||||
num_gpus: 0
|
||||
inner_lr: 0.001
|
||||
clip_actions: False
|
||||
num_maml_steps: 15
|
||||
model:
|
||||
fcnet_hiddens: [32, 32]
|
||||
free_log_std: True
|
|
@ -56,9 +56,12 @@ class GaussianNoise(Exploration):
|
|||
super().__init__(
|
||||
action_space, model=model, framework=framework, **kwargs)
|
||||
|
||||
# Create the Random exploration module (used for the first n
|
||||
# timesteps).
|
||||
self.random_timesteps = random_timesteps
|
||||
self.random_exploration = Random(
|
||||
action_space, model=self.model, framework=self.framework, **kwargs)
|
||||
|
||||
self.stddev = stddev
|
||||
# The `scale` annealing schedule.
|
||||
self.scale_schedule = scale_schedule or PiecewiseSchedule(
|
||||
|
@ -104,7 +107,7 @@ class GaussianNoise(Exploration):
|
|||
self.random_exploration.get_tf_exploration_action_op(
|
||||
action_dist, explore)
|
||||
stochastic_actions = tf.cond(
|
||||
pred=tf.convert_to_tensor(ts <= self.random_timesteps),
|
||||
pred=tf.convert_to_tensor(ts < self.random_timesteps),
|
||||
true_fn=lambda: random_actions,
|
||||
false_fn=lambda: tf.clip_by_value(
|
||||
deterministic_actions + gaussian_sample,
|
||||
|
@ -144,7 +147,7 @@ class GaussianNoise(Exploration):
|
|||
# Apply exploration.
|
||||
if explore:
|
||||
# Random exploration phase.
|
||||
if self.last_timestep <= self.random_timesteps:
|
||||
if self.last_timestep < self.random_timesteps:
|
||||
action, _ = \
|
||||
self.random_exploration.get_torch_exploration_action(
|
||||
action_dist, explore=True)
|
||||
|
|
|
@ -114,7 +114,7 @@ class OrnsteinUhlenbeckNoise(GaussianNoise):
|
|||
self.random_exploration.get_tf_exploration_action_op(
|
||||
action_dist, explore)
|
||||
exploration_actions = tf.cond(
|
||||
pred=tf.convert_to_tensor(ts <= self.random_timesteps),
|
||||
pred=tf.convert_to_tensor(ts < self.random_timesteps),
|
||||
true_fn=lambda: random_actions,
|
||||
false_fn=lambda: stochastic_actions)
|
||||
|
||||
|
@ -133,7 +133,7 @@ class OrnsteinUhlenbeckNoise(GaussianNoise):
|
|||
if timestep is None:
|
||||
self.last_timestep.assign_add(1)
|
||||
else:
|
||||
self.last_timestep = timestep
|
||||
self.last_timestep.assign(timestep)
|
||||
return action, logp
|
||||
else:
|
||||
assign_op = (tf1.assign_add(self.last_timestep, 1)
|
||||
|
@ -151,7 +151,7 @@ class OrnsteinUhlenbeckNoise(GaussianNoise):
|
|||
# Apply exploration.
|
||||
if explore:
|
||||
# Random exploration phase.
|
||||
if self.last_timestep <= self.random_timesteps:
|
||||
if self.last_timestep < self.random_timesteps:
|
||||
action, _ = \
|
||||
self.random_exploration.get_torch_exploration_action(
|
||||
action_dist, explore=True)
|
||||
|
|
|
@ -65,7 +65,7 @@ class ParameterNoise(Exploration):
|
|||
# This excludes any variable, whose name contains "LayerNorm" (those
|
||||
# are BatchNormalization layers, which should not be perturbed).
|
||||
self.model_variables = [
|
||||
v for k, v in self.model.variables(as_dict=True).items()
|
||||
v for k, v in self.model.trainable_variables(as_dict=True).items()
|
||||
if "LayerNorm" not in k
|
||||
]
|
||||
# Our noise to be added to the weights. Each item in `self.noise`
|
||||
|
@ -296,7 +296,8 @@ class ParameterNoise(Exploration):
|
|||
else:
|
||||
for i in range(len(self.noise)):
|
||||
self.noise[i] = torch.normal(
|
||||
mean=torch.zeros(self.noise[i].size()), std=self.stddev)
|
||||
mean=torch.zeros(self.noise[i].size()),
|
||||
std=self.stddev).to(self.device)
|
||||
|
||||
def _tf_sample_new_noise_op(self):
|
||||
added_noises = []
|
||||
|
@ -343,9 +344,11 @@ class ParameterNoise(Exploration):
|
|||
elif self.framework in ["tf2", "tfe"]:
|
||||
self._tf_add_stored_noise_op()
|
||||
else:
|
||||
for i in range(len(self.noise)):
|
||||
for var, noise in zip(self.model_variables, self.noise):
|
||||
# Add noise to weights in-place.
|
||||
self.model_variables[i].add_(self.noise[i])
|
||||
var.requires_grad = False
|
||||
var.add_(noise)
|
||||
var.requires_grad = True
|
||||
|
||||
self.weights_are_currently_noisy = True
|
||||
|
||||
|
@ -383,7 +386,9 @@ class ParameterNoise(Exploration):
|
|||
else:
|
||||
for var, noise in zip(self.model_variables, self.noise):
|
||||
# Remove noise from weights in-place.
|
||||
var.requires_grad = False
|
||||
var.add_(-noise)
|
||||
var.requires_grad = True
|
||||
|
||||
self.weights_are_currently_noisy = False
|
||||
|
||||
|
|
|
@ -72,10 +72,14 @@ class Random(Exploration):
|
|||
maxval=component.n,
|
||||
dtype=component.dtype)
|
||||
elif isinstance(component, MultiDiscrete):
|
||||
return tf.random.uniform(
|
||||
shape=(batch_size, ) + component.shape,
|
||||
maxval=component.nvec,
|
||||
dtype=component.dtype)
|
||||
return tf.concat(
|
||||
[
|
||||
tf.random.uniform(
|
||||
shape=(batch_size, 1),
|
||||
maxval=n,
|
||||
dtype=component.dtype) for n in component.nvec
|
||||
],
|
||||
axis=1)
|
||||
elif isinstance(component, Box):
|
||||
if component.bounded_above.all() and \
|
||||
component.bounded_below.all():
|
||||
|
|
|
@ -5,8 +5,9 @@ from ray.rllib.models.action_dist import ActionDistribution
|
|||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.exploration.exploration import Exploration
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
TensorType
|
||||
from ray.rllib.utils.exploration.random import Random
|
||||
from ray.rllib.utils.framework import get_variable, try_import_tf, \
|
||||
try_import_torch, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -21,18 +22,37 @@ class StochasticSampling(Exploration):
|
|||
lowering stddev, temperature, etc.. over time.
|
||||
"""
|
||||
|
||||
def __init__(self, action_space, *, framework: str, model: ModelV2,
|
||||
def __init__(self,
|
||||
action_space,
|
||||
*,
|
||||
framework: str,
|
||||
model: ModelV2,
|
||||
random_timesteps: int = 0,
|
||||
**kwargs):
|
||||
"""Initializes a StochasticSampling Exploration object.
|
||||
|
||||
Args:
|
||||
action_space (Space): The gym action space used by the environment.
|
||||
framework (str): One of None, "tf", "torch".
|
||||
model (ModelV2): The ModelV2 used by the owning Policy.
|
||||
random_timesteps (int): The number of timesteps for which to act
|
||||
completely randomly. Only after this number of timesteps,
|
||||
actual samples will be drawn to get exploration actions.
|
||||
"""
|
||||
assert framework is not None
|
||||
super().__init__(
|
||||
action_space, model=model, framework=framework, **kwargs)
|
||||
|
||||
# Create the Random exploration module (used for the first n
|
||||
# timesteps).
|
||||
self.random_timesteps = random_timesteps
|
||||
self.random_exploration = Random(
|
||||
action_space, model=self.model, framework=self.framework, **kwargs)
|
||||
|
||||
# The current timestep value (tf-var or python int).
|
||||
self.last_timestep = get_variable(
|
||||
0, framework=self.framework, tf_name="timestep")
|
||||
|
||||
@override(Exploration)
|
||||
def get_exploration_action(self,
|
||||
*,
|
||||
|
@ -41,36 +61,73 @@ class StochasticSampling(Exploration):
|
|||
explore: bool = True):
|
||||
if self.framework == "torch":
|
||||
return self._get_torch_exploration_action(action_distribution,
|
||||
explore)
|
||||
timestep, explore)
|
||||
else:
|
||||
return self._get_tf_exploration_action_op(action_distribution,
|
||||
explore)
|
||||
timestep, explore)
|
||||
|
||||
def _get_tf_exploration_action_op(self, action_dist, timestep, explore):
|
||||
ts = timestep if timestep is not None else self.last_timestep + 1
|
||||
|
||||
stochastic_actions = tf.cond(
|
||||
pred=tf.convert_to_tensor(ts < self.random_timesteps),
|
||||
true_fn=lambda: (
|
||||
self.random_exploration.get_tf_exploration_action_op(
|
||||
action_dist,
|
||||
explore=True)[0]),
|
||||
false_fn=lambda: action_dist.sample(),
|
||||
)
|
||||
deterministic_actions = action_dist.deterministic_sample()
|
||||
|
||||
def _get_tf_exploration_action_op(self, action_dist, explore):
|
||||
sample = action_dist.sample()
|
||||
deterministic_sample = action_dist.deterministic_sample()
|
||||
action = tf.cond(
|
||||
tf.constant(explore) if isinstance(explore, bool) else explore,
|
||||
true_fn=lambda: sample,
|
||||
false_fn=lambda: deterministic_sample)
|
||||
true_fn=lambda: stochastic_actions,
|
||||
false_fn=lambda: deterministic_actions)
|
||||
|
||||
def logp_false_fn():
|
||||
batch_size = tf.shape(tree.flatten(action)[0])[0]
|
||||
return tf.zeros(shape=(batch_size, ), dtype=tf.float32)
|
||||
|
||||
logp = tf.cond(
|
||||
tf.constant(explore) if isinstance(explore, bool) else explore,
|
||||
tf.math.logical_and(
|
||||
explore, tf.convert_to_tensor(ts >= self.random_timesteps)),
|
||||
true_fn=lambda: action_dist.sampled_action_logp(),
|
||||
false_fn=logp_false_fn)
|
||||
|
||||
return action, logp
|
||||
# Increment `last_timestep` by 1 (or set to `timestep`).
|
||||
if self.framework in ["tf2", "tfe"]:
|
||||
if timestep is None:
|
||||
self.last_timestep.assign_add(1)
|
||||
else:
|
||||
self.last_timestep.assign(timestep)
|
||||
return action, logp
|
||||
else:
|
||||
assign_op = (tf1.assign_add(self.last_timestep, 1)
|
||||
if timestep is None else tf1.assign(
|
||||
self.last_timestep, timestep))
|
||||
with tf1.control_dependencies([assign_op]):
|
||||
return action, logp
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_exploration_action(action_dist, explore):
|
||||
def _get_torch_exploration_action(self, action_dist, timestep, explore):
|
||||
# Set last timestep or (if not given) increase by one.
|
||||
self.last_timestep = timestep if timestep is not None else \
|
||||
self.last_timestep + 1
|
||||
|
||||
# Apply exploration.
|
||||
if explore:
|
||||
action = action_dist.sample()
|
||||
logp = action_dist.sampled_action_logp()
|
||||
# Random exploration phase.
|
||||
if self.last_timestep < self.random_timesteps:
|
||||
action, logp = \
|
||||
self.random_exploration.get_torch_exploration_action(
|
||||
action_dist, explore=True)
|
||||
# Take a sample from our distribution.
|
||||
else:
|
||||
action = action_dist.sample()
|
||||
logp = action_dist.sampled_action_logp()
|
||||
|
||||
# No exploration -> Return deterministic actions.
|
||||
else:
|
||||
action = action_dist.deterministic_sample()
|
||||
logp = torch.zeros_like(action_dist.sampled_action_logp())
|
||||
|
||||
return action, logp
|
||||
|
|
|
@ -178,12 +178,21 @@ class TestParameterNoise(unittest.TestCase):
|
|||
noise = policy.exploration.noise[0][0][0]
|
||||
if fw == "tf":
|
||||
noise = policy.get_session().run(noise)
|
||||
elif fw == "torch":
|
||||
noise = noise.detach().cpu().numpy()
|
||||
else:
|
||||
noise = noise.numpy()
|
||||
return noise
|
||||
|
||||
def _get_current_weight(self, policy, fw):
|
||||
weights = policy.get_weights()
|
||||
if fw == "torch":
|
||||
# DQN model.
|
||||
if "_hidden_layers.0._model.0.weight" in weights:
|
||||
return weights["_hidden_layers.0._model.0.weight"][0][0]
|
||||
# DDPG model.
|
||||
else:
|
||||
return weights["policy_model.action_0._model.0.weight"][0][0]
|
||||
key = 0 if fw in ["tf2", "tfe"] else list(weights.keys())[0]
|
||||
return weights[key][0][0]
|
||||
|
||||
|
|
|
@ -204,9 +204,9 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
|
|||
false=false)
|
||||
if torch is not None:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().numpy()
|
||||
x = x.detach().cpu().numpy()
|
||||
if isinstance(y, torch.Tensor):
|
||||
y = y.detach().numpy()
|
||||
y = y.detach().cpu().numpy()
|
||||
|
||||
# Using decimals.
|
||||
if atol is None and rtol is None:
|
||||
|
|
Loading…
Add table
Reference in a new issue