Skip to content

Commit 073e657

Browse files
authored
Add Exaone4 (#357)
* Add Exaone4
1 parent ab6feba commit 073e657

File tree

2 files changed

+313
-0
lines changed

2 files changed

+313
-0
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
5353
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
5454
"lfm2": create(LFM2Configuration.self, LFM2Model.init),
5555
"baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init),
56+
"exaone4": create(Exaone4Configuration.self, Exaone4Model.init),
5657
]
5758
}
5859

@@ -255,6 +256,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
255256
defaultPrompt: "Why is the sky blue?"
256257
)
257258

259+
static public let exaone_4_0_1_2b_4bit = ModelConfiguration(
260+
id: "mlx-community/exaone-4.0-1.2b-4bit",
261+
defaultPrompt: "Why is the sky blue?"
262+
)
263+
258264
private static func all() -> [ModelConfiguration] {
259265
[
260266
codeLlama13b4bit,
@@ -291,6 +297,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
291297
ernie_45_0_3BPT_bf16_ft,
292298
lfm2_1_2b_4bit,
293299
baichuan_m1_14b_instruct_4bit,
300+
exaone_4_0_1_2b_4bit,
294301
]
295302
}
296303

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
//
2+
// Exaone4.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by John Mai on 2025/7/15.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXFast
11+
import MLXLMCommon
12+
import MLXNN
13+
14+
// port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/exaone4.py
15+
16+
private class Attention: Module {
17+
let args: Exaone4Configuration
18+
let scale: Float
19+
let isLocal: Bool
20+
let useRope: Bool
21+
22+
@ModuleInfo(key: "q_proj") var qProj: Linear
23+
@ModuleInfo(key: "k_proj") var kProj: Linear
24+
@ModuleInfo(key: "v_proj") var vProj: Linear
25+
@ModuleInfo(key: "o_proj") var oProj: Linear
26+
27+
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
28+
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
29+
30+
let rope: RoPE?
31+
32+
public init(_ args: Exaone4Configuration, isLocal: Bool?) {
33+
self.args = args
34+
self.isLocal = isLocal ?? false
35+
self.useRope = isLocal == nil || (isLocal ?? false)
36+
37+
let dim = args.hiddenSize
38+
let heads = args.attentionHeads
39+
let kvHeads = args.kvHeads
40+
41+
let headDim = args.headDim
42+
self.scale = pow(Float(headDim), -0.5)
43+
44+
_qProj.wrappedValue = Linear(dim, heads * headDim, bias: false)
45+
_kProj.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
46+
_vProj.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
47+
_oProj.wrappedValue = Linear(heads * headDim, dim, bias: false)
48+
49+
_qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
50+
_kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
51+
52+
if useRope {
53+
let ropeScale: Float
54+
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
55+
let factor = ropeScaling["factor"]
56+
{
57+
if let v = factor.asFloat() {
58+
ropeScale = 1 / v
59+
} else {
60+
fatalError("ropeScaling.factor must be a float")
61+
}
62+
} else {
63+
ropeScale = 1
64+
}
65+
66+
self.rope = RoPE(
67+
dimensions: headDim, traditional: false, base: args.ropeTheta,
68+
scale: ropeScale)
69+
} else {
70+
self.rope = nil
71+
}
72+
}
73+
74+
public func callAsFunction(
75+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
76+
) -> MLXArray {
77+
let (B, L) = (x.dim(0), x.dim(1))
78+
79+
var queries = qProj(x)
80+
var keys = kProj(x)
81+
var values = vProj(x)
82+
83+
queries = qNorm(queries.reshaped(B, L, args.attentionHeads, -1)).transposed(0, 2, 1, 3)
84+
keys = kNorm(keys.reshaped(B, L, args.kvHeads, -1)).transposed(0, 2, 1, 3)
85+
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
86+
87+
if let cache, useRope, let rope {
88+
queries = rope(queries, offset: cache.offset)
89+
keys = rope(keys, offset: cache.offset)
90+
} else if useRope, let rope {
91+
queries = rope(queries)
92+
keys = rope(keys)
93+
}
94+
95+
let output = attentionWithCacheUpdate(
96+
queries: queries,
97+
keys: keys,
98+
values: values,
99+
cache: cache,
100+
scale: scale,
101+
mask: mask
102+
)
103+
.transposed(0, 2, 1, 3)
104+
.reshaped(B, L, -1)
105+
106+
return oProj(output)
107+
}
108+
}
109+
110+
private class MLP: Module, UnaryLayer {
111+
@ModuleInfo(key: "gate_proj") var gate: Linear
112+
@ModuleInfo(key: "down_proj") var down: Linear
113+
@ModuleInfo(key: "up_proj") var up: Linear
114+
115+
public init(dimensions: Int, hiddenDimensions: Int) {
116+
_gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
117+
_down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
118+
_up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
119+
}
120+
121+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
122+
down(silu(gate(x)) * up(x))
123+
}
124+
}
125+
126+
private class TransformerBlock: Module {
127+
@ModuleInfo(key: "self_attn") var attention: Attention
128+
let mlp: MLP
129+
130+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
131+
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
132+
133+
public init(_ args: Exaone4Configuration, isLocal: Bool?) {
134+
_attention.wrappedValue = Attention(args, isLocal: isLocal)
135+
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
136+
_postAttentionLayerNorm.wrappedValue = RMSNorm(
137+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
138+
_postFeedforwardLayerNorm.wrappedValue = RMSNorm(
139+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
140+
}
141+
142+
public func callAsFunction(
143+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
144+
) -> MLXArray {
145+
var r = attention(x, mask: mask, cache: cache)
146+
let h = x + postAttentionLayerNorm(r)
147+
r = mlp(h)
148+
let out = h + postFeedforwardLayerNorm(r)
149+
return out
150+
}
151+
}
152+
153+
private class ModelInner: Module {
154+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
155+
156+
fileprivate let layers: [TransformerBlock]
157+
let norm: RMSNorm
158+
159+
public init(_ args: Exaone4Configuration) {
160+
precondition(args.vocabularySize > 0)
161+
162+
_embedTokens.wrappedValue = Embedding(
163+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
164+
165+
self.layers = (0 ..< args.hiddenLayers)
166+
.map { i in
167+
let isLocal: Bool?
168+
if let pattern = args.slidingWindowPattern {
169+
let patternIndex = i % pattern.count
170+
let character = pattern[
171+
pattern.index(pattern.startIndex, offsetBy: patternIndex)]
172+
isLocal = character == "L"
173+
} else {
174+
isLocal = nil
175+
}
176+
return TransformerBlock(args, isLocal: isLocal)
177+
}
178+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
179+
}
180+
181+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
182+
var h = embedTokens(inputs)
183+
184+
let mask = createAttentionMask(h: h, cache: cache)
185+
186+
for (i, layer) in layers.enumerated() {
187+
h = layer(h, mask: mask, cache: cache?[i])
188+
}
189+
190+
return norm(h)
191+
}
192+
}
193+
194+
public class Exaone4Model: Module, LLMModel, KVCacheDimensionProvider {
195+
public let vocabularySize: Int
196+
public let kvHeads: [Int]
197+
198+
private let model: ModelInner
199+
let configuration: Exaone4Configuration
200+
201+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
202+
203+
public init(_ args: Exaone4Configuration) {
204+
self.configuration = args
205+
self.vocabularySize = args.vocabularySize
206+
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
207+
self.model = ModelInner(args)
208+
209+
if !args.tieWordEmbeddings {
210+
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
211+
}
212+
}
213+
214+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
215+
var out = model(inputs, cache: cache)
216+
if let lmHead {
217+
out = lmHead(out)
218+
} else {
219+
out = model.embedTokens.asLinear(out)
220+
}
221+
return out
222+
}
223+
224+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
225+
var weights = weights
226+
227+
if configuration.tieWordEmbeddings {
228+
weights["lm_head.weight"] = nil
229+
}
230+
231+
return weights
232+
}
233+
234+
public func newCache(parameters: GenerateParameters? = nil) -> [KVCache] {
235+
return model.layers.map { layer in
236+
if layer.attention.isLocal, let slidingWindow = configuration.slidingWindow {
237+
return RotatingKVCache(maxSize: slidingWindow, keep: 0)
238+
} else {
239+
return StandardKVCache()
240+
}
241+
}
242+
}
243+
}
244+
245+
public struct Exaone4Configuration: Codable, Sendable {
246+
var hiddenSize: Int
247+
var hiddenLayers: Int
248+
var intermediateSize: Int
249+
var attentionHeads: Int
250+
var rmsNormEps: Float
251+
var vocabularySize: Int
252+
var kvHeads: Int
253+
var maxPositionEmbeddings: Int
254+
var ropeTheta: Float
255+
var headDim: Int
256+
var tieWordEmbeddings: Bool
257+
var ropeScaling: [String: StringOrNumber]?
258+
var slidingWindow: Int?
259+
var slidingWindowPattern: String?
260+
261+
enum CodingKeys: String, CodingKey {
262+
case hiddenSize = "hidden_size"
263+
case hiddenLayers = "num_hidden_layers"
264+
case intermediateSize = "intermediate_size"
265+
case attentionHeads = "num_attention_heads"
266+
case rmsNormEps = "rms_norm_eps"
267+
case vocabularySize = "vocab_size"
268+
case kvHeads = "num_key_value_heads"
269+
case maxPositionEmbeddings = "max_position_embeddings"
270+
case ropeTheta = "rope_theta"
271+
case headDim = "head_dim"
272+
case tieWordEmbeddings = "tie_word_embeddings"
273+
case ropeScaling = "rope_scaling"
274+
case slidingWindow = "sliding_window"
275+
case slidingWindowPattern = "sliding_window_pattern"
276+
}
277+
278+
public init(from decoder: Decoder) throws {
279+
let container = try decoder.container(keyedBy: CodingKeys.self)
280+
281+
self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
282+
self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
283+
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
284+
self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
285+
self.rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
286+
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
287+
self.kvHeads = try container.decode(Int.self, forKey: .kvHeads)
288+
self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
289+
self.ropeTheta = try container.decode(Float.self, forKey: .ropeTheta)
290+
self.headDim = try container.decode(Int.self, forKey: .headDim)
291+
self.tieWordEmbeddings = try container.decode(Bool.self, forKey: .tieWordEmbeddings)
292+
self.ropeScaling = try container.decodeIfPresent(
293+
[String: StringOrNumber].self, forKey: .ropeScaling)
294+
self.slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow)
295+
self.slidingWindowPattern = try container.decodeIfPresent(
296+
String.self, forKey: .slidingWindowPattern)
297+
}
298+
}
299+
300+
// MARK: - LoRA
301+
302+
extension Exaone4Model: LoRAModel {
303+
public func loraLinearLayers() -> LoRALinearLayers {
304+
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
305+
}
306+
}

0 commit comments

Comments
 (0)