mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
33 lines
791 B
Python
33 lines
791 B
Python
import gym
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.utils.annotations import PublicAPI
|
|
from ray.rllib.utils.typing import ModelConfigDict
|
|
|
|
|
|
@PublicAPI
|
|
class JAXModelV2(ModelV2):
|
|
"""JAX version of ModelV2.
|
|
|
|
Note that this class by itself is not a valid model unless you
|
|
implement forward() in a subclass."""
|
|
|
|
def __init__(
|
|
self,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
num_outputs: int,
|
|
model_config: ModelConfigDict,
|
|
name: str,
|
|
):
|
|
"""Initializes a JAXModelV2 instance."""
|
|
|
|
ModelV2.__init__(
|
|
self,
|
|
obs_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
framework="jax",
|
|
)
|