@@ -9,6 +9,7 @@ import Foundation
99import MLX
1010import MLXLMCommon
1111import MLXNN
12+ import ReerCodable
1213
1314// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/glm4.py
1415
@@ -166,15 +167,13 @@ public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider {
166167
167168 private let model : GLM4ModelInner
168169 let configuration : GLM4Configuration
169- let modelType : String
170170
171171 @ModuleInfo ( key: " lm_head " ) var lmHead : Linear
172172
173173 public init ( _ args: GLM4Configuration ) {
174174 self . configuration = args
175175 self . vocabularySize = args. vocabularySize
176176 self . kvHeads = ( 0 ..< args. hiddenLayers) . map { _ in args. kvHeads }
177- self . modelType = args. modelType
178177 self . model = GLM4ModelInner ( args)
179178
180179 _lmHead. wrappedValue = Linear ( args. hiddenSize, args. vocabularySize, bias: false )
@@ -196,80 +195,22 @@ public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider {
196195 }
197196}
198197
199- public struct GLM4Configuration : Codable , Sendable {
200- var hiddenSize : Int
201- var hiddenLayers : Int
202- var intermediateSize : Int
203- var attentionHeads : Int
204- var attentionBias : Bool
205- var headDim : Int
206- var rmsNormEps : Float
207- var vocabularySize : Int
208- var kvHeads : Int
209- var partialRotaryFactor : Float
210- var ropeTheta : Float = 10000.0
211- var ropeTraditional : Bool = true
212- var tieWordEmbeddings = false
213- var maxPositionEmbeddings : Int = 32768
214- var modelType : String
215-
216- enum CodingKeys : String , CodingKey {
217- case hiddenSize = " hidden_size "
218- case hiddenLayers = " num_hidden_layers "
219- case intermediateSize = " intermediate_size "
220- case attentionHeads = " num_attention_heads "
221- case attentionBias = " attention_bias "
222- case headDim = " head_dim "
223- case rmsNormEps = " rms_norm_eps "
224- case vocabularySize = " vocab_size "
225- case kvHeads = " num_key_value_heads "
226- case partialRotaryFactor = " partial_rotary_factor "
227- case ropeTheta = " rope_theta "
228- case ropeTraditional = " rope_traditional "
229- case tieWordEmbeddings = " tie_word_embeddings "
230- case maxPositionEmbeddings = " max_position_embeddings "
231- case modelType = " model_type "
232- }
233-
234- public init ( from decoder: Decoder ) throws {
235- let container : KeyedDecodingContainer < GLM4Configuration . CodingKeys > =
236- try decoder. container (
237- keyedBy: GLM4Configuration . CodingKeys. self)
238-
239- self . modelType = try container. decode (
240- String . self, forKey: GLM4Configuration . CodingKeys. modelType)
241- self . hiddenSize = try container. decode (
242- Int . self, forKey: GLM4Configuration . CodingKeys. hiddenSize)
243- self . hiddenLayers = try container. decode (
244- Int . self, forKey: GLM4Configuration . CodingKeys. hiddenLayers)
245- self . intermediateSize = try container. decode (
246- Int . self, forKey: GLM4Configuration . CodingKeys. intermediateSize)
247- self . attentionHeads = try container. decode (
248- Int . self, forKey: GLM4Configuration . CodingKeys. attentionHeads)
249- self . attentionBias = try container. decode (
250- Bool . self, forKey: GLM4Configuration . CodingKeys. attentionBias)
251- self . headDim = try container. decode (
252- Int . self, forKey: GLM4Configuration . CodingKeys. headDim)
253- self . rmsNormEps = try container. decode (
254- Float . self, forKey: GLM4Configuration . CodingKeys. rmsNormEps)
255- self . vocabularySize = try container. decode (
256- Int . self, forKey: GLM4Configuration . CodingKeys. vocabularySize)
257- self . kvHeads = try container. decode ( Int . self, forKey: GLM4Configuration . CodingKeys. kvHeads)
258- self . partialRotaryFactor = try container. decode (
259- Float . self, forKey: GLM4Configuration . CodingKeys. partialRotaryFactor)
260- self . ropeTheta =
261- try container. decodeIfPresent (
262- Float . self, forKey: GLM4Configuration . CodingKeys. ropeTheta)
263- ?? 10000.0
264- self . ropeTraditional =
265- try container. decodeIfPresent (
266- Bool . self, forKey: GLM4Configuration . CodingKeys. ropeTraditional)
267- ?? true
268- self . tieWordEmbeddings =
269- try container. decodeIfPresent ( Bool . self, forKey: . tieWordEmbeddings) ?? false
270- self . maxPositionEmbeddings =
271- try container. decodeIfPresent ( Int . self, forKey: . maxPositionEmbeddings) ?? 32768
272- }
198+ @Codable
199+ public struct GLM4Configuration : Sendable {
200+ @CodingKey ( " hidden_size " ) public var hiddenSize : Int
201+ @CodingKey ( " num_hidden_layers " ) public var hiddenLayers : Int
202+ @CodingKey ( " intermediate_size " ) public var intermediateSize : Int
203+ @CodingKey ( " num_attention_heads " ) public var attentionHeads : Int
204+ @CodingKey ( " attention_bias " ) public var attentionBias : Bool
205+ @CodingKey ( " head_dim " ) public var headDim : Int
206+ @CodingKey ( " rms_norm_eps " ) public var rmsNormEps : Float
207+ @CodingKey ( " vocab_size " ) public var vocabularySize : Int
208+ @CodingKey ( " num_key_value_heads " ) public var kvHeads : Int
209+ @CodingKey ( " partial_rotary_factor " ) public var partialRotaryFactor : Float
210+ @CodingKey ( " rope_theta " ) public var ropeTheta : Float = 10000.0
211+ @CodingKey ( " rope_traditional " ) public var ropeTraditional : Bool = true
212+ @CodingKey ( " tie_word_embeddings " ) public var tieWordEmbeddings = false
213+ @CodingKey ( " max_position_embeddings " ) public var maxPositionEmbeddings : Int = 32768
273214}
274215
275216// MARK: - LoRA
0 commit comments