diff options
Diffstat (limited to 'lib/private/TextToImage/Manager.php')
-rw-r--r-- | lib/private/TextToImage/Manager.php | 325 |
1 files changed, 325 insertions, 0 deletions
diff --git a/lib/private/TextToImage/Manager.php b/lib/private/TextToImage/Manager.php new file mode 100644 index 00000000000..eec6cc3d241 --- /dev/null +++ b/lib/private/TextToImage/Manager.php @@ -0,0 +1,325 @@ +<?php + +declare(strict_types=1); + +/** + * SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later + */ + +namespace OC\TextToImage; + +use OC\AppFramework\Bootstrap\Coordinator; +use OC\TextToImage\Db\Task as DbTask; +use OC\TextToImage\Db\TaskMapper; +use OCP\AppFramework\Db\DoesNotExistException; +use OCP\AppFramework\Db\MultipleObjectsReturnedException; +use OCP\BackgroundJob\IJobList; +use OCP\DB\Exception; +use OCP\Files\AppData\IAppDataFactory; +use OCP\Files\IAppData; +use OCP\Files\NotFoundException; +use OCP\Files\NotPermittedException; +use OCP\IConfig; +use OCP\IServerContainer; +use OCP\PreConditionNotMetException; +use OCP\TextToImage\Exception\TaskFailureException; +use OCP\TextToImage\Exception\TaskNotFoundException; +use OCP\TextToImage\IManager; +use OCP\TextToImage\IProvider; +use OCP\TextToImage\IProviderWithUserId; +use OCP\TextToImage\Task; +use Psr\Log\LoggerInterface; +use RuntimeException; +use Throwable; + +class Manager implements IManager { + /** @var ?list<IProvider> */ + private ?array $providers = null; + private IAppData $appData; + + public function __construct( + private IServerContainer $serverContainer, + private Coordinator $coordinator, + private LoggerInterface $logger, + private IJobList $jobList, + private TaskMapper $taskMapper, + private IConfig $config, + IAppDataFactory $appDataFactory, + ) { + $this->appData = $appDataFactory->get('core'); + } + + /** + * @inheritDoc + */ + public function getProviders(): array { + $context = $this->coordinator->getRegistrationContext(); + if ($context === null) { + return []; + } + + if ($this->providers !== null) { + return $this->providers; + } + + $this->providers = []; + + foreach ($context->getTextToImageProviders() as $providerServiceRegistration) { + $class = $providerServiceRegistration->getService(); + try { + /** @var IProvider $provider */ + $provider = $this->serverContainer->get($class); + $this->providers[] = $provider; + } catch (Throwable $e) { + $this->logger->error('Failed to load Text to image provider ' . $class, [ + 'exception' => $e, + ]); + } + } + + return $this->providers; + } + + /** + * @inheritDoc + */ + public function hasProviders(): bool { + $context = $this->coordinator->getRegistrationContext(); + if ($context === null) { + return false; + } + return count($context->getTextToImageProviders()) > 0; + } + + /** + * @inheritDoc + */ + public function runTask(Task $task): void { + $this->logger->debug('Running TextToImage Task'); + if (!$this->hasProviders()) { + throw new PreConditionNotMetException('No text to image provider is installed that can handle this task'); + } + $providers = $this->getPreferredProviders(); + + foreach ($providers as $provider) { + $this->logger->debug('Trying to run Text2Image provider ' . $provider::class); + try { + $task->setStatus(Task::STATUS_RUNNING); + $completionExpectedAt = new \DateTime('now'); + $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S')); + $task->setCompletionExpectedAt($completionExpectedAt); + if ($task->getId() === null) { + $this->logger->debug('Inserting Text2Image task into DB'); + $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task)); + $task->setId($taskEntity->getId()); + } else { + $this->logger->debug('Updating Text2Image task in DB'); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + } + try { + $folder = $this->appData->getFolder('text2image'); + } catch (NotFoundException) { + $this->logger->debug('Creating folder in appdata for Text2Image results'); + $folder = $this->appData->newFolder('text2image'); + } + try { + $folder = $folder->getFolder((string)$task->getId()); + } catch (NotFoundException) { + $this->logger->debug('Creating new folder in appdata Text2Image results folder'); + $folder = $folder->newFolder((string)$task->getId()); + } + $this->logger->debug('Creating result files for Text2Image task'); + $resources = []; + $files = []; + for ($i = 0; $i < $task->getNumberOfImages(); $i++) { + $file = $folder->newFile((string)$i); + $files[] = $file; + $resource = $file->write(); + if ($resource !== false && $resource !== true && is_resource($resource)) { + $resources[] = $resource; + } else { + throw new RuntimeException('Text2Image generation using provider "' . $provider->getName() . '" failed: Couldn\'t open file to write.'); + } + } + $this->logger->debug('Calling Text2Image provider\'s generate method'); + if ($provider instanceof IProviderWithUserId) { + $provider->setUserId($task->getUserId()); + } + $provider->generate($task->getInput(), $resources); + for ($i = 0; $i < $task->getNumberOfImages(); $i++) { + if (is_resource($resources[$i])) { + // If $resource hasn't been closed yet, we'll do that here + fclose($resources[$i]); + } + } + $task->setStatus(Task::STATUS_SUCCESSFUL); + $this->logger->debug('Updating Text2Image task in DB'); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + return; + } catch (\RuntimeException|\Throwable $e) { + for ($i = 0; $i < $task->getNumberOfImages(); $i++) { + if (isset($files, $files[$i])) { + try { + $files[$i]->delete(); + } catch (NotPermittedException $e) { + $this->logger->warning('Failed to clean up Text2Image result file after error', ['exception' => $e]); + } + } + } + + $this->logger->info('Text2Image generation using provider "' . $provider->getName() . '" failed', ['exception' => $e]); + $task->setStatus(Task::STATUS_FAILED); + try { + $this->taskMapper->update(DbTask::fromPublicTask($task)); + } catch (Exception $e) { + $this->logger->warning('Failed to update database after Text2Image error', ['exception' => $e]); + } + throw new TaskFailureException('Text2Image generation using provider "' . $provider->getName() . '" failed: ' . $e->getMessage(), 0, $e); + } + } + + $task->setStatus(Task::STATUS_FAILED); + try { + $this->taskMapper->update(DbTask::fromPublicTask($task)); + } catch (Exception $e) { + $this->logger->warning('Failed to update database after Text2Image error', ['exception' => $e]); + } + throw new TaskFailureException('Could not run task'); + } + + /** + * @inheritDoc + */ + public function scheduleTask(Task $task): void { + if (!$this->hasProviders()) { + throw new PreConditionNotMetException('No text to image provider is installed that can handle this task'); + } + $this->logger->debug('Scheduling Text2Image Task'); + $task->setStatus(Task::STATUS_SCHEDULED); + $completionExpectedAt = new \DateTime('now'); + $completionExpectedAt->add(new \DateInterval('PT' . $this->getPreferredProviders()[0]->getExpectedRuntime() . 'S')); + $task->setCompletionExpectedAt($completionExpectedAt); + $taskEntity = DbTask::fromPublicTask($task); + $this->taskMapper->insert($taskEntity); + $task->setId($taskEntity->getId()); + $this->jobList->add(TaskBackgroundJob::class, [ + 'taskId' => $task->getId() + ]); + } + + /** + * @inheritDoc + */ + public function runOrScheduleTask(Task $task) : void { + if (!$this->hasProviders()) { + throw new PreConditionNotMetException('No text to image provider is installed that can handle this task'); + } + $providers = $this->getPreferredProviders(); + $maxExecutionTime = (int)ini_get('max_execution_time'); + // Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time + if ($providers[0]->getExpectedRuntime() > $maxExecutionTime * 0.8) { + $this->scheduleTask($task); + return; + } + $this->runTask($task); + } + + /** + * @inheritDoc + */ + public function deleteTask(Task $task): void { + $taskEntity = DbTask::fromPublicTask($task); + $this->taskMapper->delete($taskEntity); + $this->jobList->remove(TaskBackgroundJob::class, [ + 'taskId' => $task->getId() + ]); + } + + /** + * Get a task from its id + * + * @param int $id The id of the task + * @return Task + * @throws RuntimeException If the query failed + * @throws TaskNotFoundException If the task could not be found + */ + public function getTask(int $id): Task { + try { + $taskEntity = $this->taskMapper->find($id); + return $taskEntity->toPublicTask(); + } catch (DoesNotExistException $e) { + throw new TaskNotFoundException('Could not find task with the provided id'); + } catch (MultipleObjectsReturnedException $e) { + throw new RuntimeException('Could not uniquely identify task with given id', 0, $e); + } catch (Exception $e) { + throw new RuntimeException('Failure while trying to find task by id: ' . $e->getMessage(), 0, $e); + } + } + + /** + * Get a task from its user id and task id + * If userId is null, this can only get a task that was scheduled anonymously + * + * @param int $id The id of the task + * @param string|null $userId The user id that scheduled the task + * @return Task + * @throws RuntimeException If the query failed + * @throws TaskNotFoundException If the task could not be found + */ + public function getUserTask(int $id, ?string $userId): Task { + try { + $taskEntity = $this->taskMapper->findByIdAndUser($id, $userId); + return $taskEntity->toPublicTask(); + } catch (DoesNotExistException $e) { + throw new TaskNotFoundException('Could not find task with the provided id and user id'); + } catch (MultipleObjectsReturnedException $e) { + throw new RuntimeException('Could not uniquely identify task with given id and user id', 0, $e); + } catch (Exception $e) { + throw new RuntimeException('Failure while trying to find task by id and user id: ' . $e->getMessage(), 0, $e); + } + } + + /** + * Get a list of tasks scheduled by a specific user for a specific app + * and optionally with a specific identifier. + * This cannot be used to get anonymously scheduled tasks + * + * @param string $userId + * @param string $appId + * @param string|null $identifier + * @return Task[] + * @throws RuntimeException + */ + public function getUserTasksByApp(?string $userId, string $appId, ?string $identifier = null): array { + try { + $taskEntities = $this->taskMapper->findUserTasksByApp($userId, $appId, $identifier); + return array_map(static function (DbTask $taskEntity) { + return $taskEntity->toPublicTask(); + }, $taskEntities); + } catch (Exception $e) { + throw new RuntimeException('Failure while trying to find tasks by appId and identifier: ' . $e->getMessage(), 0, $e); + } + } + + /** + * @return list<IProvider> + */ + private function getPreferredProviders() { + $providers = $this->getProviders(); + $json = $this->config->getAppValue('core', 'ai.text2image_provider', ''); + if ($json !== '') { + try { + $id = json_decode($json, true, 512, JSON_THROW_ON_ERROR); + $provider = current(array_filter($providers, fn ($provider) => $provider->getId() === $id)); + if ($provider !== false && $provider !== null) { + $providers = [$provider]; + } + } catch (\JsonException $e) { + $this->logger->warning('Failed to decode Text2Image setting `ai.text2image_provider`', ['exception' => $e]); + } + } + + return $providers; + } +} |