[rllib] Raise worker TF intra_op threads to 2, lower driver intra_op threads to 8 (#3299)

This commit is contained in:
Eric Liang 2018-11-13 11:41:58 -08:00 committed by GitHub
parent c0423db05c
commit 6ee7a3b571
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -83,9 +83,9 @@ COMMON_CONFIG = {
"synchronize_filters": True,
# Configure TF for single-process operation by default
"tf_session_args": {
# note: parallelism_threads is set to auto for the local evaluator
"intra_op_parallelism_threads": 1,
"inter_op_parallelism_threads": 1,
# note: overriden by `local_evaluator_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
"allow_growth": True,
},
@ -95,6 +95,13 @@ COMMON_CONFIG = {
},
"allow_soft_placement": True, # required by PPO multi-gpu
},
# Override the following tf session args on the local evaluator
"local_evaluator_tf_session_args": {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
},
# Whether to LZ4 compress observations
"compress_observations": False,
# Allocate a fraction of a GPU instead of one (e.g., 0.3 GPUs)
@ -150,13 +157,10 @@ class Agent(Trainable):
env_creator,
policy_graph,
0,
# important: allow local tf to use multiple CPUs for optimization
merge_dicts(
self.config, {
"tf_session_args": {
"intra_op_parallelism_threads": None,
"inter_op_parallelism_threads": None,
}
# important: allow local tf to use more CPUs for optimization
merge_dicts(self.config, {
"tf_session_args": self.
config["local_evaluator_tf_session_args"]
}))
def make_remote_evaluators(self, env_creator, policy_graph, count,
@ -172,6 +176,8 @@ class Agent(Trainable):
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))