diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index b2bfc2699..25cd0d893 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -162,11 +162,28 @@ Tune will schedule the trials to run in parallel on your Ray cluster: - PPO_CartPole-v0_0_sgd_stepsize=0.01: RUNNING [pid=21940], 16 s, 4013 ts, 22 rew - PPO_CartPole-v0_1_sgd_stepsize=0.001: RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew -Accessing Global State +Accessing Policy State ~~~~~~~~~~~~~~~~~~~~~~ It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. -You can also access just the "master" copy of the agent state through ``agent.optimizer.local_evaluator``, but note that updates here may not be reflected in remote replicas if you have configured ``num_workers > 0``. +You can also access just the "master" copy of the agent state through ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.local_evaluator.policy_map["default"].get_weights()``. This is also equivalent to ``agent.local_evaluator.for_policy(lambda p: p.get_weights())``: + +.. code-block:: python + + # Get weights of the local policy + agent.local_evaluator.policy_map["default"].get_weights() + + # Same as above + agent.local_evaluator.for_policy(lambda p: p.get_weights()) + + # Get list of weights of each evaluator, including remote replicas + agent.optimizer.foreach_evaluator( + lambda ev: ev.for_policy(lambda p: p.get_weights())) + + # Same as above + agent.optimizer.foreach_evaluator_with_index( + lambda ev, i: ev.for_policy(lambda p: p.get_weights())) + REST API --------