diff --git a/src/constants/ipc-constants.ts b/src/constants/ipc-constants.ts index 994114b11..053c8c6ae 100644 --- a/src/constants/ipc-constants.ts +++ b/src/constants/ipc-constants.ts @@ -38,6 +38,9 @@ export const IPC_CHANNELS = { // Window channels WINDOW_SET_TITLE: "window:setTitle", + // Token channels + TOKENS_COUNT_BULK: "tokens:countBulk", + // Dynamic channel prefixes WORKSPACE_CHAT_PREFIX: "workspace:chat:", WORKSPACE_METADATA: "workspace:metadata", diff --git a/src/main.ts b/src/main.ts index e05666fbe..57cf13cda 100644 --- a/src/main.ts +++ b/src/main.ts @@ -436,6 +436,21 @@ if (gotTheLock) { } }); + // Cleanup worker threads on quit + app.on("will-quit", () => { + console.log("App will quit - cleaning up worker threads"); + void (async () => { + try { + // Dynamic import is acceptable here - only loaded if worker was used + /* eslint-disable-next-line no-restricted-syntax */ + const { tokenizerWorkerPool } = await import("@/services/tokenizerWorkerPool"); + tokenizerWorkerPool.terminate(); + } catch (error) { + console.error("Error terminating worker pool:", error); + } + })(); + }); + app.on("activate", () => { // Only create window if app is ready and no window exists // This prevents "Cannot create BrowserWindow before app is ready" error diff --git a/src/preload.ts b/src/preload.ts index 85cc99449..538e585da 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -110,6 +110,10 @@ const api: IPCApi = { window: { setTitle: (title: string) => ipcRenderer.invoke(IPC_CHANNELS.WINDOW_SET_TITLE, title), }, + tokens: { + countBulk: (model: string, texts: string[]) => + ipcRenderer.invoke(IPC_CHANNELS.TOKENS_COUNT_BULK, model, texts), + }, }; // Expose the API along with platform/versions diff --git a/src/services/ipcMain.ts b/src/services/ipcMain.ts index 766bd8428..0b68456e2 100644 --- a/src/services/ipcMain.ts +++ b/src/services/ipcMain.ts @@ -140,6 +140,7 @@ export class IpcMain { this.registerDialogHandlers(ipcMain); this.registerWindowHandlers(ipcMain); + this.registerTokenHandlers(ipcMain); this.registerWorkspaceHandlers(ipcMain); this.registerProviderHandlers(ipcMain); this.registerProjectHandlers(ipcMain); @@ -174,6 +175,24 @@ export class IpcMain { }); } + private registerTokenHandlers(ipcMain: ElectronIpcMain): void { + ipcMain.handle( + IPC_CHANNELS.TOKENS_COUNT_BULK, + async (_event, model: string, texts: string[]) => { + try { + // Offload to worker thread - keeps main process responsive + // Dynamic import is acceptable here - worker pool is lazy-loaded on first use + /* eslint-disable-next-line no-restricted-syntax */ + const { tokenizerWorkerPool } = await import("@/services/tokenizerWorkerPool"); + return await tokenizerWorkerPool.countTokens(model, texts); + } catch (error) { + log.error(`Failed to count tokens for model ${model}:`, error); + return null; // Tokenizer not loaded or error occurred + } + } + ); + } + private registerWorkspaceHandlers(ipcMain: ElectronIpcMain): void { ipcMain.handle( IPC_CHANNELS.WORKSPACE_CREATE, diff --git a/src/services/tokenizerWorkerPool.ts b/src/services/tokenizerWorkerPool.ts new file mode 100644 index 000000000..703e1dbbd --- /dev/null +++ b/src/services/tokenizerWorkerPool.ts @@ -0,0 +1,164 @@ +/** + * Tokenizer Worker Pool + * Manages Node.js worker thread for off-main-thread tokenization + */ + +import { Worker } from "worker_threads"; +import path from "path"; +import { log } from "@/services/log"; + +interface PendingRequest { + resolve: (counts: number[]) => void; + reject: (error: Error) => void; + timeoutId: NodeJS.Timeout; +} + +interface TokenizeRequest { + requestId: number; + model: string; + texts: string[]; +} + +interface TokenizeResponse { + requestId: number; + success: boolean; + counts?: number[]; + error?: string; +} + +class TokenizerWorkerPool { + private worker: Worker | null = null; + private requestCounter = 0; + private pendingRequests = new Map(); + private isTerminating = false; + + /** + * Get or create the worker thread + */ + private getWorker(): Worker { + if (this.worker && !this.isTerminating) { + return this.worker; + } + + // Worker script path - compiled by tsc to dist/workers/tokenizerWorker.js + // __dirname in production will be dist/services, so we go up one level then into workers + const workerPath = path.join(__dirname, "..", "workers", "tokenizerWorker.js"); + + this.worker = new Worker(workerPath); + this.isTerminating = false; + + // Allow Node to exit even if worker is still running (important for tests) + this.worker.unref(); + + this.worker.on("message", (response: TokenizeResponse) => { + this.handleResponse(response); + }); + + this.worker.on("error", (error: Error) => { + log.error("Tokenizer worker error:", error); + // Reject all pending requests + for (const [requestId, pending] of this.pendingRequests) { + clearTimeout(pending.timeoutId); + pending.reject(new Error(`Worker error: ${error.message}`)); + this.pendingRequests.delete(requestId); + } + }); + + this.worker.on("exit", (code: number) => { + if (!this.isTerminating && code !== 0) { + log.error(`Tokenizer worker exited with code ${code}`); + } + this.worker = null; + }); + + return this.worker; + } + + /** + * Handle response from worker + */ + private handleResponse(response: TokenizeResponse): void { + const pending = this.pendingRequests.get(response.requestId); + if (!pending) { + return; // Request was cancelled or timed out + } + + clearTimeout(pending.timeoutId); + this.pendingRequests.delete(response.requestId); + + if (response.success && response.counts) { + pending.resolve(response.counts); + } else { + pending.reject(new Error(response.error ?? "Unknown worker error")); + } + } + + /** + * Count tokens for multiple texts using worker thread + * @param model - Model identifier for tokenizer selection + * @param texts - Array of texts to tokenize + * @returns Promise resolving to array of token counts + */ + async countTokens(model: string, texts: string[]): Promise { + const requestId = this.requestCounter++; + const worker = this.getWorker(); + + return new Promise((resolve, reject) => { + // Set timeout for request (30 seconds) + const timeoutId = setTimeout(() => { + const pending = this.pendingRequests.get(requestId); + if (pending) { + this.pendingRequests.delete(requestId); + reject(new Error("Tokenization request timeout (30s)")); + } + }, 30000); + + // Store pending request + this.pendingRequests.set(requestId, { + resolve, + reject, + timeoutId, + }); + + // Send request to worker + const request: TokenizeRequest = { + requestId, + model, + texts, + }; + + try { + worker.postMessage(request); + } catch (error) { + clearTimeout(timeoutId); + this.pendingRequests.delete(requestId); + reject(error instanceof Error ? error : new Error(String(error))); + } + }); + } + + /** + * Terminate the worker thread and reject all pending requests + */ + terminate(): void { + this.isTerminating = true; + + // Reject all pending requests + for (const [requestId, pending] of this.pendingRequests) { + clearTimeout(pending.timeoutId); + pending.reject(new Error("Worker pool terminated")); + this.pendingRequests.delete(requestId); + } + + // Terminate worker + if (this.worker) { + this.worker.terminate().catch((error) => { + log.error("Error terminating tokenizer worker:", error); + }); + this.worker = null; + } + } +} + +// Singleton instance +export const tokenizerWorkerPool = new TokenizerWorkerPool(); diff --git a/src/types/ipc.ts b/src/types/ipc.ts index ece311231..356a8b780 100644 --- a/src/types/ipc.ts +++ b/src/types/ipc.ts @@ -230,4 +230,7 @@ export interface IPCApi { window: { setTitle(title: string): Promise; }; + tokens: { + countBulk(model: string, texts: string[]): Promise; + }; } diff --git a/src/utils/main/tokenizer.test.ts b/src/utils/main/tokenizer.test.ts new file mode 100644 index 000000000..0cb2fba18 --- /dev/null +++ b/src/utils/main/tokenizer.test.ts @@ -0,0 +1,53 @@ +/** + * Tests for tokenizer cache behavior + */ + +import { describe, it, expect } from "@jest/globals"; +import { getTokenizerForModel } from "./tokenizer"; + +describe("tokenizer cache", () => { + const testText = "Hello, world!"; + + it("should use different cache keys for different models", () => { + // Get tokenizers for different models + const gpt4Tokenizer = getTokenizerForModel("openai:gpt-4"); + const claudeTokenizer = getTokenizerForModel("anthropic:claude-opus-4"); + + // Count tokens with first model + const gpt4Count = gpt4Tokenizer.countTokens(testText); + + // Count tokens with second model + const claudeCount = claudeTokenizer.countTokens(testText); + + // Counts may differ because different encodings + // This test mainly ensures no crash and cache isolation + expect(typeof gpt4Count).toBe("number"); + expect(typeof claudeCount).toBe("number"); + expect(gpt4Count).toBeGreaterThan(0); + expect(claudeCount).toBeGreaterThan(0); + }); + + it("should return same count for same (model, text) pair from cache", () => { + const tokenizer = getTokenizerForModel("openai:gpt-4"); + + // First call + const count1 = tokenizer.countTokens(testText); + + // Second call should hit cache + const count2 = tokenizer.countTokens(testText); + + expect(count1).toBe(count2); + }); + + it("should normalize model keys for cache consistency", () => { + // These should map to the same cache key + const tokenizer1 = getTokenizerForModel("anthropic:claude-opus-4"); + const tokenizer2 = getTokenizerForModel("anthropic/claude-opus-4"); + + const count1 = tokenizer1.countTokens(testText); + const count2 = tokenizer2.countTokens(testText); + + // Should get same count since they normalize to same model + expect(count1).toBe(count2); + }); +}); diff --git a/src/utils/main/tokenizer.ts b/src/utils/main/tokenizer.ts index 4c8bce7c0..c23310d8c 100644 --- a/src/utils/main/tokenizer.ts +++ b/src/utils/main/tokenizer.ts @@ -66,9 +66,14 @@ export async function loadTokenizerModules(): Promise { } /** - * LRU cache for token counts by text checksum - * Avoids re-tokenizing identical strings (system messages, tool definitions, etc.) - * Key: CRC32 checksum of text, Value: token count + * LRU cache for token counts by (model, text) pairs + * Avoids re-tokenizing identical strings with the same encoding + * + * Key: CRC32 checksum of "model:text" to ensure counts are model-specific + * Value: token count + * + * IMPORTANT: Cache key includes model because different encodings produce different counts. + * For async tokenization (approx → exact), the key stays stable so exact overwrites approx. */ const tokenCountCache = new LRUCache({ max: 500000, // Max entries (safety limit) @@ -83,11 +88,22 @@ const tokenCountCache = new LRUCache({ * Count tokens with caching via CRC32 checksum * Avoids re-tokenizing identical strings (system messages, tool definitions, etc.) * + * Cache key includes model to prevent cross-model count reuse. + * * NOTE: For async tokenization, this returns an approximation immediately and caches - * the accurate count in the background. Subsequent calls will use the cached accurate count. + * the accurate count in the background. Subsequent calls with the same (model, text) pair + * will use the cached accurate count once ready. */ -function countTokensCached(text: string, tokenizeFn: () => number | Promise): number { - const checksum = CRC32.str(text); +function countTokensCached( + text: string, + modelString: string, + tokenizeFn: () => number | Promise +): number { + // Include model in cache key to prevent different encodings from reusing counts + // Normalize model key for consistent cache hits (e.g., "anthropic:claude" → "anthropic/claude") + const normalizedModel = normalizeModelKey(modelString); + const cacheKey = `${normalizedModel}:${text}`; + const checksum = CRC32.str(cacheKey); const cached = tokenCountCache.get(checksum); if (cached !== undefined) { return cached; @@ -102,6 +118,7 @@ function countTokensCached(text: string, tokenizeFn: () => number | Promise tokenCountCache.set(checksum, count)); return approximation; @@ -179,8 +196,8 @@ function countTokensWithLoadedModules( * @returns Tokenizer interface with name and countTokens function */ export function getTokenizerForModel(modelString: string): Tokenizer { - // Start loading tokenizer modules in background (idempotent) - void loadTokenizerModules(); + // Tokenizer modules are loaded on-demand when countTokens is first called + // This avoids blocking app startup with 8MB+ of tokenizer downloads return { get encoding() { @@ -189,7 +206,7 @@ export function getTokenizerForModel(modelString: string): Tokenizer { countTokens: (text: string) => { // If tokenizer already loaded, use synchronous path for accurate counts if (tokenizerModules) { - return countTokensCached(text, () => { + return countTokensCached(text, modelString, () => { try { return countTokensWithLoadedModules(text, modelString, tokenizerModules!); } catch (error) { @@ -201,7 +218,7 @@ export function getTokenizerForModel(modelString: string): Tokenizer { } // Tokenizer not yet loaded - use async path (returns approximation immediately) - return countTokensCached(text, async () => { + return countTokensCached(text, modelString, async () => { await loadTokenizerModules(); try { return countTokensWithLoadedModules(text, modelString, tokenizerModules!); diff --git a/src/workers/tokenizerWorker.ts b/src/workers/tokenizerWorker.ts new file mode 100644 index 000000000..907c2c5ca --- /dev/null +++ b/src/workers/tokenizerWorker.ts @@ -0,0 +1,56 @@ +/** + * Node.js Worker Thread for tokenization + * Offloads CPU-intensive tokenization to prevent main process blocking + */ + +import { parentPort } from "worker_threads"; + +// Lazy-load tokenizer only when first needed +let getTokenizerForModel: ((model: string) => { countTokens: (text: string) => number }) | null = + null; + +interface TokenizeRequest { + requestId: number; + model: string; + texts: string[]; +} + +interface TokenizeResponse { + requestId: number; + success: boolean; + counts?: number[]; + error?: string; +} + +parentPort?.on("message", (data: TokenizeRequest) => { + const { requestId, model, texts } = data; + + void (async () => { + try { + // Lazy-load tokenizer on first use + // Dynamic import is acceptable here as worker is isolated and has no circular deps + if (!getTokenizerForModel) { + /* eslint-disable-next-line no-restricted-syntax */ + const tokenizerModule = await import("@/utils/main/tokenizer"); + getTokenizerForModel = tokenizerModule.getTokenizerForModel; + } + + const tokenizer = getTokenizerForModel(model); + const counts = texts.map((text) => tokenizer.countTokens(text)); + + const response: TokenizeResponse = { + requestId, + success: true, + counts, + }; + parentPort?.postMessage(response); + } catch (error) { + const response: TokenizeResponse = { + requestId, + success: false, + error: error instanceof Error ? error.message : String(error), + }; + parentPort?.postMessage(response); + } + })(); +}); diff --git a/tsconfig.main.json b/tsconfig.main.json index d913052f7..033067d0d 100644 --- a/tsconfig.main.json +++ b/tsconfig.main.json @@ -6,6 +6,6 @@ "noEmit": false, "sourceMap": true }, - "include": ["src/main.ts", "src/constants/**/*", "src/types/**/*.d.ts"], + "include": ["src/main.ts", "src/constants/**/*", "src/types/**/*.d.ts", "src/workers/**/*"], "exclude": ["src/App.tsx", "src/main.tsx"] }