from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy from ray.rllib.optimizers import SyncBatchReplayOptimizer # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # You should override this to point to an offline dataset (see agent.py). "input": "sampler", # Use importance sampling estimators for reward "input_evaluation": ["is", "wis"], # Scaling of advantages in exponential terms # When beta is 0, MARWIL is reduced to imitation learning "beta": 1.0, # Balancing value estimation loss and policy optimization loss "vf_coeff": 1.0, # Whether to calculate cumulative rewards "postprocess_inputs": True, # Whether to rollout "complete_episodes" or "truncate_episodes" "batch_mode": "complete_episodes", # Learning rate for adam optimizer "lr": 1e-4, # Number of timesteps collected for each SGD round "train_batch_size": 2000, # Number of steps max to keep in the batch replay buffer "replay_buffer_size": 100000, # Number of steps to read before learning starts "learning_starts": 0, # === Parallelism === "num_workers": 0, # Use PyTorch as framework? "use_pytorch": False }) # __sphinx_doc_end__ # yapf: enable def make_optimizer(workers, config): return SyncBatchReplayOptimizer( workers, learning_starts=config["learning_starts"], buffer_size=config["replay_buffer_size"], train_batch_size=config["train_batch_size"], ) def get_policy_class(config): if config.get("use_pytorch") is True: from ray.rllib.agents.marwil.marwil_torch_policy import \ MARWILTorchPolicy return MARWILTorchPolicy else: return MARWILTFPolicy MARWILTrainer = build_trainer( name="MARWIL", default_config=DEFAULT_CONFIG, default_policy=MARWILTFPolicy, get_policy_class=get_policy_class, make_policy_optimizer=make_optimizer)