mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Dreamer ConfigObject class. (#24650)
This commit is contained in:
parent
6d94b2acbe
commit
f243895ebb
4 changed files with 223 additions and 99 deletions
|
@ -1,6 +1,11 @@
|
||||||
from ray.rllib.agents.dreamer.dreamer import DREAMERTrainer, DEFAULT_CONFIG
|
from ray.rllib.agents.dreamer.dreamer import (
|
||||||
|
DREAMERConfig,
|
||||||
|
DREAMERTrainer,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"DREAMERConfig",
|
||||||
"DREAMERTrainer",
|
"DREAMERTrainer",
|
||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import random
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from ray.rllib.agents import with_common_config
|
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||||
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
||||||
from ray.rllib.agents.trainer import Trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||||
|
@ -15,6 +15,7 @@ from ray.rllib.execution.rollout_ops import (
|
||||||
synchronous_parallel_sample,
|
synchronous_parallel_sample,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
PartialTrainerConfigDict,
|
PartialTrainerConfigDict,
|
||||||
|
@ -25,68 +26,167 @@ from ray.rllib.utils.typing import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# __sphinx_doc_begin__
|
|
||||||
DEFAULT_CONFIG = with_common_config({
|
|
||||||
# PlaNET Model LR
|
|
||||||
"td_model_lr": 6e-4,
|
|
||||||
# Actor LR
|
|
||||||
"actor_lr": 8e-5,
|
|
||||||
# Critic LR
|
|
||||||
"critic_lr": 8e-5,
|
|
||||||
# Grad Clipping
|
|
||||||
"grad_clip": 100.0,
|
|
||||||
# Discount
|
|
||||||
"discount": 0.99,
|
|
||||||
# Lambda
|
|
||||||
"lambda": 0.95,
|
|
||||||
# Clipping is done inherently via policy tanh.
|
|
||||||
"clip_actions": False,
|
|
||||||
# Training iterations per data collection from real env
|
|
||||||
"dreamer_train_iters": 100,
|
|
||||||
# Horizon for Enviornment (1000 for Mujoco/DMC)
|
|
||||||
"horizon": 1000,
|
|
||||||
# Number of episodes to sample for Loss Calculation
|
|
||||||
"batch_size": 50,
|
|
||||||
# Length of each episode to sample for Loss Calculation
|
|
||||||
"batch_length": 50,
|
|
||||||
# Imagination Horizon for Training Actor and Critic
|
|
||||||
"imagine_horizon": 15,
|
|
||||||
# Free Nats
|
|
||||||
"free_nats": 3.0,
|
|
||||||
# KL Coeff for the Model Loss
|
|
||||||
"kl_coeff": 1.0,
|
|
||||||
# Distributed Dreamer not implemented yet
|
|
||||||
"num_workers": 0,
|
|
||||||
# Prefill Timesteps
|
|
||||||
"prefill_timesteps": 5000,
|
|
||||||
# This should be kept at 1 to preserve sample efficiency
|
|
||||||
"num_envs_per_worker": 1,
|
|
||||||
# Exploration Gaussian
|
|
||||||
"explore_noise": 0.3,
|
|
||||||
# Batch mode
|
|
||||||
"batch_mode": "complete_episodes",
|
|
||||||
# Custom Model
|
|
||||||
"dreamer_model": {
|
|
||||||
"custom_model": DreamerModel,
|
|
||||||
# RSSM/PlaNET parameters
|
|
||||||
"deter_size": 200,
|
|
||||||
"stoch_size": 30,
|
|
||||||
# CNN Decoder Encoder
|
|
||||||
"depth_size": 32,
|
|
||||||
# General Network Parameters
|
|
||||||
"hidden_size": 400,
|
|
||||||
# Action STD
|
|
||||||
"action_init_std": 5.0,
|
|
||||||
},
|
|
||||||
|
|
||||||
"env_config": {
|
class DREAMERConfig(TrainerConfig):
|
||||||
# Repeats action send by policy for frame_skip times in env
|
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||||
"frame_skip": 2,
|
|
||||||
},
|
Example:
|
||||||
})
|
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
||||||
# __sphinx_doc_end__
|
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
|
||||||
# fmt: on
|
... .resources(num_gpus=0)\
|
||||||
|
... .rollouts(num_rollout_workers=4)
|
||||||
|
>>> print(config.to_dict())
|
||||||
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||||
|
>>> trainer = config.build(env="CartPole-v1")
|
||||||
|
>>> trainer.train()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray import tune
|
||||||
|
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
||||||
|
>>> config = DREAMERConfig()
|
||||||
|
>>> # Print out some default values.
|
||||||
|
>>> print(config.clip_param)
|
||||||
|
>>> # Update the config object.
|
||||||
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), clip_param=0.2)
|
||||||
|
>>> # Set the config object's env.
|
||||||
|
>>> config.environment(env="CartPole-v1")
|
||||||
|
>>> # Use to_dict() to get the old-style python config dict
|
||||||
|
>>> # when running with tune.
|
||||||
|
>>> tune.run(
|
||||||
|
... "DREAMER",
|
||||||
|
... stop={"episode_reward_mean": 200},
|
||||||
|
... config=config.to_dict(),
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initializes a PPOConfig instance."""
|
||||||
|
super().__init__(trainer_class=DREAMERTrainer)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# __sphinx_doc_begin__
|
||||||
|
# Dreamer specific settings:
|
||||||
|
self.td_model_lr = 6e-4
|
||||||
|
self.actor_lr = 8e-5
|
||||||
|
self.critic_lr = 8e-5
|
||||||
|
self.grad_clip = 100.0
|
||||||
|
self.lambda_ = 0.95
|
||||||
|
self.dreamer_train_iters = 100
|
||||||
|
self.batch_size = 50
|
||||||
|
self.batch_length = 50
|
||||||
|
self.imagine_horizon = 15
|
||||||
|
self.free_nats = 3.0
|
||||||
|
self.kl_coeff = 1.0
|
||||||
|
self.prefill_timesteps = 5000
|
||||||
|
self.explore_noise = 0.3
|
||||||
|
self.dreamer_model = {
|
||||||
|
"custom_model": DreamerModel,
|
||||||
|
# RSSM/PlaNET parameters
|
||||||
|
"deter_size": 200,
|
||||||
|
"stoch_size": 30,
|
||||||
|
# CNN Decoder Encoder
|
||||||
|
"depth_size": 32,
|
||||||
|
# General Network Parameters
|
||||||
|
"hidden_size": 400,
|
||||||
|
# Action STD
|
||||||
|
"action_init_std": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Override some of TrainerConfig's default values with PPO-specific values.
|
||||||
|
# .rollouts()
|
||||||
|
self.num_workers = 0
|
||||||
|
self.num_envs_per_worker = 1
|
||||||
|
self.horizon = 1000
|
||||||
|
self.batch_mode = "complete_episodes"
|
||||||
|
self.clip_actions = False
|
||||||
|
|
||||||
|
# .training()
|
||||||
|
self.gamma = 0.99
|
||||||
|
|
||||||
|
# .environment()
|
||||||
|
self.env_config = {
|
||||||
|
# Repeats action send by policy for frame_skip times in env
|
||||||
|
"frame_skip": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# __sphinx_doc_end__
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@override(TrainerConfig)
|
||||||
|
def training(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
td_model_lr: Optional[float] = None,
|
||||||
|
actor_lr: Optional[float] = None,
|
||||||
|
critic_lr: Optional[float] = None,
|
||||||
|
grad_clip: Optional[float] = None,
|
||||||
|
lambda_: Optional[float] = None,
|
||||||
|
dreamer_train_iters: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
batch_length: Optional[int] = None,
|
||||||
|
imagine_horizon: Optional[int] = None,
|
||||||
|
free_nats: Optional[float] = None,
|
||||||
|
kl_coeff: Optional[float] = None,
|
||||||
|
prefill_timesteps: Optional[int] = None,
|
||||||
|
explore_noise: Optional[float] = None,
|
||||||
|
dreamer_model: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "DREAMERConfig":
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
td_model_lr: PlaNET (transition dynamics) model learning rate.
|
||||||
|
actor_lr: Actor model learning rate.
|
||||||
|
critic_lr: Critic model learning rate.
|
||||||
|
grad_clip: If specified, clip the global norm of gradients by this amount.
|
||||||
|
lambda_: The GAE (lambda) parameter.
|
||||||
|
dreamer_train_iters: Training iterations per data collection from real env.
|
||||||
|
batch_size: Number of episodes to sample for loss calculation.
|
||||||
|
batch_length: Length of each episode to sample for loss calculation.
|
||||||
|
imagine_horizon: Imagination horizon for training Actor and Critic.
|
||||||
|
free_nats: Free nats.
|
||||||
|
kl_coeff: KL coefficient for the model Loss.
|
||||||
|
prefill_timesteps: Prefill timesteps.
|
||||||
|
explore_noise: Exploration Gaussian noise.
|
||||||
|
dreamer_model: Custom model config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Pass kwargs onto super's `training()` method.
|
||||||
|
super().training(**kwargs)
|
||||||
|
|
||||||
|
if td_model_lr is not None:
|
||||||
|
self.td_model_lr = td_model_lr
|
||||||
|
if actor_lr is not None:
|
||||||
|
self.actor_lr = actor_lr
|
||||||
|
if critic_lr is not None:
|
||||||
|
self.critic_lr = critic_lr
|
||||||
|
if grad_clip is not None:
|
||||||
|
self.grad_clip = grad_clip
|
||||||
|
if lambda_ is not None:
|
||||||
|
self.lambda_ = lambda_
|
||||||
|
if dreamer_train_iters is not None:
|
||||||
|
self.dreamer_train_iters = dreamer_train_iters
|
||||||
|
if batch_size is not None:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
if batch_length is not None:
|
||||||
|
self.batch_length = batch_length
|
||||||
|
if imagine_horizon is not None:
|
||||||
|
self.imagine_horizon = imagine_horizon
|
||||||
|
if free_nats is not None:
|
||||||
|
self.free_nats = free_nats
|
||||||
|
if kl_coeff is not None:
|
||||||
|
self.kl_coeff = kl_coeff
|
||||||
|
if prefill_timesteps is not None:
|
||||||
|
self.prefill_timesteps = prefill_timesteps
|
||||||
|
if explore_noise is not None:
|
||||||
|
self.explore_noise = explore_noise
|
||||||
|
if dreamer_model is not None:
|
||||||
|
self.dreamer_model = dreamer_model
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
def _postprocess_gif(gif: np.ndarray):
|
def _postprocess_gif(gif: np.ndarray):
|
||||||
|
@ -197,7 +297,7 @@ class DREAMERTrainer(Trainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
return DEFAULT_CONFIG
|
return DREAMERConfig().to_dict()
|
||||||
|
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||||
|
@ -316,3 +416,20 @@ class DREAMERTrainer(Trainer):
|
||||||
results = super()._compile_step_results(*args, **kwargs)
|
results = super()._compile_step_results(*args, **kwargs)
|
||||||
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
|
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(DREAMERConfig().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
|
@ -7,12 +7,13 @@ import ray
|
||||||
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
||||||
from ray.rllib.evaluation.episode import Episode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||||
from ray.rllib.utils.typing import AgentID
|
from ray.rllib.utils.typing import AgentID, TensorType
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
if torch:
|
if torch:
|
||||||
|
@ -23,30 +24,30 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# This is the computation graph for workers (inner adaptation steps)
|
# This is the computation graph for workers (inner adaptation steps)
|
||||||
def compute_dreamer_loss(
|
def compute_dreamer_loss(
|
||||||
obs,
|
obs: TensorType,
|
||||||
action,
|
action: TensorType,
|
||||||
reward,
|
reward: TensorType,
|
||||||
model,
|
model: TorchModelV2,
|
||||||
imagine_horizon,
|
imagine_horizon: int,
|
||||||
discount=0.99,
|
gamma: float = 0.99,
|
||||||
lambda_=0.95,
|
lambda_: float = 0.95,
|
||||||
kl_coeff=1.0,
|
kl_coeff: float = 1.0,
|
||||||
free_nats=3.0,
|
free_nats: float = 3.0,
|
||||||
log=False,
|
log: bool = False,
|
||||||
):
|
):
|
||||||
"""Constructs loss for the Dreamer objective
|
"""Constructs loss for the Dreamer objective.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
obs (TensorType): Observations (o_t)
|
obs: Observations (o_t).
|
||||||
action (TensorType): Actions (a_(t-1))
|
action: Actions (a_(t-1)).
|
||||||
reward (TensorType): Rewards (r_(t-1))
|
reward: Rewards (r_(t-1)).
|
||||||
model (TorchModelV2): DreamerModel, encompassing all other models
|
model: DreamerModel, encompassing all other models.
|
||||||
imagine_horizon (int): Imagine horizon for actor and critic loss
|
imagine_horizon: Imagine horizon for actor and critic loss.
|
||||||
discount (float): Discount
|
gamma: Discount factor gamma.
|
||||||
lambda_ (float): Lambda, like in GAE
|
lambda_: Lambda, like in GAE.
|
||||||
kl_coeff (float): KL Coefficient for Divergence loss in model loss
|
kl_coeff: KL Coefficient for Divergence loss in model loss.
|
||||||
free_nats (float): Threshold for minimum divergence in model loss
|
free_nats: Threshold for minimum divergence in model loss.
|
||||||
log (bool): If log, generate gifs
|
log: If log, generate gifs.
|
||||||
"""
|
"""
|
||||||
encoder_weights = list(model.encoder.parameters())
|
encoder_weights = list(model.encoder.parameters())
|
||||||
decoder_weights = list(model.decoder.parameters())
|
decoder_weights = list(model.decoder.parameters())
|
||||||
|
@ -84,7 +85,7 @@ def compute_dreamer_loss(
|
||||||
with FreezeParameters(model_weights + critic_weights):
|
with FreezeParameters(model_weights + critic_weights):
|
||||||
reward = model.reward(imag_feat).mean
|
reward = model.reward(imag_feat).mean
|
||||||
value = model.value(imag_feat).mean
|
value = model.value(imag_feat).mean
|
||||||
pcont = discount * torch.ones_like(reward)
|
pcont = gamma * torch.ones_like(reward)
|
||||||
returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_)
|
returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_)
|
||||||
discount_shape = pcont[:1].size()
|
discount_shape = pcont[:1].size()
|
||||||
discount = torch.cumprod(
|
discount = torch.cumprod(
|
||||||
|
@ -168,7 +169,7 @@ def dreamer_loss(policy, model, dist_class, train_batch):
|
||||||
train_batch["rewards"],
|
train_batch["rewards"],
|
||||||
policy.model,
|
policy.model,
|
||||||
policy.config["imagine_horizon"],
|
policy.config["imagine_horizon"],
|
||||||
policy.config["discount"],
|
policy.config["gamma"],
|
||||||
policy.config["lambda"],
|
policy.config["lambda"],
|
||||||
policy.config["kl_coeff"],
|
policy.config["kl_coeff"],
|
||||||
policy.config["free_nats"],
|
policy.config["free_nats"],
|
||||||
|
|
|
@ -18,23 +18,24 @@ class TestDreamer(unittest.TestCase):
|
||||||
|
|
||||||
def test_dreamer_compilation(self):
|
def test_dreamer_compilation(self):
|
||||||
"""Test whether an DreamerTrainer can be built with all frameworks."""
|
"""Test whether an DreamerTrainer can be built with all frameworks."""
|
||||||
config = dreamer.DEFAULT_CONFIG.copy()
|
config = dreamer.DREAMERConfig()
|
||||||
config["env_config"] = {
|
config.environment(
|
||||||
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
|
env=RandomEnv,
|
||||||
"action_space": Box(-1.0, 1.0, (3,)),
|
env_config={
|
||||||
}
|
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
|
||||||
|
"action_space": Box(-1.0, 1.0, (3,)),
|
||||||
|
},
|
||||||
|
)
|
||||||
# Num episode chunks per batch.
|
# Num episode chunks per batch.
|
||||||
config["batch_size"] = 2
|
|
||||||
# Length (ts) of an episode chunk in a batch.
|
# Length (ts) of an episode chunk in a batch.
|
||||||
config["batch_length"] = 20
|
|
||||||
# Sub-iterations per .train() call.
|
# Sub-iterations per .train() call.
|
||||||
config["dreamer_train_iters"] = 4
|
config.training(batch_size=2, batch_length=20, dreamer_train_iters=4)
|
||||||
|
|
||||||
num_iterations = 1
|
num_iterations = 1
|
||||||
|
|
||||||
# Test against all frameworks.
|
# Test against all frameworks.
|
||||||
for _ in framework_iterator(config, frameworks="torch"):
|
for _ in framework_iterator(config, frameworks="torch"):
|
||||||
trainer = dreamer.DREAMERTrainer(config=config, env=RandomEnv)
|
trainer = config.build()
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
print(results)
|
print(results)
|
||||||
|
|
Loading…
Add table
Reference in a new issue