[GCP] Update GCP TPU config (#18634)

* [autoscaler] Update GCP TPU config

* Preemptible by default

* Remove libtpu link from head node

* Workaround
This commit is contained in:
Antoni Baum 2021-09-29 21:41:26 +02:00 committed by GitHub
parent 9c9b482661
commit 573c66a755
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -32,9 +32,9 @@ available_node_types:
# Support for TPU pods will be added in the future. # Support for TPU pods will be added in the future.
acceleratorType: v2-8 acceleratorType: v2-8
runtimeVersion: v2-alpha runtimeVersion: v2-alpha
# Uncomment to use preemptible TPUs schedulingConfig:
# schedulingConfig: # Set to false to use non-preemptible TPUs
# preemptible: true preemptible: true
provider: provider:
type: gcp type: gcp
@ -51,15 +51,21 @@ head_node_type: ray_head_default
# Compute instances have python 3.7, but TPUs have 3.8 - need to update # Compute instances have python 3.7, but TPUs have 3.8 - need to update
# Install Jax and other dependencies on the Compute head node # Install Jax and other dependencies on the Compute head node
head_setup_commands: head_setup_commands:
- conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10 # Two first lines are a workaround for ssh timing out
- export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc - sleep 2
- python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - sleep 2
- sudo chown -R $(whoami) /opt/conda/*
- conda create -y -n "ray" python=3.8.5
- conda activate ray && echo 'conda activate ray' >> ~/.bashrc
- python -m pip install --upgrade pip
- python -m pip install --upgrade "jax[cpu]==0.2.14"
- python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default] - python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default]
- python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl - python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
- git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install . - git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install .
# Install Jax and other dependencies on TPU # Install Jax and other dependencies on TPU
worker_setup_commands: worker_setup_commands:
- pip3 install --upgrade pip
- pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default] - pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default]
- python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly - python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly