-
Notifications
You must be signed in to change notification settings - Fork 368
addresses the case when shape of upsample tensor contains ITensor #3841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I think the functionality of to_trt_shape_tensor is probably available at couple of places (eg: concat iirc) manually. Do you think we could unify all this ?
|
yeah thats true. cat does it for inputs while the above is for shape tensor. Yeah I guess we could unify this. Should |
3bfa4e0 to
9403b0f
Compare
9403b0f to
3d3a8ee
Compare
| # promote remaining ints to TRT consts before concat | ||
| for i, t in enumerate(trt_tensors): | ||
| if isinstance(t, int): | ||
| const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) | ||
| set_layer_name(const, target, f"{name}_static_{i}_const") | ||
| trt_tensors[i] = const.get_output(0) | ||
|
|
||
| concat = ctx.net.add_concatenation(trt_tensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If trt_tensors have a mix of scalar integers and ITensors of dtype int64, would this work (because you're casting the scalar integers to int32 explicitly) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of shape tensors int will always be int32, so in that case this should work.
Coming to cat case. concat tensors will be either torch.Tensor or TRTTensor. They cannot be int. So I think the above should cover all the cases. Can you think of any other case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So my thought is how are we ensuring all trt_tensors have same datatypes explicitly before concatenating here because that will error out ?
This check could either be an assertion check or explicit type promotion of tensors within trt_tensor
|
Embedding bag looks like is failing. Need to look into |
Addresses #3783