@@ -1340,33 +1340,27 @@ private class LanguageModel: Module, KVCacheDimensionProvider {
13401340 }
13411341
13421342 func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
1343- var sanitizedWeights = [ String: MLXArray] ( )
1344-
1343+ var sanitizedWeights = weights
13451344 for (k, v) in weights {
1346- // Skip rotary embedding inverse frequency weights (matches Python exactly)
1347- if k. contains ( " self_attn.rotary_emb.inv_freq " ) {
1348- continue
1349- }
1350- // Python logic: if "language_model.model" not in k and "language_model.lm_head" not in k:
1351- else if !k. contains ( " language_model.model " ) && !k. contains ( " language_model.lm_head " ) {
1345+ if !k. contains ( " language_model.model " ) && !k. contains ( " language_model.lm_head " ) {
1346+ // Transform keys that don't contain the specific patterns
13521347 let newKey = k. replacingOccurrences (
13531348 of: " language_model " , with: " language_model.model " )
13541349 sanitizedWeights [ newKey] = v
1355- }
1356- // Otherwise, keep the key as is
1357- else {
1350+ } else if k. contains ( " self_attn.rotary_emb.inv_freq " ) {
1351+ // Skip rotary embedding inverse frequency weights
1352+ continue
1353+ } else {
13581354 sanitizedWeights [ k] = v
13591355 }
13601356 }
1361-
1362- // If lm_head weight is missing, use embed_tokens weight as fallback (matches Python exactly)
1357+ // Handle tied lm_head weights
13631358 if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
13641359 let embedTokensKey = " language_model.model.embed_tokens.weight "
13651360 if let embedWeight = sanitizedWeights [ embedTokensKey] {
13661361 sanitizedWeights [ " language_model.lm_head.weight " ] = embedWeight
13671362 }
13681363 }
1369-
13701364 return sanitizedWeights
13711365 }
13721366}
@@ -1676,7 +1670,6 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
16761670 self . _languageModel. wrappedValue = LanguageModel ( config: config. textConfig)
16771671 self . _visionTower. wrappedValue = Gemma3nVisionModel ( config: config. visionConfig)
16781672 self . _audioTower. wrappedValue = Gemma3nAudioModel ( config: config. audioConfig)
1679-
16801673 self . _embedVision. wrappedValue = Gemma3nMultimodalEmbedder (
16811674 multimodalConfig: config. visionConfig,
16821675 textConfig: config. textConfig
@@ -1893,20 +1886,16 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18931886 return languageModel ( inputs: inputs, cache: convertedCache) . logits
18941887 }
18951888
1896- // In class Gemma3n
18971889 public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
18981890 var sanitizedWeights = [ String: MLXArray] ( )
1899-
1900- // Remove the "model." prefix from keys.
19011891 for (k, v) in weights {
1902- if k. hasPrefix ( " model. " ) {
1892+ if k. starts ( with : " model. " ) {
19031893 let newKey = k. split ( separator: " . " ) . dropFirst ( ) . joined ( separator: " . " )
19041894 sanitizedWeights [ newKey] = v
19051895 } else {
19061896 sanitizedWeights [ k] = v
19071897 }
19081898 }
1909-
19101899 return sanitizedWeights
19111900 }
19121901
@@ -1937,14 +1926,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19371926 weights. merge ( fileWeights) { _, new in new }
19381927 }
19391928
1940- // Main sanitization (remove "model." prefix)
19411929 var sanitizedWeights = model. sanitize ( weights: weights)
1942-
1943- // Vision model sanitization (transpose conv weights)
1944- sanitizedWeights = Gemma3nVisionModel . sanitizeWeights ( sanitizedWeights)
1945-
1946- // Audio model sanitization (transpose conv weights)
1947- sanitizedWeights = model. audioTower. sanitize ( weights: sanitizedWeights)
1930+ sanitizedWeights = model. visionTower. sanitize ( weights: sanitizedWeights)
1931+ // The audio and language sanitization is not done in the Python implementation
1932+ // sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933+ // sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
19481934
19491935 // Handle tied lm_head weights
19501936 if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
@@ -1992,7 +1978,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
19921978 let maxForward : Int
19931979
19941980 @ModuleInfo ( key: " pos_proj " ) var posProj : Linear
1995- @ ModuleInfo ( key : " inv_timescales " ) var invTimescales : MLXArray
1981+ private let _invTimescales : MLXArray
19961982
19971983 init ( config: AudioConfig ) {
19981984 self . config = config
@@ -2016,7 +2002,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
20162002 MLXArray ( 0 ..< numTimescales) . asType ( . float32) * ( - logTimescaleIncrement)
20172003 )
20182004
2019- self . _invTimescales. wrappedValue = expandedDimensions (
2005+ self . _invTimescales = expandedDimensions (
20202006 expandedDimensions ( invTimescales, axis: 0 ) ,
20212007 axis: 0
20222008 )
@@ -2028,7 +2014,7 @@ private class Gemma3nAudioRelativePositionEmbedding: Module {
20282014 assert ( position. ndim == 2 )
20292015 let positionFloat = expandedDimensions ( position. asType ( . float32) , axis: - 1 )
20302016
2031- let scaledTime = positionFloat * invTimescales
2017+ let scaledTime = positionFloat * _invTimescales
20322018 let timingSignal = concatenated ( [ sin ( scaledTime) , cos ( scaledTime) ] , axis: - 1 )
20332019 return timingSignal. asType ( dtype)
20342020 }
@@ -2328,6 +2314,7 @@ private class Gemma3nAudioSubSampleConvProjection: Module {
23282314
23292315 let fInPadded = currentFForBlockInput + padFLeft + padFRight
23302316 let fOutAfterConv = ( fInPadded - kernelW) / strideW + 1
2317+
23312318 calculatedFOutDims. append ( fOutAfterConv)
23322319 currentFForBlockInput = fOutAfterConv
23332320 }
@@ -2389,8 +2376,8 @@ private class Gemma3nAudioAttention: Module {
23892376 let attentionLogitsSoftCap : Float
23902377 let contextSize : Int
23912378 let qScale : Float
2392- let localCausalValidMask : MLXArray
2393- let softcap : MLXArray
2379+ private let _localCausalValidMask : MLXArray
2380+ private let _softcap : MLXArray
23942381
23952382 @ModuleInfo ( key: " relative_position_embedding " ) var relativePositionEmbedding :
23962383 Gemma3nAudioRelativePositionEmbedding
@@ -2434,9 +2421,10 @@ private class Gemma3nAudioAttention: Module {
24342421 )
24352422
24362423 let localCausalValidMaskTemp = MLXArray . ones ( [ chunkSize, contextSize] , dtype: . bool)
2437- self . localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask .&& upperCausalMask
2424+ self . _localCausalValidMask = localCausalValidMaskTemp .&& lowerCausalMask
2425+ .&& upperCausalMask
24382426
2439- self . softcap = MLXArray ( attentionLogitsSoftCap, dtype: . float32)
2427+ self . _softcap = MLXArray ( attentionLogitsSoftCap, dtype: . float32)
24402428
24412429 super. init ( )
24422430 }
@@ -2536,7 +2524,7 @@ private class Gemma3nAudioAttention: Module {
25362524
25372525 let conditionFromCausality = expandedDimensions (
25382526 expandedDimensions (
2539- expandedDimensions ( localCausalValidMask , axis: 0 ) ,
2527+ expandedDimensions ( _localCausalValidMask , axis: 0 ) ,
25402528 axis: 0
25412529 ) ,
25422530 axis: 0
@@ -2547,9 +2535,9 @@ private class Gemma3nAudioAttention: Module {
25472535 var logits = relativePositionEmbedding ( queryBlocks, keyBlocks)
25482536
25492537 // Apply attention logit softcap
2550- logits = logits / softcap
2538+ logits = logits / _softcap
25512539 logits = tanh ( logits)
2552- logits = logits * softcap
2540+ logits = logits * _softcap
25532541
25542542 // Apply the combined mask
25552543 logits = MLX . where (
@@ -2635,8 +2623,8 @@ private class Gemma3nAudioConformerFeedForward: Module {
26352623 private let _postLayerScale : MLXArray
26362624
26372625 @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2638- @ ModuleInfo ( key : " ffw_layer_1 " ) var ffwLayer1 : Linear
2639- @ ModuleInfo ( key : " ffw_layer_2 " ) var ffwLayer2 : Linear
2626+ private let _ffwLayer1 : Linear
2627+ private let _ffwLayer2 : Linear
26402628 @ModuleInfo ( key: " post_layer_norm " ) var postLayerNorm : Gemma3nRMSNormWithScale
26412629
26422630 init ( config: AudioConfig ) {
@@ -2645,8 +2633,8 @@ private class Gemma3nAudioConformerFeedForward: Module {
26452633 self . _postLayerScale = MLXArray ( config. confResidualWeight)
26462634
26472635 self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2648- self . _ffwLayer1. wrappedValue = Linear ( config. hiddenSize, config. hiddenSize * 4 , bias: false )
2649- self . _ffwLayer2. wrappedValue = Linear ( config. hiddenSize * 4 , config. hiddenSize, bias: false )
2636+ self . _ffwLayer1 = Linear ( config. hiddenSize, config. hiddenSize * 4 , bias: false )
2637+ self . _ffwLayer2 = Linear ( config. hiddenSize * 4 , config. hiddenSize, bias: false )
26502638 self . _postLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
26512639
26522640 super. init ( )
@@ -2656,9 +2644,9 @@ private class Gemma3nAudioConformerFeedForward: Module {
26562644 let residual = x
26572645 let clippedX = clip ( x, min: - _gradientClipping, max: _gradientClipping)
26582646 var result = preLayerNorm ( clippedX)
2659- result = ffwLayer1 ( result)
2647+ result = _ffwLayer1 ( result)
26602648 result = silu ( result)
2661- result = ffwLayer2 ( result)
2649+ result = _ffwLayer2 ( result)
26622650 let clippedResult = clip ( result, min: - _gradientClipping, max: _gradientClipping)
26632651 let normedResult = postLayerNorm ( clippedResult)
26642652 return residual + ( normedResult * _postLayerScale)
@@ -2737,22 +2725,22 @@ private class Gemma3nAudioConformerLightConv1d: Module {
27372725// MARK: - Conformer Block
27382726private class Gemma3nAudioConformerBlock : Module {
27392727 let config : AudioConfig
2740- private let gradientClipping : MLXArray
2728+ private let _gradientClipping : MLXArray
27412729
27422730 @ModuleInfo var ffwLayerStart : Gemma3nAudioConformerFeedForward
27432731 @ModuleInfo var attention : Gemma3nAudioConformerAttention
27442732 @ModuleInfo var lconv1d : Gemma3nAudioConformerLightConv1d
2745- @ ModuleInfo var ffwLayerEnd : Gemma3nAudioConformerFeedForward
2733+ private let _ffwLayerEnd : Gemma3nAudioConformerFeedForward
27462734 @ModuleInfo var norm : Gemma3nRMSNormWithScale
27472735
27482736 init ( config: AudioConfig ) {
27492737 self . config = config
2750- self . gradientClipping = MLXArray ( config. gradientClipping)
2738+ self . _gradientClipping = MLXArray ( config. gradientClipping)
27512739
27522740 self . _ffwLayerStart. wrappedValue = Gemma3nAudioConformerFeedForward ( config: config)
27532741 self . _attention. wrappedValue = Gemma3nAudioConformerAttention ( config: config)
27542742 self . _lconv1d. wrappedValue = Gemma3nAudioConformerLightConv1d ( config: config)
2755- self . _ffwLayerEnd. wrappedValue = Gemma3nAudioConformerFeedForward ( config: config)
2743+ self . _ffwLayerEnd = Gemma3nAudioConformerFeedForward ( config: config)
27562744 self . _norm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
27572745
27582746 super. init ( )
@@ -2770,8 +2758,8 @@ private class Gemma3nAudioConformerBlock: Module {
27702758 ) . asType ( result. dtype)
27712759
27722760 result = lconv1d ( audioencodingsForLconvInput)
2773- result = ffwLayerEnd ( result)
2774- result = clip ( result, min: - gradientClipping , max: gradientClipping )
2761+ result = _ffwLayerEnd ( result)
2762+ result = clip ( result, min: - _gradientClipping , max: _gradientClipping )
27752763 return norm ( result)
27762764 }
27772765}
@@ -2856,7 +2844,8 @@ private func numGroups(groupSize: Int?, channels: Int) -> Int {
28562844 }
28572845 // NOTE: groupSize == 1 -> depthwise conv
28582846 assert ( channels % groupSize == 0 )
2859- return channels / groupSize
2847+ let groups = channels / groupSize
2848+ return groups
28602849}
28612850
28622851private func makeDivisible(
@@ -3082,6 +3071,7 @@ private class EdgeResidual: Module, UnaryLayer {
30823071 self . hasSkip = ( inChannels == outChannels && stride == 1 ) && !noskip
30833072
30843073 let padding = ( expKernelSize - 1 ) / 2
3074+
30853075 self . _convExp. wrappedValue = Conv2d (
30863076 inputChannels: inChannels,
30873077 outputChannels: midChannels,
@@ -3195,6 +3185,7 @@ private class MultiQueryAttention2d: Module {
31953185 groups: dim, // Depthwise
31963186 bias: false
31973187 )
3188+
31983189 self . _keyNorm. wrappedValue = RMSNormAct2d ( numChannels: dim, eps: 1e-6 , applyAct: false )
31993190 } else {
32003191 self . _keyDownConv. wrappedValue = Identity ( )
@@ -3780,37 +3771,23 @@ private class Gemma3nVisionModel: Module {
37803771 }
37813772
37823773 func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3783- return Self . sanitizeWeights ( weights)
3784- }
3785-
3786- static func sanitizeWeights( _ weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3787- var sanitizedWeights = [ String: MLXArray] ( )
3774+ var sanitizedWeights = weights
37883775 var skipTranspose = false
3789-
3790- // This logic is correct
37913776 let testKey = " vision_tower.timm_model.blocks.0.0.conv_exp.weight "
3792- if let convWeight = weights [ testKey] {
3793- let shape = convWeight. shape
3794- if shape. count == 4 , shape [ 3 ] > shape [ 1 ] {
3795- skipTranspose = true
3796- }
3777+ if let convWeight = weights [ testKey] , convWeight. ndim == 4 ,
3778+ convWeight. shape [ 3 ] > convWeight. shape [ 1 ]
3779+ {
3780+ skipTranspose = true
37973781 }
3798-
37993782 for (k, v) in weights {
38003783 if ( k. contains ( " conv " ) && k. contains ( " weight " ) )
38013784 || ( k. contains ( " attn " ) && k. contains ( " proj.weight " ) )
38023785 {
3803- if v. shape . count == 4 && !skipTranspose {
3786+ if v. ndim == 4 && !skipTranspose {
38043787 sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3805- } else {
3806- sanitizedWeights [ k] = v
38073788 }
3808- } else {
3809- // Copy all other weights (biases, norm layers, etc.)
3810- sanitizedWeights [ k] = v
38113789 }
38123790 }
3813-
38143791 return sanitizedWeights
38153792 }
38163793}
@@ -3828,8 +3805,9 @@ private class Gemma3nAudioModel: Module {
38283805
38293806 self . _subsampleConvProjection. wrappedValue = Gemma3nAudioSubSampleConvProjection (
38303807 config: config)
3831- self . _conformer. wrappedValue = ( 0 ..< config. confNumHiddenLayers) . map { _ in
3832- Gemma3nAudioConformerBlock ( config: config)
3808+
3809+ self . _conformer. wrappedValue = ( 0 ..< config. confNumHiddenLayers) . map { i in
3810+ return Gemma3nAudioConformerBlock ( config: config)
38333811 }
38343812
38353813 super. init ( )
@@ -3914,32 +3892,25 @@ private class Gemma3nAudioModel: Module {
39143892 /// Sanitizes weights by transposing convolution layers if they are not
39153893 /// already in the expected MLX format.
39163894 func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
3917- var sanitizedWeights = [ String : MLXArray ] ( )
3918-
3895+ var sanitizedWeights = weights
3896+ // Iterate over the original keys to decide which ones to modify in the copy.
39193897 for (k, v) in weights {
39203898 if k. contains ( " conv.weight " ) {
3921- // A Conv2D weight should be 4D.
3922- // If it is, check if it needs transposing from NCHW to NHWC.
3923- // If checkArrayShape is true, it's already in the correct format.
3924- if v. ndim == 4 && !checkArrayShape( v) {
3925- sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
3926- } else {
3899+ if checkArrayShape ( v) {
39273900 sanitizedWeights [ k] = v
3901+ } else {
3902+ sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
39283903 }
39293904 } else if k. contains ( " conv1d.weight " ) {
3930- // A Conv1D weight should be 3D.
3931- // If it is, check if it needs transposing from NCL to NLC.
3932- if v. ndim == 3 && !checkArrayShape( v) {
3933- sanitizedWeights [ k] = v. transposed ( 0 , 2 , 1 )
3934- } else {
3905+ if true {
39353906 sanitizedWeights [ k] = v
3907+ } else {
3908+ sanitizedWeights [ k] = v. transposed ( 0 , 2 , 1 )
39363909 }
39373910 } else {
3938- // For all other weights, keep them as they are.
39393911 sanitizedWeights [ k] = v
39403912 }
39413913 }
3942-
39433914 return sanitizedWeights
39443915 }
39453916}
0 commit comments