mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
463511f8a6
commit
49e2085d78
4 changed files with 26 additions and 6 deletions
4
python/ray/rllib/env/async_vector_env.py
vendored
4
python/ray/rllib/env/async_vector_env.py
vendored
|
@ -214,12 +214,14 @@ class _VectorEnvToAsync(AsyncVectorEnv):
|
|||
self.action_space = vector_env.action_space
|
||||
self.observation_space = vector_env.observation_space
|
||||
self.num_envs = vector_env.num_envs
|
||||
self.new_obs = self.vector_env.vector_reset()
|
||||
self.new_obs = None # lazily initialized
|
||||
self.cur_rewards = [None for _ in range(self.num_envs)]
|
||||
self.cur_dones = [False for _ in range(self.num_envs)]
|
||||
self.cur_infos = [None for _ in range(self.num_envs)]
|
||||
|
||||
def poll(self):
|
||||
if self.new_obs is None:
|
||||
self.new_obs = self.vector_env.vector_reset()
|
||||
new_obs = dict(enumerate(self.new_obs))
|
||||
rewards = dict(enumerate(self.cur_rewards))
|
||||
dones = dict(enumerate(self.cur_dones))
|
||||
|
|
|
@ -55,6 +55,18 @@ class BadPolicyGraph(PolicyGraph):
|
|||
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
||||
|
||||
|
||||
class FailOnStepEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
raise ValueError("kaboom")
|
||||
|
||||
def step(self, action):
|
||||
raise ValueError("kaboom")
|
||||
|
||||
|
||||
class MockEnv(gym.Env):
|
||||
def __init__(self, episode_length, config=None):
|
||||
self.episode_length = episode_length
|
||||
|
@ -151,6 +163,11 @@ class TestPolicyEvaluator(unittest.TestCase):
|
|||
result2 = agent.train()
|
||||
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
|
||||
|
||||
def testNoStepOnInit(self):
|
||||
register_env("fail", lambda _: FailOnStepEnv())
|
||||
pg = PGAgent(env="fail", config={"num_workers": 1})
|
||||
self.assertRaises(Exception, lambda: pg.train())
|
||||
|
||||
def testCallbacks(self):
|
||||
counts = Counter()
|
||||
pg = PGAgent(
|
||||
|
|
|
@ -74,8 +74,10 @@ class RunningStat(object):
|
|||
def push(self, x):
|
||||
x = np.asarray(x)
|
||||
# Unvectorized update of the running statistics.
|
||||
assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}"
|
||||
.format(x.shape, self._M.shape))
|
||||
if x.shape != self._M.shape:
|
||||
raise ValueError(
|
||||
"Unexpected input shape {}, expected {}, value = {}".format(
|
||||
x.shape, self._M.shape, x))
|
||||
n1 = self._n
|
||||
self._n += 1
|
||||
if self._n == 1:
|
||||
|
|
|
@ -45,10 +45,9 @@ class TFRunBuilder(object):
|
|||
self._executed = run_timeline(
|
||||
self.session, self.fetches, self.debug_name,
|
||||
self.feed_dict, os.environ.get("TF_TIMELINE_DIR"))
|
||||
except Exception as e:
|
||||
logger.error("Error fetching: {}, feed_dict={}".format(
|
||||
except Exception:
|
||||
raise ValueError("Error fetching: {}, feed_dict={}".format(
|
||||
self.fetches, self.feed_dict))
|
||||
raise e
|
||||
if isinstance(to_fetch, int):
|
||||
return self._executed[to_fetch]
|
||||
elif isinstance(to_fetch, list):
|
||||
|
|
Loading…
Add table
Reference in a new issue