import logging import random import numpy as np from ray.rllib.agents import with_common_config from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ LEARNER_INFO, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.agents.dreamer.dreamer_model import DreamerModel from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # PlaNET Model LR "td_model_lr": 6e-4, # Actor LR "actor_lr": 8e-5, # Critic LR "critic_lr": 8e-5, # Grad Clipping "grad_clip": 100.0, # Discount "discount": 0.99, # Lambda "lambda": 0.95, # Training iterations per data collection from real env "dreamer_train_iters": 100, # Horizon for Enviornment (1000 for Mujoco/DMC) "horizon": 1000, # Number of episodes to sample for Loss Calculation "batch_size": 50, # Length of each episode to sample for Loss Calculation "batch_length": 50, # Imagination Horizon for Training Actor and Critic "imagine_horizon": 15, # Free Nats "free_nats": 3.0, # KL Coeff for the Model Loss "kl_coeff": 1.0, # Distributed Dreamer not implemented yet "num_workers": 0, # Prefill Timesteps "prefill_timesteps": 5000, # This should be kept at 1 to preserve sample efficiency "num_envs_per_worker": 1, # Exploration Gaussian "explore_noise": 0.3, # Batch mode "batch_mode": "complete_episodes", # Custom Model "dreamer_model": { "custom_model": DreamerModel, # RSSM/PlaNET parameters "deter_size": 200, "stoch_size": 30, # CNN Decoder Encoder "depth_size": 32, # General Network Parameters "hidden_size": 400, # Action STD "action_init_std": 5.0, }, "env_config": { # Repeats action send by policy for frame_skip times in env "frame_skip": 2, } }) # __sphinx_doc_end__ # yapf: enable class EpisodicBuffer(object): def __init__(self, max_length: int = 1000, length: int = 50): """Data structure that stores episodes and samples chunks of size length from episodes Args: max_length: Maximum episodes it can store length: Episode chunking lengh in sample() """ # Stores all episodes into a list: List[SampleBatchType] self.episodes = [] self.max_length = max_length self.timesteps = 0 self.length = length def add(self, batch: SampleBatchType): """Splits a SampleBatch into episodes and adds episodes to the episode buffer Args: batch: SampleBatch to be added """ self.timesteps += batch.count episodes = batch.split_by_episode() for i, e in enumerate(episodes): episodes[i] = self.preprocess_episode(e) self.episodes.extend(episodes) if len(self.episodes) > self.max_length: delta = len(self.episodes) - self.max_length # Drop oldest episodes self.episodes = self.episodes[delta:] def preprocess_episode(self, episode: SampleBatchType): """Batch format should be in the form of (s_t, a_(t-1), r_(t-1)) When t=0, the resetted obs is paired with action and reward of 0. Args: episode: SampleBatch representing an episode """ obs = episode["obs"] new_obs = episode["new_obs"] action = episode["actions"] reward = episode["rewards"] act_shape = action.shape act_reset = np.array([0.0] * act_shape[-1])[None] rew_reset = np.array(0.0)[None] obs_end = np.array(new_obs[act_shape[0] - 1])[None] batch_obs = np.concatenate([obs, obs_end], axis=0) batch_action = np.concatenate([act_reset, action], axis=0) batch_rew = np.concatenate([rew_reset, reward], axis=0) new_batch = { "obs": batch_obs, "rewards": batch_rew, "actions": batch_action } return SampleBatch(new_batch) def sample(self, batch_size: int): """Samples [batch_size, length] from the list of episodes Args: batch_size: batch_size to be sampled """ episodes_buffer = [] while len(episodes_buffer) < batch_size: rand_index = random.randint(0, len(self.episodes) - 1) episode = self.episodes[rand_index] if episode.count < self.length: continue available = episode.count - self.length index = int(random.randint(0, available)) episodes_buffer.append(episode.slice(index, index + self.length)) batch = {} for k in episodes_buffer[0].keys(): batch[k] = np.stack([e[k] for e in episodes_buffer], axis=0) return SampleBatch(batch) def total_sampled_timesteps(worker): return worker.policy_map[DEFAULT_POLICY_ID].global_timestep class DreamerIteration: def __init__(self, worker, episode_buffer, dreamer_train_iters, batch_size, act_repeat): self.worker = worker self.episode_buffer = episode_buffer self.dreamer_train_iters = dreamer_train_iters self.repeat = act_repeat self.batch_size = batch_size def __call__(self, samples): # Dreamer Training Loop for n in range(self.dreamer_train_iters): print(n) batch = self.episode_buffer.sample(self.batch_size) if n == self.dreamer_train_iters - 1: batch["log_gif"] = True fetches = self.worker.learn_on_batch(batch) # Custom Logging policy_fetches = self.policy_stats(fetches) if "log_gif" in policy_fetches: gif = policy_fetches["log_gif"] policy_fetches["log_gif"] = self.postprocess_gif(gif) # Metrics Calculation metrics = _get_shared_metrics() metrics.info[LEARNER_INFO] = fetches metrics.counters[STEPS_SAMPLED_COUNTER] = self.episode_buffer.timesteps metrics.counter[STEPS_SAMPLED_COUNTER] *= self.repeat res = collect_metrics(local_worker=self.worker) res["info"] = metrics.info res["info"].update(metrics.counters) res["timesteps_total"] = metrics.counters[STEPS_SAMPLED_COUNTER] self.episode_buffer.add(samples) return res def postprocess_gif(self, gif: np.ndarray): gif = np.clip(255 * gif, 0, 255).astype(np.uint8) B, T, C, H, W = gif.shape frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W)) return frames def policy_stats(self, fetches): return fetches["default_policy"]["learner_stats"] def execution_plan(workers, config): # Special Replay Buffer for Dreamer agent episode_buffer = EpisodicBuffer(length=config["batch_length"]) local_worker = workers.local_worker() # Prefill episode buffer with initial exploration (uniform sampling) while total_sampled_timesteps(local_worker) < config["prefill_timesteps"]: samples = local_worker.sample() episode_buffer.add(samples) batch_size = config["batch_size"] dreamer_train_iters = config["dreamer_train_iters"] act_repeat = config["action_repeat"] rollouts = ParallelRollouts(workers) rollouts = rollouts.for_each( DreamerIteration(local_worker, episode_buffer, dreamer_train_iters, batch_size, act_repeat)) return rollouts def get_policy_class(config): return DreamerTorchPolicy def validate_config(config): config["action_repeat"] = config["env_config"]["frame_skip"] if config["framework"] != "torch": raise ValueError("Dreamer not supported in Tensorflow yet!") if config["batch_mode"] != "complete_episodes": raise ValueError("truncate_episodes not supported") if config["num_workers"] != 0: raise ValueError("Distributed Dreamer not supported yet!") if config["clip_actions"]: raise ValueError("Clipping is done inherently via policy tanh!") if config["action_repeat"] > 1: config["horizon"] = config["horizon"] / config["action_repeat"] DREAMERTrainer = build_trainer( name="Dreamer", default_config=DEFAULT_CONFIG, default_policy=DreamerTorchPolicy, get_policy_class=get_policy_class, execution_plan=execution_plan, validate_config=validate_config)