ray/rllib/algorithms/dreamer/dreamer.py

440 lines
16 KiB
Python

import logging
import numpy as np
import random
from typing import Optional
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, concat_samples
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.algorithms.dreamer.dreamer_model import DreamerModel
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import (
PartialAlgorithmConfigDict,
AlgorithmConfigDict,
ResultDict,
)
from ray.rllib.utils.replay_buffers import ReplayBuffer, StorageUnit
logger = logging.getLogger(__name__)
class DreamerConfig(AlgorithmConfig):
"""Defines a configuration class from which a Dreamer Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
>>> config = DreamerConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray import tune
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
>>> config = DreamerConfig()
>>> # Print out some default values.
>>> print(config.clip_param)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), clip_param=0.2)
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "Dreamer",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self):
"""Initializes a PPOConfig instance."""
super().__init__(algo_class=Dreamer)
# fmt: off
# __sphinx_doc_begin__
# Dreamer specific settings:
self.td_model_lr = 6e-4
self.actor_lr = 8e-5
self.critic_lr = 8e-5
self.grad_clip = 100.0
self.lambda_ = 0.95
self.dreamer_train_iters = 100
self.batch_size = 50
self.batch_length = 50
self.imagine_horizon = 15
self.free_nats = 3.0
self.kl_coeff = 1.0
self.prefill_timesteps = 5000
self.explore_noise = 0.3
self.dreamer_model = {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
}
# Override some of AlgorithmConfig's default values with PPO-specific values.
# .rollouts()
self.num_workers = 0
self.num_envs_per_worker = 1
self.horizon = 1000
self.batch_mode = "complete_episodes"
self.clip_actions = False
# .training()
self.gamma = 0.99
# Number of timesteps to collect from rollout workers before we start
# sampling from replay buffers for learning. Whether we count this in agent
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
self.num_steps_sampled_before_learning_starts = 0
# .environment()
self.env_config = {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
}
# __sphinx_doc_end__
# fmt: on
@override(AlgorithmConfig)
def training(
self,
*,
td_model_lr: Optional[float] = None,
actor_lr: Optional[float] = None,
critic_lr: Optional[float] = None,
grad_clip: Optional[float] = None,
lambda_: Optional[float] = None,
dreamer_train_iters: Optional[int] = None,
batch_size: Optional[int] = None,
batch_length: Optional[int] = None,
imagine_horizon: Optional[int] = None,
free_nats: Optional[float] = None,
kl_coeff: Optional[float] = None,
prefill_timesteps: Optional[int] = None,
explore_noise: Optional[float] = None,
dreamer_model: Optional[dict] = None,
num_steps_sampled_before_learning_starts: Optional[int] = None,
**kwargs,
) -> "DreamerConfig":
"""
Args:
td_model_lr: PlaNET (transition dynamics) model learning rate.
actor_lr: Actor model learning rate.
critic_lr: Critic model learning rate.
grad_clip: If specified, clip the global norm of gradients by this amount.
lambda_: The GAE (lambda) parameter.
dreamer_train_iters: Training iterations per data collection from real env.
batch_size: Number of episodes to sample for loss calculation.
batch_length: Length of each episode to sample for loss calculation.
imagine_horizon: Imagination horizon for training Actor and Critic.
free_nats: Free nats.
kl_coeff: KL coefficient for the model Loss.
prefill_timesteps: Prefill timesteps.
explore_noise: Exploration Gaussian noise.
dreamer_model: Custom model config.
num_steps_sampled_before_learning_starts: Number of timesteps to collect
from rollout workers before we start sampling from replay buffers for
learning. Whether we count this in agent steps or environment steps
depends on config["multiagent"]["count_steps_by"].
Returns:
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if td_model_lr is not None:
self.td_model_lr = td_model_lr
if actor_lr is not None:
self.actor_lr = actor_lr
if critic_lr is not None:
self.critic_lr = critic_lr
if grad_clip is not None:
self.grad_clip = grad_clip
if lambda_ is not None:
self.lambda_ = lambda_
if dreamer_train_iters is not None:
self.dreamer_train_iters = dreamer_train_iters
if batch_size is not None:
self.batch_size = batch_size
if batch_length is not None:
self.batch_length = batch_length
if imagine_horizon is not None:
self.imagine_horizon = imagine_horizon
if free_nats is not None:
self.free_nats = free_nats
if kl_coeff is not None:
self.kl_coeff = kl_coeff
if prefill_timesteps is not None:
self.prefill_timesteps = prefill_timesteps
if explore_noise is not None:
self.explore_noise = explore_noise
if dreamer_model is not None:
self.dreamer_model = dreamer_model
if num_steps_sampled_before_learning_starts is not None:
self.num_steps_sampled_before_learning_starts = (
num_steps_sampled_before_learning_starts
)
return self
def _postprocess_gif(gif: np.ndarray):
"""Process provided gif to a format that can be logged to Tensorboard."""
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
B, T, C, H, W = gif.shape
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
return frames
class EpisodeSequenceBuffer(ReplayBuffer):
def __init__(self, capacity: int = 1000, replay_sequence_length: int = 50):
"""Stores episodes and samples sequences of size `replay_sequence_length`.
Args:
capacity: Maximum number of episodes this buffer can store
replay_sequence_length: Episode chunking length in sample()
"""
super().__init__(capacity=capacity, storage_unit=StorageUnit.EPISODES)
self.replay_sequence_length = replay_sequence_length
def sample(self, num_items: int):
"""Samples [batch_size, length] from the list of episodes
Args:
num_items: batch_size to be sampled
"""
episodes_buffer = []
while len(episodes_buffer) < num_items:
episode = super().sample(1)
if episode.count < self.replay_sequence_length:
continue
available = episode.count - self.replay_sequence_length
index = int(random.randint(0, available))
episodes_buffer.append(episode[index : index + self.replay_sequence_length])
return concat_samples(episodes_buffer)
def total_sampled_timesteps(worker):
return worker.policy_map[DEFAULT_POLICY_ID].global_timestep
class DreamerIteration:
def __init__(
self, worker, episode_buffer, dreamer_train_iters, batch_size, act_repeat
):
self.worker = worker
self.episode_buffer = episode_buffer
self.dreamer_train_iters = dreamer_train_iters
self.repeat = act_repeat
self.batch_size = batch_size
def __call__(self, samples):
# Update target network every `target_network_update_freq` sample steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
]
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
# Dreamer training loop.
for n in range(self.dreamer_train_iters):
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
batch = self.episode_buffer.sample(self.batch_size)
fetches = self.worker.learn_on_batch(batch)
else:
fetches = {}
# Custom Logging
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self.postprocess_gif(gif)
# Metrics Calculation
metrics = _get_shared_metrics()
metrics.info[LEARNER_INFO] = fetches
metrics.counters[STEPS_SAMPLED_COUNTER] = self.episode_buffer.timesteps
metrics.counters[STEPS_SAMPLED_COUNTER] *= self.repeat
res = collect_metrics(local_worker=self.worker)
res["info"] = metrics.info
res["info"].update(metrics.counters)
res["timesteps_total"] = metrics.counters[STEPS_SAMPLED_COUNTER]
self.episode_buffer.add(samples)
return res
def postprocess_gif(self, gif: np.ndarray):
return _postprocess_gif(gif=gif)
class Dreamer(Algorithm):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return DreamerConfig().to_dict()
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
config["action_repeat"] = config["env_config"]["frame_skip"]
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for Dreamer!")
if config["framework"] != "torch":
raise ValueError("Dreamer not supported in Tensorflow yet!")
if config["batch_mode"] != "complete_episodes":
raise ValueError("truncate_episodes not supported")
if config["num_workers"] != 0:
raise ValueError("Distributed Dreamer not supported yet!")
if config["clip_actions"]:
raise ValueError("Clipping is done inherently via policy tanh!")
if config["dreamer_train_iters"] <= 0:
raise ValueError(
"`dreamer_train_iters` must be a positive integer. "
f"Received {config['dreamer_train_iters']} instead."
)
if config["action_repeat"] > 1:
config["horizon"] = config["horizon"] / config["action_repeat"]
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict):
return DreamerTorchPolicy
@override(Algorithm)
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
# `training_iteration` implementation: Setup buffer in `setup`, not
# in `execution_plan` (deprecated).
if self.config["_disable_execution_plan_api"] is True:
self.local_replay_buffer = EpisodeSequenceBuffer(
replay_sequence_length=config["batch_length"]
)
# Prefill episode buffer with initial exploration (uniform sampling)
while (
total_sampled_timesteps(self.workers.local_worker())
< self.config["prefill_timesteps"]
):
samples = self.workers.local_worker().sample()
self.local_replay_buffer.add(samples)
@staticmethod
@override(Algorithm)
def execution_plan(workers, config, **kwargs):
assert (
len(kwargs) == 0
), "Dreamer execution_plan does NOT take any additional parameters"
# Special replay buffer for Dreamer agent.
episode_buffer = EpisodeSequenceBuffer(
replay_sequence_length=config["batch_length"]
)
local_worker = workers.local_worker()
# Prefill episode buffer with initial exploration (uniform sampling)
while total_sampled_timesteps(local_worker) < config["prefill_timesteps"]:
samples = local_worker.sample()
episode_buffer.add(samples)
batch_size = config["batch_size"]
dreamer_train_iters = config["dreamer_train_iters"]
act_repeat = config["action_repeat"]
rollouts = ParallelRollouts(workers)
rollouts = rollouts.for_each(
DreamerIteration(
local_worker,
episode_buffer,
dreamer_train_iters,
batch_size,
act_repeat,
)
)
return rollouts
@override(Algorithm)
def training_step(self) -> ResultDict:
local_worker = self.workers.local_worker()
# Number of sub-iterations for Dreamer
dreamer_train_iters = self.config["dreamer_train_iters"]
batch_size = self.config["batch_size"]
# Collect SampleBatches from rollout workers.
batch = synchronous_parallel_sample(worker_set=self.workers)
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
fetches = {}
# Update target network every `target_network_update_freq` sample steps.
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
]
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
# Dreamer training loop.
# Run multiple sub-iterations for each training iteration.
for n in range(dreamer_train_iters):
print(f"sub-iteration={n}/{dreamer_train_iters}")
batch = self.local_replay_buffer.sample(batch_size)
fetches = local_worker.learn_on_batch(batch)
if fetches:
# Custom logging.
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self._postprocess_gif(gif)
self.local_replay_buffer.add(batch)
return fetches
# Deprecated: Use ray.rllib.algorithms.dreamer.DreamerConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DreamerConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG",
new="ray.rllib.algorithms.dreamer.dreamer.DreamerConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()