diff --git a/src/rpc/__tests__/dial.spec.ts b/src/rpc/__tests__/dial.spec.ts index 938b5b7f1..6e0c64fa8 100644 --- a/src/rpc/__tests__/dial.spec.ts +++ b/src/rpc/__tests__/dial.spec.ts @@ -27,8 +27,8 @@ import { } from '../../__mocks__/webrtc'; import { withICEServers } from '../__fixtures__/dial-webrtc-options'; import { createMockTransport } from '../../__mocks__/transports'; -import { createMockSignalingExchange } from '../__mocks__/signaling-exchanges'; import { ClientChannel } from '../client-channel'; +import type { Transport } from '@connectrpc/connect'; vi.mock('../peer'); vi.mock('../signaling-exchange'); @@ -52,15 +52,12 @@ const setupDialWebRTCMocks = () => { const peerConnection = createMockPeerConnection(); const dataChannel = createMockDataChannel(); const transport = createMockTransport(); - const signalingExchange = createMockSignalingExchange(transport); vi.mocked(newPeerConnectionForClient).mockResolvedValue({ pc: peerConnection, dc: dataChannel, }); - vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange); - const optionalWebRTCConfigFn = vi.fn().mockResolvedValue({ config: { additionalIceServers: [], @@ -68,12 +65,20 @@ const setupDialWebRTCMocks = () => { }, }); - vi.mocked(createClient).mockReturnValue({ + const mockClient = { optionalWebRTCConfig: optionalWebRTCConfigFn, - } as unknown as ReturnType); + } as unknown as ReturnType; + vi.mocked(createClient).mockReturnValue(mockClient); vi.mocked(createGrpcWebTransport).mockReturnValue(transport); + const signalingExchange = { + doExchange: vi.fn().mockResolvedValue(transport), + terminate: vi.fn(), + } as unknown as SignalingExchange; + + vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange); + return { peerConnection, dataChannel, @@ -207,21 +212,18 @@ describe('dialWebRTC', () => { expect(vi.mocked(peerConnection.close)).toHaveBeenCalled(); }); - it('should close peer connection if dialDirect fails', async () => { + it('should propagate error if transport creation fails', async () => { // Arrange - const { peerConnection, transport } = setupDialWebRTCMocks(); - // First call succeeds (getOptionalWebRTCConfig), second call fails (signaling) - vi.mocked(createGrpcWebTransport) - .mockReturnValueOnce(transport) - .mockImplementationOnce(() => { - throw new Error('Transport creation failed'); - }); + setupDialWebRTCMocks(); + vi.mocked(createGrpcWebTransport).mockImplementation(() => { + throw new Error('Transport creation failed'); + }); // Act & Assert await expect(dialWebRTC(TEST_URL, TEST_HOST)).rejects.toThrow( 'Transport creation failed' ); - expect(vi.mocked(peerConnection.close)).toHaveBeenCalled(); + expect(newPeerConnectionForClient).not.toHaveBeenCalled(); }); it('should rethrow errors after cleanup', async () => { @@ -327,6 +329,103 @@ describe('validateDialOptions', () => { }); }); +describe('resource management', () => { + it('should reuse a single transport for config fetching and signaling', async () => { + // Arrange + setupDialWebRTCMocks(); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(createGrpcWebTransport).toHaveBeenCalledTimes(1); + expect(createGrpcWebTransport).toHaveBeenCalledWith({ + baseUrl: TEST_URL, + credentials: 'same-origin', + }); + }); + + it('should reuse a single signaling client for config fetching and signaling', async () => { + // Arrange + setupDialWebRTCMocks(); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(createClient).toHaveBeenCalledTimes(1); + expect(createClient).toHaveBeenCalledWith( + expect.anything(), + expect.anything() + ); + }); + + it('should not leak transports on successful connection', async () => { + // Arrange + const { transport } = setupDialWebRTCMocks(); + const transportCount = { created: 0 }; + + vi.mocked(createGrpcWebTransport).mockImplementation(() => { + transportCount.created += 1; + return transport; + }); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(transportCount.created).toBe(1); + }); + + it('should not leak transports on connection failure', async () => { + // Arrange + const { transport, signalingExchange } = setupDialWebRTCMocks(); + const transportCount = { created: 0 }; + + vi.mocked(createGrpcWebTransport).mockImplementation(() => { + transportCount.created += 1; + return transport; + }); + + const error = new Error('Connection failed'); + vi.mocked(signalingExchange.doExchange).mockRejectedValueOnce(error); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST).catch(() => { + // Ignore error for this test + }); + + // Assert + expect(transportCount.created).toBe(1); + }); + + it('should use the same transport reference for both config and signaling', async () => { + // Arrange + setupDialWebRTCMocks(); + const capturedTransports: Transport[] = []; + + vi.mocked(createClient).mockImplementation( + (_service, capturedTransport) => { + capturedTransports.push(capturedTransport); + return { + optionalWebRTCConfig: vi.fn().mockResolvedValue({ + config: { + additionalIceServers: [], + disableTrickle: false, + }, + }), + } as unknown as ReturnType; + } + ); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(capturedTransports.length).toBe(1); + }); +}); + describe('dialDirect', () => { afterEach(() => { vi.restoreAllMocks(); diff --git a/src/rpc/dial.ts b/src/rpc/dial.ts index d1f0bd2be..fb7318de2 100644 --- a/src/rpc/dial.ts +++ b/src/rpc/dial.ts @@ -309,20 +309,24 @@ export interface WebRTCConnection { dataChannel: RTCDataChannel; } -const getOptionalWebRTCConfig = async ( +const getSignalingClient = async ( signalingAddress: string, - callOpts: CallOptions, - dialOpts?: DialOptions, + signalingExchangeOpts: DialOptions | undefined, transportCredentialsInclude = false -): Promise => { - const optsCopy = { ...dialOpts } as DialOptions; - const directTransport = await dialDirect( +) => { + const transport = await dialDirect( signalingAddress, - optsCopy, + signalingExchangeOpts, transportCredentialsInclude ); - const signalingClient = createClient(SignalingService, directTransport); + return createClient(SignalingService, transport); +}; + +const getOptionalWebRTCConfig = async ( + callOpts: CallOptions, + signalingClient: ReturnType> +): Promise => { try { const resp = await signalingClient.optionalWebRTCConfig({}, callOpts); return resp.config ?? new WebRTCConfig(); @@ -363,18 +367,25 @@ export const dialWebRTC = async ( }; /** - * First complete our WebRTC options, gathering any extra information like - * TURN servers from a cloud server. + * First, derive options specifically for signaling against our target. Then + * complete our WebRTC options, gathering any extra information like TURN + * servers from a cloud server. This also creates the transport and signaling + * client that we'll reuse to avoid resource leaks. */ - const webrtcOpts = await processWebRTCOpts( + const exchangeOpts = processSignalingExchangeOpts( usableSignalingAddress, - callOpts, - dialOpts, - transportCredentialsInclude + dialOpts ); - // then derive options specifically for signaling against our target. - const exchangeOpts = processSignalingExchangeOpts( + + const signalingClient = await getSignalingClient( usableSignalingAddress, + exchangeOpts, + transportCredentialsInclude + ); + + const webrtcOpts = await processWebRTCOpts( + signalingClient, + callOpts, dialOpts ); @@ -385,21 +396,6 @@ export const dialWebRTC = async ( ); let successful = false; - let directTransport: Transport; - try { - directTransport = await dialDirect( - usableSignalingAddress, - exchangeOpts, - transportCredentialsInclude - ); - } catch (error) { - pc.close(); - dc.close(); - throw error; - } - - const signalingClient = createClient(SignalingService, directTransport); - const exchange = new SignalingExchange( signalingClient, callOpts, @@ -453,18 +449,11 @@ export const dialWebRTC = async ( }; const processWebRTCOpts = async ( - signalingAddress: string, + signalingClient: ReturnType>, callOpts: CallOptions, - dialOpts?: DialOptions, - transportCredentialsInclude = false + dialOpts: DialOptions | undefined ): Promise => { - // Get TURN servers, if any. - const config = await getOptionalWebRTCConfig( - signalingAddress, - callOpts, - dialOpts, - transportCredentialsInclude - ); + const config = await getOptionalWebRTCConfig(callOpts, signalingClient); const additionalIceServers: RTCIceServer[] = config.additionalIceServers.map( (ice) => { const iceUrls = [];