[RLlib] TF2/eager memory leak fixes. (#19198)

This commit is contained in:
Sven Mika 2021-10-09 00:11:53 +02:00 committed by GitHub
parent 47447c71e0
commit d439fd7f17
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 70 additions and 33 deletions

View file

@ -10,9 +10,9 @@ kaggle_environments==1.7.11
# Unity3D testing
mlagents_envs==0.27.0
# For tests on PettingZoo's multi-agent envs.
pettingzoo==1.11.0
pettingzoo==1.11.1
pymunk==6.0.0
supersuit
supersuit==2.6.6
# For testing in MuJoCo-like envs (in PyBullet).
pybullet==3.1.7
# For tests on RecSim and Kaggle envs.

View file

@ -92,7 +92,7 @@ class TestPPO(unittest.TestCase):
config["model"]["lstm_cell_size"] = 10
config["model"]["max_seq_len"] = 20
# Use default-native keras models whenever possible.
config["model"]["_use_default_native_models"] = True
# config["model"]["_use_default_native_models"] = True
# Setup lr- and entropy schedules for testing.
config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]]

View file

@ -23,6 +23,7 @@ def env_creator(config):
tune.register_env("cartpole", lambda env_ctx: gym.make("CartPole-v0"))
tune.register_env("pistonball",
lambda config: PettingZooEnv(env_creator(config)))

View file

@ -756,8 +756,11 @@ class SimpleListCollector(SampleCollector):
"True. Alternatively, set no_done_at_end=True to "
"allow this.")
if len(pre_batches) > 1:
other_batches = pre_batches.copy()
del other_batches[agent_id]
else:
other_batches = {}
pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
policy = self.policy_map[pid]
if any(pre_batch[SampleBatch.DONES][:-1]) or len(

View file

@ -65,3 +65,25 @@ class RandomEnv(gym.Env):
# Multi-agent version of the RandomEnv.
RandomMultiAgentEnv = make_multi_agent(lambda c: RandomEnv(c))
# Large observation space "pre-compiled" random env (for testing).
class RandomLargeObsSpaceEnv(RandomEnv):
def __init__(self, config=None):
config = config or {}
config.update({
"observation_space": gym.spaces.Box(-1.0, 1.0, (5000, ))
})
super().__init__(config=config)
# Large observation space + cont. actions "pre-compiled" random env
# (for testing).
class RandomLargeObsSpaceEnvContActions(RandomEnv):
def __init__(self, config=None):
config = config or {}
config.update({
"observation_space": gym.spaces.Box(-1.0, 1.0, (5000, )),
"action_space": gym.spaces.Box(-1.0, 1.0, (5, )),
})
super().__init__(config=config)

View file

@ -66,14 +66,16 @@ def convert_eager_inputs(func):
@functools.wraps(func)
def _func(*args, **kwargs):
if tf.executing_eagerly():
args = [_convert_to_tf(x) for x in args]
eager_args = [_convert_to_tf(x) for x in args]
# TODO: (sven) find a way to remove key-specific hacks.
kwargs = {
eager_kwargs = {
k: _convert_to_tf(
v, dtype=tf.int64 if k == "timestep" else None)
for k, v in kwargs.items()
if k not in {"info_batch", "episodes"}
}
return func(*eager_args, **eager_kwargs)
else:
return func(*args, **kwargs)
return _func
@ -183,6 +185,14 @@ def traced_eager_policy(eager_policy_cls):
return TracedEagerPolicy
class OptimizerWrapper:
def __init__(self, tape):
self.tape = tape
def compute_gradients(self, loss, var_list):
return list(zip(self.tape.gradient(loss, var_list), var_list))
def build_eager_tf_policy(
name,
loss_fn,
@ -433,6 +443,7 @@ def build_eager_tf_policy(
lambda s: tf.convert_to_tensor(s), obs_batch),
},
_is_training=tf.constant(False))
self._lazy_tensor_dict(input_dict)
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch)
@ -466,7 +477,6 @@ def build_eager_tf_policy(
explore, timestep)
@with_lock
@convert_eager_inputs
@convert_eager_outputs
def _compute_action_helper(self, input_dict, state_batches, episodes,
explore, timestep):
@ -482,7 +492,8 @@ def build_eager_tf_policy(
self._is_training = False
self._state_in = state_batches or []
# Calculate RNN sequence lengths.
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
batch_size = int(
tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0])
seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches \
else None
@ -567,7 +578,7 @@ def build_eager_tf_policy(
extra_fetches.update(extra_action_out_fn(self))
# Update our global timestep by the batch size.
self.global_timestep += int(batch_size)
self.global_timestep += batch_size
return actions, state_out, extra_fetches
@ -747,15 +758,6 @@ def build_eager_tf_policy(
variables = self.model.trainable_variables()
if compute_gradients_fn:
class OptimizerWrapper:
def __init__(self, tape):
self.tape = tape
def compute_gradients(self, loss, var_list):
return list(
zip(self.tape.gradient(loss, var_list), var_list))
grads_and_vars = compute_gradients_fn(self,
OptimizerWrapper(tape),
loss)

View file

@ -1,7 +1,7 @@
import functools
import gym
import numpy as np
from typing import Union
from typing import Optional, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
@ -61,10 +61,11 @@ class StochasticSampling(Exploration):
dtype=np.int64)
@override(Exploration)
def get_exploration_action(self,
def get_exploration_action(
self,
*,
action_distribution: ActionDistribution,
timestep: Union[int, TensorType],
timestep: Optional[Union[int, TensorType]] = None,
explore: bool = True):
if self.framework == "torch":
return self._get_torch_exploration_action(action_distribution,
@ -74,7 +75,7 @@ class StochasticSampling(Exploration):
timestep, explore)
def _get_tf_exploration_action_op(self, action_dist, timestep, explore):
ts = timestep if timestep is not None else self.last_timestep + 1
ts = self.last_timestep + 1
stochastic_actions = tf.cond(
pred=tf.convert_to_tensor(ts < self.random_timesteps),
@ -100,10 +101,7 @@ class StochasticSampling(Exploration):
# Increment `last_timestep` by 1 (or set to `timestep`).
if self.framework in ["tf2", "tfe"]:
if timestep is None:
self.last_timestep.assign_add(1)
else:
self.last_timestep.assign(timestep)
return action, logp
else:
assign_op = (tf1.assign_add(self.last_timestep, 1)

View file

@ -31,7 +31,8 @@ logger = logging.getLogger(__name__)
def framework_iterator(config=None,
frameworks=("tf2", "tf", "tfe", "torch"),
session=False):
session=False,
with_eager_tracing=False):
"""An generator that allows for looping through n frameworks for testing.
Provides the correct config entries ("framework") as well
@ -46,6 +47,8 @@ def framework_iterator(config=None,
and yield that as second return value (otherwise yield (fw, None)).
Also sets a seed (42) on the session to make the test
deterministic.
with_eager_tracing: Include `eager_tracing=True` in the returned
configs, when framework=[tfe|tf2].
Yields:
str: If enter_session is False:
@ -105,6 +108,14 @@ def framework_iterator(config=None,
elif fw == "tf":
assert not tf1.executing_eagerly()
# Additionally loop through eager_tracing=True + False, if necessary.
if fw in ["tf2", "tfe"] and with_eager_tracing:
for tracing in [True, False]:
config["eager_tracing"] = tracing
yield fw if session is False else (fw, sess)
config["eager_tracing"] = False
# Yield current framework + tf-session (if necessary).
else:
yield fw if session is False else (fw, sess)
# Exit any context we may have entered.