[RLlib] Support for D4RL + Semi-working CQL Benchmark (#13550)

This commit is contained in:
Michael Luo 2021-01-21 07:43:55 -08:00 committed by GitHub
parent d11e62f9e6
commit 587f207c2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 61 additions and 1 deletions

View file

@ -15,6 +15,8 @@ CQL_DEFAULT_CONFIG = merge_dicts(
SAC_CONFIG, {
# You should override this to point to an offline dataset.
"input": "sampler",
# Offline RL does not need IS estimators
"input_evaluation": [],
# Number of iterations with Behavior Cloning Pretraining
"bc_iters": 20000,
# CQL Loss Temperature

View file

@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
_validate_multiagent_config
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
ShuffledInput, D4RLReader
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy import Policy
from ray.rllib.utils import merge_dicts
@ -266,6 +266,9 @@ class WorkerSet:
input_creator = (
lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
config["shuffle_buffer_size"]))
elif "d4rl" in config["input"]:
env_name = config["input"].split(".")[1]
input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
else:
input_creator = (
lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),

View file

@ -5,6 +5,7 @@ from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.mixed_input import MixedInput
from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.offline.d4rl_reader import D4RLReader
__all__ = [
"IOContext",
@ -15,4 +16,5 @@ __all__ = [
"InputReader",
"MixedInput",
"ShuffledInput",
"D4RLReader",
]

View file

@ -0,0 +1,52 @@
import logging
import gym
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType
from typing import Dict
logger = logging.getLogger(__name__)
@PublicAPI
class D4RLReader(InputReader):
"""Reader object that loads the dataset from the D4RL dataset."""
@PublicAPI
def __init__(self, inputs: str, ioctx: IOContext = None):
"""Initialize a D4RLReader.
Args:
inputs (str): String corresponding to D4RL environment name
ioctx (IOContext): Current IO context object.
"""
import d4rl
self.env = gym.make(inputs)
self.dataset = convert_to_batch(d4rl.qlearning_dataset(self.env))
assert self.dataset.count >= 1
self.dataset.shuffle()
self.counter = 0
@override(InputReader)
def next(self) -> SampleBatchType:
if self.counter >= self.dataset.count:
self.counter = 0
self.dataset.shuffle()
self.counter += 1
return self.dataset.slice(start=self.counter, end=self.counter + 1)
def convert_to_batch(dataset: Dict) -> SampleBatchType:
# Converts D4RL dataset to SampleBatch
d = {}
d[SampleBatch.OBS] = dataset["observations"]
d[SampleBatch.ACTIONS] = dataset["actions"]
d[SampleBatch.NEXT_OBS] = dataset["next_observations"]
d[SampleBatch.REWARDS] = dataset["rewards"]
d[SampleBatch.DONES] = dataset["terminals"]
return SampleBatch(d)

View file

@ -5,6 +5,7 @@ halfcheetah_cql:
episode_reward_mean: 9000
config:
# SAC Configs
input: d4rl.halfcheetah-medium-v0
framework: torch
horizon: 1000
soft_horizon: false