mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Faster remote worker space inference (don't infer if not required). (#18805)
This commit is contained in:
parent
361cae4d1c
commit
a2a077b874
2 changed files with 52 additions and 4 deletions
|
@ -1,6 +1,8 @@
|
|||
import copy
|
||||
import gym
|
||||
import numpy as np
|
||||
from random import choice
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
|
@ -184,6 +186,45 @@ class TestTrainer(unittest.TestCase):
|
|||
trainer_w_env_on_driver.stop()
|
||||
config["create_env_on_driver"] = False
|
||||
|
||||
def test_space_inference_from_remote_workers(self):
|
||||
# Expect to not do space inference if the learner has an env.
|
||||
|
||||
env = gym.make("CartPole-v0")
|
||||
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
config["env"] = "CartPole-v0"
|
||||
config["num_workers"] = 1
|
||||
|
||||
# No env on driver -> expect longer build time due to space
|
||||
# "lookup" from remote worker.
|
||||
t0 = time.time()
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
w_lookup = time.time() - t0
|
||||
print(f"No env on learner: {w_lookup}sec")
|
||||
trainer.stop()
|
||||
|
||||
# Env on driver -> expect longer build time due to space
|
||||
# "lookup" from remote worker.
|
||||
config["create_env_on_driver"] = True
|
||||
t0 = time.time()
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
wo_lookup = time.time() - t0
|
||||
print(f"Env on learner: {wo_lookup}sec")
|
||||
self.assertLess(wo_lookup, w_lookup)
|
||||
trainer.stop()
|
||||
|
||||
# Spaces given -> expect shorter build time due to no space
|
||||
# "lookup" from remote worker.
|
||||
config["create_env_on_driver"] = False
|
||||
config["observation_space"] = env.observation_space
|
||||
config["action_space"] = env.action_space
|
||||
t0 = time.time()
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
wo_lookup = time.time() - t0
|
||||
print(f"Spaces given manually in config: {wo_lookup}sec")
|
||||
self.assertLess(wo_lookup, w_lookup)
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
|
@ -77,9 +77,13 @@ class WorkerSet:
|
|||
self._remote_workers = []
|
||||
self.add_workers(num_workers)
|
||||
|
||||
# If num_workers > 0, get the action_spaces and observation_spaces
|
||||
# to not be forced to create an Env on the local worker.
|
||||
if self._remote_workers:
|
||||
# If num_workers > 0 and we don't have an env on the local worker,
|
||||
# get the observation- and action spaces for each policy from
|
||||
# the first remote worker (which does have an env).
|
||||
if self._remote_workers and \
|
||||
not trainer_config.get("create_env_on_driver") and \
|
||||
(not trainer_config.get("observation_space") or
|
||||
not trainer_config.get("action_space")):
|
||||
remote_spaces = ray.get(self.remote_workers(
|
||||
)[0].foreach_policy.remote(
|
||||
lambda p, pid: (pid, p.observation_space, p.action_space)))
|
||||
|
@ -96,6 +100,9 @@ class WorkerSet:
|
|||
spaces["__env__"] = env_spaces
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Inferred observation/action spaces from remote "
|
||||
f"worker (local worker has no env): {spaces}")
|
||||
else:
|
||||
spaces = None
|
||||
|
||||
|
@ -127,7 +134,7 @@ class WorkerSet:
|
|||
e.set_weights.remote(weights)
|
||||
|
||||
def add_workers(self, num_workers: int) -> None:
|
||||
"""Creates and add a number of remote workers to this worker set.
|
||||
"""Creates and adds a number of remote workers to this worker set.
|
||||
|
||||
Args:
|
||||
num_workers (int): The number of remote Workers to add to this
|
||||
|
|
Loading…
Add table
Reference in a new issue