ray/rllib/agents/marwil/marwil.py

89 lines
3.1 KiB
Python
Raw Normal View History

from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.train_ops import TrainOneStep
from ray.rllib.execution.metric_ops import StandardMetricsReporting
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# You should override this to point to an offline dataset (see agent.py).
"input": "sampler",
# Use importance sampling estimators for reward.
"input_evaluation": ["is", "wis"],
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# Scaling of advantages in exponential terms.
# When beta is 0.0, MARWIL is reduced to imitation learning.
"beta": 1.0,
# Balancing value estimation loss and policy optimization loss.
"vf_coeff": 1.0,
# If specified, clip the global norm of gradients by this amount.
"grad_clip": None,
# Whether to calculate cumulative rewards.
"postprocess_inputs": True,
# Whether to rollout "complete_episodes" or "truncate_episodes".
"batch_mode": "complete_episodes",
# Learning rate for adam optimizer.
"lr": 1e-4,
# Number of timesteps collected for each SGD round.
"train_batch_size": 2000,
2020-09-23 15:46:06 -07:00
# Size of the replay buffer in batches (not timesteps!).
"replay_buffer_size": 1000,
# Number of steps to read before learning starts.
"learning_starts": 0,
# === Parallelism ===
"num_workers": 0,
})
# __sphinx_doc_end__
# yapf: enable
def get_policy_class(config):
if config["framework"] == "torch":
from ray.rllib.agents.marwil.marwil_torch_policy import \
MARWILTorchPolicy
return MARWILTorchPolicy
def execution_plan(workers, config):
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])
store_op = rollouts \
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
replay_op = Replay(local_buffer=replay_buffer) \
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)) \
.for_each(TrainOneStep(workers))
train_op = Concurrently(
[store_op, replay_op], mode="round_robin", output_indexes=[1])
return StandardMetricsReporting(train_op, workers, config)
def validate_config(config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
MARWILTrainer = build_trainer(
name="MARWIL",
default_config=DEFAULT_CONFIG,
default_policy=MARWILTFPolicy,
get_policy_class=get_policy_class,
validate_config=validate_config,
execution_plan=execution_plan)