# flake8: noqa # __preprocessing_observations_start__ import gym env = gym.make("Pong-v0") # RLlib uses preprocessors to implement transforms such as one-hot encoding # and flattening of tuple and dict observations. from ray.rllib.models.preprocessors import get_preprocessor prep = get_preprocessor(env.observation_space)(env.observation_space) # # Observations should be preprocessed prior to feeding into a model env.reset().shape # (210, 160, 3) prep.transform(env.reset()).shape # (84, 84, 3) # __preprocessing_observations_end__ # __query_action_dist_start__ # Get a reference to the policy import numpy as np from ray.rllib.algorithms.ppo import PPO algo = PPO(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0}) policy = algo.get_policy() # # Run a forward pass to get model output logits. Note that complex observations # must be preprocessed as in the above code block. logits, _ = policy.model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])}) # (, []) # Compute action distribution given logits policy.dist_class # dist = policy.dist_class(logits, policy.model) # # Query the distribution for samples, sample logps dist.sample() # dist.logp([1]) # # Get the estimated values for the most recent forward pass policy.model.value_function() # policy.model.base_model.summary() """ Model: "model" _____________________________________________________________________ Layer (type) Output Shape Param # Connected to ===================================================================== observations (InputLayer) [(None, 4)] 0 _____________________________________________________________________ fc_1 (Dense) (None, 256) 1280 observations[0][0] _____________________________________________________________________ fc_value_1 (Dense) (None, 256) 1280 observations[0][0] _____________________________________________________________________ fc_2 (Dense) (None, 256) 65792 fc_1[0][0] _____________________________________________________________________ fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0] _____________________________________________________________________ fc_out (Dense) (None, 2) 514 fc_2[0][0] _____________________________________________________________________ value_out (Dense) (None, 1) 257 fc_value_2[0][0] ===================================================================== Total params: 134,915 Trainable params: 134,915 Non-trainable params: 0 _____________________________________________________________________ """ # __query_action_dist_end__ # __get_q_values_dqn_start__ # Get a reference to the model through the policy import numpy as np from ray.rllib.algorithms.dqn import DQN algo = DQN(env="CartPole-v0", config={"framework": "tf2"}) model = algo.get_policy().model # # List of all model variables model.variables() # Run a forward pass to get base model output. Note that complex observations # must be preprocessed. An example of preprocessing is examples/saving_experiences.py model_out = model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])}) # (