diff --git a/packages/cli/src/push/index.ts b/packages/cli/src/push/index.ts index 51b0ad05d8..a88688d6a6 100644 --- a/packages/cli/src/push/index.ts +++ b/packages/cli/src/push/index.ts @@ -2,7 +2,7 @@ import { EventEmitter } from 'events'; import { ServerResponse } from 'http'; import type { Server } from 'http'; import type { Socket } from 'net'; -import type { Application, RequestHandler } from 'express'; +import type { Application } from 'express'; import { Server as WSServer } from 'ws'; import { parse as parseUrl } from 'url'; import { Container, Service } from 'typedi'; @@ -14,6 +14,7 @@ import type { IPushDataType } from '@/Interfaces'; import type { User } from '@db/entities/User'; import { OnShutdown } from '@/decorators/OnShutdown'; import { AuthService } from '@/auth/auth.service'; +import { BadRequestError } from '@/errors/response-errors/bad-request.error'; const useWebSockets = config.getEnv('push.backend') === 'websocket'; @@ -38,14 +39,24 @@ export class Push extends EventEmitter { handleRequest(req: SSEPushRequest | WebSocketPushRequest, res: PushResponse) { const { - userId, + user, + ws, query: { sessionId }, } = req; + if (!sessionId) { + if (ws) { + ws.send('The query parameter "sessionId" is missing!'); + ws.close(1008); + return; + } + throw new BadRequestError('The query parameter "sessionId" is missing!'); + } + if (req.ws) { - (this.backend as WebSocketPush).add(sessionId, userId, req.ws); + (this.backend as WebSocketPush).add(sessionId, user.id, req.ws); } else if (!useWebSockets) { - (this.backend as SSEPush).add(sessionId, userId, { req, res }); + (this.backend as SSEPush).add(sessionId, user.id, { req, res }); } else { res.status(401).send('Unauthorized'); return; @@ -101,35 +112,12 @@ export const setupPushServer = (restEndpoint: string, server: Server, app: Appli export const setupPushHandler = (restEndpoint: string, app: Application) => { const endpoint = `/${restEndpoint}/push`; - - const pushValidationMiddleware: RequestHandler = async ( - req: SSEPushRequest | WebSocketPushRequest, - _, - next, - ) => { - const ws = req.ws; - - const { sessionId } = req.query; - if (sessionId === undefined) { - if (ws) { - ws.send('The query parameter "sessionId" is missing!'); - ws.close(1008); - } else { - next(new Error('The query parameter "sessionId" is missing!')); - } - return; - } - - next(); - }; - const push = Container.get(Push); const authService = Container.get(AuthService); app.use( endpoint, // eslint-disable-next-line @typescript-eslint/unbound-method authService.authMiddleware, - pushValidationMiddleware, (req: SSEPushRequest | WebSocketPushRequest, res: PushResponse) => push.handleRequest(req, res), ); }; diff --git a/packages/cli/src/push/types.ts b/packages/cli/src/push/types.ts index 791604de28..d99ca3f612 100644 --- a/packages/cli/src/push/types.ts +++ b/packages/cli/src/push/types.ts @@ -8,8 +8,8 @@ import type { AuthenticatedRequest } from '@/requests'; export type PushRequest = AuthenticatedRequest<{}, {}, {}, { sessionId: string }>; -export type SSEPushRequest = PushRequest & { ws: undefined; userId: User['id'] }; -export type WebSocketPushRequest = PushRequest & { ws: WebSocket; userId: User['id'] }; +export type SSEPushRequest = PushRequest & { ws: undefined }; +export type WebSocketPushRequest = PushRequest & { ws: WebSocket }; export type PushResponse = Response & { req: PushRequest }; diff --git a/packages/cli/test/unit/push/index.test.ts b/packages/cli/test/unit/push/index.test.ts new file mode 100644 index 0000000000..a2296a61ba --- /dev/null +++ b/packages/cli/test/unit/push/index.test.ts @@ -0,0 +1,42 @@ +import type { WebSocket } from 'ws'; +import config from '@/config'; +import type { User } from '@db/entities/User'; +import { Push } from '@/push'; +import { SSEPush } from '@/push/sse.push'; +import { WebSocketPush } from '@/push/websocket.push'; +import type { WebSocketPushRequest, SSEPushRequest } from '@/push/types'; +import { mockInstance } from '../../shared/mocking'; +import { mock } from 'jest-mock-extended'; +import { BadRequestError } from '@/errors/response-errors/bad-request.error'; + +jest.unmock('@/push'); + +describe('Push', () => { + const user = mock(); + + const sseBackend = mockInstance(SSEPush); + const wsBackend = mockInstance(WebSocketPush); + + test('should validate sessionId on requests for websocket backend', () => { + config.set('push.backend', 'websocket'); + const push = new Push(); + const ws = mock(); + const request = mock({ user, ws }); + request.query = { sessionId: '' }; + push.handleRequest(request, mock()); + + expect(ws.send).toHaveBeenCalled(); + expect(ws.close).toHaveBeenCalledWith(1008); + expect(wsBackend.add).not.toHaveBeenCalled(); + }); + + test('should validate sessionId on requests for SSE backend', () => { + config.set('push.backend', 'sse'); + const push = new Push(); + const request = mock({ user, ws: undefined }); + request.query = { sessionId: '' }; + expect(() => push.handleRequest(request, mock())).toThrow(BadRequestError); + + expect(sseBackend.add).not.toHaveBeenCalled(); + }); +});