@@ -29,7 +29,7 @@ import sys
2929
3030import  numpy as  np
3131
32- if  np.lib.NumpyVersion(np.__version__) >=  " 2.0.0a0 "  :
32+ if  np.lib.NumpyVersion(np.__version__) >=  " 2.0.0 "  :
3333    from  numpy._core._multiarray_tests import  internal_overlap
3434else :
3535    from  numpy.core._multiarray_tests import  internal_overlap
@@ -389,9 +389,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
389389    x_arr =  _process_arguments(x, n, axis, & axis_, & n_, & in_place, & xnd, 0 )
390390    x_type =  cnp.PyArray_TYPE(x_arr)
391391
392-     if  out is  not  None :
393-         in_place =  0 
394-     elif  x_type is  cnp.NPY_CFLOAT or  x_type is  cnp.NPY_CDOUBLE:
392+     if  x_type is  cnp.NPY_CFLOAT or  x_type is  cnp.NPY_CDOUBLE:
395393        #  we can operate in place if requested.
396394        if  in_place:
397395            if  not  cnp.PyArray_ISONESEGMENT(x_arr):
@@ -416,6 +414,29 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
416414        x_type =  cnp.PyArray_TYPE(x_arr)
417415        in_place =  1 
418416
417+     f_arr =  None 
418+     if  x_type is  cnp.NPY_FLOAT or  x_type is  cnp.NPY_CFLOAT:
419+         f_type =  cnp.NPY_CFLOAT
420+     else :
421+         f_type =  cnp.NPY_CDOUBLE
422+ 
423+     if  out is  not  None :
424+         out_dtype =  np.dtype(cnp.PyArray_DescrFromType(f_type))
425+         _validate_out_array(out, x, out_dtype, axis = axis_, n = n_)
426+         if  x is  out:
427+             in_place =  1 
428+         elif  (
429+             _get_element_strides(x) ==  _get_element_strides(out)
430+             and  not  np.shares_memory(x, out)
431+         ):
432+             #  out array that is used in OneMKL c2c FFT must have the same stride
433+             #  as input array and must have no common elements with input array.
434+             #  If these conditions are not met, we need to allocate a new array,
435+             #  which is done later.
436+             #  TODO: check to see if the same stride condition can be relaxed
437+             f_arr =  < cnp.ndarray>  out
438+             in_place =  0 
439+ 
419440    if  in_place:
420441        _cache_capsule =  _tls_dfti_cache_capsule()
421442        _cache =  < DftiCache * > cpython.pycapsule.PyCapsule_GetPointer(
@@ -453,25 +474,14 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
453474            ind[axis_] =  slice (0 , n_, None )
454475            x_arr =  x_arr[tuple (ind)]
455476
456-         return  x_arr
457-     else :
458-         if  x_type is  cnp.NPY_FLOAT or  x_type is  cnp.NPY_CFLOAT:
459-             f_type =  cnp.NPY_CFLOAT
477+         if  out is  not  None :
478+             out[...] =  x_arr
479+             return  out
460480        else :
461-             f_type  =  cnp.NPY_CDOUBLE 
462- 
463-         if  out  is  None :
481+             return  x_arr 
482+      else : 
483+         if  f_arr  is  None :
464484            f_arr =  _allocate_result(x_arr, n_, axis_, f_type)
465-         else :
466-             out_dtype =  np.dtype(cnp.PyArray_DescrFromType(f_type))
467-             _validate_out_array(out, x, out_dtype, axis = axis_, n = n_)
468-             #  out array that is used in OneMKL c2c FFT must have the exact same
469-             #  stride as input array. If not, we need to allocate a new array.
470-             #  TODO: check to see if this condition can be relaxed
471-             if  _get_element_strides(x) ==  _get_element_strides(out):
472-                 f_arr =  < cnp.ndarray>  out
473-             else :
474-                 f_arr =  _allocate_result(x_arr, n_, axis_, f_type)
475485
476486        #  call out-of-place FFT
477487        _cache_capsule =  _tls_dfti_cache_capsule()
@@ -612,9 +622,10 @@ def _r2c_fft1d_impl(
612622        #  be compared directly.
613623        #  TODO: currently instead of this condition, we check both input
614624        #  and output to be c_contig or f_contig, relax this condition
625+         #  In addition, input and output data sets must have no common elements
615626        c_contig =  x.flags.c_contiguous and  out.flags.c_contiguous
616627        f_contig =  x.flags.f_contiguous and  out.flags.f_contiguous
617-         if  c_contig or  f_contig:
628+         if  c_contig or  f_contig  and   not  np.shares_memory(x, out) :
618629            f_arr =  < cnp.ndarray>  out
619630        else :
620631            f_arr =  _allocate_result(x_arr, f_shape, axis_, f_type)
@@ -715,9 +726,10 @@ def _c2r_fft1d_impl(
715726            #  strides cannot be compared directly.
716727            #  TODO: currently instead of this condition, we check both input
717728            #  and output to be c_contig or f_contig, relax this condition
729+             #  Also input and output data sets must have no common elements
718730            c_contig =  x.flags.c_contiguous and  out.flags.c_contiguous
719731            f_contig =  x.flags.f_contiguous and  out.flags.f_contiguous
720-             if  c_contig or  f_contig:
732+             if  c_contig or  f_contig  and   not  np.shares_memory(x, out) :
721733                f_arr =  < cnp.ndarray>  out
722734            else :
723735                f_arr =  _allocate_result(x_arr, n_, axis_, f_type)
@@ -755,13 +767,13 @@ def _c2r_fft1d_impl(
755767
756768
757769def  _direct_fftnd (
758-     x , direction = + 1 , double fsc = 1.0 , out = None 
770+     x , direction = + 1 , double fsc = 1.0 , in_place = 0 ,  out = None 
759771):
760772    """ Perform n-dimensional FFT over all axes""" 
761773    cdef int  err
762774    cdef cnp.ndarray x_arr " xxnd_arrayObject" 
763775    cdef cnp.ndarray f_arr " ffnd_arrayObject" 
764-     cdef int  in_place,  x_type, f_type
776+     cdef int  x_type, f_type
765777
766778    if  direction not  in  [- 1 , + 1 ]:
767779        raise  ValueError (" Direction of FFT should +1 or -1"  )
@@ -779,7 +791,7 @@ def _direct_fftnd(
779791        raise  ValueError (" An input argument x is not an array-like object"  )
780792
781793    #  a copy was made, so we can work in place.
782-     in_place =  1  if  _datacopied(x_arr, x) else  0 
794+     in_place =  1  if  _datacopied(x_arr, x) else  in_place 
783795
784796    x_type =  cnp.PyArray_TYPE(x_arr)
785797    if  (
@@ -798,15 +810,35 @@ def _direct_fftnd(
798810        assert  x_type ==  cnp.NPY_CDOUBLE
799811        in_place =  1 
800812
801-     if  out is  not  None :
802-         in_place =  0 
803- 
804813    if  in_place:
805814        if  x_type ==  cnp.NPY_CDOUBLE or  x_type ==  cnp.NPY_CFLOAT:
806815            in_place =  1 
807816        else :
808817            in_place =  0 
809818
819+     f_arr =  None 
820+     if  x_type ==  cnp.NPY_CDOUBLE or  x_type ==  cnp.NPY_DOUBLE:
821+         f_type =  cnp.NPY_CDOUBLE
822+     else :
823+         f_type =  cnp.NPY_CFLOAT
824+ 
825+     if  out is  not  None :
826+         out_dtype =  np.dtype(cnp.PyArray_DescrFromType(f_type))
827+         _validate_out_array(out, x, out_dtype)
828+         if  x is  out:
829+             in_place =  1 
830+         elif  (
831+             _get_element_strides(x) ==  _get_element_strides(out)
832+             and  not  np.shares_memory(x, out)
833+         ):
834+             #  out array that is used in OneMKL c2c FFT must have the same stride
835+             #  as input array and must have no common elements with input array.
836+             #  If these conditions are not met, we need to allocate a new array,
837+             #  which is done later.
838+             #  TODO: check to see if the same stride condition can be relaxed
839+             f_arr =  < cnp.ndarray>  out
840+             in_place =  0 
841+ 
810842    if  in_place:
811843        if  x_type ==  cnp.NPY_CDOUBLE:
812844            if  direction ==  1 :
@@ -821,24 +853,14 @@ def _direct_fftnd(
821853        else :
822854            raise  ValueError (" An input argument x is not complex type array"  )
823855
824-         return  x_arr
825-     else :
826-         if  x_type ==  cnp.NPY_CDOUBLE or  x_type ==  cnp.NPY_DOUBLE:
827-             f_type =  cnp.NPY_CDOUBLE
856+         if  out is  not  None :
857+             out[...] =  x_arr
858+             return  out
828859        else :
829-             f_type =  cnp.NPY_CFLOAT
830-         if  out is  None :
860+             return  x_arr
861+     else :
862+         if  f_arr is  None :
831863            f_arr =  _allocate_result(x_arr, - 1 , 0 , f_type)
832-         else :
833-             out_dtype =  np.dtype(cnp.PyArray_DescrFromType(f_type))
834-             _validate_out_array(out, x, out_dtype)
835-             #  out array that is used in OneMKL c2c FFT must have the exact same
836-             #  stride as input array. If not, we need to allocate a new array.
837-             #  TODO: check to see if this condition can be relaxed
838-             if  _get_element_strides(x) ==  _get_element_strides(out):
839-                 f_arr =  < cnp.ndarray>  out
840-             else :
841-                 f_arr =  _allocate_result(x_arr, - 1 , 0 , f_type)
842864
843865        if  x_type ==  cnp.NPY_CDOUBLE:
844866            if  direction ==  1 :
0 commit comments