ray/rllib/agents/dreamer/dreamer.py

268 lines
8.7 KiB
Python
Raw Normal View History

2020-08-26 04:24:05 -07:00
import logging
import random
import numpy as np
from ray.rllib.agents import with_common_config
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
LEARNER_INFO, _get_shared_metrics
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.utils.typing import SampleBatchType
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# PlaNET Model LR
"td_model_lr": 6e-4,
# Actor LR
"actor_lr": 8e-5,
# Critic LR
"critic_lr": 8e-5,
# Grad Clipping
"grad_clip": 100.0,
# Discount
"discount": 0.99,
# Lambda
"lambda": 0.95,
# Training iterations per data collection from real env
"dreamer_train_iters": 100,
# Horizon for Enviornment (1000 for Mujoco/DMC)
"horizon": 1000,
# Number of episodes to sample for Loss Calculation
"batch_size": 50,
# Length of each episode to sample for Loss Calculation
"batch_length": 50,
# Imagination Horizon for Training Actor and Critic
"imagine_horizon": 15,
# Free Nats
"free_nats": 3.0,
# KL Coeff for the Model Loss
"kl_coeff": 1.0,
# Distributed Dreamer not implemented yet
"num_workers": 0,
# Prefill Timesteps
"prefill_timesteps": 5000,
# This should be kept at 1 to preserve sample efficiency
"num_envs_per_worker": 1,
# Exploration Gaussian
"explore_noise": 0.3,
# Batch mode
"batch_mode": "complete_episodes",
# Custom Model
"dreamer_model": {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
},
"env_config": {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
}
})
# __sphinx_doc_end__
# yapf: enable
class EpisodicBuffer(object):
def __init__(self, max_length: int = 1000, length: int = 50):
"""Data structure that stores episodes and samples chunks
of size length from episodes
Args:
max_length: Maximum episodes it can store
length: Episode chunking lengh in sample()
"""
# Stores all episodes into a list: List[SampleBatchType]
self.episodes = []
self.max_length = max_length
self.timesteps = 0
self.length = length
def add(self, batch: SampleBatchType):
"""Splits a SampleBatch into episodes and adds episodes
to the episode buffer
Args:
batch: SampleBatch to be added
"""
self.timesteps += batch.count
episodes = batch.split_by_episode()
for i, e in enumerate(episodes):
episodes[i] = self.preprocess_episode(e)
self.episodes.extend(episodes)
if len(self.episodes) > self.max_length:
delta = len(self.episodes) - self.max_length
# Drop oldest episodes
self.episodes = self.episodes[delta:]
def preprocess_episode(self, episode: SampleBatchType):
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
When t=0, the resetted obs is paired with action and reward of 0.
Args:
episode: SampleBatch representing an episode
"""
obs = episode["obs"]
new_obs = episode["new_obs"]
action = episode["actions"]
reward = episode["rewards"]
act_shape = action.shape
act_reset = np.array([0.0] * act_shape[-1])[None]
rew_reset = np.array(0.0)[None]
obs_end = np.array(new_obs[act_shape[0] - 1])[None]
batch_obs = np.concatenate([obs, obs_end], axis=0)
batch_action = np.concatenate([act_reset, action], axis=0)
batch_rew = np.concatenate([rew_reset, reward], axis=0)
new_batch = {
"obs": batch_obs,
"rewards": batch_rew,
"actions": batch_action
}
return SampleBatch(new_batch)
def sample(self, batch_size: int):
"""Samples [batch_size, length] from the list of episodes
Args:
batch_size: batch_size to be sampled
"""
episodes_buffer = []
while len(episodes_buffer) < batch_size:
rand_index = random.randint(0, len(self.episodes) - 1)
episode = self.episodes[rand_index]
if episode.count < self.length:
continue
available = episode.count - self.length
index = int(random.randint(0, available))
episodes_buffer.append(episode.slice(index, index + self.length))
batch = {}
for k in episodes_buffer[0].keys():
batch[k] = np.stack([e[k] for e in episodes_buffer], axis=0)
return SampleBatch(batch)
def total_sampled_timesteps(worker):
return worker.policy_map[DEFAULT_POLICY_ID].global_timestep
class DreamerIteration:
def __init__(self, worker, episode_buffer, dreamer_train_iters, batch_size,
act_repeat):
self.worker = worker
self.episode_buffer = episode_buffer
self.dreamer_train_iters = dreamer_train_iters
self.repeat = act_repeat
self.batch_size = batch_size
def __call__(self, samples):
# Dreamer Training Loop
for n in range(self.dreamer_train_iters):
print(n)
batch = self.episode_buffer.sample(self.batch_size)
if n == self.dreamer_train_iters - 1:
batch["log_gif"] = True
fetches = self.worker.learn_on_batch(batch)
# Custom Logging
policy_fetches = self.policy_stats(fetches)
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self.postprocess_gif(gif)
# Metrics Calculation
metrics = _get_shared_metrics()
metrics.info[LEARNER_INFO] = fetches
metrics.counters[STEPS_SAMPLED_COUNTER] = self.episode_buffer.timesteps
metrics.counter[STEPS_SAMPLED_COUNTER] *= self.repeat
res = collect_metrics(local_worker=self.worker)
res["info"] = metrics.info
res["info"].update(metrics.counters)
res["timesteps_total"] = metrics.counters[STEPS_SAMPLED_COUNTER]
self.episode_buffer.add(samples)
return res
def postprocess_gif(self, gif: np.ndarray):
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
B, T, C, H, W = gif.shape
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
return frames
def policy_stats(self, fetches):
return fetches["default_policy"]["learner_stats"]
def execution_plan(workers, config):
# Special Replay Buffer for Dreamer agent
episode_buffer = EpisodicBuffer(length=config["batch_length"])
local_worker = workers.local_worker()
# Prefill episode buffer with initial exploration (uniform sampling)
while total_sampled_timesteps(local_worker) < config["prefill_timesteps"]:
samples = local_worker.sample()
episode_buffer.add(samples)
batch_size = config["batch_size"]
dreamer_train_iters = config["dreamer_train_iters"]
act_repeat = config["action_repeat"]
rollouts = ParallelRollouts(workers)
rollouts = rollouts.for_each(
DreamerIteration(local_worker, episode_buffer, dreamer_train_iters,
batch_size, act_repeat))
return rollouts
def get_policy_class(config):
return DreamerTorchPolicy
def validate_config(config):
config["action_repeat"] = config["env_config"]["frame_skip"]
if config["framework"] != "torch":
raise ValueError("Dreamer not supported in Tensorflow yet!")
if config["batch_mode"] != "complete_episodes":
raise ValueError("truncate_episodes not supported")
if config["num_workers"] != 0:
raise ValueError("Distributed Dreamer not supported yet!")
if config["clip_actions"]:
raise ValueError("Clipping is done inherently via policy tanh!")
if config["action_repeat"] > 1:
config["horizon"] = config["horizon"] / config["action_repeat"]
DREAMERTrainer = build_trainer(
name="Dreamer",
default_config=DEFAULT_CONFIG,
default_policy=DreamerTorchPolicy,
get_policy_class=get_policy_class,
execution_plan=execution_plan,
validate_config=validate_config)