diff --git a/python/ray/tests/test_tensorflow.py b/python/ray/tests/test_tensorflow.py index d243deaae..ce4834992 100644 --- a/python/ray/tests/test_tensorflow.py +++ b/python/ray/tests/test_tensorflow.py @@ -4,7 +4,7 @@ import ray import ray.experimental.tf_utils from ray.rllib.utils.framework import try_import_tf -tf = try_import_tf() +tf, _, _ = try_import_tf() def make_linear_network(w_name=None, b_name=None): diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index 33580127f..de551cbd6 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -9,13 +9,12 @@ from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \ from ray.rllib.agents.sac.tests.test_sac import SimpleEnv from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid from ray.rllib.utils.test_utils import check, check_compute_single_action, \ framework_iterator from ray.rllib.utils.torch_ops import convert_to_torch_tensor -tf = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/agents/maml/tests/test_maml.py b/rllib/agents/maml/tests/test_maml.py index 98b37ca78..a8daf260d 100644 --- a/rllib/agents/maml/tests/test_maml.py +++ b/rllib/agents/maml/tests/test_maml.py @@ -2,12 +2,9 @@ import unittest import ray import ray.rllib.agents.maml as maml -from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_compute_single_action, \ framework_iterator -tf = try_import_tf() - class TestMAML(unittest.TestCase): @classmethod diff --git a/rllib/policy/trajectory_view.py b/rllib/policy/trajectory_view.py index e45858356..f1c8d4e22 100644 --- a/rllib/policy/trajectory_view.py +++ b/rllib/policy/trajectory_view.py @@ -1,11 +1,9 @@ -from dataclasses import dataclass import numpy as np -from typing import Dict +from typing import Dict, Optional from ray.rllib.utils.types import TensorType -@dataclass class ViewRequirement: """Single view requirement (for one column in a ModelV2 input_dict). @@ -24,29 +22,39 @@ class ViewRequirement: >>> print(req) {"obs": ViewRequirement(timesteps=0)} """ - # The data column name from the SampleBatch (str key). - # If None, use the dict key under which this ViewRequirement resides. - data_col: str = None - # List of relative (or absolute timesteps) to be present in the - # input_dict. - timesteps: int = 0 + def __init__(self, + data_col: Optional[str] = None, + timesteps: int = 0, + fill_mode: str = "zeros", + repeat_mode: str = "all"): + """Initializes a ViewRequirement object. - # Switch on absolute timestep mode. Default: False. - # TODO: (sven) - # "absolute_timesteps", + Args: + data_col (): The data column name from the SampleBatch (str key). + If None, use the dict key under which this ViewRequirement + resides. + timesteps (Union[List[int], int]): List of relative (or absolute + timesteps) to be present in the input_dict. + fill_mode (str): The fill mode in case t<0 or t>H. + One of "zeros", "tile". + repeat_mode (str): The repeat-mode (one of "all" or "only_first"). + E.g. for training, we only want the first internal state + timestep (the NN will calculate all others again anyways). + """ + self.data_col = data_col + self.timesteps = timesteps - # The fill mode in case t<0 or t>H: One of "zeros", "tile". - fill_mode: str = "zeros" + # Switch on absolute timestep mode. Default: False. + # TODO: (sven) + # "absolute_timesteps", - # The repeat-mode (one of "all" or "only_first"). E.g. for training, - # we only want the first internal state timestep (the NN will - # calculate all others again anyways). - repeat_mode: str = "all" + self.fill_mode = fill_mode + self.repeat_mode = repeat_mode - # Provide all data as time major (default: False). - # TODO: (sven) - # "time_major", + # Provide all data as time major (default: False). + # TODO: (sven) + # "time_major", def get_trajectory_view( diff --git a/rllib/tuned_examples/debug_learning_failure_git_bisect.py b/rllib/tuned_examples/debug_learning_failure_git_bisect.py index 84c0418a6..a5c89ab94 100644 --- a/rllib/tuned_examples/debug_learning_failure_git_bisect.py +++ b/rllib/tuned_examples/debug_learning_failure_git_bisect.py @@ -25,11 +25,8 @@ import yaml import ray from ray import tune -from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import check_learning_achieved -tf = try_import_tf() - parser = argparse.ArgumentParser() parser.add_argument("--run", type=str, default=None) parser.add_argument("--torch", action="store_true")