[rllib] Don't reset envs when possible (#3290)

* laz

* better errors
This commit is contained in:
Eric Liang 2018-11-11 01:45:37 -08:00 committed by GitHub
parent 463511f8a6
commit 49e2085d78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 6 deletions

View file

@ -214,12 +214,14 @@ class _VectorEnvToAsync(AsyncVectorEnv):
self.action_space = vector_env.action_space
self.observation_space = vector_env.observation_space
self.num_envs = vector_env.num_envs
self.new_obs = self.vector_env.vector_reset()
self.new_obs = None # lazily initialized
self.cur_rewards = [None for _ in range(self.num_envs)]
self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [None for _ in range(self.num_envs)]
def poll(self):
if self.new_obs is None:
self.new_obs = self.vector_env.vector_reset()
new_obs = dict(enumerate(self.new_obs))
rewards = dict(enumerate(self.cur_rewards))
dones = dict(enumerate(self.cur_dones))

View file

@ -55,6 +55,18 @@ class BadPolicyGraph(PolicyGraph):
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
class FailOnStepEnv(gym.Env):
def __init__(self):
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
raise ValueError("kaboom")
def step(self, action):
raise ValueError("kaboom")
class MockEnv(gym.Env):
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
@ -151,6 +163,11 @@ class TestPolicyEvaluator(unittest.TestCase):
result2 = agent.train()
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
def testNoStepOnInit(self):
register_env("fail", lambda _: FailOnStepEnv())
pg = PGAgent(env="fail", config={"num_workers": 1})
self.assertRaises(Exception, lambda: pg.train())
def testCallbacks(self):
counts = Counter()
pg = PGAgent(

View file

@ -74,8 +74,10 @@ class RunningStat(object):
def push(self, x):
x = np.asarray(x)
# Unvectorized update of the running statistics.
assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}"
.format(x.shape, self._M.shape))
if x.shape != self._M.shape:
raise ValueError(
"Unexpected input shape {}, expected {}, value = {}".format(
x.shape, self._M.shape, x))
n1 = self._n
self._n += 1
if self._n == 1:

View file

@ -45,10 +45,9 @@ class TFRunBuilder(object):
self._executed = run_timeline(
self.session, self.fetches, self.debug_name,
self.feed_dict, os.environ.get("TF_TIMELINE_DIR"))
except Exception as e:
logger.error("Error fetching: {}, feed_dict={}".format(
except Exception:
raise ValueError("Error fetching: {}, feed_dict={}".format(
self.fetches, self.feed_dict))
raise e
if isinstance(to_fetch, int):
return self._executed[to_fetch]
elif isinstance(to_fetch, list):