mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00

- This PR completes any previously missing PyTorch Model counterparts to TFModels in examples/models. - It also makes sure, all example scripts in the rllib/examples folder are tested for both frameworks and learn the given task (this is often currently not checked) using a --as-test flag in connection with a --stop-reward.
35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
import numpy as np
|
|
import random
|
|
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
|
|
class TestPolicy(Policy):
|
|
"""A dummy Policy that returns a random (batched) int for compute_actions.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.exploration = self._create_exploration()
|
|
|
|
@override(Policy)
|
|
def compute_actions(self,
|
|
obs_batch,
|
|
state_batches=None,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None,
|
|
episodes=None,
|
|
explore=None,
|
|
timestep=None,
|
|
**kwargs):
|
|
return np.array([random.choice([0, 1])] * len(obs_batch)), [], {}
|
|
|
|
@override(Policy)
|
|
def compute_log_likelihoods(self,
|
|
actions,
|
|
obs_batch,
|
|
state_batches=None,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None):
|
|
return np.array([random.random()] * len(obs_batch))
|