[RLlib] DatasetReader action normalization. (#27356)

This commit is contained in:
Charles Sun 2022-08-09 07:54:03 -07:00 committed by GitHub
parent 537f7c65c1
commit c358305ca6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 154 additions and 56 deletions

View file

@ -1,9 +1,11 @@
from collections import Counter from collections import Counter
import gym import gym
from gym.spaces import Box, Discrete from gym.spaces import Box, Discrete
import json
import numpy as np import numpy as np
import os import os
import random import random
import tempfile
import time import time
import unittest import unittest
@ -22,6 +24,8 @@ from ray.rllib.examples.env.mock_env import (
) )
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import ( from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID, DEFAULT_POLICY_ID,
@ -358,6 +362,80 @@ class TestRolloutWorker(unittest.TestCase):
self.assertLess(np.min(sample["actions"]), action_space.low[0]) self.assertLess(np.min(sample["actions"]), action_space.low[0])
ev.stop() ev.stop()
def test_action_normalization_offline_dataset(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create environment
env = gym.make("Pendulum-v1")
# create temp data with actions at min and max
data = {
"type": "SampleBatch",
"actions": [[2.0], [-2.0]],
"dones": [0.0, 0.0],
"rewards": [0.0, 0.0],
"obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"new_obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
}
data_file = os.path.join(tmp_dir, "data.json")
with open(data_file, "w") as f:
json.dump(data, f)
# create input reader functions
def dataset_reader_creator(ioctx):
config = {
"input": "dataset",
"input_config": {"format": "json", "paths": data_file},
}
_, shards = get_dataset_and_shards(config, num_workers=0)
return DatasetReader(shards[0], ioctx)
def json_reader_creator(ioctx):
return JsonReader(data_file, ioctx)
input_creators = [dataset_reader_creator, json_reader_creator]
# actions_in_input_normalized, normalize_actions
parameters = [
(True, True),
(True, False),
(False, True),
(False, False),
]
# check that samples from dataset will be normalized if and only if
# actions_in_input_normalized == False and
# normalize_actions == True
for input_creator in input_creators:
for actions_in_input_normalized, normalize_actions in parameters:
ev = RolloutWorker(
env_creator=lambda _: env,
policy_spec=MockPolicy,
policy_config=dict(
actions_in_input_normalized=actions_in_input_normalized,
normalize_actions=normalize_actions,
clip_actions=False,
offline_sampling=True,
train_batch_size=1,
),
rollout_fragment_length=1,
input_creator=input_creator,
)
sample = ev.sample()
if normalize_actions and not actions_in_input_normalized:
# check if the samples from dataset are normalized properly
self.assertLessEqual(np.max(sample["actions"]), 1.0)
self.assertGreaterEqual(np.min(sample["actions"]), -1.0)
else:
# check if the samples from dataset are not normalized
self.assertGreater(np.max(sample["actions"]), 1.5)
self.assertLess(np.min(sample["actions"]), -1.5)
ev.stop()
def test_action_immutability(self): def test_action_immutability(self):
from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.examples.env.random_env import RandomEnv

View file

@ -8,7 +8,7 @@ import zipfile
import ray.data import ray.data
from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
@ -251,6 +251,7 @@ class DatasetReader(InputReader):
d = next(self._iter).as_pydict() d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter. # Columns like obs are compressed when written by DatasetWriter.
d = from_json_data(d, self._ioctx.worker) d = from_json_data(d, self._ioctx.worker)
d = postprocess_actions(d, self._ioctx)
count += d.count count += d.count
ret.append(self._postprocess_if_needed(d)) ret.append(self._postprocess_if_needed(d))
ret = concat_samples(ret) ret = concat_samples(ret)

View file

@ -83,6 +83,78 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
return json_data return json_data
@DeveloperAPI
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
# Clip actions (from any values into env's bounds), if necessary.
cfg = ioctx.config
# TODO(jungong) : we should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions"):
if ioctx.worker is None:
raise ValueError(
"clip_actions is True but cannot clip actions since no workers exist"
)
if isinstance(batch, SampleBatch):
default_policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
batch[SampleBatch.ACTIONS] = clip_action(
batch[SampleBatch.ACTIONS], default_policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
b[SampleBatch.ACTIONS] = clip_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)
# Re-normalize actions (from env's bounds to zero-centered), if
# necessary.
if (
cfg.get("actions_in_input_normalized") is False
and cfg.get("normalize_actions") is True
):
if ioctx.worker is None:
raise ValueError(
"actions_in_input_normalized is False but"
"cannot normalize actions since no workers exist"
)
# If we have a complex action space and actions were flattened
# and we have to normalize -> Error.
error_msg = (
"Normalization of offline actions that are flattened is not "
"supported! Make sure that you record actions into offline "
"file with the `_disable_action_flattening=True` flag OR "
"as already normalized (between -1.0 and 1.0) values. "
"Also, when reading already normalized action values from "
"offline files, make sure to set "
"`actions_in_input_normalized=True` so that RLlib will not "
"perform normalization on top."
)
if isinstance(batch, SampleBatch):
pol = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], pol.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
pol = ioctx.worker.policy_map[pid]
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
b[SampleBatch.ACTIONS] = normalize_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)
return batch
@DeveloperAPI @DeveloperAPI
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]): def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch). # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
@ -290,61 +362,8 @@ class JsonReader(InputReader):
) )
return None return None
# Clip actions (from any values into env's bounds), if necessary. batch = postprocess_actions(batch, self.ioctx)
cfg = self.ioctx.config
# TODO(jungong) : we should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions") and self.ioctx.worker is not None:
if isinstance(batch, SampleBatch):
batch[SampleBatch.ACTIONS] = clip_action(
batch[SampleBatch.ACTIONS], self.default_policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
b[SampleBatch.ACTIONS] = clip_action(
b[SampleBatch.ACTIONS],
self.ioctx.worker.policy_map[pid].action_space_struct,
)
# Re-normalize actions (from env's bounds to zero-centered), if
# necessary.
if (
cfg.get("actions_in_input_normalized") is False
and self.ioctx.worker is not None
):
# If we have a complex action space and actions were flattened
# and we have to normalize -> Error.
error_msg = (
"Normalization of offline actions that are flattened is not "
"supported! Make sure that you record actions into offline "
"file with the `_disable_action_flattening=True` flag OR "
"as already normalized (between -1.0 and 1.0) values. "
"Also, when reading already normalized action values from "
"offline files, make sure to set "
"`actions_in_input_normalized=True` so that RLlib will not "
"perform normalization on top."
)
if isinstance(batch, SampleBatch):
pol = self.default_policy
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], pol.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
pol = self.policy_map[pid]
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
b[SampleBatch.ACTIONS] = normalize_action(
b[SampleBatch.ACTIONS],
self.ioctx.worker.policy_map[pid].action_space_struct,
)
return batch return batch
def _next_line(self) -> str: def _next_line(self) -> str:

View file

@ -20,7 +20,7 @@ pendulum_crr:
actor_hidden_activation: 'relu' actor_hidden_activation: 'relu'
actor_hiddens: [256, 256] actor_hiddens: [256, 256]
actor_lr: 0.0003 actor_lr: 0.0003
actions_in_input_normalized: True actions_in_input_normalized: False
clip_actions: True clip_actions: True
# Q function update setting # Q function update setting
twin_q: True twin_q: True