[RLlib] DatasetReader action normalization. (#27356)

This commit is contained in:
Charles Sun 2022-08-09 07:54:03 -07:00 committed by GitHub
parent 537f7c65c1
commit c358305ca6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 154 additions and 56 deletions

View file

@ -1,9 +1,11 @@
from collections import Counter
import gym
from gym.spaces import Box, Discrete
import json
import numpy as np
import os
import random
import tempfile
import time
import unittest
@ -22,6 +24,8 @@ from ray.rllib.examples.env.mock_env import (
)
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
@ -358,6 +362,80 @@ class TestRolloutWorker(unittest.TestCase):
self.assertLess(np.min(sample["actions"]), action_space.low[0])
ev.stop()
def test_action_normalization_offline_dataset(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create environment
env = gym.make("Pendulum-v1")
# create temp data with actions at min and max
data = {
"type": "SampleBatch",
"actions": [[2.0], [-2.0]],
"dones": [0.0, 0.0],
"rewards": [0.0, 0.0],
"obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"new_obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
}
data_file = os.path.join(tmp_dir, "data.json")
with open(data_file, "w") as f:
json.dump(data, f)
# create input reader functions
def dataset_reader_creator(ioctx):
config = {
"input": "dataset",
"input_config": {"format": "json", "paths": data_file},
}
_, shards = get_dataset_and_shards(config, num_workers=0)
return DatasetReader(shards[0], ioctx)
def json_reader_creator(ioctx):
return JsonReader(data_file, ioctx)
input_creators = [dataset_reader_creator, json_reader_creator]
# actions_in_input_normalized, normalize_actions
parameters = [
(True, True),
(True, False),
(False, True),
(False, False),
]
# check that samples from dataset will be normalized if and only if
# actions_in_input_normalized == False and
# normalize_actions == True
for input_creator in input_creators:
for actions_in_input_normalized, normalize_actions in parameters:
ev = RolloutWorker(
env_creator=lambda _: env,
policy_spec=MockPolicy,
policy_config=dict(
actions_in_input_normalized=actions_in_input_normalized,
normalize_actions=normalize_actions,
clip_actions=False,
offline_sampling=True,
train_batch_size=1,
),
rollout_fragment_length=1,
input_creator=input_creator,
)
sample = ev.sample()
if normalize_actions and not actions_in_input_normalized:
# check if the samples from dataset are normalized properly
self.assertLessEqual(np.max(sample["actions"]), 1.0)
self.assertGreaterEqual(np.min(sample["actions"]), -1.0)
else:
# check if the samples from dataset are not normalized
self.assertGreater(np.max(sample["actions"]), 1.5)
self.assertLess(np.min(sample["actions"]), -1.5)
ev.stop()
def test_action_immutability(self):
from ray.rllib.examples.env.random_env import RandomEnv

View file

@ -8,7 +8,7 @@ import zipfile
import ray.data
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data
from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
@ -251,6 +251,7 @@ class DatasetReader(InputReader):
d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter.
d = from_json_data(d, self._ioctx.worker)
d = postprocess_actions(d, self._ioctx)
count += d.count
ret.append(self._postprocess_if_needed(d))
ret = concat_samples(ret)

View file

@ -83,6 +83,78 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
return json_data
@DeveloperAPI
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
# Clip actions (from any values into env's bounds), if necessary.
cfg = ioctx.config
# TODO(jungong) : we should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions"):
if ioctx.worker is None:
raise ValueError(
"clip_actions is True but cannot clip actions since no workers exist"
)
if isinstance(batch, SampleBatch):
default_policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
batch[SampleBatch.ACTIONS] = clip_action(
batch[SampleBatch.ACTIONS], default_policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
b[SampleBatch.ACTIONS] = clip_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)
# Re-normalize actions (from env's bounds to zero-centered), if
# necessary.
if (
cfg.get("actions_in_input_normalized") is False
and cfg.get("normalize_actions") is True
):
if ioctx.worker is None:
raise ValueError(
"actions_in_input_normalized is False but"
"cannot normalize actions since no workers exist"
)
# If we have a complex action space and actions were flattened
# and we have to normalize -> Error.
error_msg = (
"Normalization of offline actions that are flattened is not "
"supported! Make sure that you record actions into offline "
"file with the `_disable_action_flattening=True` flag OR "
"as already normalized (between -1.0 and 1.0) values. "
"Also, when reading already normalized action values from "
"offline files, make sure to set "
"`actions_in_input_normalized=True` so that RLlib will not "
"perform normalization on top."
)
if isinstance(batch, SampleBatch):
pol = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], pol.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
pol = ioctx.worker.policy_map[pid]
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
b[SampleBatch.ACTIONS] = normalize_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)
return batch
@DeveloperAPI
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
@ -290,61 +362,8 @@ class JsonReader(InputReader):
)
return None
# Clip actions (from any values into env's bounds), if necessary.
cfg = self.ioctx.config
# TODO(jungong) : we should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions") and self.ioctx.worker is not None:
if isinstance(batch, SampleBatch):
batch[SampleBatch.ACTIONS] = clip_action(
batch[SampleBatch.ACTIONS], self.default_policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
b[SampleBatch.ACTIONS] = clip_action(
b[SampleBatch.ACTIONS],
self.ioctx.worker.policy_map[pid].action_space_struct,
)
# Re-normalize actions (from env's bounds to zero-centered), if
# necessary.
if (
cfg.get("actions_in_input_normalized") is False
and self.ioctx.worker is not None
):
batch = postprocess_actions(batch, self.ioctx)
# If we have a complex action space and actions were flattened
# and we have to normalize -> Error.
error_msg = (
"Normalization of offline actions that are flattened is not "
"supported! Make sure that you record actions into offline "
"file with the `_disable_action_flattening=True` flag OR "
"as already normalized (between -1.0 and 1.0) values. "
"Also, when reading already normalized action values from "
"offline files, make sure to set "
"`actions_in_input_normalized=True` so that RLlib will not "
"perform normalization on top."
)
if isinstance(batch, SampleBatch):
pol = self.default_policy
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], pol.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
pol = self.policy_map[pid]
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
b[SampleBatch.ACTIONS] = normalize_action(
b[SampleBatch.ACTIONS],
self.ioctx.worker.policy_map[pid].action_space_struct,
)
return batch
def _next_line(self) -> str:

View file

@ -20,7 +20,7 @@ pendulum_crr:
actor_hidden_activation: 'relu'
actor_hiddens: [256, 256]
actor_lr: 0.0003
actions_in_input_normalized: True
actions_in_input_normalized: False
clip_actions: True
# Q function update setting
twin_q: True