Skip to content

Commit b76f9e4

Browse files
committed
use ReerCodable macro to allow for default values
- make all configuration Sendable and public var as well
1 parent 51f51cf commit b76f9e4

35 files changed

+640
-1525
lines changed

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ jobs:
3838
xcrun --show-sdk-build-version
3939
swift --version
4040
find . -name Package.resolved -exec rm {} \;
41-
xcodebuild test -scheme mlx-libraries-Package -destination 'platform=OS X'
41+
xcodebuild test -scheme mlx-libraries-Package -destination 'platform=OS X' -skipMacroValidation
4242
- run:
4343
name: Build Examples
4444
command: |
4545
xcodebuild -version
4646
xcrun --show-sdk-build-version
4747
swift --version
4848
find . -name Package.resolved -exec rm {} \;
49-
xcodebuild -scheme llm-tool
50-
xcodebuild -scheme image-tool
51-
xcodebuild -scheme mnist-tool
49+
xcodebuild -scheme llm-tool -skipMacroValidation
50+
xcodebuild -scheme image-tool -skipMacroValidation
51+
xcodebuild -scheme mnist-tool -skipMacroValidation
5252
5353
workflows:
5454
build_and_test:

Libraries/Embedders/Pooling.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import MLX
55
import MLXLinalg
66
import MLXNN
77

8-
public struct PoolingConfiguration: Codable {
8+
public struct PoolingConfiguration: Codable, Sendable {
99
public let dimension: Int
1010
public let poolingModeClsToken: Bool
1111
public let poolingModeMeanTokens: Bool
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import Foundation
2+
3+
/// `swift-transformers` also declares a public `Decoder` and it conflicts with the `Codable`
4+
/// implementations.
5+
public typealias Decoder = Swift.Decoder

Libraries/MLXLLM/Lora+Data.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public func loadLoRAData(url: URL) throws -> [String] {
4848

4949
func loadJSONL(url: URL) throws -> [String] {
5050

51-
struct Line: Codable {
51+
struct Line: Codable, Sendable {
5252
let text: String?
5353
}
5454

Libraries/MLXLLM/Models/Cohere.swift

Lines changed: 16 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ import Foundation
22
import MLX
33
import MLXLMCommon
44
import MLXNN
5+
import ReerCodable
56

6-
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/cohere.py
7+
// port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/cohere.py
78

89
private class Attention: Module {
910

@@ -172,63 +173,21 @@ public class CohereModel: Module, LLMModel, KVCacheDimensionProvider {
172173
}
173174
}
174175

175-
public struct CohereConfiguration: Codable, Sendable {
176-
177-
var hiddenSize: Int
178-
var hiddenLayers: Int
179-
var intermediateSize: Int
180-
var attentionHeads: Int
181-
var layerNormEps: Float
182-
var vocabularySize: Int
183-
var kvHeads: Int
184-
var ropeTheta: Float = 8000000.0
185-
var ropeTraditional: Bool = true
186-
var ropeScaling: [String: StringOrNumber]? = nil
187-
var logitScale: Float
188-
189-
enum CodingKeys: String, CodingKey {
190-
case hiddenSize = "hidden_size"
191-
case hiddenLayers = "num_hidden_layers"
192-
case intermediateSize = "intermediate_size"
193-
case attentionHeads = "num_attention_heads"
194-
case kvHeads = "num_key_value_heads"
195-
case ropeTheta = "rope_theta"
196-
case vocabularySize = "vocab_size"
197-
case layerNormEps = "layer_norm_eps"
198-
case logitScale = "logit_scale"
199-
case ropeTraditional = "rope_traditional"
200-
case ropeScaling = "rope_scaling"
201-
}
176+
@Codable
177+
public struct CohereConfiguration: Sendable {
178+
179+
@CodingKey("hidden_size") public var hiddenSize: Int = 8192
180+
@CodingKey("num_hidden_layers") public var hiddenLayers: Int = 40
181+
@CodingKey("intermediate_size") public var intermediateSize: Int = 22528
182+
@CodingKey("num_attention_heads") public var attentionHeads: Int = 64
183+
@CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-5
184+
@CodingKey("vocab_size") public var vocabularySize: Int = 256000
185+
@CodingKey("num_key_value_heads") public var kvHeads: Int = 64
186+
@CodingKey("rope_theta") public var ropeTheta: Float = 8000000.0
187+
@CodingKey("rope_traditional") public var ropeTraditional: Bool = true
188+
@CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil
189+
@CodingKey("logit_scale") public var logitScale: Float = 0.0625
202190

203-
public init(from decoder: Decoder) throws {
204-
// custom implementation to handle optional keys with required values
205-
let container: KeyedDecodingContainer<CohereConfiguration.CodingKeys> =
206-
try decoder.container(
207-
keyedBy: CohereConfiguration.CodingKeys.self)
208-
209-
self.hiddenSize = try container.decode(
210-
Int.self, forKey: CohereConfiguration.CodingKeys.hiddenSize)
211-
self.hiddenLayers = try container.decode(
212-
Int.self, forKey: CohereConfiguration.CodingKeys.hiddenLayers)
213-
self.intermediateSize = try container.decode(
214-
Int.self, forKey: CohereConfiguration.CodingKeys.intermediateSize)
215-
self.attentionHeads = try container.decode(
216-
Int.self, forKey: CohereConfiguration.CodingKeys.attentionHeads)
217-
self.layerNormEps = try container.decode(
218-
Float.self, forKey: CohereConfiguration.CodingKeys.layerNormEps)
219-
self.vocabularySize = try container.decode(
220-
Int.self, forKey: CohereConfiguration.CodingKeys.vocabularySize)
221-
self.kvHeads = try container.decode(
222-
Int.self, forKey: CohereConfiguration.CodingKeys.kvHeads)
223-
self.ropeTheta =
224-
try container.decodeIfPresent(
225-
Float.self, forKey: CohereConfiguration.CodingKeys.ropeTheta)
226-
?? 8000000.0
227-
self.ropeScaling = try container.decodeIfPresent(
228-
[String: StringOrNumber].self, forKey: CohereConfiguration.CodingKeys.ropeScaling)
229-
self.logitScale = try container.decode(
230-
Float.self, forKey: CohereConfiguration.CodingKeys.logitScale)
231-
}
232191
}
233192

234193
// MARK: - LoRA

Libraries/MLXLLM/Models/GLM4.swift

Lines changed: 17 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import Foundation
99
import MLX
1010
import MLXLMCommon
1111
import MLXNN
12+
import ReerCodable
1213

1314
// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/glm4.py
1415

@@ -166,15 +167,13 @@ public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider {
166167

167168
private let model: GLM4ModelInner
168169
let configuration: GLM4Configuration
169-
let modelType: String
170170

171171
@ModuleInfo(key: "lm_head") var lmHead: Linear
172172

173173
public init(_ args: GLM4Configuration) {
174174
self.configuration = args
175175
self.vocabularySize = args.vocabularySize
176176
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
177-
self.modelType = args.modelType
178177
self.model = GLM4ModelInner(args)
179178

180179
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
@@ -196,80 +195,22 @@ public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider {
196195
}
197196
}
198197

199-
public struct GLM4Configuration: Codable, Sendable {
200-
var hiddenSize: Int
201-
var hiddenLayers: Int
202-
var intermediateSize: Int
203-
var attentionHeads: Int
204-
var attentionBias: Bool
205-
var headDim: Int
206-
var rmsNormEps: Float
207-
var vocabularySize: Int
208-
var kvHeads: Int
209-
var partialRotaryFactor: Float
210-
var ropeTheta: Float = 10000.0
211-
var ropeTraditional: Bool = true
212-
var tieWordEmbeddings = false
213-
var maxPositionEmbeddings: Int = 32768
214-
var modelType: String
215-
216-
enum CodingKeys: String, CodingKey {
217-
case hiddenSize = "hidden_size"
218-
case hiddenLayers = "num_hidden_layers"
219-
case intermediateSize = "intermediate_size"
220-
case attentionHeads = "num_attention_heads"
221-
case attentionBias = "attention_bias"
222-
case headDim = "head_dim"
223-
case rmsNormEps = "rms_norm_eps"
224-
case vocabularySize = "vocab_size"
225-
case kvHeads = "num_key_value_heads"
226-
case partialRotaryFactor = "partial_rotary_factor"
227-
case ropeTheta = "rope_theta"
228-
case ropeTraditional = "rope_traditional"
229-
case tieWordEmbeddings = "tie_word_embeddings"
230-
case maxPositionEmbeddings = "max_position_embeddings"
231-
case modelType = "model_type"
232-
}
233-
234-
public init(from decoder: Decoder) throws {
235-
let container: KeyedDecodingContainer<GLM4Configuration.CodingKeys> =
236-
try decoder.container(
237-
keyedBy: GLM4Configuration.CodingKeys.self)
238-
239-
self.modelType = try container.decode(
240-
String.self, forKey: GLM4Configuration.CodingKeys.modelType)
241-
self.hiddenSize = try container.decode(
242-
Int.self, forKey: GLM4Configuration.CodingKeys.hiddenSize)
243-
self.hiddenLayers = try container.decode(
244-
Int.self, forKey: GLM4Configuration.CodingKeys.hiddenLayers)
245-
self.intermediateSize = try container.decode(
246-
Int.self, forKey: GLM4Configuration.CodingKeys.intermediateSize)
247-
self.attentionHeads = try container.decode(
248-
Int.self, forKey: GLM4Configuration.CodingKeys.attentionHeads)
249-
self.attentionBias = try container.decode(
250-
Bool.self, forKey: GLM4Configuration.CodingKeys.attentionBias)
251-
self.headDim = try container.decode(
252-
Int.self, forKey: GLM4Configuration.CodingKeys.headDim)
253-
self.rmsNormEps = try container.decode(
254-
Float.self, forKey: GLM4Configuration.CodingKeys.rmsNormEps)
255-
self.vocabularySize = try container.decode(
256-
Int.self, forKey: GLM4Configuration.CodingKeys.vocabularySize)
257-
self.kvHeads = try container.decode(Int.self, forKey: GLM4Configuration.CodingKeys.kvHeads)
258-
self.partialRotaryFactor = try container.decode(
259-
Float.self, forKey: GLM4Configuration.CodingKeys.partialRotaryFactor)
260-
self.ropeTheta =
261-
try container.decodeIfPresent(
262-
Float.self, forKey: GLM4Configuration.CodingKeys.ropeTheta)
263-
?? 10000.0
264-
self.ropeTraditional =
265-
try container.decodeIfPresent(
266-
Bool.self, forKey: GLM4Configuration.CodingKeys.ropeTraditional)
267-
?? true
268-
self.tieWordEmbeddings =
269-
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
270-
self.maxPositionEmbeddings =
271-
try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768
272-
}
198+
@Codable
199+
public struct GLM4Configuration: Sendable {
200+
@CodingKey("hidden_size") public var hiddenSize: Int
201+
@CodingKey("num_hidden_layers") public var hiddenLayers: Int
202+
@CodingKey("intermediate_size") public var intermediateSize: Int
203+
@CodingKey("num_attention_heads") public var attentionHeads: Int
204+
@CodingKey("attention_bias") public var attentionBias: Bool
205+
@CodingKey("head_dim") public var headDim: Int
206+
@CodingKey("rms_norm_eps") public var rmsNormEps: Float
207+
@CodingKey("vocab_size") public var vocabularySize: Int
208+
@CodingKey("num_key_value_heads") public var kvHeads: Int
209+
@CodingKey("partial_rotary_factor") public var partialRotaryFactor: Float
210+
@CodingKey("rope_theta") public var ropeTheta: Float = 10000.0
211+
@CodingKey("rope_traditional") public var ropeTraditional: Bool = true
212+
@CodingKey("tie_word_embeddings") public var tieWordEmbeddings = false
213+
@CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 32768
273214
}
274215

275216
// MARK: - LoRA

Libraries/MLXLLM/Models/Gemma.swift

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import MLX
55
import MLXLMCommon
66
import MLXNN
77
import Tokenizers
8+
import ReerCodable
89

9-
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
10+
// Port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/gemma.py
1011

1112
// Specialized norm for Gemma
1213
private class RMSNorm: Module, UnaryLayer {
@@ -178,11 +179,9 @@ public class GemmaModel: Module, LLMModel, KVCacheDimensionProvider {
178179
public let vocabularySize: Int
179180
public let kvHeads: [Int]
180181

181-
let modelType: String
182182
private let model: GemmaModelInner
183183

184184
public init(_ args: GemmaConfiguration) {
185-
self.modelType = args.modelType
186185
self.vocabularySize = args.vocabularySize
187186
self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
188187
self.model = GemmaModelInner(args)
@@ -198,34 +197,18 @@ public class GemmaModel: Module, LLMModel, KVCacheDimensionProvider {
198197
}
199198
}
200199

201-
public struct GemmaConfiguration: Codable, Sendable {
202-
var modelType: String
203-
var hiddenSize: Int
204-
var hiddenLayers: Int
205-
var intermediateSize: Int
206-
var attentionHeads: Int
207-
var headDimensions: Int
208-
var rmsNormEps: Float
209-
var vocabularySize: Int
210-
var kvHeads: Int
211-
private let _ropeTheta: Float?
212-
public var ropeTheta: Float { _ropeTheta ?? 10_000 }
213-
private let _ropeTraditional: Bool?
214-
public var ropeTraditional: Bool { _ropeTraditional ?? false }
215-
216-
enum CodingKeys: String, CodingKey {
217-
case modelType = "model_type"
218-
case hiddenSize = "hidden_size"
219-
case hiddenLayers = "num_hidden_layers"
220-
case intermediateSize = "intermediate_size"
221-
case attentionHeads = "num_attention_heads"
222-
case headDimensions = "head_dim"
223-
case rmsNormEps = "rms_norm_eps"
224-
case vocabularySize = "vocab_size"
225-
case kvHeads = "num_key_value_heads"
226-
case _ropeTheta = "rope_theta"
227-
case _ropeTraditional = "rope_traditional"
228-
}
200+
@Codable
201+
public struct GemmaConfiguration: Sendable {
202+
@CodingKey("hidden_size") public var hiddenSize: Int
203+
@CodingKey("num_hidden_layers") public var hiddenLayers: Int
204+
@CodingKey("intermediate_size") public var intermediateSize: Int
205+
@CodingKey("num_attention_heads") public var attentionHeads: Int
206+
@CodingKey("head_dim") public var headDimensions: Int
207+
@CodingKey("rms_norm_eps") public var rmsNormEps: Float
208+
@CodingKey("vocab_size") public var vocabularySize: Int
209+
@CodingKey("num_key_value_heads") public var kvHeads: Int
210+
@CodingKey("rope_theta") public var ropeTheta: Float = 10_000
211+
@CodingKey("rope_traditional") public var ropeTraditional: Bool = false
229212
}
230213

231214
// MARK: - LoRA

0 commit comments

Comments
 (0)