mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] DatasetReader action normalization. (#27356)
This commit is contained in:
parent
537f7c65c1
commit
c358305ca6
4 changed files with 154 additions and 56 deletions
|
@ -1,9 +1,11 @@
|
|||
from collections import Counter
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
|
@ -22,6 +24,8 @@ from ray.rllib.examples.env.mock_env import (
|
|||
)
|
||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.sample_batch import (
|
||||
DEFAULT_POLICY_ID,
|
||||
|
@ -358,6 +362,80 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
self.assertLess(np.min(sample["actions"]), action_space.low[0])
|
||||
ev.stop()
|
||||
|
||||
def test_action_normalization_offline_dataset(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# create environment
|
||||
env = gym.make("Pendulum-v1")
|
||||
|
||||
# create temp data with actions at min and max
|
||||
data = {
|
||||
"type": "SampleBatch",
|
||||
"actions": [[2.0], [-2.0]],
|
||||
"dones": [0.0, 0.0],
|
||||
"rewards": [0.0, 0.0],
|
||||
"obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
||||
"new_obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
||||
}
|
||||
|
||||
data_file = os.path.join(tmp_dir, "data.json")
|
||||
|
||||
with open(data_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
# create input reader functions
|
||||
def dataset_reader_creator(ioctx):
|
||||
config = {
|
||||
"input": "dataset",
|
||||
"input_config": {"format": "json", "paths": data_file},
|
||||
}
|
||||
_, shards = get_dataset_and_shards(config, num_workers=0)
|
||||
return DatasetReader(shards[0], ioctx)
|
||||
|
||||
def json_reader_creator(ioctx):
|
||||
return JsonReader(data_file, ioctx)
|
||||
|
||||
input_creators = [dataset_reader_creator, json_reader_creator]
|
||||
|
||||
# actions_in_input_normalized, normalize_actions
|
||||
parameters = [
|
||||
(True, True),
|
||||
(True, False),
|
||||
(False, True),
|
||||
(False, False),
|
||||
]
|
||||
|
||||
# check that samples from dataset will be normalized if and only if
|
||||
# actions_in_input_normalized == False and
|
||||
# normalize_actions == True
|
||||
for input_creator in input_creators:
|
||||
for actions_in_input_normalized, normalize_actions in parameters:
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: env,
|
||||
policy_spec=MockPolicy,
|
||||
policy_config=dict(
|
||||
actions_in_input_normalized=actions_in_input_normalized,
|
||||
normalize_actions=normalize_actions,
|
||||
clip_actions=False,
|
||||
offline_sampling=True,
|
||||
train_batch_size=1,
|
||||
),
|
||||
rollout_fragment_length=1,
|
||||
input_creator=input_creator,
|
||||
)
|
||||
|
||||
sample = ev.sample()
|
||||
|
||||
if normalize_actions and not actions_in_input_normalized:
|
||||
# check if the samples from dataset are normalized properly
|
||||
self.assertLessEqual(np.max(sample["actions"]), 1.0)
|
||||
self.assertGreaterEqual(np.min(sample["actions"]), -1.0)
|
||||
else:
|
||||
# check if the samples from dataset are not normalized
|
||||
self.assertGreater(np.max(sample["actions"]), 1.5)
|
||||
self.assertLess(np.min(sample["actions"]), -1.5)
|
||||
|
||||
ev.stop()
|
||||
|
||||
def test_action_immutability(self):
|
||||
from ray.rllib.examples.env.random_env import RandomEnv
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import zipfile
|
|||
import ray.data
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.json_reader import from_json_data
|
||||
from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
|
||||
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
||||
|
@ -251,6 +251,7 @@ class DatasetReader(InputReader):
|
|||
d = next(self._iter).as_pydict()
|
||||
# Columns like obs are compressed when written by DatasetWriter.
|
||||
d = from_json_data(d, self._ioctx.worker)
|
||||
d = postprocess_actions(d, self._ioctx)
|
||||
count += d.count
|
||||
ret.append(self._postprocess_if_needed(d))
|
||||
ret = concat_samples(ret)
|
||||
|
|
|
@ -83,6 +83,78 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
|
|||
return json_data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
|
||||
# Clip actions (from any values into env's bounds), if necessary.
|
||||
cfg = ioctx.config
|
||||
# TODO(jungong) : we should not clip_action in input reader.
|
||||
# Use connector to handle this.
|
||||
if cfg.get("clip_actions"):
|
||||
if ioctx.worker is None:
|
||||
raise ValueError(
|
||||
"clip_actions is True but cannot clip actions since no workers exist"
|
||||
)
|
||||
|
||||
if isinstance(batch, SampleBatch):
|
||||
default_policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
|
||||
batch[SampleBatch.ACTIONS] = clip_action(
|
||||
batch[SampleBatch.ACTIONS], default_policy.action_space_struct
|
||||
)
|
||||
else:
|
||||
for pid, b in batch.policy_batches.items():
|
||||
b[SampleBatch.ACTIONS] = clip_action(
|
||||
b[SampleBatch.ACTIONS],
|
||||
ioctx.worker.policy_map[pid].action_space_struct,
|
||||
)
|
||||
# Re-normalize actions (from env's bounds to zero-centered), if
|
||||
# necessary.
|
||||
if (
|
||||
cfg.get("actions_in_input_normalized") is False
|
||||
and cfg.get("normalize_actions") is True
|
||||
):
|
||||
if ioctx.worker is None:
|
||||
raise ValueError(
|
||||
"actions_in_input_normalized is False but"
|
||||
"cannot normalize actions since no workers exist"
|
||||
)
|
||||
|
||||
# If we have a complex action space and actions were flattened
|
||||
# and we have to normalize -> Error.
|
||||
error_msg = (
|
||||
"Normalization of offline actions that are flattened is not "
|
||||
"supported! Make sure that you record actions into offline "
|
||||
"file with the `_disable_action_flattening=True` flag OR "
|
||||
"as already normalized (between -1.0 and 1.0) values. "
|
||||
"Also, when reading already normalized action values from "
|
||||
"offline files, make sure to set "
|
||||
"`actions_in_input_normalized=True` so that RLlib will not "
|
||||
"perform normalization on top."
|
||||
)
|
||||
|
||||
if isinstance(batch, SampleBatch):
|
||||
pol = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
|
||||
if isinstance(
|
||||
pol.action_space_struct, (tuple, dict)
|
||||
) and not pol.config.get("_disable_action_flattening"):
|
||||
raise ValueError(error_msg)
|
||||
batch[SampleBatch.ACTIONS] = normalize_action(
|
||||
batch[SampleBatch.ACTIONS], pol.action_space_struct
|
||||
)
|
||||
else:
|
||||
for pid, b in batch.policy_batches.items():
|
||||
pol = ioctx.worker.policy_map[pid]
|
||||
if isinstance(
|
||||
pol.action_space_struct, (tuple, dict)
|
||||
) and not pol.config.get("_disable_action_flattening"):
|
||||
raise ValueError(error_msg)
|
||||
b[SampleBatch.ACTIONS] = normalize_action(
|
||||
b[SampleBatch.ACTIONS],
|
||||
ioctx.worker.policy_map[pid].action_space_struct,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
|
||||
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
|
||||
|
@ -290,61 +362,8 @@ class JsonReader(InputReader):
|
|||
)
|
||||
return None
|
||||
|
||||
# Clip actions (from any values into env's bounds), if necessary.
|
||||
cfg = self.ioctx.config
|
||||
# TODO(jungong) : we should not clip_action in input reader.
|
||||
# Use connector to handle this.
|
||||
if cfg.get("clip_actions") and self.ioctx.worker is not None:
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch[SampleBatch.ACTIONS] = clip_action(
|
||||
batch[SampleBatch.ACTIONS], self.default_policy.action_space_struct
|
||||
)
|
||||
else:
|
||||
for pid, b in batch.policy_batches.items():
|
||||
b[SampleBatch.ACTIONS] = clip_action(
|
||||
b[SampleBatch.ACTIONS],
|
||||
self.ioctx.worker.policy_map[pid].action_space_struct,
|
||||
)
|
||||
# Re-normalize actions (from env's bounds to zero-centered), if
|
||||
# necessary.
|
||||
if (
|
||||
cfg.get("actions_in_input_normalized") is False
|
||||
and self.ioctx.worker is not None
|
||||
):
|
||||
batch = postprocess_actions(batch, self.ioctx)
|
||||
|
||||
# If we have a complex action space and actions were flattened
|
||||
# and we have to normalize -> Error.
|
||||
error_msg = (
|
||||
"Normalization of offline actions that are flattened is not "
|
||||
"supported! Make sure that you record actions into offline "
|
||||
"file with the `_disable_action_flattening=True` flag OR "
|
||||
"as already normalized (between -1.0 and 1.0) values. "
|
||||
"Also, when reading already normalized action values from "
|
||||
"offline files, make sure to set "
|
||||
"`actions_in_input_normalized=True` so that RLlib will not "
|
||||
"perform normalization on top."
|
||||
)
|
||||
|
||||
if isinstance(batch, SampleBatch):
|
||||
pol = self.default_policy
|
||||
if isinstance(
|
||||
pol.action_space_struct, (tuple, dict)
|
||||
) and not pol.config.get("_disable_action_flattening"):
|
||||
raise ValueError(error_msg)
|
||||
batch[SampleBatch.ACTIONS] = normalize_action(
|
||||
batch[SampleBatch.ACTIONS], pol.action_space_struct
|
||||
)
|
||||
else:
|
||||
for pid, b in batch.policy_batches.items():
|
||||
pol = self.policy_map[pid]
|
||||
if isinstance(
|
||||
pol.action_space_struct, (tuple, dict)
|
||||
) and not pol.config.get("_disable_action_flattening"):
|
||||
raise ValueError(error_msg)
|
||||
b[SampleBatch.ACTIONS] = normalize_action(
|
||||
b[SampleBatch.ACTIONS],
|
||||
self.ioctx.worker.policy_map[pid].action_space_struct,
|
||||
)
|
||||
return batch
|
||||
|
||||
def _next_line(self) -> str:
|
||||
|
|
|
@ -20,7 +20,7 @@ pendulum_crr:
|
|||
actor_hidden_activation: 'relu'
|
||||
actor_hiddens: [256, 256]
|
||||
actor_lr: 0.0003
|
||||
actions_in_input_normalized: True
|
||||
actions_in_input_normalized: False
|
||||
clip_actions: True
|
||||
# Q function update setting
|
||||
twin_q: True
|
||||
|
|
Loading…
Add table
Reference in a new issue