mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rllib] Fix tune.run(Agent class) (#4630)
* update * Update __init__.py
This commit is contained in:
parent
56a78baf67
commit
3e234fe937
3 changed files with 24 additions and 2 deletions
|
@ -289,6 +289,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_legacy.py
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_io.py
|
||||
|
||||
|
|
15
python/ray/rllib/tests/test_legacy.py
Normal file
15
python/ray/rllib/tests/test_legacy.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray import tune
|
||||
import ray
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
# Test legacy *Agent classes work (renamed to Trainer)
|
||||
tune.run(
|
||||
PPOAgent,
|
||||
config={"env": "CartPole-v0"},
|
||||
stop={"training_iteration": 2})
|
|
@ -10,14 +10,18 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def renamed_class(cls):
|
||||
"""Helper class for renaming Agent => Trainer with a warning."""
|
||||
|
||||
class DeprecationWrapper(cls):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
old_name = cls.__name__.replace("Trainer", "Agent")
|
||||
new_name = cls.__name__
|
||||
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
|
||||
format(old_name, new_name) +
|
||||
"This will raise an error in the future.")
|
||||
cls.__init__(self, *args, **kwargs)
|
||||
cls.__init__(self, config, env, logger_creator)
|
||||
|
||||
DeprecationWrapper.__name__ = cls.__name__
|
||||
|
||||
return DeprecationWrapper
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue