Skip to content

Conversation

@apbose
Copy link
Collaborator

@apbose apbose commented Sep 27, 2025

Addresses #3783

@meta-cla meta-cla bot added the cla signed label Sep 27, 2025
@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 27, 2025
@github-actions github-actions bot requested a review from peri044 September 27, 2025 01:57
@github-actions github-actions bot added the component: tests Issues re: Tests label Oct 15, 2025
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I think the functionality of to_trt_shape_tensor is probably available at couple of places (eg: concat iirc) manually. Do you think we could unify all this ?

@apbose
Copy link
Collaborator Author

apbose commented Oct 17, 2025

yeah thats true. cat does it for inputs while the above is for shape tensor. Yeah I guess we could unify this. Should
I merge this for now since it addresses user error? Or do you think I should merge after unifying?

@apbose apbose force-pushed the abose/upsample_shape_ITensor_list branch from 3bfa4e0 to 9403b0f Compare October 20, 2025 22:46
@apbose apbose force-pushed the abose/upsample_shape_ITensor_list branch from 9403b0f to 3d3a8ee Compare October 20, 2025 22:49
Comment on lines 63 to 70
# promote remaining ints to TRT consts before concat
for i, t in enumerate(trt_tensors):
if isinstance(t, int):
const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32))
set_layer_name(const, target, f"{name}_static_{i}_const")
trt_tensors[i] = const.get_output(0)

concat = ctx.net.add_concatenation(trt_tensors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If trt_tensors have a mix of scalar integers and ITensors of dtype int64, would this work (because you're casting the scalar integers to int32 explicitly) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of shape tensors int will always be int32, so in that case this should work.
Coming to cat case. concat tensors will be either torch.Tensor or TRTTensor. They cannot be int. So I think the above should cover all the cases. Can you think of any other case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my thought is how are we ensuring all trt_tensors have same datatypes explicitly before concatenating here because that will error out ?
This check could either be an assertion check or explicit type promotion of tensors within trt_tensor

@apbose
Copy link
Collaborator Author

apbose commented Oct 23, 2025

Embedding bag looks like is failing. Need to look into

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants