from __future__ import absolute_import from __future__ import division from __future__ import print_function import cv2 import gym from gym.spaces.box import Box from gym import spaces import logging import numpy as np import time import vectorized from vectorized.wrappers import Unvectorize, Vectorize logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def create_env(env_id, client_id, remotes, **kwargs): return create_atari_env(env_id) def create_atari_env(env_id): env = gym.make(env_id) env = Vectorize(env) env = AtariRescale42x42(env) env = DiagnosticsInfo(env) env = Unvectorize(env) return env def DiagnosticsInfo(env, *args, **kwargs): return vectorized.VectorizeFilter(env, DiagnosticsInfoI, *args, **kwargs) class DiagnosticsInfoI(vectorized.Filter): def __init__(self, log_interval=503): super(DiagnosticsInfoI, self).__init__() self._episode_time = time.time() self._last_time = time.time() self._local_t = 0 self._log_interval = log_interval self._episode_reward = 0 self._episode_length = 0 self._all_rewards = [] self._num_vnc_updates = 0 self._last_episode_id = -1 def _after_reset(self, observation): logger.info('Resetting environment') self._episode_reward = 0 self._episode_length = 0 self._all_rewards = [] return observation def _after_step(self, observation, reward, done, info): to_log = {} if self._episode_length == 0: self._episode_time = time.time() self._local_t += 1 if info.get("stats.vnc.updates.n") is not None: self._num_vnc_updates += info.get("stats.vnc.updates.n") if self._local_t % self._log_interval == 0: cur_time = time.time() elapsed = cur_time - self._last_time fps = self._log_interval / elapsed self._last_time = cur_time cur_episode_id = info.get('vectorized.episode_id', 0) to_log["diagnostics/fps"] = fps if self._last_episode_id == cur_episode_id: to_log["diagnostics/fps_within_episode"] = fps self._last_episode_id = cur_episode_id if info.get("stats.gauges.diagnostics.lag.action") is not None: to_log["diagnostics/action_lag_lb"] = info["stats.gauges.diagnostics.lag.action"][0] to_log["diagnostics/action_lag_ub"] = info["stats.gauges.diagnostics.lag.action"][1] if info.get("reward.count") is not None: to_log["diagnostics/reward_count"] = info["reward.count"] if info.get("stats.gauges.diagnostics.clock_skew") is not None: to_log["diagnostics/clock_skew_lb"] = info["stats.gauges.diagnostics.clock_skew"][0] to_log["diagnostics/clock_skew_ub"] = info["stats.gauges.diagnostics.clock_skew"][1] if info.get("stats.gauges.diagnostics.lag.observation") is not None: to_log["diagnostics/observation_lag_lb"] = info["stats.gauges.diagnostics.lag.observation"][0] to_log["diagnostics/observation_lag_ub"] = info["stats.gauges.diagnostics.lag.observation"][1] if info.get("stats.vnc.updates.n") is not None: to_log["diagnostics/vnc_updates_n"] = info["stats.vnc.updates.n"] to_log["diagnostics/vnc_updates_n_ps"] = self._num_vnc_updates / elapsed self._num_vnc_updates = 0 if info.get("stats.vnc.updates.bytes") is not None: to_log["diagnostics/vnc_updates_bytes"] = info["stats.vnc.updates.bytes"] if info.get("stats.vnc.updates.pixels") is not None: to_log["diagnostics/vnc_updates_pixels"] = info["stats.vnc.updates.pixels"] if info.get("stats.vnc.updates.rectangles") is not None: to_log["diagnostics/vnc_updates_rectangles"] = info["stats.vnc.updates.rectangles"] if info.get("env_status.state_id") is not None: to_log["diagnostics/env_state_id"] = info["env_status.state_id"] if reward is not None: self._episode_reward += reward if observation is not None: self._episode_length += 1 self._all_rewards.append(reward) if done: logger.info('Episode terminating: episode_reward=%s episode_length=%s', self._episode_reward, self._episode_length) total_time = time.time() - self._episode_time to_log["global/episode_reward"] = self._episode_reward to_log["global/episode_length"] = self._episode_length to_log["global/episode_time"] = total_time to_log["global/reward_per_time"] = self._episode_reward / total_time self._episode_reward = 0 self._episode_length = 0 self._all_rewards = [] return observation, reward, done, to_log def _process_frame42(frame): frame = frame[34:34+160, :160] # Resize by half, then down to 42x42 (essentially mipmapping). If # we resize directly we lose pixels that, when mapped to 42x42, # aren't close enough to the pixel boundary. frame = cv2.resize(frame, (80, 80)) frame = cv2.resize(frame, (42, 42)) frame = frame.mean(2) frame = frame.astype(np.float32) frame *= (1.0 / 255.0) frame = np.reshape(frame, [42, 42, 1]) return frame class AtariRescale42x42(vectorized.ObservationWrapper): def __init__(self, env=None): super(AtariRescale42x42, self).__init__(env) self.observation_space = Box(0.0, 1.0, [42, 42, 1]) def _observation(self, observation_n): return [_process_frame42(observation) for observation in observation_n] class CropScreen(vectorized.ObservationWrapper): """Crops out a [height]x[width] area starting from (top,left) """ def __init__(self, env, height, width, top=0, left=0): super(CropScreen, self).__init__(env) self.height = height self.width = width self.top = top self.left = left self.observation_space = Box(0, 255, shape=(height, width, 3)) def _observation(self, observation_n): return [ob[self.top:self.top+self.height, self.left:self.left+self.width, :] if ob is not None else None for ob in observation_n] def _process_frame_flash(frame): frame = cv2.resize(frame, (200, 128)) frame = frame.mean(2).astype(np.float32) frame *= (1.0 / 255.0) frame = np.reshape(frame, [128, 200, 1]) return frame class FlashRescale(vectorized.ObservationWrapper): def __init__(self, env=None): super(FlashRescale, self).__init__(env) self.observation_space = Box(0.0, 1.0, [128, 200, 1]) def _observation(self, observation_n): return [_process_frame_flash(observation) for observation in observation_n]