diff options
Diffstat (limited to 'lib/private/TaskProcessing/Manager.php')
-rw-r--r-- | lib/private/TaskProcessing/Manager.php | 148 |
1 files changed, 133 insertions, 15 deletions
diff --git a/lib/private/TaskProcessing/Manager.php b/lib/private/TaskProcessing/Manager.php index e2047be9c2c..11fb2bed559 100644 --- a/lib/private/TaskProcessing/Manager.php +++ b/lib/private/TaskProcessing/Manager.php @@ -31,14 +31,19 @@ use OCP\Files\Node; use OCP\Files\NotPermittedException; use OCP\Files\SimpleFS\ISimpleFile; use OCP\Http\Client\IClientService; -use OCP\IConfig; +use OCP\IAppConfig; +use OCP\ICache; +use OCP\ICacheFactory; use OCP\IL10N; use OCP\IServerContainer; +use OCP\IUserManager; +use OCP\IUserSession; use OCP\L10N\IFactory; use OCP\Lock\LockedException; use OCP\SpeechToText\ISpeechToTextProvider; use OCP\SpeechToText\ISpeechToTextProviderWithId; use OCP\TaskProcessing\EShapeType; +use OCP\TaskProcessing\Events\GetTaskProcessingProvidersEvent; use OCP\TaskProcessing\Events\TaskFailedEvent; use OCP\TaskProcessing\Events\TaskSuccessfulEvent; use OCP\TaskProcessing\Exception\NotFoundException; @@ -68,6 +73,11 @@ class Manager implements IManager { public const LEGACY_PREFIX_TEXTTOIMAGE = 'legacy:TextToImage:'; public const LEGACY_PREFIX_SPEECHTOTEXT = 'legacy:SpeechToText:'; + public const LAZY_CONFIG_KEYS = [ + 'ai.taskprocessing_type_preferences', + 'ai.taskprocessing_provider_preferences', + ]; + /** @var list<IProvider>|null */ private ?array $providers = null; @@ -77,8 +87,17 @@ class Manager implements IManager { private ?array $availableTaskTypes = null; private IAppData $appData; + private ?array $preferences = null; + private ?array $providersById = null; + + /** @var ITaskType[]|null */ + private ?array $taskTypes = null; + private ICache $distributedCache; + + private ?GetTaskProcessingProvidersEvent $eventResult = null; + public function __construct( - private IConfig $config, + private IAppConfig $appConfig, private Coordinator $coordinator, private IServerContainer $serverContainer, private LoggerInterface $logger, @@ -91,8 +110,12 @@ class Manager implements IManager { private IUserMountCache $userMountCache, private IClientService $clientService, private IAppManager $appManager, + private IUserManager $userManager, + private IUserSession $userSession, + ICacheFactory $cacheFactory, ) { $this->appData = $appDataFactory->get('core'); + $this->distributedCache = $cacheFactory->createDistributed('task_processing::'); } @@ -481,6 +504,20 @@ class Manager implements IManager { } /** + * Dispatches the event to collect external providers and task types. + * Caches the result within the request. + */ + private function dispatchGetProvidersEvent(): GetTaskProcessingProvidersEvent { + if ($this->eventResult !== null) { + return $this->eventResult; + } + + $this->eventResult = new GetTaskProcessingProvidersEvent(); + $this->dispatcher->dispatchTyped($this->eventResult); + return $this->eventResult ; + } + + /** * @return IProvider[] */ private function _getProviders(): array { @@ -508,6 +545,16 @@ class Manager implements IManager { } } + $event = $this->dispatchGetProvidersEvent(); + $externalProviders = $event->getProviders(); + foreach ($externalProviders as $provider) { + if (!isset($providers[$provider->getId()])) { + $providers[$provider->getId()] = $provider; + } else { + $this->logger->info('Skipping external task processing provider with ID ' . $provider->getId() . ' because a local provider with the same ID already exists.'); + } + } + $providers += $this->_getTextProcessingProviders() + $this->_getTextToImageProviders() + $this->_getSpeechToTextProviders(); return $providers; @@ -523,6 +570,10 @@ class Manager implements IManager { return []; } + if ($this->taskTypes !== null) { + return $this->taskTypes; + } + // Default task types $taskTypes = [ \OCP\TaskProcessing\TaskTypes\TextToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToText::class), @@ -542,6 +593,10 @@ class Manager implements IManager { \OCP\TaskProcessing\TaskTypes\TextToTextChatWithTools::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextChatWithTools::class), \OCP\TaskProcessing\TaskTypes\ContextAgentInteraction::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextAgentInteraction::class), \OCP\TaskProcessing\TaskTypes\TextToTextProofread::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextProofread::class), + \OCP\TaskProcessing\TaskTypes\TextToSpeech::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToSpeech::class), + \OCP\TaskProcessing\TaskTypes\AudioToAudioChat::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AudioToAudioChat::class), + \OCP\TaskProcessing\TaskTypes\ContextAgentAudioInteraction::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextAgentAudioInteraction::class), + \OCP\TaskProcessing\TaskTypes\AnalyzeImages::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AnalyzeImages::class), ]; foreach ($context->getTaskProcessingTaskTypes() as $providerServiceRegistration) { @@ -560,9 +615,19 @@ class Manager implements IManager { } } + $event = $this->dispatchGetProvidersEvent(); + $externalTaskTypes = $event->getTaskTypes(); + foreach ($externalTaskTypes as $taskType) { + if (isset($taskTypes[$taskType->getId()])) { + $this->logger->warning('External task processing task type is using ID ' . $taskType->getId() . ' which is already used by a locally registered task type (' . get_class($taskTypes[$taskType->getId()]) . ')'); + } + $taskTypes[$taskType->getId()] = $taskType; + } + $taskTypes += $this->_getTextProcessingTaskTypes(); - return $taskTypes; + $this->taskTypes = $taskTypes; + return $this->taskTypes; } /** @@ -570,7 +635,7 @@ class Manager implements IManager { */ private function _getTaskTypeSettings(): array { try { - $json = $this->config->getAppValue('core', 'ai.taskprocessing_type_preferences', ''); + $json = $this->appConfig->getValueString('core', 'ai.taskprocessing_type_preferences', '', lazy: true); if ($json === '') { return []; } @@ -582,10 +647,10 @@ class Manager implements IManager { foreach ($taskTypes as $taskType) { $taskTypeSettings[$taskType->getId()] = false; }; - + return $taskTypeSettings; } - + } /** @@ -725,12 +790,27 @@ class Manager implements IManager { public function getPreferredProvider(string $taskTypeId) { try { - $preferences = json_decode($this->config->getAppValue('core', 'ai.taskprocessing_provider_preferences', 'null'), associative: true, flags: JSON_THROW_ON_ERROR); + if ($this->preferences === null) { + $this->preferences = $this->distributedCache->get('ai.taskprocessing_provider_preferences'); + if ($this->preferences === null) { + $this->preferences = json_decode( + $this->appConfig->getValueString('core', 'ai.taskprocessing_provider_preferences', 'null', lazy: true), + associative: true, + flags: JSON_THROW_ON_ERROR, + ); + $this->distributedCache->set('ai.taskprocessing_provider_preferences', $this->preferences, 60 * 3); + } + } + $providers = $this->getProviders(); - if (isset($preferences[$taskTypeId])) { - $provider = current(array_values(array_filter($providers, fn ($provider) => $provider->getId() === $preferences[$taskTypeId]))); - if ($provider !== false) { - return $provider; + if (isset($this->preferences[$taskTypeId])) { + $providersById = $this->providersById ?? array_reduce($providers, static function (array $carry, IProvider $provider) { + $carry[$provider->getId()] = $provider; + return $carry; + }, []); + $this->providersById = $providersById; + if (isset($providersById[$this->preferences[$taskTypeId]])) { + return $providersById[$this->preferences[$taskTypeId]]; } } // By default, use the first available provider @@ -745,7 +825,17 @@ class Manager implements IManager { throw new \OCP\TaskProcessing\Exception\Exception('No matching provider found'); } - public function getAvailableTaskTypes(bool $showDisabled = false): array { + public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userId = null): array { + // userId will be obtained from the session if left to null + if (!$this->checkGuestAccess($userId)) { + return []; + } + if ($this->availableTaskTypes === null) { + $cachedValue = $this->distributedCache->get('available_task_types_v2'); + if ($cachedValue !== null) { + $this->availableTaskTypes = unserialize($cachedValue); + } + } // Either we have no cache or showDisabled is turned on, which we don't want to cache, ever. if ($this->availableTaskTypes === null || $showDisabled) { $taskTypes = $this->_getTaskTypes(); @@ -787,6 +877,7 @@ class Manager implements IManager { } $this->availableTaskTypes = $availableTaskTypes; + $this->distributedCache->set('available_task_types_v2', serialize($this->availableTaskTypes), 60); } @@ -797,7 +888,27 @@ class Manager implements IManager { return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]); } + private function checkGuestAccess(?string $userId = null): bool { + if ($userId === null && !$this->userSession->isLoggedIn()) { + return true; + } + if ($userId === null) { + $user = $this->userSession->getUser(); + } else { + $user = $this->userManager->get($userId); + } + + $guestsAllowed = $this->appConfig->getValueString('core', 'ai.taskprocessing_guests', 'false'); + if ($guestsAllowed == 'true' || !class_exists(\OCA\Guests\UserBackend::class) || !($user->getBackend() instanceof \OCA\Guests\UserBackend)) { + return true; + } + return false; + } + public function scheduleTask(Task $task): void { + if (!$this->checkGuestAccess($task->getUserId())) { + throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('Access to this resource is forbidden for guests.'); + } if (!$this->canHandleTask($task)) { throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId()); } @@ -812,6 +923,9 @@ class Manager implements IManager { } public function runTask(Task $task): Task { + if (!$this->checkGuestAccess($task->getUserId())) { + throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('Access to this resource is forbidden for guests.'); + } if (!$this->canHandleTask($task)) { throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId()); } @@ -973,7 +1087,7 @@ class Manager implements IManager { $task->setEndedAt(time()); $error = 'The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec'; $task->setErrorMessage($error); - $this->logger->error($error, ['exception' => $e]); + $this->logger->error($error, ['exception' => $e, 'output' => $result]); } catch (NotPermittedException $e) { $task->setProgress(1); $task->setStatus(Task::STATUS_FAILED); @@ -990,7 +1104,11 @@ class Manager implements IManager { $this->logger->error($error, ['exception' => $e]); } } - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); + try { + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); + } catch (\JsonException $e) { + throw new \OCP\TaskProcessing\Exception\Exception('The task was processed successfully but the provider\'s output could not be encoded as JSON for the database.', 0, $e); + } try { $this->taskMapper->update($taskEntity); $this->runWebhook($task); @@ -1366,7 +1484,7 @@ class Manager implements IManager { $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. Invalid method: ' . $method); } [, $exAppId, $httpMethod] = $parsedMethod; - if (!$this->appManager->isInstalled('app_api')) { + if (!$this->appManager->isEnabledForAnyone('app_api')) { $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. AppAPI is disabled or not installed.'); return; } |