from collections import Counter import copy import logging import random import re import time import os from typing import ( TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Type, Union, ) import numpy as np import tree # pip install dm_tree import yaml import pprint from gym.spaces import Box import ray from ray import air, tune from ray.rllib.utils.framework import try_import_jax, try_import_tf, try_import_torch from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED from ray.rllib.utils.typing import PartialAlgorithmConfigDict from ray.tune import CLIReporter, run_experiments if TYPE_CHECKING: from ray.rllib.algorithms import Algorithm, AlgorithmConfig jax, _ = try_import_jax() tf1, tf, tfv = try_import_tf() if tf1: eager_mode = None try: from tensorflow.python.eager.context import eager_mode except (ImportError, ModuleNotFoundError): pass torch, _ = try_import_torch() logger = logging.getLogger(__name__) def framework_iterator( config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None, frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"), session: bool = False, with_eager_tracing: bool = False, time_iterations: Optional[dict] = None, ) -> Union[str, Tuple[str, Optional["tf1.Session"]]]: """An generator that allows for looping through n frameworks for testing. Provides the correct config entries ("framework") as well as the correct eager/non-eager contexts for tfe/tf. Args: config: An optional config dict or AlgorithmConfig object. This will be modified (value for "framework" changed) depending on the iteration. frameworks: A list/tuple of the frameworks to be tested. Allowed are: "tf2", "tf", "tfe", "torch", and None. session: If True and only in the tf-case: Enter a tf.Session() and yield that as second return value (otherwise yield (fw, None)). Also sets a seed (42) on the session to make the test deterministic. with_eager_tracing: Include `eager_tracing=True` in the returned configs, when framework=[tfe|tf2]. time_iterations: If provided, will write to the given dict (by framework key) the times in seconds that each (framework's) iteration takes. Yields: If `session` is False: The current framework [tf2|tf|tfe|torch] used. If `session` is True: A tuple consisting of the current framework string and the tf1.Session (if fw="tf", otherwise None). """ config = config or {} frameworks = [frameworks] if isinstance(frameworks, str) else list(frameworks) # Both tf2 and tfe present -> remove "tfe" or "tf2" depending on version. if "tf2" in frameworks and "tfe" in frameworks: frameworks.remove("tfe" if tfv == 2 else "tf2") for fw in frameworks: # Skip non-installed frameworks. if fw == "torch" and not torch: logger.warning("framework_iterator skipping torch (not installed)!") continue if fw != "torch" and not tf: logger.warning( "framework_iterator skipping {} (tf not installed)!".format(fw) ) continue elif fw == "tfe" and not eager_mode: logger.warning( "framework_iterator skipping tf-eager (could not " "import `eager_mode` from tensorflow.python)!" ) continue elif fw == "tf2" and tfv != 2: logger.warning("framework_iterator skipping tf2.x (tf version is < 2.0)!") continue elif fw == "jax" and not jax: logger.warning("framework_iterator skipping JAX (not installed)!") continue assert fw in ["tf2", "tf", "tfe", "torch", "jax", None] # Do we need a test session? sess = None if fw == "tf" and session is True: sess = tf1.Session() sess.__enter__() tf1.set_random_seed(42) if isinstance(config, dict): config["framework"] = fw else: config.framework(fw) eager_ctx = None # Enable eager mode for tf2 and tfe. if fw in ["tf2", "tfe"]: eager_ctx = eager_mode() eager_ctx.__enter__() assert tf1.executing_eagerly() # Make sure, eager mode is off. elif fw == "tf": assert not tf1.executing_eagerly() # Additionally loop through eager_tracing=True + False, if necessary. if fw in ["tf2", "tfe"] and with_eager_tracing: for tracing in [True, False]: if isinstance(config, dict): config["eager_tracing"] = tracing else: config.framework(eager_tracing=tracing) print(f"framework={fw} (eager-tracing={tracing})") time_started = time.time() yield fw if session is False else (fw, sess) if time_iterations is not None: time_total = time.time() - time_started time_iterations[fw + ("+tracing" if tracing else "")] = time_total print(f".. took {time_total}sec") if isinstance(config, dict): config["eager_tracing"] = False else: config.framework(eager_tracing=False) # Yield current framework + tf-session (if necessary). else: print(f"framework={fw}") time_started = time.time() yield fw if session is False else (fw, sess) if time_iterations is not None: time_total = time.time() - time_started time_iterations[fw + ("+tracing" if tracing else "")] = time_total print(f".. took {time_total}sec") # Exit any context we may have entered. if eager_ctx: eager_ctx.__exit__(None, None, None) elif sess: sess.__exit__(None, None, None) def check(x, y, decimals=5, atol=None, rtol=None, false=False): """ Checks two structures (dict, tuple, list, np.array, float, int, etc..) for (almost) numeric identity. All numbers in the two structures have to match up to `decimal` digits after the floating point. Uses assertions. Args: x: The value to be compared (to the expectation: `y`). This may be a Tensor. y: The expected value to be compared to `x`. This must not be a tf-Tensor, but may be a tfe/torch-Tensor. decimals: The number of digits after the floating point up to which all numeric values have to match. atol: Absolute tolerance of the difference between x and y (overrides `decimals` if given). rtol: Relative tolerance of the difference between x and y (overrides `decimals` if given). false: Whether to check that x and y are NOT the same. """ # A dict type. if isinstance(x, dict): assert isinstance(y, dict), "ERROR: If x is dict, y needs to be a dict as well!" y_keys = set(x.keys()) for key, value in x.items(): assert key in y, f"ERROR: y does not have x's key='{key}'! y={y}" check(value, y[key], decimals=decimals, atol=atol, rtol=rtol, false=false) y_keys.remove(key) assert not y_keys, "ERROR: y contains keys ({}) that are not in x! y={}".format( list(y_keys), y ) # A tuple type. elif isinstance(x, (tuple, list)): assert isinstance( y, (tuple, list) ), "ERROR: If x is tuple, y needs to be a tuple as well!" assert len(y) == len( x ), "ERROR: y does not have the same length as x ({} vs {})!".format( len(y), len(x) ) for i, value in enumerate(x): check(value, y[i], decimals=decimals, atol=atol, rtol=rtol, false=false) # Boolean comparison. elif isinstance(x, (np.bool_, bool)): if false is True: assert bool(x) is not bool(y), f"ERROR: x ({x}) is y ({y})!" else: assert bool(x) is bool(y), f"ERROR: x ({x}) is not y ({y})!" # Nones or primitives. elif x is None or y is None or isinstance(x, (str, int)): if false is True: assert x != y, f"ERROR: x ({x}) is the same as y ({y})!" else: assert x == y, f"ERROR: x ({x}) is not the same as y ({y})!" # String/byte comparisons. elif hasattr(x, "dtype") and (x.dtype == object or str(x.dtype).startswith(" raise error (not expected to be equal). if false is True: assert False, f"ERROR: x ({x}) is the same as y ({y})!" # Using atol/rtol. else: # Provide defaults for either one of atol/rtol. if atol is None: atol = 0 if rtol is None: rtol = 1e-7 try: np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) except AssertionError as e: if false is False: raise e else: if false is True: assert False, f"ERROR: x ({x}) is the same as y ({y})!" def check_compute_single_action( algorithm, include_state=False, include_prev_action_reward=False ): """Tests different combinations of args for algorithm.compute_single_action. Args: algorithm: The Algorithm object to test. include_state: Whether to include the initial state of the Policy's Model in the `compute_single_action` call. include_prev_action_reward: Whether to include the prev-action and -reward in the `compute_single_action` call. Raises: ValueError: If anything unexpected happens. """ # Have to import this here to avoid circular dependency. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch # Some Algorithms may not abide to the standard API. pid = DEFAULT_POLICY_ID try: # Multi-agent: Pick any learnable policy (or DEFAULT_POLICY if it's the only # one). pid = next(iter(algorithm.workers.local_worker().get_policies_to_train())) pol = algorithm.get_policy(pid) except AttributeError: pol = algorithm.policy # Get the policy's model. model = pol.model action_space = pol.action_space def _test( what, method_to_test, obs_space, full_fetch, explore, timestep, unsquash, clip ): call_kwargs = {} if what is algorithm: call_kwargs["full_fetch"] = full_fetch call_kwargs["policy_id"] = pid obs = obs_space.sample() if isinstance(obs_space, Box): obs = np.clip(obs, -1.0, 1.0) state_in = None if include_state: state_in = model.get_initial_state() if not state_in: state_in = [] i = 0 while f"state_in_{i}" in model.view_requirements: state_in.append( model.view_requirements[f"state_in_{i}"].space.sample() ) i += 1 action_in = action_space.sample() if include_prev_action_reward else None reward_in = 1.0 if include_prev_action_reward else None if method_to_test == "input_dict": assert what is pol input_dict = {SampleBatch.OBS: obs} if include_prev_action_reward: input_dict[SampleBatch.PREV_ACTIONS] = action_in input_dict[SampleBatch.PREV_REWARDS] = reward_in if state_in: for i, s in enumerate(state_in): input_dict[f"state_in_{i}"] = s input_dict_batched = SampleBatch( tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict) ) action = pol.compute_actions_from_input_dict( input_dict=input_dict_batched, explore=explore, timestep=timestep, **call_kwargs, ) # Unbatch everything to be able to compare against single # action below. # ARS and ES return action batches as lists. if isinstance(action[0], list): action = (np.array(action[0]), action[1], action[2]) action = tree.map_structure(lambda s: s[0], action) try: action2 = pol.compute_single_action( input_dict=input_dict, explore=explore, timestep=timestep, **call_kwargs, ) # Make sure these are the same, unless we have exploration # switched on (or noisy layers). if not explore and not pol.config.get("noisy"): check(action, action2) except TypeError: pass else: action = what.compute_single_action( obs, state_in, prev_action=action_in, prev_reward=reward_in, explore=explore, timestep=timestep, unsquash_action=unsquash, clip_action=clip, **call_kwargs, ) state_out = None if state_in or full_fetch or what is pol: action, state_out, _ = action if state_out: for si, so in zip(state_in, state_out): check(list(si.shape), so.shape) if unsquash is None: unsquash = what.config["normalize_actions"] if clip is None: clip = what.config["clip_actions"] # Test whether unsquash/clipping works on the Algorithm's # compute_single_action method: Both flags should force the action # to be within the space's bounds. if method_to_test == "single" and what == algorithm: if not action_space.contains(action) and ( clip or unsquash or not isinstance(action_space, Box) ): raise ValueError( f"Returned action ({action}) of algorithm/policy {what} " f"not in Env's action_space {action_space}" ) # We are operating in normalized space: Expect only smaller action # values. if ( isinstance(action_space, Box) and not unsquash and what.config.get("normalize_actions") and np.any(np.abs(action) > 15.0) ): raise ValueError( f"Returned action ({action}) of algorithm/policy {what} " "should be in normalized space, but seems too large/small " "for that!" ) # Loop through: Policy vs Algorithm; Different API methods to calculate # actions; unsquash option; clip option; full fetch or not. for what in [pol, algorithm]: if what is algorithm: # Get the obs-space from Workers.env (not Policy) due to possible # pre-processor up front. worker_set = getattr(algorithm, "workers", None) assert worker_set if isinstance(worker_set, list): obs_space = algorithm.get_policy(pid).observation_space else: obs_space = worker_set.local_worker().for_policy( lambda p: p.observation_space, policy_id=pid ) obs_space = getattr(obs_space, "original_space", obs_space) else: obs_space = pol.observation_space for method_to_test in ["single"] + (["input_dict"] if what is pol else []): for explore in [True, False]: for full_fetch in [False, True] if what is algorithm else [False]: timestep = random.randint(0, 100000) for unsquash in [True, False, None]: for clip in [False] if unsquash else [True, False, None]: _test( what, method_to_test, obs_space, full_fetch, explore, timestep, unsquash, clip, ) def check_learning_achieved( tune_results: "tune.ResultGrid", min_reward, evaluation=False ): """Throws an error if `min_reward` is not reached within tune_results. Checks the last iteration found in tune_results for its "episode_reward_mean" value and compares it to `min_reward`. Args: tune_results: The tune.run returned results object. min_reward: The min reward that must be reached. Raises: ValueError: If `min_reward` not reached. """ # Get maximum reward of all trials # (check if at least one trial achieved some learning) avg_rewards = [ ( row["episode_reward_mean"] if not evaluation else row["evaluation/episode_reward_mean"] ) for _, row in tune_results.get_dataframe().iterrows() ] best_avg_reward = max(avg_rewards) if best_avg_reward < min_reward: raise ValueError(f"`stop-reward` of {min_reward} not reached!") print(f"`stop-reward` of {min_reward} reached! ok") def check_train_results(train_results): """Checks proper structure of a Algorithm.train() returned dict. Args: train_results: The train results dict to check. Raises: AssertionError: If `train_results` doesn't have the proper structure or data in it. """ # Import these here to avoid circular dependencies. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent # Assert that some keys are where we would expect them. for key in [ "agent_timesteps_total", "config", "custom_metrics", "episode_len_mean", "episode_reward_max", "episode_reward_mean", "episode_reward_min", "episodes_total", "hist_stats", "info", "iterations_since_restore", "num_healthy_workers", "perf", "policy_reward_max", "policy_reward_mean", "policy_reward_min", "sampler_perf", "time_since_restore", "time_this_iter_s", "timesteps_since_restore", "timesteps_total", "timers", "time_total_s", "training_iteration", ]: assert ( key in train_results ), f"'{key}' not found in `train_results` ({train_results})!" _, is_multi_agent = check_multi_agent(train_results["config"]) # Check in particular the "info" dict. info = train_results["info"] assert LEARNER_INFO in info, f"'learner' not in train_results['infos'] ({info})!" assert ( "num_steps_trained" in info or NUM_ENV_STEPS_TRAINED in info ), f"'num_(env_)?steps_trained' not in train_results['infos'] ({info})!" learner_info = info[LEARNER_INFO] # Make sure we have a default_policy key if we are not in a # multi-agent setup. if not is_multi_agent: # APEX algos sometimes have an empty learner info dict (no metrics # collected yet). assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, ( f"'{DEFAULT_POLICY_ID}' not found in " f"train_results['infos']['learner'] ({learner_info})!" ) for pid, policy_stats in learner_info.items(): if pid == "batch_count": continue # Make sure each policy has the LEARNER_STATS_KEY under it. assert LEARNER_STATS_KEY in policy_stats learner_stats = policy_stats[LEARNER_STATS_KEY] for key, value in learner_stats.items(): # Min- and max-stats should be single values. if key.startswith("min_") or key.startswith("max_"): assert np.isscalar(value), f"'key' value not a scalar ({value})!" return train_results def run_learning_tests_from_yaml( yaml_files: List[str], *, max_num_repeats: int = 2, use_pass_criteria_as_stop: bool = True, smoke_test: bool = False, ) -> Dict[str, Any]: """Runs the given experiments in yaml_files and returns results dict. Args: yaml_files: List of yaml file names. max_num_repeats: How many times should we repeat a failed experiment? use_pass_criteria_as_stop: Configure the Trial so that it stops as soon as pass criterias are met. smoke_test: Whether this is just a smoke-test. If True, set time_total_s to 5min and don't early out due to rewards or timesteps reached. Returns: A results dict mapping strings (e.g. "time_taken", "stats", "passed") to the respective stats/values. """ print("Will run the following yaml files:") for yaml_file in yaml_files: print("->", yaml_file) # All trials we'll ever run in this test script. all_trials = [] # The experiments (by name) we'll run up to `max_num_repeats` times. experiments = {} # The results per experiment. checks = {} # Metrics per experiment. stats = {} start_time = time.monotonic() def should_check_eval(experiment): # If we have evaluation workers, use their rewards. # This is useful for offline learning tests, where # we evaluate against an actual environment. return experiment["config"].get("evaluation_interval", None) is not None # Loop through all collected files and gather experiments. # Augment all by `torch` framework. for yaml_file in yaml_files: tf_experiments = yaml.safe_load(open(yaml_file).read()) # Add torch version of all experiments to the list. for k, e in tf_experiments.items(): # If framework explicitly given, only test for that framework. # Some algos do not have both versions available. if "frameworks" in e: frameworks = e["frameworks"] else: # By default we don't run tf2, because tf2's multi-gpu support # isn't complete yet. frameworks = ["tf", "torch"] # Pop frameworks key to not confuse Tune. e.pop("frameworks", None) e["stop"] = e["stop"] if "stop" in e else {} e["pass_criteria"] = e["pass_criteria"] if "pass_criteria" in e else {} check_eval = should_check_eval(e) episode_reward_key = ( "episode_reward_mean" if not check_eval else "evaluation/episode_reward_mean" ) # For smoke-tests, we just run for n min. if smoke_test: # 0sec for each(!) experiment/trial. # This is such that if there are many experiments/trials # in a test (e.g. rllib_learning_test), each one can at least # create its Algorithm and run a first iteration. e["stop"]["time_total_s"] = 0 else: if use_pass_criteria_as_stop: # We also stop early, once we reach the desired reward. min_reward = e.get("pass_criteria", {}).get(episode_reward_key) if min_reward is not None: e["stop"][episode_reward_key] = min_reward # Generate `checks` dict for all experiments # (tf, tf2 and/or torch). for framework in frameworks: k_ = k + "-" + framework ec = copy.deepcopy(e) ec["config"]["framework"] = framework if framework == "tf2": ec["config"]["eager_tracing"] = True checks[k_] = { "min_reward": ec["pass_criteria"].get(episode_reward_key, 0.0), "min_throughput": ec["pass_criteria"].get("timesteps_total", 0.0) / (ec["stop"].get("time_total_s", 1.0) or 1.0), "time_total_s": ec["stop"].get("time_total_s"), "failures": 0, "passed": False, } # This key would break tune. ec.pop("pass_criteria", None) # One experiment to run. experiments[k_] = ec # Keep track of those experiments we still have to run. # If an experiment passes, we'll remove it from this dict. experiments_to_run = experiments.copy() try: ray.init(address="auto") except ConnectionError: ray.init() for i in range(max_num_repeats): # We are done. if len(experiments_to_run) == 0: print("All experiments finished.") break print(f"Starting learning test iteration {i}...") # Print out the actual config. print("== Test config ==") print(yaml.dump(experiments_to_run)) # Run remaining experiments. trials = run_experiments( experiments_to_run, resume=False, verbose=2, progress_reporter=CLIReporter( metric_columns={ "training_iteration": "iter", "time_total_s": "time_total_s", NUM_ENV_STEPS_SAMPLED: "ts (sampled)", NUM_ENV_STEPS_TRAINED: "ts (trained)", "episodes_this_iter": "train_episodes", "episode_reward_mean": "reward_mean", "evaluation/episode_reward_mean": "eval_reward_mean", }, parameter_columns=["framework"], sort_by_metric=True, max_report_frequency=30, ), ) all_trials.extend(trials) # Check each experiment for whether it passed. # Criteria is to a) reach reward AND b) to have reached the throughput # defined by `NUM_ENV_STEPS_(SAMPLED|TRAINED)` / `time_total_s`. for experiment in experiments_to_run.copy(): print(f"Analyzing experiment {experiment} ...") # Collect all trials within this experiment (some experiments may # have num_samples or grid_searches defined). trials_for_experiment = [] for t in trials: trial_exp = re.sub(".+/([^/]+)$", "\\1", t.local_dir) if trial_exp == experiment: trials_for_experiment.append(t) print(f" ... Trials: {trials_for_experiment}.") check_eval = should_check_eval(experiments[experiment]) # Error: Increase failure count and repeat. if any(t.status == "ERROR" for t in trials_for_experiment): print(" ... ERROR.") checks[experiment]["failures"] += 1 # Smoke-tests always succeed. elif smoke_test: print(" ... SMOKE TEST (mark ok).") checks[experiment]["passed"] = True del experiments_to_run[experiment] # Experiment finished: Check reward achieved and timesteps done # (throughput). else: # Use best_result's reward to check min_reward. if check_eval: episode_reward_mean = np.mean( [ t.metric_analysis["evaluation/episode_reward_mean"]["max"] for t in trials_for_experiment ] ) else: episode_reward_mean = np.mean( [ t.metric_analysis["episode_reward_mean"]["max"] for t in trials_for_experiment ] ) desired_reward = checks[experiment]["min_reward"] # Use last_result["timesteps_total"] to check throughput. timesteps_total = np.mean( [t.last_result["timesteps_total"] for t in trials_for_experiment] ) total_time_s = np.mean( [t.last_result["time_total_s"] for t in trials_for_experiment] ) # TODO(jungong) : track training- and env throughput separately. throughput = timesteps_total / (total_time_s or 1.0) # Throughput verification is not working. Many algorithm, e.g. TD3, # achieves the learning goal, but fails the throughput check # miserably. # TODO(jungong): Figure out why. # # desired_throughput = checks[experiment]["min_throughput"] desired_throughput = None # Record performance. stats[experiment] = { "episode_reward_mean": float(episode_reward_mean), "throughput": ( float(throughput) if throughput is not None else 0.0 ), } print( f" ... Desired reward={desired_reward}; " f"desired throughput={desired_throughput}" ) # We failed to reach desired reward or the desired throughput. if (desired_reward and episode_reward_mean < desired_reward) or ( desired_throughput and throughput < desired_throughput ): print( " ... Not successful: Actual " f"reward={episode_reward_mean}; " f"actual throughput={throughput}" ) checks[experiment]["failures"] += 1 # We succeeded! else: print( " ... Successful: (mark ok). Actual " f"reward={episode_reward_mean}; " f"actual throughput={throughput}" ) checks[experiment]["passed"] = True del experiments_to_run[experiment] ray.shutdown() time_taken = time.monotonic() - start_time # Create results dict and write it to disk. result = { "time_taken": float(time_taken), "trial_states": dict(Counter([trial.status for trial in all_trials])), "last_update": float(time.time()), "stats": stats, "passed": [k for k, exp in checks.items() if exp["passed"]], "failures": { k: exp["failures"] for k, exp in checks.items() if exp["failures"] > 0 }, } return result def check_same_batch(batch1, batch2) -> None: """Check if both batches are (almost) identical. For MultiAgentBatches, the step count and individual policy's SampleBatches are checked for identity. For SampleBatches, identity is checked as the almost numerical key-value-pair identity between batches with ray.rllib.utils.test_utils.check(). unroll_id is compared only if both batches have an unroll_id. Args: batch1: Batch to compare against batch2 batch2: Batch to compare against batch1 """ # Avoids circular import from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch assert type(batch1) == type( batch2 ), "Input batches are of different types {} and {}".format( str(type(batch1)), str(type(batch2)) ) def check_sample_batches(_batch1, _batch2, _policy_id=None): unroll_id_1 = _batch1.get("unroll_id", None) unroll_id_2 = _batch2.get("unroll_id", None) # unroll IDs only have to fit if both batches have them if unroll_id_1 is not None and unroll_id_2 is not None: assert unroll_id_1 == unroll_id_2 batch1_keys = set() for k, v in _batch1.items(): # unroll_id is compared above already if k == "unroll_id": continue check(v, _batch2[k]) batch1_keys.add(k) batch2_keys = set(_batch2.keys()) # unroll_id is compared above already batch2_keys.discard("unroll_id") _difference = batch1_keys.symmetric_difference(batch2_keys) # Cases where one batch has info and the other has not if _policy_id: assert not _difference, ( "SampleBatches for policy with ID {} " "don't share information on the " "following information: \n{}" "".format(_policy_id, _difference) ) else: assert not _difference, ( "SampleBatches don't share information " "on the following information: \n{}" "".format(_difference) ) if type(batch1) == SampleBatch: check_sample_batches(batch1, batch2) elif type(batch1) == MultiAgentBatch: assert batch1.count == batch2.count batch1_ids = set() for policy_id, policy_batch in batch1.policy_batches.items(): check_sample_batches( policy_batch, batch2.policy_batches[policy_id], policy_id ) batch1_ids.add(policy_id) # Case where one ma batch has info on a policy the other has not batch2_ids = set(batch2.policy_batches.keys()) difference = batch1_ids.symmetric_difference(batch2_ids) assert ( not difference ), f"MultiAgentBatches don't share the following information: \n{difference}." else: raise ValueError("Unsupported batch type " + str(type(batch1))) def check_reproducibilty( algo_class: Type["Algorithm"], algo_config: "AlgorithmConfig", *, fw_kwargs: Dict[str, Any], training_iteration: int = 1, ) -> None: # TODO @kourosh: we can get rid of examples/deterministic_training.py once # this is added to all algorithms """Check if the algorithm is reproducible across different testing conditions: frameworks: all input frameworks num_gpus: int(os.environ.get("RLLIB_NUM_GPUS", "0")) num_workers: 0 (only local workers) or 4 ((1) local workers + (4) remote workers) num_envs_per_worker: 2 Args: algo_class: Algorithm class to test. algo_config: Base config to use for the algorithm. fw_kwargs: Framework iterator keyword arguments. training_iteration: Number of training iterations to run. Returns: None Raises: It raises an AssertionError if the algorithm is not reproducible. """ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO stop_dict = { "training_iteration": training_iteration, } # use 0 and 2 workers (for more that 4 workers we have to make sure the instance # type in ci build has enough resources) for num_workers in [0, 2]: algo_config = ( algo_config.debugging(seed=42) .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) .rollouts(num_rollout_workers=num_workers, num_envs_per_worker=2) ) for fw in framework_iterator(algo_config, **fw_kwargs): print( f"Testing reproducibility of {algo_class.__name__}" f" with {num_workers} workers on fw = {fw}" ) print("/// config") pprint.pprint(algo_config.to_dict()) # test tune.run() reproducibility results1 = tune.Tuner( algo_class, param_space=algo_config.to_dict(), run_config=air.RunConfig(stop=stop_dict, verbose=1), ).fit() results1 = results1.get_best_result().metrics results2 = tune.Tuner( algo_class, param_space=algo_config.to_dict(), run_config=air.RunConfig(stop=stop_dict, verbose=1), ).fit() results2 = results2.get_best_result().metrics # Test rollout behavior. check(results1["hist_stats"], results2["hist_stats"]) # As well as training behavior (minibatch sequence during SGD # iterations). check( results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"], results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"], )