diff --git a/.circleci/config.yml b/.circleci/config.yml index 5a97f48a..8ffa27c1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -38,7 +38,7 @@ jobs: xcrun --show-sdk-build-version swift --version find . -name Package.resolved -exec rm {} \; - xcodebuild test -scheme mlx-libraries-Package -destination 'platform=OS X' + xcodebuild test -scheme mlx-libraries-Package -destination 'platform=OS X' -skipMacroValidation - run: name: Build Examples command: | @@ -46,9 +46,9 @@ jobs: xcrun --show-sdk-build-version swift --version find . -name Package.resolved -exec rm {} \; - xcodebuild -scheme llm-tool - xcodebuild -scheme image-tool - xcodebuild -scheme mnist-tool + xcodebuild -scheme llm-tool -skipMacroValidation + xcodebuild -scheme image-tool -skipMacroValidation + xcodebuild -scheme mnist-tool -skipMacroValidation workflows: build_and_test: diff --git a/Libraries/Embedders/Pooling.swift b/Libraries/Embedders/Pooling.swift index 912f37b8..39341416 100644 --- a/Libraries/Embedders/Pooling.swift +++ b/Libraries/Embedders/Pooling.swift @@ -2,23 +2,16 @@ import Foundation import MLX -import MLXLinalg import MLXNN +import ReerCodable -public struct PoolingConfiguration: Codable { - public let dimension: Int - public let poolingModeClsToken: Bool - public let poolingModeMeanTokens: Bool - public let poolingModeMaxTokens: Bool - public let poolingModeLastToken: Bool - - enum CodingKeys: String, CodingKey { - case dimension = "word_embedding_dimension" - case poolingModeClsToken = "pooling_mode_cls_token" - case poolingModeMeanTokens = "pooling_mode_mean_tokens" - case poolingModeMaxTokens = "pooling_mode_max_tokens" - case poolingModeLastToken = "pooling_mode_lasttoken" - } +@Codable +public struct PoolingConfiguration: Sendable { + @CodingKey("word_embedding_dimension") public let dimension: Int + @CodingKey("pooling_mode_cls_token") public let poolingModeClsToken: Bool + @CodingKey("pooling_mode_mean_tokens") public let poolingModeMeanTokens: Bool + @CodingKey("pooling_mode_max_tokens") public let poolingModeMaxTokens: Bool + @CodingKey("pooling_mode_lasttoken") public let poolingModeLastToken: Bool } func loadPooling(modelDirectory: URL) -> Pooling { diff --git a/Libraries/MLXLLM/Codable+Support.swift b/Libraries/MLXLLM/Codable+Support.swift new file mode 100644 index 00000000..84329926 --- /dev/null +++ b/Libraries/MLXLLM/Codable+Support.swift @@ -0,0 +1,5 @@ +import Foundation + +/// `swift-transformers` also declares a public `Decoder` and it conflicts with the `Codable` +/// implementations. +public typealias Decoder = Swift.Decoder diff --git a/Libraries/MLXLLM/Documentation.docc/adding-model.md b/Libraries/MLXLLM/Documentation.docc/adding-model.md index a8662374..b3cb7190 100644 --- a/Libraries/MLXLLM/Documentation.docc/adding-model.md +++ b/Libraries/MLXLLM/Documentation.docc/adding-model.md @@ -14,17 +14,12 @@ and create a `.swift` file for your new model: Create a configuration struct to match the `config.json` (any parameters needed). ```swift -public struct YourModelConfiguration: Codable, Sendable { - public let hiddenSize: Int - - // use this pattern for values that need defaults - public let _layerNormEps: Float? - public var layerNormEps: Float { _layerNormEps ?? 1e-6 } - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case _layerNormEps = "layer_norm_eps" - } +import ReerCodable + +@Codable +public struct YourModelConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-6 } ``` diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index ea8aedba..7954005a 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -35,9 +35,9 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable { "phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init), "gemma": create(GemmaConfiguration.self, GemmaModel.init), "gemma2": create(Gemma2Configuration.self, Gemma2Model.init), - "gemma3": create(Gemma3TextConfiguration.self, Gemma3TextModel.init), - "gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init), - "gemma3n": create(Gemma3nTextConfiguration.self, Gemma3nTextModel.init), + "gemma3": create(Gemma3TextConfigurationContainer.self, Gemma3TextModel.init), + "gemma3_text": create(Gemma3TextConfigurationContainer.self, Gemma3TextModel.init), + "gemma3n": create(Gemma3nTextConfigurationContainer.self, Gemma3nTextModel.init), "qwen2": create(Qwen2Configuration.self, Qwen2Model.init), "qwen3": create(Qwen3Configuration.self, Qwen3Model.init), "qwen3_moe": create(Qwen3MoEConfiguration.self, Qwen3MoEModel.init), diff --git a/Libraries/MLXLLM/Lora+Data.swift b/Libraries/MLXLLM/Lora+Data.swift index 975e41f4..defa8f5a 100644 --- a/Libraries/MLXLLM/Lora+Data.swift +++ b/Libraries/MLXLLM/Lora+Data.swift @@ -48,7 +48,7 @@ public func loadLoRAData(url: URL) throws -> [String] { func loadJSONL(url: URL) throws -> [String] { - struct Line: Codable { + struct Line: Codable, Sendable { let text: String? } diff --git a/Libraries/MLXLLM/Models/BaichuanM1.swift b/Libraries/MLXLLM/Models/BaichuanM1.swift index c2d707fb..5c4712fe 100644 --- a/Libraries/MLXLLM/Models/BaichuanM1.swift +++ b/Libraries/MLXLLM/Models/BaichuanM1.swift @@ -7,43 +7,26 @@ import Foundation import MLX -import MLXFast import MLXLMCommon import MLXNN -import MLXRandom - -public struct BaichuanM1Configuration: Codable, Sendable { - var vocabularySize: Int - var hiddenSize: Int - var intermediateSize: Int - var hiddenLayers: Int - var attentionHeads: Int - var kvHeads: Int - var ropeTheta: Float - var slidingWindow: Int - var slidingWindowLayers: [Int] - var convWindow: Int - var rmsNormEps: Float - var swaAttentionHeads: Int? - var swaKvHeads: Int? - var tieWordEmbeddings: Bool = false - - enum CodingKeys: String, CodingKey { - case vocabularySize = "vocab_size" - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case hiddenLayers = "num_hidden_layers" - case attentionHeads = "num_attention_heads" - case kvHeads = "num_key_value_heads" - case ropeTheta = "rope_theta" - case slidingWindow = "sliding_window" - case slidingWindowLayers = "sliding_window_layers" - case convWindow = "conv_window" - case rmsNormEps = "rms_norm_eps" - case swaAttentionHeads = "num_swa_attention_heads" - case swaKvHeads = "num_swa_key_value_heads" - case tieWordEmbeddings = "tie_word_embeddings" - } +import ReerCodable + +@Codable +public struct BaichuanM1Configuration: Sendable { + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("sliding_window") public var slidingWindow: Int + @CodingKey("sliding_window_layers") public var slidingWindowLayers: [Int] + @CodingKey("conv_window") public var convWindow: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("num_swa_attention_heads") public var swaAttentionHeads: Int? + @CodingKey("num_swa_key_value_heads") public var swaKvHeads: Int? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = false } private class Attention: Module { diff --git a/Libraries/MLXLLM/Models/BailingMoe.swift b/Libraries/MLXLLM/Models/BailingMoe.swift index 4feba094..deb7fe32 100644 --- a/Libraries/MLXLLM/Models/BailingMoe.swift +++ b/Libraries/MLXLLM/Models/BailingMoe.swift @@ -10,69 +10,41 @@ import Foundation import MLX import MLXLMCommon import MLXNN - -public struct BailingMoeConfiguration: Codable, Sendable { - var modelType: String - var hiddenSize: Int - var intermediateSize: Int - var maxPositionEmbeddings: Int? - var moeIntermediateSize: Int - var numExperts: Int - var numSharedExperts: Int - var normTopkProb: Bool - var attentionHeads: Int - var numExpertsPerToken: Int - var hiddenLayers: Int - var kvHeads: Int - var rmsNormEps: Float - var ropeTheta: Float - var vocabularySize: Int - var firstKDenseReplace: Int +import ReerCodable + +@Codable +public struct BailingMoeConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int? + @CodingKey("moe_intermediate_size") public var moeIntermediateSize: Int + @CodingKey("num_experts") public var numExperts: Int + @CodingKey("num_shared_experts") public var numSharedExperts: Int + @CodingKey("norm_topk_prob") public var normTopkProb: Bool + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("num_experts_per_tok") public var numExpertsPerToken: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("first_k_dense_replace") public var firstKDenseReplace: Int // Optional features - var ropeScaling: [String: StringOrNumber]? = nil - var useBias: Bool = false - var useQKVBias: Bool = false - var useQKNorm: Bool = false - var tieWordEmbeddings: Bool = false - var partialRotaryFactor: Float = 1.0 - var moeRouterEnableExpertBias: Bool = false - var routedScalingFactor: Float = 1.0 - var scoreFunction: String = "softmax" - var nGroup: Int = 1 - var topkGroup: Int = 4 - var moeSharedExpertIntermediateSize: Int? = nil - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case maxPositionEmbeddings = "max_position_embeddings" - case moeIntermediateSize = "moe_intermediate_size" - case numExperts = "num_experts" - case numSharedExperts = "num_shared_experts" - case normTopkProb = "norm_topk_prob" - case attentionHeads = "num_attention_heads" - case numExpertsPerToken = "num_experts_per_tok" - case hiddenLayers = "num_hidden_layers" - case kvHeads = "num_key_value_heads" - case rmsNormEps = "rms_norm_eps" - case ropeTheta = "rope_theta" - case vocabularySize = "vocab_size" - case firstKDenseReplace = "first_k_dense_replace" - case ropeScaling = "rope_scaling" - case useBias = "use_bias" - case useQKVBias = "use_qkv_bias" - case useQKNorm = "use_qk_norm" - case tieWordEmbeddings = "tie_word_embeddings" - case partialRotaryFactor = "partial_rotary_factor" - case moeRouterEnableExpertBias = "moe_router_enable_expert_bias" - case routedScalingFactor = "routed_scaling_factor" - case scoreFunction = "score_function" - case nGroup = "n_group" - case topkGroup = "topk_group" - case moeSharedExpertIntermediateSize = "moe_shared_expert_intermediate_size" - } + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil + @CodingKey("use_bias") public var useBias: Bool = false + @CodingKey("use_qkv_bias") public var useQKVBias: Bool = false + @CodingKey("use_qk_norm") public var useQKNorm: Bool = false + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = false + @CodingKey("partial_rotary_factor") public var partialRotaryFactor: Float = 1.0 + @CodingKey("moe_router_enable_expert_bias") public var moeRouterEnableExpertBias: Bool = false + @CodingKey("routed_scaling_factor") public var routedScalingFactor: Float = 1.0 + @CodingKey("score_function") public var scoreFunction: String = "softmax" + @CodingKey("n_group") public var nGroup: Int = 1 + @CodingKey("topk_group") public var topkGroup: Int = 4 + @CodingKey("moe_shared_expert_intermediate_size") public var moeSharedExpertIntermediateSize: + Int? = nil } private class Attention: Module { diff --git a/Libraries/MLXLLM/Models/Bitnet.swift b/Libraries/MLXLLM/Models/Bitnet.swift index bda67664..866d4bdf 100644 --- a/Libraries/MLXLLM/Models/Bitnet.swift +++ b/Libraries/MLXLLM/Models/Bitnet.swift @@ -7,9 +7,9 @@ import Foundation import MLX -import MLXFast import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/bitnet.py @@ -55,7 +55,7 @@ private func makeBitLinearKernel() -> MLXFast.MLXFastKernel { } """ - return metalKernel( + return MLXFast.metalKernel( name: "bitlinear_matmul", inputNames: ["x", "packed_weights", "weight_scale"], outputNames: ["out"], @@ -155,113 +155,32 @@ private class BitLinear: Module { // MARK: - Model Configuration -public struct BitnetConfiguration: Codable, Sendable { - var modelType: String - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var headDimensions: Int? - var maxPositionEmbeddings: Int? - var kvHeads: Int? - var attentionBias: Bool - var mlpBias: Bool - var ropeTheta: Float - var ropeTraditional: Bool - var ropeScaling: [String: StringOrNumber]? - var tieWordEmbeddings: Bool - - public init( - modelType: String = "bitnet", - hiddenSize: Int, - hiddenLayers: Int, - intermediateSize: Int, - attentionHeads: Int, - rmsNormEps: Float, - vocabularySize: Int, - headDimensions: Int? = nil, - maxPositionEmbeddings: Int? = nil, - kvHeads: Int? = nil, - attentionBias: Bool = false, - mlpBias: Bool = false, - ropeTheta: Float = 10000, - ropeTraditional: Bool = false, - ropeScaling: [String: StringOrNumber]? = nil, - tieWordEmbeddings: Bool = true - ) { - self.modelType = modelType - self.hiddenSize = hiddenSize - self.hiddenLayers = hiddenLayers - self.intermediateSize = intermediateSize - self.attentionHeads = attentionHeads - self.rmsNormEps = rmsNormEps - self.vocabularySize = vocabularySize - self.headDimensions = headDimensions - self.maxPositionEmbeddings = maxPositionEmbeddings - self.kvHeads = kvHeads ?? attentionHeads - self.attentionBias = attentionBias - self.mlpBias = mlpBias - self.ropeTheta = ropeTheta - self.ropeTraditional = ropeTraditional - self.ropeScaling = ropeScaling - self.tieWordEmbeddings = tieWordEmbeddings - } - - var resolvedKvHeads: Int { +@Codable +public struct BitnetConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String = "bitnet" + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("head_dim") public var headDimensions: Int? + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int? + @CodingKey("num_key_value_heads") public var kvHeads: Int? + @CodingKey("attention_bias") public var attentionBias: Bool = false + @CodingKey("mlp_bias") public var mlpBias: Bool = false + @CodingKey("rope_theta") public var ropeTheta: Float = 10000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true + + public var resolvedKvHeads: Int { kvHeads ?? attentionHeads } - var resolvedHeadDimensions: Int { + public var resolvedHeadDimensions: Int { headDimensions ?? (hiddenSize / attentionHeads) } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case headDimensions = "head_dim" - case maxPositionEmbeddings = "max_position_embeddings" - case kvHeads = "num_key_value_heads" - case attentionBias = "attention_bias" - case mlpBias = "mlp_bias" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - } - - public init(from decoder: Swift.Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - modelType = try container.decodeIfPresent(String.self, forKey: .modelType) ?? "bitnet" - hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) - maxPositionEmbeddings = try container.decodeIfPresent( - Int.self, forKey: .maxPositionEmbeddings - ) - kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads - attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) ?? false - mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) ?? false - ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10000 - ropeTraditional = - try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false - ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling - ) - tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? true - } } // MARK: - Attention diff --git a/Libraries/MLXLLM/Models/Cohere.swift b/Libraries/MLXLLM/Models/Cohere.swift index 470c9799..23f15c31 100644 --- a/Libraries/MLXLLM/Models/Cohere.swift +++ b/Libraries/MLXLLM/Models/Cohere.swift @@ -2,8 +2,9 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable -// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/cohere.py +// port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/cohere.py private class Attention: Module { @@ -172,63 +173,21 @@ public class CohereModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct CohereConfiguration: Codable, Sendable { - - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var layerNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var ropeTheta: Float = 8000000.0 - var ropeTraditional: Bool = true - var ropeScaling: [String: StringOrNumber]? = nil - var logitScale: Float - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case kvHeads = "num_key_value_heads" - case ropeTheta = "rope_theta" - case vocabularySize = "vocab_size" - case layerNormEps = "layer_norm_eps" - case logitScale = "logit_scale" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - } +@Codable +public struct CohereConfiguration: Sendable { + + @CodingKey("hidden_size") public var hiddenSize: Int = 8192 + @CodingKey("num_hidden_layers") public var hiddenLayers: Int = 40 + @CodingKey("intermediate_size") public var intermediateSize: Int = 22528 + @CodingKey("num_attention_heads") public var attentionHeads: Int = 64 + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-5 + @CodingKey("vocab_size") public var vocabularySize: Int = 256000 + @CodingKey("num_key_value_heads") public var kvHeads: Int = 64 + @CodingKey("rope_theta") public var ropeTheta: Float = 8000000.0 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = true + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil + @CodingKey("logit_scale") public var logitScale: Float = 0.0625 - public init(from decoder: Decoder) throws { - // custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = - try decoder.container( - keyedBy: CohereConfiguration.CodingKeys.self) - - self.hiddenSize = try container.decode( - Int.self, forKey: CohereConfiguration.CodingKeys.hiddenSize) - self.hiddenLayers = try container.decode( - Int.self, forKey: CohereConfiguration.CodingKeys.hiddenLayers) - self.intermediateSize = try container.decode( - Int.self, forKey: CohereConfiguration.CodingKeys.intermediateSize) - self.attentionHeads = try container.decode( - Int.self, forKey: CohereConfiguration.CodingKeys.attentionHeads) - self.layerNormEps = try container.decode( - Float.self, forKey: CohereConfiguration.CodingKeys.layerNormEps) - self.vocabularySize = try container.decode( - Int.self, forKey: CohereConfiguration.CodingKeys.vocabularySize) - self.kvHeads = try container.decode( - Int.self, forKey: CohereConfiguration.CodingKeys.kvHeads) - self.ropeTheta = - try container.decodeIfPresent( - Float.self, forKey: CohereConfiguration.CodingKeys.ropeTheta) - ?? 8000000.0 - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: CohereConfiguration.CodingKeys.ropeScaling) - self.logitScale = try container.decode( - Float.self, forKey: CohereConfiguration.CodingKeys.logitScale) - } } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/DeepseekV3.swift b/Libraries/MLXLLM/Models/DeepseekV3.swift index 6e577188..0daf8c0b 100644 --- a/Libraries/MLXLLM/Models/DeepseekV3.swift +++ b/Libraries/MLXLLM/Models/DeepseekV3.swift @@ -2,67 +2,39 @@ import Foundation import MLX -import MLXFast import MLXLLM import MLXLMCommon import MLXNN - -public struct DeepseekV3Configuration: Codable, Sendable { - var vocabSize: Int - var hiddenSize: Int - var intermediateSize: Int - var moeIntermediateSize: Int - var numHiddenLayers: Int - var numAttentionHeads: Int - var numKeyValueHeads: Int - var nSharedExperts: Int? - var nRoutedExperts: Int? - var routedScalingFactor: Float - var kvLoraRank: Int - var qLoraRank: Int - var qkRopeHeadDim: Int - var vHeadDim: Int - var qkNopeHeadDim: Int - var normTopkProb: Bool - var nGroup: Int? - var topkGroup: Int? - var numExpertsPerTok: Int? - var moeLayerFreq: Int - var firstKDenseReplace: Int - var maxPositionEmbeddings: Int - var rmsNormEps: Float - var ropeTheta: Float - var ropeScaling: [String: StringOrNumber]? - var attentionBias: Bool - - enum CodingKeys: String, CodingKey { - case vocabSize = "vocab_size" - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case moeIntermediateSize = "moe_intermediate_size" - case numHiddenLayers = "num_hidden_layers" - case numAttentionHeads = "num_attention_heads" - case numKeyValueHeads = "num_key_value_heads" - case nSharedExperts = "n_shared_experts" - case nRoutedExperts = "n_routed_experts" - case routedScalingFactor = "routed_scaling_factor" - case kvLoraRank = "kv_lora_rank" - case qLoraRank = "q_lora_rank" - case qkRopeHeadDim = "qk_rope_head_dim" - case vHeadDim = "v_head_dim" - case qkNopeHeadDim = "qk_nope_head_dim" - case normTopkProb = "norm_topk_prob" - case nGroup = "n_group" - case topkGroup = "topk_group" - case numExpertsPerTok = "num_experts_per_tok" - case moeLayerFreq = "moe_layer_freq" - case firstKDenseReplace = "first_k_dense_replace" - case maxPositionEmbeddings = "max_position_embeddings" - case rmsNormEps = "rms_norm_eps" - case ropeTheta = "rope_theta" - case ropeScaling = "rope_scaling" - case attentionBias = "attention_bias" - } +import ReerCodable + +@Codable +public struct DeepseekV3Configuration: Sendable { + @CodingKey("vocab_size") public var vocabSize: Int + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("moe_intermediate_size") public var moeIntermediateSize: Int + @CodingKey("num_hidden_layers") public var numHiddenLayers: Int + @CodingKey("num_attention_heads") public var numAttentionHeads: Int + @CodingKey("num_key_value_heads") public var numKeyValueHeads: Int + @CodingKey("n_shared_experts") public var nSharedExperts: Int? + @CodingKey("n_routed_experts") public var nRoutedExperts: Int? + @CodingKey("routed_scaling_factor") public var routedScalingFactor: Float + @CodingKey("kv_lora_rank") public var kvLoraRank: Int + @CodingKey("q_lora_rank") public var qLoraRank: Int + @CodingKey("qk_rope_head_dim") public var qkRopeHeadDim: Int + @CodingKey("v_head_dim") public var vHeadDim: Int + @CodingKey("qk_nope_head_dim") public var qkNopeHeadDim: Int + @CodingKey("norm_topk_prob") public var normTopkProb: Bool + @CodingKey("n_group") public var nGroup: Int? + @CodingKey("topk_group") public var topkGroup: Int? + @CodingKey("num_experts_per_tok") public var numExpertsPerTok: Int? + @CodingKey("moe_layer_freq") public var moeLayerFreq: Int + @CodingKey("first_k_dense_replace") public var firstKDenseReplace: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("attention_bias") public var attentionBias: Bool } private func yarnFindCorrectionDim( diff --git a/Libraries/MLXLLM/Models/Ernie4_5.swift b/Libraries/MLXLLM/Models/Ernie4_5.swift index 6be2ba7d..b9088c56 100644 --- a/Libraries/MLXLLM/Models/Ernie4_5.swift +++ b/Libraries/MLXLLM/Models/Ernie4_5.swift @@ -9,55 +9,24 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/ernie4_5.py -public struct Ernie45Configuration: Codable { - var hiddenSize: Int - var intermediateSize: Int - var maxPositionEmbeddings: Int - var numAttentionHeads: Int - var numKeyValueHeads: Int - var headDim: Int? - var numHiddenLayers: Int - var rmsNormEps: Float - var vocabularySize: Int - var ropeTheta: Float - var useBias: Bool - var tieWordEmbeddings: Bool - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case maxPositionEmbeddings = "max_position_embeddings" - case numAttentionHeads = "num_attention_heads" - case numKeyValueHeads = "num_key_value_heads" - case headDim = "head_dim" - case numHiddenLayers = "num_hidden_layers" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case ropeTheta = "rope_theta" - case useBias = "use_bias" - case tieWordEmbeddings = "tie_word_embeddings" - } - - public init(from decoder: Decoder) throws { - let container: KeyedDecodingContainer = - try decoder.container(keyedBy: Ernie45Configuration.CodingKeys.self) - - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) - self.numAttentionHeads = try container.decode(Int.self, forKey: .numAttentionHeads) - self.numKeyValueHeads = try container.decode(Int.self, forKey: .numKeyValueHeads) - self.headDim = try container.decode(Int.self, forKey: .headDim) - self.numHiddenLayers = try container.decode(Int.self, forKey: .numHiddenLayers) - self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - self.ropeTheta = try container.decode(Float.self, forKey: .ropeTheta) - self.useBias = try container.decode(Bool.self, forKey: .useBias) - self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings) - } +@Codable +public struct Ernie45Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int + @CodingKey("num_attention_heads") public var numAttentionHeads: Int + @CodingKey("num_key_value_heads") public var numKeyValueHeads: Int + @CodingKey("head_dim") public var headDim: Int? + @CodingKey("num_hidden_layers") public var numHiddenLayers: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("use_bias") public var useBias: Bool + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool } private class Attention: Module { diff --git a/Libraries/MLXLLM/Models/Exaone4.swift b/Libraries/MLXLLM/Models/Exaone4.swift index 9c90bd69..5b8cde82 100644 --- a/Libraries/MLXLLM/Models/Exaone4.swift +++ b/Libraries/MLXLLM/Models/Exaone4.swift @@ -7,9 +7,9 @@ import Foundation import MLX -import MLXFast import MLXLMCommon import MLXNN +import ReerCodable // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/exaone4.py @@ -242,59 +242,22 @@ public class Exaone4Model: Module, LLMModel, KVCacheDimensionProvider { } } -public struct Exaone4Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var maxPositionEmbeddings: Int - var ropeTheta: Float - var headDim: Int - var tieWordEmbeddings: Bool - var ropeScaling: [String: StringOrNumber]? - var slidingWindow: Int? - var slidingWindowPattern: String? - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeTheta = "rope_theta" - case headDim = "head_dim" - case tieWordEmbeddings = "tie_word_embeddings" - case ropeScaling = "rope_scaling" - case slidingWindow = "sliding_window" - case slidingWindowPattern = "sliding_window_pattern" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: .kvHeads) - self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) - self.ropeTheta = try container.decode(Float.self, forKey: .ropeTheta) - self.headDim = try container.decode(Int.self, forKey: .headDim) - self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings) - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - self.slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) - self.slidingWindowPattern = try container.decodeIfPresent( - String.self, forKey: .slidingWindowPattern) - } +@Codable +public struct Exaone4Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("head_dim") public var headDim: Int + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("sliding_window") public var slidingWindow: Int? + @CodingKey("sliding_window_pattern") public var slidingWindowPattern: String? } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/GLM4.swift b/Libraries/MLXLLM/Models/GLM4.swift index 44b07cd4..7066ffd8 100644 --- a/Libraries/MLXLLM/Models/GLM4.swift +++ b/Libraries/MLXLLM/Models/GLM4.swift @@ -9,6 +9,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/glm4.py @@ -166,7 +167,6 @@ public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider { private let model: GLM4ModelInner let configuration: GLM4Configuration - let modelType: String @ModuleInfo(key: "lm_head") var lmHead: Linear @@ -174,7 +174,6 @@ public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider { self.configuration = args self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.modelType = args.modelType self.model = GLM4ModelInner(args) _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) @@ -196,80 +195,22 @@ public class GLM4Model: Module, LLMModel, KVCacheDimensionProvider { } } -public struct GLM4Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var attentionBias: Bool - var headDim: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var partialRotaryFactor: Float - var ropeTheta: Float = 10000.0 - var ropeTraditional: Bool = true - var tieWordEmbeddings = false - var maxPositionEmbeddings: Int = 32768 - var modelType: String - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case attentionBias = "attention_bias" - case headDim = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case partialRotaryFactor = "partial_rotary_factor" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case tieWordEmbeddings = "tie_word_embeddings" - case maxPositionEmbeddings = "max_position_embeddings" - case modelType = "model_type" - } - - public init(from decoder: Decoder) throws { - let container: KeyedDecodingContainer = - try decoder.container( - keyedBy: GLM4Configuration.CodingKeys.self) - - self.modelType = try container.decode( - String.self, forKey: GLM4Configuration.CodingKeys.modelType) - self.hiddenSize = try container.decode( - Int.self, forKey: GLM4Configuration.CodingKeys.hiddenSize) - self.hiddenLayers = try container.decode( - Int.self, forKey: GLM4Configuration.CodingKeys.hiddenLayers) - self.intermediateSize = try container.decode( - Int.self, forKey: GLM4Configuration.CodingKeys.intermediateSize) - self.attentionHeads = try container.decode( - Int.self, forKey: GLM4Configuration.CodingKeys.attentionHeads) - self.attentionBias = try container.decode( - Bool.self, forKey: GLM4Configuration.CodingKeys.attentionBias) - self.headDim = try container.decode( - Int.self, forKey: GLM4Configuration.CodingKeys.headDim) - self.rmsNormEps = try container.decode( - Float.self, forKey: GLM4Configuration.CodingKeys.rmsNormEps) - self.vocabularySize = try container.decode( - Int.self, forKey: GLM4Configuration.CodingKeys.vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: GLM4Configuration.CodingKeys.kvHeads) - self.partialRotaryFactor = try container.decode( - Float.self, forKey: GLM4Configuration.CodingKeys.partialRotaryFactor) - self.ropeTheta = - try container.decodeIfPresent( - Float.self, forKey: GLM4Configuration.CodingKeys.ropeTheta) - ?? 10000.0 - self.ropeTraditional = - try container.decodeIfPresent( - Bool.self, forKey: GLM4Configuration.CodingKeys.ropeTraditional) - ?? true - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false - self.maxPositionEmbeddings = - try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768 - } +@Codable +public struct GLM4Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("attention_bias") public var attentionBias: Bool + @CodingKey("head_dim") public var headDim: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("partial_rotary_factor") public var partialRotaryFactor: Float + @CodingKey("rope_theta") public var ropeTheta: Float = 10000.0 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = true + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings = false + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 32768 } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/GPTOSS.swift b/Libraries/MLXLLM/Models/GPTOSS.swift index 76c31fe5..2a2afb0d 100644 --- a/Libraries/MLXLLM/Models/GPTOSS.swift +++ b/Libraries/MLXLLM/Models/GPTOSS.swift @@ -7,67 +7,29 @@ import Foundation import MLX -import MLXFast import MLXLMCommon import MLXNN -import MLXRandom +import ReerCodable // MARK: - Configuration -public struct GPTOSSConfiguration: Codable, Sendable { - public var modelType: String = "gpt_oss" - public var hiddenLayers: Int = 36 - public var localExperts: Int = 128 - public var expertsPerToken: Int = 4 - public var vocabularySize: Int = 201088 - public var rmsNormEps: Float = 1e-5 - public var hiddenSize: Int = 2880 - public var intermediateSize: Int = 2880 - public var headDim: Int = 64 - public var attentionHeads: Int = 64 - public var kvHeads: Int = 8 - public var slidingWindow: Int = 128 - public var ropeTheta: Float = 150000 - public var ropeScaling: [String: StringOrNumber]? = nil - public var layerTypes: [String]? = nil - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenLayers = "num_hidden_layers" - case localExperts = "num_local_experts" - case expertsPerToken = "num_experts_per_tok" - case vocabularySize = "vocab_size" - case rmsNormEps = "rms_norm_eps" - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case headDim = "head_dim" - case attentionHeads = "num_attention_heads" - case kvHeads = "num_key_value_heads" - case slidingWindow = "sliding_window" - case ropeTheta = "rope_theta" - case ropeScaling = "rope_scaling" - case layerTypes = "layer_types" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - self.modelType = try container.decode(String.self, forKey: .modelType) - self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - self.localExperts = try container.decode(Int.self, forKey: .localExperts) - self.expertsPerToken = try container.decode(Int.self, forKey: .expertsPerToken) - self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - self.headDim = try container.decode(Int.self, forKey: .headDim) - self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - self.kvHeads = try container.decode(Int.self, forKey: .kvHeads) - self.slidingWindow = try container.decode(Int.self, forKey: .slidingWindow) - self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 150000 - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - self.layerTypes = try container.decodeIfPresent([String].self, forKey: .layerTypes) - } +@Codable +public struct GPTOSSConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String = "gpt_oss" + @CodingKey("num_hidden_layers") public var hiddenLayers: Int = 36 + @CodingKey("num_local_experts") public var localExperts: Int = 128 + @CodingKey("num_experts_per_tok") public var expertsPerToken: Int = 4 + @CodingKey("vocab_size") public var vocabularySize: Int = 201088 + @CodingKey("rms_norm_eps") public var rmsNormEps: Float = 1e-5 + @CodingKey("hidden_size") public var hiddenSize: Int = 2880 + @CodingKey("intermediate_size") public var intermediateSize: Int = 2880 + @CodingKey("head_dim") public var headDim: Int = 64 + @CodingKey("num_attention_heads") public var attentionHeads: Int = 64 + @CodingKey("num_key_value_heads") public var kvHeads: Int = 8 + @CodingKey("sliding_window") public var slidingWindow: Int = 128 + @CodingKey("rope_theta") public var ropeTheta: Float = 150000 + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil + @CodingKey("layer_types") public var layerTypes: [String]? = nil } private func mlxTopK(_ a: MLXArray, k: Int, axis: Int = -1) -> (values: MLXArray, indices: MLXArray) diff --git a/Libraries/MLXLLM/Models/Gemma.swift b/Libraries/MLXLLM/Models/Gemma.swift index 3f1a6653..fe899446 100644 --- a/Libraries/MLXLLM/Models/Gemma.swift +++ b/Libraries/MLXLLM/Models/Gemma.swift @@ -4,9 +4,10 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers -// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py +// Port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/gemma.py // Specialized norm for Gemma private class RMSNorm: Module, UnaryLayer { @@ -178,11 +179,9 @@ public class GemmaModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - let modelType: String private let model: GemmaModelInner public init(_ args: GemmaConfiguration) { - self.modelType = args.modelType self.vocabularySize = args.vocabularySize self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) self.model = GemmaModelInner(args) @@ -198,34 +197,18 @@ public class GemmaModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct GemmaConfiguration: Codable, Sendable { - var modelType: String - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var headDimensions: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - private let _ropeTheta: Float? - public var ropeTheta: Float { _ropeTheta ?? 10_000 } - private let _ropeTraditional: Bool? - public var ropeTraditional: Bool { _ropeTraditional ?? false } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDimensions = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case _ropeTheta = "rope_theta" - case _ropeTraditional = "rope_traditional" - } +@Codable +public struct GemmaConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("head_dim") public var headDimensions: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Gemma2.swift b/Libraries/MLXLLM/Models/Gemma2.swift index 561477c1..bf1ba415 100644 --- a/Libraries/MLXLLM/Models/Gemma2.swift +++ b/Libraries/MLXLLM/Models/Gemma2.swift @@ -4,9 +4,10 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers -// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py +// Port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/gemma2.py private class Attention: Module { let args: Gemma2Configuration @@ -203,70 +204,21 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider { } } -public struct Gemma2Configuration: Codable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var headDimensions: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var ropeTheta: Float = 10_000 - var ropeTraditional: Bool = false - var attnLogitSoftcapping: Float = 50.0 - var finalLogitSoftcapping: Float = 30.0 - var queryPreAttnScalar: Float = 144.0 - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDimensions = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case attnLogitSoftcapping = "attn_logit_softcapping" - case finalLogitSoftcapping = "final_logit_softcapping" - case queryPreAttnScalar = "query_pre_attn_scalar" - } - - public init(from decoder: Swift.Decoder) throws { - // Custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = try decoder.container( - keyedBy: CodingKeys.self) - - self.hiddenSize = try container.decode( - Int.self, forKey: CodingKeys.hiddenSize) - self.hiddenLayers = try container.decode( - Int.self, forKey: CodingKeys.hiddenLayers) - self.intermediateSize = try container.decode( - Int.self, forKey: CodingKeys.intermediateSize) - self.attentionHeads = try container.decode( - Int.self, forKey: CodingKeys.attentionHeads) - self.headDimensions = try container.decode( - Int.self, forKey: CodingKeys.headDimensions) - self.rmsNormEps = try container.decode( - Float.self, forKey: CodingKeys.rmsNormEps) - self.vocabularySize = try container.decode( - Int.self, forKey: CodingKeys.vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: CodingKeys.kvHeads) - self.ropeTheta = - try container.decodeIfPresent(Float.self, forKey: CodingKeys.ropeTheta) - ?? 10_000 - self.ropeTraditional = - try container.decodeIfPresent( - Bool.self, forKey: CodingKeys.ropeTraditional) ?? false - self.attnLogitSoftcapping = try container.decode( - Float.self, forKey: CodingKeys.attnLogitSoftcapping) - self.finalLogitSoftcapping = try container.decode( - Float.self, forKey: CodingKeys.finalLogitSoftcapping) - self.queryPreAttnScalar = try container.decode( - Float.self, forKey: CodingKeys.queryPreAttnScalar) - } +@Codable +public struct Gemma2Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("head_dim") public var headDimensions: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("attn_logit_softcapping") public var attnLogitSoftcapping: Float = 50.0 + @CodingKey("final_logit_softcapping") public var finalLogitSoftcapping: Float = 30.0 + @CodingKey("query_pre_attn_scalar") public var queryPreAttnScalar: Float = 144.0 } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index eb953ec4..7f795ba9 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -9,82 +9,52 @@ import Foundation import MLX -import MLXFast import MLXLLM import MLXLMCommon import MLXNN +import ReerCodable + +@Codable +public struct Gemma3TextConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int = 4 + @CodingKey("head_dim") public var headDim: Int = 256 + @CodingKey("rms_norm_eps") public var rmsNormEps: Float = 1.0e-6 + @CodingKey("vocab_size") public var vocabularySize: Int = 262144 + @CodingKey("num_key_value_heads") public var kvHeads: Int = 1 + @CodingKey("rope_global_base_freq") public var ropeGlobalBaseFreq: Float = 1_000_000.0 + @CodingKey("rope_local_base_freq") public var ropeLocalBaseFreq: Float = 10_000.0 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("query_pre_attn_scalar") public var queryPreAttnScalar: Float = 256 + @CodingKey("sliding_window") public var slidingWindow: Int = 512 + @CodingKey("sliding_window_pattern") public var slidingWindowPattern: Int = 6 +} -public struct Gemma3TextConfiguration: Codable { - let modelType: String - let hiddenSize: Int - let hiddenLayers: Int - let intermediateSize: Int - let attentionHeads: Int - let headDim: Int - let rmsNormEps: Float - let vocabularySize: Int - let kvHeads: Int - let ropeGlobalBaseFreq: Float - let ropeLocalBaseFreq: Float - let ropeTraditional: Bool - let queryPreAttnScalar: Float - let slidingWindow: Int - let slidingWindowPattern: Int - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDim = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case ropeGlobalBaseFreq = "rope_global_base_freq" - case ropeLocalBaseFreq = "rope_local_base_freq" - case ropeTraditional = "rope_traditional" - case queryPreAttnScalar = "query_pre_attn_scalar" - case slidingWindow = "sliding_window" - case slidingWindowPattern = "sliding_window_pattern" - } +public struct Gemma3TextConfigurationContainer: Codable, Sendable { + public var configuration: Gemma3TextConfiguration enum VLMCodingKeys: String, CodingKey { case textConfig = "text_config" } - public init(from decoder: Decoder) throws { - let nestedContainer = try decoder.container(keyedBy: VLMCodingKeys.self) - + public init(from decoder: any Decoder) throws { // in the case of VLM models convertered using mlx_lm.convert // the configuration will still match the VLMs and be under text_config - let container = - if nestedContainer.contains(.textConfig) { - try nestedContainer.nestedContainer(keyedBy: CodingKeys.self, forKey: .textConfig) - } else { - try decoder.container(keyedBy: CodingKeys.self) - } + let nestedContainer = try decoder.container(keyedBy: VLMCodingKeys.self) + if let configuration = try nestedContainer.decodeIfPresent( + Gemma3TextConfiguration.self, forKey: .textConfig) + { + self.configuration = configuration + } else { + self.configuration = try Gemma3TextConfiguration(from: decoder) + } + } - modelType = try container.decode(String.self, forKey: .modelType) - hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - attentionHeads = try container.decodeIfPresent(Int.self, forKey: .attentionHeads) ?? 4 - headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? 256 - rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1.0e-6 - vocabularySize = try container.decodeIfPresent(Int.self, forKey: .vocabularySize) ?? 262144 - kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? 1 - ropeGlobalBaseFreq = - try container.decodeIfPresent(Float.self, forKey: .ropeGlobalBaseFreq) ?? 1_000_000.0 - ropeLocalBaseFreq = - try container.decodeIfPresent(Float.self, forKey: .ropeLocalBaseFreq) ?? 10_000.0 - ropeTraditional = - try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false - queryPreAttnScalar = - try container.decodeIfPresent(Float.self, forKey: .queryPreAttnScalar) ?? 256 - slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512 - slidingWindowPattern = - try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6 + public func encode(to encoder: any Encoder) throws { + try configuration.encode(to: encoder) } } @@ -336,6 +306,10 @@ public class Gemma3TextModel: Module, LLMModel { public let config: Gemma3TextConfiguration public var vocabularySize: Int { config.vocabularySize } + convenience public init(_ config: Gemma3TextConfigurationContainer) { + self.init(config.configuration) + } + public init(_ config: Gemma3TextConfiguration) { self.config = config self.model = Gemma3Model(config) diff --git a/Libraries/MLXLLM/Models/Gemma3nText.swift b/Libraries/MLXLLM/Models/Gemma3nText.swift index 6b06c55c..1f591a5d 100644 --- a/Libraries/MLXLLM/Models/Gemma3nText.swift +++ b/Libraries/MLXLLM/Models/Gemma3nText.swift @@ -9,115 +9,65 @@ import Foundation import MLX -import MLXFast import MLXLMCommon import MLXNN +import ReerCodable // MARK: - Configuration -public struct Gemma3nTextConfiguration: Codable { - let modelType: String - let hiddenSize: Int - let numHiddenLayers: Int - let intermediateSize: Int - let numAttentionHeads: Int - let headDim: Int - let rmsNormEps: Float - let vocabSize: Int - let numKeyValueHeads: Int - let numKvSharedLayers: Int - let queryPreAttnScalar: Float - let vocabSizePerLayerInput: Int - let slidingWindow: Int - let maxPositionEmbeddings: Int - let ropeLocalBaseFreq: Float - let ropeTheta: Float - let finalLogitSoftcapping: Float - let layerTypes: [String]? - let activationSparsityPattern: [Float]? - let hiddenSizePerLayerInput: Int - let altupNumInputs: Int - let altupCoefClip: Float? - let altupCorrectScale: Bool - let altupActiveIdx: Int - let laurelRank: Int - let ropeScaling: [String: String]? - let slidingWindowPattern: Int? - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case numHiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case numAttentionHeads = "num_attention_heads" - case headDim = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabSize = "vocab_size" - case numKeyValueHeads = "num_key_value_heads" - case numKvSharedLayers = "num_kv_shared_layers" - case queryPreAttnScalar = "query_pre_attn_scalar" - case vocabSizePerLayerInput = "vocab_size_per_layer_input" - case slidingWindow = "sliding_window" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeLocalBaseFreq = "rope_local_base_freq" - case ropeTheta = "rope_theta" - case finalLogitSoftcapping = "final_logit_softcapping" - case layerTypes = "layer_types" - case activationSparsityPattern = "activation_sparsity_pattern" - case hiddenSizePerLayerInput = "hidden_size_per_layer_input" - case altupNumInputs = "altup_num_inputs" - case altupCoefClip = "altup_coef_clip" - case altupCorrectScale = "altup_correct_scale" - case altupActiveIdx = "altup_active_idx" - case laurelRank = "laurel_rank" - case ropeScaling = "rope_scaling" - case slidingWindowPattern = "sliding_window_pattern" - } +@Codable +public struct Gemma3nTextConfiguration { + @CodingKey("model_type") public let modelType: String + @CodingKey("hidden_size") public let hiddenSize: Int + @CodingKey("num_hidden_layers") public let numHiddenLayers: Int + @CodingKey("intermediate_size") public let intermediateSize: Int + @CodingKey("num_attention_heads") public let numAttentionHeads: Int + @CodingKey("head_dim") public let headDim: Int + @CodingKey("rms_norm_eps") public let rmsNormEps: Float + @CodingKey("vocab_size") public let vocabSize: Int + @CodingKey("num_key_value_heads") public let numKeyValueHeads: Int + @CodingKey("num_kv_shared_layers") public let numKvSharedLayers: Int + @CodingKey("query_pre_attn_scalar") public let queryPreAttnScalar: Float + @CodingKey("vocab_size_per_layer_input") public let vocabSizePerLayerInput: Int + @CodingKey("sliding_window") public let slidingWindow: Int + @CodingKey("max_position_embeddings") public let maxPositionEmbeddings: Int + @CodingKey("rope_local_base_freq") public let ropeLocalBaseFreq: Float + @CodingKey("rope_theta") public let ropeTheta: Float + @CodingKey("final_logit_softcapping") public let finalLogitSoftcapping: Float + @CodingKey("layer_types") public let layerTypes: [String]? + @CodingKey("activation_sparsity_pattern") public let activationSparsityPattern: [Float]? + @CodingKey("hidden_size_per_layer_input") public let hiddenSizePerLayerInput: Int + @CodingKey("altup_num_inputs") public let altupNumInputs: Int + @CodingKey("altup_coef_clip") public let altupCoefClip: Float? + @CodingKey("altup_correct_scale") public let altupCorrectScale: Bool + @CodingKey("altup_active_idx") public let altupActiveIdx: Int + @CodingKey("laurel_rank") public let laurelRank: Int + @CodingKey("rope_scaling") public let ropeScaling: [String: String]? + @CodingKey("sliding_window_pattern") public let slidingWindowPattern: Int? +} + +public struct Gemma3nTextConfigurationContainer: Codable, Sendable { + public var configuration: Gemma3nTextConfiguration enum VLMCodingKeys: String, CodingKey { case textConfig = "text_config" } - public init(from decoder: Decoder) throws { + public init(from decoder: any Decoder) throws { + // in the case of VLM models convertered using mlx_lm.convert + // the configuration will still match the VLMs and be under text_config let nestedContainer = try decoder.container(keyedBy: VLMCodingKeys.self) + if let configuration = try nestedContainer.decodeIfPresent( + Gemma3nTextConfiguration.self, forKey: .textConfig) + { + self.configuration = configuration + } else { + self.configuration = try Gemma3nTextConfiguration(from: decoder) + } + } - // in the case of Gemma 3n model, the configuration matches VLMs and text config is under a text_config key - let container = - if nestedContainer.contains(.textConfig) { - try nestedContainer.nestedContainer(keyedBy: CodingKeys.self, forKey: .textConfig) - } else { - try decoder.container(keyedBy: CodingKeys.self) - } - - modelType = try container.decode(String.self, forKey: .modelType) - hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - numHiddenLayers = try container.decode(Int.self, forKey: .numHiddenLayers) - intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - numAttentionHeads = try container.decode(Int.self, forKey: .numAttentionHeads) - headDim = try container.decode(Int.self, forKey: .headDim) - rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - vocabSize = try container.decode(Int.self, forKey: .vocabSize) - numKeyValueHeads = try container.decode(Int.self, forKey: .numKeyValueHeads) - numKvSharedLayers = try container.decode(Int.self, forKey: .numKvSharedLayers) - queryPreAttnScalar = try container.decode(Float.self, forKey: .queryPreAttnScalar) - vocabSizePerLayerInput = try container.decode(Int.self, forKey: .vocabSizePerLayerInput) - slidingWindow = try container.decode(Int.self, forKey: .slidingWindow) - maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) - ropeLocalBaseFreq = try container.decode(Float.self, forKey: .ropeLocalBaseFreq) - ropeTheta = try container.decode(Float.self, forKey: .ropeTheta) - finalLogitSoftcapping = try container.decode(Float.self, forKey: .finalLogitSoftcapping) - layerTypes = try container.decode([String]?.self, forKey: .layerTypes) - activationSparsityPattern = try container.decodeIfPresent( - [Float].self, forKey: .activationSparsityPattern) - hiddenSizePerLayerInput = try container.decode(Int.self, forKey: .hiddenSizePerLayerInput) - altupNumInputs = try container.decode(Int.self, forKey: .altupNumInputs) - altupCoefClip = try container.decodeIfPresent(Float.self, forKey: .altupCoefClip) - altupCorrectScale = try container.decode(Bool.self, forKey: .altupCorrectScale) - altupActiveIdx = try container.decode(Int.self, forKey: .altupActiveIdx) - laurelRank = try container.decode(Int.self, forKey: .laurelRank) - ropeScaling = try container.decodeIfPresent([String: String].self, forKey: .ropeScaling) - slidingWindowPattern = try container.decodeIfPresent( - Int.self, forKey: .slidingWindowPattern) + public func encode(to encoder: any Encoder) throws { + try configuration.encode(to: encoder) } } @@ -928,6 +878,10 @@ public class Gemma3nTextModel: Module, LLMModel { var kvHeads: [Int] + public convenience init(config: Gemma3nTextConfigurationContainer) { + self.init(config: config.configuration) + } + public init(config: Gemma3nTextConfiguration) { self.config = config self.modelType = config.modelType diff --git a/Libraries/MLXLLM/Models/Granite.swift b/Libraries/MLXLLM/Models/Granite.swift index 936de499..b1625d99 100644 --- a/Libraries/MLXLLM/Models/Granite.swift +++ b/Libraries/MLXLLM/Models/Granite.swift @@ -5,12 +5,13 @@ // Created by Sachin Desai on 4/25/25. // -// Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/granite.py - import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable + +// Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/granite.py private class Attention: Module { let args: GraniteConfiguration @@ -214,69 +215,25 @@ public class GraniteModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct GraniteConfiguration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var logitsScaling: Float - var attentionMultiplier: Float - var embeddingMultiplier: Float - var residualMultiplier: Float - var maxPositionEmbeddings: Int - var kvHeads: Int - var attentionBias: Bool - var mlpBias: Bool - var ropeTheta: Float - var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? = nil - var tieWordEmbeddings: Bool = true - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case logitsScaling = "logits_scaling" - case attentionMultiplier = "attention_multiplier" - case embeddingMultiplier = "embedding_multiplier" - case residualMultiplier = "residual_multiplier" - case maxPositionEmbeddings = "max_position_embeddings" - case kvHeads = "num_key_value_heads" - case attentionBias = "attention_bias" - case mlpBias = "mlp_bias" - case ropeTheta = "rope_theta" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - } - - public init(from decoder: Decoder) throws { - let container: KeyedDecodingContainer = - try decoder.container(keyedBy: GraniteConfiguration.CodingKeys.self) - - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - self.logitsScaling = try container.decode(Float.self, forKey: .logitsScaling) - self.attentionMultiplier = try container.decode(Float.self, forKey: .attentionMultiplier) - self.embeddingMultiplier = try container.decode(Float.self, forKey: .embeddingMultiplier) - self.residualMultiplier = try container.decode(Float.self, forKey: .residualMultiplier) - self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) - self.kvHeads = try container.decode(Int.self, forKey: .kvHeads) - self.attentionBias = try container.decode(Bool.self, forKey: .attentionBias) - self.mlpBias = try container.decode(Bool.self, forKey: .mlpBias) ?? false - self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10000000.0 - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings) - } +@Codable +public struct GraniteConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("logits_scaling") public var logitsScaling: Float + @CodingKey("attention_multiplier") public var attentionMultiplier: Float + @CodingKey("embedding_multiplier") public var embeddingMultiplier: Float + @CodingKey("residual_multiplier") public var residualMultiplier: Float + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("attention_bias") public var attentionBias: Bool + @CodingKey("mlp_bias") public var mlpBias: Bool + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Internlm2.swift b/Libraries/MLXLLM/Models/Internlm2.swift index a2166cee..b79b51ba 100644 --- a/Libraries/MLXLLM/Models/Internlm2.swift +++ b/Libraries/MLXLLM/Models/Internlm2.swift @@ -4,6 +4,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // Port of https://github.com/maiqingqiang/mlx-examples/blob/main/llms/mlx_lm/models/internlm2.py @@ -240,76 +241,36 @@ extension InternLM2Model: LoRAModel { } } -public struct InternLM2Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var maxPositionEmbeddings: Int = 32768 - var ropeTheta: Float = 10000 - var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? - var tieWordEmbeddings: Bool = false - var bias: Bool = true +@Codable +public struct InternLM2Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 32768 + @CodingKey("rope_theta") public var ropeTheta: Float = 10000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = false + @CodingKey("bias") public var bias: Bool = true var kvGroups: Int { attentionHeads / kvHeads } - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - case bias = "bias" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads - maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) - if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { - self.ropeTheta = ropeTheta - } - if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) - { - self.ropeTraditional = ropeTraditional - } - ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - if let tieWordEmbeddings = try container.decodeIfPresent( - Bool.self, forKey: .tieWordEmbeddings) - { - self.tieWordEmbeddings = tieWordEmbeddings - } - if let bias = try container.decodeIfPresent(Bool.self, forKey: .bias) { - self.bias = bias - } + public func didDecode(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: AnyCodingKey.self) + let codingKey = AnyCodingKey("rope_scaling") if let ropeScaling { let requiredKeys: Set = ["factor", "type"] let keys = Set(ropeScaling.keys) if !requiredKeys.isSubset(of: keys) { throw DecodingError.dataCorruptedError( - forKey: .ropeScaling, in: container, + forKey: codingKey, in: container, debugDescription: "rope_scaling must contain keys \(requiredKeys)" ) } @@ -317,7 +278,7 @@ public struct InternLM2Configuration: Codable, Sendable { type != .string("linear") && type != .string("dynamic") { throw DecodingError.dataCorruptedError( - forKey: .ropeScaling, in: container, + forKey: codingKey, in: container, debugDescription: "rope_scaling 'type' currently only supports 'linear' or 'dynamic'" ) diff --git a/Libraries/MLXLLM/Models/LFM2.swift b/Libraries/MLXLLM/Models/LFM2.swift index d326595b..b4850a65 100644 --- a/Libraries/MLXLLM/Models/LFM2.swift +++ b/Libraries/MLXLLM/Models/LFM2.swift @@ -9,76 +9,31 @@ import Foundation import MLX import MLXLMCommon import MLXNN - -public struct LFM2Configuration: Codable, Sendable { - let modelType: String - let vocabularySize: Int - let hiddenSize: Int - let hiddenLayers: Int - let attentionHeads: Int - let kvHeads: Int - let maxPositionEmbeddings: Int? - let normEps: Float - let convBias: Bool - let convLCache: Int - private let _blockDim: Int? - var blockDim: Int { _blockDim ?? hiddenSize } - private let _blockFFDim: Int? - var blockFFDim: Int { _blockFFDim ?? hiddenSize } - let blockMultipleOf: Int - let blockFFNDimMultiplier: Float - let blockAutoAdjustFFDim: Bool - private let _fullAttnIdxs: [Int]? - var fullAttnIdxs: [Int] { _fullAttnIdxs ?? Array(0 ..< hiddenLayers) } - let ropeTheta: Float - var headDimensions: Int { hiddenSize / attentionHeads } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case vocabularySize = "vocab_size" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case attentionHeads = "num_attention_heads" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case normEps = "norm_eps" - case convBias = "conv_bias" - case convLCache = "conv_L_cache" - case _blockDim = "block_dim" - case _blockFFDim = "block_ff_dim" - case blockMultipleOf = "block_multiple_of" - case blockFFNDimMultiplier = "block_ffn_dim_multiplier" - case blockAutoAdjustFFDim = "block_auto_adjust_ff_dim" - case _fullAttnIdxs = "full_attn_idxs" - case ropeTheta = "rope_theta" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - self.modelType = try container.decodeIfPresent(String.self, forKey: .modelType) ?? "lfm2" - self.vocabularySize = - try container.decodeIfPresent(Int.self, forKey: .vocabularySize) ?? 65536 - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - self.kvHeads = try container.decode(Int.self, forKey: .kvHeads) - self.maxPositionEmbeddings = try container.decodeIfPresent( - Int.self, forKey: .maxPositionEmbeddings) - self.normEps = try container.decode(Float.self, forKey: .normEps) - self.convBias = try container.decodeIfPresent(Bool.self, forKey: .convBias) ?? false - self.convLCache = try container.decodeIfPresent(Int.self, forKey: .convLCache) ?? 3 - self._blockDim = try container.decodeIfPresent(Int.self, forKey: ._blockDim) - self._blockFFDim = try container.decodeIfPresent(Int.self, forKey: ._blockFFDim) - self.blockMultipleOf = - try container.decodeIfPresent(Int.self, forKey: .blockMultipleOf) ?? 256 - self.blockFFNDimMultiplier = - try container.decodeIfPresent(Float.self, forKey: .blockFFNDimMultiplier) ?? 1.0 - self.blockAutoAdjustFFDim = - try container.decodeIfPresent(Bool.self, forKey: .blockAutoAdjustFFDim) ?? true - self._fullAttnIdxs = try container.decodeIfPresent([Int].self, forKey: ._fullAttnIdxs) - self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 1000000.0 - } +import ReerCodable + +@Codable +public struct LFM2Configuration: Sendable { + @CodingKey("model_type") public var modelType: String = "lfm2" + @CodingKey("vocab_size") public var vocabularySize: Int = 65536 + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int? + @CodingKey("norm_eps") public var normEps: Float + @CodingKey("conv_bias") public var convBias: Bool = false + @CodingKey("conv_L_cache") public var convLCache: Int = 3 + @CodingKey("block_dim") private var _blockDim: Int? + public var blockDim: Int { _blockDim ?? hiddenSize } + @CodingKey("block_ff_dim") private var _blockFFDim: Int? + public var blockFFDim: Int { _blockFFDim ?? hiddenSize } + @CodingKey("block_multiple_of") public var blockMultipleOf: Int = 256 + @CodingKey("block_ffn_dim_multiplier") public var blockFFNDimMultiplier: Float = 1.0 + @CodingKey("block_auto_adjust_ff_dim") public var blockAutoAdjustFFDim: Bool = true + @CodingKey("full_attn_idxs") private var _fullAttnIdxs: [Int]? + public var fullAttnIdxs: [Int] { _fullAttnIdxs ?? Array(0 ..< hiddenLayers) } + @CodingKey("rope_theta") public var ropeTheta: Float = 1000000.0 + public var headDimensions: Int { hiddenSize / attentionHeads } } private class Attention: Module { diff --git a/Libraries/MLXLLM/Models/Lille130m.swift b/Libraries/MLXLLM/Models/Lille130m.swift index 3906ad67..feaede67 100644 --- a/Libraries/MLXLLM/Models/Lille130m.swift +++ b/Libraries/MLXLLM/Models/Lille130m.swift @@ -9,6 +9,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // MARK: - Attention @@ -189,53 +190,18 @@ public final class Lille130mModel: Module, LLMModel, KVCacheDimensionProvider { // MARK: - Configuration -public struct Lille130mConfiguration: Codable, Sendable { - public var modelType: String - public var blockSize: Int - public var layerNormEps: Float - public var hiddenSize: Int // n_embd - public var attentionHeads: Int // n_head - public var kvHeads: Int // n_kv_heads - public var hiddenLayers: Int // n_layer - public var ropeTheta: Float - public var vocabularySize: Int - public var tieWordEmbeddings: Bool = true - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case blockSize = "block_size" - case layerNormEps = "layer_norm_eps" - case hiddenSize = "n_embd" - case attentionHeads = "n_head" - case kvHeads = "n_kv_heads" - case hiddenLayers = "n_layer" - case ropeTheta = "rope_theta" - case vocabularySize = "vocab_size" - } - - public init(from decoder: Decoder) throws { - let container: KeyedDecodingContainer = - try decoder.container(keyedBy: Lille130mConfiguration.CodingKeys.self) - - self.modelType = try container.decode( - String.self, forKey: Lille130mConfiguration.CodingKeys.modelType) - self.blockSize = try container.decode( - Int.self, forKey: Lille130mConfiguration.CodingKeys.blockSize) - self.layerNormEps = try container.decode( - Float.self, forKey: Lille130mConfiguration.CodingKeys.layerNormEps) - self.hiddenSize = try container.decode( - Int.self, forKey: Lille130mConfiguration.CodingKeys.hiddenSize) - self.attentionHeads = try container.decode( - Int.self, forKey: Lille130mConfiguration.CodingKeys.attentionHeads) - self.kvHeads = try container.decode( - Int.self, forKey: Lille130mConfiguration.CodingKeys.kvHeads) - self.hiddenLayers = try container.decode( - Int.self, forKey: Lille130mConfiguration.CodingKeys.hiddenLayers) - self.ropeTheta = try container.decode( - Float.self, forKey: Lille130mConfiguration.CodingKeys.ropeTheta) - self.vocabularySize = try container.decode( - Int.self, forKey: Lille130mConfiguration.CodingKeys.vocabularySize) - } +@Codable +public struct Lille130mConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("block_size") public var blockSize: Int + @CodingKey("layer_norm_eps") public var layerNormEps: Float + @CodingKey("n_embd") public var hiddenSize: Int + @CodingKey("n_head") public var attentionHeads: Int + @CodingKey("n_kv_heads") public var kvHeads: Int + @CodingKey("n_layer") public var hiddenLayers: Int + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Llama.swift b/Libraries/MLXLLM/Models/Llama.swift index 431b7e63..f09bdf4c 100644 --- a/Libraries/MLXLLM/Models/Llama.swift +++ b/Libraries/MLXLLM/Models/Llama.swift @@ -4,9 +4,10 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers -// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py +// port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/llama.py func computeBaseFrequency( base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]? @@ -340,108 +341,37 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct LlamaConfiguration: Codable, Sendable { - - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var headDimensions: Int? - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var maxPositionEmbeddings: Int? - var ropeTheta: Float = 10_000 - var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? - var tieWordEmbeddings: Bool = true - var attentionBias: Bool = false - var mlpBias: Bool = false - - public init( - hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int, attentionHeads: Int, - headDimensions: Int? = nil, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int, - maxPositionEmbeddings: Int? = nil, ropeTheta: Float = 10_000, ropeTraditional: Bool = false, - ropeScaling: [String: StringOrNumber]? = nil, tieWordEmbeddings: Bool = true, - attentionBias: Bool = false, mlpBias: Bool = false - ) { - self.hiddenSize = hiddenSize - self.hiddenLayers = hiddenLayers - self.intermediateSize = intermediateSize - self.attentionHeads = attentionHeads - self.headDimensions = headDimensions - self.rmsNormEps = rmsNormEps - self.vocabularySize = vocabularySize - self.kvHeads = kvHeads - self.maxPositionEmbeddings = maxPositionEmbeddings - self.ropeTheta = ropeTheta - self.ropeTraditional = ropeTraditional - self.ropeScaling = ropeScaling - self.tieWordEmbeddings = tieWordEmbeddings - self.attentionBias = attentionBias - self.mlpBias = mlpBias - } +@Codable +public struct LlamaConfiguration: Sendable { + + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("head_dim") public var headDimensions: Int? + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int? + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true + @CodingKey("attention_bias") public var attentionBias: Bool = false + @CodingKey("mlp_bias") public var mlpBias: Bool = false var resolvedHeadDimensions: Int { headDimensions ?? (hiddenSize / attentionHeads) } - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDimensions = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - case attentionBias = "attention_bias" - case mlpBias = "mlp_bias" - } - - public init(from decoder: Swift.Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) - rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads - maxPositionEmbeddings = try container.decodeIfPresent( - Int.self, forKey: .maxPositionEmbeddings) - if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { - self.ropeTheta = ropeTheta - } - if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) - { - self.ropeTraditional = ropeTraditional - } - ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - if let tieWordEmbeddings = try container.decodeIfPresent( - Bool.self, forKey: .tieWordEmbeddings) - { - self.tieWordEmbeddings = tieWordEmbeddings - } - if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) { - self.attentionBias = attentionBias - } - if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) { - self.mlpBias = mlpBias - } + public func didDecode(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: AnyCodingKey.self) + let codingKey = AnyCodingKey("rope_scaling") if let ropeScaling { if ropeScaling["factor"] == nil { throw DecodingError.dataCorruptedError( - forKey: .ropeScaling, in: container, + forKey: codingKey, in: container, debugDescription: "rope_scaling must contain 'factor'") } if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] { @@ -452,7 +382,7 @@ public struct LlamaConfiguration: Codable, Sendable { ] if !options.contains(ropeType) { throw DecodingError.dataCorruptedError( - forKey: .ropeScaling, in: container, + forKey: codingKey, in: container, debugDescription: "rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'" ) @@ -460,7 +390,7 @@ public struct LlamaConfiguration: Codable, Sendable { } } else { throw DecodingError.dataCorruptedError( - forKey: .ropeScaling, in: container, + forKey: codingKey, in: container, debugDescription: "rope_scaling must contain either 'type' or 'rope_type'") } } diff --git a/Libraries/MLXLLM/Models/MiMo.swift b/Libraries/MLXLLM/Models/MiMo.swift index 33db0a5c..981b7448 100644 --- a/Libraries/MLXLLM/Models/MiMo.swift +++ b/Libraries/MLXLLM/Models/MiMo.swift @@ -9,6 +9,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable private class Attention: Module { let args: MiMoConfiguration @@ -212,59 +213,21 @@ public class MiMoModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct MiMoConfiguration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var maxPositionEmbeddings: Int - var ropeTheta: Float - var ropeTraditional: Bool - var ropeScaling: [String: StringOrNumber]? - var tieWordEmbeddings: Bool - var numNextnPredictLayers: Int - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - case numNextnPredictLayers = "num_nextn_predict_layers" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: .kvHeads) - self.maxPositionEmbeddings = - try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768 - self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10000.0 - self.ropeTraditional = - try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false - self.numNextnPredictLayers = - try container.decodeIfPresent(Int.self, forKey: .numNextnPredictLayers) ?? 2 - } +@Codable +public struct MiMoConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 32768 + @CodingKey("rope_theta") public var ropeTheta: Float = 10000.0 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = false + @CodingKey("num_nextn_predict_layers") public var numNextnPredictLayers: Int = 2 } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Olmo2.swift b/Libraries/MLXLLM/Models/Olmo2.swift index 6f883be1..e4c69cde 100644 --- a/Libraries/MLXLLM/Models/Olmo2.swift +++ b/Libraries/MLXLLM/Models/Olmo2.swift @@ -10,6 +10,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // MARK: - RoPE helpers @@ -339,76 +340,29 @@ public class Olmo2Model: Module, LLMModel, KVCacheDimensionProvider { // MARK: - Configuration -public struct Olmo2Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var headDimensions: Int? - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var maxPositionEmbeddings: Int? - var ropeTheta: Float = 10_000 - var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? - var tieWordEmbeddings: Bool = true - var attentionBias: Bool = false - var mlpBias: Bool = false - - var resolvedHeadDimensions: Int { headDimensions ?? (hiddenSize / attentionHeads) } - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDimensions = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - case attentionBias = "attention_bias" - case mlpBias = "mlp_bias" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) - rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - let maybeKV = try container.decodeIfPresent(Int.self, forKey: .kvHeads) - kvHeads = maybeKV ?? attentionHeads - maxPositionEmbeddings = try container.decodeIfPresent( - Int.self, forKey: .maxPositionEmbeddings) - if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { - self.ropeTheta = ropeTheta - } - if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) - { - self.ropeTraditional = ropeTraditional - } - ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - if let tieWordEmbeddings = try container.decodeIfPresent( - Bool.self, forKey: .tieWordEmbeddings) - { - self.tieWordEmbeddings = tieWordEmbeddings - } - if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) { - self.attentionBias = attentionBias - } - if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) { - self.mlpBias = mlpBias +@Codable +public struct Olmo2Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("head_dim") public var headDimensions: Int? + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int = 0 // Will be set in didDecode + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int? + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true + @CodingKey("attention_bias") public var attentionBias: Bool = false + @CodingKey("mlp_bias") public var mlpBias: Bool = false + + public var resolvedHeadDimensions: Int { headDimensions ?? (hiddenSize / attentionHeads) } + + mutating public func didDecode() { + if kvHeads == 0 { + kvHeads = attentionHeads } } } diff --git a/Libraries/MLXLLM/Models/OlmoE.swift b/Libraries/MLXLLM/Models/OlmoE.swift index f475e9cc..18762fd0 100644 --- a/Libraries/MLXLLM/Models/OlmoE.swift +++ b/Libraries/MLXLLM/Models/OlmoE.swift @@ -10,6 +10,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // MARK: - RoPE helpers @@ -374,87 +375,34 @@ public class OlmoEModel: Module, LLMModel, KVCacheDimensionProvider { // MARK: - Configuration -public struct OlmoEConfiguration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var headDimensions: Int? - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var maxPositionEmbeddings: Int? - var ropeTheta: Float = 10_000 - var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? - var tieWordEmbeddings: Bool = true - var attentionBias: Bool = false - var mlpBias: Bool = false - - var numExperts: Int - var numExpertsPerToken: Int - var normTopkProb: Bool = false - - var resolvedHeadDimensions: Int { headDimensions ?? (hiddenSize / attentionHeads) } - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDimensions = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - case attentionBias = "attention_bias" - case mlpBias = "mlp_bias" - case numExperts = "num_experts" - case numExpertsPerToken = "num_experts_per_tok" - case normTopkProb = "norm_topk_prob" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) - rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - let maybeKV = try container.decodeIfPresent(Int.self, forKey: .kvHeads) - kvHeads = maybeKV ?? attentionHeads - maxPositionEmbeddings = try container.decodeIfPresent( - Int.self, forKey: .maxPositionEmbeddings) - if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { - self.ropeTheta = ropeTheta - } - if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) - { - self.ropeTraditional = ropeTraditional - } - ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - if let tieWordEmbeddings = try container.decodeIfPresent( - Bool.self, forKey: .tieWordEmbeddings) - { - self.tieWordEmbeddings = tieWordEmbeddings - } - if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) { - self.attentionBias = attentionBias - } - if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) { - self.mlpBias = mlpBias +@Codable +public struct OlmoEConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("head_dim") public var headDimensions: Int? + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int = 0 // Will be set in didDecode + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int? + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true + @CodingKey("attention_bias") public var attentionBias: Bool = false + @CodingKey("mlp_bias") public var mlpBias: Bool = false + + @CodingKey("num_experts") public var numExperts: Int + @CodingKey("num_experts_per_tok") public var numExpertsPerToken: Int + @CodingKey("norm_topk_prob") public var normTopkProb: Bool = false + + public var resolvedHeadDimensions: Int { headDimensions ?? (hiddenSize / attentionHeads) } + + mutating public func didDecode() { + if kvHeads == 0 { + kvHeads = attentionHeads } - numExperts = try container.decode(Int.self, forKey: .numExperts) - numExpertsPerToken = try container.decode(Int.self, forKey: .numExpertsPerToken) - normTopkProb = try container.decodeIfPresent(Bool.self, forKey: .normTopkProb) ?? false } } diff --git a/Libraries/MLXLLM/Models/OpenELM.swift b/Libraries/MLXLLM/Models/OpenELM.swift index 2cd9f3dd..9843039b 100644 --- a/Libraries/MLXLLM/Models/OpenELM.swift +++ b/Libraries/MLXLLM/Models/OpenELM.swift @@ -9,6 +9,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable func computeHeads(modelDim: Int, headDim: Int) -> Int { assert(modelDim % headDim == 0, "modelDim must be divisible by headDim") @@ -211,58 +212,26 @@ public class OpenELMModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct OpenElmConfiguration: Codable, Sendable { - var modelType: String - var headDimensions: Int - var numTransformerLayers: Int - var modelDim: Int - var vocabularySize: Int - var ffnDimDivisor: Int - var numQueryHeads: [Int] = [] - var kvHeads: [Int] = [] - var ffnWithGlu: Bool = true - var normalizeQkProjections: Bool = true - var shareInputOutputLayers: Bool = true - var rmsNormEps: Float = 1e-6 - var ropeTheta: Float = 10_000 - var ropeTraditional: Bool = false - var numGqaGroups: Int = 4 - var ffnMultipliers: [Float] = [0.5, 4.0] - var qkvMultiplier: [Float] = [0.5, 1.0] - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case headDimensions = "head_dim" - case numTransformerLayers = "num_transformer_layers" - case modelDim = "model_dim" - case vocabularySize = "vocab_size" - case ffnDimDivisor = "ffn_dim_divisor" - case ffnMultipliers = "ffn_multipliers" - case ffnWithGlu = "ffn_with_glu" - case normalizeQkProjections = "normalize_qk_projections" - case shareInputOutputLayers = "share_input_output_layers" - } - - public init(from decoder: Decoder) throws { - // custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = - try decoder.container( - keyedBy: OpenElmConfiguration.CodingKeys.self) - - self.modelType = try container.decode( - String.self, forKey: OpenElmConfiguration.CodingKeys.modelType) - self.headDimensions = try container.decode( - Int.self, forKey: OpenElmConfiguration.CodingKeys.headDimensions) - self.numTransformerLayers = try container.decode( - Int.self, forKey: OpenElmConfiguration.CodingKeys.numTransformerLayers) - - self.modelDim = try container.decode( - Int.self, forKey: OpenElmConfiguration.CodingKeys.modelDim) - self.vocabularySize = try container.decode( - Int.self, forKey: OpenElmConfiguration.CodingKeys.vocabularySize) - self.ffnDimDivisor = try container.decode( - Int.self, forKey: OpenElmConfiguration.CodingKeys.ffnDimDivisor) - +@Codable +public struct OpenElmConfiguration: Sendable { + @CodingKey("head_dim") public var headDimensions: Int + @CodingKey("num_transformer_layers") public var numTransformerLayers: Int + @CodingKey("model_dim") public var modelDim: Int + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("ffn_dim_divisor") public var ffnDimDivisor: Int + @CodingKey("ffn_multipliers") public var ffnMultipliers: [Float] = [0.5, 4.0] + @CodingKey("ffn_with_glu") public var ffnWithGlu: Bool = true + @CodingKey("normalize_qk_projections") public var normalizeQkProjections: Bool = true + @CodingKey("share_input_output_layers") public var shareInputOutputLayers: Bool = true + @CodingIgnored public var numQueryHeads: [Int] = [] + @CodingIgnored public var kvHeads: [Int] = [] + @CodingIgnored public var rmsNormEps: Float = 1e-6 + @CodingIgnored public var ropeTheta: Float = 10_000 + @CodingIgnored public var ropeTraditional: Bool = false + @CodingIgnored public var numGqaGroups: Int = 4 + @CodingIgnored public var qkvMultiplier: [Float] = [0.5, 1.0] + + public mutating func didDecode(from decoder: any Decoder) throws { let qkvMultipliers = stride( from: qkvMultiplier[0], through: qkvMultiplier[1], by: (qkvMultiplier[1] - qkvMultiplier[0]) / Float(numTransformerLayers - 1) @@ -287,16 +256,6 @@ public struct OpenElmConfiguration: Codable, Sendable { by: (ffnMultipliers[1] - ffnMultipliers[0]) / Float(numTransformerLayers - 1) ) .map { round($0 * 100) / 100 } - - self.ffnWithGlu = - try container.decodeIfPresent( - Bool.self, forKey: OpenElmConfiguration.CodingKeys.ffnWithGlu) ?? true - self.normalizeQkProjections = - try container.decodeIfPresent( - Bool.self, forKey: OpenElmConfiguration.CodingKeys.normalizeQkProjections) ?? true - self.shareInputOutputLayers = - try container.decodeIfPresent( - Bool.self, forKey: OpenElmConfiguration.CodingKeys.shareInputOutputLayers) ?? true } } diff --git a/Libraries/MLXLLM/Models/Phi.swift b/Libraries/MLXLLM/Models/Phi.swift index dd173cbd..1068d7dc 100644 --- a/Libraries/MLXLLM/Models/Phi.swift +++ b/Libraries/MLXLLM/Models/Phi.swift @@ -4,8 +4,9 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable -// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py +// https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/phi.py private class PhiAttention: Module { @@ -179,59 +180,18 @@ public class PhiModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct PhiConfiguration: Codable, Sendable { - var maxPositionalEmbeddings = 2048 - var vocabularySize = 51200 - var hiddenSize = 2560 - var attentionHeads = 32 - var hiddenLayers = 32 - var kvHeads = 32 - var partialRotaryFactor: Float = 0.4 - var intermediateSize = 10240 - var layerNormEps: Float = 1e-5 - var ropeTheta: Float = 10_000 - - enum CodingKeys: String, CodingKey { - case maxPositionalEmbeddings = "max_position_embeddings" - case vocabularySize = "vocab_size" - case hiddenSize = "hidden_size" - case attentionHeads = "num_attention_heads" - case hiddenLayers = "num_hidden_layers" - case kvHeads = "num_key_value_heads" - case partialRotaryFactor = "partial_rotary_factor" - case intermediateSize = "intermediate_size" - case layerNormEps = "layer_norm_eps" - case ropeTheta = "rope_theta" - } - - public init(from decoder: Decoder) throws { - let container: KeyedDecodingContainer = try decoder.container( - keyedBy: PhiConfiguration.CodingKeys.self) - - self.maxPositionalEmbeddings = try container.decode( - Int.self, forKey: PhiConfiguration.CodingKeys.maxPositionalEmbeddings) - self.vocabularySize = try container.decode( - Int.self, forKey: PhiConfiguration.CodingKeys.vocabularySize) - self.hiddenSize = try container.decode( - Int.self, forKey: PhiConfiguration.CodingKeys.hiddenSize) - self.attentionHeads = try container.decode( - Int.self, forKey: PhiConfiguration.CodingKeys.attentionHeads) - self.hiddenLayers = try container.decode( - Int.self, forKey: PhiConfiguration.CodingKeys.hiddenLayers) - self.kvHeads = - try container.decodeIfPresent(Int.self, forKey: PhiConfiguration.CodingKeys.kvHeads) - ?? attentionHeads - self.partialRotaryFactor = try container.decode( - Float.self, forKey: PhiConfiguration.CodingKeys.partialRotaryFactor) - self.intermediateSize = try container.decode( - Int.self, forKey: PhiConfiguration.CodingKeys.intermediateSize) - self.layerNormEps = try container.decode( - Float.self, forKey: PhiConfiguration.CodingKeys.layerNormEps) - self.ropeTheta = - try container.decodeIfPresent(Float.self, forKey: PhiConfiguration.CodingKeys.ropeTheta) - ?? 10_000 - - } +@Codable +public struct PhiConfiguration: Sendable { + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings = 2048 + @CodingKey("vocab_size") public var vocabularySize = 51200 + @CodingKey("hidden_size") public var hiddenSize = 2560 + @CodingKey("num_attention_heads") public var attentionHeads = 32 + @CodingKey("num_hidden_layers") public var hiddenLayers = 32 + @CodingKey("num_key_value_heads") public var kvHeads = 32 + @CodingKey("partial_rotary_factor") public var partialRotaryFactor: Float = 0.4 + @CodingKey("intermediate_size") public var intermediateSize = 10240 + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-5 + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Phi3.swift b/Libraries/MLXLLM/Models/Phi3.swift index 0792dd83..6fd3811a 100644 --- a/Libraries/MLXLLM/Models/Phi3.swift +++ b/Libraries/MLXLLM/Models/Phi3.swift @@ -4,6 +4,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable private class Attention: Module { @@ -229,13 +230,14 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider { } } -struct RopeScalingWithFactorArrays: Codable { - let longFactor: [Float]? - let shortFactor: [Float]? - let factor: Float? - let type: String? - let longMScale: Float? - let shortMScale: Float? +@Codable +public struct RopeScalingWithFactorArrays: Sendable { + @CodingKey("long_factor") public var longFactor: [Float]? + @CodingKey("short_factor") public var shortFactor: [Float]? + @CodingKey("long_mscale") public var longMScale: Float? + @CodingKey("short_mscale") public var shortMScale: Float? + public var factor: Float? + public var type: String? enum CodingKeys: String, CodingKey { case type @@ -247,74 +249,22 @@ struct RopeScalingWithFactorArrays: Codable { } } -public struct Phi3Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var ropeTheta: Float = 10_000 - var ropeTraditional: Bool = false - var ropeScaling: RopeScalingWithFactorArrays? - var partialRotaryFactor: Float = 1.0 - var maxPositionEmbeddings: Int - var originalMaxPositionEmbeddings: Int - var tieWordEmbeddings: Bool = false - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case partialRotaryFactor = "partial_rotary_factor" - case maxPositionEmbeddings = "max_position_embeddings" - case originalMaxPositionEmbeddings = "original_max_position_embeddings" - case tieWordEmbeddings = "tie_word_embeddings" - } - - public init(from decoder: Decoder) throws { - // custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = try decoder.container( - keyedBy: Phi3Configuration.CodingKeys.self) - - hiddenSize = try container.decode(Int.self, forKey: Phi3Configuration.CodingKeys.hiddenSize) - hiddenLayers = try container.decode( - Int.self, forKey: Phi3Configuration.CodingKeys.hiddenLayers) - intermediateSize = try container.decode( - Int.self, forKey: Phi3Configuration.CodingKeys.intermediateSize) - attentionHeads = try container.decode( - Int.self, forKey: Phi3Configuration.CodingKeys.attentionHeads) - rmsNormEps = try container.decode( - Float.self, forKey: Phi3Configuration.CodingKeys.rmsNormEps) - vocabularySize = try container.decode( - Int.self, forKey: Phi3Configuration.CodingKeys.vocabularySize) - kvHeads = try container.decode(Int.self, forKey: Phi3Configuration.CodingKeys.kvHeads) - ropeTheta = - try container.decodeIfPresent( - Float.self, forKey: Phi3Configuration.CodingKeys.ropeTheta) ?? 10_000 - ropeTraditional = - try container.decodeIfPresent( - Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false - ropeScaling = try container.decodeIfPresent( - RopeScalingWithFactorArrays.self, forKey: .ropeScaling) - partialRotaryFactor = - try container.decodeIfPresent( - Float.self, forKey: .partialRotaryFactor) ?? 1.0 - maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings) - originalMaxPositionEmbeddings = try container.decode( - Int.self, forKey: .originalMaxPositionEmbeddings) - tieWordEmbeddings = - try container.decodeIfPresent( - Bool.self, forKey: .tieWordEmbeddings) ?? false - } +@Codable +public struct Phi3Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: RopeScalingWithFactorArrays? + @CodingKey("partial_rotary_factor") public var partialRotaryFactor: Float = 1.0 + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int + @CodingKey("original_max_position_embeddings") public var originalMaxPositionEmbeddings: Int + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = false } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/PhiMoE.swift b/Libraries/MLXLLM/Models/PhiMoE.swift index a5fdf72a..c0130f51 100644 --- a/Libraries/MLXLLM/Models/PhiMoE.swift +++ b/Libraries/MLXLLM/Models/PhiMoE.swift @@ -2,41 +2,26 @@ import Foundation import MLX import MLXLMCommon import MLXNN - -// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phimoe.py - -public struct PhiMoEConfiguration: Codable, Sendable { - var modelType: String = "phimoe" - var vocabularySize: Int = 32064 - var hiddenSize: Int = 4096 - var intermediateSize: Int = 6400 - var hiddenLayers: Int = 32 - var attentionHeads: Int = 32 - var kvHeads: Int = 8 - var maxPositionEmbeddings: Int = 131072 - var originalMaxPositionEmbeddings: Int = 4096 - var rmsNormEps: Float = 1e-6 - var ropeScaling: RopeScalingWithFactorArrays? - var numLocalExperts: Int = 16 - var numExpertsPerToken: Int = 2 - var ropeTheta: Float = 10000.0 - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case vocabularySize = "vocab_size" - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case hiddenLayers = "num_hidden_layers" - case attentionHeads = "num_attention_heads" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case originalMaxPositionEmbeddings = "original_max_position_embeddings" - case rmsNormEps = "rms_norm_eps" - case ropeScaling = "rope_scaling" - case numLocalExperts = "num_local_experts" - case numExpertsPerToken = "num_experts_per_tok" - case ropeTheta = "rope_theta" - } +import ReerCodable + +// Port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/phimoe.py + +@Codable +public struct PhiMoEConfiguration: Sendable { + @CodingKey("vocab_size") public var vocabularySize: Int = 32064 + @CodingKey("hidden_size") public var hiddenSize: Int = 4096 + @CodingKey("intermediate_size") public var intermediateSize: Int = 6400 + @CodingKey("num_hidden_layers") public var hiddenLayers: Int = 32 + @CodingKey("num_attention_heads") public var attentionHeads: Int = 32 + @CodingKey("num_key_value_heads") public var kvHeads: Int = 8 + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 131072 + @CodingKey("original_max_position_embeddings") public var originalMaxPositionEmbeddings: Int = + 4096 + @CodingKey("rms_norm_eps") public var rmsNormEps: Float = 1e-6 + @CodingKey("rope_scaling") public var ropeScaling: RopeScalingWithFactorArrays? + @CodingKey("num_local_experts") public var numLocalExperts: Int = 16 + @CodingKey("num_experts_per_tok") public var numExpertsPerToken: Int = 2 + @CodingKey("rope_theta") public var ropeTheta: Float = 10000.0 } private class Attention: Module { diff --git a/Libraries/MLXLLM/Models/Qwen2.swift b/Libraries/MLXLLM/Models/Qwen2.swift index c9d62f83..bb0998f5 100644 --- a/Libraries/MLXLLM/Models/Qwen2.swift +++ b/Libraries/MLXLLM/Models/Qwen2.swift @@ -9,8 +9,9 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable -// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py +// port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/qwen2.py private class Attention: Module { let args: Qwen2Configuration @@ -212,64 +213,19 @@ public class Qwen2Model: Module, LLMModel, KVCacheDimensionProvider { } } -public struct Qwen2Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var ropeTheta: Float = 1_000_000 - var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? = nil - var tieWordEmbeddings = false - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - } - - public init(from decoder: Decoder) throws { - // custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = - try decoder.container( - keyedBy: Qwen2Configuration.CodingKeys.self) - - self.hiddenSize = try container.decode( - Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenSize) - self.hiddenLayers = try container.decode( - Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenLayers) - self.intermediateSize = try container.decode( - Int.self, forKey: Qwen2Configuration.CodingKeys.intermediateSize) - self.attentionHeads = try container.decode( - Int.self, forKey: Qwen2Configuration.CodingKeys.attentionHeads) - self.rmsNormEps = try container.decode( - Float.self, forKey: Qwen2Configuration.CodingKeys.rmsNormEps) - self.vocabularySize = try container.decode( - Int.self, forKey: Qwen2Configuration.CodingKeys.vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: Qwen2Configuration.CodingKeys.kvHeads) - self.ropeTheta = - try container.decodeIfPresent( - Float.self, forKey: Qwen2Configuration.CodingKeys.ropeTheta) - ?? 1_000_000 - self.ropeTraditional = - try container.decodeIfPresent( - Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling) - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false - } +@Codable +public struct Qwen2Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float = 1_000_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings = false } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Qwen3.swift b/Libraries/MLXLLM/Models/Qwen3.swift index 6e9e8bb9..e65fb844 100644 --- a/Libraries/MLXLLM/Models/Qwen3.swift +++ b/Libraries/MLXLLM/Models/Qwen3.swift @@ -9,6 +9,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen3.py @@ -217,67 +218,20 @@ public class Qwen3Model: Module, LLMModel, KVCacheDimensionProvider { } } -public struct Qwen3Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var ropeTheta: Float = 1_000_000 - var headDim: Int - var ropeScaling: [String: StringOrNumber]? = nil - var tieWordEmbeddings = false - var maxPositionEmbeddings: Int = 32768 - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case ropeTheta = "rope_theta" - case headDim = "head_dim" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - case maxPositionEmbeddings = "max_position_embeddings" - } - - public init(from decoder: Decoder) throws { - // custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = - try decoder.container( - keyedBy: Qwen3Configuration.CodingKeys.self) - - self.hiddenSize = try container.decode( - Int.self, forKey: Qwen3Configuration.CodingKeys.hiddenSize) - self.hiddenLayers = try container.decode( - Int.self, forKey: Qwen3Configuration.CodingKeys.hiddenLayers) - self.intermediateSize = try container.decode( - Int.self, forKey: Qwen3Configuration.CodingKeys.intermediateSize) - self.attentionHeads = try container.decode( - Int.self, forKey: Qwen3Configuration.CodingKeys.attentionHeads) - self.rmsNormEps = try container.decode( - Float.self, forKey: Qwen3Configuration.CodingKeys.rmsNormEps) - self.vocabularySize = try container.decode( - Int.self, forKey: Qwen3Configuration.CodingKeys.vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: Qwen3Configuration.CodingKeys.kvHeads) - self.ropeTheta = - try container.decodeIfPresent( - Float.self, forKey: Qwen3Configuration.CodingKeys.ropeTheta) - ?? 1_000_000 - self.headDim = try container.decode( - Int.self, forKey: Qwen3Configuration.CodingKeys.headDim) - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: Qwen3Configuration.CodingKeys.ropeScaling) - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false - self.maxPositionEmbeddings = - try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768 - } +@Codable +public struct Qwen3Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float = 1_000_000 + @CodingKey("head_dim") public var headDim: Int + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings = false + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 32768 } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/Qwen3MoE.swift b/Libraries/MLXLLM/Models/Qwen3MoE.swift index 74d8d6b1..4d82dcf3 100644 --- a/Libraries/MLXLLM/Models/Qwen3MoE.swift +++ b/Libraries/MLXLLM/Models/Qwen3MoE.swift @@ -9,6 +9,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen3_moe.py @@ -282,76 +283,26 @@ public class Qwen3MoEModel: Module, LLMModel, KVCacheDimensionProvider { } } -public struct Qwen3MoEConfiguration: Codable, Sendable { - var modelType: String = "qwen3_moe" - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var numExperts: Int - var numExpertsPerToken: Int - var decoderSparseStep: Int - var mlpOnlyLayers: [Int] - var moeIntermediateSize: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var headDim: Int - var ropeTheta: Float = 1_000_000 - var tieWordEmbeddings: Bool = false - var maxPositionEmbeddings: Int = 32768 - var normTopkProb: Bool = false - var ropeScaling: [String: StringOrNumber]? = nil - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case numExperts = "num_experts" - case numExpertsPerToken = "num_experts_per_tok" - case decoderSparseStep = "decoder_sparse_step" - case mlpOnlyLayers = "mlp_only_layers" - case moeIntermediateSize = "moe_intermediate_size" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case headDim = "head_dim" - case ropeTheta = "rope_theta" - case tieWordEmbeddings = "tie_word_embeddings" - case maxPositionEmbeddings = "max_position_embeddings" - case normTopkProb = "norm_topk_prob" - case ropeScaling = "rope_scaling" - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - self.modelType = - try container.decodeIfPresent(String.self, forKey: .modelType) ?? "qwen3_moe" - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - self.numExperts = try container.decode(Int.self, forKey: .numExperts) - self.numExpertsPerToken = try container.decode(Int.self, forKey: .numExpertsPerToken) - self.decoderSparseStep = try container.decode(Int.self, forKey: .decoderSparseStep) - self.mlpOnlyLayers = try container.decode([Int].self, forKey: .mlpOnlyLayers) - self.moeIntermediateSize = try container.decode(Int.self, forKey: .moeIntermediateSize) - self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: .kvHeads) - self.headDim = try container.decode(Int.self, forKey: .headDim) - self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 1_000_000 - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false - self.maxPositionEmbeddings = - try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768 - self.normTopkProb = try container.decodeIfPresent(Bool.self, forKey: .normTopkProb) ?? false - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - } +@Codable +public struct Qwen3MoEConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("num_experts") public var numExperts: Int + @CodingKey("num_experts_per_tok") public var numExpertsPerToken: Int + @CodingKey("decoder_sparse_step") public var decoderSparseStep: Int + @CodingKey("mlp_only_layers") public var mlpOnlyLayers: [Int] + @CodingKey("moe_intermediate_size") public var moeIntermediateSize: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("head_dim") public var headDim: Int + @CodingKey("rope_theta") public var ropeTheta: Float = 1_000_000 + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = false + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 32768 + @CodingKey("norm_topk_prob") public var normTopkProb: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? = nil } // MARK: - LoRA diff --git a/Libraries/MLXLLM/Models/SmolLM3.swift b/Libraries/MLXLLM/Models/SmolLM3.swift index f171d4ec..d7e020ae 100644 --- a/Libraries/MLXLLM/Models/SmolLM3.swift +++ b/Libraries/MLXLLM/Models/SmolLM3.swift @@ -9,6 +9,7 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable private protocol PositionEmbedding { func callAsFunction(_ x: MLXArray, offset: Int) -> MLXArray @@ -235,132 +236,49 @@ public class SmolLM3Model: Module, LLMModel, KVCacheDimensionProvider { // MARK: - Configuration -public struct SmolLM3Configuration: Codable, Sendable { - var modelType: String - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var headDimensions: Int? - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - var maxPositionEmbeddings: Int? - var ropeTheta: Float = 10_000 - var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? - var tieWordEmbeddings: Bool = true - var attentionBias: Bool = false - var mlpBias: Bool = false - - var noRopeLayerInterval: Int = 4 - var noRopeLayers: [Int] = [] - - var resolvedHeadDimensions: Int { +@Codable +public struct SmolLM3Configuration: Sendable { + @CodingKey("model_type") public var modelType: String = "smollm3" + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("head_dim") public var headDimensions: Int? + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int = 0 // Will be set in didDecode + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int? + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true + @CodingKey("attention_bias") public var attentionBias: Bool = false + @CodingKey("mlp_bias") public var mlpBias: Bool = false + @CodingKey("no_rope_layer_interval") public var noRopeLayerInterval: Int = 4 + @CodingKey("no_rope_layers") public var noRopeLayers: [Int] = [] + + public var resolvedHeadDimensions: Int { headDimensions ?? (hiddenSize / attentionHeads) } - public init( - modelType: String = "smollm3", - hiddenSize: Int, - hiddenLayers: Int, - intermediateSize: Int, - attentionHeads: Int, - headDimensions: Int? = nil, - rmsNormEps: Float, - vocabularySize: Int, - kvHeads: Int, - maxPositionEmbeddings: Int? = nil, - ropeTheta: Float = 10_000, - ropeTraditional: Bool = false, - ropeScaling: [String: StringOrNumber]? = nil, - tieWordEmbeddings: Bool = true, - attentionBias: Bool = false, - mlpBias: Bool = false, - noRopeLayerInterval: Int = 4, - noRopeLayers: [Int]? = nil - ) { - self.modelType = modelType - self.hiddenSize = hiddenSize - self.hiddenLayers = hiddenLayers - self.intermediateSize = intermediateSize - self.attentionHeads = attentionHeads - self.headDimensions = headDimensions - self.rmsNormEps = rmsNormEps - self.vocabularySize = vocabularySize - self.kvHeads = kvHeads - self.maxPositionEmbeddings = maxPositionEmbeddings - self.ropeTheta = ropeTheta - self.ropeTraditional = ropeTraditional - self.ropeScaling = ropeScaling - self.tieWordEmbeddings = tieWordEmbeddings - self.attentionBias = attentionBias - self.mlpBias = mlpBias - self.noRopeLayerInterval = noRopeLayerInterval - - if let noRopeLayers = noRopeLayers { - self.noRopeLayers = noRopeLayers - } else { - self.noRopeLayers = (0 ..< hiddenLayers).map { i in - (i + 1) % noRopeLayerInterval != 0 ? 1 : 0 - } - } - } + public mutating func didDecode(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: AnyCodingKey.self) - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDimensions = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case ropeTheta = "rope_theta" - case ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case tieWordEmbeddings = "tie_word_embeddings" - case attentionBias = "attention_bias" - case mlpBias = "mlp_bias" - case noRopeLayerInterval = "no_rope_layer_interval" - case noRopeLayers = "no_rope_layers" - } + // Set kvHeads to attentionHeads if not provided in JSON + if kvHeads == 0 { + kvHeads = attentionHeads + } - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - self.modelType = try container.decodeIfPresent(String.self, forKey: .modelType) ?? "smollm3" - self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) - self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) - self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) - self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) - self.headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) - self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) - self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) - self.kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads - self.maxPositionEmbeddings = try container.decodeIfPresent( - Int.self, forKey: .maxPositionEmbeddings) - self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10_000 - self.ropeTraditional = - try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) ?? false - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: .ropeScaling) - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? true - self.attentionBias = - try container.decodeIfPresent(Bool.self, forKey: .attentionBias) ?? false - self.mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) ?? false - - self.noRopeLayerInterval = - try container.decodeIfPresent(Int.self, forKey: .noRopeLayerInterval) ?? 4 - - if let noRopeLayers = try container.decodeIfPresent([Int].self, forKey: .noRopeLayers) { - self.noRopeLayers = noRopeLayers - } else { - self.noRopeLayers = (0 ..< hiddenLayers).map { i in - (i + 1) % noRopeLayerInterval != 0 ? 1 : 0 + // Compute noRopeLayers if not provided in JSON + if noRopeLayers.isEmpty, + (try? container.decode(Int.self, forKey: AnyCodingKey("num_hidden_layers"))) != nil + { + let providedNoRopeLayers = try? container.decode( + [Int].self, forKey: AnyCodingKey("no_rope_layers")) + if providedNoRopeLayers == nil { + noRopeLayers = (0 ..< hiddenLayers).map { i in + (i + 1) % noRopeLayerInterval != 0 ? 1 : 0 + } } } } diff --git a/Libraries/MLXLLM/Models/Starcoder2.swift b/Libraries/MLXLLM/Models/Starcoder2.swift index b520b5b2..9ef6a231 100644 --- a/Libraries/MLXLLM/Models/Starcoder2.swift +++ b/Libraries/MLXLLM/Models/Starcoder2.swift @@ -9,8 +9,9 @@ import Foundation import MLX import MLXLMCommon import MLXNN +import ReerCodable -// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py +// port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/starcoder2.py private class Attention: Module { let args: Starcoder2Configuration @@ -182,70 +183,19 @@ public class Starcoder2Model: Module, LLMModel, KVCacheDimensionProvider { } } -public struct Starcoder2Configuration: Codable, Sendable { - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var kvHeads: Int - var maxPositionEmbeddings: Int = 16384 - var normEpsilon: Float = 1e-5 - var normType: String = "layer_norm" - var vocabularySize: Int = 49152 - var ropeTheta: Float = 100000 - var tieWordEmbeddings: Bool = true - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case kvHeads = "num_key_value_heads" - case maxPositionEmbeddings = "max_position_embeddings" - case normEpsilon = "norm_epsilon" - case normType = "norm_type" - case vocabularySize = "vocab_size" - case ropeTheta = "rope_theta" - case tieWordEmbeddings = "tie_word_embeddings" - } - - public init(from decoder: Decoder) throws { - // custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = - try decoder.container( - keyedBy: Starcoder2Configuration.CodingKeys.self) - - self.hiddenSize = try container.decode( - Int.self, forKey: Starcoder2Configuration.CodingKeys.hiddenSize) - self.hiddenLayers = try container.decode( - Int.self, forKey: Starcoder2Configuration.CodingKeys.hiddenLayers) - self.intermediateSize = try container.decode( - Int.self, forKey: Starcoder2Configuration.CodingKeys.intermediateSize) - self.attentionHeads = try container.decode( - Int.self, forKey: Starcoder2Configuration.CodingKeys.attentionHeads) - self.kvHeads = try container.decode( - Int.self, forKey: Starcoder2Configuration.CodingKeys.kvHeads) - self.maxPositionEmbeddings = - try container.decodeIfPresent( - Int.self, forKey: Starcoder2Configuration.CodingKeys.maxPositionEmbeddings) ?? 16384 - self.normEpsilon = - try container.decodeIfPresent( - Float.self, forKey: Starcoder2Configuration.CodingKeys.normEpsilon) ?? 1e-5 - self.normType = - try container.decodeIfPresent( - String.self, forKey: Starcoder2Configuration.CodingKeys.normType) ?? "layer_norm" - self.vocabularySize = - try container.decodeIfPresent( - Int.self, forKey: Starcoder2Configuration.CodingKeys.vocabularySize) ?? 49152 - self.ropeTheta = - try container.decodeIfPresent( - Float.self, forKey: Starcoder2Configuration.CodingKeys.ropeTheta) - ?? 100000 - self.tieWordEmbeddings = - try container.decodeIfPresent( - Bool.self, forKey: Starcoder2Configuration.CodingKeys.tieWordEmbeddings) - ?? true - } +@Codable +public struct Starcoder2Configuration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 16384 + @CodingKey("norm_epsilon") public var normEpsilon: Float = 1e-5 + @CodingKey("norm_type") public var normType: String = "layer_norm" + @CodingKey("vocab_size") public var vocabularySize: Int = 49152 + @CodingKey("rope_theta") public var ropeTheta: Float = 100000 + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true } // MARK: - LoRA diff --git a/Libraries/MLXLLM/README.md b/Libraries/MLXLLM/README.md index 16540fe6..c4502e52 100644 --- a/Libraries/MLXLLM/README.md +++ b/Libraries/MLXLLM/README.md @@ -11,7 +11,7 @@ This is a port of several models from: -- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/ +- https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/ using the Hugging Face swift transformers package to provide tokenization: @@ -93,17 +93,12 @@ and create a `.swift` file for your new model: Create a configuration struct to match the `config.json` (any parameters needed). ```swift -public struct YourModelConfiguration: Codable, Sendable { - public let hiddenSize: Int - - // use this pattern for values that need defaults - public let _layerNormEps: Float? - public var layerNormEps: Float { _layerNormEps ?? 1e-6 } - - enum CodingKeys: String, CodingKey { - case hiddenSize = "hidden_size" - case _layerNormEps = "layer_norm_eps" - } +import ReerCodable + +@Codable +public struct YourModelConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-6 } ``` diff --git a/Libraries/MLXLLM/SwitchLayers.swift b/Libraries/MLXLLM/SwitchLayers.swift index 9e8e9a8f..956087cd 100644 --- a/Libraries/MLXLLM/SwitchLayers.swift +++ b/Libraries/MLXLLM/SwitchLayers.swift @@ -2,7 +2,7 @@ import Foundation import MLX import MLXNN -// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/switch_layers.py +// Port of https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/models/switch_layers.py public func gatherSort(x: MLXArray, indices: MLXArray) -> (MLXArray, MLXArray, MLXArray) { let m = indices.dim(-1) diff --git a/Libraries/MLXLMCommon/Adapters/LoRA/DoRA+Layers.swift b/Libraries/MLXLMCommon/Adapters/LoRA/DoRA+Layers.swift index c8a692ff..7e3d1230 100644 --- a/Libraries/MLXLMCommon/Adapters/LoRA/DoRA+Layers.swift +++ b/Libraries/MLXLMCommon/Adapters/LoRA/DoRA+Layers.swift @@ -7,9 +7,7 @@ import Foundation import MLX -import MLXLinalg import MLXNN -import MLXRandom /// Performs the forward pass for a DoRA linear layer. private func forward( diff --git a/Libraries/MLXLMCommon/Adapters/LoRA/LoRA+Layers.swift b/Libraries/MLXLMCommon/Adapters/LoRA/LoRA+Layers.swift index 6a8900c1..feb3cd3b 100644 --- a/Libraries/MLXLMCommon/Adapters/LoRA/LoRA+Layers.swift +++ b/Libraries/MLXLMCommon/Adapters/LoRA/LoRA+Layers.swift @@ -4,7 +4,6 @@ import Foundation import MLX import MLXNN import MLXOptimizers -import MLXRandom /// Implementation of LoRA `Linear` replacement layer. /// diff --git a/Libraries/MLXLMCommon/Adapters/LoRA/LoRAContainer.swift b/Libraries/MLXLMCommon/Adapters/LoRA/LoRAContainer.swift index 12951bfc..b2fbb0e2 100644 --- a/Libraries/MLXLMCommon/Adapters/LoRA/LoRAContainer.swift +++ b/Libraries/MLXLMCommon/Adapters/LoRA/LoRAContainer.swift @@ -8,6 +8,7 @@ import Foundation import MLX import MLXNN +import ReerCodable /// Configuration for how LoRA or DoRA should be applied. /// @@ -24,43 +25,24 @@ import MLXNN /// } /// } /// ``` -public struct LoRAConfiguration: Codable { +@Codable +public struct LoRAConfiguration { public enum FineTuneType: String, Codable { case lora case dora } - public struct LoRAParameters: Codable { + @Codable + public struct LoRAParameters { - public let rank: Int - public let scale: Float - - public init(rank: Int = 8, scale: Float = 10.0) { - self.rank = rank - self.scale = scale - } - } - - public let numLayers: Int - public let fineTuneType: FineTuneType - public let loraParameters: LoRAParameters - - public init( - numLayers: Int = 16, - fineTuneType: FineTuneType = .lora, - loraParameters: LoRAParameters = .init() - ) { - self.numLayers = numLayers - self.fineTuneType = fineTuneType - self.loraParameters = loraParameters + public var rank = 8 + public var scale: Float = 10.0 } - enum CodingKeys: String, CodingKey { - case numLayers = "num_layers" - case fineTuneType = "fine_tune_type" - case loraParameters = "lora_parameters" - } + @CodingKey("num_layers") public var numLayers: Int = 16 + @CodingKey("fine_tune_type") public var fineTuneType: FineTuneType = .lora + @CodingKey("lora_parameters") public var loraParameters = LoRAParameters() } /// A container for managing LoRA or DoRA adapters and applying them to a language model. diff --git a/Libraries/MLXLMCommon/AttentionUtils.swift b/Libraries/MLXLMCommon/AttentionUtils.swift index d00a76c2..d9a285a0 100644 --- a/Libraries/MLXLMCommon/AttentionUtils.swift +++ b/Libraries/MLXLMCommon/AttentionUtils.swift @@ -1,6 +1,5 @@ import Foundation import MLX -import MLXFast /// Attention utilities that match Python mlx-lm's interface /// diff --git a/Libraries/MLXLMCommon/BaseConfiguration.swift b/Libraries/MLXLMCommon/BaseConfiguration.swift index e9c0ed18..7899b38d 100644 --- a/Libraries/MLXLMCommon/BaseConfiguration.swift +++ b/Libraries/MLXLMCommon/BaseConfiguration.swift @@ -1,36 +1,26 @@ // Copyright © 2025 Apple Inc. import Foundation +import ReerCodable /// Base ``LanguageModel`` configuration -- provides `modelType` /// and `quantization` (used in loading the model). /// /// This is used by ``ModelFactory/load(hub:configuration:progressHandler:)`` /// to determine the type of model to load. -public struct BaseConfiguration: Codable, Sendable { - public let modelType: String +@Codable(memberwiseInit: false) +public struct BaseConfiguration: Sendable { + @CodingKey("model_type") public let modelType: String - public struct Quantization: Codable, Sendable, Equatable { - public init(groupSize: Int, bits: Int) { - self.groupSize = groupSize - self.bits = bits - } - - public let groupSize: Int + @Codable + public struct Quantization: Sendable, Equatable { + @CodingKey("group_size") public let groupSize: Int public let bits: Int - public var quantMethod: String? = nil - public var linearClass: String? = nil - public var quantizationMode: String? = nil + @CodingKey("quant_method") public var quantMethod: String? = nil + @CodingKey("linear_class") public var linearClass: String? = nil + @CodingKey("quantization_mode") public var quantizationMode: String? = nil public var asTuple: (Int, Int) { (groupSize, bits) } - - enum CodingKeys: String, CodingKey { - case groupSize = "group_size" - case bits = "bits" - case quantMethod = "quant_method" - case linearClass = "linear_class" - case quantizationMode = "quantization_mode" - } } /// handling instructions for ``PerLayerQuantization`` @@ -83,39 +73,26 @@ public struct BaseConfiguration: Codable, Sendable { /// /// This mixed type structure requires manual decoding. struct QuantizationContainer: Codable, Sendable { - var quantization: Quantization + var quantization: Quantization? var perLayerQuantization: PerLayerQuantization - // based on Dictionary's coding key - internal struct _DictionaryCodingKey: CodingKey { - internal let stringValue: String - internal let intValue: Int? - - internal init(stringValue: String) { - self.stringValue = stringValue - self.intValue = Int(stringValue) - } - - internal init(intValue: Int) { - self.stringValue = "\(intValue)" - self.intValue = intValue - } + internal init(quantization: Quantization?, perLayerQuantization: PerLayerQuantization) { + self.quantization = quantization + self.perLayerQuantization = perLayerQuantization } init(from decoder: any Decoder) throws { // handle the embedded Quantization - self.quantization = try Quantization(from: decoder) + self.quantization = try? Quantization(from: decoder) // and the interleaved per-layer values var perLayerQuantization = [String: QuantizationOption]() - let container = try decoder.container(keyedBy: _DictionaryCodingKey.self) + let container = try decoder.container(keyedBy: AnyCodingKey.self) for key in container.allKeys { switch key.stringValue { - case Quantization.CodingKeys.groupSize.rawValue: continue - case Quantization.CodingKeys.bits.rawValue: continue - case Quantization.CodingKeys.quantMethod.rawValue: continue - case Quantization.CodingKeys.linearClass.rawValue: continue - case Quantization.CodingKeys.quantizationMode.rawValue: continue + // ignore keys that belong to Quantization + case "group_size", "bits": continue + case "quant_method", "linear_class", "quantization_mode": continue default: if let f = try? container.decode(Bool.self, forKey: key) { @@ -135,19 +112,20 @@ public struct BaseConfiguration: Codable, Sendable { func encode(to encoder: any Encoder) throws { try quantization.encode(to: encoder) - var container = encoder.container(keyedBy: _DictionaryCodingKey.self) + var container = encoder.container(keyedBy: AnyCodingKey.self) for (key, value) in perLayerQuantization.perLayerQuantization { + guard let key = AnyCodingKey(stringValue: key) else { continue } switch value { case .skip: - try container.encode(false, forKey: .init(stringValue: key)) + try container.encode(false, forKey: key) case .quantize(let q): - try container.encode(q, forKey: .init(stringValue: key)) + try container.encode(q, forKey: key) } } } } - var quantizationContainer: QuantizationContainer? + @CodingKey("quantization") var quantizationContainer: QuantizationContainer? @available(*, deprecated, message: "Please use perLayerQuantization instead") public var quantization: Quantization? { @@ -158,8 +136,13 @@ public struct BaseConfiguration: Codable, Sendable { quantizationContainer?.perLayerQuantization } - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case quantizationContainer = "quantization" + public init( + modelType: String, quantization: Quantization? = nil, + perLayerQuantization: PerLayerQuantization? = nil + ) { + self.modelType = modelType + self.quantizationContainer = QuantizationContainer( + quantization: quantization, + perLayerQuantization: perLayerQuantization ?? .init(perLayerQuantization: [:])) } } diff --git a/Libraries/MLXLMCommon/Documentation.docc/porting.md b/Libraries/MLXLMCommon/Documentation.docc/porting.md index ff42ab64..f9e194e2 100644 --- a/Libraries/MLXLMCommon/Documentation.docc/porting.md +++ b/Libraries/MLXLMCommon/Documentation.docc/porting.md @@ -58,38 +58,24 @@ This will be loaded from a JSON file and used to configure the model, including This translates naturally into a `Codable` struct in Swift with a few details: -- The keys in the JSON file will be `snake_case`. The simplest way to accommodate that is to specify `CodingKeys` to name them explicitly. +- The keys in the JSON file will be `snake_case`. If using ReerCodable (recommended) you can explicitly give the coding key name as shown or by annotating the property or type with `@SnakeCase`. - Some of the parameters have default values. ```swift -public struct GemmaConfiguration: Codable, Sendable { - var modelType: String - var hiddenSize: Int - var hiddenLayers: Int - var intermediateSize: Int - var attentionHeads: Int - var headDimensions: Int - var rmsNormEps: Float - var vocabularySize: Int - var kvHeads: Int - private let _ropeTheta: Float? - public var ropeTheta: Float { _ropeTheta ?? 10_000 } - private let _ropeTraditional: Bool? - public var ropeTraditional: Bool { _ropeTraditional ?? false } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case headDimensions = "head_dim" - case rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case _ropeTheta = "rope_theta" - case _ropeTraditional = "rope_traditional" - } +import ReerCodable + +@Codable +public struct GemmaConfiguration: Sendable { + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("head_dim") public var headDimensions: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false } ``` diff --git a/Libraries/MLXLMCommon/Models/Gemma.swift b/Libraries/MLXLMCommon/Models/Gemma.swift index d5d1bc7b..c94af958 100644 --- a/Libraries/MLXLMCommon/Models/Gemma.swift +++ b/Libraries/MLXLMCommon/Models/Gemma.swift @@ -7,7 +7,6 @@ import Foundation import MLX -import MLXFast import MLXNN public enum Gemma { diff --git a/Libraries/MLXVLM/Codable+Support.swift b/Libraries/MLXVLM/Codable+Support.swift new file mode 100644 index 00000000..84329926 --- /dev/null +++ b/Libraries/MLXVLM/Codable+Support.swift @@ -0,0 +1,5 @@ +import Foundation + +/// `swift-transformers` also declares a public `Decoder` and it conflicts with the `Codable` +/// implementations. +public typealias Decoder = Swift.Decoder diff --git a/Libraries/MLXVLM/Models/Gemma3.swift b/Libraries/MLXVLM/Models/Gemma3.swift index b1a30f2a..0f81b14a 100644 --- a/Libraries/MLXVLM/Models/Gemma3.swift +++ b/Libraries/MLXVLM/Models/Gemma3.swift @@ -1,147 +1,84 @@ import CoreImage import MLX -import MLXFast import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers // Based on https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/gemma3 // MARK: - Text Configuration -public struct Gemma3TextConfiguration: Codable, Sendable { - public let modelType: String - public let hiddenSize: Int - public let hiddenLayers: Int - public let intermediateSize: Int - public let slidingWindow: Int - public let ropeScaling: [String: StringOrNumber]? - public let finalLogitSoftcapping: Float? +@Codable +public struct Gemma3TextConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("sliding_window") public var slidingWindow: Int + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("final_logit_softcapping") public var finalLogitSoftcapping: Float? - public let vocabularySize: Int = 262208 - public let rmsNormEps: Float = 1.0e-6 + @CodingKey("vocab_size") public var vocabularySize: Int = 262208 + @CodingKey("rms_norm_eps") public var rmsNormEps: Float = 1.0e-6 // Decoded from JSON when present, with fallback if not - - private let _attentionHeads: Int? - private let _kvHeads: Int? - private let _headDim: Int? - private let _queryPreAttnScalar: Float? - - // Not included in 4B model config.json, included for 12B and 27B models - public var attentionHeads: Int { - _attentionHeads ?? 8 - } - - // Not included in 4B model config.json, included for 12B and 27B models - public var kvHeads: Int { - _kvHeads ?? 4 - } - - // Not included in 4B and 12B model config.json, included for 27B model - public var headDim: Int { - _headDim ?? 256 - } - - // Not included in 4B and 12B model config.json, included for 27B model - public var queryPreAttnScalar: Float { - _queryPreAttnScalar ?? 256 - } - - public let ropeGlobalBaseFreq: Float = 1_000_000.0 - public let ropeLocalBaseFreq: Float = 10_000.0 - public let ropeTraditional: Bool = false - public let mmTokensPerImage: Int = 256 - public let slidingWindowPattern: Int = 6 - public let maxPositionEmbeddings: Int = 4096 - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case slidingWindow = "sliding_window" - case ropeScaling = "rope_scaling" - case finalLogitSoftcapping = "final_logit_softcapping" - case _attentionHeads = "num_attention_heads" - case _kvHeads = "num_key_value_heads" - case _headDim = "head_dim" - case _queryPreAttnScalar = "query_pre_attn_scalar" - } + @CodingKey("num_attention_heads") public var attentionHeads: Int = 8 + @CodingKey("num_key_value_heads") public var kvHeads: Int = 4 + @CodingKey("head_dim") public var headDim: Int = 256 + @CodingKey("query_pre_attn_scalar") public var queryPreAttnScalar: Float = 256 + + @CodingKey("rope_global_base_freq") public var ropeGlobalBaseFreq: Float = 1_000_000.0 + @CodingKey("rope_local_base_freq") public var ropeLocalBaseFreq: Float = 10_000.0 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("mm_tokens_per_image") public var mmTokensPerImage: Int = 256 + @CodingKey("sliding_window_pattern") public var slidingWindowPattern: Int = 6 + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 4096 } // MARK: - Vision Configuration -public struct Gemma3VisionConfiguration: Codable, Sendable { - public let modelType: String - public let hiddenLayers: Int - public let hiddenSize: Int - public let intermediateSize: Int - public let attentionHeads: Int - public let patchSize: Int - public let imageSize: Int - - public let numChannels: Int = 3 - public let layerNormEps: Float = 1e-6 - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenLayers = "num_hidden_layers" - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case patchSize = "patch_size" - case imageSize = "image_size" - } +@Codable +public struct Gemma3VisionConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("patch_size") public var patchSize: Int + @CodingKey("image_size") public var imageSize: Int = 224 + + @CodingKey("num_channels") public var numChannels: Int = 3 + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-6 } // MARK: - Quantization Configuration -public struct QuantizationConfig: Codable, Sendable { - public let groupSize: Int - public let bits: Int - - enum CodingKeys: String, CodingKey { - case groupSize = "group_size" - case bits - } +@Codable +public struct QuantizationConfig: Sendable { + @CodingKey("group_size") public var groupSize: Int + public var bits: Int } // MARK: - Model Configuration -public struct Gemma3Configuration: Codable, Sendable { - public let textConfiguration: Gemma3TextConfiguration - public let visionConfiguration: Gemma3VisionConfiguration - public let modelType: String - public let mmTokensPerImage: Int - public let quantization: QuantizationConfig? - - private let _vocabularySize: Int? - private let _padTokenId: Int? - - // Computed properties that use the text configuration or provide defaults - - public var vocabularySize: Int { - _vocabularySize ?? textConfiguration.vocabularySize - } - - public var hiddenSize: Int { - textConfiguration.hiddenSize - } - - public var padTokenId: Int { - _padTokenId ?? 0 - } - - enum CodingKeys: String, CodingKey { - case textConfiguration = "text_config" - case visionConfiguration = "vision_config" - case modelType = "model_type" - case mmTokensPerImage = "mm_tokens_per_image" - case quantization - - case _vocabularySize = "vocab_size" - case _padTokenId = "pad_token_id" +@Codable +public struct Gemma3Configuration: Sendable { + @CodingKey("text_config") public var textConfiguration: Gemma3TextConfiguration + @CodingKey("vision_config") public var visionConfiguration: Gemma3VisionConfiguration + @CodingKey("model_type") public var modelType: String + @CodingKey("mm_tokens_per_image") public var mmTokensPerImage: Int + public var quantization: QuantizationConfig? + + @CodingKey("vocab_size") public var vocabularySize: Int = 257152 + @CodingKey("ignore_index") public var ignoreIndex: Int = -100 + @CodingKey("image_token_index") public var imageTokenIndex: Int = 262144 + @CodingKey("hidden_size") public var hiddenSize: Int = 2048 + @CodingKey("pad_token_id") public var padTokenId: Int = 0 + @CodingKey("eos_token_id") public var eosTokenId: [Int]? + + public var textVocabularySize: Int { + textConfiguration.vocabularySize } } @@ -624,7 +561,7 @@ private class EncoderLayer: Module { } } -private class Encoder: Module { +private class GemmaEncoder: Module { @ModuleInfo var layers: [EncoderLayer] init(config: Gemma3VisionConfiguration) { @@ -715,12 +652,12 @@ private class VisionEmbeddings: Module, UnaryLayer { private class SigLipVisionModel: Module { @ModuleInfo var embeddings: VisionEmbeddings - @ModuleInfo var encoder: Encoder + @ModuleInfo var encoder: GemmaEncoder @ModuleInfo(key: "post_layernorm") var postLayerNorm: LayerNorm init(config: Gemma3VisionConfiguration) { self.embeddings = VisionEmbeddings(config: config) - self.encoder = Encoder(config: config) + self.encoder = GemmaEncoder(config: config) self._postLayerNorm.wrappedValue = LayerNorm(dimensions: config.hiddenSize) super.init() } @@ -920,7 +857,7 @@ public class Gemma3: Module, VLMModel, KVCacheDimensionProvider { public let config: Gemma3Configuration - public var vocabularySize: Int { config.vocabularySize } + public var vocabularySize: Int { config.textVocabularySize } public var kvHeads: [Int] { languageModel.kvHeads } /// Create cache with proper types for each layer @@ -1159,33 +1096,35 @@ public class Gemma3Processor: UserInputProcessor { } } -public struct Gemma3ProcessorConfiguration: Codable, Sendable { +@Codable +public struct Gemma3ProcessorConfiguration: Sendable { // Fields from the preprocessor_config.json - public let processorClass: String - public let imageProcessorType: String - public let doNormalize: Bool - public let doRescale: Bool - public let doResize: Bool - public let imageMean: [CGFloat] - public let imageStd: [CGFloat] - public let imageSeqLength: Int - public let resample: Int - public let rescaleFactor: Float - public let size: ImageSize + @CodingKey("processor_class") public var processorClass: String + @CodingKey("image_processor_type") public var imageProcessorType: String + @CodingKey("do_normalize") public var doNormalize: Bool + @CodingKey("do_rescale") public var doRescale: Bool + @CodingKey("do_resize") public var doResize: Bool + @CodingKey("image_mean") public var imageMean: [CGFloat] + @CodingKey("image_std") public var imageStd: [CGFloat] + @CodingKey("image_seq_length") public var imageSeqLength: Int + public var resample: Int + @CodingKey("rescale_factor") public var rescaleFactor: Float + public var size: ImageSize // Optional fields - public let doConvertRgb: Bool? - public let doPanAndScan: Bool? - public let panAndScanMaxNumCrops: Int? - public let panAndScanMinCropSize: Int? - public let panAndScanMinRatioToActivate: Float? + @CodingKey("do_convert_rgb") public var doConvertRgb: Bool? + @CodingKey("do_pan_and_scan") public var doPanAndScan: Bool? + @CodingKey("pan_and_scan_max_num_crops") public var panAndScanMaxNumCrops: Int? + @CodingKey("pan_and_scan_min_crop_size") public var panAndScanMinCropSize: Int? + @CodingKey("pan_and_scan_min_ratio_to_activate") public var panAndScanMinRatioToActivate: Float? // Image token identifier from model configuration - public let imageTokenId: Int = 262144 + public var imageTokenId: Int = 262144 - public struct ImageSize: Codable, Sendable { - public let height: Int - public let width: Int + @Codable + public struct ImageSize: Sendable { + public var height: Int + public var width: Int } // Computed properties for convenience @@ -1198,25 +1137,6 @@ public struct Gemma3ProcessorConfiguration: Codable, Sendable { public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { (imageStd[0], imageStd[1], imageStd[2]) } - - enum CodingKeys: String, CodingKey { - case processorClass = "processor_class" - case imageProcessorType = "image_processor_type" - case doNormalize = "do_normalize" - case doRescale = "do_rescale" - case doResize = "do_resize" - case doConvertRgb = "do_convert_rgb" - case doPanAndScan = "do_pan_and_scan" - case imageMean = "image_mean" - case imageStd = "image_std" - case imageSeqLength = "image_seq_length" - case resample - case rescaleFactor = "rescale_factor" - case size - case panAndScanMaxNumCrops = "pan_and_scan_max_num_crops" - case panAndScanMinCropSize = "pan_and_scan_min_crop_size" - case panAndScanMinRatioToActivate = "pan_and_scan_min_ratio_to_activate" - } } extension Gemma3: LoRAModel { diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift index 73d8f766..a533fb94 100644 --- a/Libraries/MLXVLM/Models/Idefics3.swift +++ b/Libraries/MLXVLM/Models/Idefics3.swift @@ -11,110 +11,50 @@ import Hub import MLX import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers // MARK: - Configuration -public struct Idefics3Configuration: Codable, Sendable { - - public struct TextConfiguration: Codable, Sendable { - public let modelType: String - public let hiddenSize: Int - public var numHiddenLayers: Int { _numHiddenLayers ?? 32 } - public let intermediateSize: Int - public let numAttentionHeads: Int - public let rmsNormEps: Float - public let vocabSize: Int - public let numKeyValueHeads: Int - public let ropeTheta: Float - public var ropeTraditional: Bool { _ropeTraditional ?? false } - public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? false } - - private let _numHiddenLayers: Int? - private let _ropeTraditional: Bool? - private let _tieWordEmbeddings: Bool? - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case _numHiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case numAttentionHeads = "num_attention_heads" - case rmsNormEps = "rms_norm_eps" - case vocabSize = "vocab_size" - case numKeyValueHeads = "num_key_value_heads" - case ropeTheta = "rope_theta" - case _ropeTraditional = "rope_traditional" - case _tieWordEmbeddings = "tie_word_embeddings" - } - } - - public struct VisionConfiguration: Codable, Sendable { - public let modelType: String - public var numHiddenLayers: Int { _numHiddenLayers ?? 12 } - public let hiddenSize: Int - public var intermediateSize: Int { _intermediateSize ?? 3072 } - public let numAttentionHeads: Int - public let patchSize: Int - public let imageSize: Int - public var numChannels: Int { _numChannels ?? 3 } - public var layerNormEps: Float { _layerNormEps ?? 1e-6 } - - private let _numHiddenLayers: Int? - private let _intermediateSize: Int? - private let _numChannels: Int? - private let _layerNormEps: Float? - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case _numHiddenLayers = "num_hidden_layers" - case hiddenSize = "hidden_size" - case _intermediateSize = "intermediate_size" - case numAttentionHeads = "num_attention_heads" - case patchSize = "patch_size" - case imageSize = "image_size" - case _numChannels = "num_channels" - case _layerNormEps = "layer_norm_eps" - } - } - - public let textConfig: TextConfiguration - public let visionConfig: VisionConfiguration - public let modelType: String - public let ignoreIndex: Int - public let vocabSize: Int - public let scaleFactor: Int - public let imageTokenId: Int - public let imageTokenIndex: Int - - enum CodingKeys: String, CodingKey { - case textConfig = "text_config" - case visionConfig = "vision_config" - case modelType = "model_type" - case ignoreIndex = "ignore_index" - case vocabSize = "vocab_size" - case scaleFactor = "scale_factor" - case imageTokenId = "image_token_id" - case imageTokenIndex = "image_token_index" - } - - public init(from decoder: any Swift.Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - - self.textConfig = - try container - .decode(TextConfiguration.self, forKey: .textConfig) - self.visionConfig = - try container - .decode(VisionConfiguration.self, forKey: .visionConfig) - self.modelType = try container.decode(String.self, forKey: .modelType) - self.ignoreIndex = (try? container.decode(Int.self, forKey: .ignoreIndex)) ?? -100 - self.vocabSize = (try? container.decode(Int.self, forKey: .vocabSize)) ?? 128259 - self.scaleFactor = (try? container.decode(Int.self, forKey: .scaleFactor)) ?? 2 - self.imageTokenId = (try? container.decode(Int.self, forKey: .imageTokenId)) ?? 49153 - self.imageTokenIndex = - (try? container.decode(Int.self, forKey: .imageTokenIndex)) ?? self.imageTokenId - } +@Codable +public struct Idefics3Configuration: Sendable { + + @Codable + public struct TextConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var numHiddenLayers: Int = 32 + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var numAttentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float + @CodingKey("vocab_size") public var vocabSize: Int + @CodingKey("num_key_value_heads") public var numKeyValueHeads: Int + @CodingKey("rope_theta") public var ropeTheta: Float + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = false + } + + @Codable + public struct VisionConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("num_hidden_layers") public var numHiddenLayers: Int = 12 + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("intermediate_size") public var intermediateSize: Int = 3072 + @CodingKey("num_attention_heads") public var numAttentionHeads: Int + @CodingKey("patch_size") public var patchSize: Int + @CodingKey("image_size") public var imageSize: Int + @CodingKey("num_channels") public var numChannels: Int = 3 + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-6 + } + + @CodingKey("text_config") public var textConfig: TextConfiguration + @CodingKey("vision_config") public var visionConfig: VisionConfiguration + @CodingKey("model_type") public var modelType: String + @CodingKey("ignore_index") public var ignoreIndex: Int = -10 + @CodingKey("vocab_size") public var vocabSize: Int = 128259 + @CodingKey("scale_factor") public var scaleFactor: Int = 2 + @CodingKey("image_token_id") public var imageTokenId: Int = 49153 + @CodingKey("image_token_index", "image_token_id") public var imageTokenIndex: Int } // MARK: - Connector @@ -776,18 +716,18 @@ public class Idefics3: Module, VLMModel, KVCacheDimensionProvider { } // MARK: - Processor Configuration -public struct Idefics3ProcessorConfiguration: Codable, Sendable { - public struct Size: Codable, Sendable { - public let longestEdge: Int - enum CodingKeys: String, CodingKey { - case longestEdge = "longest_edge" - } +@Codable +public struct Idefics3ProcessorConfiguration: Sendable { + + @Codable + public struct Size: Sendable { + @CodingKey("longest_edge") public var longestEdge: Int } - public let imageMean: [CGFloat] - public let imageStd: [CGFloat] - public let size: Size - public let imageSequenceLength: Int? + @CodingKey("image_mean") public var imageMean: [CGFloat] + @CodingKey("image_std") public var imageStd: [CGFloat] + public var size: Size + @CodingKey("image_seq_len") public var imageSequenceLength: Int? public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { (imageMean[0], imageMean[1], imageMean[2]) @@ -795,13 +735,6 @@ public struct Idefics3ProcessorConfiguration: Codable, Sendable { public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { (imageStd[0], imageStd[1], imageStd[2]) } - - enum CodingKeys: String, CodingKey { - case imageMean = "image_mean" - case imageStd = "image_std" - case size - case imageSequenceLength = "image_seq_len" - } } // MARK: - Processor diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift index 7f128162..74a6b796 100644 --- a/Libraries/MLXVLM/Models/Paligemma.swift +++ b/Libraries/MLXVLM/Models/Paligemma.swift @@ -8,6 +8,7 @@ import Hub import MLX import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers // MARK: - Language @@ -629,100 +630,63 @@ public class PaliGemma: Module, VLMModel, KVCacheDimensionProvider { // MARK: - Configuration /// Confguration for ``PaliGemma`` -public struct PaliGemmaConfiguration: Codable, Sendable { - - public struct TextConfiguration: Codable, Sendable { - public let modelType: String - public let hiddenSize: Int - public let hiddenLayers: Int - public let intermediateSize: Int - public let attentionHeads: Int - public let kvHeads: Int - public let vocabularySize: Int - private let _rmsNormEps: Float? - public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } - private let _ropeTheta: Float? - public var ropeTheta: Float { _ropeTheta ?? 10_000 } - private let _ropeTraditional: Bool? - public var ropeTraditional: Bool { _ropeTraditional ?? false } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case kvHeads = "num_key_value_heads" - case vocabularySize = "vocab_size" - case _rmsNormEps = "rms_norm_eps" - case _ropeTheta = "rope_theta" - case _ropeTraditional = "rope_traditional" - } - } - - public struct VisionConfiguration: Codable, Sendable { - public let modelType: String - public let hiddenSize: Int - public let hiddenLayers: Int - public let intermediateSize: Int - public let attentionHeads: Int - public let patchSize: Int - public let projectionDimensions: Int - public let imageSize: Int - private let _channels: Int? - public var channels: Int { _channels ?? 3 } - private let _layerNormEps: Float? - public var layerNormEps: Float { _layerNormEps ?? 1e-6 } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case patchSize = "patch_size" - case projectionDimensions = "projection_dim" - case imageSize = "image_size" - case _channels = "num_channels" - case _layerNormEps = "layer_norm_eps" - } - } - - public let textConfiguration: TextConfiguration - public let visionConfiguration: VisionConfiguration - public let modelType: String - public let vocabularySize: Int - public let ignoreIndex: Int - public let imageTokenIndex: Int - public let hiddenSize: Int - public let padTokenId: Int - - enum CodingKeys: String, CodingKey { - case textConfiguration = "text_config" - case visionConfiguration = "vision_config" - case modelType = "model_type" - case vocabularySize = "vocab_size" - case ignoreIndex = "ignore_index" - case imageTokenIndex = "image_token_index" - case hiddenSize = "hidden_size" - case padTokenId = "pad_token_id" - } +@Codable +public struct PaliGemmaConfiguration: Sendable { + + @Codable + public struct TextConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float = 1e-6 + @CodingKey("rope_theta") public var ropeTheta: Float = 10_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + } + + @Codable + public struct VisionConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("patch_size") public var patchSize: Int + @CodingKey("projection_dim") public var projectionDimensions: Int + @CodingKey("image_size") public var imageSize: Int + @CodingKey("num_channels") public var channels: Int = 3 + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-6 + } + + @CodingKey("text_config") public var textConfiguration: TextConfiguration + @CodingKey("vision_config") public var visionConfiguration: VisionConfiguration + @CodingKey("model_type") public var modelType: String + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("ignore_index") public var ignoreIndex: Int + @CodingKey("image_token_index") public var imageTokenIndex: Int + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("pad_token_id") public var padTokenId: Int } /// Configuration for ``PaliGemmaProcessor`` -public struct PaliGemmaProcessorConfiguration: Codable, Sendable { +@Codable +public struct PaliGemmaProcessorConfiguration: Sendable { - public struct Size: Codable, Sendable { - public let width: Int - public let height: Int + @Codable + public struct Size: Sendable { + public var width: Int + public var height: Int var cgSize: CGSize { .init(width: width, height: height) } } - public let imageMean: [CGFloat] - public let imageStd: [CGFloat] - public let size: Size - public let imageSequenceLength: Int + @CodingKey("image_mean") public var imageMean: [CGFloat] + @CodingKey("image_std") public var imageStd: [CGFloat] + public var size: Size + @CodingKey("image_seq_length") public var imageSequenceLength: Int public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { (imageMean[0], imageMean[1], imageMean[2]) @@ -730,11 +694,4 @@ public struct PaliGemmaProcessorConfiguration: Codable, Sendable { public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { (imageStd[0], imageStd[1], imageStd[2]) } - - enum CodingKeys: String, CodingKey { - case imageMean = "image_mean" - case imageStd = "image_std" - case size - case imageSequenceLength = "image_seq_length" - } } diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index 29ba9fd7..45d69a0a 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -6,6 +6,7 @@ import Hub import MLX import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers // MARK: - Language @@ -915,125 +916,62 @@ public class Qwen25VL: Module, VLMModel, KVCacheDimensionProvider { /// Configuration for ``Qwen25VL`` public struct Qwen25VLConfiguration: Codable, Sendable { - public struct TextConfiguration: Codable, Sendable { - public let modelType: String - public let hiddenSize: Int - public let hiddenLayers: Int - public let intermediateSize: Int - public let attentionHeads: Int - private let _rmsNormEps: Float? - public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } - public let vocabularySize: Int - public let kvHeads: Int - private let _maxPositionEmbeddings: Int? - public var maxPositionEmbeddings: Int { _maxPositionEmbeddings ?? 128000 } - private let _ropeTheta: Float? - public var ropeTheta: Float { _ropeTheta ?? 1_000_000 } - private let _ropeTraditional: Bool? - public var ropeTraditional: Bool { _ropeTraditional ?? false } - public let ropeScaling: [String: StringOrNumber]? - private let _tieWordEmbeddings: Bool? - public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true } - private let _slidingWindow: Int? - public var slidingWindow: Int { _slidingWindow ?? 32768 } - private let _useSlidingWindow: Bool? - public var useSlidingWindow: Bool { _useSlidingWindow ?? false } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case _rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case _maxPositionEmbeddings = "max_position_embeddings" - case _ropeTheta = "rope_theta" - case _ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case _tieWordEmbeddings = "tie_word_embeddings" - case _slidingWindow = "sliding_window" - case _useSlidingWindow = "use_sliding_window" - } + @Codable + public struct TextConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float = 1e-6 + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 128000 + @CodingKey("rope_theta") public var ropeTheta: Float = 1_000_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true + @CodingKey("sliding_window") public var slidingWindow: Int = 32768 + @CodingKey("use_sliding_window") public var useSlidingWindow: Bool = false } - public struct VisionConfiguration: Codable, Sendable { - public let depth: Int - public let hiddenSize: Int - public let intermediateSize: Int - public let outHiddenSize: Int - public let numHeads: Int - public let patchSize: Int - private let _inChans: Int? - public var inChannels: Int { _inChans ?? 3 } - private let _layerNormEps: Float? - public var layerNormEps: Float { _layerNormEps ?? 1e-6 } - public let spatialPatchSize: Int - public let spatialMergeSize: Int - public let temporalPatchSize: Int - public let windowSize: Int - public let fullattBlockIndexes: [Int] - public let tokensPerSecond: Int - private let _skipVision: Bool? - public var skipVision: Bool { _skipVision ?? false } - private let _hiddenAct: String? - public var hiddenAct: String { _hiddenAct ?? "silu" } - - enum CodingKeys: String, CodingKey { - case depth - case hiddenSize = "hidden_size" - case intermediateSize = "intermediate_size" - case outHiddenSize = "out_hidden_size" - case numHeads = "num_heads" - case patchSize = "patch_size" - case _inChans = "in_chans" - case _layerNormEps = "layer_norm_eps" // Added this line - case spatialPatchSize = "spatial_patch_size" - case spatialMergeSize = "spatial_merge_size" - case temporalPatchSize = "temporal_patch_size" - case windowSize = "window_size" - case fullattBlockIndexes = "fullatt_block_indexes" - case tokensPerSecond = "tokens_per_second" - case _skipVision = "skip_vision" - case _hiddenAct = "hidden_act" - } + @Codable + public struct VisionConfiguration: Sendable { + public var depth: Int + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("out_hidden_size") public var outHiddenSize: Int + @CodingKey("num_heads") public var numHeads: Int + @CodingKey("patch_size") public var patchSize: Int + @CodingKey("in_chans") public var inChannels: Int = 3 + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-6 + @CodingKey("spatial_patch_size") public var spatialPatchSize: Int + @CodingKey("spatial_merge_size") public var spatialMergeSize: Int + @CodingKey("temporal_patch_size") public var temporalPatchSize: Int + @CodingKey("window_size") public var windowSize: Int + @CodingKey("fullatt_block_indexes") public var fullattBlockIndexes: [Int] + @CodingKey("tokens_per_second") public var tokensPerSecond: Int + @CodingKey("skip_vision") public var skipVision: Bool = false + @CodingKey("hidden_act") public var hiddenAct: String = "silu" } + @Codable public struct BaseConfiguration: Codable, Sendable { - public let modelType: String - public let vocabularySize: Int - public let imageTokenId: Int - public let videoTokenId: Int - public let visionStartTokenId: Int - public let visionEndTokenId: Int - public let visionTokenId: Int - public let hiddenSize: Int - public let numAttentionHeads: Int - public let numHiddenLayers: Int - public let intermediateSize: Int - public let numKeyValueHeads: Int - public let slidingWindow: Int - public let useSlidingWindow: Bool - public let maxWindowLayers: Int - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case vocabularySize = "vocab_size" - case imageTokenId = "image_token_id" - case videoTokenId = "video_token_id" - case visionStartTokenId = "vision_start_token_id" - case visionEndTokenId = "vision_end_token_id" - case visionTokenId = "vision_token_id" - case hiddenSize = "hidden_size" - case numAttentionHeads = "num_attention_heads" - case numHiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case numKeyValueHeads = "num_key_value_heads" - case slidingWindow = "sliding_window" - case useSlidingWindow = "use_sliding_window" - case maxWindowLayers = "max_window_layers" - } + @CodingKey("model_type") public var modelType: String + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("image_token_id") public var imageTokenId: Int + @CodingKey("video_token_id") public var videoTokenId: Int + @CodingKey("vision_start_token_id") public var visionStartTokenId: Int + @CodingKey("vision_end_token_id") public var visionEndTokenId: Int + @CodingKey("vision_token_id") public var visionTokenId: Int + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_attention_heads") public var numAttentionHeads: Int + @CodingKey("num_hidden_layers") public var numHiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_key_value_heads") public var numKeyValueHeads: Int + @CodingKey("sliding_window") public var slidingWindow: Int + @CodingKey("use_sliding_window") public var useSlidingWindow: Bool + @CodingKey("max_window_layers") public var maxWindowLayers: Int } public let textConfiguration: TextConfiguration @@ -1055,6 +993,14 @@ public struct Qwen25VLConfiguration: Codable, Sendable { self.textConfiguration = try TextConfiguration(from: decoder) self.baseConfiguration = try BaseConfiguration(from: decoder) } + + public func encode(to encoder: any Encoder) throws { + var container = try encoder.container(keyedBy: CodingKeys.self) + + try container.encode(visionConfiguration, forKey: .visionConfiguration) + try textConfiguration.encode(to: encoder) + try baseConfiguration.encode(to: encoder) + } } /// Configuration for ``Qwen25VLProcessor`` diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 1675deca..40e06544 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -8,6 +8,7 @@ import Hub import MLX import MLXLMCommon import MLXNN +import ReerCodable import Tokenizers // MARK: - Language @@ -756,87 +757,45 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { /// Configuration for ``Qwen2VL`` public struct Qwen2VLConfiguration: Codable, Sendable { - public struct TextConfiguration: Codable, Sendable { - public let modelType: String - public let hiddenSize: Int - public let hiddenLayers: Int - public let intermediateSize: Int - public let attentionHeads: Int - private let _rmsNormEps: Float? - public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } - public let vocabularySize: Int - public let kvHeads: Int - private let _maxPositionEmbeddings: Int? - public var maxpPositionEmbeddings: Int { _maxPositionEmbeddings ?? 32768 } - private let _ropeTheta: Float? - public var ropeTheta: Float { _ropeTheta ?? 1_000_000 } - private let _ropeTraditional: Bool? - public var ropeTraditional: Bool { _ropeTraditional ?? false } - public let ropeScaling: [String: StringOrNumber]? - private let _tieWordEmbeddings: Bool? - public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true } - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case hiddenSize = "hidden_size" - case hiddenLayers = "num_hidden_layers" - case intermediateSize = "intermediate_size" - case attentionHeads = "num_attention_heads" - case _rmsNormEps = "rms_norm_eps" - case vocabularySize = "vocab_size" - case kvHeads = "num_key_value_heads" - case _maxPositionEmbeddings = "max_position_embeddings" - case _ropeTheta = "rope_theta" - case _ropeTraditional = "rope_traditional" - case ropeScaling = "rope_scaling" - case _tieWordEmbeddings = "tie_word_embeddings" - } - } - - public struct VisionConfiguration: Codable, Sendable { - public let depth: Int - public let embedDimensions: Int - public let hiddenSize: Int - public let numHeads: Int - public let patchSize: Int - public let mlpRatio: Float - public let _inChannels: Int? - public var inChannels: Int { _inChannels ?? 3 } - public let _layerNormEps: Float? - public var layerNormEps: Float { _layerNormEps ?? 1e-6 } - public let spatialPatchSize: Int - public let spatialMergeSize: Int - public let temporalPatchSize: Int - - enum CodingKeys: String, CodingKey { - case depth - case embedDimensions = "embed_dim" - case hiddenSize = "hidden_size" - case numHeads = "num_heads" - case patchSize = "patch_size" - case mlpRatio = "mlp_ratio" - case _inChannels = "in_channels" - case _layerNormEps = "layer_norm_eps" - case spatialPatchSize = "spatial_patch_size" - case spatialMergeSize = "spatial_merge_size" - case temporalPatchSize = "temporal_patch_size" - } - } - - public struct BaseConfiguration: Codable, Sendable { - public let modelType: String - public let vocabularySize: Int - public let imageTokenId: Int - public let videoTokenId: Int - public let hiddenSize: Int - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case vocabularySize = "vocab_size" - case imageTokenId = "image_token_id" - case videoTokenId = "video_token_id" - case hiddenSize = "hidden_size" - } + @Codable + public struct TextConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_hidden_layers") public var hiddenLayers: Int + @CodingKey("intermediate_size") public var intermediateSize: Int + @CodingKey("num_attention_heads") public var attentionHeads: Int + @CodingKey("rms_norm_eps") public var rmsNormEps: Float = 1e-6 + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("num_key_value_heads") public var kvHeads: Int + @CodingKey("max_position_embeddings") public var maxPositionEmbeddings: Int = 32768 + @CodingKey("rope_theta") public var ropeTheta: Float = 1_000_000 + @CodingKey("rope_traditional") public var ropeTraditional: Bool = false + @CodingKey("rope_scaling") public var ropeScaling: [String: StringOrNumber]? + @CodingKey("tie_word_embeddings") public var tieWordEmbeddings: Bool = true + } + + @Codable + public struct VisionConfiguration: Sendable { + public var depth: Int + @CodingKey("embed_dim") public var embedDimensions: Int + @CodingKey("hidden_size") public var hiddenSize: Int + @CodingKey("num_heads") public var numHeads: Int + @CodingKey("patch_size") public var patchSize: Int + @CodingKey("mlp_ratio") public var mlpRatio: Float + @CodingKey("in_channels") public var inChannels: Int = 3 + @CodingKey("layer_norm_eps") public var layerNormEps: Float = 1e-6 + @CodingKey("spatial_patch_size") public var spatialPatchSize: Int + @CodingKey("spatial_merge_size") public var spatialMergeSize: Int + @CodingKey("temporal_patch_size") public var temporalPatchSize: Int + } + + @Codable + public struct BaseConfiguration: Sendable { + @CodingKey("model_type") public var modelType: String + @CodingKey("vocab_size") public var vocabularySize: Int + @CodingKey("image_token_id") public var imageTokenId: Int + @CodingKey("video_token_id") public var videoTokenId: Int + @CodingKey("hidden_size") public var hiddenSize: Int } public let textConfiguration: TextConfiguration @@ -858,30 +817,34 @@ public struct Qwen2VLConfiguration: Codable, Sendable { self.textConfiguration = try TextConfiguration(from: decoder) self.baseConfiguration = try BaseConfiguration(from: decoder) } + + public func encode(to encoder: any Encoder) throws { + var container = try encoder.container(keyedBy: CodingKeys.self) + + try container.encode(visionConfiguration, forKey: .visionConfiguration) + try textConfiguration.encode(to: encoder) + try baseConfiguration.encode(to: encoder) + } } /// Configuration for ``Qwen2VLProcessor`` -public struct Qwen2VLProcessorConfiguration: Codable, Sendable { +@Codable +public struct Qwen2VLProcessorConfiguration: Sendable { - public struct Size: Codable, Sendable { - public let maxPixels: Int - public let minPixels: Int - - enum CodingKeys: String, CodingKey { - case maxPixels = "max_pixels" - case minPixels = "min_pixels" - } + @Codable + public struct Size: Sendable { + @CodingKey("max_pixels") public var maxPixels: Int + @CodingKey("min_pixels") public var minPixels: Int } - public let imageMean: [CGFloat] - public let imageStd: [CGFloat] - public let mergeSize: Int - public let patchSize: Int - public let temporalPatchSize: Int - - private let _size: Size? - private let _maxPixels: Int? - private let _minPixels: Int? + @CodingKey("image_mean") public var imageMean: [CGFloat] + @CodingKey("image_std") public var imageStd: [CGFloat] + @CodingKey("merge_size") public var mergeSize: Int + @CodingKey("patch_size") public var patchSize: Int + @CodingKey("temporal_patch_size") public var temporalPatchSize: Int + @CodingKey("max_pixels") private var _maxPixels: Int? + @CodingKey("min_pixels") private var _minPixels: Int? + @CodingKey("size") private var _size: Size? public var minPixels: Int { _minPixels ?? _size?.minPixels ?? 3136 @@ -896,17 +859,6 @@ public struct Qwen2VLProcessorConfiguration: Codable, Sendable { public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { (imageStd[0], imageStd[1], imageStd[2]) } - - enum CodingKeys: String, CodingKey { - case imageMean = "image_mean" - case imageStd = "image_std" - case mergeSize = "merge_size" - case patchSize = "patch_size" - case temporalPatchSize = "temporal_patch_size" - case _maxPixels = "max_pixels" - case _minPixels = "min_pixels" - case _size = "size" - } } /// Message Generator for Qwen2VL diff --git a/Libraries/MLXVLM/Models/SmolVLM2.swift b/Libraries/MLXVLM/Models/SmolVLM2.swift index b75a9717..d841676f 100644 --- a/Libraries/MLXVLM/Models/SmolVLM2.swift +++ b/Libraries/MLXVLM/Models/SmolVLM2.swift @@ -10,6 +10,7 @@ import CoreMedia import Foundation import MLX import MLXLMCommon +import ReerCodable import Tokenizers // MARK: - Configuration and modeling are Idefics3 @@ -18,47 +19,28 @@ typealias SmolVLM2Configuration = Idefics3Configuration typealias SmolVLM2 = Idefics3 // MARK: - SmolVLMProcessor and configuration +@Codable +public struct SmolVLMProcessorConfiguration: Sendable { -public struct SmolVLMProcessorConfiguration: Codable, Sendable { - public struct Size: Codable, Sendable { - public let longestEdge: Int - enum CodingKeys: String, CodingKey { - case longestEdge = "longest_edge" - } + @Codable + public struct Size: Sendable { + @CodingKey("longest_edge") public var longestEdge: Int } - public struct VideoSampling: Codable, Sendable { - public let fps: Int - public let maxFrames: Int + @Codable + public struct VideoSampling: Sendable { + public var fps: Int + @CodingKey("max_frames") public var maxFrames: Int // Intentionally ignoring videoSize because I believe it's still wrong in the config files // public let videoSize: Size - - enum CodingKeys: String, CodingKey { - case fps - case maxFrames = "max_frames" - } } - public let imageMean: [CGFloat] - public let imageStd: [CGFloat] - public let size: Size - public let maxImageSize: Size - public let videoSampling: VideoSampling - private let _imageSequenceLength: Int? - // TODO: this does not come in preprocessor_config.json, verify where transformers gets it from - public var imageSequenceLength: Int { _imageSequenceLength ?? 64 } - - init( - imageMean: [CGFloat], imageStd: [CGFloat], size: Size, maxImageSize: Size, - videoSampling: VideoSampling, imageSequenceLength: Int? - ) { - self.imageMean = imageMean - self.imageStd = imageStd - self.size = size - self.maxImageSize = maxImageSize - self.videoSampling = videoSampling - self._imageSequenceLength = imageSequenceLength - } + @CodingKey("image_mean") public var imageMean: [CGFloat] + @CodingKey("image_std") public var imageStd: [CGFloat] + public var size: Size + @CodingKey("max_image_size") public var maxImageSize: Size + @CodingKey("video_sampling") public var videoSampling: VideoSampling + @CodingKey("image_seq_len") public var imageSequenceLength: Int = 64 public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { (imageMean[0], imageMean[1], imageMean[2]) @@ -66,15 +48,6 @@ public struct SmolVLMProcessorConfiguration: Codable, Sendable { public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { (imageStd[0], imageStd[1], imageStd[2]) } - - enum CodingKeys: String, CodingKey { - case imageMean = "image_mean" - case imageStd = "image_std" - case size - case maxImageSize = "max_image_size" - case videoSampling = "video_sampling" - case _imageSequenceLength = "image_seq_len" - } } public class SmolVLMProcessor: UserInputProcessor { diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index b3ea12f9..d0bb35db 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -4,6 +4,7 @@ import Foundation import Hub import MLX import MLXLMCommon +import ReerCodable import Tokenizers public enum VLMError: LocalizedError { @@ -37,12 +38,9 @@ public enum VLMError: LocalizedError { } } -public struct BaseProcessorConfiguration: Codable, Sendable { - public let processorClass: String - - enum CodingKeys: String, CodingKey { - case processorClass = "processor_class" - } +@Codable +public struct BaseProcessorConfiguration: Sendable { + @CodingKey("processor_class") public let processorClass: String } /// Creates a function that loads a configuration file and instantiates a model with the proper configuration diff --git a/Libraries/StableDiffusion/Configuration.swift b/Libraries/StableDiffusion/Configuration.swift index c39a06ff..e8afa346 100644 --- a/Libraries/StableDiffusion/Configuration.swift +++ b/Libraries/StableDiffusion/Configuration.swift @@ -7,7 +7,7 @@ import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/config.py /// Configuration for ``Autoencoder`` -struct AutoencoderConfiguration: Codable { +struct AutoencoderConfiguration: Codable, Sendable { public var inputChannels = 3 public var outputChannels = 3 @@ -60,7 +60,7 @@ struct AutoencoderConfiguration: Codable { } /// Configuration for ``CLIPTextModel`` -struct CLIPTextModelConfiguration: Codable { +struct CLIPTextModelConfiguration: Codable, Sendable { public enum ClipActivation: String, Codable { case fast = "quick_gelu" @@ -137,7 +137,7 @@ struct CLIPTextModelConfiguration: Codable { } /// Configuration for ``UNetModel`` -struct UNetConfiguration: Codable { +struct UNetConfiguration: Codable, Sendable { public var inputChannels = 4 public var outputChannels = 4 @@ -250,7 +250,7 @@ struct UNetConfiguration: Codable { } /// Configuration for ``StableDiffusion`` -public struct DiffusionConfiguration: Codable { +public struct DiffusionConfiguration: Codable, Sendable { public enum BetaSchedule: String, Codable { case linear = "linear" diff --git a/Package.resolved b/Package.resolved index ab7ca101..4747d656 100644 --- a/Package.resolved +++ b/Package.resolved @@ -27,6 +27,15 @@ "version" : "0.25.5" } }, + { + "identity" : "reercodable", + "kind" : "remoteSourceControl", + "location" : "https://github.com/reers/ReerCodable.git", + "state" : { + "revision" : "d76657994c28360c5ac9db39fc4e7b18329667fb", + "version" : "1.3.4" + } + }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", @@ -54,6 +63,15 @@ "version" : "1.0.2" } }, + { + "identity" : "swift-syntax", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swiftlang/swift-syntax.git", + "state" : { + "revision" : "f99ae8aa18f0cf0d53481901f88a0991dc3bd4a2", + "version" : "601.0.1" + } + }, { "identity" : "swift-transformers", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 365d7814..fa1fff22 100644 --- a/Package.swift +++ b/Package.swift @@ -31,6 +31,7 @@ let package = Package( .package( url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.23") ), + .package(url: "https://github.com/reers/ReerCodable.git", from: "1.3.4"), .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST ], targets: [ @@ -42,8 +43,8 @@ let package = Package( .product(name: "MLXFast", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "MLXRandom", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), + .product(name: "ReerCodable", package: "ReerCodable"), ], path: "Libraries/MLXLLM", exclude: [ @@ -61,8 +62,8 @@ let package = Package( .product(name: "MLXFast", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "MLXRandom", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), + .product(name: "ReerCodable", package: "ReerCodable"), ], path: "Libraries/MLXVLM", exclude: [ @@ -78,9 +79,8 @@ let package = Package( .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "MLXRandom", package: "mlx-swift"), - .product(name: "MLXLinalg", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), + .product(name: "ReerCodable", package: "ReerCodable"), ], path: "Libraries/MLXLMCommon", exclude: [ @@ -96,7 +96,6 @@ let package = Package( .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "MLXRandom", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), "MLXLMCommon", "MLXLLM", @@ -114,10 +113,9 @@ let package = Package( name: "MLXEmbedders", dependencies: [ .product(name: "MLX", package: "mlx-swift"), - .product(name: "MLXFast", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), - .product(name: "MLXLinalg", package: "mlx-swift"), + .product(name: "ReerCodable", package: "ReerCodable"), ], path: "Libraries/Embedders", exclude: [ @@ -131,7 +129,6 @@ let package = Package( .product(name: "MLXFast", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "MLXRandom", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), .product(name: "Gzip", package: "GzipSwift"), ], @@ -148,8 +145,8 @@ let package = Package( dependencies: [ .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), - .product(name: "MLXRandom", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), + .product(name: "ReerCodable", package: "ReerCodable"), ], path: "Libraries/StableDiffusion", exclude: [ diff --git a/Tools/llm-tool/ListCommands.swift b/Tools/llm-tool/ListCommands.swift index d42db427..3555645f 100644 --- a/Tools/llm-tool/ListCommands.swift +++ b/Tools/llm-tool/ListCommands.swift @@ -26,7 +26,7 @@ struct ListLLMCommand: AsyncParsableCommand { func run() async throws { for configuration in LLMRegistry.shared.models { switch configuration.id { - case .id(let id): print(id) + case .id(let id, let revision): print(id) case .directory: break } } @@ -43,7 +43,7 @@ struct ListVLMCommand: AsyncParsableCommand { func run() async throws { for configuration in VLMRegistry.shared.models { switch configuration.id { - case .id(let id): print(id) + case .id(let id, let revision): print(id) case .directory: break } } diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index cbe9c71b..e470d085 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -1042,6 +1042,7 @@ C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Libraries/.." */, C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */, C32B4C6B2DA7132C00EF663D /* XCRemoteSwiftPackageReference "swift-async-algorithms" */, + C3FF946F2DD54E170070900D /* XCRemoteSwiftPackageReference "ReerCodable" */, ); productRefGroup = C39273752B606A0A00368D5D /* Products */; projectDirPath = ""; @@ -3285,6 +3286,14 @@ minimumVersion = 1.4.0; }; }; + C3FF946F2DD54E170070900D /* XCRemoteSwiftPackageReference "ReerCodable" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/reers/ReerCodable.git"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 1.2.3; + }; + }; /* End XCRemoteSwiftPackageReference section */ /* Begin XCSwiftPackageProductDependency section */ diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 721510b5..d83de827 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "77747b309bf037a12446caffbd79db01f4e13f522d9c91eefae3e4516f93b2ca", + "originHash" : "731b77b3cf1180fd67406fb628ae8041128ca06f724373c6eb2e20a069617c60", "pins" : [ { "identity" : "gzipswift", @@ -46,6 +46,15 @@ "version" : "0.4.0" } }, + { + "identity" : "reercodable", + "kind" : "remoteSourceControl", + "location" : "https://github.com/reers/ReerCodable.git", + "state" : { + "revision" : "d76657994c28360c5ac9db39fc4e7b18329667fb", + "version" : "1.3.4" + } + }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", @@ -100,6 +109,15 @@ "version" : "1.0.3" } }, + { + "identity" : "swift-syntax", + "kind" : "remoteSourceControl", + "location" : "https://github.com/swiftlang/swift-syntax.git", + "state" : { + "revision" : "f99ae8aa18f0cf0d53481901f88a0991dc3bd4a2", + "version" : "601.0.1" + } + }, { "identity" : "swift-transformers", "kind" : "remoteSourceControl", diff --git a/support/generate-run-all-llms.sh b/support/generate-run-all-llms.sh index f1f837aa..775cd01a 100755 --- a/support/generate-run-all-llms.sh +++ b/support/generate-run-all-llms.sh @@ -3,7 +3,10 @@ echo "#!/bin/sh" echo "# NOTE: GENERATED BY generate-run-all-llms.sh -- DO NOT MODIFY BY HAND" +# note: omit DeepSeek-R1 as many won't have the resources to run this + ./mlx-run llm-tool list llms | \ + grep -v DeepSeek-R1 | \ awk '{printf "./mlx-run llm-tool eval --download ~/Downloads/huggingface --model %s\n", $0}' | \ awk '{printf "echo\necho ======\necho '\''%s'\''\n%s\n", $0, $0}' diff --git a/support/run-all-llms.sh b/support/run-all-llms.sh index 5ae758b3..cf5fbc07 100755 --- a/support/run-all-llms.sh +++ b/support/run-all-llms.sh @@ -2,76 +2,68 @@ # NOTE: GENERATED BY generate-run-all-llms.sh -- DO NOT MODIFY BY HAND echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/quantized-gemma-2b-it' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/quantized-gemma-2b-it +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen1.5-0.5B-Chat-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen1.5-0.5B-Chat-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/granite-3.3-2b-instruct-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/granite-3.3-2b-instruct-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3.1-8B-Instruct-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3.1-8B-Instruct-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-7B-Instruct-v0.3-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-7B-Instruct-v0.3-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/quantized-gemma-2b-it' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/quantized-gemma-2b-it echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3-8B-Instruct-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3-8B-Instruct-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-0.6B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-0.6B-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-4B-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-4B-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Baichuan-M1-14B-Instruct-4bit-ft' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Baichuan-M1-14B-Instruct-4bit-ft echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen1.5-0.5B-Chat-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen1.5-0.5B-Chat-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/exaone-4.0-1.2b-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/exaone-4.0-1.2b-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-1.7B-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-1.7B-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OLMoE-1B-7B-0125-Instruct-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OLMoE-1B-7B-0125-Instruct-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-Nemo-Instruct-2407-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-Nemo-Instruct-2407-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-2b-it-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-2b-it-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-9b-it-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-9b-it-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OLMo-2-1124-7B-Instruct-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OLMo-2-1124-7B-Instruct-4bit echo echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-7B-Instruct-4bit' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-7B-Instruct-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-0.6B-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-0.6B-4bit -echo -echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX -echo -echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-8B-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-8B-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3-8B-Instruct-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3-8B-Instruct-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/ERNIE-4.5-0.3B-PT-bf16-ft' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/ERNIE-4.5-0.3B-PT-bf16-ft echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/phi-2-hf-4bit-mlx' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/phi-2-hf-4bit-mlx +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-1.5B-Instruct-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-1.5B-Instruct-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-9b-it-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-9b-it-4bit echo echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/GLM-4-9B-0414-4bit' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/GLM-4-9B-0414-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3.1-8B-Instruct-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Meta-Llama-3.1-8B-Instruct-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E4B-it-lm-bf16' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E4B-it-lm-bf16 echo echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Phi-3.5-mini-instruct-4bit' @@ -82,49 +74,129 @@ echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-com ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Llama-3.2-1B-Instruct-4bit echo echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-1b-it-qat-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-1b-it-qat-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-8B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-8B-4bit +echo +echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Llama-3.2-3B-Instruct-4bit' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Llama-3.2-3B-Instruct-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OpenELM-270M-Instruct' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OpenELM-270M-Instruct +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-1.5B-Instruct-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-1.5B-Instruct-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/AceReason-Nemotron-7B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/AceReason-Nemotron-7B-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E4B-it-lm-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E4B-it-lm-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/phi-2-hf-4bit-mlx' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/phi-2-hf-4bit-mlx echo echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/MiMo-7B-SFT-4bit' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/MiMo-7B-SFT-4bit echo echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/bitnet-b1.58-2B-4T-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/bitnet-b1.58-2B-4T-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-Nemo-Instruct-2407-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-Nemo-Instruct-2407-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OpenELM-270M-Instruct' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/OpenELM-270M-Instruct +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/lille-130m-instruct-bf16' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/lille-130m-instruct-bf16 +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/granite-3.3-2b-instruct-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/granite-3.3-2b-instruct-4bit +echo +echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolLM-135M-Instruct-4bit' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolLM-135M-Instruct-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-30B-A3B-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-30B-A3B-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/LFM2-1.2B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/LFM2-1.2B-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-2b-it-4bit' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-2-2b-it-4bit +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-1.7B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-1.7B-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E2B-it-lm-bf16' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E2B-it-lm-bf16 +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-4B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-4B-4bit echo echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Phi-3.5-MoE-instruct-4bit' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Phi-3.5-MoE-instruct-4bit echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolVLM-Instruct-4bit --resize 512 --image support/test.jpg' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolVLM-Instruct-4bit --resize 512 --image support/test.jpg +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E2B-it-lm-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3n-E2B-it-lm-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolLM3-3B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolLM3-3B-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Ling-mini-2.0-2bit-DWQ' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Ling-mini-2.0-2bit-DWQ +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-7B-Instruct-v0.3-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Mistral-7B-Instruct-v0.3-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-30B-A3B-4bit' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen3-30B-A3B-4bit +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-12b-it-qat-4bit --resize 512 --image support/test.jpg' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-12b-it-qat-4bit --resize 512 --image support/test.jpg echo echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/paligemma-3b-mix-448-8bit --resize 512 --image support/test.jpg' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/paligemma-3b-mix-448-8bit --resize 512 --image support/test.jpg echo echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx --resize 512 --image support/test.jpg' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx --resize 512 --image support/test.jpg +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-4b-it-qat-4bit --resize 512 --image support/test.jpg' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-4b-it-qat-4bit --resize 512 --image support/test.jpg +echo +echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-VL-3B-Instruct-4bit --resize 512 --image support/test.jpg' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2.5-VL-3B-Instruct-4bit --resize 512 --image support/test.jpg echo echo ====== -echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx --resize 512 --image support/test.jpg' -./mlx-run llm-tool eval --download ~/Downloads/huggingface --model HuggingFaceTB/SmolVLM2-500M-Video-Instruct-mlx --resize 512 --image support/test.jpg +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolVLM-Instruct-4bit --resize 512 --image support/test.jpg' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/SmolVLM-Instruct-4bit --resize 512 --image support/test.jpg echo echo ====== echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2-VL-2B-Instruct-4bit --resize 512 --image support/test.jpg' ./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/Qwen2-VL-2B-Instruct-4bit --resize 512 --image support/test.jpg +echo +echo ====== +echo './mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-27b-it-qat-4bit --resize 512 --image support/test.jpg' +./mlx-run llm-tool eval --download ~/Downloads/huggingface --model mlx-community/gemma-3-27b-it-qat-4bit --resize 512 --image support/test.jpg