ray/rllib/policy/tests/test_policy.py

21 lines
711 B
Python

import random
from ray.rllib.policy.policy import Policy
class TestPolicy(Policy):
"""
A dummy Policy that returns a random (batched) int for compute_actions
and implements all other abstract methods of Policy with "pass".
"""
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None,
deterministic=None,
explore=True,
time_step=None,
**kwargs):
return [random.choice([0, 1])] * len(obs_batch), [], {}