๐Ÿ“ฆ EqualifyEverything / equalify-reflow

๐Ÿ“„ api_key_auth.py ยท 312 lines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312"""API Key authentication middleware for FastAPI."""

import logging
import secrets
from collections.abc import Awaitable, Callable
from typing import Any

from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

from ..auth.base import Identity
from ..config import settings

logger = logging.getLogger(__name__)


class APIKeyAuthMiddleware(BaseHTTPMiddleware):
    """
    Middleware for API key authentication.

    Validates API keys in the X-API-Key header (configurable).
    Public endpoints (health, metrics, dev monitoring) bypass authentication.
    """

    def __init__(self, app: Any) -> None:
        """
        Initialize API key auth middleware.

        Args:
            app: FastAPI application instance
        """
        super().__init__(app)
        # Cache API keys at initialization to avoid reloading on every request
        self._cached_keys: set[str] = self._load_api_keys()

    def _load_api_keys(self) -> set[str]:
        """
        Load valid API keys from settings.

        Returns:
            Set of valid API key strings
        """
        if not settings.api_keys:
            logger.warning("No API keys configured! All authenticated requests will be rejected.")
            return set()

        # Parse comma-separated keys from SecretStr
        keys_str = settings.api_keys.get_secret_value()
        keys = {key.strip() for key in keys_str.split(",") if key.strip()}

        if not keys:
            logger.warning("API keys configured but empty after parsing!")
            return set()

        logger.info(f"Loaded {len(keys)} API key(s) for authentication")
        return keys

    async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
        """
        Validate API key before processing request.

        Args:
            request: Incoming request
            call_next: Next middleware/handler

        Returns:
            Response object (401 if auth fails, otherwise normal response)
        """
        # Skip authentication for public endpoints
        if self._is_public_endpoint(request):
            return await call_next(request)

        # Extract API key from header
        api_key = request.headers.get(settings.api_key_header_name)

        # Check if API key is provided
        if not api_key:
            logger.warning(
                f"Missing API key for {request.method} {request.url.path} from {self._get_client_ip(request)}"
            )
            return self._unauthorized_response(
                detail=f"Missing API key. Provide a valid key in the '{settings.api_key_header_name}' header."
            )

        # Validate API key using constant-time comparison
        if not self._is_valid_key(api_key):
            logger.warning(
                f"Invalid API key for {request.method} {request.url.path} from {self._get_client_ip(request)}"
            )
            return self._unauthorized_response(detail="Invalid API key")

        # API key is valid, add to request state for potential use in handlers
        request.state.api_key = api_key

        # Process request
        return await call_next(request)

    def _is_public_endpoint(self, request: Request) -> bool:
        """
        Check if endpoint is exempt from API key auth.

        Only `/api/*` paths require an API key โ€” everything else is served by
        the (public) viewer, the LTI JWT flow, or is a public health check.
        API routes themselves have three carve-outs:

        - Development-only endpoints under /api/dev/monitoring, /api/dev/minimal,
          /api/dev/pipeline-viewer are public when settings.environment == "dev"
        - Same-origin requests from the viewer are exempt via
          `_is_demo_ui_request` (checks Referer, Origin, and Sec-Fetch-Site)
          so the viewer's JS can call the API without the browser having to
          inject an X-API-Key header
        - SSE stream endpoints with a `?token=` query parameter bypass API
          key auth and validate the token in the endpoint handler instead

        Args:
            request: Incoming request

        Returns:
            True if the request is exempt from API key auth
        """
        path = request.url.path

        # API routes are the only ones gated by API key auth.
        if path.startswith("/api/"):
            # Session-authenticated requests (set by SessionAuthMiddleware
            # when AUTH_MODE != 'none') short-circuit. Both auth paths
            # coexist: an API key still works in parallel for programmatic
            # clients regardless of AUTH_MODE.
            #
            # We isinstance-check rather than truthiness-check because mock
            # tests upstream sometimes use MagicMock for request.state, and a
            # MagicMock attribute is always truthy. Only a real Identity
            # object should short-circuit.
            if isinstance(getattr(request.state, "identity", None), Identity):
                return True
            # /api/v1/auth/* endpoints handle their own auth (config is
            # always public; login/logout/me self-validate). Exempt the
            # subtree here so an unauthenticated user can hit /auth/login.
            if path.startswith("/api/v1/auth/"):
                return True
            # Dev-only endpoints are exempt when running in dev
            if settings.environment == "dev" and (
                path.startswith("/api/dev/monitoring")
                or path.startswith("/api/dev/minimal")
                or path.startswith("/api/dev/pipeline-viewer")
            ):
                return True
            # Same-origin requests from the viewer or demo UI.
            # When AUTH_MODE != 'none', the SPA must establish a session
            # before it can call the API โ€” the same-origin shortcut is only
            # safe in the open-default mode. Otherwise an unauthenticated
            # browser session would defeat the new auth layer.
            if settings.auth_mode == "none" and self._is_demo_ui_request(request):
                return True
            # SSE stream endpoints authenticate via ?token= query param
            if self._is_stream_token_request(request):
                return True
            return False

        # Everything outside /api/ is exempt from API key auth:
        #   - / and every SPA deep link โ†’ viewer HTML (public)
        #   - /viewer, /viewer/* โ†’ legacy 301 redirects to /
        #   - /health, /health/ready โ†’ public health checks
        #   - /docs, /openapi.json, /redoc โ†’ public API documentation
        #   - /lti/* โ†’ Canvas JWT auth
        #   - /static/canvas/* โ†’ dashboard assets
        #   - /metrics โ†’ Prometheus
        #   - /assets/*, favicons, fonts, etc. โ†’ public viewer static files
        return True

    def _is_demo_ui_request(self, request: Request) -> bool:
        """
        Check if request is a same-origin browser fetch from the viewer SPA.

        The viewer is served at `/` and is now publicly accessible. Its JS
        calls /api/v1/* from the same origin without an X-API-Key header.
        We identify these requests by the combination of:

        - No X-API-Key header (external clients always send one)
        - Sec-Fetch-Site: same-origin (set by the browser for genuinely
          same-origin fetches, unforgeable by external callers)

        This is safe because:
        1. CORS prevents external sites from forging the Origin or reading
           the response
        2. External API clients must send X-API-Key, which takes the
           normal-auth branch in dispatch() above before this code runs
        3. Sec-Fetch-Site is a browser-controlled header โ€” page scripts
           cannot set or spoof it

        Args:
            request: Incoming request

        Returns:
            True if request appears to be a same-origin viewer fetch
        """
        # If request has an API key, it's an external client - use normal auth
        if request.headers.get(settings.api_key_header_name):
            return False

        # Same-origin fetch from the viewer SPA โ€” trust the browser-set
        # Sec-Fetch-Site header combined with absence of API key.
        sec_fetch_site = request.headers.get("Sec-Fetch-Site", "")
        if sec_fetch_site == "same-origin":
            return True

        return False

    def _is_stream_token_request(self, request: Request) -> bool:
        """
        Check if request has stream token for SSE endpoint.

        Stream tokens allow bypassing API key auth for browser EventSource
        connections which cannot send custom headers.

        This only checks if the endpoint qualifies for token-based auth.
        Actual token validation and consumption happens in the endpoint handler.
        We mark it as "public" here to bypass API key check, then the
        endpoint validates the token.

        Args:
            request: Incoming request

        Returns:
            True if this is a stream endpoint with a token parameter
        """
        path = request.url.path

        # Only applies to stream endpoints (not the token creation endpoint)
        if not path.endswith("/stream"):
            return False

        # Must have token query parameter
        token = request.query_params.get("token")
        if not token:
            return False

        # Basic format validation (256-bit tokens are ~43 chars)
        # Full validation happens in endpoint handler
        if len(token) < 40:
            return False

        return True

    def _is_valid_key(self, provided_key: str) -> bool:
        """
        Validate API key using constant-time comparison.

        Uses secrets.compare_digest() to prevent timing attacks.
        Uses cached keys loaded at initialization for optimal performance.

        Args:
            provided_key: API key from request header

        Returns:
            True if key is valid
        """
        if not self._cached_keys:
            return False

        # Use constant-time comparison to prevent timing attacks
        for valid_key in self._cached_keys:
            if secrets.compare_digest(provided_key, valid_key):
                return True

        return False

    def _get_client_ip(self, request: Request) -> str:
        """
        Extract client IP from request.

        Handles X-Forwarded-For header for reverse proxy setups.

        Args:
            request: Incoming request

        Returns:
            Client IP address
        """
        # Check X-Forwarded-For header (for reverse proxies)
        forwarded = request.headers.get("X-Forwarded-For")
        if forwarded:
            return forwarded.split(",")[0].strip()

        # Check X-Real-IP header (nginx)
        real_ip = request.headers.get("X-Real-IP")
        if real_ip:
            return real_ip.strip()

        # Fall back to direct connection IP
        if request.client:
            return request.client.host

        return "unknown"

    def _unauthorized_response(self, detail: str) -> JSONResponse:
        """
        Create unauthorized error response.

        Args:
            detail: Error message

        Returns:
            401 Unauthorized response
        """
        return JSONResponse(
            status_code=status.HTTP_401_UNAUTHORIZED,
            content={"detail": detail},
            headers={"WWW-Authenticate": 'ApiKey realm="API", charset="UTF-8"'},
        )