Richard Liaw b463d9e5c7 Initial A3C Example - PongDeterministic-v3 (#331)
* Initializing A3C code

* Modifications for Ray usage

* cleanup

* removing universe dependency

* fixes (not yet working

* hack

* documentation

* Cleanup

* Preliminary Portion

Make sure to change when merging

* RL part

* Cleaning up Driver and Worker code

* Updating driver code

* instructions...

* fixed

* Minor changes.

* Fixing cmake issues

* ray instruction

* updating port to new universe

* Fix for env.configure

* redundant commands

* Revert scipy.misc -> cv2 and raise exception for wrong gym version.
2017-03-11 00:57:53 -08:00

166 lines
6.8 KiB

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import gym
from gym.spaces.box import Box
from gym import spaces
import logging
import numpy as np
import time
import vectorized
from vectorized.wrappers import Unvectorize, Vectorize
logger = logging.getLogger(__name__)
def create_env(env_id, client_id, remotes, **kwargs):
return create_atari_env(env_id)
def create_atari_env(env_id):
env = gym.make(env_id)
env = Vectorize(env)
env = AtariRescale42x42(env)
env = DiagnosticsInfo(env)
env = Unvectorize(env)
return env
def DiagnosticsInfo(env, *args, **kwargs):
return vectorized.VectorizeFilter(env, DiagnosticsInfoI, *args, **kwargs)
class DiagnosticsInfoI(vectorized.Filter):
def __init__(self, log_interval=503):
super(DiagnosticsInfoI, self).__init__()
self._episode_time = time.time()
self._last_time = time.time()
self._local_t = 0
self._log_interval = log_interval
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
self._num_vnc_updates = 0
self._last_episode_id = -1
def _after_reset(self, observation):
logger.info('Resetting environment')
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
return observation
def _after_step(self, observation, reward, done, info):
to_log = {}
if self._episode_length == 0:
self._episode_time = time.time()
self._local_t += 1
if info.get("stats.vnc.updates.n") is not None:
self._num_vnc_updates += info.get("stats.vnc.updates.n")
if self._local_t % self._log_interval == 0:
cur_time = time.time()
elapsed = cur_time - self._last_time
fps = self._log_interval / elapsed
self._last_time = cur_time
cur_episode_id = info.get('vectorized.episode_id', 0)
to_log["diagnostics/fps"] = fps
if self._last_episode_id == cur_episode_id:
to_log["diagnostics/fps_within_episode"] = fps
self._last_episode_id = cur_episode_id
if info.get("stats.gauges.diagnostics.lag.action") is not None:
to_log["diagnostics/action_lag_lb"] = info["stats.gauges.diagnostics.lag.action"][0]
to_log["diagnostics/action_lag_ub"] = info["stats.gauges.diagnostics.lag.action"][1]
if info.get("reward.count") is not None:
to_log["diagnostics/reward_count"] = info["reward.count"]
if info.get("stats.gauges.diagnostics.clock_skew") is not None:
to_log["diagnostics/clock_skew_lb"] = info["stats.gauges.diagnostics.clock_skew"][0]
to_log["diagnostics/clock_skew_ub"] = info["stats.gauges.diagnostics.clock_skew"][1]
if info.get("stats.gauges.diagnostics.lag.observation") is not None:
to_log["diagnostics/observation_lag_lb"] = info["stats.gauges.diagnostics.lag.observation"][0]
to_log["diagnostics/observation_lag_ub"] = info["stats.gauges.diagnostics.lag.observation"][1]
if info.get("stats.vnc.updates.n") is not None:
to_log["diagnostics/vnc_updates_n"] = info["stats.vnc.updates.n"]
to_log["diagnostics/vnc_updates_n_ps"] = self._num_vnc_updates / elapsed
self._num_vnc_updates = 0
if info.get("stats.vnc.updates.bytes") is not None:
to_log["diagnostics/vnc_updates_bytes"] = info["stats.vnc.updates.bytes"]
if info.get("stats.vnc.updates.pixels") is not None:
to_log["diagnostics/vnc_updates_pixels"] = info["stats.vnc.updates.pixels"]
if info.get("stats.vnc.updates.rectangles") is not None:
to_log["diagnostics/vnc_updates_rectangles"] = info["stats.vnc.updates.rectangles"]
if info.get("env_status.state_id") is not None:
to_log["diagnostics/env_state_id"] = info["env_status.state_id"]
if reward is not None:
self._episode_reward += reward
if observation is not None:
self._episode_length += 1
if done:
logger.info('Episode terminating: episode_reward=%s episode_length=%s', self._episode_reward, self._episode_length)
total_time = time.time() - self._episode_time
to_log["global/episode_reward"] = self._episode_reward
to_log["global/episode_length"] = self._episode_length
to_log["global/episode_time"] = total_time
to_log["global/reward_per_time"] = self._episode_reward / total_time
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
return observation, reward, done, to_log
def _process_frame42(frame):
frame = frame[34:34+160, :160]
# Resize by half, then down to 42x42 (essentially mipmapping). If
# we resize directly we lose pixels that, when mapped to 42x42,
# aren't close enough to the pixel boundary.
frame = cv2.resize(frame, (80, 80))
frame = cv2.resize(frame, (42, 42))
frame = frame.mean(2)
frame = frame.astype(np.float32)
frame *= (1.0 / 255.0)
frame = np.reshape(frame, [42, 42, 1])
return frame
class AtariRescale42x42(vectorized.ObservationWrapper):
def __init__(self, env=None):
super(AtariRescale42x42, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [42, 42, 1])
def _observation(self, observation_n):
return [_process_frame42(observation) for observation in observation_n]
class CropScreen(vectorized.ObservationWrapper):
"""Crops out a [height]x[width] area starting from (top,left) """
def __init__(self, env, height, width, top=0, left=0):
super(CropScreen, self).__init__(env)
self.height = height
self.width = width
self.top = top
self.left = left
self.observation_space = Box(0, 255, shape=(height, width, 3))
def _observation(self, observation_n):
return [ob[self.top:self.top+self.height, self.left:self.left+self.width, :] if ob is not None else None
for ob in observation_n]
def _process_frame_flash(frame):
frame = cv2.resize(frame, (200, 128))
frame = frame.mean(2).astype(np.float32)
frame *= (1.0 / 255.0)
frame = np.reshape(frame, [128, 200, 1])
return frame
class FlashRescale(vectorized.ObservationWrapper):
def __init__(self, env=None):
super(FlashRescale, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [128, 200, 1])
def _observation(self, observation_n):
return [_process_frame_flash(observation) for observation in observation_n]