mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Support for D4RL + Semi-working CQL Benchmark (#13550)
This commit is contained in:
parent
d11e62f9e6
commit
587f207c2f
5 changed files with 61 additions and 1 deletions
|
@ -15,6 +15,8 @@ CQL_DEFAULT_CONFIG = merge_dicts(
|
|||
SAC_CONFIG, {
|
||||
# You should override this to point to an offline dataset.
|
||||
"input": "sampler",
|
||||
# Offline RL does not need IS estimators
|
||||
"input_evaluation": [],
|
||||
# Number of iterations with Behavior Cloning Pretraining
|
||||
"bc_iters": 20000,
|
||||
# CQL Loss Temperature
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
|
|||
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
ShuffledInput, D4RLReader
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
@ -266,6 +266,9 @@ class WorkerSet:
|
|||
input_creator = (
|
||||
lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
|
||||
config["shuffle_buffer_size"]))
|
||||
elif "d4rl" in config["input"]:
|
||||
env_name = config["input"].split(".")[1]
|
||||
input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
|
||||
else:
|
||||
input_creator = (
|
||||
lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
|
||||
|
|
|
@ -5,6 +5,7 @@ from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
|
|||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.mixed_input import MixedInput
|
||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||
from ray.rllib.offline.d4rl_reader import D4RLReader
|
||||
|
||||
__all__ = [
|
||||
"IOContext",
|
||||
|
@ -15,4 +16,5 @@ __all__ = [
|
|||
"InputReader",
|
||||
"MixedInput",
|
||||
"ShuffledInput",
|
||||
"D4RLReader",
|
||||
]
|
||||
|
|
52
rllib/offline/d4rl_reader.py
Normal file
52
rllib/offline/d4rl_reader.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import logging
|
||||
import gym
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class D4RLReader(InputReader):
|
||||
"""Reader object that loads the dataset from the D4RL dataset."""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, inputs: str, ioctx: IOContext = None):
|
||||
"""Initialize a D4RLReader.
|
||||
|
||||
Args:
|
||||
inputs (str): String corresponding to D4RL environment name
|
||||
ioctx (IOContext): Current IO context object.
|
||||
"""
|
||||
import d4rl
|
||||
self.env = gym.make(inputs)
|
||||
self.dataset = convert_to_batch(d4rl.qlearning_dataset(self.env))
|
||||
assert self.dataset.count >= 1
|
||||
self.dataset.shuffle()
|
||||
self.counter = 0
|
||||
|
||||
@override(InputReader)
|
||||
def next(self) -> SampleBatchType:
|
||||
if self.counter >= self.dataset.count:
|
||||
self.counter = 0
|
||||
self.dataset.shuffle()
|
||||
|
||||
self.counter += 1
|
||||
return self.dataset.slice(start=self.counter, end=self.counter + 1)
|
||||
|
||||
|
||||
def convert_to_batch(dataset: Dict) -> SampleBatchType:
|
||||
# Converts D4RL dataset to SampleBatch
|
||||
d = {}
|
||||
d[SampleBatch.OBS] = dataset["observations"]
|
||||
d[SampleBatch.ACTIONS] = dataset["actions"]
|
||||
d[SampleBatch.NEXT_OBS] = dataset["next_observations"]
|
||||
d[SampleBatch.REWARDS] = dataset["rewards"]
|
||||
d[SampleBatch.DONES] = dataset["terminals"]
|
||||
|
||||
return SampleBatch(d)
|
|
@ -5,6 +5,7 @@ halfcheetah_cql:
|
|||
episode_reward_mean: 9000
|
||||
config:
|
||||
# SAC Configs
|
||||
input: d4rl.halfcheetah-medium-v0
|
||||
framework: torch
|
||||
horizon: 1000
|
||||
soft_horizon: false
|
||||
|
|
Loading…
Add table
Reference in a new issue