1- import numpy as np
1+ import sys
2+
23import pytest
34from mpi4py import MPI
45
56from .algo import partition , lower_bound
6- from .utils import get_n_proc_for_test , add_n_procs , run_item_test , mark_original_index
7- from .utils_mpi import number_of_working_processes , is_dyn_master_process
7+ from .utils . items import get_n_proc_for_test , add_n_procs , run_item_test , mark_original_index
8+ from .utils . mpi import number_of_working_processes , is_dyn_master_process
89from .gather_report import gather_report_on_local_rank_0
910from .static_scheduler_utils import group_items_by_parallel_steps
1011
1112
1213def mark_skip (item ):
1314 comm = MPI .COMM_WORLD
14- n_rank = comm .Get_size ()
15+ n_rank = comm .size
1516 n_proc_test = get_n_proc_for_test (item )
1617 skip_msg = f"Not enough procs to execute: { n_proc_test } required but only { n_rank } available"
1718 item .add_marker (pytest .mark .skip (reason = skip_msg ), append = False )
@@ -28,7 +29,8 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
2829 if mpi_comm_creation_function == 'MPI_Comm_create' :
2930 return sub_comm_from_ranks (global_comm , range (0 ,n_proc ))
3031 elif mpi_comm_creation_function == 'MPI_Comm_split' :
31- if i_rank < n_proc_test :
32+ i_rank = global_comm .rank
33+ if i_rank < n_proc :
3234 color = 1
3335 else :
3436 color = MPI .UNDEFINED
@@ -37,8 +39,7 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
3739 assert 0 , 'Unknown MPI communicator creation function. Available: `MPI_Comm_create`, `MPI_Comm_split`'
3840
3941def create_sub_comms_for_each_size (global_comm , mpi_comm_creation_function ):
40- i_rank = global_comm .Get_rank ()
41- n_rank = global_comm .Get_size ()
42+ n_rank = global_comm .size
4243 sub_comms = [None ] * n_rank
4344 for i in range (0 ,n_rank ):
4445 n_proc = i + 1
@@ -47,8 +48,7 @@ def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function):
4748
4849
4950def add_sub_comm (items , global_comm , test_comm_creation , mpi_comm_creation_function ):
50- i_rank = global_comm .Get_rank ()
51- n_rank = global_comm .Get_size ()
51+ n_rank = global_comm .size
5252
5353 # Strategy 'by_rank': create one sub-communicator by size, from sequential (size=1) to n_rank
5454 if test_comm_creation == 'by_rank' :
@@ -71,12 +71,17 @@ def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_funct
7171 assert 0 , 'Unknown test MPI communicator creation strategy. Available: `by_rank`, `by_test`'
7272
7373class SequentialScheduler :
74- def __init__ (self , global_comm , test_comm_creation = 'by_rank' , mpi_comm_creation_function = 'MPI_Comm_create' , barrier_at_test_start = True , barrier_at_test_end = True ):
74+ def __init__ (self , global_comm ):
7575 self .global_comm = global_comm .Dup () # ensure that all communications within the framework are private to the framework
76- self .test_comm_creation = test_comm_creation
77- self .mpi_comm_creation_function = mpi_comm_creation_function
78- self .barrier_at_test_start = barrier_at_test_start
79- self .barrier_at_test_end = barrier_at_test_end
76+
77+ # These parameters are not accessible through the API, but are left here for tweaking and experimenting
78+ self .test_comm_creation = 'by_rank' # possible values : 'by_rank' | 'by_test'
79+ self .mpi_comm_creation_function = 'MPI_Comm_create' # possible values : 'MPI_Comm_create' | 'MPI_Comm_split'
80+ self .barrier_at_test_start = True
81+ self .barrier_at_test_end = True
82+ if sys .platform == "win32" :
83+ self .mpi_comm_creation_function = 'MPI_Comm_split' # because 'MPI_Comm_create' uses `Create_group`,
84+ # that is not implemented in mpi4py for Windows
8085
8186 @pytest .hookimpl (trylast = True )
8287 def pytest_collection_modifyitems (self , config , items ):
@@ -86,20 +91,10 @@ def pytest_collection_modifyitems(self, config, items):
8691 def pytest_runtest_protocol (self , item , nextitem ):
8792 if self .barrier_at_test_start :
8893 self .global_comm .barrier ()
89- #print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
9094 _ = yield
91- #print(f'pytest_runtest_protocol end {MPI.COMM_WORLD.rank=}')
9295 if self .barrier_at_test_end :
9396 self .global_comm .barrier ()
9497
95- #@pytest.hookimpl(tryfirst=True)
96- #def pytest_runtest_protocol(self, item, nextitem):
97- # if self.barrier_at_test_start:
98- # self.global_comm.barrier()
99- # print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
100- # if item.sub_comm == MPI.COMM_NULL:
101- # return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run
102-
10398 @pytest .hookimpl (tryfirst = True )
10499 def pytest_pyfunc_call (self , pyfuncitem ):
105100 #print(f'pytest_pyfunc_call {MPI.COMM_WORLD.rank=}')
@@ -113,7 +108,7 @@ def pytest_runtestloop(self, session) -> bool:
113108 _ = yield
114109 # prevent return value being non-zero (ExitCode.NO_TESTS_COLLECTED)
115110 # when no test run on non-master
116- if self .global_comm .Get_rank () != 0 and session .testscollected == 0 :
111+ if self .global_comm .rank != 0 and session .testscollected == 0 :
117112 session .testscollected = 1
118113 return True
119114
@@ -136,7 +131,7 @@ def pytest_runtest_logreport(self, report):
136131
137132
138133def prepare_items_to_run (items , comm ):
139- i_rank = comm .Get_rank ()
134+ i_rank = comm .rank
140135
141136 items_to_run = []
142137
@@ -168,7 +163,7 @@ def prepare_items_to_run(items, comm):
168163
169164
170165def items_to_run_on_this_proc (items_by_steps , items_to_skip , comm ):
171- i_rank = comm .Get_rank ()
166+ i_rank = comm .rank
172167
173168 items = []
174169
@@ -204,14 +199,13 @@ def pytest_runtestloop(self, session) -> bool:
204199 and not session .config .option .continue_on_collection_errors
205200 ):
206201 raise session .Interrupted (
207- "%d error%s during collection"
208- % (session .testsfailed , "s" if session .testsfailed != 1 else "" )
202+ f"{ session .testsfailed } error{ 's' if session .testsfailed != 1 else '' } during collection"
209203 )
210204
211205 if session .config .option .collectonly :
212206 return True
213207
214- n_workers = self .global_comm .Get_size ()
208+ n_workers = self .global_comm .size
215209
216210 add_n_procs (session .items )
217211
@@ -221,20 +215,12 @@ def pytest_runtestloop(self, session) -> bool:
221215 items_by_steps , items_to_skip , self .global_comm
222216 )
223217
224- for i , item in enumerate (items ):
225- # nextitem = items[i + 1] if i + 1 < len(items) else None
226- # For optimization purposes, it would be nice to have the previous commented line
227- # (`nextitem` is only used internally by PyTest in _setupstate.teardown_exact)
228- # Here, it does not work:
229- # it seems that things are messed up on rank 0
230- # because the nextitem might not be run (see pytest_runtest_setup/call/teardown hooks just above)
231- # In practice though, it seems that it is not the main thing that slows things down...
232-
218+ for item in items :
233219 nextitem = None
234220 run_item_test (item , nextitem , session )
235221
236222 # prevent return value being non-zero (ExitCode.NO_TESTS_COLLECTED) when no test run on non-master
237- if self .global_comm .Get_rank () != 0 and session .testscollected == 0 :
223+ if self .global_comm .rank != 0 and session .testscollected == 0 :
238224 session .testscollected = 1
239225 return True
240226
@@ -256,8 +242,8 @@ def pytest_runtest_logreport(self, report):
256242 gather_report_on_local_rank_0 (report )
257243
258244 # master ranks of each sub_comm must send their report to rank 0
259- if sub_comm .Get_rank () == 0 : # only master are concerned
260- if self .global_comm .Get_rank () != 0 : # if master is not global master, send
245+ if sub_comm .rank == 0 : # only master are concerned
246+ if self .global_comm .rank != 0 : # if master is not global master, send
261247 self .global_comm .send (report , dest = 0 )
262248 elif report .master_running_proc != 0 : # else, recv if test run remotely
263249 # In the line below, MPI.ANY_TAG will NOT clash with communications outside the framework because self.global_comm is private
@@ -322,7 +308,7 @@ def schedule_test(item, available_procs, inter_comm):
322308
323309 # mark the procs as busy
324310 for sub_rank in sub_ranks :
325- available_procs [sub_rank ] = False
311+ available_procs [sub_rank ] = 0
326312
327313 # TODO isend would be slightly better (less waiting)
328314 for sub_rank in sub_ranks :
@@ -354,19 +340,19 @@ def wait_test_to_complete(items_to_run, session, available_procs, inter_comm):
354340 for sub_rank in sub_ranks :
355341 if sub_rank != first_rank_done :
356342 rank_original_idx = inter_comm .recv (source = sub_rank , tag = WORK_DONE_TAG )
357- assert ( rank_original_idx == original_idx ) # sub_rank is supposed to have worked on the same test
343+ assert rank_original_idx == original_idx # sub_rank is supposed to have worked on the same test
358344
359345 # the procs are now available
360346 for sub_rank in sub_ranks :
361- available_procs [sub_rank ] = True
347+ available_procs [sub_rank ] = 1
362348
363349 # "run" the test (i.e. trigger PyTest pipeline but do not really run the code)
364350 nextitem = None # not known at this point
365351 run_item_test (item , nextitem , session )
366352
367353
368354def wait_last_tests_to_complete (items_to_run , session , available_procs , inter_comm ):
369- while np . sum (available_procs ) < len (available_procs ):
355+ while sum (available_procs ) < len (available_procs ):
370356 wait_test_to_complete (items_to_run , session , available_procs , inter_comm )
371357
372358
@@ -418,8 +404,7 @@ def pytest_runtestloop(self, session) -> bool:
418404 and not session .config .option .continue_on_collection_errors
419405 ):
420406 raise session .Interrupted (
421- "%d error%s during collection"
422- % (session .testsfailed , "s" if session .testsfailed != 1 else "" )
407+ f"{ session .testsfailed } error{ 's' if session .testsfailed != 1 else '' } during collection"
423408 )
424409
425410 if session .config .option .collectonly :
@@ -451,10 +436,10 @@ def pytest_runtestloop(self, session) -> bool:
451436
452437 # schedule tests to run
453438 items_left_to_run = sorted (items_to_run , key = lambda item : item .n_proc )
454- available_procs = np . ones ( n_workers , dtype = np . int8 )
439+ available_procs = [ 1 ] * n_workers
455440
456441 while len (items_left_to_run ) > 0 :
457- n_av_procs = np . sum (available_procs )
442+ n_av_procs = sum (available_procs )
458443
459444 item_idx = item_with_biggest_admissible_n_proc (items_left_to_run , n_av_procs )
460445
@@ -511,7 +496,7 @@ def pytest_runtest_logreport(self, report):
511496 sub_comm = report .sub_comm
512497 gather_report_on_local_rank_0 (report )
513498
514- if sub_comm .Get_rank () == 0 : # if local master proc, send
499+ if sub_comm .rank == 0 : # if local master proc, send
515500 # The idea of the scheduler is the following:
516501 # The server schedules test over clients
517502 # A client executes the test then report to the server it is done
0 commit comments