From a1e74adf3526c477c7db71c17402c19bab1c1e7c Mon Sep 17 00:00:00 2001 From: interestingLSY Date: Thu, 6 Jun 2024 17:44:05 +0800 Subject: [PATCH 1/7] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 1e1c4ca8..749e6d4d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ build/ diff_gaussian_rasterization.egg-info/ dist/ +__pycache__/ +_C.cpython* From 3ec0567bd8b064a56e9aadbb45400414d10e2eff Mon Sep 17 00:00:00 2001 From: interestingLSY Date: Thu, 6 Jun 2024 17:52:14 +0800 Subject: [PATCH 2/7] Use rsqrt() instead of 1/sqrt for speed and precision --- cuda_rasterizer/auxiliary.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 4d4b9b78..d7b1a6e6 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -99,7 +99,7 @@ __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, cons __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; - float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float invsum32 = rsqrtf(sum2 * sum2 * sum2); float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; return dnormvdz; } @@ -107,7 +107,7 @@ __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; - float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float invsum32 = rsqrtf(sum2 * sum2 * sum2); float3 dnormvdv; dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; @@ -119,7 +119,7 @@ __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; - float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float invsum32 = rsqrtf(sum2 * sum2 * sum2); float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; From 53f1e86843173d15d4e98e29c1fa08adfe0018f4 Mon Sep 17 00:00:00 2001 From: interestingLSY Date: Fri, 7 Jun 2024 23:11:20 +0800 Subject: [PATCH 3/7] Optimize backward propagation --- cuda_rasterizer/auxiliary.h | 3 +- cuda_rasterizer/backward.cu | 364 ++++++++++++++++++++++------- cuda_rasterizer/backward.h | 16 +- cuda_rasterizer/forward.cu | 6 +- cuda_rasterizer/forward.h | 2 +- cuda_rasterizer/rasterizer_impl.cu | 22 +- cuda_rasterizer/rasterizer_impl.h | 9 +- 7 files changed, 328 insertions(+), 94 deletions(-) diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index d7b1a6e6..33620000 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -16,7 +16,8 @@ #include "stdio.h" #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) -#define NUM_WARPS (BLOCK_SIZE/32) +#define WARP_SIZE 32 +#define NUM_WARPS (BLOCK_SIZE/WARP_SIZE) // Spherical harmonics coefficients __device__ const float SH_C0 = 0.28209479177387814f; diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 4aa41e1c..352571c8 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -395,12 +395,49 @@ __global__ void preprocessCUDA( computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot); } + +template +__forceinline__ __device__ T warpReduceSum(T val, const cg::coalesced_group &group) { + #pragma unroll + for (int offset = 32 / 2; offset > 0; offset /= 2) + val += group.shfl_down(val, offset); + return val; +} + + +template +__forceinline__ __device__ T warpReduceSum(T val) { + #pragma unroll + for (int offset = 32 / 2; offset > 0; offset /= 2) + val += __shfl_xor_sync(0xFFFFFFFF, val, offset); + return val; +} +__forceinline__ __device__ float2 warpReduceSum(float2 val) { + #pragma unroll + for (int offset = 32 / 2; offset > 0; offset /= 2) { + val.x += __shfl_xor_sync(0xFFFFFFFF, val.x, offset); + val.y += __shfl_xor_sync(0xFFFFFFFF, val.y, offset); + } + return val; +} +__forceinline__ __device__ float4 warpReduceSum(float4 val) { + #pragma unroll + for (int offset = 32 / 2; offset > 0; offset /= 2) { + val.x += __shfl_xor_sync(0xFFFFFFFF, val.x, offset); + val.y += __shfl_xor_sync(0xFFFFFFFF, val.y, offset); + val.z += __shfl_xor_sync(0xFFFFFFFF, val.z, offset); + val.w += __shfl_xor_sync(0xFFFFFFFF, val.w, offset); + } + return val; +} + + // Backward version of the rendering procedure. template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA( const uint2* __restrict__ ranges, - const uint32_t* __restrict__ point_list, + const uint64_t* __restrict__ point_list, int W, int H, const float* __restrict__ bg_color, const float2* __restrict__ points_xy_image, @@ -409,10 +446,9 @@ renderCUDA( const float* __restrict__ final_Ts, const uint32_t* __restrict__ n_contrib, const float* __restrict__ dL_dpixels, - float3* __restrict__ dL_dmean2D, - float4* __restrict__ dL_dconic2D, - float* __restrict__ dL_dopacity, - float* __restrict__ dL_dcolors) + float* __restrict__ dL_dcolors, + float2* __restrict__ dL_dmean2D, + float4* __restrict__ dL_dconic2D_dopacity) { // We rasterize again. Compute necessary block info. auto block = cg::this_thread_block(); @@ -428,10 +464,9 @@ renderCUDA( const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); - bool done = !inside; int toDo = range.y - range.x; - __shared__ int collected_id[BLOCK_SIZE]; + __shared__ int collected_offset[BLOCK_SIZE]; // Offsets of instances before sorting __shared__ float2 collected_xy[BLOCK_SIZE]; __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; __shared__ float collected_colors[C * BLOCK_SIZE]; @@ -469,8 +504,10 @@ renderCUDA( const int progress = i * BLOCK_SIZE + block.thread_rank(); if (range.x + progress < range.y) { - const int coll_id = point_list[range.y - progress - 1]; - collected_id[block.thread_rank()] = coll_id; + const uint64_t coll_id_and_offset = point_list[range.y - progress - 1]; + const int coll_id = coll_id_and_offset>>32; + const int offset_before_sorting = coll_id_and_offset & 0xFFFFFFFF; + collected_offset[block.thread_rank()] = offset_before_sorting; collected_xy[block.thread_rank()] = points_xy_image[coll_id]; collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; for (int i = 0; i < C; i++) @@ -478,81 +515,225 @@ renderCUDA( } block.sync(); + static constexpr int REDUCTION_BATCH_SIZE = 8; + __shared__ float batch_dL_dcolors[REDUCTION_BATCH_SIZE][NUM_WARPS][C]; + __shared__ float2 batch_dL_dmean2D[REDUCTION_BATCH_SIZE][NUM_WARPS]; + __shared__ float4 batch_dL_dconic2D_dopacity[REDUCTION_BATCH_SIZE][NUM_WARPS]; + // Iterate over Gaussians - for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) + for (int j = 0; j < min(BLOCK_SIZE, toDo); j++) { // Keep track of current Gaussian ID. Skip, if this one // is behind the last contributor for this pixel. + float cur_dL_dcolors[C] = {0}; + float2 cur_dL_dmean2D = {0, 0}; + float4 cur_dL_dconic2D_dopacity = {0, 0, 0, 0}; + contributor--; - if (contributor >= last_contributor) - continue; - - // Compute blending values, as before. - const float2 xy = collected_xy[j]; - const float2 d = { xy.x - pixf.x, xy.y - pixf.y }; - const float4 con_o = collected_conic_opacity[j]; - const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; - if (power > 0.0f) - continue; - - const float G = exp(power); - const float alpha = min(0.99f, con_o.w * G); - if (alpha < 1.0f / 255.0f) - continue; - - T = T / (1.f - alpha); - const float dchannel_dcolor = alpha * T; - - // Propagate gradients to per-Gaussian colors and keep - // gradients w.r.t. alpha (blending factor for a Gaussian/pixel - // pair). - float dL_dalpha = 0.0f; - const int global_id = collected_id[j]; - for (int ch = 0; ch < C; ch++) - { - const float c = collected_colors[ch * BLOCK_SIZE + j]; - // Update last color (to be used in the next iteration) - accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch]; - last_color[ch] = c; - - const float dL_dchannel = dL_dpixel[ch]; - dL_dalpha += (c - accum_rec[ch]) * dL_dchannel; - // Update the gradients w.r.t. color of the Gaussian. - // Atomic, since this pixel is just one of potentially - // many that were affected by this Gaussian. - atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel); + if (inside && contributor < last_contributor) { + // Compute blending values, as before. + const float2 xy = collected_xy[j]; + const float2 d = { xy.x - pixf.x, xy.y - pixf.y }; + const float4 con_o = collected_conic_opacity[j]; + const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; + + const float G = exp(power); + const float alpha = min(0.99f, con_o.w * G); + if (power <= 0.0f && alpha >= 1.0f / 255.0f) { + T = T / (1.f - alpha); + const float dchannel_dcolor = alpha * T; + + // Propagate gradients to per-Gaussian colors and keep + // gradients w.r.t. alpha (blending factor for a Gaussian/pixel + // pair). + float dL_dalpha = 0.0f; + #pragma unroll + for (int ch = 0; ch < C; ch++) + { + const float c = collected_colors[ch * BLOCK_SIZE + j]; + // Update last color (to be used in the next iteration) + accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch]; + last_color[ch] = c; + + const float dL_dchannel = dL_dpixel[ch]; + dL_dalpha += (c - accum_rec[ch]) * dL_dchannel; + // Update the gradients w.r.t. color of the Gaussian. + // Atomic, since this pixel is just one of potentially + // many that were affected by this Gaussian. + cur_dL_dcolors[ch] = dchannel_dcolor * dL_dchannel; + } + dL_dalpha *= T; + // Update last alpha (to be used in the next iteration) + last_alpha = alpha; + + // Account for fact that alpha also influences how much of + // the background color is added if nothing left to blend + float bg_dot_dpixel = 0; + #pragma unroll + for (int i = 0; i < C; i++) + bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; + dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel; + + // Helpful reusable temporary variables + const float dL_dG = con_o.w * dL_dalpha; + const float gdx = G * d.x; + const float gdy = G * d.y; + const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y; + const float dG_ddely = -gdy * con_o.z - gdx * con_o.y; + + // Update gradients w.r.t. 2D mean position of the Gaussian + // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) + // Update gradients w.r.t. opacity of the Gaussian + cur_dL_dmean2D = { + dL_dG * dG_ddelx * ddelx_dx, + dL_dG * dG_ddely * ddely_dy + }; + cur_dL_dconic2D_dopacity = { + -0.5f * gdx * d.x * dL_dG, + -0.5f * gdx * d.y * dL_dG, + G * dL_dalpha, + -0.5f * gdy * d.y * dL_dG + }; + } } - dL_dalpha *= T; - // Update last alpha (to be used in the next iteration) - last_alpha = alpha; - // Account for fact that alpha also influences how much of - // the background color is added if nothing left to blend - float bg_dot_dpixel = 0; - for (int i = 0; i < C; i++) - bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; - dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel; + // Perform warp-level reduction + #pragma unroll + for (int ch = 0; ch < C; ch++) + cur_dL_dcolors[ch] = warpReduceSum(cur_dL_dcolors[ch]); + cur_dL_dmean2D = warpReduceSum(cur_dL_dmean2D); + cur_dL_dconic2D_dopacity = warpReduceSum(cur_dL_dconic2D_dopacity); + // Store the results in shared memory + if (block.thread_rank() % WARP_SIZE == 0) + { + int warp_id = block.thread_rank() / WARP_SIZE; + #pragma unroll + for (int ch = 0; ch < C; ch++) + batch_dL_dcolors[j%REDUCTION_BATCH_SIZE][warp_id][ch] = cur_dL_dcolors[ch]; + batch_dL_dmean2D[j%REDUCTION_BATCH_SIZE][warp_id] = cur_dL_dmean2D; + batch_dL_dconic2D_dopacity[j%REDUCTION_BATCH_SIZE][warp_id] = cur_dL_dconic2D_dopacity; + } - // Helpful reusable temporary variables - const float dL_dG = con_o.w * dL_dalpha; - const float gdx = G * d.x; - const float gdy = G * d.y; - const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y; - const float dG_ddely = -gdy * con_o.z - gdx * con_o.y; + // If this is the last Gaussian in the batch, perform block-level + // reduction and store the results in global memory. + if (j % REDUCTION_BATCH_SIZE == REDUCTION_BATCH_SIZE - 1 || j == min(BLOCK_SIZE, toDo) - 1) + { + // Make sure we can perform this reduction with one warp + static_assert(NUM_WARPS <= WARP_SIZE); + // Make sure there are enough warps if we assign each warp with + // an item in the batch + static_assert(NUM_WARPS >= REDUCTION_BATCH_SIZE); + // Make sure the number of warps is a power of 2 + // static_assert((NUM_WARPS & (NUM_WARPS - 1)) == 0); + + // Wait for all warps to finish + block.sync(); + + int batch_start_j = j / REDUCTION_BATCH_SIZE * REDUCTION_BATCH_SIZE; + int batch_id = block.thread_rank() / WARP_SIZE; + if (batch_id < REDUCTION_BATCH_SIZE && + batch_start_j+batch_id < min(BLOCK_SIZE, toDo)) { + int lane_id = block.thread_rank() % WARP_SIZE; + + // Perform warp-level reduction + #pragma unroll + for (int ch = 0; ch < C; ch++) + cur_dL_dcolors[ch] = warpReduceSum( + lane_id < NUM_WARPS ? batch_dL_dcolors[batch_id][lane_id][ch] : 0 + ); + cur_dL_dmean2D = warpReduceSum( + lane_id < NUM_WARPS ? + batch_dL_dmean2D[batch_id][lane_id] : + float2{0, 0} + ); + cur_dL_dconic2D_dopacity = warpReduceSum( + lane_id < NUM_WARPS ? + batch_dL_dconic2D_dopacity[batch_id][lane_id] : + float4{0, 0, 0, 0} + ); + + // Store the results in global memory + if (lane_id == 0) + { + const int global_offset = collected_offset[batch_start_j+batch_id]; + if constexpr(C == 3) { + // Special optimization for C == 3 + ((float3*)dL_dcolors)[global_offset] = make_float3(cur_dL_dcolors[0], cur_dL_dcolors[1], cur_dL_dcolors[2]); + } else { + #pragma unroll + for (int ch = 0; ch < C; ch++) + dL_dcolors[global_offset * C + ch] = cur_dL_dcolors[ch]; + } + dL_dmean2D[global_offset] = cur_dL_dmean2D; + dL_dconic2D_dopacity[global_offset] = cur_dL_dconic2D_dopacity; + } + } + + // Wait for all warps to finish reducing + if (j != min(BLOCK_SIZE, toDo) - 1) + block.sync(); + } + } + } +} - // Update gradients w.r.t. 2D mean position of the Gaussian - atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx); - atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy); +template +__global__ void gather_gradientsCUDA( + int P, + const uint32_t* __restrict__ point_offsets, + const float* __restrict__ dL_dcolors_bin, + const float2* __restrict__ dL_dmean2D_bin, + const float4* __restrict__ dL_dconic2D_dopacity_bin, + float* __restrict__ dL_dcolors, + float3* __restrict__ dL_dmean2D, + float4* __restrict__ dL_dconic2D, + float* __restrict__ dL_dopacity) +{ + // Every warp is responsible for one Gaussian + int gaussian_id = cg::this_grid().thread_rank() / WARP_SIZE; + if (gaussian_id >= P) + return; - // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) - atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG); - atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG); - atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG); + int lane_id = cg::this_thread_block().thread_rank() % WARP_SIZE; + int range_start = gaussian_id == 0 ? 0 : point_offsets[gaussian_id-1]; + int range_end = point_offsets[gaussian_id]; + + float dL_dcolors_sum[C] = { 0 }; + float2 dL_dmean2D_sum = { 0, 0 }; + float4 dL_dconic2D_dopacity_sum = { 0, 0, 0, 0 }; + + #pragma unroll 2 + for (int i = range_start + lane_id; i < range_end; i += WARP_SIZE) { + #pragma unroll + for (int ch = 0; ch < C; ch++) + dL_dcolors_sum[ch] += dL_dcolors_bin[i * C + ch]; + float2 cur_mean2d = dL_dmean2D_bin[i]; + dL_dmean2D_sum.x += cur_mean2d.x; + dL_dmean2D_sum.y += cur_mean2d.y; + float4 cur_conic2d_dopacity = dL_dconic2D_dopacity_bin[i]; + dL_dconic2D_dopacity_sum.x += cur_conic2d_dopacity.x; + dL_dconic2D_dopacity_sum.y += cur_conic2d_dopacity.y; + dL_dconic2D_dopacity_sum.z += cur_conic2d_dopacity.z; + dL_dconic2D_dopacity_sum.w += cur_conic2d_dopacity.w; + } - // Update gradients w.r.t. opacity of the Gaussian - atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha); - } + // Warp-level reduction + #pragma unroll + for (int ch = 0; ch < C; ch++) + dL_dcolors_sum[ch] = warpReduceSum(dL_dcolors_sum[ch]); + dL_dmean2D_sum = warpReduceSum(dL_dmean2D_sum); + dL_dconic2D_dopacity_sum = warpReduceSum(dL_dconic2D_dopacity_sum); + + // Write-back + if (lane_id == 0) { + #pragma unroll + for (int ch = 0; ch < C; ch++) + dL_dcolors[gaussian_id * C + ch] = dL_dcolors_sum[ch]; + dL_dmean2D[gaussian_id].x = dL_dmean2D_sum.x; + dL_dmean2D[gaussian_id].y = dL_dmean2D_sum.y; + dL_dconic2D[gaussian_id] = dL_dconic2D_dopacity_sum; + dL_dopacity[gaussian_id] = dL_dconic2D_dopacity_sum.z; } } @@ -621,10 +802,35 @@ void BACKWARD::preprocess( dL_drot); } +void BACKWARD::gather_gradients( + int P, + const uint32_t* point_offsets, + const float* dL_dcolors_bin, + const float2* dL_dmean2D_bin, + const float4* dL_dconic2D_dopacity_bin, + float* dL_dcolors, + float3* dL_dmean2D, + float4* dL_dconic2D, + float* dL_dopacity) +{ + static constexpr int N_WARPS = 16; + int num_blocks = (P + N_WARPS - 1) / N_WARPS; + gather_gradientsCUDA << > > ( + P, + point_offsets, + dL_dcolors_bin, + dL_dmean2D_bin, + dL_dconic2D_dopacity_bin, + dL_dcolors, + dL_dmean2D, + dL_dconic2D, + dL_dopacity); +} + void BACKWARD::render( const dim3 grid, const dim3 block, const uint2* ranges, - const uint32_t* point_list, + const uint64_t* point_list, int W, int H, const float* bg_color, const float2* means2D, @@ -633,10 +839,9 @@ void BACKWARD::render( const float* final_Ts, const uint32_t* n_contrib, const float* dL_dpixels, - float3* dL_dmean2D, - float4* dL_dconic2D, - float* dL_dopacity, - float* dL_dcolors) + float* dL_dcolors, + float2* dL_dmean2D, + float4* dL_dconic2D_dopacity) { renderCUDA << > >( ranges, @@ -649,9 +854,8 @@ void BACKWARD::render( final_Ts, n_contrib, dL_dpixels, + dL_dcolors, dL_dmean2D, - dL_dconic2D, - dL_dopacity, - dL_dcolors + dL_dconic2D_dopacity ); } \ No newline at end of file diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 93dd2e4b..bc5d581a 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -23,7 +23,7 @@ namespace BACKWARD void render( const dim3 grid, dim3 block, const uint2* ranges, - const uint32_t* point_list, + const uint64_t* point_list, int W, int H, const float* bg_color, const float2* means2D, @@ -32,10 +32,20 @@ namespace BACKWARD const float* final_Ts, const uint32_t* n_contrib, const float* dL_dpixels, + float* dL_dcolors, + float2* dL_dmean2D, + float4* dL_dconic2D_dopacity); + + void gather_gradients( + int P, + const uint32_t* point_offsets, + const float* dL_dcolors_bin, + const float2* dL_dmean2D_bin, + const float4* dL_dconic2D_dopacity_bin, + float* dL_dcolors, float3* dL_dmean2D, float4* dL_dconic2D, - float* dL_dopacity, - float* dL_dcolors); + float* dL_dopacity); void preprocess( int P, int D, int M, diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c419a328..1a7e98f2 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -262,7 +262,7 @@ template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA( const uint2* __restrict__ ranges, - const uint32_t* __restrict__ point_list, + const uint64_t* __restrict__ point_list, int W, int H, const float2* __restrict__ points_xy_image, const float* __restrict__ features, @@ -314,7 +314,7 @@ renderCUDA( int progress = i * BLOCK_SIZE + block.thread_rank(); if (range.x + progress < range.y) { - int coll_id = point_list[range.x + progress]; + int coll_id = point_list[range.x + progress]>>32; collected_id[block.thread_rank()] = coll_id; collected_xy[block.thread_rank()] = points_xy_image[coll_id]; collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; @@ -376,7 +376,7 @@ renderCUDA( void FORWARD::render( const dim3 grid, dim3 block, const uint2* ranges, - const uint32_t* point_list, + const uint64_t* point_list, int W, int H, const float2* means2D, const float* colors, diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 3c11cb91..0caf6fb0 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -51,7 +51,7 @@ namespace FORWARD void render( const dim3 grid, dim3 block, const uint2* ranges, - const uint32_t* point_list, + const uint64_t* point_list, int W, int H, const float2* points_xy_image, const float* features, diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index f8782ac4..5176e3c8 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -73,7 +73,7 @@ __global__ void duplicateWithKeys( const float* depths, const uint32_t* offsets, uint64_t* gaussian_keys_unsorted, - uint32_t* gaussian_values_unsorted, + uint64_t* gaussian_values_unsorted, int* radii, dim3 grid) { @@ -103,7 +103,7 @@ __global__ void duplicateWithKeys( key <<= 32; key |= *((uint32_t*)&depths[idx]); gaussian_keys_unsorted[off] = key; - gaussian_values_unsorted[off] = idx; + gaussian_values_unsorted[off] = idx<<32 | off; off++; } } @@ -190,6 +190,9 @@ CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chun binning.point_list_keys_unsorted, binning.point_list_keys, binning.point_list_unsorted, binning.point_list, P); obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); + obtain(chunk, binning.dL_dcolors, P * NUM_CHANNELS, 128); + obtain(chunk, binning.dL_dmean2D, P * 2, 128); + obtain(chunk, binning.dL_dconic2D_dopacity, P * 4, 128); return binning; } @@ -400,10 +403,21 @@ void CudaRasterizer::Rasterizer::backward( imgState.accum_alpha, imgState.n_contrib, dL_dpix, + binningState.dL_dcolors, + binningState.dL_dmean2D, + binningState.dL_dconic2D_dopacity + ), debug) + + CHECK_CUDA(BACKWARD::gather_gradients( + P, + geomState.point_offsets, + binningState.dL_dcolors, + binningState.dL_dmean2D, + binningState.dL_dconic2D_dopacity, + dL_dcolor, (float3*)dL_dmean2D, (float4*)dL_dconic, - dL_dopacity, - dL_dcolor), debug) + dL_dopacity), debug) // Take care of the rest of preprocessing. Was the precomputed covariance // given to us or a scales/rot pair? If precomputed, pass that. If not, diff --git a/cuda_rasterizer/rasterizer_impl.h b/cuda_rasterizer/rasterizer_impl.h index bc3f0ece..f9db59cf 100644 --- a/cuda_rasterizer/rasterizer_impl.h +++ b/cuda_rasterizer/rasterizer_impl.h @@ -57,10 +57,15 @@ namespace CudaRasterizer size_t sorting_size; uint64_t* point_list_keys_unsorted; uint64_t* point_list_keys; - uint32_t* point_list_unsorted; - uint32_t* point_list; + uint64_t* point_list_unsorted; // High 32 bits are the id of the gaussian, + // Low 32 bits are the index of the copy before sorting + uint64_t* point_list; char* list_sorting_space; + float* dL_dcolors; + float2* dL_dmean2D; + float4* dL_dconic2D_dopacity; + static BinningState fromChunk(char*& chunk, size_t P); }; From fbaec5cccddae9f7c345b4e4fd7b475e9f44a0a2 Mon Sep 17 00:00:00 2001 From: interestingLSY Date: Sat, 8 Jun 2024 15:45:20 +0800 Subject: [PATCH 4/7] Further optimization --- cuda_rasterizer/backward.cu | 218 ++++++++++++++++++----------- cuda_rasterizer/backward.h | 14 +- cuda_rasterizer/rasterizer_impl.cu | 10 +- 3 files changed, 156 insertions(+), 86 deletions(-) diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 352571c8..f6265eba 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -397,32 +397,23 @@ __global__ void preprocessCUDA( template -__forceinline__ __device__ T warpReduceSum(T val, const cg::coalesced_group &group) { +__forceinline__ __device__ T warpReduceSum(T val, int num_lanes = 32) { #pragma unroll - for (int offset = 32 / 2; offset > 0; offset /= 2) - val += group.shfl_down(val, offset); - return val; -} - - -template -__forceinline__ __device__ T warpReduceSum(T val) { - #pragma unroll - for (int offset = 32 / 2; offset > 0; offset /= 2) + for (int offset = num_lanes / 2; offset > 0; offset /= 2) val += __shfl_xor_sync(0xFFFFFFFF, val, offset); return val; } -__forceinline__ __device__ float2 warpReduceSum(float2 val) { +__forceinline__ __device__ float2 warpReduceSum(float2 val, int num_lanes = 32) { #pragma unroll - for (int offset = 32 / 2; offset > 0; offset /= 2) { + for (int offset = num_lanes / 2; offset > 0; offset /= 2) { val.x += __shfl_xor_sync(0xFFFFFFFF, val.x, offset); val.y += __shfl_xor_sync(0xFFFFFFFF, val.y, offset); } return val; } -__forceinline__ __device__ float4 warpReduceSum(float4 val) { +__forceinline__ __device__ float4 warpReduceSum(float4 val, int num_lanes = 32) { #pragma unroll - for (int offset = 32 / 2; offset > 0; offset /= 2) { + for (int offset = num_lanes / 2; offset > 0; offset /= 2) { val.x += __shfl_xor_sync(0xFFFFFFFF, val.x, offset); val.y += __shfl_xor_sync(0xFFFFFFFF, val.y, offset); val.z += __shfl_xor_sync(0xFFFFFFFF, val.z, offset); @@ -433,6 +424,7 @@ __forceinline__ __device__ float4 warpReduceSum(float4 val) { // Backward version of the rendering procedure. +#define USE_ATOMIC_THRESHOLD 6 template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA( @@ -446,9 +438,14 @@ renderCUDA( const float* __restrict__ final_Ts, const uint32_t* __restrict__ n_contrib, const float* __restrict__ dL_dpixels, - float* __restrict__ dL_dcolors, - float2* __restrict__ dL_dmean2D, - float4* __restrict__ dL_dconic2D_dopacity) + const uint32_t* __restrict__ tiles_touched, + float* __restrict__ dL_dcolors_bin, + float2* __restrict__ dL_dmean2D_bin, + float4* __restrict__ dL_dconic2D_dopacity_bin, + float* __restrict__ dL_dcolors_global, + float3* __restrict__ dL_dmean2D_global, + float4* __restrict__ dL_dconic2D_global, + float* __restrict__ dL_dopacity_global) { // We rasterize again. Compute necessary block info. auto block = cg::this_thread_block(); @@ -466,7 +463,9 @@ renderCUDA( int toDo = range.y - range.x; - __shared__ int collected_offset[BLOCK_SIZE]; // Offsets of instances before sorting + __shared__ int collected_offset[BLOCK_SIZE]; // Offsets of instances before sorting when USE_ATOMIC_ADD is False + // Id of the gaussian when USE_ATOMIC_ADD is True + __shared__ bool collected_use_atomic[BLOCK_SIZE]; __shared__ float2 collected_xy[BLOCK_SIZE]; __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; __shared__ float collected_colors[C * BLOCK_SIZE]; @@ -507,7 +506,10 @@ renderCUDA( const uint64_t coll_id_and_offset = point_list[range.y - progress - 1]; const int coll_id = coll_id_and_offset>>32; const int offset_before_sorting = coll_id_and_offset & 0xFFFFFFFF; - collected_offset[block.thread_rank()] = offset_before_sorting; + const int cur_tiles_touched = tiles_touched[coll_id]; + bool cur_use_atomic = cur_tiles_touched <= USE_ATOMIC_THRESHOLD; + collected_use_atomic[block.thread_rank()] = cur_use_atomic; + collected_offset[block.thread_rank()] = cur_use_atomic ? coll_id : offset_before_sorting; collected_xy[block.thread_rank()] = points_xy_image[coll_id]; collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; for (int i = 0; i < C; i++) @@ -515,10 +517,15 @@ renderCUDA( } block.sync(); - static constexpr int REDUCTION_BATCH_SIZE = 8; + static constexpr int REDUCTION_BATCH_SIZE = 16; + int cur_reduction_batch_idx = 0; + __shared__ int batch_j[REDUCTION_BATCH_SIZE]; __shared__ float batch_dL_dcolors[REDUCTION_BATCH_SIZE][NUM_WARPS][C]; __shared__ float2 batch_dL_dmean2D[REDUCTION_BATCH_SIZE][NUM_WARPS]; __shared__ float4 batch_dL_dconic2D_dopacity[REDUCTION_BATCH_SIZE][NUM_WARPS]; + __shared__ float batch_reduced_dL_dcolors[REDUCTION_BATCH_SIZE][C]; + __shared__ float2 batch_reduced_dL_dmean2D[REDUCTION_BATCH_SIZE]; + __shared__ float4 batch_reduced_dL_dconic2D_dopacity[REDUCTION_BATCH_SIZE]; // Iterate over Gaussians for (int j = 0; j < min(BLOCK_SIZE, toDo); j++) @@ -528,6 +535,7 @@ renderCUDA( float cur_dL_dcolors[C] = {0}; float2 cur_dL_dmean2D = {0, 0}; float4 cur_dL_dconic2D_dopacity = {0, 0, 0, 0}; + const bool use_atomic = collected_use_atomic[j]; contributor--; if (inside && contributor < last_contributor) { @@ -547,6 +555,7 @@ renderCUDA( // gradients w.r.t. alpha (blending factor for a Gaussian/pixel // pair). float dL_dalpha = 0.0f; + const int global_id = collected_offset[j]; #pragma unroll for (int ch = 0; ch < C; ch++) { @@ -560,7 +569,11 @@ renderCUDA( // Update the gradients w.r.t. color of the Gaussian. // Atomic, since this pixel is just one of potentially // many that were affected by this Gaussian. - cur_dL_dcolors[ch] = dchannel_dcolor * dL_dchannel; + if (use_atomic) { + atomicAdd(&dL_dcolors_global[global_id*C + ch], dchannel_dcolor * dL_dchannel); + } else { + cur_dL_dcolors[ch] = dchannel_dcolor * dL_dchannel; + } } dL_dalpha *= T; // Update last alpha (to be used in the next iteration) @@ -584,95 +597,123 @@ renderCUDA( // Update gradients w.r.t. 2D mean position of the Gaussian // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) // Update gradients w.r.t. opacity of the Gaussian - cur_dL_dmean2D = { - dL_dG * dG_ddelx * ddelx_dx, - dL_dG * dG_ddely * ddely_dy - }; - cur_dL_dconic2D_dopacity = { - -0.5f * gdx * d.x * dL_dG, - -0.5f * gdx * d.y * dL_dG, - G * dL_dalpha, - -0.5f * gdy * d.y * dL_dG - }; + if (use_atomic) { + atomicAdd(&dL_dmean2D_global[global_id].x, dL_dG * dG_ddelx * ddelx_dx); + atomicAdd(&dL_dmean2D_global[global_id].y, dL_dG * dG_ddely * ddely_dy); + atomicAdd(&dL_dconic2D_global[global_id].x, -0.5f * gdx * d.x * dL_dG); + atomicAdd(&dL_dconic2D_global[global_id].y, -0.5f * gdx * d.y * dL_dG); + atomicAdd(&dL_dconic2D_global[global_id].w, -0.5f * gdy * d.y * dL_dG); + atomicAdd(&dL_dopacity_global[global_id], G * dL_dalpha); + } else { + cur_dL_dmean2D = { + dL_dG * dG_ddelx * ddelx_dx, + dL_dG * dG_ddely * ddely_dy + }; + cur_dL_dconic2D_dopacity = { + -0.5f * gdx * d.x * dL_dG, + -0.5f * gdx * d.y * dL_dG, + G * dL_dalpha, + -0.5f * gdy * d.y * dL_dG + }; + } } } - // Perform warp-level reduction - #pragma unroll - for (int ch = 0; ch < C; ch++) - cur_dL_dcolors[ch] = warpReduceSum(cur_dL_dcolors[ch]); - cur_dL_dmean2D = warpReduceSum(cur_dL_dmean2D); - cur_dL_dconic2D_dopacity = warpReduceSum(cur_dL_dconic2D_dopacity); - - // Store the results in shared memory - if (block.thread_rank() % WARP_SIZE == 0) - { - int warp_id = block.thread_rank() / WARP_SIZE; + if (!use_atomic) { + // Perform warp-level reduction #pragma unroll - for (int ch = 0; ch < C; ch++) - batch_dL_dcolors[j%REDUCTION_BATCH_SIZE][warp_id][ch] = cur_dL_dcolors[ch]; - batch_dL_dmean2D[j%REDUCTION_BATCH_SIZE][warp_id] = cur_dL_dmean2D; - batch_dL_dconic2D_dopacity[j%REDUCTION_BATCH_SIZE][warp_id] = cur_dL_dconic2D_dopacity; + for (int offset = 32/2; offset > 0; offset /= 2) { + #pragma unroll + for (int ch = 0; ch < C; ch++) + cur_dL_dcolors[ch] += __shfl_down_sync(0xFFFFFFFF, cur_dL_dcolors[ch], offset); + cur_dL_dmean2D.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.x, offset); + cur_dL_dmean2D.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.y, offset); + cur_dL_dconic2D_dopacity.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.x, offset); + cur_dL_dconic2D_dopacity.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.y, offset); + cur_dL_dconic2D_dopacity.z += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.z, offset); + cur_dL_dconic2D_dopacity.w += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.w, offset); + } + + // Store the results in shared memory + if (block.thread_rank() % WARP_SIZE == 0) + { + int warp_id = block.thread_rank() / WARP_SIZE; + batch_j[cur_reduction_batch_idx] = j; + #pragma unroll + for (int ch = 0; ch < C; ch++) + batch_dL_dcolors[cur_reduction_batch_idx][warp_id][ch] = cur_dL_dcolors[ch]; + batch_dL_dmean2D[cur_reduction_batch_idx][warp_id] = cur_dL_dmean2D; + batch_dL_dconic2D_dopacity[cur_reduction_batch_idx][warp_id] = cur_dL_dconic2D_dopacity; + } + cur_reduction_batch_idx += 1; } // If this is the last Gaussian in the batch, perform block-level // reduction and store the results in global memory. - if (j % REDUCTION_BATCH_SIZE == REDUCTION_BATCH_SIZE - 1 || j == min(BLOCK_SIZE, toDo) - 1) + if (cur_reduction_batch_idx == REDUCTION_BATCH_SIZE || (j == min(BLOCK_SIZE, toDo) - 1 && cur_reduction_batch_idx != 0)) { // Make sure we can perform this reduction with one warp static_assert(NUM_WARPS <= WARP_SIZE); - // Make sure there are enough warps if we assign each warp with - // an item in the batch - static_assert(NUM_WARPS >= REDUCTION_BATCH_SIZE); // Make sure the number of warps is a power of 2 - // static_assert((NUM_WARPS & (NUM_WARPS - 1)) == 0); + static_assert((NUM_WARPS & (NUM_WARPS - 1)) == 0); - // Wait for all warps to finish + // Wait for all warps to finish storing block.sync(); - int batch_start_j = j / REDUCTION_BATCH_SIZE * REDUCTION_BATCH_SIZE; - int batch_id = block.thread_rank() / WARP_SIZE; - if (batch_id < REDUCTION_BATCH_SIZE && - batch_start_j+batch_id < min(BLOCK_SIZE, toDo)) { + for (int batch_id = block.thread_rank() / WARP_SIZE; batch_id < cur_reduction_batch_idx; batch_id += NUM_WARPS) { int lane_id = block.thread_rank() % WARP_SIZE; // Perform warp-level reduction #pragma unroll for (int ch = 0; ch < C; ch++) cur_dL_dcolors[ch] = warpReduceSum( - lane_id < NUM_WARPS ? batch_dL_dcolors[batch_id][lane_id][ch] : 0 + lane_id < NUM_WARPS ? batch_dL_dcolors[batch_id][lane_id][ch] : 0, + NUM_WARPS ); cur_dL_dmean2D = warpReduceSum( lane_id < NUM_WARPS ? batch_dL_dmean2D[batch_id][lane_id] : - float2{0, 0} + float2{0, 0}, + NUM_WARPS ); cur_dL_dconic2D_dopacity = warpReduceSum( lane_id < NUM_WARPS ? batch_dL_dconic2D_dopacity[batch_id][lane_id] : - float4{0, 0, 0, 0} + float4{0, 0, 0, 0}, + NUM_WARPS ); // Store the results in global memory if (lane_id == 0) { - const int global_offset = collected_offset[batch_start_j+batch_id]; - if constexpr(C == 3) { - // Special optimization for C == 3 - ((float3*)dL_dcolors)[global_offset] = make_float3(cur_dL_dcolors[0], cur_dL_dcolors[1], cur_dL_dcolors[2]); - } else { - #pragma unroll - for (int ch = 0; ch < C; ch++) - dL_dcolors[global_offset * C + ch] = cur_dL_dcolors[ch]; - } - dL_dmean2D[global_offset] = cur_dL_dmean2D; - dL_dconic2D_dopacity[global_offset] = cur_dL_dconic2D_dopacity; + #pragma unroll + for (int ch = 0; ch < C; ch++) + batch_reduced_dL_dcolors[batch_id][ch] = cur_dL_dcolors[ch]; + batch_reduced_dL_dmean2D[batch_id] = cur_dL_dmean2D; + batch_reduced_dL_dconic2D_dopacity[batch_id] = cur_dL_dconic2D_dopacity; } } // Wait for all warps to finish reducing - if (j != min(BLOCK_SIZE, toDo) - 1) - block.sync(); + block.sync(); + + if (block.thread_rank() < cur_reduction_batch_idx) + { + const int batch_id = block.thread_rank(); + const int global_offset = collected_offset[batch_j[batch_id]]; + if constexpr(C == 3) { + // Special optimization for C == 3 + ((float3*)dL_dcolors_bin)[global_offset] = make_float3(batch_reduced_dL_dcolors[batch_id][0], batch_reduced_dL_dcolors[batch_id][1], batch_reduced_dL_dcolors[batch_id][2]); + } else { + #pragma unroll + for (int ch = 0; ch < C; ch++) + dL_dcolors_bin[global_offset * C + ch] = batch_reduced_dL_dcolors[batch_id][ch]; + } + dL_dmean2D_bin[global_offset] = batch_reduced_dL_dmean2D[batch_id]; + dL_dconic2D_dopacity_bin[global_offset] = batch_reduced_dL_dconic2D_dopacity[batch_id]; + } + + cur_reduction_batch_idx = 0; } } } @@ -682,6 +723,7 @@ template __global__ void gather_gradientsCUDA( int P, const uint32_t* __restrict__ point_offsets, + const uint32_t* __restrict__ tiles_touched, const float* __restrict__ dL_dcolors_bin, const float2* __restrict__ dL_dmean2D_bin, const float4* __restrict__ dL_dconic2D_dopacity_bin, @@ -694,6 +736,9 @@ __global__ void gather_gradientsCUDA( int gaussian_id = cg::this_grid().thread_rank() / WARP_SIZE; if (gaussian_id >= P) return; + int cur_tiles_touched = tiles_touched[gaussian_id]; + if (cur_tiles_touched <= USE_ATOMIC_THRESHOLD) + return; int lane_id = cg::this_thread_block().thread_rank() % WARP_SIZE; int range_start = gaussian_id == 0 ? 0 : point_offsets[gaussian_id-1]; @@ -805,6 +850,7 @@ void BACKWARD::preprocess( void BACKWARD::gather_gradients( int P, const uint32_t* point_offsets, + const uint32_t* tiles_touched, const float* dL_dcolors_bin, const float2* dL_dmean2D_bin, const float4* dL_dconic2D_dopacity_bin, @@ -818,6 +864,7 @@ void BACKWARD::gather_gradients( gather_gradientsCUDA << > > ( P, point_offsets, + tiles_touched, dL_dcolors_bin, dL_dmean2D_bin, dL_dconic2D_dopacity_bin, @@ -839,9 +886,14 @@ void BACKWARD::render( const float* final_Ts, const uint32_t* n_contrib, const float* dL_dpixels, - float* dL_dcolors, - float2* dL_dmean2D, - float4* dL_dconic2D_dopacity) + const uint32_t* tiles_touched, + float* dL_dcolors_bin, + float2* dL_dmean2D_bin, + float4* dL_dconic2D_dopacity_bin, + float* __restrict__ dL_dcolors_global, + float3* __restrict__ dL_dmean2D_global, + float4* __restrict__ dL_dconic2D_global, + float* __restrict__ dL_dopacity_global) { renderCUDA << > >( ranges, @@ -854,8 +906,12 @@ void BACKWARD::render( final_Ts, n_contrib, dL_dpixels, - dL_dcolors, - dL_dmean2D, - dL_dconic2D_dopacity - ); -} \ No newline at end of file + tiles_touched, + dL_dcolors_bin, + dL_dmean2D_bin, + dL_dconic2D_dopacity_bin, + dL_dcolors_global, + dL_dmean2D_global, + dL_dconic2D_global, + dL_dopacity_global); +} diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index bc5d581a..1a4af215 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -21,7 +21,7 @@ namespace BACKWARD { void render( - const dim3 grid, dim3 block, + const dim3 grid, const dim3 block, const uint2* ranges, const uint64_t* point_list, int W, int H, @@ -32,13 +32,19 @@ namespace BACKWARD const float* final_Ts, const uint32_t* n_contrib, const float* dL_dpixels, - float* dL_dcolors, - float2* dL_dmean2D, - float4* dL_dconic2D_dopacity); + const uint32_t* tiles_touched, + float* dL_dcolors_bin, + float2* dL_dmean2D_bin, + float4* dL_dconic2D_dopacity_bin, + float* __restrict__ dL_dcolors_global, + float3* __restrict__ dL_dmean2D_global, + float4* __restrict__ dL_dconic2D_global, + float* __restrict__ dL_dopacity_global); void gather_gradients( int P, const uint32_t* point_offsets, + const uint32_t* tiles_touched, const float* dL_dcolors_bin, const float2* dL_dmean2D_bin, const float4* dL_dconic2D_dopacity_bin, diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index 5176e3c8..0ea799b6 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -386,6 +386,8 @@ void CudaRasterizer::Rasterizer::backward( const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); const dim3 block(BLOCK_X, BLOCK_Y, 1); + // printf("%d %d %.2f\n", P, R, (float)R/P); + // Compute loss gradients w.r.t. 2D mean position, conic matrix, // opacity and RGB of Gaussians from per-pixel loss gradients. // If we were given precomputed colors and not SHs, use them. @@ -403,14 +405,20 @@ void CudaRasterizer::Rasterizer::backward( imgState.accum_alpha, imgState.n_contrib, dL_dpix, + geomState.tiles_touched, binningState.dL_dcolors, binningState.dL_dmean2D, - binningState.dL_dconic2D_dopacity + binningState.dL_dconic2D_dopacity, + dL_dcolor, + (float3*)dL_dmean2D, + (float4*)dL_dconic, + dL_dopacity ), debug) CHECK_CUDA(BACKWARD::gather_gradients( P, geomState.point_offsets, + geomState.tiles_touched, binningState.dL_dcolors, binningState.dL_dmean2D, binningState.dL_dconic2D_dopacity, From 83eceddbfb0ab0364124c0ccf559d6612bfd6949 Mon Sep 17 00:00:00 2001 From: interestingLSY Date: Sat, 8 Jun 2024 16:10:01 +0800 Subject: [PATCH 5/7] Further optimization --- cuda_rasterizer/backward.cu | 70 +++++++++++++++---------------------- 1 file changed, 29 insertions(+), 41 deletions(-) diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index f6265eba..7f8edd44 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -523,9 +523,6 @@ renderCUDA( __shared__ float batch_dL_dcolors[REDUCTION_BATCH_SIZE][NUM_WARPS][C]; __shared__ float2 batch_dL_dmean2D[REDUCTION_BATCH_SIZE][NUM_WARPS]; __shared__ float4 batch_dL_dconic2D_dopacity[REDUCTION_BATCH_SIZE][NUM_WARPS]; - __shared__ float batch_reduced_dL_dcolors[REDUCTION_BATCH_SIZE][C]; - __shared__ float2 batch_reduced_dL_dmean2D[REDUCTION_BATCH_SIZE]; - __shared__ float4 batch_reduced_dL_dconic2D_dopacity[REDUCTION_BATCH_SIZE]; // Iterate over Gaussians for (int j = 0; j < min(BLOCK_SIZE, toDo); j++) @@ -666,53 +663,44 @@ renderCUDA( // Perform warp-level reduction #pragma unroll for (int ch = 0; ch < C; ch++) - cur_dL_dcolors[ch] = warpReduceSum( - lane_id < NUM_WARPS ? batch_dL_dcolors[batch_id][lane_id][ch] : 0, - NUM_WARPS - ); - cur_dL_dmean2D = warpReduceSum( - lane_id < NUM_WARPS ? - batch_dL_dmean2D[batch_id][lane_id] : - float2{0, 0}, - NUM_WARPS - ); - cur_dL_dconic2D_dopacity = warpReduceSum( - lane_id < NUM_WARPS ? - batch_dL_dconic2D_dopacity[batch_id][lane_id] : - float4{0, 0, 0, 0}, - NUM_WARPS - ); + cur_dL_dcolors[ch] = lane_id < NUM_WARPS ? batch_dL_dcolors[batch_id][lane_id][ch] : 0, + cur_dL_dmean2D = lane_id < NUM_WARPS ? batch_dL_dmean2D[batch_id][lane_id] : float2{0, 0}, + cur_dL_dconic2D_dopacity = lane_id < NUM_WARPS ? batch_dL_dconic2D_dopacity[batch_id][lane_id] : float4{0, 0, 0, 0}; + #pragma unroll + for (int offset = NUM_WARPS/2; offset > 0; offset /= 2) { + #pragma unroll + for (int ch = 0; ch < C; ch++) + cur_dL_dcolors[ch] += __shfl_down_sync(0xFFFFFFFF, cur_dL_dcolors[ch], offset); + cur_dL_dmean2D.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.x, offset); + cur_dL_dmean2D.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.y, offset); + cur_dL_dconic2D_dopacity.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.x, offset); + cur_dL_dconic2D_dopacity.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.y, offset); + cur_dL_dconic2D_dopacity.z += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.z, offset); + cur_dL_dconic2D_dopacity.w += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.w, offset); + } + // Store the results in global memory if (lane_id == 0) { - #pragma unroll - for (int ch = 0; ch < C; ch++) - batch_reduced_dL_dcolors[batch_id][ch] = cur_dL_dcolors[ch]; - batch_reduced_dL_dmean2D[batch_id] = cur_dL_dmean2D; - batch_reduced_dL_dconic2D_dopacity[batch_id] = cur_dL_dconic2D_dopacity; + const int global_offset = collected_offset[batch_j[batch_id]]; + if constexpr(C == 3) { + // Special optimization for C == 3 + ((float3*)dL_dcolors_bin)[global_offset] = make_float3(cur_dL_dcolors[0], cur_dL_dcolors[1], cur_dL_dcolors[2]); + } else { + #pragma unroll + for (int ch = 0; ch < C; ch++) + dL_dcolors_bin[global_offset * C + ch] = cur_dL_dcolors[ch]; + } + dL_dmean2D_bin[global_offset] = cur_dL_dmean2D; + dL_dconic2D_dopacity_bin[global_offset] = cur_dL_dconic2D_dopacity; } } // Wait for all warps to finish reducing - block.sync(); + if (j != min(BLOCK_SIZE, toDo) - 1) + block.sync(); - if (block.thread_rank() < cur_reduction_batch_idx) - { - const int batch_id = block.thread_rank(); - const int global_offset = collected_offset[batch_j[batch_id]]; - if constexpr(C == 3) { - // Special optimization for C == 3 - ((float3*)dL_dcolors_bin)[global_offset] = make_float3(batch_reduced_dL_dcolors[batch_id][0], batch_reduced_dL_dcolors[batch_id][1], batch_reduced_dL_dcolors[batch_id][2]); - } else { - #pragma unroll - for (int ch = 0; ch < C; ch++) - dL_dcolors_bin[global_offset * C + ch] = batch_reduced_dL_dcolors[batch_id][ch]; - } - dL_dmean2D_bin[global_offset] = batch_reduced_dL_dmean2D[batch_id]; - dL_dconic2D_dopacity_bin[global_offset] = batch_reduced_dL_dconic2D_dopacity[batch_id]; - } - cur_reduction_batch_idx = 0; } } From 656552575d52c87c1d668a472b45d6e995e8ca12 Mon Sep 17 00:00:00 2001 From: interestingLSY Date: Mon, 10 Jun 2024 17:34:34 +0800 Subject: [PATCH 6/7] Don't use a separate gather kernel --- cuda_rasterizer/backward.cu | 188 +++++++---------------------- cuda_rasterizer/backward.h | 25 +--- cuda_rasterizer/forward.cu | 6 +- cuda_rasterizer/forward.h | 2 +- cuda_rasterizer/rasterizer_impl.cu | 31 +---- cuda_rasterizer/rasterizer_impl.h | 9 +- 6 files changed, 60 insertions(+), 201 deletions(-) diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 7f8edd44..89523419 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -429,7 +429,7 @@ template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA( const uint2* __restrict__ ranges, - const uint64_t* __restrict__ point_list, + const uint32_t* __restrict__ point_list, int W, int H, const float* __restrict__ bg_color, const float2* __restrict__ points_xy_image, @@ -437,15 +437,12 @@ renderCUDA( const float* __restrict__ colors, const float* __restrict__ final_Ts, const uint32_t* __restrict__ n_contrib, - const float* __restrict__ dL_dpixels, const uint32_t* __restrict__ tiles_touched, - float* __restrict__ dL_dcolors_bin, - float2* __restrict__ dL_dmean2D_bin, - float4* __restrict__ dL_dconic2D_dopacity_bin, - float* __restrict__ dL_dcolors_global, - float3* __restrict__ dL_dmean2D_global, - float4* __restrict__ dL_dconic2D_global, - float* __restrict__ dL_dopacity_global) + const float* __restrict__ dL_dpixels, + float3* __restrict__ dL_dmean2D, + float4* __restrict__ dL_dconic2D, + float* __restrict__ dL_dopacity, + float* __restrict__ dL_dcolors) { // We rasterize again. Compute necessary block info. auto block = cg::this_thread_block(); @@ -463,8 +460,7 @@ renderCUDA( int toDo = range.y - range.x; - __shared__ int collected_offset[BLOCK_SIZE]; // Offsets of instances before sorting when USE_ATOMIC_ADD is False - // Id of the gaussian when USE_ATOMIC_ADD is True + __shared__ int collected_id[BLOCK_SIZE]; __shared__ bool collected_use_atomic[BLOCK_SIZE]; __shared__ float2 collected_xy[BLOCK_SIZE]; __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; @@ -503,13 +499,11 @@ renderCUDA( const int progress = i * BLOCK_SIZE + block.thread_rank(); if (range.x + progress < range.y) { - const uint64_t coll_id_and_offset = point_list[range.y - progress - 1]; - const int coll_id = coll_id_and_offset>>32; - const int offset_before_sorting = coll_id_and_offset & 0xFFFFFFFF; + const int coll_id = point_list[range.y - progress - 1]; const int cur_tiles_touched = tiles_touched[coll_id]; bool cur_use_atomic = cur_tiles_touched <= USE_ATOMIC_THRESHOLD; collected_use_atomic[block.thread_rank()] = cur_use_atomic; - collected_offset[block.thread_rank()] = cur_use_atomic ? coll_id : offset_before_sorting; + collected_id[block.thread_rank()] = coll_id; collected_xy[block.thread_rank()] = points_xy_image[coll_id]; collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; for (int i = 0; i < C; i++) @@ -552,7 +546,7 @@ renderCUDA( // gradients w.r.t. alpha (blending factor for a Gaussian/pixel // pair). float dL_dalpha = 0.0f; - const int global_id = collected_offset[j]; + const int global_id = collected_id[j]; #pragma unroll for (int ch = 0; ch < C; ch++) { @@ -567,7 +561,7 @@ renderCUDA( // Atomic, since this pixel is just one of potentially // many that were affected by this Gaussian. if (use_atomic) { - atomicAdd(&dL_dcolors_global[global_id*C + ch], dchannel_dcolor * dL_dchannel); + atomicAdd(&dL_dcolors[global_id*C + ch], dchannel_dcolor * dL_dchannel); } else { cur_dL_dcolors[ch] = dchannel_dcolor * dL_dchannel; } @@ -591,16 +585,16 @@ renderCUDA( const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y; const float dG_ddely = -gdy * con_o.z - gdx * con_o.y; - // Update gradients w.r.t. 2D mean position of the Gaussian - // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) - // Update gradients w.r.t. opacity of the Gaussian if (use_atomic) { - atomicAdd(&dL_dmean2D_global[global_id].x, dL_dG * dG_ddelx * ddelx_dx); - atomicAdd(&dL_dmean2D_global[global_id].y, dL_dG * dG_ddely * ddely_dy); - atomicAdd(&dL_dconic2D_global[global_id].x, -0.5f * gdx * d.x * dL_dG); - atomicAdd(&dL_dconic2D_global[global_id].y, -0.5f * gdx * d.y * dL_dG); - atomicAdd(&dL_dconic2D_global[global_id].w, -0.5f * gdy * d.y * dL_dG); - atomicAdd(&dL_dopacity_global[global_id], G * dL_dalpha); + // Update gradients w.r.t. 2D mean position of the Gaussian + atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx); + atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy); + // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) + atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG); + atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG); + atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG); + // Update gradients w.r.t. opacity of the Gaussian + atomicAdd(&dL_dopacity[global_id], G * dL_dalpha); } else { cur_dL_dmean2D = { dL_dG * dG_ddelx * ddelx_dx, @@ -683,17 +677,18 @@ renderCUDA( // Store the results in global memory if (lane_id == 0) { - const int global_offset = collected_offset[batch_j[batch_id]]; - if constexpr(C == 3) { - // Special optimization for C == 3 - ((float3*)dL_dcolors_bin)[global_offset] = make_float3(cur_dL_dcolors[0], cur_dL_dcolors[1], cur_dL_dcolors[2]); - } else { - #pragma unroll - for (int ch = 0; ch < C; ch++) - dL_dcolors_bin[global_offset * C + ch] = cur_dL_dcolors[ch]; - } - dL_dmean2D_bin[global_offset] = cur_dL_dmean2D; - dL_dconic2D_dopacity_bin[global_offset] = cur_dL_dconic2D_dopacity; + const int global_id = collected_id[batch_j[batch_id]]; + // if (global_id < 0 || global_id >= 208424) + // printf("%d\n", global_id); + #pragma unroll + for (int ch = 0; ch < C; ch++) + atomicAdd(&dL_dcolors[global_id * C + ch], cur_dL_dcolors[ch]); + atomicAdd(&dL_dmean2D[global_id].x, cur_dL_dmean2D.x); + atomicAdd(&dL_dmean2D[global_id].y, cur_dL_dmean2D.y); + atomicAdd(&dL_dconic2D[global_id].x, cur_dL_dconic2D_dopacity.x); + atomicAdd(&dL_dconic2D[global_id].y, cur_dL_dconic2D_dopacity.y); + atomicAdd(&dL_dconic2D[global_id].w, cur_dL_dconic2D_dopacity.w); + atomicAdd(&dL_dopacity[global_id], cur_dL_dconic2D_dopacity.z); } } @@ -707,69 +702,6 @@ renderCUDA( } } -template -__global__ void gather_gradientsCUDA( - int P, - const uint32_t* __restrict__ point_offsets, - const uint32_t* __restrict__ tiles_touched, - const float* __restrict__ dL_dcolors_bin, - const float2* __restrict__ dL_dmean2D_bin, - const float4* __restrict__ dL_dconic2D_dopacity_bin, - float* __restrict__ dL_dcolors, - float3* __restrict__ dL_dmean2D, - float4* __restrict__ dL_dconic2D, - float* __restrict__ dL_dopacity) -{ - // Every warp is responsible for one Gaussian - int gaussian_id = cg::this_grid().thread_rank() / WARP_SIZE; - if (gaussian_id >= P) - return; - int cur_tiles_touched = tiles_touched[gaussian_id]; - if (cur_tiles_touched <= USE_ATOMIC_THRESHOLD) - return; - - int lane_id = cg::this_thread_block().thread_rank() % WARP_SIZE; - int range_start = gaussian_id == 0 ? 0 : point_offsets[gaussian_id-1]; - int range_end = point_offsets[gaussian_id]; - - float dL_dcolors_sum[C] = { 0 }; - float2 dL_dmean2D_sum = { 0, 0 }; - float4 dL_dconic2D_dopacity_sum = { 0, 0, 0, 0 }; - - #pragma unroll 2 - for (int i = range_start + lane_id; i < range_end; i += WARP_SIZE) { - #pragma unroll - for (int ch = 0; ch < C; ch++) - dL_dcolors_sum[ch] += dL_dcolors_bin[i * C + ch]; - float2 cur_mean2d = dL_dmean2D_bin[i]; - dL_dmean2D_sum.x += cur_mean2d.x; - dL_dmean2D_sum.y += cur_mean2d.y; - float4 cur_conic2d_dopacity = dL_dconic2D_dopacity_bin[i]; - dL_dconic2D_dopacity_sum.x += cur_conic2d_dopacity.x; - dL_dconic2D_dopacity_sum.y += cur_conic2d_dopacity.y; - dL_dconic2D_dopacity_sum.z += cur_conic2d_dopacity.z; - dL_dconic2D_dopacity_sum.w += cur_conic2d_dopacity.w; - } - - // Warp-level reduction - #pragma unroll - for (int ch = 0; ch < C; ch++) - dL_dcolors_sum[ch] = warpReduceSum(dL_dcolors_sum[ch]); - dL_dmean2D_sum = warpReduceSum(dL_dmean2D_sum); - dL_dconic2D_dopacity_sum = warpReduceSum(dL_dconic2D_dopacity_sum); - - // Write-back - if (lane_id == 0) { - #pragma unroll - for (int ch = 0; ch < C; ch++) - dL_dcolors[gaussian_id * C + ch] = dL_dcolors_sum[ch]; - dL_dmean2D[gaussian_id].x = dL_dmean2D_sum.x; - dL_dmean2D[gaussian_id].y = dL_dmean2D_sum.y; - dL_dconic2D[gaussian_id] = dL_dconic2D_dopacity_sum; - dL_dopacity[gaussian_id] = dL_dconic2D_dopacity_sum.z; - } -} - void BACKWARD::preprocess( int P, int D, int M, const float3* means3D, @@ -835,37 +767,10 @@ void BACKWARD::preprocess( dL_drot); } -void BACKWARD::gather_gradients( - int P, - const uint32_t* point_offsets, - const uint32_t* tiles_touched, - const float* dL_dcolors_bin, - const float2* dL_dmean2D_bin, - const float4* dL_dconic2D_dopacity_bin, - float* dL_dcolors, - float3* dL_dmean2D, - float4* dL_dconic2D, - float* dL_dopacity) -{ - static constexpr int N_WARPS = 16; - int num_blocks = (P + N_WARPS - 1) / N_WARPS; - gather_gradientsCUDA << > > ( - P, - point_offsets, - tiles_touched, - dL_dcolors_bin, - dL_dmean2D_bin, - dL_dconic2D_dopacity_bin, - dL_dcolors, - dL_dmean2D, - dL_dconic2D, - dL_dopacity); -} - void BACKWARD::render( const dim3 grid, const dim3 block, const uint2* ranges, - const uint64_t* point_list, + const uint32_t* point_list, int W, int H, const float* bg_color, const float2* means2D, @@ -873,15 +778,12 @@ void BACKWARD::render( const float* colors, const float* final_Ts, const uint32_t* n_contrib, - const float* dL_dpixels, const uint32_t* tiles_touched, - float* dL_dcolors_bin, - float2* dL_dmean2D_bin, - float4* dL_dconic2D_dopacity_bin, - float* __restrict__ dL_dcolors_global, - float3* __restrict__ dL_dmean2D_global, - float4* __restrict__ dL_dconic2D_global, - float* __restrict__ dL_dopacity_global) + const float* dL_dpixels, + float3* dL_dmean2D, + float4* dL_dconic2D, + float* dL_dopacity, + float* dL_dcolors) { renderCUDA << > >( ranges, @@ -893,13 +795,11 @@ void BACKWARD::render( colors, final_Ts, n_contrib, - dL_dpixels, tiles_touched, - dL_dcolors_bin, - dL_dmean2D_bin, - dL_dconic2D_dopacity_bin, - dL_dcolors_global, - dL_dmean2D_global, - dL_dconic2D_global, - dL_dopacity_global); + dL_dpixels, + dL_dmean2D, + dL_dconic2D, + dL_dopacity, + dL_dcolors + ); } diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 1a4af215..d39e0bbd 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -21,9 +21,9 @@ namespace BACKWARD { void render( - const dim3 grid, const dim3 block, + const dim3 grid, dim3 block, const uint2* ranges, - const uint64_t* point_list, + const uint32_t* point_list, int W, int H, const float* bg_color, const float2* means2D, @@ -31,27 +31,12 @@ namespace BACKWARD const float* colors, const float* final_Ts, const uint32_t* n_contrib, - const float* dL_dpixels, - const uint32_t* tiles_touched, - float* dL_dcolors_bin, - float2* dL_dmean2D_bin, - float4* dL_dconic2D_dopacity_bin, - float* __restrict__ dL_dcolors_global, - float3* __restrict__ dL_dmean2D_global, - float4* __restrict__ dL_dconic2D_global, - float* __restrict__ dL_dopacity_global); - - void gather_gradients( - int P, - const uint32_t* point_offsets, const uint32_t* tiles_touched, - const float* dL_dcolors_bin, - const float2* dL_dmean2D_bin, - const float4* dL_dconic2D_dopacity_bin, - float* dL_dcolors, + const float* dL_dpixels, float3* dL_dmean2D, float4* dL_dconic2D, - float* dL_dopacity); + float* dL_dopacity, + float* dL_dcolors); void preprocess( int P, int D, int M, diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index 1a7e98f2..c419a328 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -262,7 +262,7 @@ template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA( const uint2* __restrict__ ranges, - const uint64_t* __restrict__ point_list, + const uint32_t* __restrict__ point_list, int W, int H, const float2* __restrict__ points_xy_image, const float* __restrict__ features, @@ -314,7 +314,7 @@ renderCUDA( int progress = i * BLOCK_SIZE + block.thread_rank(); if (range.x + progress < range.y) { - int coll_id = point_list[range.x + progress]>>32; + int coll_id = point_list[range.x + progress]; collected_id[block.thread_rank()] = coll_id; collected_xy[block.thread_rank()] = points_xy_image[coll_id]; collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; @@ -376,7 +376,7 @@ renderCUDA( void FORWARD::render( const dim3 grid, dim3 block, const uint2* ranges, - const uint64_t* point_list, + const uint32_t* point_list, int W, int H, const float2* means2D, const float* colors, diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 0caf6fb0..3c11cb91 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -51,7 +51,7 @@ namespace FORWARD void render( const dim3 grid, dim3 block, const uint2* ranges, - const uint64_t* point_list, + const uint32_t* point_list, int W, int H, const float2* points_xy_image, const float* features, diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index 0ea799b6..c3931845 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -73,7 +73,7 @@ __global__ void duplicateWithKeys( const float* depths, const uint32_t* offsets, uint64_t* gaussian_keys_unsorted, - uint64_t* gaussian_values_unsorted, + uint32_t* gaussian_values_unsorted, int* radii, dim3 grid) { @@ -103,7 +103,7 @@ __global__ void duplicateWithKeys( key <<= 32; key |= *((uint32_t*)&depths[idx]); gaussian_keys_unsorted[off] = key; - gaussian_values_unsorted[off] = idx<<32 | off; + gaussian_values_unsorted[off] = idx; off++; } } @@ -190,9 +190,6 @@ CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chun binning.point_list_keys_unsorted, binning.point_list_keys, binning.point_list_unsorted, binning.point_list, P); obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); - obtain(chunk, binning.dL_dcolors, P * NUM_CHANNELS, 128); - obtain(chunk, binning.dL_dmean2D, P * 2, 128); - obtain(chunk, binning.dL_dconic2D_dopacity, P * 4, 128); return binning; } @@ -386,8 +383,6 @@ void CudaRasterizer::Rasterizer::backward( const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); const dim3 block(BLOCK_X, BLOCK_Y, 1); - // printf("%d %d %.2f\n", P, R, (float)R/P); - // Compute loss gradients w.r.t. 2D mean position, conic matrix, // opacity and RGB of Gaussians from per-pixel loss gradients. // If we were given precomputed colors and not SHs, use them. @@ -404,28 +399,12 @@ void CudaRasterizer::Rasterizer::backward( color_ptr, imgState.accum_alpha, imgState.n_contrib, - dL_dpix, - geomState.tiles_touched, - binningState.dL_dcolors, - binningState.dL_dmean2D, - binningState.dL_dconic2D_dopacity, - dL_dcolor, - (float3*)dL_dmean2D, - (float4*)dL_dconic, - dL_dopacity - ), debug) - - CHECK_CUDA(BACKWARD::gather_gradients( - P, - geomState.point_offsets, geomState.tiles_touched, - binningState.dL_dcolors, - binningState.dL_dmean2D, - binningState.dL_dconic2D_dopacity, - dL_dcolor, + dL_dpix, (float3*)dL_dmean2D, (float4*)dL_dconic, - dL_dopacity), debug) + dL_dopacity, + dL_dcolor), debug) // Take care of the rest of preprocessing. Was the precomputed covariance // given to us or a scales/rot pair? If precomputed, pass that. If not, diff --git a/cuda_rasterizer/rasterizer_impl.h b/cuda_rasterizer/rasterizer_impl.h index f9db59cf..bc3f0ece 100644 --- a/cuda_rasterizer/rasterizer_impl.h +++ b/cuda_rasterizer/rasterizer_impl.h @@ -57,15 +57,10 @@ namespace CudaRasterizer size_t sorting_size; uint64_t* point_list_keys_unsorted; uint64_t* point_list_keys; - uint64_t* point_list_unsorted; // High 32 bits are the id of the gaussian, - // Low 32 bits are the index of the copy before sorting - uint64_t* point_list; + uint32_t* point_list_unsorted; + uint32_t* point_list; char* list_sorting_space; - float* dL_dcolors; - float2* dL_dmean2D; - float4* dL_dconic2D_dopacity; - static BinningState fromChunk(char*& chunk, size_t P); }; From c096556de786efdf9eb7215dd488e4693275bb3a Mon Sep 17 00:00:00 2001 From: interestingLSY Date: Tue, 11 Jun 2024 11:00:59 +0800 Subject: [PATCH 7/7] Remove unused functions --- cuda_rasterizer/backward.cu | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 89523419..2ba6de3d 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -395,36 +395,8 @@ __global__ void preprocessCUDA( computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot); } - -template -__forceinline__ __device__ T warpReduceSum(T val, int num_lanes = 32) { - #pragma unroll - for (int offset = num_lanes / 2; offset > 0; offset /= 2) - val += __shfl_xor_sync(0xFFFFFFFF, val, offset); - return val; -} -__forceinline__ __device__ float2 warpReduceSum(float2 val, int num_lanes = 32) { - #pragma unroll - for (int offset = num_lanes / 2; offset > 0; offset /= 2) { - val.x += __shfl_xor_sync(0xFFFFFFFF, val.x, offset); - val.y += __shfl_xor_sync(0xFFFFFFFF, val.y, offset); - } - return val; -} -__forceinline__ __device__ float4 warpReduceSum(float4 val, int num_lanes = 32) { - #pragma unroll - for (int offset = num_lanes / 2; offset > 0; offset /= 2) { - val.x += __shfl_xor_sync(0xFFFFFFFF, val.x, offset); - val.y += __shfl_xor_sync(0xFFFFFFFF, val.y, offset); - val.z += __shfl_xor_sync(0xFFFFFFFF, val.z, offset); - val.w += __shfl_xor_sync(0xFFFFFFFF, val.w, offset); - } - return val; -} - - // Backward version of the rendering procedure. -#define USE_ATOMIC_THRESHOLD 6 +#define USE_ATOMIC_THRESHOLD 10 template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA(