mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] DatasetReader action normalization. (#27356)
This commit is contained in:
parent
537f7c65c1
commit
c358305ca6
4 changed files with 154 additions and 56 deletions
|
@ -1,9 +1,11 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Box, Discrete
|
from gym.spaces import Box, Discrete
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
@ -22,6 +24,8 @@ from ray.rllib.examples.env.mock_env import (
|
||||||
)
|
)
|
||||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
|
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
|
||||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||||
|
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
|
||||||
|
from ray.rllib.offline.json_reader import JsonReader
|
||||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||||
from ray.rllib.policy.sample_batch import (
|
from ray.rllib.policy.sample_batch import (
|
||||||
DEFAULT_POLICY_ID,
|
DEFAULT_POLICY_ID,
|
||||||
|
@ -358,6 +362,80 @@ class TestRolloutWorker(unittest.TestCase):
|
||||||
self.assertLess(np.min(sample["actions"]), action_space.low[0])
|
self.assertLess(np.min(sample["actions"]), action_space.low[0])
|
||||||
ev.stop()
|
ev.stop()
|
||||||
|
|
||||||
|
def test_action_normalization_offline_dataset(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# create environment
|
||||||
|
env = gym.make("Pendulum-v1")
|
||||||
|
|
||||||
|
# create temp data with actions at min and max
|
||||||
|
data = {
|
||||||
|
"type": "SampleBatch",
|
||||||
|
"actions": [[2.0], [-2.0]],
|
||||||
|
"dones": [0.0, 0.0],
|
||||||
|
"rewards": [0.0, 0.0],
|
||||||
|
"obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
||||||
|
"new_obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
||||||
|
}
|
||||||
|
|
||||||
|
data_file = os.path.join(tmp_dir, "data.json")
|
||||||
|
|
||||||
|
with open(data_file, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
# create input reader functions
|
||||||
|
def dataset_reader_creator(ioctx):
|
||||||
|
config = {
|
||||||
|
"input": "dataset",
|
||||||
|
"input_config": {"format": "json", "paths": data_file},
|
||||||
|
}
|
||||||
|
_, shards = get_dataset_and_shards(config, num_workers=0)
|
||||||
|
return DatasetReader(shards[0], ioctx)
|
||||||
|
|
||||||
|
def json_reader_creator(ioctx):
|
||||||
|
return JsonReader(data_file, ioctx)
|
||||||
|
|
||||||
|
input_creators = [dataset_reader_creator, json_reader_creator]
|
||||||
|
|
||||||
|
# actions_in_input_normalized, normalize_actions
|
||||||
|
parameters = [
|
||||||
|
(True, True),
|
||||||
|
(True, False),
|
||||||
|
(False, True),
|
||||||
|
(False, False),
|
||||||
|
]
|
||||||
|
|
||||||
|
# check that samples from dataset will be normalized if and only if
|
||||||
|
# actions_in_input_normalized == False and
|
||||||
|
# normalize_actions == True
|
||||||
|
for input_creator in input_creators:
|
||||||
|
for actions_in_input_normalized, normalize_actions in parameters:
|
||||||
|
ev = RolloutWorker(
|
||||||
|
env_creator=lambda _: env,
|
||||||
|
policy_spec=MockPolicy,
|
||||||
|
policy_config=dict(
|
||||||
|
actions_in_input_normalized=actions_in_input_normalized,
|
||||||
|
normalize_actions=normalize_actions,
|
||||||
|
clip_actions=False,
|
||||||
|
offline_sampling=True,
|
||||||
|
train_batch_size=1,
|
||||||
|
),
|
||||||
|
rollout_fragment_length=1,
|
||||||
|
input_creator=input_creator,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = ev.sample()
|
||||||
|
|
||||||
|
if normalize_actions and not actions_in_input_normalized:
|
||||||
|
# check if the samples from dataset are normalized properly
|
||||||
|
self.assertLessEqual(np.max(sample["actions"]), 1.0)
|
||||||
|
self.assertGreaterEqual(np.min(sample["actions"]), -1.0)
|
||||||
|
else:
|
||||||
|
# check if the samples from dataset are not normalized
|
||||||
|
self.assertGreater(np.max(sample["actions"]), 1.5)
|
||||||
|
self.assertLess(np.min(sample["actions"]), -1.5)
|
||||||
|
|
||||||
|
ev.stop()
|
||||||
|
|
||||||
def test_action_immutability(self):
|
def test_action_immutability(self):
|
||||||
from ray.rllib.examples.env.random_env import RandomEnv
|
from ray.rllib.examples.env.random_env import RandomEnv
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import zipfile
|
||||||
import ray.data
|
import ray.data
|
||||||
from ray.rllib.offline.input_reader import InputReader
|
from ray.rllib.offline.input_reader import InputReader
|
||||||
from ray.rllib.offline.io_context import IOContext
|
from ray.rllib.offline.io_context import IOContext
|
||||||
from ray.rllib.offline.json_reader import from_json_data
|
from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
|
||||||
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
|
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
||||||
|
@ -251,6 +251,7 @@ class DatasetReader(InputReader):
|
||||||
d = next(self._iter).as_pydict()
|
d = next(self._iter).as_pydict()
|
||||||
# Columns like obs are compressed when written by DatasetWriter.
|
# Columns like obs are compressed when written by DatasetWriter.
|
||||||
d = from_json_data(d, self._ioctx.worker)
|
d = from_json_data(d, self._ioctx.worker)
|
||||||
|
d = postprocess_actions(d, self._ioctx)
|
||||||
count += d.count
|
count += d.count
|
||||||
ret.append(self._postprocess_if_needed(d))
|
ret.append(self._postprocess_if_needed(d))
|
||||||
ret = concat_samples(ret)
|
ret = concat_samples(ret)
|
||||||
|
|
|
@ -83,6 +83,78 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
|
||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
|
||||||
|
# Clip actions (from any values into env's bounds), if necessary.
|
||||||
|
cfg = ioctx.config
|
||||||
|
# TODO(jungong) : we should not clip_action in input reader.
|
||||||
|
# Use connector to handle this.
|
||||||
|
if cfg.get("clip_actions"):
|
||||||
|
if ioctx.worker is None:
|
||||||
|
raise ValueError(
|
||||||
|
"clip_actions is True but cannot clip actions since no workers exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(batch, SampleBatch):
|
||||||
|
default_policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
|
||||||
|
batch[SampleBatch.ACTIONS] = clip_action(
|
||||||
|
batch[SampleBatch.ACTIONS], default_policy.action_space_struct
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for pid, b in batch.policy_batches.items():
|
||||||
|
b[SampleBatch.ACTIONS] = clip_action(
|
||||||
|
b[SampleBatch.ACTIONS],
|
||||||
|
ioctx.worker.policy_map[pid].action_space_struct,
|
||||||
|
)
|
||||||
|
# Re-normalize actions (from env's bounds to zero-centered), if
|
||||||
|
# necessary.
|
||||||
|
if (
|
||||||
|
cfg.get("actions_in_input_normalized") is False
|
||||||
|
and cfg.get("normalize_actions") is True
|
||||||
|
):
|
||||||
|
if ioctx.worker is None:
|
||||||
|
raise ValueError(
|
||||||
|
"actions_in_input_normalized is False but"
|
||||||
|
"cannot normalize actions since no workers exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we have a complex action space and actions were flattened
|
||||||
|
# and we have to normalize -> Error.
|
||||||
|
error_msg = (
|
||||||
|
"Normalization of offline actions that are flattened is not "
|
||||||
|
"supported! Make sure that you record actions into offline "
|
||||||
|
"file with the `_disable_action_flattening=True` flag OR "
|
||||||
|
"as already normalized (between -1.0 and 1.0) values. "
|
||||||
|
"Also, when reading already normalized action values from "
|
||||||
|
"offline files, make sure to set "
|
||||||
|
"`actions_in_input_normalized=True` so that RLlib will not "
|
||||||
|
"perform normalization on top."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(batch, SampleBatch):
|
||||||
|
pol = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
|
||||||
|
if isinstance(
|
||||||
|
pol.action_space_struct, (tuple, dict)
|
||||||
|
) and not pol.config.get("_disable_action_flattening"):
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
batch[SampleBatch.ACTIONS] = normalize_action(
|
||||||
|
batch[SampleBatch.ACTIONS], pol.action_space_struct
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for pid, b in batch.policy_batches.items():
|
||||||
|
pol = ioctx.worker.policy_map[pid]
|
||||||
|
if isinstance(
|
||||||
|
pol.action_space_struct, (tuple, dict)
|
||||||
|
) and not pol.config.get("_disable_action_flattening"):
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
b[SampleBatch.ACTIONS] = normalize_action(
|
||||||
|
b[SampleBatch.ACTIONS],
|
||||||
|
ioctx.worker.policy_map[pid].action_space_struct,
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
|
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
|
||||||
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
|
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
|
||||||
|
@ -290,61 +362,8 @@ class JsonReader(InputReader):
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Clip actions (from any values into env's bounds), if necessary.
|
batch = postprocess_actions(batch, self.ioctx)
|
||||||
cfg = self.ioctx.config
|
|
||||||
# TODO(jungong) : we should not clip_action in input reader.
|
|
||||||
# Use connector to handle this.
|
|
||||||
if cfg.get("clip_actions") and self.ioctx.worker is not None:
|
|
||||||
if isinstance(batch, SampleBatch):
|
|
||||||
batch[SampleBatch.ACTIONS] = clip_action(
|
|
||||||
batch[SampleBatch.ACTIONS], self.default_policy.action_space_struct
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for pid, b in batch.policy_batches.items():
|
|
||||||
b[SampleBatch.ACTIONS] = clip_action(
|
|
||||||
b[SampleBatch.ACTIONS],
|
|
||||||
self.ioctx.worker.policy_map[pid].action_space_struct,
|
|
||||||
)
|
|
||||||
# Re-normalize actions (from env's bounds to zero-centered), if
|
|
||||||
# necessary.
|
|
||||||
if (
|
|
||||||
cfg.get("actions_in_input_normalized") is False
|
|
||||||
and self.ioctx.worker is not None
|
|
||||||
):
|
|
||||||
|
|
||||||
# If we have a complex action space and actions were flattened
|
|
||||||
# and we have to normalize -> Error.
|
|
||||||
error_msg = (
|
|
||||||
"Normalization of offline actions that are flattened is not "
|
|
||||||
"supported! Make sure that you record actions into offline "
|
|
||||||
"file with the `_disable_action_flattening=True` flag OR "
|
|
||||||
"as already normalized (between -1.0 and 1.0) values. "
|
|
||||||
"Also, when reading already normalized action values from "
|
|
||||||
"offline files, make sure to set "
|
|
||||||
"`actions_in_input_normalized=True` so that RLlib will not "
|
|
||||||
"perform normalization on top."
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(batch, SampleBatch):
|
|
||||||
pol = self.default_policy
|
|
||||||
if isinstance(
|
|
||||||
pol.action_space_struct, (tuple, dict)
|
|
||||||
) and not pol.config.get("_disable_action_flattening"):
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
batch[SampleBatch.ACTIONS] = normalize_action(
|
|
||||||
batch[SampleBatch.ACTIONS], pol.action_space_struct
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for pid, b in batch.policy_batches.items():
|
|
||||||
pol = self.policy_map[pid]
|
|
||||||
if isinstance(
|
|
||||||
pol.action_space_struct, (tuple, dict)
|
|
||||||
) and not pol.config.get("_disable_action_flattening"):
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
b[SampleBatch.ACTIONS] = normalize_action(
|
|
||||||
b[SampleBatch.ACTIONS],
|
|
||||||
self.ioctx.worker.policy_map[pid].action_space_struct,
|
|
||||||
)
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _next_line(self) -> str:
|
def _next_line(self) -> str:
|
||||||
|
|
|
@ -20,7 +20,7 @@ pendulum_crr:
|
||||||
actor_hidden_activation: 'relu'
|
actor_hidden_activation: 'relu'
|
||||||
actor_hiddens: [256, 256]
|
actor_hiddens: [256, 256]
|
||||||
actor_lr: 0.0003
|
actor_lr: 0.0003
|
||||||
actions_in_input_normalized: True
|
actions_in_input_normalized: False
|
||||||
clip_actions: True
|
clip_actions: True
|
||||||
# Q function update setting
|
# Q function update setting
|
||||||
twin_q: True
|
twin_q: True
|
||||||
|
|
Loading…
Add table
Reference in a new issue