[RLlib] Switch Dreamer to training_iteration API. (#24488)

This commit is contained in:
Amog Kamsetty 2022-05-09 23:37:34 -07:00 committed by GitHub
parent bf6b7f4395
commit b5b48f6cc7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 22 deletions

View file

@ -4,7 +4,7 @@
## Overview
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based enviornments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based environments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
Dreamer is a gradient-based RL algorithm. This means that the agent imagines ahead using its learned transition dynamics model (PlaNet) to discover new rewards and states. Because imagining ahead is fully differentiable, the RL objective (maximizing the sum of rewards) is fully differentiable and does not need to be optimized indirectly such as policy gradient methods. This feature of gradient-based learning, in conjunction with PlaNet, enables the agent to learn in a latent space and achieves much better sample complexity and performance than other visual-based agents.
@ -14,6 +14,6 @@ For more details, there is a Ray/RLlib [blogpost](https://medium.com/distributed
Dreamer.
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#dqn)**
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#dreamer)**
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/simple_q.py)**
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dreamer/dreamer.py)**

View file

@ -10,10 +10,18 @@ from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metric
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
SampleBatchType,
TrainerConfigDict,
ResultDict,
)
logger = logging.getLogger(__name__)
@ -76,22 +84,26 @@ DEFAULT_CONFIG = with_common_config({
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
},
# Use `execution_plan` instead of `training_iteration`.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on
def _postprocess_gif(gif: np.ndarray):
"""Process provided gif to a format that can be logged to Tensorboard."""
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
B, T, C, H, W = gif.shape
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
return frames
class EpisodicBuffer(object):
def __init__(self, max_length: int = 1000, length: int = 50):
"""Data structure that stores episodes and samples chunks
of size length from episodes
"""Stores episodes and samples chunks of size ``length`` from episodes.
Args:
max_length: Maximum episodes it can store
length: Episode chunking lengh in sample()
length: Episode chunking length in sample()
"""
# Stores all episodes into a list: List[SampleBatchType]
@ -101,8 +113,7 @@ class EpisodicBuffer(object):
self.length = length
def add(self, batch: SampleBatchType):
"""Splits a SampleBatch into episodes and adds episodes
to the episode buffer
"""Splits a SampleBatch into episodes and adds episodes to the episode buffer.
Args:
batch: SampleBatch to be added
@ -151,7 +162,6 @@ class DreamerIteration:
self.batch_size = batch_size
def __call__(self, samples):
# Dreamer training loop.
for n in range(self.dreamer_train_iters):
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
@ -161,7 +171,7 @@ class DreamerIteration:
fetches = self.worker.learn_on_batch(batch)
# Custom Logging
policy_fetches = self.policy_stats(fetches)
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self.postprocess_gif(gif)
@ -180,13 +190,7 @@ class DreamerIteration:
return res
def postprocess_gif(self, gif: np.ndarray):
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
B, T, C, H, W = gif.shape
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
return frames
def policy_stats(self, fetches):
return fetches[DEFAULT_POLICY_ID]["learner_stats"]
return _postprocess_gif(gif=gif)
class DREAMERTrainer(Trainer):
@ -211,6 +215,11 @@ class DREAMERTrainer(Trainer):
raise ValueError("Distributed Dreamer not supported yet!")
if config["clip_actions"]:
raise ValueError("Clipping is done inherently via policy tanh!")
if config["dreamer_train_iters"] <= 0:
raise ValueError(
"`dreamer_train_iters` must be a positive integer. "
f"Received {config['dreamer_train_iters']} instead."
)
if config["action_repeat"] > 1:
config["horizon"] = config["horizon"] / config["action_repeat"]
@ -218,6 +227,22 @@ class DREAMERTrainer(Trainer):
def get_default_policy_class(self, config: TrainerConfigDict):
return DreamerTorchPolicy
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# `training_iteration` implementation: Setup buffer in `setup`, not
# in `execution_plan` (deprecated).
if self.config["_disable_execution_plan_api"] is True:
self.local_replay_buffer = EpisodicBuffer(length=config["batch_length"])
# Prefill episode buffer with initial exploration (uniform sampling)
while (
total_sampled_timesteps(self.workers.local_worker())
< self.config["prefill_timesteps"]
):
samples = self.workers.local_worker().sample()
self.local_replay_buffer.add(samples)
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
@ -250,3 +275,44 @@ class DREAMERTrainer(Trainer):
)
)
return rollouts
@override(Trainer)
def training_iteration(self) -> ResultDict:
local_worker = self.workers.local_worker()
# Number of sub-iterations for Dreamer
dreamer_train_iters = self.config["dreamer_train_iters"]
batch_size = self.config["batch_size"]
action_repeat = self.config["action_repeat"]
# Collect SampleBatches from rollout workers.
batch = synchronous_parallel_sample(worker_set=self.workers)
fetches = {}
# Dreamer training loop.
# Run multiple sub-iterations for each training iteration.
for n in range(dreamer_train_iters):
print(f"sub-iteration={n}/{dreamer_train_iters}")
batch = self.local_replay_buffer.sample(batch_size)
fetches = local_worker.learn_on_batch(batch)
if fetches:
# Custom Logging
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self._postprocess_gif(gif)
self._counters[STEPS_SAMPLED_COUNTER] = (
self.local_replay_buffer.timesteps * action_repeat
)
self.local_replay_buffer.add(batch)
return fetches
def _compile_step_results(self, *args, **kwargs):
results = super()._compile_step_results(*args, **kwargs)
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
return results