mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
289 lines
11 KiB
Python
289 lines
11 KiB
Python
from typing import Callable, Optional, Type, Union
|
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm
|
|
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
|
from ray.rllib.execution.rollout_ops import (
|
|
synchronous_parallel_sample,
|
|
)
|
|
from ray.rllib.execution.train_ops import (
|
|
multi_gpu_train_one_step,
|
|
train_one_step,
|
|
)
|
|
from ray.rllib.offline.estimators import ImportanceSampling, WeightedImportanceSampling
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.deprecation import (
|
|
Deprecated,
|
|
deprecation_warning,
|
|
)
|
|
from ray.rllib.utils.metrics import (
|
|
NUM_AGENT_STEPS_SAMPLED,
|
|
NUM_ENV_STEPS_SAMPLED,
|
|
SYNCH_WORKER_WEIGHTS_TIMER,
|
|
SAMPLE_TIMER,
|
|
)
|
|
from ray.rllib.utils.typing import (
|
|
AlgorithmConfigDict,
|
|
EnvType,
|
|
ResultDict,
|
|
)
|
|
from ray.tune.logger import Logger
|
|
|
|
|
|
class MARWILConfig(AlgorithmConfig):
|
|
"""Defines a configuration class from which a MARWIL Algorithm can be built.
|
|
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.marwil import MARWILConfig
|
|
>>> # Run this from the ray directory root.
|
|
>>> config = MARWILConfig().training(beta=1.0, lr=0.00001, gamma=0.99)\
|
|
... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
|
>>> algo = config.build()
|
|
>>> algo.train()
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.marwil import MARWILConfig
|
|
>>> from ray import tune
|
|
>>> config = MARWILConfig()
|
|
>>> # Print out some default values.
|
|
>>> print(config.beta)
|
|
>>> # Update the config object.
|
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), beta=0.75)
|
|
>>> # Set the config object's data path.
|
|
>>> # Run this from the ray directory root.
|
|
>>> config.offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
|
|
>>> # Set the config object's env, used for evaluation.
|
|
>>> config.environment(env="CartPole-v0")
|
|
>>> # Use to_dict() to get the old-style python config dict
|
|
>>> # when running with tune.
|
|
>>> tune.run(
|
|
... "MARWIL",
|
|
... config=config.to_dict(),
|
|
... )
|
|
"""
|
|
|
|
def __init__(self, algo_class=None):
|
|
"""Initializes a MARWILConfig instance."""
|
|
super().__init__(algo_class=algo_class or MARWIL)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
# MARWIL specific settings:
|
|
self.beta = 1.0
|
|
self.bc_logstd_coeff = 0.0
|
|
self.moving_average_sqd_adv_norm_update_rate = 1e-8
|
|
self.moving_average_sqd_adv_norm_start = 100.0
|
|
self.use_gae = True
|
|
self.vf_coeff = 1.0
|
|
self.grad_clip = None
|
|
|
|
# Override some of AlgorithmConfig's default values with MARWIL-specific values.
|
|
|
|
# You should override input_ to point to an offline dataset
|
|
# (see trainer.py and trainer_config.py).
|
|
# The dataset may have an arbitrary number of timesteps
|
|
# (and even episodes) per line.
|
|
# However, each line must only contain consecutive timesteps in
|
|
# order for MARWIL to be able to calculate accumulated
|
|
# discounted returns. It is ok, though, to have multiple episodes in
|
|
# the same line.
|
|
self.input_ = "sampler"
|
|
self.postprocess_inputs = True
|
|
self.lr = 1e-4
|
|
self.train_batch_size = 2000
|
|
self.num_workers = 0
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
# TODO: Delete this and change off_policy_estimation_methods to {}
|
|
# Also remove the same section from BC
|
|
self.off_policy_estimation_methods = {
|
|
"is": {"type": ImportanceSampling},
|
|
"wis": {"type": WeightedImportanceSampling},
|
|
}
|
|
self._set_off_policy_estimation_methods = False
|
|
|
|
@override(AlgorithmConfig)
|
|
def training(
|
|
self,
|
|
*,
|
|
beta: Optional[float] = None,
|
|
bc_logstd_coeff: Optional[float] = None,
|
|
moving_average_sqd_adv_norm_update_rate: Optional[float] = None,
|
|
moving_average_sqd_adv_norm_start: Optional[float] = None,
|
|
use_gae: Optional[bool] = True,
|
|
vf_coeff: Optional[float] = None,
|
|
grad_clip: Optional[float] = None,
|
|
**kwargs,
|
|
) -> "MARWILConfig":
|
|
"""Sets the training related configuration.
|
|
|
|
Args:
|
|
beta: Scaling of advantages in exponential terms. When beta is 0.0,
|
|
MARWIL is reduced to behavior cloning (imitation learning);
|
|
see bc.py algorithm in this same directory.
|
|
bc_logstd_coeff: A coefficient to encourage higher action distribution
|
|
entropy for exploration.
|
|
moving_average_sqd_adv_norm_start: Starting value for the
|
|
squared moving average advantage norm (c^2).
|
|
use_gae: If true, use the Generalized Advantage Estimator (GAE)
|
|
with a value function, see https://arxiv.org/pdf/1506.02438.pdf in
|
|
case an input line ends with a non-terminal timestep.
|
|
vf_coeff: Balancing value estimation loss and policy optimization loss.
|
|
moving_average_sqd_adv_norm_update_rate: Update rate for the
|
|
squared moving average advantage norm (c^2).
|
|
grad_clip: If specified, clip the global norm of gradients by this amount.
|
|
|
|
Returns:
|
|
This updated AlgorithmConfig object.
|
|
"""
|
|
# Pass kwargs onto super's `training()` method.
|
|
super().training(**kwargs)
|
|
if beta is not None:
|
|
self.beta = beta
|
|
if bc_logstd_coeff is not None:
|
|
self.bc_logstd_coeff = bc_logstd_coeff
|
|
if moving_average_sqd_adv_norm_update_rate is not None:
|
|
self.moving_average_sqd_adv_norm_update_rate = (
|
|
moving_average_sqd_adv_norm_update_rate
|
|
)
|
|
if moving_average_sqd_adv_norm_start is not None:
|
|
self.moving_average_sqd_adv_norm_start = moving_average_sqd_adv_norm_start
|
|
if use_gae is not None:
|
|
self.use_gae = use_gae
|
|
if vf_coeff is not None:
|
|
self.vf_coeff = vf_coeff
|
|
if grad_clip is not None:
|
|
self.grad_clip = grad_clip
|
|
return self
|
|
|
|
def evaluation(
|
|
self,
|
|
**kwargs,
|
|
) -> "MARWILConfig":
|
|
"""Sets the evaluation related configuration.
|
|
|
|
Returns:
|
|
This updated AlgorithmConfig object.
|
|
"""
|
|
# Pass kwargs onto super's `evaluation()` method.
|
|
super().evaluation(**kwargs)
|
|
|
|
if "off_policy_estimation_methods" in kwargs:
|
|
# User specified their OPE methods.
|
|
self._set_off_policy_estimation_methods = True
|
|
|
|
return self
|
|
|
|
def build(
|
|
self,
|
|
env: Optional[Union[str, EnvType]] = None,
|
|
logger_creator: Optional[Callable[[], Logger]] = None,
|
|
) -> "Algorithm":
|
|
if not self._set_off_policy_estimation_methods:
|
|
deprecation_warning(
|
|
old="MARWIL currently uses off_policy_estimation_methods: "
|
|
f"{self.off_policy_estimation_methods} by default. This will"
|
|
"change to off_policy_estimation_methods: {} in a future release."
|
|
"If you want to use an off-policy estimator, specify it in"
|
|
".evaluation(off_policy_estimation_methods=...)",
|
|
error=False,
|
|
)
|
|
return super().build(env, logger_creator)
|
|
|
|
|
|
class MARWIL(Algorithm):
|
|
@classmethod
|
|
@override(Algorithm)
|
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
|
return MARWILConfig().to_dict()
|
|
|
|
@override(Algorithm)
|
|
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
|
# Call super's validation method.
|
|
super().validate_config(config)
|
|
|
|
if config["beta"] < 0.0 or config["beta"] > 1.0:
|
|
raise ValueError("`beta` must be within 0.0 and 1.0!")
|
|
|
|
if config["num_gpus"] > 1:
|
|
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
|
|
|
|
if config["postprocess_inputs"] is False and config["beta"] > 0.0:
|
|
raise ValueError(
|
|
"`postprocess_inputs` must be True for MARWIL (to "
|
|
"calculate accum., discounted returns)!"
|
|
)
|
|
|
|
@override(Algorithm)
|
|
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
|
if config["framework"] == "torch":
|
|
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
|
|
MARWILTorchPolicy,
|
|
)
|
|
|
|
return MARWILTorchPolicy
|
|
elif config["framework"] == "tf":
|
|
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
|
|
MARWILTF1Policy,
|
|
)
|
|
|
|
return MARWILTF1Policy
|
|
else:
|
|
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTF2Policy
|
|
|
|
return MARWILTF2Policy
|
|
|
|
@override(Algorithm)
|
|
def training_step(self) -> ResultDict:
|
|
# Collect SampleBatches from sample workers.
|
|
with self._timers[SAMPLE_TIMER]:
|
|
train_batch = synchronous_parallel_sample(worker_set=self.workers)
|
|
train_batch = train_batch.as_multi_agent()
|
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
|
|
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
|
|
# Train.
|
|
if self.config["simple_optimizer"]:
|
|
train_results = train_one_step(self, train_batch)
|
|
else:
|
|
train_results = multi_gpu_train_one_step(self, train_batch)
|
|
|
|
# TODO: Move training steps counter update outside of `train_one_step()` method.
|
|
# # Update train step counters.
|
|
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
|
|
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
|
|
|
global_vars = {
|
|
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
|
|
}
|
|
|
|
# Update weights - after learning on the local worker - on all remote
|
|
# workers.
|
|
if self.workers.remote_workers():
|
|
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
|
self.workers.sync_weights(global_vars=global_vars)
|
|
|
|
# Update global vars on local worker as well.
|
|
self.workers.local_worker().set_global_vars(global_vars)
|
|
|
|
return train_results
|
|
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.marwil.MARWILConfig instead!
|
|
class _deprecated_default_config(dict):
|
|
def __init__(self):
|
|
super().__init__(MARWILConfig().to_dict())
|
|
|
|
@Deprecated(
|
|
old="ray.rllib.agents.marwil.marwil::DEFAULT_CONFIG",
|
|
new="ray.rllib.algorithms.marwil.marwil::MARWILConfig(...)",
|
|
error=False,
|
|
)
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|