ray/rllib/policy/tests/test_policy.py
Sven Mika e153e3179f
[RLlib] Exploration API: Policy changes needed for forward pass noisifications. (#7798)
* Rollback.

* WIP.

* WIP.

* LINT.

* WIP.

* Fix.

* Fix.

* Fix.

* LINT.

* Fix (SAC does currently not support eager).

* Fix.

* WIP.

* LINT.

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* WIP.

* Fix.

* LINT.

* LINT.

* Fix and LINT.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Fix.

* Fix and LINT.

* Update rllib/utils/exploration/exploration.py

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Fixes.

* LINT.

* WIP.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-01 00:43:21 -07:00

34 lines
1.1 KiB
Python

import random
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
class TestPolicy(Policy):
"""A dummy Policy that returns a random (batched) int for compute_actions.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exploration = self._create_exploration()
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
return [random.choice([0, 1])] * len(obs_batch), [], {}
@override(Policy)
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
return [random.random()] * len(obs_batch)