[RLlib] PyBullet Env native support via env str-specifier (if installed). (#12209)

This commit is contained in:
Sven Mika 2020-11-30 12:41:24 +01:00 committed by GitHub
parent b85c6abc3e
commit bb03e2499b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 5 deletions

View file

@ -283,7 +283,7 @@ install_dependencies() {
fi
fi
# Additional RLlib dependencies.
# Additional RLlib test dependencies.
if [ "${RLLIB_TESTING-}" = 1 ]; then
pip install -r "${WORKSPACE_DIR}"/python/requirements_rllib.txt
# install the following packages for testing on travis only

View file

@ -5,5 +5,7 @@ torch>=1.6.0
# Version requirement to match Tune
torchvision>=0.6.0
smart_open
# For testing in MuJoCo-like envs (in PyBullet).
pybullet
# For tests on PettingZoo's multi-agent envs.
pettingzoo>=1.4.0

View file

@ -339,6 +339,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/sac"]
)
py_test(
name = "run_regression_tests_cartpole_continuous_pybullet_sac_tf",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf", "learning_tests_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml"],
args = ["--yaml-dir=tuned_examples/sac"]
)
py_test(
name = "run_regression_tests_cartpole_sac_torch",
main = "tests/run_regression_tests.py",
@ -349,6 +359,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/sac", "--torch"]
)
py_test(
name = "run_regression_tests_cartpole_continuous_pybullet_sac_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml"],
args = ["--yaml-dir=tuned_examples/sac", "--torch"]
)
py_test(
name = "run_regression_tests_pendulum_sac_tf",
main = "tests/run_regression_tests.py",

View file

@ -553,11 +553,22 @@ class Trainer(Trainable):
elif "." in env:
self.env_creator = \
lambda env_context: from_config(env, env_context)
# Try gym.
# Try gym/PyBullet.
else:
import gym # soft dependency
self.env_creator = \
lambda env_context: gym.make(env, **env_context)
def _creator(env_context):
import gym
# Allow for PyBullet envs to be used as well (via string).
# This allows for doing things like
# `env=CartPoleContinuousBulletEnv-v0`.
try:
import pybullet_envs
pybullet_envs.getList()
except (ModuleNotFoundError, ImportError):
pass
return gym.make(env, **env_context)
self.env_creator = _creator
else:
self.env_creator = lambda env_config: None

View file

@ -0,0 +1,9 @@
cartpole-sac:
env: CartPoleContinuousBulletEnv-v0
run: SAC
stop:
episode_reward_mean: 100
timesteps_total: 100000
config:
# Works for both torch and tf.
framework: tf