mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] PyBullet Env native support via env str-specifier (if installed). (#12209)
This commit is contained in:
parent
b85c6abc3e
commit
bb03e2499b
5 changed files with 47 additions and 5 deletions
|
@ -283,7 +283,7 @@ install_dependencies() {
|
|||
fi
|
||||
fi
|
||||
|
||||
# Additional RLlib dependencies.
|
||||
# Additional RLlib test dependencies.
|
||||
if [ "${RLLIB_TESTING-}" = 1 ]; then
|
||||
pip install -r "${WORKSPACE_DIR}"/python/requirements_rllib.txt
|
||||
# install the following packages for testing on travis only
|
||||
|
|
|
@ -5,5 +5,7 @@ torch>=1.6.0
|
|||
# Version requirement to match Tune
|
||||
torchvision>=0.6.0
|
||||
smart_open
|
||||
# For testing in MuJoCo-like envs (in PyBullet).
|
||||
pybullet
|
||||
# For tests on PettingZoo's multi-agent envs.
|
||||
pettingzoo>=1.4.0
|
||||
|
|
20
rllib/BUILD
20
rllib/BUILD
|
@ -339,6 +339,16 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/sac"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "run_regression_tests_cartpole_continuous_pybullet_sac_tf",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_tf", "learning_tests_cartpole"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/sac"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "run_regression_tests_cartpole_sac_torch",
|
||||
main = "tests/run_regression_tests.py",
|
||||
|
@ -349,6 +359,16 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/sac", "--torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "run_regression_tests_cartpole_continuous_pybullet_sac_torch",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_torch", "learning_tests_cartpole"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/sac", "--torch"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "run_regression_tests_pendulum_sac_tf",
|
||||
main = "tests/run_regression_tests.py",
|
||||
|
|
|
@ -553,11 +553,22 @@ class Trainer(Trainable):
|
|||
elif "." in env:
|
||||
self.env_creator = \
|
||||
lambda env_context: from_config(env, env_context)
|
||||
# Try gym.
|
||||
# Try gym/PyBullet.
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = \
|
||||
lambda env_context: gym.make(env, **env_context)
|
||||
|
||||
def _creator(env_context):
|
||||
import gym
|
||||
# Allow for PyBullet envs to be used as well (via string).
|
||||
# This allows for doing things like
|
||||
# `env=CartPoleContinuousBulletEnv-v0`.
|
||||
try:
|
||||
import pybullet_envs
|
||||
pybullet_envs.getList()
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
pass
|
||||
return gym.make(env, **env_context)
|
||||
|
||||
self.env_creator = _creator
|
||||
else:
|
||||
self.env_creator = lambda env_config: None
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
cartpole-sac:
|
||||
env: CartPoleContinuousBulletEnv-v0
|
||||
run: SAC
|
||||
stop:
|
||||
episode_reward_mean: 100
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
Loading…
Add table
Reference in a new issue