diff options
Diffstat (limited to 'lib/private')
-rw-r--r-- | lib/private/AppFramework/Middleware/Security/RateLimitingMiddleware.php | 118 |
1 files changed, 77 insertions, 41 deletions
diff --git a/lib/private/AppFramework/Middleware/Security/RateLimitingMiddleware.php b/lib/private/AppFramework/Middleware/Security/RateLimitingMiddleware.php index 5f683fa38ac..6f84a0c94d0 100644 --- a/lib/private/AppFramework/Middleware/Security/RateLimitingMiddleware.php +++ b/lib/private/AppFramework/Middleware/Security/RateLimitingMiddleware.php @@ -1,5 +1,9 @@ <?php + +declare(strict_types=1); + /** + * @copyright Copyright (c) 2023 Joas Schilling <coding@schilljs.com> * @copyright Copyright (c) 2017 Lukas Reschke <lukas@statuscode.ch> * * @author Christoph Wurst <christoph@winzerhof-wurst.at> @@ -27,11 +31,17 @@ namespace OC\AppFramework\Middleware\Security; use OC\AppFramework\Utility\ControllerMethodReflector; use OC\Security\RateLimiting\Exception\RateLimitExceededException; use OC\Security\RateLimiting\Limiter; +use OCP\AppFramework\Controller; +use OCP\AppFramework\Http\Attribute\AnonRateLimit; +use OCP\AppFramework\Http\Attribute\ARateLimit; +use OCP\AppFramework\Http\Attribute\UserRateLimit; use OCP\AppFramework\Http\DataResponse; +use OCP\AppFramework\Http\Response; use OCP\AppFramework\Http\TemplateResponse; use OCP\AppFramework\Middleware; use OCP\IRequest; use OCP\IUserSession; +use ReflectionMethod; /** * Class RateLimitingMiddleware is the middleware responsible for implementing the @@ -42,7 +52,12 @@ use OCP\IUserSession; * @UserRateThrottle(limit=5, period=100) * @AnonRateThrottle(limit=1, period=100) * - * Those annotations above would mean that logged-in users can access the page 5 + * Or attributes such as: + * + * #[UserRateLimit(limit: 5, period: 100)] + * #[AnonRateLimit(limit: 1, period: 100)] + * + * Both sets would mean that logged-in users can access the page 5 * times within 100 seconds, and anonymous users 1 time within 100 seconds. If * only an AnonRateThrottle is specified that one will also be applied to logged-in * users. @@ -50,64 +65,85 @@ use OCP\IUserSession; * @package OC\AppFramework\Middleware\Security */ class RateLimitingMiddleware extends Middleware { - /** @var IRequest $request */ - private $request; - /** @var IUserSession */ - private $userSession; - /** @var ControllerMethodReflector */ - private $reflector; - /** @var Limiter */ - private $limiter; - - /** - * @param IRequest $request - * @param IUserSession $userSession - * @param ControllerMethodReflector $reflector - * @param Limiter $limiter - */ - public function __construct(IRequest $request, - IUserSession $userSession, - ControllerMethodReflector $reflector, - Limiter $limiter) { - $this->request = $request; - $this->userSession = $userSession; - $this->reflector = $reflector; - $this->limiter = $limiter; + public function __construct( + protected IRequest $request, + protected IUserSession $userSession, + protected ControllerMethodReflector $reflector, + protected Limiter $limiter, + ) { } /** * {@inheritDoc} * @throws RateLimitExceededException */ - public function beforeController($controller, $methodName) { + public function beforeController(Controller $controller, string $methodName): void { parent::beforeController($controller, $methodName); - - $anonLimit = $this->reflector->getAnnotationParameter('AnonRateThrottle', 'limit'); - $anonPeriod = $this->reflector->getAnnotationParameter('AnonRateThrottle', 'period'); - $userLimit = $this->reflector->getAnnotationParameter('UserRateThrottle', 'limit'); - $userPeriod = $this->reflector->getAnnotationParameter('UserRateThrottle', 'period'); $rateLimitIdentifier = get_class($controller) . '::' . $methodName; - if ($userLimit !== '' && $userPeriod !== '' && $this->userSession->isLoggedIn()) { - $this->limiter->registerUserRequest( - $rateLimitIdentifier, - $userLimit, - $userPeriod, - $this->userSession->getUser() - ); - } elseif ($anonLimit !== '' && $anonPeriod !== '') { + + if ($this->userSession->isLoggedIn()) { + $rateLimit = $this->readLimitFromAnnotationOrAttribute($controller, $methodName, 'UserRateThrottle', UserRateLimit::class); + + if ($rateLimit !== null) { + $this->limiter->registerUserRequest( + $rateLimitIdentifier, + $rateLimit->getLimit(), + $rateLimit->getPeriod(), + $this->userSession->getUser() + ); + return; + } + + // If not user specific rate limit is found the Anon rate limit applies! + } + + $rateLimit = $this->readLimitFromAnnotationOrAttribute($controller, $methodName, 'AnonRateThrottle', AnonRateLimit::class); + + if ($rateLimit !== null) { $this->limiter->registerAnonRequest( $rateLimitIdentifier, - $anonLimit, - $anonPeriod, + $rateLimit->getLimit(), + $rateLimit->getPeriod(), $this->request->getRemoteAddress() ); } } /** + * @template T of ARateLimit + * + * @param Controller $controller + * @param string $methodName + * @param string $annotationName + * @param class-string<T> $attributeClass + * @return ?ARateLimit + */ + protected function readLimitFromAnnotationOrAttribute(Controller $controller, string $methodName, string $annotationName, string $attributeClass): ?ARateLimit { + $annotationLimit = $this->reflector->getAnnotationParameter($annotationName, 'limit'); + $annotationPeriod = $this->reflector->getAnnotationParameter($annotationName, 'period'); + + if ($annotationLimit !== '' && $annotationPeriod !== '') { + return new $attributeClass( + (int) $annotationLimit, + (int) $annotationPeriod, + ); + } + + $reflectionMethod = new ReflectionMethod($controller, $methodName); + $attributes = $reflectionMethod->getAttributes($attributeClass); + $attribute = current($attributes); + + if ($attribute !== false) { + return $attribute->newInstance(); + } + + return null; + } + + /** * {@inheritDoc} */ - public function afterException($controller, $methodName, \Exception $exception) { + public function afterException(Controller $controller, string $methodName, \Exception $exception): Response { if ($exception instanceof RateLimitExceededException) { if (stripos($this->request->getHeader('Accept'), 'html') === false) { $response = new DataResponse([], $exception->getCode()); |