aboutsummaryrefslogtreecommitdiffstats
path: root/lib/private/TextToImage/Manager.php
diff options
context:
space:
mode:
Diffstat (limited to 'lib/private/TextToImage/Manager.php')
-rw-r--r--lib/private/TextToImage/Manager.php325
1 files changed, 325 insertions, 0 deletions
diff --git a/lib/private/TextToImage/Manager.php b/lib/private/TextToImage/Manager.php
new file mode 100644
index 00000000000..eec6cc3d241
--- /dev/null
+++ b/lib/private/TextToImage/Manager.php
@@ -0,0 +1,325 @@
+<?php
+
+declare(strict_types=1);
+
+/**
+ * SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-License-Identifier: AGPL-3.0-or-later
+ */
+
+namespace OC\TextToImage;
+
+use OC\AppFramework\Bootstrap\Coordinator;
+use OC\TextToImage\Db\Task as DbTask;
+use OC\TextToImage\Db\TaskMapper;
+use OCP\AppFramework\Db\DoesNotExistException;
+use OCP\AppFramework\Db\MultipleObjectsReturnedException;
+use OCP\BackgroundJob\IJobList;
+use OCP\DB\Exception;
+use OCP\Files\AppData\IAppDataFactory;
+use OCP\Files\IAppData;
+use OCP\Files\NotFoundException;
+use OCP\Files\NotPermittedException;
+use OCP\IConfig;
+use OCP\IServerContainer;
+use OCP\PreConditionNotMetException;
+use OCP\TextToImage\Exception\TaskFailureException;
+use OCP\TextToImage\Exception\TaskNotFoundException;
+use OCP\TextToImage\IManager;
+use OCP\TextToImage\IProvider;
+use OCP\TextToImage\IProviderWithUserId;
+use OCP\TextToImage\Task;
+use Psr\Log\LoggerInterface;
+use RuntimeException;
+use Throwable;
+
+class Manager implements IManager {
+ /** @var ?list<IProvider> */
+ private ?array $providers = null;
+ private IAppData $appData;
+
+ public function __construct(
+ private IServerContainer $serverContainer,
+ private Coordinator $coordinator,
+ private LoggerInterface $logger,
+ private IJobList $jobList,
+ private TaskMapper $taskMapper,
+ private IConfig $config,
+ IAppDataFactory $appDataFactory,
+ ) {
+ $this->appData = $appDataFactory->get('core');
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function getProviders(): array {
+ $context = $this->coordinator->getRegistrationContext();
+ if ($context === null) {
+ return [];
+ }
+
+ if ($this->providers !== null) {
+ return $this->providers;
+ }
+
+ $this->providers = [];
+
+ foreach ($context->getTextToImageProviders() as $providerServiceRegistration) {
+ $class = $providerServiceRegistration->getService();
+ try {
+ /** @var IProvider $provider */
+ $provider = $this->serverContainer->get($class);
+ $this->providers[] = $provider;
+ } catch (Throwable $e) {
+ $this->logger->error('Failed to load Text to image provider ' . $class, [
+ 'exception' => $e,
+ ]);
+ }
+ }
+
+ return $this->providers;
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function hasProviders(): bool {
+ $context = $this->coordinator->getRegistrationContext();
+ if ($context === null) {
+ return false;
+ }
+ return count($context->getTextToImageProviders()) > 0;
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function runTask(Task $task): void {
+ $this->logger->debug('Running TextToImage Task');
+ if (!$this->hasProviders()) {
+ throw new PreConditionNotMetException('No text to image provider is installed that can handle this task');
+ }
+ $providers = $this->getPreferredProviders();
+
+ foreach ($providers as $provider) {
+ $this->logger->debug('Trying to run Text2Image provider ' . $provider::class);
+ try {
+ $task->setStatus(Task::STATUS_RUNNING);
+ $completionExpectedAt = new \DateTime('now');
+ $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S'));
+ $task->setCompletionExpectedAt($completionExpectedAt);
+ if ($task->getId() === null) {
+ $this->logger->debug('Inserting Text2Image task into DB');
+ $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task));
+ $task->setId($taskEntity->getId());
+ } else {
+ $this->logger->debug('Updating Text2Image task in DB');
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ }
+ try {
+ $folder = $this->appData->getFolder('text2image');
+ } catch (NotFoundException) {
+ $this->logger->debug('Creating folder in appdata for Text2Image results');
+ $folder = $this->appData->newFolder('text2image');
+ }
+ try {
+ $folder = $folder->getFolder((string)$task->getId());
+ } catch (NotFoundException) {
+ $this->logger->debug('Creating new folder in appdata Text2Image results folder');
+ $folder = $folder->newFolder((string)$task->getId());
+ }
+ $this->logger->debug('Creating result files for Text2Image task');
+ $resources = [];
+ $files = [];
+ for ($i = 0; $i < $task->getNumberOfImages(); $i++) {
+ $file = $folder->newFile((string)$i);
+ $files[] = $file;
+ $resource = $file->write();
+ if ($resource !== false && $resource !== true && is_resource($resource)) {
+ $resources[] = $resource;
+ } else {
+ throw new RuntimeException('Text2Image generation using provider "' . $provider->getName() . '" failed: Couldn\'t open file to write.');
+ }
+ }
+ $this->logger->debug('Calling Text2Image provider\'s generate method');
+ if ($provider instanceof IProviderWithUserId) {
+ $provider->setUserId($task->getUserId());
+ }
+ $provider->generate($task->getInput(), $resources);
+ for ($i = 0; $i < $task->getNumberOfImages(); $i++) {
+ if (is_resource($resources[$i])) {
+ // If $resource hasn't been closed yet, we'll do that here
+ fclose($resources[$i]);
+ }
+ }
+ $task->setStatus(Task::STATUS_SUCCESSFUL);
+ $this->logger->debug('Updating Text2Image task in DB');
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ return;
+ } catch (\RuntimeException|\Throwable $e) {
+ for ($i = 0; $i < $task->getNumberOfImages(); $i++) {
+ if (isset($files, $files[$i])) {
+ try {
+ $files[$i]->delete();
+ } catch (NotPermittedException $e) {
+ $this->logger->warning('Failed to clean up Text2Image result file after error', ['exception' => $e]);
+ }
+ }
+ }
+
+ $this->logger->info('Text2Image generation using provider "' . $provider->getName() . '" failed', ['exception' => $e]);
+ $task->setStatus(Task::STATUS_FAILED);
+ try {
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ } catch (Exception $e) {
+ $this->logger->warning('Failed to update database after Text2Image error', ['exception' => $e]);
+ }
+ throw new TaskFailureException('Text2Image generation using provider "' . $provider->getName() . '" failed: ' . $e->getMessage(), 0, $e);
+ }
+ }
+
+ $task->setStatus(Task::STATUS_FAILED);
+ try {
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ } catch (Exception $e) {
+ $this->logger->warning('Failed to update database after Text2Image error', ['exception' => $e]);
+ }
+ throw new TaskFailureException('Could not run task');
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function scheduleTask(Task $task): void {
+ if (!$this->hasProviders()) {
+ throw new PreConditionNotMetException('No text to image provider is installed that can handle this task');
+ }
+ $this->logger->debug('Scheduling Text2Image Task');
+ $task->setStatus(Task::STATUS_SCHEDULED);
+ $completionExpectedAt = new \DateTime('now');
+ $completionExpectedAt->add(new \DateInterval('PT' . $this->getPreferredProviders()[0]->getExpectedRuntime() . 'S'));
+ $task->setCompletionExpectedAt($completionExpectedAt);
+ $taskEntity = DbTask::fromPublicTask($task);
+ $this->taskMapper->insert($taskEntity);
+ $task->setId($taskEntity->getId());
+ $this->jobList->add(TaskBackgroundJob::class, [
+ 'taskId' => $task->getId()
+ ]);
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function runOrScheduleTask(Task $task) : void {
+ if (!$this->hasProviders()) {
+ throw new PreConditionNotMetException('No text to image provider is installed that can handle this task');
+ }
+ $providers = $this->getPreferredProviders();
+ $maxExecutionTime = (int)ini_get('max_execution_time');
+ // Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time
+ if ($providers[0]->getExpectedRuntime() > $maxExecutionTime * 0.8) {
+ $this->scheduleTask($task);
+ return;
+ }
+ $this->runTask($task);
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function deleteTask(Task $task): void {
+ $taskEntity = DbTask::fromPublicTask($task);
+ $this->taskMapper->delete($taskEntity);
+ $this->jobList->remove(TaskBackgroundJob::class, [
+ 'taskId' => $task->getId()
+ ]);
+ }
+
+ /**
+ * Get a task from its id
+ *
+ * @param int $id The id of the task
+ * @return Task
+ * @throws RuntimeException If the query failed
+ * @throws TaskNotFoundException If the task could not be found
+ */
+ public function getTask(int $id): Task {
+ try {
+ $taskEntity = $this->taskMapper->find($id);
+ return $taskEntity->toPublicTask();
+ } catch (DoesNotExistException $e) {
+ throw new TaskNotFoundException('Could not find task with the provided id');
+ } catch (MultipleObjectsReturnedException $e) {
+ throw new RuntimeException('Could not uniquely identify task with given id', 0, $e);
+ } catch (Exception $e) {
+ throw new RuntimeException('Failure while trying to find task by id: ' . $e->getMessage(), 0, $e);
+ }
+ }
+
+ /**
+ * Get a task from its user id and task id
+ * If userId is null, this can only get a task that was scheduled anonymously
+ *
+ * @param int $id The id of the task
+ * @param string|null $userId The user id that scheduled the task
+ * @return Task
+ * @throws RuntimeException If the query failed
+ * @throws TaskNotFoundException If the task could not be found
+ */
+ public function getUserTask(int $id, ?string $userId): Task {
+ try {
+ $taskEntity = $this->taskMapper->findByIdAndUser($id, $userId);
+ return $taskEntity->toPublicTask();
+ } catch (DoesNotExistException $e) {
+ throw new TaskNotFoundException('Could not find task with the provided id and user id');
+ } catch (MultipleObjectsReturnedException $e) {
+ throw new RuntimeException('Could not uniquely identify task with given id and user id', 0, $e);
+ } catch (Exception $e) {
+ throw new RuntimeException('Failure while trying to find task by id and user id: ' . $e->getMessage(), 0, $e);
+ }
+ }
+
+ /**
+ * Get a list of tasks scheduled by a specific user for a specific app
+ * and optionally with a specific identifier.
+ * This cannot be used to get anonymously scheduled tasks
+ *
+ * @param string $userId
+ * @param string $appId
+ * @param string|null $identifier
+ * @return Task[]
+ * @throws RuntimeException
+ */
+ public function getUserTasksByApp(?string $userId, string $appId, ?string $identifier = null): array {
+ try {
+ $taskEntities = $this->taskMapper->findUserTasksByApp($userId, $appId, $identifier);
+ return array_map(static function (DbTask $taskEntity) {
+ return $taskEntity->toPublicTask();
+ }, $taskEntities);
+ } catch (Exception $e) {
+ throw new RuntimeException('Failure while trying to find tasks by appId and identifier: ' . $e->getMessage(), 0, $e);
+ }
+ }
+
+ /**
+ * @return list<IProvider>
+ */
+ private function getPreferredProviders() {
+ $providers = $this->getProviders();
+ $json = $this->config->getAppValue('core', 'ai.text2image_provider', '');
+ if ($json !== '') {
+ try {
+ $id = json_decode($json, true, 512, JSON_THROW_ON_ERROR);
+ $provider = current(array_filter($providers, fn ($provider) => $provider->getId() === $id));
+ if ($provider !== false && $provider !== null) {
+ $providers = [$provider];
+ }
+ } catch (\JsonException $e) {
+ $this->logger->warning('Failed to decode Text2Image setting `ai.text2image_provider`', ['exception' => $e]);
+ }
+ }
+
+ return $providers;
+ }
+}