@@ -66,6 +66,14 @@ struct PackedVec3::State {
6666 // / A map from type to the name of a helper function used to unpack that type.
6767 Hashmap<const core::type::Type*, Symbol, 4 > unpack_helpers;
6868
69+ // / @returns true if @p addrspace requires vec3 types to be packed
70+ bool AddressSpaceNeedsPacking (core::AddressSpace addrspace) {
71+ // Host-shareable address spaces need to be packed to match the memory layout on the host.
72+ // The workgroup address space needs to be packed so that the size of generated threadgroup
73+ // variables matches the size of the original WGSL declarations.
74+ return core::IsHostShareable (addrspace) || addrspace == core::AddressSpace::kWorkgroup ;
75+ }
76+
6977 // / @param ty the type to test
7078 // / @returns true if `ty` is a vec3, false otherwise
7179 bool IsVec3 (const core::type::Type* ty) {
@@ -342,7 +350,7 @@ struct PackedVec3::State {
342350 // if the transform is necessary.
343351 for (auto * decl : src->AST ().GlobalVariables ()) {
344352 auto * var = sem.Get <sem::GlobalVariable>(decl);
345- if (var && core::IsHostShareable (var->AddressSpace ()) &&
353+ if (var && AddressSpaceNeedsPacking (var->AddressSpace ()) &&
346354 ContainsVec3 (var->Type ()->UnwrapRef ())) {
347355 return true ;
348356 }
@@ -379,7 +387,7 @@ struct PackedVec3::State {
379387 [&](const sem::TypeExpression* type) {
380388 // Rewrite pointers to types that contain vec3s.
381389 auto * ptr = type->Type ()->As <core::type::Pointer>();
382- if (ptr && core::IsHostShareable (ptr->AddressSpace ())) {
390+ if (ptr && AddressSpaceNeedsPacking (ptr->AddressSpace ())) {
383391 auto new_store_type = RewriteType (ptr->StoreType ());
384392 if (new_store_type) {
385393 auto access = ptr->AddressSpace () == core::AddressSpace::kStorage
@@ -392,7 +400,7 @@ struct PackedVec3::State {
392400 }
393401 },
394402 [&](const sem::Variable* var) {
395- if (!core::IsHostShareable (var->AddressSpace ())) {
403+ if (!AddressSpaceNeedsPacking (var->AddressSpace ())) {
396404 return ;
397405 }
398406
@@ -408,7 +416,7 @@ struct PackedVec3::State {
408416 auto * lhs = sem.GetVal (assign->lhs );
409417 auto * rhs = sem.GetVal (assign->rhs );
410418 if (!ContainsVec3 (rhs->Type ()) ||
411- !core::IsHostShareable (
419+ !AddressSpaceNeedsPacking (
412420 lhs->Type ()->As <core::type::Reference>()->AddressSpace ())) {
413421 // Skip assignments to address spaces that are not host-shareable, or
414422 // that do not contain vec3 types.
@@ -436,7 +444,7 @@ struct PackedVec3::State {
436444 [&](const sem::Load* load) {
437445 // Unpack loads of types that contain vec3s in host-shareable address spaces.
438446 if (ContainsVec3 (load->Type ()) &&
439- core::IsHostShareable (load->ReferenceType ()->AddressSpace ())) {
447+ AddressSpaceNeedsPacking (load->ReferenceType ()->AddressSpace ())) {
440448 to_unpack.Add (load);
441449 }
442450 },
@@ -446,7 +454,7 @@ struct PackedVec3::State {
446454 // struct.
447455 if (auto * ref = accessor->Type ()->As <core::type::Reference>()) {
448456 if (IsVec3 (ref->StoreType ()) &&
449- core::IsHostShareable (ref->AddressSpace ())) {
457+ AddressSpaceNeedsPacking (ref->AddressSpace ())) {
450458 ctx.Replace (node, b.MemberAccessor (ctx.Clone (accessor->Declaration ()),
451459 kStructMemberName ));
452460 }
0 commit comments