mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Raise worker TF intra_op threads to 2, lower driver intra_op threads to 8 (#3299)
This commit is contained in:
parent
c0423db05c
commit
6ee7a3b571
1 changed files with 17 additions and 11 deletions
|
@ -83,9 +83,9 @@ COMMON_CONFIG = {
|
|||
"synchronize_filters": True,
|
||||
# Configure TF for single-process operation by default
|
||||
"tf_session_args": {
|
||||
# note: parallelism_threads is set to auto for the local evaluator
|
||||
"intra_op_parallelism_threads": 1,
|
||||
"inter_op_parallelism_threads": 1,
|
||||
# note: overriden by `local_evaluator_tf_session_args`
|
||||
"intra_op_parallelism_threads": 2,
|
||||
"inter_op_parallelism_threads": 2,
|
||||
"gpu_options": {
|
||||
"allow_growth": True,
|
||||
},
|
||||
|
@ -95,6 +95,13 @@ COMMON_CONFIG = {
|
|||
},
|
||||
"allow_soft_placement": True, # required by PPO multi-gpu
|
||||
},
|
||||
# Override the following tf session args on the local evaluator
|
||||
"local_evaluator_tf_session_args": {
|
||||
# Allow a higher level of parallelism by default, but not unlimited
|
||||
# since that can cause crashes with many concurrent drivers.
|
||||
"intra_op_parallelism_threads": 8,
|
||||
"inter_op_parallelism_threads": 8,
|
||||
},
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
# Allocate a fraction of a GPU instead of one (e.g., 0.3 GPUs)
|
||||
|
@ -150,14 +157,11 @@ class Agent(Trainable):
|
|||
env_creator,
|
||||
policy_graph,
|
||||
0,
|
||||
# important: allow local tf to use multiple CPUs for optimization
|
||||
merge_dicts(
|
||||
self.config, {
|
||||
"tf_session_args": {
|
||||
"intra_op_parallelism_threads": None,
|
||||
"inter_op_parallelism_threads": None,
|
||||
}
|
||||
}))
|
||||
# important: allow local tf to use more CPUs for optimization
|
||||
merge_dicts(self.config, {
|
||||
"tf_session_args": self.
|
||||
config["local_evaluator_tf_session_args"]
|
||||
}))
|
||||
|
||||
def make_remote_evaluators(self, env_creator, policy_graph, count,
|
||||
remote_args):
|
||||
|
@ -172,6 +176,8 @@ class Agent(Trainable):
|
|||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(**config["tf_session_args"]))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue