diff --git a/rllib/train.py b/rllib/train.py index ef6f23039..fa492cfed 100755 --- a/rllib/train.py +++ b/rllib/train.py @@ -182,10 +182,15 @@ def run(args, parser): inputs = force_list(input_) # This script runs in the ray/rllib dir. rllib_dir = Path(__file__).parent - abs_inputs = [ - str(rllib_dir.absolute().joinpath(i)) - if not os.path.exists(i) else i for i in inputs - ] + + def patch_path(path): + if os.path.exists(path): + return path + else: + abs_path = str(rllib_dir.absolute().joinpath(path)) + return abs_path if os.path.exists(abs_path) else path + + abs_inputs = list(map(patch_path, inputs)) if not isinstance(input_, list): abs_inputs = abs_inputs[0]