mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] TF2/eager memory leak fixes. (#19198)
This commit is contained in:
parent
47447c71e0
commit
d439fd7f17
8 changed files with 70 additions and 33 deletions
|
@ -10,9 +10,9 @@ kaggle_environments==1.7.11
|
|||
# Unity3D testing
|
||||
mlagents_envs==0.27.0
|
||||
# For tests on PettingZoo's multi-agent envs.
|
||||
pettingzoo==1.11.0
|
||||
pettingzoo==1.11.1
|
||||
pymunk==6.0.0
|
||||
supersuit
|
||||
supersuit==2.6.6
|
||||
# For testing in MuJoCo-like envs (in PyBullet).
|
||||
pybullet==3.1.7
|
||||
# For tests on RecSim and Kaggle envs.
|
||||
|
|
|
@ -92,7 +92,7 @@ class TestPPO(unittest.TestCase):
|
|||
config["model"]["lstm_cell_size"] = 10
|
||||
config["model"]["max_seq_len"] = 20
|
||||
# Use default-native keras models whenever possible.
|
||||
config["model"]["_use_default_native_models"] = True
|
||||
# config["model"]["_use_default_native_models"] = True
|
||||
|
||||
# Setup lr- and entropy schedules for testing.
|
||||
config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]]
|
||||
|
|
1
rllib/env/tests/test_remote_worker_envs.py
vendored
1
rllib/env/tests/test_remote_worker_envs.py
vendored
|
@ -23,6 +23,7 @@ def env_creator(config):
|
|||
|
||||
|
||||
tune.register_env("cartpole", lambda env_ctx: gym.make("CartPole-v0"))
|
||||
|
||||
tune.register_env("pistonball",
|
||||
lambda config: PettingZooEnv(env_creator(config)))
|
||||
|
||||
|
|
|
@ -756,8 +756,11 @@ class SimpleListCollector(SampleCollector):
|
|||
"True. Alternatively, set no_done_at_end=True to "
|
||||
"allow this.")
|
||||
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
if len(pre_batches) > 1:
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
else:
|
||||
other_batches = {}
|
||||
pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
|
||||
policy = self.policy_map[pid]
|
||||
if any(pre_batch[SampleBatch.DONES][:-1]) or len(
|
||||
|
|
22
rllib/examples/env/random_env.py
vendored
22
rllib/examples/env/random_env.py
vendored
|
@ -65,3 +65,25 @@ class RandomEnv(gym.Env):
|
|||
|
||||
# Multi-agent version of the RandomEnv.
|
||||
RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c))
|
||||
|
||||
|
||||
# Large observation space "pre-compiled" random env (for testing).
|
||||
class RandomLargeObsSpaceEnv(RandomEnv):
|
||||
def __init__(self, config=None):
|
||||
config = config or {}
|
||||
config.update({
|
||||
"observation_space": gym.spaces.Box(-1.0, 1.0, (5000, ))
|
||||
})
|
||||
super().__init__(config=config)
|
||||
|
||||
|
||||
# Large observation space + cont. actions "pre-compiled" random env
|
||||
# (for testing).
|
||||
class RandomLargeObsSpaceEnvContActions(RandomEnv):
|
||||
def __init__(self, config=None):
|
||||
config = config or {}
|
||||
config.update({
|
||||
"observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )),
|
||||
"action_space": gym.spaces.Box(-1.0, 1.0, (5, )),
|
||||
})
|
||||
super().__init__(config=config)
|
||||
|
|
|
@ -66,15 +66,17 @@ def convert_eager_inputs(func):
|
|||
@functools.wraps(func)
|
||||
def _func(*args, **kwargs):
|
||||
if tf.executing_eagerly():
|
||||
args = [_convert_to_tf(x) for x in args]
|
||||
eager_args = [_convert_to_tf(x) for x in args]
|
||||
# TODO: (sven) find a way to remove key-specific hacks.
|
||||
kwargs = {
|
||||
eager_kwargs = {
|
||||
k: _convert_to_tf(
|
||||
v, dtype=tf.int64 if k == "timestep" else None)
|
||||
for k, v in kwargs.items()
|
||||
if k not in {"info_batch", "episodes"}
|
||||
}
|
||||
return func(*args, **kwargs)
|
||||
return func(*eager_args, **eager_kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _func
|
||||
|
||||
|
@ -183,6 +185,14 @@ def traced_eager_policy(eager_policy_cls):
|
|||
return TracedEagerPolicy
|
||||
|
||||
|
||||
class OptimizerWrapper:
|
||||
def __init__(self, tape):
|
||||
self.tape = tape
|
||||
|
||||
def compute_gradients(self, loss, var_list):
|
||||
return list(zip(self.tape.gradient(loss, var_list), var_list))
|
||||
|
||||
|
||||
def build_eager_tf_policy(
|
||||
name,
|
||||
loss_fn,
|
||||
|
@ -433,6 +443,7 @@ def build_eager_tf_policy(
|
|||
lambda s: tf.convert_to_tensor(s), obs_batch),
|
||||
},
|
||||
_is_training=tf.constant(False))
|
||||
self._lazy_tensor_dict(input_dict)
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
tf.convert_to_tensor(prev_action_batch)
|
||||
|
@ -466,7 +477,6 @@ def build_eager_tf_policy(
|
|||
explore, timestep)
|
||||
|
||||
@with_lock
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def _compute_action_helper(self, input_dict, state_batches, episodes,
|
||||
explore, timestep):
|
||||
|
@ -482,7 +492,8 @@ def build_eager_tf_policy(
|
|||
self._is_training = False
|
||||
self._state_in = state_batches or []
|
||||
# Calculate RNN sequence lengths.
|
||||
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
|
||||
batch_size = int(
|
||||
tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0])
|
||||
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
|
||||
else None
|
||||
|
||||
|
@ -567,7 +578,7 @@ def build_eager_tf_policy(
|
|||
extra_fetches.update(extra_action_out_fn(self))
|
||||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += int(batch_size)
|
||||
self.global_timestep += batch_size
|
||||
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
|
@ -747,15 +758,6 @@ def build_eager_tf_policy(
|
|||
variables = self.model.trainable_variables()
|
||||
|
||||
if compute_gradients_fn:
|
||||
|
||||
class OptimizerWrapper:
|
||||
def __init__(self, tape):
|
||||
self.tape = tape
|
||||
|
||||
def compute_gradients(self, loss, var_list):
|
||||
return list(
|
||||
zip(self.tape.gradient(loss, var_list), var_list))
|
||||
|
||||
grads_and_vars = compute_gradients_fn(self,
|
||||
OptimizerWrapper(tape),
|
||||
loss)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import functools
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
@ -61,11 +61,12 @@ class StochasticSampling(Exploration):
|
|||
dtype=np.int64)
|
||||
|
||||
@override(Exploration)
|
||||
def get_exploration_action(self,
|
||||
*,
|
||||
action_distribution: ActionDistribution,
|
||||
timestep: Union[int, TensorType],
|
||||
explore: bool = True):
|
||||
def get_exploration_action(
|
||||
self,
|
||||
*,
|
||||
action_distribution: ActionDistribution,
|
||||
timestep: Optional[Union[int, TensorType]] = None,
|
||||
explore: bool = True):
|
||||
if self.framework == "torch":
|
||||
return self._get_torch_exploration_action(action_distribution,
|
||||
timestep, explore)
|
||||
|
@ -74,7 +75,7 @@ class StochasticSampling(Exploration):
|
|||
timestep, explore)
|
||||
|
||||
def _get_tf_exploration_action_op(self, action_dist, timestep, explore):
|
||||
ts = timestep if timestep is not None else self.last_timestep + 1
|
||||
ts = self.last_timestep + 1
|
||||
|
||||
stochastic_actions = tf.cond(
|
||||
pred=tf.convert_to_tensor(ts < self.random_timesteps),
|
||||
|
@ -100,10 +101,7 @@ class StochasticSampling(Exploration):
|
|||
|
||||
# Increment `last_timestep` by 1 (or set to `timestep`).
|
||||
if self.framework in ["tf2", "tfe"]:
|
||||
if timestep is None:
|
||||
self.last_timestep.assign_add(1)
|
||||
else:
|
||||
self.last_timestep.assign(timestep)
|
||||
self.last_timestep.assign_add(1)
|
||||
return action, logp
|
||||
else:
|
||||
assign_op = (tf1.assign_add(self.last_timestep, 1)
|
||||
|
|
|
@ -31,7 +31,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def framework_iterator(config=None,
|
||||
frameworks=("tf2", "tf", "tfe", "torch"),
|
||||
session=False):
|
||||
session=False,
|
||||
with_eager_tracing=False):
|
||||
"""An generator that allows for looping through n frameworks for testing.
|
||||
|
||||
Provides the correct config entries ("framework") as well
|
||||
|
@ -46,6 +47,8 @@ def framework_iterator(config=None,
|
|||
and yield that as second return value (otherwise yield (fw, None)).
|
||||
Also sets a seed (42) on the session to make the test
|
||||
deterministic.
|
||||
with_eager_tracing: Include `eager_tracing=True` in the returned
|
||||
configs, when framework=[tfe|tf2].
|
||||
|
||||
Yields:
|
||||
str: If enter_session is False:
|
||||
|
@ -105,7 +108,15 @@ def framework_iterator(config=None,
|
|||
elif fw == "tf":
|
||||
assert not tf1.executing_eagerly()
|
||||
|
||||
yield fw if session is False else (fw, sess)
|
||||
# Additionally loop through eager_tracing=True + False, if necessary.
|
||||
if fw in ["tf2", "tfe"] and with_eager_tracing:
|
||||
for tracing in [True, False]:
|
||||
config["eager_tracing"] = tracing
|
||||
yield fw if session is False else (fw, sess)
|
||||
config["eager_tracing"] = False
|
||||
# Yield current framework + tf-session (if necessary).
|
||||
else:
|
||||
yield fw if session is False else (fw, sess)
|
||||
|
||||
# Exit any context we may have entered.
|
||||
if eager_ctx:
|
||||
|
|
Loading…
Add table
Reference in a new issue