1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196"""Unit tests for rate limiting middleware.
Tests focus on critical behavior:
1. Fail-open on Redis errors (availability)
2. Blocking after threshold exceeded (security/cost control)
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import Request, Response
from src.middleware.rate_limit import RateLimitMiddleware
@pytest.fixture
def mock_app() -> MagicMock:
"""Create mock FastAPI app."""
return MagicMock()
@pytest.fixture
def middleware(mock_app: MagicMock) -> RateLimitMiddleware:
"""Create rate limit middleware instance."""
return RateLimitMiddleware(mock_app)
def create_mock_request(
path: str,
method: str = "POST",
client_host: str = "192.168.1.100",
rate_limiter: AsyncMock | None = None,
) -> MagicMock:
"""Create mock request with optional rate limiter in app state."""
request = MagicMock(spec=Request)
request.url.path = path
request.method = method
request.client = MagicMock()
request.client.host = client_host
request.headers = MagicMock()
request.headers.get = MagicMock(return_value=None)
# Set up app state with rate limiter
request.app = MagicMock()
request.app.state = MagicMock()
request.app.state.rate_limiter = rate_limiter
return request
@pytest.mark.unit
@pytest.mark.asyncio
async def test_rate_limit_fails_open_on_redis_error(middleware: RateLimitMiddleware) -> None:
"""If Redis/rate limiter fails, requests should still succeed (fail open).
Catches: All users blocked when Redis is unavailable.
This is critical for availability - rate limiting should not take down the service.
"""
# Setup: rate limiter raises exception (simulating Redis failure)
mock_rate_limiter = AsyncMock()
mock_rate_limiter.check_submit_rate_limit = AsyncMock(
side_effect=Exception("Redis connection failed")
)
request = create_mock_request(
path="/api/v1/documents/submit",
method="POST",
rate_limiter=mock_rate_limiter,
)
call_next = AsyncMock(return_value=Response(status_code=200))
# Execute
response = await middleware.dispatch(request, call_next)
# Assert: request should proceed despite rate limiter failure
# The middleware catches exceptions and allows the request through
assert call_next.called, "Request should proceed when rate limiter fails"
assert response.status_code == 200
@pytest.mark.unit
@pytest.mark.asyncio
async def test_rate_limit_fails_open_when_not_configured(middleware: RateLimitMiddleware) -> None:
"""If rate limiter not in app state, requests should still succeed.
Catches: Service startup issues where rate limiter isn't initialized.
"""
# Setup: no rate limiter configured
request = create_mock_request(
path="/api/v1/documents/submit",
method="POST",
rate_limiter=None, # Not configured
)
call_next = AsyncMock(return_value=Response(status_code=200))
# Execute
response = await middleware.dispatch(request, call_next)
# Assert: request proceeds without rate limiting
assert call_next.called
assert response.status_code == 200
@pytest.mark.unit
@pytest.mark.asyncio
async def test_rate_limit_blocks_after_threshold(middleware: RateLimitMiddleware) -> None:
"""Request after limit exceeded returns 429 with retry_after.
Catches: Rate limiting not enforcing limits, allowing abuse/cost overruns.
"""
# Setup: rate limiter returns not allowed
mock_rate_limiter = AsyncMock()
mock_rate_limiter.check_submit_rate_limit = AsyncMock(
return_value=(False, 60) # Not allowed, retry after 60 seconds
)
request = create_mock_request(
path="/api/v1/documents/submit",
method="POST",
rate_limiter=mock_rate_limiter,
)
call_next = AsyncMock(return_value=Response(status_code=200))
# Execute
response = await middleware.dispatch(request, call_next)
# Assert: 429 response with correct headers
assert response.status_code == 429, "Should return 429 when rate limited"
assert not call_next.called, "Request should not proceed when rate limited"
assert response.headers.get("Retry-After") == "60"
assert response.headers.get("X-RateLimit-Remaining") == "0"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_rate_limit_allows_request_under_threshold(middleware: RateLimitMiddleware) -> None:
"""Request under limit proceeds normally with rate limit headers.
Catches: Rate limiter incorrectly blocking valid requests.
"""
# Setup: rate limiter allows request
mock_rate_limiter = AsyncMock()
mock_rate_limiter.check_submit_rate_limit = AsyncMock(
return_value=(True, 0) # Allowed
)
mock_rate_limiter.get_remaining_quota = AsyncMock(
return_value={"limit": 25, "remaining": 24, "reset_at": 1234567890}
)
request = create_mock_request(
path="/api/v1/documents/submit",
method="POST",
rate_limiter=mock_rate_limiter,
)
# Create a real Response object that we can modify
mock_response = Response(status_code=200)
call_next = AsyncMock(return_value=mock_response)
# Execute
response = await middleware.dispatch(request, call_next)
# Assert: request proceeds with rate limit headers
assert call_next.called
assert response.status_code == 200
assert response.headers.get("X-RateLimit-Limit") == "25"
assert response.headers.get("X-RateLimit-Remaining") == "24"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_rate_limit_skips_health_endpoints(middleware: RateLimitMiddleware) -> None:
"""Health check endpoints bypass rate limiting.
Catches: Health checks being rate limited, breaking load balancer probes.
"""
# Setup: rate limiter that would block (but shouldn't be called)
mock_rate_limiter = AsyncMock()
mock_rate_limiter.check_submit_rate_limit = AsyncMock(
return_value=(False, 60)
)
request = create_mock_request(
path="/health",
method="GET",
rate_limiter=mock_rate_limiter,
)
call_next = AsyncMock(return_value=Response(status_code=200))
# Execute
response = await middleware.dispatch(request, call_next)
# Assert: health check bypasses rate limiting
assert call_next.called
assert response.status_code == 200
mock_rate_limiter.check_submit_rate_limit.assert_not_called()