mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[GCP] Update GCP TPU config (#18634)
* [autoscaler] Update GCP TPU config * Preemptible by default * Remove libtpu link from head node * Workaround
This commit is contained in:
parent
9c9b482661
commit
573c66a755
1 changed files with 12 additions and 6 deletions
|
@ -32,9 +32,9 @@ available_node_types:
|
||||||
# Support for TPU pods will be added in the future.
|
# Support for TPU pods will be added in the future.
|
||||||
acceleratorType: v2-8
|
acceleratorType: v2-8
|
||||||
runtimeVersion: v2-alpha
|
runtimeVersion: v2-alpha
|
||||||
# Uncomment to use preemptible TPUs
|
schedulingConfig:
|
||||||
# schedulingConfig:
|
# Set to false to use non-preemptible TPUs
|
||||||
# preemptible: true
|
preemptible: true
|
||||||
|
|
||||||
provider:
|
provider:
|
||||||
type: gcp
|
type: gcp
|
||||||
|
@ -51,15 +51,21 @@ head_node_type: ray_head_default
|
||||||
# Compute instances have python 3.7, but TPUs have 3.8 - need to update
|
# Compute instances have python 3.7, but TPUs have 3.8 - need to update
|
||||||
# Install Jax and other dependencies on the Compute head node
|
# Install Jax and other dependencies on the Compute head node
|
||||||
head_setup_commands:
|
head_setup_commands:
|
||||||
- conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10
|
# Two first lines are a workaround for ssh timing out
|
||||||
- export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc
|
- sleep 2
|
||||||
- python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
- sleep 2
|
||||||
|
- sudo chown -R $(whoami) /opt/conda/*
|
||||||
|
- conda create -y -n "ray" python=3.8.5
|
||||||
|
- conda activate ray && echo 'conda activate ray' >> ~/.bashrc
|
||||||
|
- python -m pip install --upgrade pip
|
||||||
|
- python -m pip install --upgrade "jax[cpu]==0.2.14"
|
||||||
- python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default]
|
- python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default]
|
||||||
- python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
|
- python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
|
||||||
- git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install .
|
- git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install .
|
||||||
|
|
||||||
# Install Jax and other dependencies on TPU
|
# Install Jax and other dependencies on TPU
|
||||||
worker_setup_commands:
|
worker_setup_commands:
|
||||||
|
- pip3 install --upgrade pip
|
||||||
- pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
- pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||||
- pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default]
|
- pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default]
|
||||||
- python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly
|
- python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly
|
||||||
|
|
Loading…
Add table
Reference in a new issue