Skip to content

Commit 2473a71

Browse files
committed
Refactor WebGPUComputeEngine to improve kernel management and enhance performance in webgpuJacobiSolver
1 parent 30eba31 commit 2473a71

File tree

1 file changed

+74
-38
lines changed

1 file changed

+74
-38
lines changed

src/methods/webgpuJacobiSolverScript.js

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,21 @@ import * as ti from "../vendor/taichi.esm.js";
1313
import { debugLog, errorLog } from "../utilities/loggingScript.js";
1414

1515
/**
16-
* Class providing GPU-accelerated Jacobi solver using Taichi.js/WebGPU.
17-
* Offloads iterative linear algebra to the GPU for improved performance on large systems.
16+
* Class to provide GPU-accelerated Jacobi solver using Taichi.js/WebGPU
17+
* Offloads iterative linear algebra to the GPU for improved performance on large systems
1818
*/
1919
export class WebGPUComputeEngine {
2020
/**
21-
* Creates a WebGPUComputeEngine instance.
22-
* The engine remains uninitialized until initialize() is called.
21+
* Constructor to creates a WebGPUComputeEngine instance
22+
* The engine remains uninitialized until initialize() is called
2323
*/
2424
constructor() {
2525
this.initialized = false;
26+
this.extractDiagonalKernel = null;
27+
this.jacobiStepKernel = null;
28+
this.swapSolutionKernel = null;
29+
this.cachedSize = null;
30+
this.fields = null;
2631
}
2732

2833
/**
@@ -47,60 +52,86 @@ export class WebGPUComputeEngine {
4752
* @returns {Promise<object>} Result object containing the solution, iteration count, and convergence flag
4853
*/
4954
async webgpuJacobiSolver(A, b, x0, maxIter, tol) {
55+
await this.initialize();
5056
const n = b.length;
5157
const flatA = A.flat();
5258

53-
const AField = ti.field(ti.f32, [n * n]);
54-
const bField = ti.field(ti.f32, [n]);
55-
const xField = ti.field(ti.f32, [n]);
56-
const xNewField = ti.field(ti.f32, [n]);
57-
const diagField = ti.field(ti.f32, [n]);
58-
const maxResidualField = ti.field(ti.f32, [1]);
59+
if (!this.fields || this.cachedSize !== n) {
60+
this.fields = {
61+
AField: ti.field(ti.f32, [n * n]),
62+
bField: ti.field(ti.f32, [n]),
63+
xField: ti.field(ti.f32, [n]),
64+
xNewField: ti.field(ti.f32, [n]),
65+
diagField: ti.field(ti.f32, [n]),
66+
maxResidualField: ti.field(ti.f32, [1]),
67+
};
68+
this.cachedSize = n;
69+
}
70+
71+
const { AField, bField, xField, xNewField, diagField, maxResidualField } = this.fields;
5972

6073
AField.fromArray(flatA);
6174
bField.fromArray(b);
6275
xField.fromArray(x0);
6376
xNewField.fromArray(x0);
6477

6578
ti.addToKernelScope({ AField, bField, xField, xNewField, diagField, maxResidualField });
79+
if (!this.extractDiagonalKernel) {
80+
this.extractDiagonalKernel = ti.kernel((size) => {
81+
for (let i of ti.ndrange(size)) {
82+
diagField[i] = AField[ti.i32(i) * ti.i32(size) + ti.i32(i)];
83+
}
84+
});
6685

67-
ti.kernel((size) => {
68-
for (let i of ti.ndrange(size)) {
69-
diagField[i] = AField[ti.i32(i) * ti.i32(size) + ti.i32(i)];
70-
}
71-
})(n);
72-
73-
const jacobiStep = ti.kernel((size) => {
74-
maxResidualField[0] = 0.0;
75-
for (let i of ti.ndrange(size)) {
76-
let sum = 0.0;
77-
for (let j of ti.ndrange(size)) {
78-
sum += AField[ti.i32(i) * ti.i32(size) + ti.i32(j)] * xField[j];
86+
this.jacobiStepKernel = ti.kernel((size) => {
87+
maxResidualField[0] = 0.0;
88+
for (let i of ti.ndrange(size)) {
89+
let sum = 0.0;
90+
for (let j of ti.ndrange(size)) {
91+
sum += AField[ti.i32(i) * ti.i32(size) + ti.i32(j)] * xField[j];
92+
}
93+
const residual = bField[i] - sum;
94+
xNewField[i] = xField[i] + residual / diagField[i];
95+
ti.atomicMax(maxResidualField[0], ti.abs(residual));
7996
}
80-
const residual = bField[i] - sum;
81-
xNewField[i] = xField[i] + residual / diagField[i];
82-
ti.atomicMax(maxResidualField[0], ti.abs(residual));
83-
}
84-
});
97+
});
8598

86-
const swapSolution = ti.kernel((size) => {
87-
for (let i of ti.ndrange(size)) {
88-
xField[i] = xNewField[i];
89-
}
90-
});
99+
this.swapSolutionKernel = ti.kernel((size) => {
100+
for (let i of ti.ndrange(size)) {
101+
xField[i] = xNewField[i];
102+
}
103+
});
104+
}
105+
106+
this.extractDiagonalKernel(n);
107+
108+
const residualCheckInterval = Math.max(1, Math.min(10, Math.floor(maxIter / 4) || 1));
109+
let iterations = maxIter;
110+
let converged = false;
91111

92112
for (let iter = 0; iter < maxIter; iter++) {
93-
jacobiStep(n);
113+
this.jacobiStepKernel(n);
114+
this.swapSolutionKernel(n);
115+
116+
const shouldCheckResidual = (iter + 1) % residualCheckInterval === 0 || iter === maxIter - 1;
117+
if (!shouldCheckResidual) {
118+
continue;
119+
}
120+
94121
const rnorm = (await maxResidualField.toArray())[0];
95-
debugLog(`Jacobi: Iteration ${iter + 1}, residual norm: ${rnorm}`);
122+
iterations = iter + 1;
123+
debugLog(`Jacobi: Iteration ${iterations}, residual norm: ${rnorm}`);
96124
if (rnorm < tol) {
97-
return { solutionVector: await xNewField.toArray(), iterations: iter + 1, converged: true };
125+
converged = true;
126+
break;
98127
}
99-
swapSolution(n);
100128
}
101129

102-
errorLog(`Jacobi: Did not converge in ${maxIter} iterations`);
103-
return { solutionVector: await xField.toArray(), iterations: maxIter, converged: false };
130+
if (!converged) {
131+
errorLog(`Jacobi: Did not converge in ${maxIter} iterations`);
132+
}
133+
134+
return { solutionVector: await xField.toArray(), iterations, converged };
104135
}
105136

106137
/**
@@ -115,5 +146,10 @@ export class WebGPUComputeEngine {
115146
await ti.destroy();
116147
}
117148
this.initialized = false;
149+
this.extractDiagonalKernel = null;
150+
this.jacobiStepKernel = null;
151+
this.swapSolutionKernel = null;
152+
this.cachedSize = null;
153+
this.fields = null;
118154
}
119155
}

0 commit comments

Comments
 (0)