mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Trajectory view API: enable by default for SAC, DDPG, DQN, SimpleQ (#11827)
This commit is contained in:
parent
8609e2dd90
commit
b6b54f1c81
9 changed files with 62 additions and 32 deletions
|
@ -145,6 +145,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
|
||||
# Use the new "trajectory view API" to collect samples and produce
|
||||
# model- and policy inputs.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -132,6 +132,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
|
||||
# Use the new "trajectory view API" to collect samples and produce
|
||||
# model- and policy inputs.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -90,6 +90,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 0,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
|
||||
# Use the new "trajectory view API" to collect samples and produce
|
||||
# model- and policy inputs.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -134,6 +134,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Use a Beta-distribution instead of a SquashedGaussian for bounded,
|
||||
# continuous action spaces (not recommended, for debugging only).
|
||||
"_use_beta_distribution": False,
|
||||
|
||||
# Use the new "trajectory view API" to collect samples and produce
|
||||
# model- and policy inputs.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
|
|
@ -4,6 +4,7 @@ import time
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
|
@ -26,19 +27,25 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
def test_traj_view_normal_case(self):
|
||||
"""Tests, whether Model and Policy return the correct ViewRequirements.
|
||||
"""
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
for _ in framework_iterator(config):
|
||||
trainer = dqn.DQNTrainer(
|
||||
config,
|
||||
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_policy = policy.view_requirements
|
||||
assert len(view_req_model) == 1, view_req_model
|
||||
assert len(view_req_policy) == 12, view_req_policy
|
||||
assert len(view_req_policy) == 8, view_req_policy
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
SampleBatch.VF_PREDS, "advantages", "value_targets",
|
||||
SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
|
||||
SampleBatch.OBS,
|
||||
SampleBatch.ACTIONS,
|
||||
SampleBatch.REWARDS,
|
||||
SampleBatch.DONES,
|
||||
SampleBatch.NEXT_OBS,
|
||||
SampleBatch.EPS_ID,
|
||||
SampleBatch.AGENT_INDEX,
|
||||
"weights",
|
||||
]:
|
||||
assert key in view_req_policy
|
||||
# None of the view cols has a special underlying data_col,
|
||||
|
|
|
@ -457,7 +457,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
logger.info("Testing `compute_actions` w/ dummy batch.")
|
||||
actions, state_outs, extra_fetches = \
|
||||
self.compute_actions_from_input_dict(
|
||||
self._dummy_batch, explore=True, timestep=0)
|
||||
self._dummy_batch, explore=False, timestep=0)
|
||||
for key, value in extra_fetches.items():
|
||||
self._dummy_batch[key] = np.zeros_like(value)
|
||||
self._input_dict[key] = get_placeholder(value=value, name=key)
|
||||
|
|
|
@ -97,7 +97,7 @@ if __name__ == "__main__":
|
|||
print("Regression test PASSED")
|
||||
break
|
||||
else:
|
||||
print("Regression test FAILED on attempt {}", i + 1)
|
||||
print("Regression test FAILED on attempt {}".format(i + 1))
|
||||
|
||||
if not passed:
|
||||
print("Overall regression FAILED: Exiting with Error.")
|
||||
|
|
|
@ -21,10 +21,12 @@ def one_hot(i, n):
|
|||
|
||||
|
||||
class TestMultiAgentEnv(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_basic_mock(self):
|
||||
|
@ -327,24 +329,27 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
prev_reward_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
# Pretend we did a model-based rollout and want to return
|
||||
# the extra trajectory.
|
||||
builder = episodes[0].new_batch_builder()
|
||||
rollout_id = random.randint(0, 10000)
|
||||
for t in range(5):
|
||||
builder.add_values(
|
||||
agent_id="extra_0",
|
||||
policy_id="p1", # use p1 so we can easily check it
|
||||
t=t,
|
||||
eps_id=rollout_id, # new id for each rollout
|
||||
obs=obs_batch[0],
|
||||
actions=0,
|
||||
rewards=0,
|
||||
dones=t == 4,
|
||||
infos={},
|
||||
new_obs=obs_batch[0])
|
||||
batch = builder.build_and_reset(episode=None)
|
||||
episodes[0].add_extra_batch(batch)
|
||||
# In policy loss initialization phase, no episodes are passed
|
||||
# in.
|
||||
if episodes is not None:
|
||||
# Pretend we did a model-based rollout and want to return
|
||||
# the extra trajectory.
|
||||
builder = episodes[0].new_batch_builder()
|
||||
rollout_id = random.randint(0, 10000)
|
||||
for t in range(5):
|
||||
builder.add_values(
|
||||
agent_id="extra_0",
|
||||
policy_id="p1", # use p1 so we can easily check it
|
||||
t=t,
|
||||
eps_id=rollout_id, # new id for each rollout
|
||||
obs=obs_batch[0],
|
||||
actions=0,
|
||||
rewards=0,
|
||||
dones=t == 4,
|
||||
infos={},
|
||||
new_obs=obs_batch[0])
|
||||
batch = builder.build_and_reset(episode=None)
|
||||
episodes[0].add_extra_batch(batch)
|
||||
|
||||
# Just return zeros for actions
|
||||
return [0] * len(obs_batch), [], {}
|
||||
|
|
|
@ -3,9 +3,11 @@ cartpole-dqn:
|
|||
run: DQN
|
||||
stop:
|
||||
episode_reward_mean: 150
|
||||
timesteps_total: 50000
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
model:
|
||||
fcnet_hiddens: [64]
|
||||
fcnet_activation: linear
|
||||
n_step: 3
|
||||
gamma: 0.95
|
||||
|
|
Loading…
Add table
Reference in a new issue