@@ -78,7 +78,10 @@ def __init__(self, args):
7878 self ._config_gpu_memory (self ._args .gpu_mem_cap )
7979
8080 def _config_gpu_memory (self , gpu_mem_cap ):
81- gpus = tf .config .experimental .list_physical_devices ('GPU' )
81+ try :
82+ gpus = tf .config .list_physical_devices ('GPU' )
83+ except AttributeError :
84+ gpus = tf .config .experimental .list_physical_devices ('GPU' )
8285
8386 if not gpus :
8487 raise RuntimeError ("No GPUs has been found." )
@@ -90,15 +93,20 @@ def _config_gpu_memory(self, gpu_mem_cap):
9093 for gpu in gpus :
9194 try :
9295 if not gpu_mem_cap :
93- tf .config .experimental .set_memory_growth (gpu , True )
96+ try :
97+ tf .config .set_memory_growth (gpu , True )
98+ except AttributeError :
99+ tf .config .experimental .set_memory_growth (gpu , True )
100+
94101 else :
95- tf .config .experimental .set_virtual_device_configuration (
96- gpu , [
97- tf .config .experimental .VirtualDeviceConfiguration (
98- memory_limit = gpu_mem_cap
99- )
100- ]
101- )
102+ try :
103+ set_virtual_device_configuration = tf .config .set_virtual_device_configuration
104+ device_config = tf .config .LogicalDeviceConfiguration (memory_limit = gpu_mem_cap )
105+ except AttributeError :
106+ set_virtual_device_configuration = tf .config .experimental .set_virtual_device_configuration
107+ device_config = tf .config .experimental .VirtualDeviceConfiguration (memory_limit = gpu_mem_cap )
108+
109+ set_virtual_device_configuration (gpu , [device_config ])
102110 except RuntimeError as e :
103111 print ('Can not set GPU memory config' , e )
104112
0 commit comments