[rllib] Fix tune.run(Agent class) (#4630)

* update

* Update __init__.py
This commit is contained in:
Eric Liang 2019-04-15 09:12:23 -07:00 committed by GitHub
parent 776a7308c8
commit 3fd9dea721
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 2 deletions

View file

@ -289,6 +289,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_legacy.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_io.py /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_io.py

View file

@ -0,0 +1,15 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.ppo import PPOAgent
from ray import tune
import ray
if __name__ == "__main__":
ray.init()
# Test legacy *Agent classes work (renamed to Trainer)
tune.run(
PPOAgent,
config={"env": "CartPole-v0"},
stop={"training_iteration": 2})

View file

@ -10,14 +10,18 @@ logger = logging.getLogger(__name__)
def renamed_class(cls): def renamed_class(cls):
"""Helper class for renaming Agent => Trainer with a warning."""
class DeprecationWrapper(cls): class DeprecationWrapper(cls):
def __init__(self, *args, **kwargs): def __init__(self, config=None, env=None, logger_creator=None):
old_name = cls.__name__.replace("Trainer", "Agent") old_name = cls.__name__.replace("Trainer", "Agent")
new_name = cls.__name__ new_name = cls.__name__
logger.warn("DeprecationWarning: {} has been renamed to {}. ". logger.warn("DeprecationWarning: {} has been renamed to {}. ".
format(old_name, new_name) + format(old_name, new_name) +
"This will raise an error in the future.") "This will raise an error in the future.")
cls.__init__(self, *args, **kwargs) cls.__init__(self, config, env, logger_creator)
DeprecationWrapper.__name__ = cls.__name__
return DeprecationWrapper return DeprecationWrapper