@@ -144,6 +144,17 @@ extern "C" {
144144///
145145/// Currently, Array objects can store only data until four dimensions
146146///
147+ /// ## Sharing Across Threads
148+ ///
149+ /// While sharing an Array with other threads, there is no need to wrap
150+ /// this in an Arc object unless only one such object is required to exist.
151+ /// The reason being that ArrayFire's internal Array is appropriately reference
152+ /// counted in thread safe manner. However, if you need to modify Array object,
153+ /// then please do wrap the object using a Mutex or Read-Write lock.
154+ ///
155+ /// Examples on how to share Array across threads is illustrated in our
156+ /// [book](http://arrayfire.org/arrayfire-rust/book/multi-threading.html)
157+ ///
147158/// ### NOTE
148159///
149160/// All operators(traits) from std::ops module implemented for Array object
@@ -156,6 +167,11 @@ pub struct Array<T: HasAfEnum> {
156167 _marker : PhantomData < T > ,
157168}
158169
170+ /// Enable safely moving Array objects across threads
171+ unsafe impl < T : HasAfEnum > Send for Array < T > { }
172+
173+ unsafe impl < T : HasAfEnum > Sync for Array < T > { }
174+
159175macro_rules! is_func {
160176 ( $doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
161177 #[ doc=$doc_str]
@@ -834,3 +850,236 @@ pub fn is_eval_manual() -> bool {
834850 ret_val > 0
835851 }
836852}
853+
854+ #[ cfg( test) ]
855+ mod tests {
856+ use super :: super :: array:: print;
857+ use super :: super :: data:: constant;
858+ use super :: super :: device:: { info, set_device, sync} ;
859+ use crate :: dim4;
860+ use std:: sync:: { mpsc, Arc , RwLock } ;
861+ use std:: thread;
862+
863+ #[ test]
864+ fn thread_move_array ( ) {
865+ // ANCHOR: move_array_to_thread
866+ set_device ( 0 ) ;
867+ info ( ) ;
868+ let mut a = constant ( 1 , dim4 ! ( 3 , 3 ) ) ;
869+
870+ let handle = thread:: spawn ( move || {
871+ //set_device to appropriate device id is required in each thread
872+ set_device ( 0 ) ;
873+
874+ println ! ( "\n From thread {:?}" , thread:: current( ) . id( ) ) ;
875+
876+ a += constant ( 2 , dim4 ! ( 3 , 3 ) ) ;
877+ print ( & a) ;
878+ } ) ;
879+
880+ //Need to join other threads as main thread holds arrayfire context
881+ handle. join ( ) . unwrap ( ) ;
882+ // ANCHOR_END: move_array_to_thread
883+ }
884+
885+ #[ test]
886+ fn thread_borrow_array ( ) {
887+ set_device ( 0 ) ;
888+ info ( ) ;
889+ let a = constant ( 1i32 , dim4 ! ( 3 , 3 ) ) ;
890+
891+ let handle = thread:: spawn ( move || {
892+ set_device ( 0 ) ; //set_device to appropriate device id is required in each thread
893+ println ! ( "\n From thread {:?}" , thread:: current( ) . id( ) ) ;
894+ print ( & a) ;
895+ } ) ;
896+ //Need to join other threads as main thread holds arrayfire context
897+ handle. join ( ) . unwrap ( ) ;
898+ }
899+
900+ // ANCHOR: multiple_threads_enum_def
901+ #[ derive( Debug , Copy , Clone ) ]
902+ enum Op {
903+ Add ,
904+ Sub ,
905+ Div ,
906+ Mul ,
907+ }
908+ // ANCHOR_END: multiple_threads_enum_def
909+
910+ #[ test]
911+ fn read_from_multiple_threads ( ) {
912+ // ANCHOR: read_from_multiple_threads
913+ let ops: Vec < _ > = vec ! [ Op :: Add , Op :: Sub , Op :: Div , Op :: Mul , Op :: Add , Op :: Div ] ;
914+
915+ // Set active GPU/device on main thread on which
916+ // subsequent Array objects are created
917+ set_device ( 0 ) ;
918+
919+ // ArrayFire Array's are internally maintained via atomic reference counting
920+ // Thus, they need no Arc wrapping while moving to another thread.
921+ // Just call clone method on the object and share the resulting clone object
922+ let a = constant ( 1.0f32 , dim4 ! ( 3 , 3 ) ) ;
923+ let b = constant ( 2.0f32 , dim4 ! ( 3 , 3 ) ) ;
924+
925+ let threads: Vec < _ > = ops
926+ . into_iter ( )
927+ . map ( |op| {
928+ let x = a. clone ( ) ;
929+ let y = b. clone ( ) ;
930+ thread:: spawn ( move || {
931+ set_device ( 0 ) ; //Both of objects are created on device 0 earlier
932+ match op {
933+ Op :: Add => {
934+ let _c = x + y;
935+ }
936+ Op :: Sub => {
937+ let _c = x - y;
938+ }
939+ Op :: Div => {
940+ let _c = x / y;
941+ }
942+ Op :: Mul => {
943+ let _c = x * y;
944+ }
945+ }
946+ sync ( 0 ) ;
947+ thread:: sleep ( std:: time:: Duration :: new ( 1 , 0 ) ) ;
948+ } )
949+ } )
950+ . collect ( ) ;
951+ for child in threads {
952+ let _ = child. join ( ) ;
953+ }
954+ // ANCHOR_END: read_from_multiple_threads
955+ }
956+
957+ #[ test]
958+ fn access_using_rwlock ( ) {
959+ // ANCHOR: access_using_rwlock
960+ let ops: Vec < _ > = vec ! [ Op :: Add , Op :: Sub , Op :: Div , Op :: Mul , Op :: Add , Op :: Div ] ;
961+
962+ // Set active GPU/device on main thread on which
963+ // subsequent Array objects are created
964+ set_device ( 0 ) ;
965+
966+ let c = constant ( 0.0f32 , dim4 ! ( 3 , 3 ) ) ;
967+ let a = constant ( 1.0f32 , dim4 ! ( 3 , 3 ) ) ;
968+ let b = constant ( 2.0f32 , dim4 ! ( 3 , 3 ) ) ;
969+
970+ // Move ownership to RwLock and wrap in Arc since same object is to be modified
971+ let c_lock = Arc :: new ( RwLock :: new ( c) ) ;
972+
973+ // a and b are internally reference counted by ArrayFire. Unless there
974+ // is prior known need that they may be modified, you can simply clone
975+ // the objects pass them to threads
976+
977+ let threads: Vec < _ > = ops
978+ . into_iter ( )
979+ . map ( |op| {
980+ let x = a. clone ( ) ;
981+ let y = b. clone ( ) ;
982+
983+ let wlock = c_lock. clone ( ) ;
984+ thread:: spawn ( move || {
985+ //Both of objects are created on device 0 in main thread
986+ //Every thread needs to set the device that it is going to
987+ //work on. Note that all Array objects must have been created
988+ //on same device as of date this is written on.
989+ set_device ( 0 ) ;
990+ if let Ok ( mut c_guard) = wlock. write ( ) {
991+ match op {
992+ Op :: Add => {
993+ * c_guard += x + y;
994+ }
995+ Op :: Sub => {
996+ * c_guard += x - y;
997+ }
998+ Op :: Div => {
999+ * c_guard += x / y;
1000+ }
1001+ Op :: Mul => {
1002+ * c_guard += x * y;
1003+ }
1004+ }
1005+ }
1006+ } )
1007+ } )
1008+ . collect ( ) ;
1009+
1010+ for child in threads {
1011+ let _ = child. join ( ) ;
1012+ }
1013+
1014+ //let read_guard = c_lock.read().unwrap();
1015+ //af_print!("C after threads joined", *read_guard);
1016+ //C after threads joined
1017+ //[3 3 1 1]
1018+ // 8.0000 8.0000 8.0000
1019+ // 8.0000 8.0000 8.0000
1020+ // 8.0000 8.0000 8.0000
1021+ // ANCHOR_END: access_using_rwlock
1022+ }
1023+
1024+ #[ test]
1025+ fn accum_using_channel ( ) {
1026+ // ANCHOR: accum_using_channel
1027+ let ops: Vec < _ > = vec ! [ Op :: Add , Op :: Sub , Op :: Div , Op :: Mul , Op :: Add , Op :: Div ] ;
1028+ let ops_len: usize = ops. len ( ) ;
1029+
1030+ // Set active GPU/device on main thread on which
1031+ // subsequent Array objects are created
1032+ set_device ( 0 ) ;
1033+
1034+ let mut c = constant ( 0.0f32 , dim4 ! ( 3 , 3 ) ) ;
1035+ let a = constant ( 1.0f32 , dim4 ! ( 3 , 3 ) ) ;
1036+ let b = constant ( 2.0f32 , dim4 ! ( 3 , 3 ) ) ;
1037+
1038+ let ( tx, rx) = mpsc:: channel ( ) ;
1039+
1040+ let threads: Vec < _ > = ops
1041+ . into_iter ( )
1042+ . map ( |op| {
1043+ // a and b are internally reference counted by ArrayFire. Unless there
1044+ // is prior known need that they may be modified, you can simply clone
1045+ // the objects pass them to threads
1046+ let x = a. clone ( ) ;
1047+ let y = b. clone ( ) ;
1048+
1049+ let tx_clone = tx. clone ( ) ;
1050+
1051+ thread:: spawn ( move || {
1052+ //Both of objects are created on device 0 in main thread
1053+ //Every thread needs to set the device that it is going to
1054+ //work on. Note that all Array objects must have been created
1055+ //on same device as of date this is written on.
1056+ set_device ( 0 ) ;
1057+
1058+ let c = match op {
1059+ Op :: Add => x + y,
1060+ Op :: Sub => x - y,
1061+ Op :: Div => x / y,
1062+ Op :: Mul => x * y,
1063+ } ;
1064+ tx_clone. send ( c) . unwrap ( ) ;
1065+ } )
1066+ } )
1067+ . collect ( ) ;
1068+
1069+ for _i in 0 ..ops_len {
1070+ c += rx. recv ( ) . unwrap ( ) ;
1071+ }
1072+
1073+ //Need to join other threads as main thread holds arrayfire context
1074+ for child in threads {
1075+ let _ = child. join ( ) ;
1076+ }
1077+
1078+ //af_print!("C after accumulating results", &c);
1079+ //[3 3 1 1]
1080+ // 8.0000 8.0000 8.0000
1081+ // 8.0000 8.0000 8.0000
1082+ // 8.0000 8.0000 8.0000
1083+ // ANCHOR_END: accum_using_channel
1084+ }
1085+ }
0 commit comments