@@ -538,33 +538,24 @@ private class Gemma3nAttention: Module {
538538 values = vProj ( x) . reshaped ( hiddenShape)
539539 values = vNorm ( values)
540540 values = values. transposed ( 0 , 2 , 1 , 3 )
541-
542- if let cache = cache {
543- ( keys, values) = cache. update ( keys: keys, values: values)
544- }
545541 }
546542
543+ // Repeat keys and values for multi-head attention
547544 keys = repeated ( keys, count: repeats, axis: 1 )
548545 values = repeated ( values, count: repeats, axis: 1 )
549546
550- var attnWeights = matmul ( queries, keys. swappedAxes ( 2 , 3 ) ) * scale
551-
552- if attnLogitSoftcapping > 0 {
553- attnWeights = attnWeights / attnLogitSoftcapping
554- attnWeights = tanh ( attnWeights)
555- attnWeights = attnWeights * attnLogitSoftcapping
556- }
557-
558- if case . array( let maskArray) = mask {
559- let causalMask = maskArray [ 0 ... , ..< keys. shape [ 2 ] ]
560- attnWeights = attnWeights + causalMask
561- }
562-
563- attnWeights = softmax ( attnWeights. asType ( . float32) , axis: - 1 ) . asType ( queries. dtype)
564-
565- let output = matmul ( attnWeights, values)
566- . transposed ( 0 , 2 , 1 , 3 )
567- . reshaped ( inputShape + [ - 1 ] )
547+ // Use custom attention function that supports both quantized cache and logit softcapping
548+ let output = gemma3nAttentionWithCacheUpdate (
549+ queries: queries,
550+ keys: keys,
551+ values: values,
552+ cache: cache,
553+ scale: scale,
554+ attnLogitSoftcapping: attnLogitSoftcapping,
555+ mask: mask ?? . none
556+ )
557+ . transposed ( 0 , 2 , 1 , 3 )
558+ . reshaped ( inputShape + [ - 1 ] )
568559
569560 return oProj ( output)
570561 }
@@ -1308,6 +1299,72 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
13081299
13091300// MARK: - Helper Functions
13101301
1302+ // MARK: - Custom Attention for Gemma3n with Logit Softcapping
1303+
1304+ /// Custom attention function for Gemma3n that supports:
1305+ /// - Logit softcapping (applied before softmax)
1306+ /// - Standard KV cache support
1307+ /// - Exact alignment with Python implementation
1308+ ///
1309+ /// TODO: Quantized KV Cache Integration
1310+ /// Action items for adding quantized cache support:
1311+ /// 1. Add QuantizedKVCache detection: `if let quantizedKVCache = cache as? QuantizedKVCache`
1312+ /// 2. Use quantizedKVCache.updateQuantized(keys: keys, values: values) for cache update
1313+ /// 3. Implement manual quantized attention computation with logit softcapping:
1314+ /// - Cannot use quantizedScaledDotProductAttention directly (no softcapping support)
1315+ /// - Need to manually compute: matmul(queries, dequantized_keys) with softcapping
1316+ /// - May require dequantization of keys for logit softcapping application
1317+ /// 4. Consider performance trade-offs:
1318+ /// - Manual dequantization vs quantized attention benefits
1319+ /// - Might need hybrid approach or dedicated quantized+softcapping function
1320+ /// 5. Test with QuantizedKVCache to ensure numerical accuracy matches Python
1321+ /// 6. Update documentation and examples
1322+ private func gemma3nAttentionWithCacheUpdate(
1323+ queries: MLXArray ,
1324+ keys: MLXArray ,
1325+ values: MLXArray ,
1326+ cache: KVCache ? ,
1327+ scale: Float ,
1328+ attnLogitSoftcapping: Float ,
1329+ mask: MLXFast . ScaledDotProductAttentionMaskMode = . none
1330+ ) -> MLXArray {
1331+ // Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
1332+ let ( cachedKeys, cachedValues) : ( MLXArray , MLXArray )
1333+
1334+ if let cache = cache {
1335+ ( cachedKeys, cachedValues) = cache. update ( keys: keys, values: values)
1336+ } else {
1337+ ( cachedKeys, cachedValues) = ( keys, values)
1338+ }
1339+
1340+ // Manual attention computation to support logit softcapping
1341+ // This matches the Python implementation exactly:
1342+ // attn_weights = mx.matmul(queries, keys.swapaxes(2, 3)) * self.scale
1343+ var attnWeights = matmul ( queries, cachedKeys. swappedAxes ( 2 , 3 ) ) * scale
1344+
1345+ // Apply logit softcapping if enabled (matches Python)
1346+ // if self.attn_logit_softcapping is not None and self.attn_logit_softcapping > 0:
1347+ if attnLogitSoftcapping > 0 {
1348+ attnWeights = attnWeights / attnLogitSoftcapping
1349+ attnWeights = tanh ( attnWeights)
1350+ attnWeights = attnWeights * attnLogitSoftcapping
1351+ }
1352+
1353+ // Apply mask if provided (matches Python)
1354+ // if mask is not None: causal_mask = mask[:, : keys.shape[-2]]
1355+ if case . array( let maskArray) = mask {
1356+ let causalMask = maskArray [ 0 ... , ..< cachedKeys. shape [ 2 ] ]
1357+ attnWeights = attnWeights + causalMask
1358+ }
1359+
1360+ // Apply softmax and compute output (matches Python)
1361+ // attn_weights = mx.softmax(attn_weights.astype(mx.float32), axis=-1).astype(queries.dtype)
1362+ attnWeights = softmax ( attnWeights. asType ( . float32) , axis: - 1 ) . asType ( queries. dtype)
1363+
1364+ // output = mx.matmul(attn_weights, values)
1365+ return matmul ( attnWeights, cachedValues)
1366+ }
1367+
13111368private func bicubicInterpolate( _ x: MLXArray , to targetSize: ( Int , Int ) , alignCorners: Bool = false ) -> MLXArray {
13121369 // TODO: This implementation uses nested loops and sequential MLX operations, which is much slower
13131370 // than the Python version that uses mx.fast.metal_kernel() for parallel GPU computation.
0 commit comments