From 952f60fe2eafad34e152ec530e5eb65ca7ca1a67 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 10 Dec 2024 14:55:24 -0800 Subject: [PATCH] Docker: Upgrade Jax to 0.4.37 --- Dockerfile | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Dockerfile b/Dockerfile index c35de629..ac9ece06 100644 --- a/Dockerfile +++ b/Dockerfile @@ -95,6 +95,14 @@ ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.h # Jax will fallback to CPU when run on a machine without TPU. RUN pip install .[core,tpu] RUN if [ -n "$EXTRAS" ]; then pip install .[$EXTRAS]; fi + +# V6E requires using jax 0.4.37 so temporarily override before AXLearn +# upgrades to Jax 0.4.37. +ENV DATE=20241201 +RUN pip install -U "jax[tpu]==0.4.37" "jax==0.4.37" "jaxlib==0.4.36" \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +RUN pip install -U "libtpu-nightly==0.1.dev${DATE}" \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html COPY . . ################################################################################