ray/rllib/models/jax/jax_modelv2.py

27 lines
760 B
Python

import gym
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import ModelConfigDict
@PublicAPI
class JAXModelV2(ModelV2):
"""JAX version of ModelV2.
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
"""Initializes a JAXModelV2 instance."""
ModelV2.__init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
framework="jax")