[RLlib] Dreamer ConfigObject class. (#24650)

This commit is contained in:
Sven Mika 2022-05-10 16:19:42 +02:00 committed by GitHub
parent 6d94b2acbe
commit f243895ebb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 223 additions and 99 deletions

View file

@ -1,6 +1,11 @@
from ray.rllib.agents.dreamer.dreamer import DREAMERTrainer, DEFAULT_CONFIG
from ray.rllib.agents.dreamer.dreamer import (
DREAMERConfig,
DREAMERTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"DREAMERConfig",
"DREAMERTrainer",
"DEFAULT_CONFIG",
]

View file

@ -1,9 +1,9 @@
import logging
import random
import numpy as np
import random
from typing import Optional
from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
@ -15,6 +15,7 @@ from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
@ -25,68 +26,167 @@ from ray.rllib.utils.typing import (
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# PlaNET Model LR
"td_model_lr": 6e-4,
# Actor LR
"actor_lr": 8e-5,
# Critic LR
"critic_lr": 8e-5,
# Grad Clipping
"grad_clip": 100.0,
# Discount
"discount": 0.99,
# Lambda
"lambda": 0.95,
# Clipping is done inherently via policy tanh.
"clip_actions": False,
# Training iterations per data collection from real env
"dreamer_train_iters": 100,
# Horizon for Enviornment (1000 for Mujoco/DMC)
"horizon": 1000,
# Number of episodes to sample for Loss Calculation
"batch_size": 50,
# Length of each episode to sample for Loss Calculation
"batch_length": 50,
# Imagination Horizon for Training Actor and Critic
"imagine_horizon": 15,
# Free Nats
"free_nats": 3.0,
# KL Coeff for the Model Loss
"kl_coeff": 1.0,
# Distributed Dreamer not implemented yet
"num_workers": 0,
# Prefill Timesteps
"prefill_timesteps": 5000,
# This should be kept at 1 to preserve sample efficiency
"num_envs_per_worker": 1,
# Exploration Gaussian
"explore_noise": 0.3,
# Batch mode
"batch_mode": "complete_episodes",
# Custom Model
"dreamer_model": {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
},
"env_config": {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
},
})
# __sphinx_doc_end__
# fmt: on
class DREAMERConfig(TrainerConfig):
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
Example:
>>> from ray.rllib.agents.dreamer import DREAMERConfig
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray import tune
>>> from ray.rllib.agents.dreamer import DREAMERConfig
>>> config = DREAMERConfig()
>>> # Print out some default values.
>>> print(config.clip_param)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), clip_param=0.2)
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "DREAMER",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=DREAMERTrainer)
# fmt: off
# __sphinx_doc_begin__
# Dreamer specific settings:
self.td_model_lr = 6e-4
self.actor_lr = 8e-5
self.critic_lr = 8e-5
self.grad_clip = 100.0
self.lambda_ = 0.95
self.dreamer_train_iters = 100
self.batch_size = 50
self.batch_length = 50
self.imagine_horizon = 15
self.free_nats = 3.0
self.kl_coeff = 1.0
self.prefill_timesteps = 5000
self.explore_noise = 0.3
self.dreamer_model = {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
}
# Override some of TrainerConfig's default values with PPO-specific values.
# .rollouts()
self.num_workers = 0
self.num_envs_per_worker = 1
self.horizon = 1000
self.batch_mode = "complete_episodes"
self.clip_actions = False
# .training()
self.gamma = 0.99
# .environment()
self.env_config = {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
}
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
def training(
self,
*,
td_model_lr: Optional[float] = None,
actor_lr: Optional[float] = None,
critic_lr: Optional[float] = None,
grad_clip: Optional[float] = None,
lambda_: Optional[float] = None,
dreamer_train_iters: Optional[int] = None,
batch_size: Optional[int] = None,
batch_length: Optional[int] = None,
imagine_horizon: Optional[int] = None,
free_nats: Optional[float] = None,
kl_coeff: Optional[float] = None,
prefill_timesteps: Optional[int] = None,
explore_noise: Optional[float] = None,
dreamer_model: Optional[dict] = None,
**kwargs,
) -> "DREAMERConfig":
"""
Args:
td_model_lr: PlaNET (transition dynamics) model learning rate.
actor_lr: Actor model learning rate.
critic_lr: Critic model learning rate.
grad_clip: If specified, clip the global norm of gradients by this amount.
lambda_: The GAE (lambda) parameter.
dreamer_train_iters: Training iterations per data collection from real env.
batch_size: Number of episodes to sample for loss calculation.
batch_length: Length of each episode to sample for loss calculation.
imagine_horizon: Imagination horizon for training Actor and Critic.
free_nats: Free nats.
kl_coeff: KL coefficient for the model Loss.
prefill_timesteps: Prefill timesteps.
explore_noise: Exploration Gaussian noise.
dreamer_model: Custom model config.
Returns:
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if td_model_lr is not None:
self.td_model_lr = td_model_lr
if actor_lr is not None:
self.actor_lr = actor_lr
if critic_lr is not None:
self.critic_lr = critic_lr
if grad_clip is not None:
self.grad_clip = grad_clip
if lambda_ is not None:
self.lambda_ = lambda_
if dreamer_train_iters is not None:
self.dreamer_train_iters = dreamer_train_iters
if batch_size is not None:
self.batch_size = batch_size
if batch_length is not None:
self.batch_length = batch_length
if imagine_horizon is not None:
self.imagine_horizon = imagine_horizon
if free_nats is not None:
self.free_nats = free_nats
if kl_coeff is not None:
self.kl_coeff = kl_coeff
if prefill_timesteps is not None:
self.prefill_timesteps = prefill_timesteps
if explore_noise is not None:
self.explore_noise = explore_noise
if dreamer_model is not None:
self.dreamer_model = dreamer_model
return self
def _postprocess_gif(gif: np.ndarray):
@ -197,7 +297,7 @@ class DREAMERTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return DREAMERConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@ -316,3 +416,20 @@ class DREAMERTrainer(Trainer):
results = super()._compile_step_results(*args, **kwargs)
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
return results
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DREAMERConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG",
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -7,12 +7,13 @@ import ray
from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.typing import AgentID
from ray.rllib.utils.typing import AgentID, TensorType
torch, nn = try_import_torch()
if torch:
@ -23,30 +24,30 @@ logger = logging.getLogger(__name__)
# This is the computation graph for workers (inner adaptation steps)
def compute_dreamer_loss(
obs,
action,
reward,
model,
imagine_horizon,
discount=0.99,
lambda_=0.95,
kl_coeff=1.0,
free_nats=3.0,
log=False,
obs: TensorType,
action: TensorType,
reward: TensorType,
model: TorchModelV2,
imagine_horizon: int,
gamma: float = 0.99,
lambda_: float = 0.95,
kl_coeff: float = 1.0,
free_nats: float = 3.0,
log: bool = False,
):
"""Constructs loss for the Dreamer objective
"""Constructs loss for the Dreamer objective.
Args:
obs (TensorType): Observations (o_t)
action (TensorType): Actions (a_(t-1))
reward (TensorType): Rewards (r_(t-1))
model (TorchModelV2): DreamerModel, encompassing all other models
imagine_horizon (int): Imagine horizon for actor and critic loss
discount (float): Discount
lambda_ (float): Lambda, like in GAE
kl_coeff (float): KL Coefficient for Divergence loss in model loss
free_nats (float): Threshold for minimum divergence in model loss
log (bool): If log, generate gifs
obs: Observations (o_t).
action: Actions (a_(t-1)).
reward: Rewards (r_(t-1)).
model: DreamerModel, encompassing all other models.
imagine_horizon: Imagine horizon for actor and critic loss.
gamma: Discount factor gamma.
lambda_: Lambda, like in GAE.
kl_coeff: KL Coefficient for Divergence loss in model loss.
free_nats: Threshold for minimum divergence in model loss.
log: If log, generate gifs.
"""
encoder_weights = list(model.encoder.parameters())
decoder_weights = list(model.decoder.parameters())
@ -84,7 +85,7 @@ def compute_dreamer_loss(
with FreezeParameters(model_weights + critic_weights):
reward = model.reward(imag_feat).mean
value = model.value(imag_feat).mean
pcont = discount * torch.ones_like(reward)
pcont = gamma * torch.ones_like(reward)
returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_)
discount_shape = pcont[:1].size()
discount = torch.cumprod(
@ -168,7 +169,7 @@ def dreamer_loss(policy, model, dist_class, train_batch):
train_batch["rewards"],
policy.model,
policy.config["imagine_horizon"],
policy.config["discount"],
policy.config["gamma"],
policy.config["lambda"],
policy.config["kl_coeff"],
policy.config["free_nats"],

View file

@ -18,23 +18,24 @@ class TestDreamer(unittest.TestCase):
def test_dreamer_compilation(self):
"""Test whether an DreamerTrainer can be built with all frameworks."""
config = dreamer.DEFAULT_CONFIG.copy()
config["env_config"] = {
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
"action_space": Box(-1.0, 1.0, (3,)),
}
config = dreamer.DREAMERConfig()
config.environment(
env=RandomEnv,
env_config={
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
"action_space": Box(-1.0, 1.0, (3,)),
},
)
# Num episode chunks per batch.
config["batch_size"] = 2
# Length (ts) of an episode chunk in a batch.
config["batch_length"] = 20
# Sub-iterations per .train() call.
config["dreamer_train_iters"] = 4
config.training(batch_size=2, batch_length=20, dreamer_train_iters=4)
num_iterations = 1
# Test against all frameworks.
for _ in framework_iterator(config, frameworks="torch"):
trainer = dreamer.DREAMERTrainer(config=config, env=RandomEnv)
trainer = config.build()
for i in range(num_iterations):
results = trainer.train()
print(results)