mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Switch Dreamer
to training_iteration
API. (#24488)
This commit is contained in:
parent
bf6b7f4395
commit
b5b48f6cc7
2 changed files with 88 additions and 22 deletions
|
@ -4,7 +4,7 @@
|
|||
|
||||
## Overview
|
||||
|
||||
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based enviornments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
|
||||
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based environments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
|
||||
|
||||
Dreamer is a gradient-based RL algorithm. This means that the agent imagines ahead using its learned transition dynamics model (PlaNet) to discover new rewards and states. Because imagining ahead is fully differentiable, the RL objective (maximizing the sum of rewards) is fully differentiable and does not need to be optimized indirectly such as policy gradient methods. This feature of gradient-based learning, in conjunction with PlaNet, enables the agent to learn in a latent space and achieves much better sample complexity and performance than other visual-based agents.
|
||||
|
||||
|
@ -14,6 +14,6 @@ For more details, there is a Ray/RLlib [blogpost](https://medium.com/distributed
|
|||
|
||||
Dreamer.
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#dqn)**
|
||||
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#dreamer)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/simple_q.py)**
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dreamer/dreamer.py)**
|
||||
|
|
|
@ -10,10 +10,18 @@ from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metric
|
|||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
SampleBatchType,
|
||||
TrainerConfigDict,
|
||||
ResultDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -76,22 +84,26 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Repeats action send by policy for frame_skip times in env
|
||||
"frame_skip": 2,
|
||||
},
|
||||
|
||||
# Use `execution_plan` instead of `training_iteration`.
|
||||
"_disable_execution_plan_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _postprocess_gif(gif: np.ndarray):
|
||||
"""Process provided gif to a format that can be logged to Tensorboard."""
|
||||
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
|
||||
B, T, C, H, W = gif.shape
|
||||
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
|
||||
return frames
|
||||
|
||||
|
||||
class EpisodicBuffer(object):
|
||||
def __init__(self, max_length: int = 1000, length: int = 50):
|
||||
"""Data structure that stores episodes and samples chunks
|
||||
of size length from episodes
|
||||
"""Stores episodes and samples chunks of size ``length`` from episodes.
|
||||
|
||||
Args:
|
||||
max_length: Maximum episodes it can store
|
||||
length: Episode chunking lengh in sample()
|
||||
length: Episode chunking length in sample()
|
||||
"""
|
||||
|
||||
# Stores all episodes into a list: List[SampleBatchType]
|
||||
|
@ -101,8 +113,7 @@ class EpisodicBuffer(object):
|
|||
self.length = length
|
||||
|
||||
def add(self, batch: SampleBatchType):
|
||||
"""Splits a SampleBatch into episodes and adds episodes
|
||||
to the episode buffer
|
||||
"""Splits a SampleBatch into episodes and adds episodes to the episode buffer.
|
||||
|
||||
Args:
|
||||
batch: SampleBatch to be added
|
||||
|
@ -151,7 +162,6 @@ class DreamerIteration:
|
|||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, samples):
|
||||
|
||||
# Dreamer training loop.
|
||||
for n in range(self.dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
|
||||
|
@ -161,7 +171,7 @@ class DreamerIteration:
|
|||
fetches = self.worker.learn_on_batch(batch)
|
||||
|
||||
# Custom Logging
|
||||
policy_fetches = self.policy_stats(fetches)
|
||||
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
if "log_gif" in policy_fetches:
|
||||
gif = policy_fetches["log_gif"]
|
||||
policy_fetches["log_gif"] = self.postprocess_gif(gif)
|
||||
|
@ -180,13 +190,7 @@ class DreamerIteration:
|
|||
return res
|
||||
|
||||
def postprocess_gif(self, gif: np.ndarray):
|
||||
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
|
||||
B, T, C, H, W = gif.shape
|
||||
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
|
||||
return frames
|
||||
|
||||
def policy_stats(self, fetches):
|
||||
return fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
return _postprocess_gif(gif=gif)
|
||||
|
||||
|
||||
class DREAMERTrainer(Trainer):
|
||||
|
@ -211,6 +215,11 @@ class DREAMERTrainer(Trainer):
|
|||
raise ValueError("Distributed Dreamer not supported yet!")
|
||||
if config["clip_actions"]:
|
||||
raise ValueError("Clipping is done inherently via policy tanh!")
|
||||
if config["dreamer_train_iters"] <= 0:
|
||||
raise ValueError(
|
||||
"`dreamer_train_iters` must be a positive integer. "
|
||||
f"Received {config['dreamer_train_iters']} instead."
|
||||
)
|
||||
if config["action_repeat"] > 1:
|
||||
config["horizon"] = config["horizon"] / config["action_repeat"]
|
||||
|
||||
|
@ -218,6 +227,22 @@ class DREAMERTrainer(Trainer):
|
|||
def get_default_policy_class(self, config: TrainerConfigDict):
|
||||
return DreamerTorchPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
super().setup(config)
|
||||
# `training_iteration` implementation: Setup buffer in `setup`, not
|
||||
# in `execution_plan` (deprecated).
|
||||
if self.config["_disable_execution_plan_api"] is True:
|
||||
self.local_replay_buffer = EpisodicBuffer(length=config["batch_length"])
|
||||
|
||||
# Prefill episode buffer with initial exploration (uniform sampling)
|
||||
while (
|
||||
total_sampled_timesteps(self.workers.local_worker())
|
||||
< self.config["prefill_timesteps"]
|
||||
):
|
||||
samples = self.workers.local_worker().sample()
|
||||
self.local_replay_buffer.add(samples)
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
|
@ -250,3 +275,44 @@ class DREAMERTrainer(Trainer):
|
|||
)
|
||||
)
|
||||
return rollouts
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
# Number of sub-iterations for Dreamer
|
||||
dreamer_train_iters = self.config["dreamer_train_iters"]
|
||||
batch_size = self.config["batch_size"]
|
||||
action_repeat = self.config["action_repeat"]
|
||||
|
||||
# Collect SampleBatches from rollout workers.
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
|
||||
fetches = {}
|
||||
|
||||
# Dreamer training loop.
|
||||
# Run multiple sub-iterations for each training iteration.
|
||||
for n in range(dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{dreamer_train_iters}")
|
||||
batch = self.local_replay_buffer.sample(batch_size)
|
||||
fetches = local_worker.learn_on_batch(batch)
|
||||
|
||||
if fetches:
|
||||
# Custom Logging
|
||||
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
if "log_gif" in policy_fetches:
|
||||
gif = policy_fetches["log_gif"]
|
||||
policy_fetches["log_gif"] = self._postprocess_gif(gif)
|
||||
|
||||
self._counters[STEPS_SAMPLED_COUNTER] = (
|
||||
self.local_replay_buffer.timesteps * action_repeat
|
||||
)
|
||||
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
return fetches
|
||||
|
||||
def _compile_step_results(self, *args, **kwargs):
|
||||
results = super()._compile_step_results(*args, **kwargs)
|
||||
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
|
||||
return results
|
||||
|
|
Loading…
Add table
Reference in a new issue