[RLlib] Faster remote worker space inference (don't infer if not required). (#18805)

This commit is contained in:
Sven Mika 2021-09-23 10:54:37 +02:00 committed by GitHub
parent 361cae4d1c
commit a2a077b874
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 4 deletions

View file

@ -1,6 +1,8 @@
import copy
import gym
import numpy as np
from random import choice
import time
import unittest
import ray
@ -184,6 +186,45 @@ class TestTrainer(unittest.TestCase):
trainer_w_env_on_driver.stop()
config["create_env_on_driver"] = False
def test_space_inference_from_remote_workers(self):
# Expect to not do space inference if the learner has an env.
env = gym.make("CartPole-v0")
config = pg.DEFAULT_CONFIG.copy()
config["env"] = "CartPole-v0"
config["num_workers"] = 1
# No env on driver -> expect longer build time due to space
# "lookup" from remote worker.
t0 = time.time()
trainer = pg.PGTrainer(config=config)
w_lookup = time.time() - t0
print(f"No env on learner: {w_lookup}sec")
trainer.stop()
# Env on driver -> expect longer build time due to space
# "lookup" from remote worker.
config["create_env_on_driver"] = True
t0 = time.time()
trainer = pg.PGTrainer(config=config)
wo_lookup = time.time() - t0
print(f"Env on learner: {wo_lookup}sec")
self.assertLess(wo_lookup, w_lookup)
trainer.stop()
# Spaces given -> expect shorter build time due to no space
# "lookup" from remote worker.
config["create_env_on_driver"] = False
config["observation_space"] = env.observation_space
config["action_space"] = env.action_space
t0 = time.time()
trainer = pg.PGTrainer(config=config)
wo_lookup = time.time() - t0
print(f"Spaces given manually in config: {wo_lookup}sec")
self.assertLess(wo_lookup, w_lookup)
trainer.stop()
if __name__ == "__main__":
import pytest

View file

@ -77,9 +77,13 @@ class WorkerSet:
self._remote_workers = []
self.add_workers(num_workers)
# If num_workers > 0, get the action_spaces and observation_spaces
# to not be forced to create an Env on the local worker.
if self._remote_workers:
# If num_workers > 0 and we don't have an env on the local worker,
# get the observation- and action spaces for each policy from
# the first remote worker (which does have an env).
if self._remote_workers and \
not trainer_config.get("create_env_on_driver") and \
(not trainer_config.get("observation_space") or
not trainer_config.get("action_space")):
remote_spaces = ray.get(self.remote_workers(
)[0].foreach_policy.remote(
lambda p, pid: (pid, p.observation_space, p.action_space)))
@ -96,6 +100,9 @@ class WorkerSet:
spaces["__env__"] = env_spaces
except Exception:
pass
logger.info("Inferred observation/action spaces from remote "
f"worker (local worker has no env): {spaces}")
else:
spaces = None
@ -127,7 +134,7 @@ class WorkerSet:
e.set_weights.remote(weights)
def add_workers(self, num_workers: int) -> None:
"""Creates and add a number of remote workers to this worker set.
"""Creates and adds a number of remote workers to this worker set.
Args:
num_workers (int): The number of remote Workers to add to this