2022-06-07 10:18:14 -07:00
|
|
|
import gym
|
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
|
2022-06-07 10:18:14 -07:00
|
|
|
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
|
|
|
|
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
|
|
|
|
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
|
|
|
|
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
|
2022-07-21 20:43:59 -07:00
|
|
|
from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
|
2022-07-09 01:06:24 -07:00
|
|
|
from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
|
2022-06-29 23:44:10 -07:00
|
|
|
from ray.rllib.connectors.connector import ConnectorContext, get_connector
|
2022-07-09 01:06:24 -07:00
|
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
2022-06-07 10:18:14 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2022-07-26 21:52:14 -07:00
|
|
|
from ray.rllib.utils.test_utils import check
|
2022-07-21 20:43:59 -07:00
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
ActionConnectorDataType,
|
|
|
|
AgentConnectorDataType,
|
|
|
|
AgentConnectorsOutput,
|
|
|
|
)
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TestAgentConnector(unittest.TestCase):
|
|
|
|
def test_connector_pipeline(self):
|
|
|
|
ctx = ConnectorContext()
|
|
|
|
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
|
|
|
|
pipeline = AgentConnectorPipeline(ctx, connectors)
|
|
|
|
name, params = pipeline.to_config()
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
|
|
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
|
|
|
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
|
|
|
|
|
|
|
|
def test_obs_preprocessor_connector(self):
|
|
|
|
obs_space = gym.spaces.Dict(
|
|
|
|
{
|
|
|
|
"a": gym.spaces.Box(low=0, high=1, shape=(1,)),
|
|
|
|
"b": gym.spaces.Tuple(
|
|
|
|
[gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
|
|
|
|
),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
ctx = ConnectorContext(config={}, observation_space=obs_space)
|
|
|
|
|
|
|
|
c = ObsPreprocessorConnector(ctx)
|
|
|
|
name, params = c.to_config()
|
|
|
|
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
|
|
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
|
|
|
|
|
|
|
|
obs = obs_space.sample()
|
|
|
|
# Fake deterministic data.
|
|
|
|
obs["a"][0] = 0.5
|
|
|
|
obs["b"] = (1, np.array([0, 2]))
|
|
|
|
|
|
|
|
d = AgentConnectorDataType(
|
|
|
|
0,
|
|
|
|
1,
|
|
|
|
{
|
|
|
|
SampleBatch.OBS: obs,
|
|
|
|
},
|
|
|
|
)
|
2022-06-29 23:44:10 -07:00
|
|
|
preprocessed = c([d])
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
# obs is completely flattened.
|
|
|
|
self.assertTrue(
|
|
|
|
(preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_clip_reward_connector(self):
|
|
|
|
ctx = ConnectorContext()
|
|
|
|
|
|
|
|
c = ClipRewardAgentConnector(ctx, limit=2.0)
|
|
|
|
name, params = c.to_config()
|
|
|
|
|
|
|
|
self.assertEqual(name, "ClipRewardAgentConnector")
|
|
|
|
self.assertAlmostEqual(params["limit"], 2.0)
|
|
|
|
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
|
|
self.assertTrue(isinstance(restored, ClipRewardAgentConnector))
|
|
|
|
|
|
|
|
d = AgentConnectorDataType(
|
|
|
|
0,
|
|
|
|
1,
|
|
|
|
{
|
|
|
|
SampleBatch.REWARDS: 5.8,
|
|
|
|
},
|
|
|
|
)
|
2022-06-29 23:44:10 -07:00
|
|
|
clipped = restored([d])
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
self.assertEqual(len(clipped), 1)
|
|
|
|
self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
|
|
|
|
|
|
|
|
def test_flatten_data_connector(self):
|
|
|
|
ctx = ConnectorContext()
|
|
|
|
|
|
|
|
c = FlattenDataAgentConnector(ctx)
|
|
|
|
|
|
|
|
name, params = c.to_config()
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
|
|
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
|
|
|
|
|
2022-07-09 01:06:24 -07:00
|
|
|
sample_batch = {
|
|
|
|
SampleBatch.NEXT_OBS: {
|
|
|
|
"sensor1": [[1, 1], [2, 2]],
|
|
|
|
"sensor2": 8.8,
|
|
|
|
},
|
|
|
|
SampleBatch.REWARDS: 5.8,
|
|
|
|
SampleBatch.ACTIONS: [[1, 1], [2]],
|
|
|
|
SampleBatch.INFOS: {"random": "info"},
|
|
|
|
}
|
|
|
|
|
2022-06-07 10:18:14 -07:00
|
|
|
d = AgentConnectorDataType(
|
|
|
|
0,
|
|
|
|
1,
|
2022-07-09 01:06:24 -07:00
|
|
|
# FlattenDataAgentConnector does NOT touch for_training dict,
|
|
|
|
# so simply pass None here.
|
|
|
|
AgentConnectorsOutput(None, sample_batch),
|
2022-06-07 10:18:14 -07:00
|
|
|
)
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
flattened = c([d])
|
2022-06-07 10:18:14 -07:00
|
|
|
self.assertEqual(len(flattened), 1)
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
batch = flattened[0].data.for_action
|
2022-06-07 10:18:14 -07:00
|
|
|
self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all())
|
|
|
|
self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
|
|
|
|
# Not flattened.
|
|
|
|
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
|
|
|
|
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
|
|
|
|
|
2022-07-21 20:43:59 -07:00
|
|
|
def test_state_buffer_connector(self):
|
|
|
|
ctx = ConnectorContext(
|
|
|
|
action_space=gym.spaces.Box(low=-1.0, high=1.0, shape=(3,)),
|
|
|
|
)
|
|
|
|
c = StateBufferConnector(ctx)
|
|
|
|
|
|
|
|
# Reset without any buffered data should do nothing.
|
|
|
|
c.reset(env_id=0)
|
|
|
|
|
|
|
|
d = AgentConnectorDataType(
|
|
|
|
0,
|
|
|
|
1,
|
|
|
|
{
|
|
|
|
SampleBatch.NEXT_OBS: {
|
|
|
|
"sensor1": [[1, 1], [2, 2]],
|
|
|
|
"sensor2": 8.8,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
with_buffered = c([d])
|
|
|
|
self.assertEqual(len(with_buffered), 1)
|
|
|
|
self.assertTrue((with_buffered[0].data[SampleBatch.ACTIONS] == [0, 0, 0]).all())
|
|
|
|
|
|
|
|
c.on_policy_output(ActionConnectorDataType(0, 1, ([1, 2, 3], [], {})))
|
|
|
|
|
|
|
|
with_buffered = c([d])
|
|
|
|
self.assertEqual(len(with_buffered), 1)
|
|
|
|
self.assertEqual(with_buffered[0].data[SampleBatch.ACTIONS], [1, 2, 3])
|
|
|
|
|
2022-07-09 01:06:24 -07:00
|
|
|
|
2022-07-25 13:17:17 -07:00
|
|
|
class TestViewRequirementConnector(unittest.TestCase):
|
|
|
|
def test_vr_connector_respects_training_or_inference_vr_flags(self):
|
|
|
|
"""Tests that the connector respects the flags within view_requirements (i.e.
|
|
|
|
used_for_training, used_for_compute_actions) under different is_training modes.
|
2022-07-26 21:52:14 -07:00
|
|
|
is_training = False (inference mode)
|
|
|
|
the returned data is a SampleBatch that can be used to run corresponding
|
|
|
|
policy.
|
|
|
|
is_training = True (training mode)
|
|
|
|
the returned data is the input dict itself, which the policy collector in
|
|
|
|
env_runner will use to construct the episode, and a SampleBatch that can be
|
|
|
|
used to run corresponding policy.
|
2022-07-25 13:17:17 -07:00
|
|
|
"""
|
|
|
|
view_rq_dict = {
|
|
|
|
"both": ViewRequirement(
|
|
|
|
"obs", used_for_training=True, used_for_compute_actions=True
|
|
|
|
),
|
|
|
|
"only_inference": ViewRequirement(
|
|
|
|
"obs", used_for_training=False, used_for_compute_actions=True
|
|
|
|
),
|
|
|
|
"none": ViewRequirement(
|
|
|
|
"obs", used_for_training=False, used_for_compute_actions=False
|
|
|
|
),
|
|
|
|
"only_training": ViewRequirement(
|
|
|
|
"obs", used_for_training=True, used_for_compute_actions=False
|
|
|
|
),
|
|
|
|
}
|
|
|
|
|
|
|
|
obs_arr = np.array([0, 1, 2, 3])
|
2022-07-26 21:52:14 -07:00
|
|
|
agent_data = {SampleBatch.NEXT_OBS: obs_arr}
|
2022-07-25 13:17:17 -07:00
|
|
|
data = AgentConnectorDataType(0, 1, agent_data)
|
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
config = PPOConfig().to_dict()
|
|
|
|
ctx = ConnectorContext(
|
|
|
|
view_requirements=view_rq_dict,
|
|
|
|
config=config,
|
|
|
|
is_policy_recurrent=True,
|
|
|
|
)
|
2022-07-25 13:17:17 -07:00
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
for_action_expected = SampleBatch(
|
|
|
|
{
|
|
|
|
"both": obs_arr[None],
|
|
|
|
"only_inference": obs_arr[None],
|
|
|
|
"seq_lens": np.array([1]),
|
|
|
|
}
|
|
|
|
)
|
2022-07-25 13:17:17 -07:00
|
|
|
|
|
|
|
for_training_expected_list = [
|
|
|
|
# is_training = False
|
|
|
|
None,
|
|
|
|
# is_training = True
|
|
|
|
agent_data,
|
|
|
|
]
|
|
|
|
|
|
|
|
for is_training in [True, False]:
|
|
|
|
c = ViewRequirementAgentConnector(ctx)
|
|
|
|
c.is_training(is_training)
|
|
|
|
processed = c([data])
|
|
|
|
|
|
|
|
for_training = processed[0].data.for_training
|
|
|
|
for_training_expected = for_training_expected_list[is_training]
|
|
|
|
for_action = processed[0].data.for_action
|
|
|
|
|
|
|
|
print("-" * 30)
|
|
|
|
print(f"is_training = {is_training}")
|
|
|
|
print("for action:")
|
|
|
|
print(for_action)
|
|
|
|
print("for training:")
|
|
|
|
print(for_training)
|
|
|
|
|
|
|
|
check(for_training, for_training_expected)
|
|
|
|
check(for_action, for_action_expected)
|
|
|
|
|
|
|
|
def test_vr_connector_shift_by_one(self):
|
|
|
|
"""Test that the ViewRequirementConnector can handle shift by one correctly and
|
2022-07-26 21:52:14 -07:00
|
|
|
can ignore future referencing view_requirements to respect causality"""
|
2022-07-25 13:17:17 -07:00
|
|
|
view_rq_dict = {
|
|
|
|
"state": ViewRequirement("obs"),
|
|
|
|
"next_state": ViewRequirement(
|
|
|
|
"obs", shift=1, used_for_compute_actions=False
|
|
|
|
),
|
|
|
|
"prev_state": ViewRequirement("obs", shift=-1),
|
|
|
|
}
|
|
|
|
|
|
|
|
obs_arrs = np.arange(10)[:, None] + 1
|
2022-07-26 21:52:14 -07:00
|
|
|
config = PPOConfig().to_dict()
|
|
|
|
ctx = ConnectorContext(
|
|
|
|
view_requirements=view_rq_dict, config=config, is_policy_recurrent=True
|
|
|
|
)
|
2022-07-25 13:17:17 -07:00
|
|
|
c = ViewRequirementAgentConnector(ctx)
|
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
# keep a running list of observations
|
|
|
|
obs_list = []
|
|
|
|
for t, obs in enumerate(obs_arrs):
|
|
|
|
# t=0 is the next state of t=-1
|
|
|
|
data = AgentConnectorDataType(
|
|
|
|
0, 1, {SampleBatch.NEXT_OBS: obs, SampleBatch.T: t - 1}
|
|
|
|
)
|
|
|
|
processed = c([data]) # env.reset() for t == -1 else env.step()
|
|
|
|
for_action = processed[0].data.for_action
|
|
|
|
# add cur obs to the list
|
|
|
|
obs_list.append(obs)
|
2022-07-25 13:17:17 -07:00
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
if t == 0:
|
|
|
|
check(for_action["prev_state"], for_action["state"])
|
|
|
|
else:
|
|
|
|
# prev state should be equal to the prev time step obs
|
|
|
|
check(for_action["prev_state"], obs_list[-2][None])
|
2022-07-25 13:17:17 -07:00
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
def test_vr_connector_causal_slice(self):
|
|
|
|
"""Test that the ViewRequirementConnector can handle slice shifts correctly."""
|
2022-07-25 13:17:17 -07:00
|
|
|
view_rq_dict = {
|
|
|
|
"state": ViewRequirement("obs"),
|
2022-07-26 21:52:14 -07:00
|
|
|
# shift array should be [-2, -1, 0]
|
2022-07-25 13:17:17 -07:00
|
|
|
"prev_states": ViewRequirement("obs", shift="-2:0"),
|
2022-07-26 21:52:14 -07:00
|
|
|
# shift array should be [-4, -2, 0]
|
2022-07-25 13:17:17 -07:00
|
|
|
"prev_strided_states_even": ViewRequirement("obs", shift="-4:0:2"),
|
|
|
|
# shift array should be [-3, -1]
|
|
|
|
"prev_strided_states_odd": ViewRequirement("obs", shift="-3:0:2"),
|
|
|
|
}
|
|
|
|
|
|
|
|
obs_arrs = np.arange(10)[:, None] + 1
|
2022-07-26 21:52:14 -07:00
|
|
|
config = PPOConfig().to_dict()
|
|
|
|
ctx = ConnectorContext(
|
|
|
|
view_requirements=view_rq_dict, config=config, is_policy_recurrent=True
|
|
|
|
)
|
2022-07-25 13:17:17 -07:00
|
|
|
c = ViewRequirementAgentConnector(ctx)
|
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
# keep a queue of observations
|
|
|
|
obs_list = []
|
|
|
|
for t, obs in enumerate(obs_arrs):
|
|
|
|
# t=0 is the next state of t=-1
|
|
|
|
data = AgentConnectorDataType(
|
|
|
|
0, 1, {SampleBatch.NEXT_OBS: obs, SampleBatch.T: t - 1}
|
|
|
|
)
|
|
|
|
processed = c([data])
|
|
|
|
for_action = processed[0].data.for_action
|
|
|
|
|
|
|
|
if t == 0:
|
|
|
|
obs_list.extend([obs for _ in range(5)])
|
|
|
|
else:
|
|
|
|
# remove the first obs and add the current obs to the end
|
|
|
|
obs_list.pop(0)
|
|
|
|
obs_list.append(obs)
|
|
|
|
|
|
|
|
# check state
|
|
|
|
check(for_action["state"], obs[None])
|
|
|
|
|
|
|
|
# check prev_states
|
|
|
|
check(
|
|
|
|
for_action["prev_states"],
|
|
|
|
np.stack(obs_list)[np.array([-3, -2, -1])][None],
|
|
|
|
)
|
|
|
|
|
|
|
|
# check prev_strided_states_even
|
|
|
|
check(
|
|
|
|
for_action["prev_strided_states_even"],
|
|
|
|
np.stack(obs_list)[np.array([-5, -3, -1])][None],
|
|
|
|
)
|
|
|
|
|
|
|
|
check(
|
|
|
|
for_action["prev_strided_states_odd"],
|
|
|
|
np.stack(obs_list)[np.array([-4, -2])][None],
|
|
|
|
)
|
2022-07-25 13:17:17 -07:00
|
|
|
|
|
|
|
def test_vr_connector_with_multiple_buffers(self):
|
|
|
|
"""Test that the ViewRequirementConnector can handle slice shifts correctly
|
|
|
|
when it has multiple buffers to shift."""
|
|
|
|
context_len = 5
|
|
|
|
# This view requirement simulates the use-case of a decision transformer
|
|
|
|
# without reward-to-go.
|
|
|
|
view_rq_dict = {
|
2022-07-26 21:52:14 -07:00
|
|
|
# obs[t-context_len+1:t]
|
|
|
|
"context_obs": ViewRequirement("obs", shift=f"-{context_len-1}:0"),
|
|
|
|
# next_obs[t-context_len+1:t]
|
|
|
|
"context_next_obs": ViewRequirement(
|
|
|
|
"obs", shift=f"-{context_len}:1", used_for_compute_actions=False
|
|
|
|
),
|
|
|
|
# act[t-context_len+1:t]
|
|
|
|
"context_act": ViewRequirement(
|
|
|
|
SampleBatch.ACTIONS, shift=f"-{context_len-1}:-1"
|
|
|
|
),
|
2022-07-25 13:17:17 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
obs_arrs = np.arange(10)[:, None] + 1
|
2022-07-26 21:52:14 -07:00
|
|
|
act_arrs = (np.arange(10)[:, None] + 1) * 100
|
2022-07-25 13:17:17 -07:00
|
|
|
n_steps = obs_arrs.shape[0]
|
2022-07-26 21:52:14 -07:00
|
|
|
config = PPOConfig().to_dict()
|
|
|
|
ctx = ConnectorContext(
|
|
|
|
view_requirements=view_rq_dict, config=config, is_policy_recurrent=True
|
|
|
|
)
|
2022-07-25 13:17:17 -07:00
|
|
|
c = ViewRequirementAgentConnector(ctx)
|
|
|
|
|
2022-07-26 21:52:14 -07:00
|
|
|
# keep a queue of length ctx_len of observations
|
|
|
|
obs_list, act_list = [], []
|
|
|
|
for t in range(n_steps):
|
|
|
|
# next state and action at time t-1 are the following
|
|
|
|
timestep_data = {
|
|
|
|
SampleBatch.NEXT_OBS: obs_arrs[t],
|
|
|
|
SampleBatch.ACTIONS: (
|
|
|
|
np.zeros_like(act_arrs[0]) if t == 0 else act_arrs[t - 1]
|
|
|
|
),
|
|
|
|
SampleBatch.T: t - 1,
|
|
|
|
}
|
|
|
|
data = AgentConnectorDataType(0, 1, timestep_data)
|
|
|
|
processed = c([data])
|
|
|
|
for_action = processed[0].data.for_action
|
|
|
|
|
|
|
|
if t == 0:
|
|
|
|
obs_list.extend([obs_arrs[0] for _ in range(context_len)])
|
|
|
|
act_list.extend(
|
|
|
|
[np.zeros_like(act_arrs[0]) for _ in range(context_len)]
|
2022-07-25 13:17:17 -07:00
|
|
|
)
|
2022-07-26 21:52:14 -07:00
|
|
|
else:
|
|
|
|
obs_list.pop(0)
|
|
|
|
act_list.pop(0)
|
|
|
|
obs_list.append(obs_arrs[t])
|
|
|
|
act_list.append(act_arrs[t - 1])
|
|
|
|
|
|
|
|
self.assertTrue("context_next_obs" not in for_action)
|
|
|
|
check(for_action["context_obs"], np.stack(obs_list)[None])
|
|
|
|
check(for_action["context_act"], np.stack(act_list[:-1])[None])
|
|
|
|
|
|
|
|
def test_connector_pipline_with_view_requirement(self):
|
|
|
|
"""A very minimal test that checks wheter pipeline connectors work in a
|
|
|
|
simulation rollout."""
|
|
|
|
# TODO: make this test beefier and more comprehensive
|
|
|
|
config = (
|
|
|
|
PPOConfig()
|
|
|
|
.framework("torch")
|
|
|
|
.environment(env="CartPole-v0")
|
|
|
|
.rollouts(create_env_on_local_worker=True)
|
|
|
|
)
|
|
|
|
algo = PPO(config)
|
|
|
|
rollout_worker = algo.workers.local_worker()
|
|
|
|
policy = rollout_worker.get_policy()
|
|
|
|
env = rollout_worker.env
|
|
|
|
|
|
|
|
# create a connector context
|
|
|
|
ctx = ConnectorContext(
|
|
|
|
view_requirements=policy.view_requirements,
|
|
|
|
config=policy.config,
|
|
|
|
initial_states=policy.get_initial_state(),
|
|
|
|
is_policy_recurrent=policy.is_recurrent(),
|
|
|
|
observation_space=policy.observation_space,
|
|
|
|
action_space=policy.action_space,
|
|
|
|
)
|
|
|
|
|
|
|
|
# build chain of connectors
|
|
|
|
connectors = [
|
|
|
|
ObsPreprocessorConnector(ctx),
|
|
|
|
StateBufferConnector(ctx),
|
|
|
|
ViewRequirementAgentConnector(ctx),
|
|
|
|
]
|
|
|
|
agent_connector = AgentConnectorPipeline(ctx, connectors)
|
|
|
|
|
|
|
|
name, params = agent_connector.to_config()
|
|
|
|
restored = get_connector(ctx, name, params)
|
|
|
|
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
|
|
|
for cidx, c in enumerate(connectors):
|
|
|
|
check(restored.connectors[cidx].to_config(), c.to_config())
|
|
|
|
|
|
|
|
# simulate a rollout
|
|
|
|
n_steps = 10
|
|
|
|
obs = env.reset()
|
|
|
|
env_out = AgentConnectorDataType(
|
|
|
|
0, 1, {SampleBatch.NEXT_OBS: obs, SampleBatch.T: -1}
|
|
|
|
)
|
|
|
|
agent_obs = agent_connector([env_out])[0]
|
|
|
|
t = 0
|
|
|
|
total_rewards = 0
|
|
|
|
while t < n_steps:
|
|
|
|
policy_output = policy.compute_actions_from_input_dict(
|
|
|
|
agent_obs.data.for_action
|
|
|
|
)
|
|
|
|
agent_connector.on_policy_output(
|
|
|
|
ActionConnectorDataType(0, 1, policy_output)
|
|
|
|
)
|
|
|
|
action = policy_output[0][0]
|
|
|
|
|
|
|
|
next_obs, rewards, dones, info = env.step(action)
|
|
|
|
env_out_dict = {
|
|
|
|
SampleBatch.NEXT_OBS: next_obs,
|
|
|
|
SampleBatch.REWARDS: rewards,
|
|
|
|
SampleBatch.DONES: dones,
|
|
|
|
SampleBatch.INFOS: info,
|
|
|
|
SampleBatch.ACTIONS: action,
|
|
|
|
SampleBatch.T: t,
|
|
|
|
# state_out
|
|
|
|
}
|
|
|
|
env_out = AgentConnectorDataType(0, 1, env_out_dict)
|
|
|
|
agent_obs = agent_connector([env_out])[0]
|
|
|
|
total_rewards += rewards
|
|
|
|
t += 1
|
|
|
|
print(total_rewards)
|
2022-07-25 13:17:17 -07:00
|
|
|
|
|
|
|
|
2022-06-07 10:18:14 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import sys
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
import pytest
|
|
|
|
|
2022-06-07 10:18:14 -07:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|