ray/rllib/agents/marwil/marwil.py
Jaroslaw Rzepecki 67319bc887
[RLlib] Update MARWIL to use tf policy template (#6975)
* update MARWIL to use tf policy template

* formatting fixes
2020-01-31 12:57:52 -08:00

58 lines
1.9 KiB
Python

from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.marwil.marwil_policy import MARWILTFPolicy
from ray.rllib.optimizers import SyncBatchReplayOptimizer
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# You should override this to point to an offline dataset (see agent.py).
"input": "sampler",
# Use importance sampling estimators for reward
"input_evaluation": ["is", "wis"],
# Scaling of advantages in exponential terms
# When beta is 0, MARWIL is reduced to imitation learning
"beta": 1.0,
# Balancing value estimation loss and policy optimization loss
"vf_coeff": 1.0,
# Whether to calculate cumulative rewards
"postprocess_inputs": True,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "complete_episodes",
# Learning rate for adam optimizer
"lr": 1e-4,
# Number of timesteps collected for each SGD round
"train_batch_size": 2000,
# Number of steps max to keep in the batch replay buffer
"replay_buffer_size": 100000,
# Number of steps to read before learning starts
"learning_starts": 0,
# === Parallelism ===
"num_workers": 0,
})
# __sphinx_doc_end__
# yapf: enable
def make_optimizer(workers, config):
return SyncBatchReplayOptimizer(
workers,
learning_starts=config["learning_starts"],
buffer_size=config["replay_buffer_size"],
train_batch_size=config["train_batch_size"],
)
def validate_config(config):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError("DDPG does not support PyTorch yet! Use tf instead.")
MARWILTrainer = build_trainer(
name="MARWIL",
default_config=DEFAULT_CONFIG,
default_policy=MARWILTFPolicy,
validate_config=validate_config,
make_policy_optimizer=make_optimizer)