mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
|
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
|
|
StoreToReplayBuffer
|
|
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|
from ray.rllib.execution.concurrency_ops import Concurrently
|
|
from ray.rllib.execution.train_ops import TrainOneStep
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# You should override this to point to an offline dataset (see agent.py).
|
|
"input": "sampler",
|
|
# Use importance sampling estimators for reward
|
|
"input_evaluation": ["is", "wis"],
|
|
|
|
# Scaling of advantages in exponential terms
|
|
# When beta is 0, MARWIL is reduced to imitation learning
|
|
"beta": 1.0,
|
|
# Balancing value estimation loss and policy optimization loss
|
|
"vf_coeff": 1.0,
|
|
# Whether to calculate cumulative rewards
|
|
"postprocess_inputs": True,
|
|
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
|
"batch_mode": "complete_episodes",
|
|
# Learning rate for adam optimizer
|
|
"lr": 1e-4,
|
|
# Number of timesteps collected for each SGD round
|
|
"train_batch_size": 2000,
|
|
# Number of steps max to keep in the batch replay buffer
|
|
"replay_buffer_size": 100000,
|
|
# Number of steps to read before learning starts
|
|
"learning_starts": 0,
|
|
# === Parallelism ===
|
|
"num_workers": 0,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# yapf: enable
|
|
|
|
|
|
def get_policy_class(config):
|
|
if config["framework"] == "torch":
|
|
from ray.rllib.agents.marwil.marwil_torch_policy import \
|
|
MARWILTorchPolicy
|
|
return MARWILTorchPolicy
|
|
else:
|
|
return MARWILTFPolicy
|
|
|
|
|
|
def execution_plan(workers, config):
|
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])
|
|
|
|
store_op = rollouts \
|
|
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
|
|
|
replay_op = Replay(local_buffer=replay_buffer) \
|
|
.combine(
|
|
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
|
.for_each(TrainOneStep(workers))
|
|
|
|
train_op = Concurrently(
|
|
[store_op, replay_op], mode="round_robin", output_indexes=[1])
|
|
|
|
return StandardMetricsReporting(train_op, workers, config)
|
|
|
|
|
|
MARWILTrainer = build_trainer(
|
|
name="MARWIL",
|
|
default_config=DEFAULT_CONFIG,
|
|
default_policy=MARWILTFPolicy,
|
|
get_policy_class=get_policy_class,
|
|
execution_plan=execution_plan)
|