File tree Expand file tree Collapse file tree 3 files changed +6
-11
lines changed Expand file tree Collapse file tree 3 files changed +6
-11
lines changed Original file line number Diff line number Diff line change @@ -78,17 +78,9 @@ RUN apt-get install -y ocl-icd-libopencl1 clinfo && \
7878 uv pip install --system /tmp/lightgbm/*.whl && \
7979 rm -rf /tmp/lightgbm && \
8080 /tmp/clean-layer.sh
81-
82- # Remove CUDA_VERSION from non-GPU image.
83- {{ else }}
84- ENV CUDA_VERSION=""
8581{{ end }}
8682
8783
88- # Update GPG key per documentation at https://cloud.google.com/compute/docs/troubleshooting/known-issues
89- RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
90- RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
91-
9284# Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections,
9385# as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346
9486RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \
Original file line number Diff line number Diff line change @@ -11,7 +11,10 @@ def getAcceleratorName():
1111 except FileNotFoundError :
1212 return ("nvidia-smi not found." )
1313
14- gpu_test = unittest .skipIf (len (os .environ .get ('CUDA_VERSION' , '' )) == 0 , 'Not running GPU tests' )
14+ def isGPU ():
15+ return os .path .isfile ('/proc/driver/nvidia/version' )
16+
17+ gpu_test = unittest .skipIf (not isGPU (), 'Not running GPU tests' )
1518# b/342143152 P100s are slowly being unsupported in new release of popular ml tools such as RAPIDS.
1619p100_exempt = unittest .skipIf (getAcceleratorName () == "Tesla P100-PCIE-16GB" , 'Not running p100 exempt tests' )
1720tpu_test = unittest .skipIf (len (os .environ .get ('ISTPUVM' , '' )) == 0 , 'Not running TPU tests' )
Original file line number Diff line number Diff line change 66import jax
77import jax .numpy as np
88
9- from common import gpu_test
9+ from common import gpu_test , isGPU
1010from jax import grad , jit
1111
1212
@@ -21,5 +21,5 @@ def test_grad(self):
2121 self .assertEqual (0.4199743 , ag )
2222
2323 def test_backend (self ):
24- expected_backend = 'cpu' if len ( os . environ . get ( 'CUDA_VERSION' , '' )) == 0 else 'gpu'
24+ expected_backend = 'cpu' if not isGPU () else 'gpu'
2525 self .assertEqual (expected_backend , jax .default_backend ())
You can’t perform that action at this time.
0 commit comments