using System; using System.Collections.Concurrent; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Haoliang.Core.Services; using Haoliang.Models.Common; namespace Haoliang.Api.Middleware { public class RateLimitMiddleware { private readonly RequestDelegate _next; private readonly ILogger _logger; private readonly ILoggingService _loggingService; private readonly RateLimitSettings _settings; private readonly ConcurrentDictionary _requestTimestamps = new ConcurrentDictionary(); private readonly ConcurrentDictionary _requestCounts = new ConcurrentDictionary(); public RateLimitMiddleware( RequestDelegate next, ILogger logger, ILoggingService loggingService, IOptions settings) { _next = next; _logger = logger; _loggingService = loggingService; _settings = settings.Value; } public async Task Invoke(HttpContext context) { var clientId = GetClientId(context); var endpoint = GetEndpoint(context); if (IsRateLimited(clientId, endpoint)) { await LogRateLimitExceeded(clientId, endpoint); context.Response.StatusCode = 429; context.Response.Headers["X-RateLimit-Limit"] = _settings.MaxRequests.ToString(); context.Response.Headers["X-RateLimit-Remaining"] = "0"; context.Response.Headers["X-RateLimit-Reset"] = GetResetTime().ToString(); var response = new ApiResponse { Success = false, Message = "Rate limit exceeded. Please try again later.", ErrorCode = 429, Timestamp = DateTime.Now }; await context.Response.WriteAsync(System.Text.Json.JsonSerializer.Serialize(response)); return; } await _next(context); // Update rate limit headers context.Response.Headers["X-RateLimit-Limit"] = _settings.MaxRequests.ToString(); context.Response.Headers["X-RateLimit-Remaining"] = GetRemainingRequests(clientId, endpoint).ToString(); context.Response.Headers["X-RateLimit-Reset"] = GetResetTime().ToString(); } private bool IsRateLimited(string clientId, string endpoint) { var key = $"{clientId}:{endpoint}"; var now = DateTime.UtcNow; var windowStart = now.AddSeconds(-_settings.TimeWindow); // Clean old entries CleanupOldEntries(windowStart); // Check if we need to reset the count if (_requestCounts.TryGetValue(key, out var count)) { if (count >= _settings.MaxRequests) { // Check if the time window has reset if (_requestTimestamps.TryGetValue(key, out var timestamp) && timestamp < windowStart) { _requestCounts[key] = 1; _requestTimestamps[key] = now; return false; } return true; } } // Increment count _requestCounts.AddOrUpdate(key, 1, (_, _) => count + 1); _requestTimestamps.AddOrUpdate(key, now, (_, _) => now); return false; } private void CleanupOldEntries(DateTime windowStart) { var oldKeys = _requestTimestamps .Where(kvp => kvp.Value < windowStart) .Select(kvp => kvp.Key) .ToList(); foreach (var key in oldKeys) { _requestTimestamps.TryRemove(key, out _); _requestCounts.TryRemove(key, out _); } } private int GetRemainingRequests(string clientId, string endpoint) { var key = $"{clientId}:{endpoint}"; var maxRequests = _settings.MaxRequests; var currentCount = _requestCounts.TryGetValue(key, out var count) ? count : 0; return Math.Max(0, maxRequests - currentCount); } private DateTime GetResetTime() { return DateTime.UtcNow.AddSeconds(_settings.TimeWindow); } private async Task LogRateLimitExceeded(string clientId, string endpoint) { var logData = new { ClientId = clientId, Endpoint = endpoint, Timestamp = DateTime.Now, MaxRequests = _settings.MaxRequests, TimeWindow = _settings.TimeWindow }; await _loggingService.LogWarningAsync($"Rate limit exceeded: {JsonSerializer.Serialize(logData)}"); _logger.LogWarning("Rate limit exceeded for client {ClientId} on endpoint {Endpoint}", clientId, endpoint); } private string GetClientId(HttpContext context) { // Try to get client ID from various sources if (context.User?.Identity?.IsAuthenticated == true) { return context.User.Identity.Name ?? "authenticated"; } // Use IP address as fallback return context.Connection.RemoteIpAddress?.ToString() ?? "unknown"; } private string GetEndpoint(HttpContext context) { return context.Request.Path.ToString(); } } public class RateLimitSettings { public int MaxRequests { get; set; } = 100; public int TimeWindow { get; set; } = 60; // in seconds public bool EnableRateLimiting { get; set; } = true; public bool ExcludeHealthChecks { get; set; } = true; public string[] ExcludedPaths { get; set; } = Array.Empty(); } public static class RateLimitMiddlewareExtensions { public static IApplicationBuilder UseRateLimiting(this IApplicationBuilder builder, Action configure = null) { var settings = new RateLimitSettings(); configure?.Invoke(settings); return builder.UseMiddleware(Options.Create(settings)); } } }