mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
440 lines
16 KiB
Python
440 lines
16 KiB
Python
import logging
|
|
import numpy as np
|
|
import random
|
|
from typing import Optional
|
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm
|
|
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
|
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
|
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, concat_samples
|
|
from ray.rllib.evaluation.metrics import collect_metrics
|
|
from ray.rllib.algorithms.dreamer.dreamer_model import DreamerModel
|
|
from ray.rllib.execution.rollout_ops import (
|
|
ParallelRollouts,
|
|
synchronous_parallel_sample,
|
|
)
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.metrics import (
|
|
NUM_AGENT_STEPS_SAMPLED,
|
|
NUM_ENV_STEPS_SAMPLED,
|
|
)
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
|
from ray.rllib.utils.typing import (
|
|
PartialAlgorithmConfigDict,
|
|
AlgorithmConfigDict,
|
|
ResultDict,
|
|
)
|
|
from ray.rllib.utils.replay_buffers import ReplayBuffer, StorageUnit
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DreamerConfig(AlgorithmConfig):
|
|
"""Defines a configuration class from which a Dreamer Algorithm can be built.
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
|
|
>>> config = DreamerConfig().training(gamma=0.9, lr=0.01)\
|
|
... .resources(num_gpus=0)\
|
|
... .rollouts(num_rollout_workers=4)
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
|
>>> trainer = config.build(env="CartPole-v1")
|
|
>>> trainer.train()
|
|
|
|
Example:
|
|
>>> from ray import tune
|
|
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
|
|
>>> config = DreamerConfig()
|
|
>>> # Print out some default values.
|
|
>>> print(config.clip_param)
|
|
>>> # Update the config object.
|
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), clip_param=0.2)
|
|
>>> # Set the config object's env.
|
|
>>> config.environment(env="CartPole-v1")
|
|
>>> # Use to_dict() to get the old-style python config dict
|
|
>>> # when running with tune.
|
|
>>> tune.run(
|
|
... "Dreamer",
|
|
... stop={"episode_reward_mean": 200},
|
|
... config=config.to_dict(),
|
|
... )
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initializes a PPOConfig instance."""
|
|
super().__init__(algo_class=Dreamer)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
# Dreamer specific settings:
|
|
self.td_model_lr = 6e-4
|
|
self.actor_lr = 8e-5
|
|
self.critic_lr = 8e-5
|
|
self.grad_clip = 100.0
|
|
self.lambda_ = 0.95
|
|
self.dreamer_train_iters = 100
|
|
self.batch_size = 50
|
|
self.batch_length = 50
|
|
self.imagine_horizon = 15
|
|
self.free_nats = 3.0
|
|
self.kl_coeff = 1.0
|
|
self.prefill_timesteps = 5000
|
|
self.explore_noise = 0.3
|
|
self.dreamer_model = {
|
|
"custom_model": DreamerModel,
|
|
# RSSM/PlaNET parameters
|
|
"deter_size": 200,
|
|
"stoch_size": 30,
|
|
# CNN Decoder Encoder
|
|
"depth_size": 32,
|
|
# General Network Parameters
|
|
"hidden_size": 400,
|
|
# Action STD
|
|
"action_init_std": 5.0,
|
|
}
|
|
|
|
# Override some of AlgorithmConfig's default values with PPO-specific values.
|
|
# .rollouts()
|
|
self.num_workers = 0
|
|
self.num_envs_per_worker = 1
|
|
self.horizon = 1000
|
|
self.batch_mode = "complete_episodes"
|
|
self.clip_actions = False
|
|
|
|
# .training()
|
|
self.gamma = 0.99
|
|
# Number of timesteps to collect from rollout workers before we start
|
|
# sampling from replay buffers for learning. Whether we count this in agent
|
|
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
|
self.num_steps_sampled_before_learning_starts = 0
|
|
|
|
# .environment()
|
|
self.env_config = {
|
|
# Repeats action send by policy for frame_skip times in env
|
|
"frame_skip": 2,
|
|
}
|
|
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
@override(AlgorithmConfig)
|
|
def training(
|
|
self,
|
|
*,
|
|
td_model_lr: Optional[float] = None,
|
|
actor_lr: Optional[float] = None,
|
|
critic_lr: Optional[float] = None,
|
|
grad_clip: Optional[float] = None,
|
|
lambda_: Optional[float] = None,
|
|
dreamer_train_iters: Optional[int] = None,
|
|
batch_size: Optional[int] = None,
|
|
batch_length: Optional[int] = None,
|
|
imagine_horizon: Optional[int] = None,
|
|
free_nats: Optional[float] = None,
|
|
kl_coeff: Optional[float] = None,
|
|
prefill_timesteps: Optional[int] = None,
|
|
explore_noise: Optional[float] = None,
|
|
dreamer_model: Optional[dict] = None,
|
|
num_steps_sampled_before_learning_starts: Optional[int] = None,
|
|
**kwargs,
|
|
) -> "DreamerConfig":
|
|
"""
|
|
|
|
Args:
|
|
td_model_lr: PlaNET (transition dynamics) model learning rate.
|
|
actor_lr: Actor model learning rate.
|
|
critic_lr: Critic model learning rate.
|
|
grad_clip: If specified, clip the global norm of gradients by this amount.
|
|
lambda_: The GAE (lambda) parameter.
|
|
dreamer_train_iters: Training iterations per data collection from real env.
|
|
batch_size: Number of episodes to sample for loss calculation.
|
|
batch_length: Length of each episode to sample for loss calculation.
|
|
imagine_horizon: Imagination horizon for training Actor and Critic.
|
|
free_nats: Free nats.
|
|
kl_coeff: KL coefficient for the model Loss.
|
|
prefill_timesteps: Prefill timesteps.
|
|
explore_noise: Exploration Gaussian noise.
|
|
dreamer_model: Custom model config.
|
|
num_steps_sampled_before_learning_starts: Number of timesteps to collect
|
|
from rollout workers before we start sampling from replay buffers for
|
|
learning. Whether we count this in agent steps or environment steps
|
|
depends on config["multiagent"]["count_steps_by"].
|
|
|
|
Returns:
|
|
|
|
"""
|
|
|
|
# Pass kwargs onto super's `training()` method.
|
|
super().training(**kwargs)
|
|
|
|
if td_model_lr is not None:
|
|
self.td_model_lr = td_model_lr
|
|
if actor_lr is not None:
|
|
self.actor_lr = actor_lr
|
|
if critic_lr is not None:
|
|
self.critic_lr = critic_lr
|
|
if grad_clip is not None:
|
|
self.grad_clip = grad_clip
|
|
if lambda_ is not None:
|
|
self.lambda_ = lambda_
|
|
if dreamer_train_iters is not None:
|
|
self.dreamer_train_iters = dreamer_train_iters
|
|
if batch_size is not None:
|
|
self.batch_size = batch_size
|
|
if batch_length is not None:
|
|
self.batch_length = batch_length
|
|
if imagine_horizon is not None:
|
|
self.imagine_horizon = imagine_horizon
|
|
if free_nats is not None:
|
|
self.free_nats = free_nats
|
|
if kl_coeff is not None:
|
|
self.kl_coeff = kl_coeff
|
|
if prefill_timesteps is not None:
|
|
self.prefill_timesteps = prefill_timesteps
|
|
if explore_noise is not None:
|
|
self.explore_noise = explore_noise
|
|
if dreamer_model is not None:
|
|
self.dreamer_model = dreamer_model
|
|
if num_steps_sampled_before_learning_starts is not None:
|
|
self.num_steps_sampled_before_learning_starts = (
|
|
num_steps_sampled_before_learning_starts
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
def _postprocess_gif(gif: np.ndarray):
|
|
"""Process provided gif to a format that can be logged to Tensorboard."""
|
|
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
|
|
B, T, C, H, W = gif.shape
|
|
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
|
|
return frames
|
|
|
|
|
|
class EpisodeSequenceBuffer(ReplayBuffer):
|
|
def __init__(self, capacity: int = 1000, replay_sequence_length: int = 50):
|
|
"""Stores episodes and samples sequences of size `replay_sequence_length`.
|
|
|
|
Args:
|
|
capacity: Maximum number of episodes this buffer can store
|
|
replay_sequence_length: Episode chunking length in sample()
|
|
"""
|
|
super().__init__(capacity=capacity, storage_unit=StorageUnit.EPISODES)
|
|
self.replay_sequence_length = replay_sequence_length
|
|
|
|
def sample(self, num_items: int):
|
|
"""Samples [batch_size, length] from the list of episodes
|
|
|
|
Args:
|
|
num_items: batch_size to be sampled
|
|
"""
|
|
episodes_buffer = []
|
|
while len(episodes_buffer) < num_items:
|
|
episode = super().sample(1)
|
|
if episode.count < self.replay_sequence_length:
|
|
continue
|
|
available = episode.count - self.replay_sequence_length
|
|
index = int(random.randint(0, available))
|
|
episodes_buffer.append(episode[index : index + self.replay_sequence_length])
|
|
|
|
return concat_samples(episodes_buffer)
|
|
|
|
|
|
def total_sampled_timesteps(worker):
|
|
return worker.policy_map[DEFAULT_POLICY_ID].global_timestep
|
|
|
|
|
|
class DreamerIteration:
|
|
def __init__(
|
|
self, worker, episode_buffer, dreamer_train_iters, batch_size, act_repeat
|
|
):
|
|
self.worker = worker
|
|
self.episode_buffer = episode_buffer
|
|
self.dreamer_train_iters = dreamer_train_iters
|
|
self.repeat = act_repeat
|
|
self.batch_size = batch_size
|
|
|
|
def __call__(self, samples):
|
|
|
|
# Update target network every `target_network_update_freq` sample steps.
|
|
cur_ts = self._counters[
|
|
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
|
]
|
|
|
|
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
|
# Dreamer training loop.
|
|
for n in range(self.dreamer_train_iters):
|
|
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
|
|
batch = self.episode_buffer.sample(self.batch_size)
|
|
fetches = self.worker.learn_on_batch(batch)
|
|
else:
|
|
fetches = {}
|
|
|
|
# Custom Logging
|
|
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
|
if "log_gif" in policy_fetches:
|
|
gif = policy_fetches["log_gif"]
|
|
policy_fetches["log_gif"] = self.postprocess_gif(gif)
|
|
|
|
# Metrics Calculation
|
|
metrics = _get_shared_metrics()
|
|
metrics.info[LEARNER_INFO] = fetches
|
|
metrics.counters[STEPS_SAMPLED_COUNTER] = self.episode_buffer.timesteps
|
|
metrics.counters[STEPS_SAMPLED_COUNTER] *= self.repeat
|
|
res = collect_metrics(local_worker=self.worker)
|
|
res["info"] = metrics.info
|
|
res["info"].update(metrics.counters)
|
|
res["timesteps_total"] = metrics.counters[STEPS_SAMPLED_COUNTER]
|
|
|
|
self.episode_buffer.add(samples)
|
|
return res
|
|
|
|
def postprocess_gif(self, gif: np.ndarray):
|
|
return _postprocess_gif(gif=gif)
|
|
|
|
|
|
class Dreamer(Algorithm):
|
|
@classmethod
|
|
@override(Algorithm)
|
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
|
return DreamerConfig().to_dict()
|
|
|
|
@override(Algorithm)
|
|
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
|
# Call super's validation method.
|
|
super().validate_config(config)
|
|
|
|
config["action_repeat"] = config["env_config"]["frame_skip"]
|
|
if config["num_gpus"] > 1:
|
|
raise ValueError("`num_gpus` > 1 not yet supported for Dreamer!")
|
|
if config["framework"] != "torch":
|
|
raise ValueError("Dreamer not supported in Tensorflow yet!")
|
|
if config["batch_mode"] != "complete_episodes":
|
|
raise ValueError("truncate_episodes not supported")
|
|
if config["num_workers"] != 0:
|
|
raise ValueError("Distributed Dreamer not supported yet!")
|
|
if config["clip_actions"]:
|
|
raise ValueError("Clipping is done inherently via policy tanh!")
|
|
if config["dreamer_train_iters"] <= 0:
|
|
raise ValueError(
|
|
"`dreamer_train_iters` must be a positive integer. "
|
|
f"Received {config['dreamer_train_iters']} instead."
|
|
)
|
|
if config["action_repeat"] > 1:
|
|
config["horizon"] = config["horizon"] / config["action_repeat"]
|
|
|
|
@override(Algorithm)
|
|
def get_default_policy_class(self, config: AlgorithmConfigDict):
|
|
return DreamerTorchPolicy
|
|
|
|
@override(Algorithm)
|
|
def setup(self, config: PartialAlgorithmConfigDict):
|
|
super().setup(config)
|
|
# `training_iteration` implementation: Setup buffer in `setup`, not
|
|
# in `execution_plan` (deprecated).
|
|
if self.config["_disable_execution_plan_api"] is True:
|
|
self.local_replay_buffer = EpisodeSequenceBuffer(
|
|
replay_sequence_length=config["batch_length"]
|
|
)
|
|
|
|
# Prefill episode buffer with initial exploration (uniform sampling)
|
|
while (
|
|
total_sampled_timesteps(self.workers.local_worker())
|
|
< self.config["prefill_timesteps"]
|
|
):
|
|
samples = self.workers.local_worker().sample()
|
|
self.local_replay_buffer.add(samples)
|
|
|
|
@staticmethod
|
|
@override(Algorithm)
|
|
def execution_plan(workers, config, **kwargs):
|
|
assert (
|
|
len(kwargs) == 0
|
|
), "Dreamer execution_plan does NOT take any additional parameters"
|
|
|
|
# Special replay buffer for Dreamer agent.
|
|
episode_buffer = EpisodeSequenceBuffer(
|
|
replay_sequence_length=config["batch_length"]
|
|
)
|
|
|
|
local_worker = workers.local_worker()
|
|
|
|
# Prefill episode buffer with initial exploration (uniform sampling)
|
|
while total_sampled_timesteps(local_worker) < config["prefill_timesteps"]:
|
|
samples = local_worker.sample()
|
|
episode_buffer.add(samples)
|
|
|
|
batch_size = config["batch_size"]
|
|
dreamer_train_iters = config["dreamer_train_iters"]
|
|
act_repeat = config["action_repeat"]
|
|
|
|
rollouts = ParallelRollouts(workers)
|
|
rollouts = rollouts.for_each(
|
|
DreamerIteration(
|
|
local_worker,
|
|
episode_buffer,
|
|
dreamer_train_iters,
|
|
batch_size,
|
|
act_repeat,
|
|
)
|
|
)
|
|
return rollouts
|
|
|
|
@override(Algorithm)
|
|
def training_step(self) -> ResultDict:
|
|
local_worker = self.workers.local_worker()
|
|
|
|
# Number of sub-iterations for Dreamer
|
|
dreamer_train_iters = self.config["dreamer_train_iters"]
|
|
batch_size = self.config["batch_size"]
|
|
|
|
# Collect SampleBatches from rollout workers.
|
|
batch = synchronous_parallel_sample(worker_set=self.workers)
|
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
|
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
|
|
|
fetches = {}
|
|
|
|
# Update target network every `target_network_update_freq` sample steps.
|
|
cur_ts = self._counters[
|
|
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
|
]
|
|
|
|
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
|
# Dreamer training loop.
|
|
# Run multiple sub-iterations for each training iteration.
|
|
for n in range(dreamer_train_iters):
|
|
print(f"sub-iteration={n}/{dreamer_train_iters}")
|
|
batch = self.local_replay_buffer.sample(batch_size)
|
|
fetches = local_worker.learn_on_batch(batch)
|
|
|
|
if fetches:
|
|
# Custom logging.
|
|
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
|
if "log_gif" in policy_fetches:
|
|
gif = policy_fetches["log_gif"]
|
|
policy_fetches["log_gif"] = self._postprocess_gif(gif)
|
|
|
|
self.local_replay_buffer.add(batch)
|
|
|
|
return fetches
|
|
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.dreamer.DreamerConfig instead!
|
|
class _deprecated_default_config(dict):
|
|
def __init__(self):
|
|
super().__init__(DreamerConfig().to_dict())
|
|
|
|
@Deprecated(
|
|
old="ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG",
|
|
new="ray.rllib.algorithms.dreamer.dreamer.DreamerConfig(...)",
|
|
error=False,
|
|
)
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|