Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions fbgemm_gpu/src/memory_utils/memory_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ std::tuple<void*, size_t> adjust_to_page_boundaries(void* ptr, size_t size) {
return std::make_tuple((void*)raw_ptr_adjusted, (size_t)size_adjusted);
}

#ifdef USE_ROCM
using gpuMemLocation = hipMemLocation;
#else
using gpuMemLocation = cudaMemLocation;
#endif

inline gpuMemLocation new_mem_location_from_device(const int device_id) {
gpuMemLocation deviceLoc;
deviceLoc.type = cudaMemLocationTypeDevice;
deviceLoc.id = device_id;
return deviceLoc;
}

} // namespace

Tensor new_managed_tensor(
Expand All @@ -158,11 +171,31 @@ Tensor new_managed_tensor(

// Set preferred memory location to host memory
AT_CUDA_CHECK(cudaMemAdvise(
ptr, size_bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId));
ptr,
size_bytes,
cudaMemAdviseSetPreferredLocation,
#if CUDA_VERSION >= 13000
// Starting with CUDA 13, the deviceId arg (int) is replaced with
// cudaMemLocation (struct)
new_mem_location_from_device(cudaCpuDeviceId)
#else
cudaCpuDeviceId
#endif
));

// User hints with "accessed by": GPU will establish direct mapping of data
// in CPU memory, no page faults will be generated
AT_CUDA_CHECK(cudaMemAdvise(
ptr, size_bytes, cudaMemAdviseSetAccessedBy, at::cuda::current_device()));
ptr,
size_bytes,
cudaMemAdviseSetAccessedBy,
#if CUDA_VERSION >= 13000
new_mem_location_from_device(at::cuda::current_device())
#else
at::cuda::current_device()
#endif
));

C10_CUDA_KERNEL_LAUNCH_CHECK();

// Work around fork issue - see uvm_mem_advice_dont_fork for details
Expand Down Expand Up @@ -353,7 +386,12 @@ void uvm_cuda_mem_advise(const Tensor& t, int64_t cuda_memory_advise) {
ptr,
size_bytes,
static_cast<enum cudaMemoryAdvise>(cuda_memory_advise),
hint_device));
#if CUDA_VERSION >= 13000
new_mem_location_from_device(hint_device)
#else
hint_device
#endif
));
return;
}

Expand All @@ -379,7 +417,18 @@ void uvm_cuda_mem_prefetch_async(

auto stream = at::cuda::getCurrentCUDAStream();

AT_CUDA_CHECK(cudaMemPrefetchAsync(ptr, size_bytes, prefetch_device, stream));
AT_CUDA_CHECK(cudaMemPrefetchAsync(
ptr,
size_bytes,
#if CUDA_VERSION >= 13000
new_mem_location_from_device(prefetch_device),
// Flags argument needs to be set to zero for now, see:
// https://docs.nvidia.com/cuda/archive/13.0.0/cuda-runtime-api/group__CUDART__MEMORY.html
0,
#else
prefetch_device,
#endif
stream));

return;
}
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_async_batched_cumsum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ __global__ __launch_bounds__(kMaxThreads) void _batched_complete_cumsum_kernel(
data = (val_t)values[blockIdx.x][i];
}
BlockScan(temp_storage).InclusiveSum(data, data, prefix_op);

#if CUDA_VERSION >= 13000
__syncthreads();
#else
cub::CTA_SYNC();
#endif

if (i < len) {
out[blockIdx.x][i + 1] = data;
}
Expand Down
Loading