mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Make Dataset reader default reader and enable CRR to use dataset (#26304)
Co-authored-by: avnish <avnish@avnishs-MBP.local.meter>
This commit is contained in:
parent
61c9e761f3
commit
1243ed62bf
18 changed files with 277 additions and 216 deletions
22
rllib/BUILD
22
rllib/BUILD
|
@ -254,17 +254,17 @@ py_test(
|
||||||
|
|
||||||
# CRR
|
# CRR
|
||||||
py_test(
|
py_test(
|
||||||
name = "learning_tests_pendulum_crr",
|
name = "learning_tests_pendulum_crr",
|
||||||
main = "tests/run_regression_tests.py",
|
main = "tests/run_regression_tests.py",
|
||||||
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
tags = ["team:rllib", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["tests/run_regression_tests.py"],
|
srcs = ["tests/run_regression_tests.py"],
|
||||||
# Include an offline json data file as well.
|
# Include an offline json data file as well.
|
||||||
data = [
|
data = [
|
||||||
"tuned_examples/crr/pendulum-v1-crr.yaml",
|
"tuned_examples/crr/pendulum-v1-crr.yaml",
|
||||||
"tests/data/pendulum/pendulum_replay_v1.1.0.zip",
|
"tests/data/pendulum/pendulum_replay_v1.1.0.zip",
|
||||||
],
|
],
|
||||||
args = ["--yaml-dir=tuned_examples/crr"]
|
args = ["--yaml-dir=tuned_examples/crr"]
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -14,7 +14,6 @@ from ray.rllib.execution.train_ops import (
|
||||||
multi_gpu_train_one_step,
|
multi_gpu_train_one_step,
|
||||||
train_one_step,
|
train_one_step,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.deprecation import (
|
from ray.rllib.utils.deprecation import (
|
||||||
|
@ -32,8 +31,8 @@ from ray.rllib.utils.metrics import (
|
||||||
NUM_TARGET_UPDATES,
|
NUM_TARGET_UPDATES,
|
||||||
TARGET_NET_UPDATE_TIMER,
|
TARGET_NET_UPDATE_TIMER,
|
||||||
SYNCH_WORKER_WEIGHTS_TIMER,
|
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||||
|
SAMPLE_TIMER,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
|
|
||||||
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
|
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
@ -177,23 +176,11 @@ class CQL(SAC):
|
||||||
@override(SAC)
|
@override(SAC)
|
||||||
def training_step(self) -> ResultDict:
|
def training_step(self) -> ResultDict:
|
||||||
# Collect SampleBatches from sample workers.
|
# Collect SampleBatches from sample workers.
|
||||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
with self._timers[SAMPLE_TIMER]:
|
||||||
batch = batch.as_multi_agent()
|
train_batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
train_batch = train_batch.as_multi_agent()
|
||||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
||||||
# Add batch to replay buffer.
|
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
||||||
self.local_replay_buffer.add(batch)
|
|
||||||
|
|
||||||
# Sample training batch from replay buffer.
|
|
||||||
train_batch = sample_min_n_steps_from_buffer(
|
|
||||||
self.local_replay_buffer,
|
|
||||||
self.config["train_batch_size"],
|
|
||||||
count_by_agent_steps=self._by_agent_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Old-style replay buffers return None if learning has not started
|
|
||||||
if not train_batch:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Postprocess batch before we learn on it.
|
# Postprocess batch before we learn on it.
|
||||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||||
|
@ -207,14 +194,6 @@ class CQL(SAC):
|
||||||
else:
|
else:
|
||||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||||
|
|
||||||
# Update replay buffer priorities.
|
|
||||||
update_priorities_in_replay_buffer(
|
|
||||||
self.local_replay_buffer,
|
|
||||||
self.config,
|
|
||||||
train_batch,
|
|
||||||
train_results,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update target network every `target_network_update_freq` training steps.
|
# Update target network every `target_network_update_freq` training steps.
|
||||||
cur_ts = self._counters[
|
cur_ts = self._counters[
|
||||||
NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED
|
NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED
|
||||||
|
|
|
@ -69,7 +69,8 @@ class TestCQL(unittest.TestCase):
|
||||||
evaluation_parallel_to_training=False,
|
evaluation_parallel_to_training=False,
|
||||||
evaluation_num_workers=2,
|
evaluation_num_workers=2,
|
||||||
)
|
)
|
||||||
.rollouts(rollout_fragment_length=1)
|
.rollouts(num_rollout_workers=0)
|
||||||
|
.reporting(min_time_s_per_iteration=0.0)
|
||||||
)
|
)
|
||||||
num_iterations = 4
|
num_iterations = 4
|
||||||
|
|
||||||
|
@ -85,7 +86,6 @@ class TestCQL(unittest.TestCase):
|
||||||
f"iter={trainer.iteration} "
|
f"iter={trainer.iteration} "
|
||||||
f"R={eval_results['episode_reward_mean']}"
|
f"R={eval_results['episode_reward_mean']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(trainer)
|
||||||
|
|
||||||
# Get policy and model.
|
# Get policy and model.
|
||||||
|
@ -97,9 +97,9 @@ class TestCQL(unittest.TestCase):
|
||||||
# Example on how to do evaluation on the trained Trainer
|
# Example on how to do evaluation on the trained Trainer
|
||||||
# using the data from CQL's global replay buffer.
|
# using the data from CQL's global replay buffer.
|
||||||
# Get a sample (MultiAgentBatch).
|
# Get a sample (MultiAgentBatch).
|
||||||
multi_agent_batch = trainer.local_replay_buffer.sample(
|
|
||||||
num_items=config.train_batch_size
|
batch = trainer.workers.local_worker().input_reader.next()
|
||||||
)
|
multi_agent_batch = batch.as_multi_agent()
|
||||||
# All experiences have been buffered for `default_policy`
|
# All experiences have been buffered for `default_policy`
|
||||||
batch = multi_agent_batch.policy_batches["default_policy"]
|
batch = multi_agent_batch.policy_batches["default_policy"]
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tree
|
|
||||||
|
|
||||||
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
|
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
|
||||||
|
from ray.rllib.execution import synchronous_parallel_sample
|
||||||
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
|
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
|
||||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.metrics import (
|
from ray.rllib.utils.metrics import (
|
||||||
LAST_TARGET_UPDATE_TS,
|
LAST_TARGET_UPDATE_TS,
|
||||||
|
NUM_AGENT_STEPS_TRAINED,
|
||||||
|
NUM_ENV_STEPS_TRAINED,
|
||||||
NUM_TARGET_UPDATES,
|
NUM_TARGET_UPDATES,
|
||||||
TARGET_NET_UPDATE_TIMER,
|
TARGET_NET_UPDATE_TIMER,
|
||||||
|
NUM_AGENT_STEPS_SAMPLED,
|
||||||
|
NUM_ENV_STEPS_SAMPLED,
|
||||||
|
SAMPLE_TIMER,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
|
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
AlgorithmConfigDict,
|
AlgorithmConfigDict,
|
||||||
PartialAlgorithmConfigDict,
|
PartialAlgorithmConfigDict,
|
||||||
|
@ -38,19 +38,13 @@ class CRRConfig(AlgorithmConfig):
|
||||||
self.advantage_type = "mean"
|
self.advantage_type = "mean"
|
||||||
self.n_action_sample = 4
|
self.n_action_sample = 4
|
||||||
self.twin_q = True
|
self.twin_q = True
|
||||||
self.target_update_grad_intervals = 100
|
self.train_batch_size = 128
|
||||||
|
|
||||||
|
# target_network_update_freq by default is 100 * train_batch_size
|
||||||
|
# if target_network_update_freq is not set. See self.setup for code.
|
||||||
|
self.target_network_update_freq = None
|
||||||
# __sphinx_doc_end__
|
# __sphinx_doc_end__
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.replay_buffer_config = {
|
|
||||||
"type": MultiAgentReplayBuffer,
|
|
||||||
"capacity": 50000,
|
|
||||||
# How many steps of the model to sample before learning starts.
|
|
||||||
"learning_starts": 1000,
|
|
||||||
"replay_batch_size": 32,
|
|
||||||
# The number of contiguous environment steps to replay at once. This
|
|
||||||
# may be set to greater than 1 to support recurrent models.
|
|
||||||
"replay_sequence_length": 1,
|
|
||||||
}
|
|
||||||
self.actor_hiddens = [256, 256]
|
self.actor_hiddens = [256, 256]
|
||||||
self.actor_hidden_activation = "relu"
|
self.actor_hidden_activation = "relu"
|
||||||
self.critic_hiddens = [256, 256]
|
self.critic_hiddens = [256, 256]
|
||||||
|
@ -60,7 +54,10 @@ class CRRConfig(AlgorithmConfig):
|
||||||
self.tau = 5e-3
|
self.tau = 5e-3
|
||||||
|
|
||||||
# overriding the trainer config default
|
# overriding the trainer config default
|
||||||
self.num_workers = 0 # offline RL does not need rollout workers
|
# If data ingestion/sample_time is slow, increase this
|
||||||
|
self.num_workers = 4
|
||||||
|
self.offline_sampling = True
|
||||||
|
self.min_iter_time_s = 10.0
|
||||||
|
|
||||||
def training(
|
def training(
|
||||||
self,
|
self,
|
||||||
|
@ -71,8 +68,7 @@ class CRRConfig(AlgorithmConfig):
|
||||||
advantage_type: Optional[str] = None,
|
advantage_type: Optional[str] = None,
|
||||||
n_action_sample: Optional[int] = None,
|
n_action_sample: Optional[int] = None,
|
||||||
twin_q: Optional[bool] = None,
|
twin_q: Optional[bool] = None,
|
||||||
target_update_grad_intervals: Optional[int] = None,
|
target_network_update_freq: Optional[int] = None,
|
||||||
replay_buffer_config: Optional[dict] = None,
|
|
||||||
actor_hiddens: Optional[List[int]] = None,
|
actor_hiddens: Optional[List[int]] = None,
|
||||||
actor_hidden_activation: Optional[str] = None,
|
actor_hidden_activation: Optional[str] = None,
|
||||||
critic_hiddens: Optional[List[int]] = None,
|
critic_hiddens: Optional[List[int]] = None,
|
||||||
|
@ -110,10 +106,9 @@ class CRRConfig(AlgorithmConfig):
|
||||||
a^j)]
|
a^j)]
|
||||||
n_action_sample: the number of actions to sample for v_t estimation.
|
n_action_sample: the number of actions to sample for v_t estimation.
|
||||||
twin_q: if True, uses pessimistic q estimation.
|
twin_q: if True, uses pessimistic q estimation.
|
||||||
target_update_grad_intervals: The frequency at which we update the
|
target_network_update_freq: The frequency at which we update the
|
||||||
target copy of the model in terms of the number of gradient updates
|
target copy of the model in terms of the number of gradient updates
|
||||||
applied to the main model.
|
applied to the main model.
|
||||||
replay_buffer_config: The config dictionary for replay buffer.
|
|
||||||
actor_hiddens: The number of hidden units in the actor's fc network.
|
actor_hiddens: The number of hidden units in the actor's fc network.
|
||||||
actor_hidden_activation: The activation used in the actor's fc network.
|
actor_hidden_activation: The activation used in the actor's fc network.
|
||||||
critic_hiddens: The number of hidden units in the critic's fc network.
|
critic_hiddens: The number of hidden units in the critic's fc network.
|
||||||
|
@ -139,10 +134,8 @@ class CRRConfig(AlgorithmConfig):
|
||||||
self.n_action_sample = n_action_sample
|
self.n_action_sample = n_action_sample
|
||||||
if twin_q is not None:
|
if twin_q is not None:
|
||||||
self.twin_q = twin_q
|
self.twin_q = twin_q
|
||||||
if target_update_grad_intervals is not None:
|
if target_network_update_freq is not None:
|
||||||
self.target_update_grad_intervals = target_update_grad_intervals
|
self.target_network_update_freq = target_network_update_freq
|
||||||
if replay_buffer_config is not None:
|
|
||||||
self.replay_buffer_config = replay_buffer_config
|
|
||||||
if actor_hiddens is not None:
|
if actor_hiddens is not None:
|
||||||
self.actor_hiddens = actor_hiddens
|
self.actor_hiddens = actor_hiddens
|
||||||
if actor_hidden_activation is not None:
|
if actor_hidden_activation is not None:
|
||||||
|
@ -168,44 +161,10 @@ class CRR(Algorithm):
|
||||||
|
|
||||||
def setup(self, config: PartialAlgorithmConfigDict):
|
def setup(self, config: PartialAlgorithmConfigDict):
|
||||||
super().setup(config)
|
super().setup(config)
|
||||||
# initial setup for handling the offline data in form of a replay buffer
|
if self.config.get("target_network_update_freq", None) is None:
|
||||||
# Add the entire dataset to Replay Buffer (global variable)
|
self.config["target_network_update_freq"] = (
|
||||||
reader = self.workers.local_worker().input_reader
|
self.config["train_batch_size"] * 100
|
||||||
|
|
||||||
# For d4rl, add the D4RLReaders' dataset to the buffer.
|
|
||||||
if isinstance(self.config["input"], str) and "d4rl" in self.config["input"]:
|
|
||||||
dataset = reader.dataset
|
|
||||||
self.local_replay_buffer.add(dataset)
|
|
||||||
# For a list of files, add each file's entire content to the buffer.
|
|
||||||
elif isinstance(reader, ShuffledInput):
|
|
||||||
num_batches = 0
|
|
||||||
total_timesteps = 0
|
|
||||||
for batch in reader.child.read_all_files():
|
|
||||||
num_batches += 1
|
|
||||||
total_timesteps += len(batch)
|
|
||||||
# Add NEXT_OBS if not available. This is slightly hacked
|
|
||||||
# as for the very last time step, we will use next-obs=zeros
|
|
||||||
# and therefore force-set DONE=True to avoid this missing
|
|
||||||
# next-obs to cause learning problems.
|
|
||||||
if SampleBatch.NEXT_OBS not in batch:
|
|
||||||
obs = batch[SampleBatch.OBS]
|
|
||||||
batch[SampleBatch.NEXT_OBS] = np.concatenate(
|
|
||||||
[obs[1:], np.zeros_like(obs[0:1])]
|
|
||||||
)
|
|
||||||
batch[SampleBatch.DONES][-1] = True
|
|
||||||
self.local_replay_buffer.add(batch)
|
|
||||||
print(
|
|
||||||
f"Loaded {num_batches} batches ({total_timesteps} ts) into the"
|
|
||||||
" replay buffer, which has capacity "
|
|
||||||
f"{self.local_replay_buffer.capacity}."
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Unknown offline input! config['input'] must either be list of"
|
|
||||||
" offline files (json) or a D4RL-specific InputReader "
|
|
||||||
"specifier (e.g. 'd4rl.hopper-medium-v0')."
|
|
||||||
)
|
|
||||||
|
|
||||||
# added a counter key for keeping track of number of gradient updates
|
# added a counter key for keeping track of number of gradient updates
|
||||||
self._counters[NUM_GRADIENT_UPDATES] = 0
|
self._counters[NUM_GRADIENT_UPDATES] = 0
|
||||||
# if I don't set this here to zero I won't see zero in the logs (defaultdict)
|
# if I don't set this here to zero I won't see zero in the logs (defaultdict)
|
||||||
|
@ -227,47 +186,39 @@ class CRR(Algorithm):
|
||||||
|
|
||||||
@override(Algorithm)
|
@override(Algorithm)
|
||||||
def training_step(self) -> ResultDict:
|
def training_step(self) -> ResultDict:
|
||||||
|
with self._timers[SAMPLE_TIMER]:
|
||||||
|
train_batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||||
|
train_batch = train_batch.as_multi_agent()
|
||||||
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
||||||
|
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
||||||
|
|
||||||
total_transitions = len(self.local_replay_buffer)
|
# Postprocess batch before we learn on it.
|
||||||
bsize = self.config["train_batch_size"]
|
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||||
n_batches_per_epoch = total_transitions // bsize
|
train_batch = post_fn(train_batch, self.workers, self.config)
|
||||||
|
|
||||||
results = []
|
# Learn on training batch.
|
||||||
for batch_iter in range(n_batches_per_epoch):
|
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||||
# Sample training batch from replay buffer.
|
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||||
train_batch = self.local_replay_buffer.sample(bsize)
|
if self.config.get("simple_optimizer", False):
|
||||||
|
train_results = train_one_step(self, train_batch)
|
||||||
|
else:
|
||||||
|
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||||
|
|
||||||
# Postprocess batch before we learn on it.
|
# update target every few gradient updates
|
||||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
# Update target network every `target_network_update_freq` training steps.
|
||||||
train_batch = post_fn(train_batch, self.workers, self.config)
|
cur_ts = self._counters[
|
||||||
|
NUM_AGENT_STEPS_TRAINED if self._by_agent_steps else NUM_ENV_STEPS_TRAINED
|
||||||
|
]
|
||||||
|
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||||
|
|
||||||
# Learn on training batch.
|
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
||||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
to_update = self.workers.local_worker().get_policies_to_train()
|
||||||
if self.config.get("simple_optimizer", False):
|
self.workers.local_worker().foreach_policy_to_train(
|
||||||
train_results = train_one_step(self, train_batch)
|
lambda p, pid: pid in to_update and p.update_target()
|
||||||
else:
|
)
|
||||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
self._counters[NUM_TARGET_UPDATES] += 1
|
||||||
|
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||||
|
|
||||||
# update target every few gradient updates
|
self._counters[NUM_GRADIENT_UPDATES] += 1
|
||||||
cur_ts = self._counters[NUM_GRADIENT_UPDATES]
|
return train_results
|
||||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
|
||||||
|
|
||||||
if cur_ts - last_update >= self.config["target_update_grad_intervals"]:
|
|
||||||
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
|
||||||
to_update = self.workers.local_worker().get_policies_to_train()
|
|
||||||
self.workers.local_worker().foreach_policy_to_train(
|
|
||||||
lambda p, pid: pid in to_update and p.update_target()
|
|
||||||
)
|
|
||||||
self._counters[NUM_TARGET_UPDATES] += 1
|
|
||||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
|
||||||
|
|
||||||
self._counters[NUM_GRADIENT_UPDATES] += 1
|
|
||||||
|
|
||||||
results.append(train_results)
|
|
||||||
|
|
||||||
summary = tree.map_structure_with_path(
|
|
||||||
lambda path, *v: float(np.mean(v)), *results
|
|
||||||
)
|
|
||||||
|
|
||||||
return summary
|
|
||||||
|
|
|
@ -4,8 +4,8 @@ import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.algorithms.crr import CRRConfig
|
from ray.rllib.algorithms.crr import CRRConfig
|
||||||
|
from ray.rllib.offline.json_reader import JsonReader
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
|
|
||||||
from ray.rllib.utils.test_utils import (
|
from ray.rllib.utils.test_utils import (
|
||||||
check_compute_single_action,
|
check_compute_single_action,
|
||||||
check_train_results,
|
check_train_results,
|
||||||
|
@ -32,24 +32,31 @@ class TestCRR(unittest.TestCase):
|
||||||
print("rllib dir={}".format(rllib_dir))
|
print("rllib dir={}".format(rllib_dir))
|
||||||
data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json")
|
data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json")
|
||||||
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||||
|
# Will use the Json Reader in this example until we convert over the example
|
||||||
|
# files over to Parquet, since the dataset json reader cannot handle large
|
||||||
|
# block sizes.
|
||||||
|
|
||||||
|
def input_reading_fn(ioctx):
|
||||||
|
return JsonReader(ioctx.config["input_config"]["paths"], ioctx)
|
||||||
|
|
||||||
|
input_config = {"paths": data_file}
|
||||||
|
|
||||||
config = (
|
config = (
|
||||||
CRRConfig()
|
CRRConfig()
|
||||||
.environment(env="Pendulum-v1", clip_actions=True)
|
.environment(env="Pendulum-v1", clip_actions=True)
|
||||||
.framework("torch")
|
.framework("torch")
|
||||||
.offline_data(input_=[data_file], actions_in_input_normalized=True)
|
.offline_data(
|
||||||
|
input_=input_reading_fn,
|
||||||
|
input_config=input_config,
|
||||||
|
actions_in_input_normalized=True,
|
||||||
|
)
|
||||||
.training(
|
.training(
|
||||||
twin_q=True,
|
twin_q=True,
|
||||||
train_batch_size=256,
|
train_batch_size=256,
|
||||||
replay_buffer_config={
|
|
||||||
"type": MultiAgentReplayBuffer,
|
|
||||||
"learning_starts": 0,
|
|
||||||
"capacity": 100000,
|
|
||||||
},
|
|
||||||
weight_type="bin",
|
weight_type="bin",
|
||||||
advantage_type="mean",
|
advantage_type="mean",
|
||||||
n_action_sample=4,
|
n_action_sample=4,
|
||||||
target_update_grad_intervals=10000,
|
target_network_update_freq=10000,
|
||||||
tau=1.0,
|
tau=1.0,
|
||||||
)
|
)
|
||||||
.evaluation(
|
.evaluation(
|
||||||
|
|
|
@ -18,6 +18,7 @@ from ray.rllib.utils.metrics import (
|
||||||
NUM_AGENT_STEPS_SAMPLED,
|
NUM_AGENT_STEPS_SAMPLED,
|
||||||
NUM_ENV_STEPS_SAMPLED,
|
NUM_ENV_STEPS_SAMPLED,
|
||||||
SYNCH_WORKER_WEIGHTS_TIMER,
|
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||||
|
SAMPLE_TIMER,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
ResultDict,
|
ResultDict,
|
||||||
|
@ -253,7 +254,8 @@ class MARWIL(Algorithm):
|
||||||
@override(Algorithm)
|
@override(Algorithm)
|
||||||
def training_step(self) -> ResultDict:
|
def training_step(self) -> ResultDict:
|
||||||
# Collect SampleBatches from sample workers.
|
# Collect SampleBatches from sample workers.
|
||||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
with self._timers[SAMPLE_TIMER]:
|
||||||
|
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||||
batch = batch.as_multi_agent()
|
batch = batch.as_multi_agent()
|
||||||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
||||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
||||||
|
|
|
@ -863,7 +863,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
>>> print(worker.sample()) # doctest: +SKIP
|
>>> print(worker.sample()) # doctest: +SKIP
|
||||||
SampleBatch({"obs": [...], "action": [...], ...})
|
SampleBatch({"obs": [...], "action": [...], ...})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.fake_sampler and self.last_batch is not None:
|
if self.fake_sampler and self.last_batch is not None:
|
||||||
return self.last_batch
|
return self.last_batch
|
||||||
elif self.input_reader is None:
|
elif self.input_reader is None:
|
||||||
|
@ -893,9 +892,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
max_batches = self.num_envs
|
max_batches = self.num_envs
|
||||||
else:
|
else:
|
||||||
max_batches = float("inf")
|
max_batches = float("inf")
|
||||||
|
while steps_so_far < self.rollout_fragment_length and (
|
||||||
while (
|
len(batches) < max_batches or self.policy_config.get("offline_sampling")
|
||||||
steps_so_far < self.rollout_fragment_length and len(batches) < max_batches
|
|
||||||
):
|
):
|
||||||
batch = self.input_reader.next()
|
batch = self.input_reader.next()
|
||||||
steps_so_far += (
|
steps_so_far += (
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import logging
|
import logging
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
@ -99,14 +102,59 @@ class WorkerSet:
|
||||||
}
|
}
|
||||||
self._cls = RolloutWorker.as_remote(**self._remote_args).remote
|
self._cls = RolloutWorker.as_remote(**self._remote_args).remote
|
||||||
self._logdir = logdir
|
self._logdir = logdir
|
||||||
|
|
||||||
if _setup:
|
if _setup:
|
||||||
# Force a local worker if num_workers == 0 (no remote workers).
|
# Force a local worker if num_workers == 0 (no remote workers).
|
||||||
# Otherwise, this WorkerSet would be empty.
|
# Otherwise, this WorkerSet would be empty.
|
||||||
self._local_worker = None
|
self._local_worker = None
|
||||||
if num_workers == 0:
|
if num_workers == 0:
|
||||||
local_worker = True
|
local_worker = True
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
isinstance(trainer_config["input"], str)
|
||||||
|
or isinstance(trainer_config["input"], list)
|
||||||
|
)
|
||||||
|
and ("d4rl" not in trainer_config["input"])
|
||||||
|
and (not "sampler" == trainer_config["input"])
|
||||||
|
and (not "dataset" == trainer_config["input"])
|
||||||
|
and (
|
||||||
|
not (
|
||||||
|
isinstance(trainer_config["input"], str)
|
||||||
|
and registry_contains_input(trainer_config["input"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
not (
|
||||||
|
isinstance(trainer_config["input"], str)
|
||||||
|
and self._valid_module(trainer_config["input"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
paths = trainer_config["input"]
|
||||||
|
if isinstance(paths, str):
|
||||||
|
inputs = Path(paths).absolute()
|
||||||
|
if inputs.is_dir():
|
||||||
|
paths = list(inputs.glob("*.json")) + list(inputs.glob("*.zip"))
|
||||||
|
paths = [str(path) for path in paths]
|
||||||
|
else:
|
||||||
|
paths = [paths]
|
||||||
|
ends_with_zip_or_json = all(
|
||||||
|
re.search("\\.zip$", path) or re.search("\\.json$", path)
|
||||||
|
for path in paths
|
||||||
|
)
|
||||||
|
ends_with_parquet = all(
|
||||||
|
re.search("\\.parquet$", path) for path in paths
|
||||||
|
)
|
||||||
|
trainer_config["input"] = "dataset"
|
||||||
|
input_config = {"paths": paths}
|
||||||
|
if ends_with_zip_or_json:
|
||||||
|
input_config["format"] = "json"
|
||||||
|
elif ends_with_parquet:
|
||||||
|
input_config["format"] = "parquet"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Input path must end with .zip, .parquet, or .json"
|
||||||
|
)
|
||||||
|
trainer_config["input_config"] = input_config
|
||||||
self._local_config = merge_dicts(
|
self._local_config = merge_dicts(
|
||||||
trainer_config,
|
trainer_config,
|
||||||
{"tf_session_args": trainer_config["local_tf_session_args"]},
|
{"tf_session_args": trainer_config["local_tf_session_args"]},
|
||||||
|
@ -737,6 +785,26 @@ class WorkerSet:
|
||||||
|
|
||||||
return faulty_worker_indices
|
return faulty_worker_indices
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _valid_module(cls, class_path):
|
||||||
|
del cls
|
||||||
|
if (
|
||||||
|
isinstance(class_path, str)
|
||||||
|
and not os.path.isfile(class_path)
|
||||||
|
and "." in class_path
|
||||||
|
):
|
||||||
|
module_path, class_name = class_path.rsplit(".", 1)
|
||||||
|
try:
|
||||||
|
spec = importlib.util.find_spec(module_path)
|
||||||
|
if spec is not None:
|
||||||
|
return True
|
||||||
|
except (ModuleNotFoundError, ValueError):
|
||||||
|
print(
|
||||||
|
f"module {module_path} not found while trying to get "
|
||||||
|
f"input {class_path}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
@Deprecated(new="WorkerSet.foreach_policy_to_train", error=False)
|
@Deprecated(new="WorkerSet.foreach_policy_to_train", error=False)
|
||||||
def foreach_trainable_policy(self, func):
|
def foreach_trainable_policy(self, func):
|
||||||
return self.foreach_policy_to_train(func)
|
return self.foreach_policy_to_train(func)
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import ray.data
|
import ray.data
|
||||||
from ray.rllib.offline.input_reader import InputReader
|
from ray.rllib.offline.input_reader import InputReader
|
||||||
from ray.rllib.offline.io_context import IOContext
|
from ray.rllib.offline.io_context import IOContext
|
||||||
from ray.rllib.offline.json_reader import from_json_data
|
from ray.rllib.offline.json_reader import from_json_data
|
||||||
from ray.rllib.policy.sample_batch import concat_samples
|
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -25,11 +27,31 @@ def _get_resource_bundles(config: AlgorithmConfigDict):
|
||||||
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
|
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
|
||||||
|
|
||||||
|
|
||||||
|
def _unzip_if_needed(paths: List[str], format: str):
|
||||||
|
"""If a path in paths is a zip file, unzip it and use path of the unzipped file"""
|
||||||
|
ret = []
|
||||||
|
for path in paths:
|
||||||
|
fpath = Path(path).absolute()
|
||||||
|
if not fpath.exists():
|
||||||
|
fpath = Path(__file__).parent.parent / path
|
||||||
|
if not fpath.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {path}")
|
||||||
|
if re.search("\\.zip$", str(fpath)):
|
||||||
|
with zipfile.ZipFile(str(fpath), "r") as zip_ref:
|
||||||
|
zip_ref.extractall(str(fpath.parent))
|
||||||
|
fpath = re.sub("\\.zip$", f".{format}", str(fpath))
|
||||||
|
fpath = str(fpath)
|
||||||
|
ret.append(fpath)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def get_dataset_and_shards(
|
def get_dataset_and_shards(
|
||||||
config: AlgorithmConfigDict, num_workers: int, local_worker: bool
|
config: AlgorithmConfigDict, num_workers: int, local_worker: bool
|
||||||
) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]):
|
) -> Tuple[ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]]:
|
||||||
assert config["input"] == "dataset"
|
assert config["input"] == "dataset", (
|
||||||
|
"Must specify input as dataset if" " calling `get_dataset_and_shards`"
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
"input_config" in config
|
"input_config" in config
|
||||||
), "Must specify input_config dict if using Dataset input."
|
), "Must specify input_config dict if using Dataset input."
|
||||||
|
@ -37,36 +59,47 @@ def get_dataset_and_shards(
|
||||||
input_config = config["input_config"]
|
input_config = config["input_config"]
|
||||||
|
|
||||||
format = input_config.get("format")
|
format = input_config.get("format")
|
||||||
path = input_config.get("path")
|
assert format in ("json", "parquet"), (
|
||||||
|
"Offline input data format must be " "parquet " "or json"
|
||||||
|
)
|
||||||
|
paths = input_config.get("paths")
|
||||||
loader_fn = input_config.get("loader_fn")
|
loader_fn = input_config.get("loader_fn")
|
||||||
|
if loader_fn and (format or paths):
|
||||||
if loader_fn and (format or path):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When using a `loader_fn`, you cannot specify a `format` or `path`."
|
"When using a `loader_fn`, you cannot specify a `format` or `path`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (format and path) and not loader_fn:
|
if not (format and paths) and not loader_fn:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must specify format and path, or a loader_fn via input_config key"
|
"Must specify format and path, or a loader_fn via input_config key"
|
||||||
" when using Ray dataset input."
|
" when using Ray dataset input."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not isinstance(paths, (list, str)):
|
||||||
|
raise ValueError("Paths must be a list of path strings or a path string")
|
||||||
|
if isinstance(paths, str):
|
||||||
|
paths = [paths]
|
||||||
|
paths = _unzip_if_needed(paths, format)
|
||||||
|
|
||||||
parallelism = input_config.get("parallelism", num_workers or 1)
|
parallelism = input_config.get("parallelism", num_workers or 1)
|
||||||
cpus_per_task = input_config.get(
|
cpus_per_task = input_config.get(
|
||||||
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loader_fn or (format and path)
|
assert loader_fn or (format and paths), (
|
||||||
|
f"If using a loader_fn: {loader_fn} that constructs a dataset, "
|
||||||
|
"format: {format} and paths: {paths} must be specified. If format and "
|
||||||
|
"paths are specified, a loader_fn must not be specified."
|
||||||
|
)
|
||||||
if loader_fn:
|
if loader_fn:
|
||||||
dataset = loader_fn()
|
dataset = loader_fn()
|
||||||
elif format == "json":
|
elif format == "json":
|
||||||
dataset = ray.data.read_json(
|
dataset = ray.data.read_json(
|
||||||
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||||
)
|
)
|
||||||
elif format == "parquet":
|
elif format == "parquet":
|
||||||
dataset = ray.data.read_parquet(
|
dataset = ray.data.read_parquet(
|
||||||
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Un-supported Ray dataset format: ", format)
|
raise ValueError("Un-supported Ray dataset format: ", format)
|
||||||
|
@ -96,7 +129,7 @@ class DatasetReader(InputReader):
|
||||||
"format": "json",
|
"format": "json",
|
||||||
# A single data file, a directory, or anything
|
# A single data file, a directory, or anything
|
||||||
# that ray.data.dataset recognizes.
|
# that ray.data.dataset recognizes.
|
||||||
"path": "/tmp/sample_batches/",
|
"paths": "/tmp/sample_batches/",
|
||||||
# By default, parallelism=num_workers.
|
# By default, parallelism=num_workers.
|
||||||
"parallelism": 3,
|
"parallelism": 3,
|
||||||
# Dataset allocates 0.5 CPU for each reader by default.
|
# Dataset allocates 0.5 CPU for each reader by default.
|
||||||
|
@ -114,20 +147,33 @@ class DatasetReader(InputReader):
|
||||||
ds: Ray dataset to sample from.
|
ds: Ray dataset to sample from.
|
||||||
"""
|
"""
|
||||||
self._ioctx = ioctx
|
self._ioctx = ioctx
|
||||||
|
self._default_policy = self.policy_map = None
|
||||||
self._dataset = ds
|
self._dataset = ds
|
||||||
|
self.count = None if not self._dataset else self._dataset.count()
|
||||||
|
# do this to disable the ray data stdout logging
|
||||||
|
ray.data.set_progress_bars(enabled=False)
|
||||||
|
|
||||||
# the number of rows to return per call to next()
|
# the number of rows to return per call to next()
|
||||||
if self._ioctx:
|
if self._ioctx:
|
||||||
self.batch_size = ioctx.config.get("train_batch_size", 1)
|
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
|
||||||
num_workers = ioctx.config.get("num_workers", 0)
|
num_workers = self._ioctx.config.get("num_workers", 0)
|
||||||
|
seed = self._ioctx.config.get("seed", None)
|
||||||
if num_workers:
|
if num_workers:
|
||||||
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
||||||
# We allow the creation of a non-functioning None DatasetReader.
|
# We allow the creation of a non-functioning None DatasetReader.
|
||||||
# It's useful for example for a non-rollout local worker.
|
# It's useful for example for a non-rollout local worker.
|
||||||
if ds:
|
if ds:
|
||||||
|
if self._ioctx.worker is not None:
|
||||||
|
self._policy_map = self._ioctx.worker.policy_map
|
||||||
|
self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
|
||||||
|
self._dataset.random_shuffle(seed=seed)
|
||||||
print(
|
print(
|
||||||
"DatasetReader ", ioctx.worker_index, " has ", ds.count(), " samples."
|
f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
|
||||||
)
|
)
|
||||||
self._iter = self._dataset.repeat().iter_rows()
|
# TODO: @avnishn make this call seeded.
|
||||||
|
# calling random_shuffle_each_window shuffles the dataset after
|
||||||
|
# each time the whole dataset has been read.
|
||||||
|
self._iter = self._dataset.repeat().random_shuffle_each_window().iter_rows()
|
||||||
else:
|
else:
|
||||||
self._iter = None
|
self._iter = None
|
||||||
|
|
||||||
|
@ -142,6 +188,22 @@ class DatasetReader(InputReader):
|
||||||
# Columns like obs are compressed when written by DatasetWriter.
|
# Columns like obs are compressed when written by DatasetWriter.
|
||||||
d = from_json_data(d, self._ioctx.worker)
|
d = from_json_data(d, self._ioctx.worker)
|
||||||
count += d.count
|
count += d.count
|
||||||
ret.append(d)
|
ret.append(self._postprocess_if_needed(d))
|
||||||
ret = concat_samples(ret)
|
ret = concat_samples(ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
|
||||||
|
if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"):
|
||||||
|
return batch
|
||||||
|
|
||||||
|
if isinstance(batch, SampleBatch):
|
||||||
|
out = []
|
||||||
|
for sub_batch in batch.split_by_episode():
|
||||||
|
out.append(self._default_policy.postprocess_trajectory(sub_batch))
|
||||||
|
return SampleBatch.concat_samples(out)
|
||||||
|
else:
|
||||||
|
# TODO(ekl) this is trickier since the alignments between agent
|
||||||
|
# trajectories in the episode are not available any more.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Postprocessing of multi-agent data not implemented yet."
|
||||||
|
)
|
||||||
|
|
|
@ -22,7 +22,7 @@ class TestDatasetReader(unittest.TestCase):
|
||||||
print("rllib dir={}".format(rllib_dir))
|
print("rllib dir={}".format(rllib_dir))
|
||||||
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
|
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
|
||||||
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||||
input_config = {"format": "json", "path": data_file}
|
input_config = {"format": "json", "paths": data_file}
|
||||||
dataset, _ = get_dataset_and_shards(
|
dataset, _ = get_dataset_and_shards(
|
||||||
{"input": "dataset", "input_config": input_config}, 0, True
|
{"input": "dataset", "input_config": input_config}, 0, True
|
||||||
)
|
)
|
||||||
|
|
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
|
@ -106,7 +106,13 @@ class NestedActionSpacesTest(unittest.TestCase):
|
||||||
|
|
||||||
# Test, whether offline data can be properly read by
|
# Test, whether offline data can be properly read by
|
||||||
# BC, configured accordingly.
|
# BC, configured accordingly.
|
||||||
config["input"] = config["output"]
|
|
||||||
|
# doing this for backwards compatibility until we move to parquet
|
||||||
|
# as default output
|
||||||
|
config["input"] = lambda ioctx: JsonReader(
|
||||||
|
ioctx.config["input_config"]["paths"], ioctx
|
||||||
|
)
|
||||||
|
config["input_config"] = {"paths": config["output"]}
|
||||||
del config["output"]
|
del config["output"]
|
||||||
bc = BC(config=config)
|
bc = BC(config=config)
|
||||||
bc.train()
|
bc.train()
|
||||||
|
|
|
@ -7,7 +7,7 @@ pendulum-cql:
|
||||||
run: CQL
|
run: CQL
|
||||||
stop:
|
stop:
|
||||||
evaluation/episode_reward_mean: -700
|
evaluation/episode_reward_mean: -700
|
||||||
timesteps_total: 200000
|
timesteps_total: 800000
|
||||||
config:
|
config:
|
||||||
# Works for both torch and tf.
|
# Works for both torch and tf.
|
||||||
framework: tf
|
framework: tf
|
||||||
|
@ -21,10 +21,9 @@ pendulum-cql:
|
||||||
|
|
||||||
twin_q: true
|
twin_q: true
|
||||||
train_batch_size: 2000
|
train_batch_size: 2000
|
||||||
replay_buffer_config:
|
|
||||||
type: MultiAgentReplayBuffer
|
|
||||||
learning_starts: 0
|
|
||||||
bc_iters: 100
|
bc_iters: 100
|
||||||
|
num_workers: 2
|
||||||
|
min_time_s_per_iteration: 10
|
||||||
|
|
||||||
metrics_smoothing_episodes: 5
|
metrics_smoothing_episodes: 5
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ cartpole_crr:
|
||||||
config:
|
config:
|
||||||
input:
|
input:
|
||||||
- 'tests/data/cartpole/large.json'
|
- 'tests/data/cartpole/large.json'
|
||||||
|
num_workers: 3
|
||||||
framework: torch
|
framework: torch
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
train_batch_size: 2048
|
train_batch_size: 2048
|
||||||
|
@ -20,7 +21,7 @@ cartpole_crr:
|
||||||
clip_actions: True
|
clip_actions: True
|
||||||
# Q function update setting
|
# Q function update setting
|
||||||
twin_q: True
|
twin_q: True
|
||||||
target_update_grad_intervals: 1
|
target_network_update_freq: 1
|
||||||
tau: 0.0005
|
tau: 0.0005
|
||||||
# evaluation
|
# evaluation
|
||||||
evaluation_config:
|
evaluation_config:
|
||||||
|
@ -31,11 +32,6 @@ cartpole_crr:
|
||||||
evaluation_interval: 1
|
evaluation_interval: 1
|
||||||
evaluation_num_workers: 1
|
evaluation_num_workers: 1
|
||||||
evaluation_parallel_to_training: True
|
evaluation_parallel_to_training: True
|
||||||
# replay buffer
|
|
||||||
replay_buffer_config:
|
|
||||||
type: ray.rllib.utils.replay_buffers.MultiAgentReplayBuffer
|
|
||||||
learning_starts: 0
|
|
||||||
capacity: 100000
|
|
||||||
# specific to CRR
|
# specific to CRR
|
||||||
temperature: 1.0
|
temperature: 1.0
|
||||||
weight_type: bin
|
weight_type: bin
|
||||||
|
|
|
@ -8,6 +8,7 @@ cartpole_crr:
|
||||||
input:
|
input:
|
||||||
- 'tests/data/cartpole/large.json'
|
- 'tests/data/cartpole/large.json'
|
||||||
framework: torch
|
framework: torch
|
||||||
|
num_workers: 3
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
train_batch_size: 2048
|
train_batch_size: 2048
|
||||||
critic_hidden_activation: 'tanh'
|
critic_hidden_activation: 'tanh'
|
||||||
|
@ -20,7 +21,7 @@ cartpole_crr:
|
||||||
clip_actions: True
|
clip_actions: True
|
||||||
# Q function update setting
|
# Q function update setting
|
||||||
twin_q: True
|
twin_q: True
|
||||||
target_update_grad_intervals: 1
|
target_network_update_freq: 1
|
||||||
tau: 0.0005
|
tau: 0.0005
|
||||||
# evaluation
|
# evaluation
|
||||||
evaluation_config:
|
evaluation_config:
|
||||||
|
@ -31,11 +32,6 @@ cartpole_crr:
|
||||||
evaluation_interval: 1
|
evaluation_interval: 1
|
||||||
evaluation_num_workers: 1
|
evaluation_num_workers: 1
|
||||||
evaluation_parallel_to_training: True
|
evaluation_parallel_to_training: True
|
||||||
# replay buffer
|
|
||||||
replay_buffer_config:
|
|
||||||
type: ray.rllib.utils.replay_buffers.MultiAgentReplayBuffer
|
|
||||||
learning_starts: 0
|
|
||||||
capacity: 100000
|
|
||||||
# specific to CRR
|
# specific to CRR
|
||||||
temperature: 1.0
|
temperature: 1.0
|
||||||
weight_type: bin
|
weight_type: bin
|
||||||
|
|
|
@ -2,12 +2,13 @@ pendulum_crr:
|
||||||
env: 'Pendulum-v1'
|
env: 'Pendulum-v1'
|
||||||
run: CRR
|
run: CRR
|
||||||
stop:
|
stop:
|
||||||
evaluation/episode_reward_mean: -200
|
# We could make this -200, but given that we have 4 cpus for our tests, we will have to settle for -300.
|
||||||
training_iteration: 500
|
evaluation/episode_reward_mean: -300
|
||||||
|
timesteps_total: 2000000
|
||||||
config:
|
config:
|
||||||
input:
|
input: 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
|
||||||
- 'tests/data/pendulum/pendulum_replay_v1.1.0.zip'
|
|
||||||
framework: torch
|
framework: torch
|
||||||
|
num_workers: 3
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
train_batch_size: 1024
|
train_batch_size: 1024
|
||||||
critic_hidden_activation: 'relu'
|
critic_hidden_activation: 'relu'
|
||||||
|
@ -20,7 +21,7 @@ pendulum_crr:
|
||||||
clip_actions: True
|
clip_actions: True
|
||||||
# Q function update setting
|
# Q function update setting
|
||||||
twin_q: True
|
twin_q: True
|
||||||
target_update_grad_intervals: 1
|
target_network_update_freq: 1
|
||||||
tau: 0.0001
|
tau: 0.0001
|
||||||
# evaluation
|
# evaluation
|
||||||
evaluation_config:
|
evaluation_config:
|
||||||
|
@ -32,10 +33,6 @@ pendulum_crr:
|
||||||
evaluation_num_workers: 1
|
evaluation_num_workers: 1
|
||||||
evaluation_parallel_to_training: True
|
evaluation_parallel_to_training: True
|
||||||
# replay buffer
|
# replay buffer
|
||||||
replay_buffer_config:
|
|
||||||
type: ray.rllib.utils.replay_buffers.MultiAgentReplayBuffer
|
|
||||||
learning_starts: 0
|
|
||||||
capacity: 100000
|
|
||||||
# specific to CRR
|
# specific to CRR
|
||||||
temperature: 1.0
|
temperature: 1.0
|
||||||
weight_type: exp
|
weight_type: exp
|
||||||
|
|
Loading…
Add table
Reference in a new issue