From 573c66a7557915bdfbc2bf91857d5a0b870ed709 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Sep 2021 21:41:26 +0200 Subject: [PATCH] [GCP] Update GCP TPU config (#18634) * [autoscaler] Update GCP TPU config * Preemptible by default * Remove libtpu link from head node * Workaround --- python/ray/autoscaler/gcp/tpu.yaml | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/ray/autoscaler/gcp/tpu.yaml b/python/ray/autoscaler/gcp/tpu.yaml index 34726cb22..a963e62c1 100644 --- a/python/ray/autoscaler/gcp/tpu.yaml +++ b/python/ray/autoscaler/gcp/tpu.yaml @@ -32,9 +32,9 @@ available_node_types: # Support for TPU pods will be added in the future. acceleratorType: v2-8 runtimeVersion: v2-alpha - # Uncomment to use preemptible TPUs - # schedulingConfig: - # preemptible: true + schedulingConfig: + # Set to false to use non-preemptible TPUs + preemptible: true provider: type: gcp @@ -51,15 +51,21 @@ head_node_type: ray_head_default # Compute instances have python 3.7, but TPUs have 3.8 - need to update # Install Jax and other dependencies on the Compute head node head_setup_commands: - - conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10 - - export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc - - python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + # Two first lines are a workaround for ssh timing out + - sleep 2 + - sleep 2 + - sudo chown -R $(whoami) /opt/conda/* + - conda create -y -n "ray" python=3.8.5 + - conda activate ray && echo 'conda activate ray' >> ~/.bashrc + - python -m pip install --upgrade pip + - python -m pip install --upgrade "jax[cpu]==0.2.14" - python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default] - python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl - git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install . # Install Jax and other dependencies on TPU worker_setup_commands: + - pip3 install --upgrade pip - pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default] - python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly