Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,16 +1045,37 @@ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) ->
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Convert an RGB-like depth image to a depth map.

Args:
image (`Union[np.ndarray, torch.Tensor]`):
The RGB-like depth image to convert.

Returns:
`Union[np.ndarray, torch.Tensor]`:
The corresponding depth map.
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
# 1. Cast the tensor to a larger integer type (e.g., int32)
# to safely perform the multiplication by 256.
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
# 3. Cast the final result to the desired depth map type (uint16) if needed
# before returning, though leaving it as int32/int64 is often safer
# for return value from a library function.

if isinstance(image, torch.Tensor):
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
image_safe = image.to(torch.int32)

# Calculate the depth map
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]

# You may want to cast the final result to uint16, but casting to a
# larger int type (like int32) is sufficient to fix the overflow.
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
return depth_map
Copy link
Collaborator

@DN6 DN6 Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good. Could we just store the original dtype and cast back once the computation is done.

original_dtype = image.dtype
# perform compute
return depth_map.to(original_dtype)


elif isinstance(image, np.ndarray):
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
image_safe = image.astype(np.int32)

# Calculate the depth map
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]

# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
return depth_map
else:
raise TypeError("Input image must be a torch.Tensor or np.ndarray")

def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
r"""
Expand Down