diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 8012ce652..efe30a0a7 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -289,6 +289,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_legacy.py + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_io.py diff --git a/python/ray/rllib/tests/test_legacy.py b/python/ray/rllib/tests/test_legacy.py new file mode 100644 index 000000000..ae8285881 --- /dev/null +++ b/python/ray/rllib/tests/test_legacy.py @@ -0,0 +1,15 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.agents.ppo import PPOAgent +from ray import tune +import ray + +if __name__ == "__main__": + ray.init() + # Test legacy *Agent classes work (renamed to Trainer) + tune.run( + PPOAgent, + config={"env": "CartPole-v0"}, + stop={"training_iteration": 2}) diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index 5e7b2a141..7aab0f2a0 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -10,14 +10,18 @@ logger = logging.getLogger(__name__) def renamed_class(cls): + """Helper class for renaming Agent => Trainer with a warning.""" + class DeprecationWrapper(cls): - def __init__(self, *args, **kwargs): + def __init__(self, config=None, env=None, logger_creator=None): old_name = cls.__name__.replace("Trainer", "Agent") new_name = cls.__name__ logger.warn("DeprecationWarning: {} has been renamed to {}. ". format(old_name, new_name) + "This will raise an error in the future.") - cls.__init__(self, *args, **kwargs) + cls.__init__(self, config, env, logger_creator) + + DeprecationWrapper.__name__ = cls.__name__ return DeprecationWrapper