diff options
Diffstat (limited to 'lib/private/TextProcessing/Manager.php')
-rw-r--r-- | lib/private/TextProcessing/Manager.php | 116 |
1 files changed, 90 insertions, 26 deletions
diff --git a/lib/private/TextProcessing/Manager.php b/lib/private/TextProcessing/Manager.php index 34f0b4e7964..3fe45ce55ec 100644 --- a/lib/private/TextProcessing/Manager.php +++ b/lib/private/TextProcessing/Manager.php @@ -3,24 +3,8 @@ declare(strict_types=1); /** - * @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net> - * - * @author Marcel Klehr <mklehr@gmx.net> - * - * @license GNU AGPL version 3 or any later version - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see <http://www.gnu.org/licenses/>. + * SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors + * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OC\TextProcessing; @@ -36,13 +20,22 @@ use OCP\DB\Exception; use OCP\IConfig; use OCP\IServerContainer; use OCP\PreConditionNotMetException; +use OCP\TaskProcessing\IManager as TaskProcessingIManager; +use OCP\TaskProcessing\TaskTypes\TextToText; +use OCP\TaskProcessing\TaskTypes\TextToTextHeadline; +use OCP\TaskProcessing\TaskTypes\TextToTextSummary; +use OCP\TaskProcessing\TaskTypes\TextToTextTopics; use OCP\TextProcessing\Exception\TaskFailureException; +use OCP\TextProcessing\FreePromptTaskType; +use OCP\TextProcessing\HeadlineTaskType; use OCP\TextProcessing\IManager; use OCP\TextProcessing\IProvider; use OCP\TextProcessing\IProviderWithExpectedRuntime; use OCP\TextProcessing\IProviderWithId; +use OCP\TextProcessing\SummaryTaskType; use OCP\TextProcessing\Task; use OCP\TextProcessing\Task as OCPTask; +use OCP\TextProcessing\TopicsTaskType; use Psr\Log\LoggerInterface; use RuntimeException; use Throwable; @@ -51,6 +44,13 @@ class Manager implements IManager { /** @var ?IProvider[] */ private ?array $providers = null; + private static array $taskProcessingCompatibleTaskTypes = [ + FreePromptTaskType::class => TextToText::ID, + HeadlineTaskType::class => TextToTextHeadline::ID, + SummaryTaskType::class => TextToTextSummary::ID, + TopicsTaskType::class => TextToTextTopics::ID, + ]; + public function __construct( private IServerContainer $serverContainer, private Coordinator $coordinator, @@ -58,6 +58,7 @@ class Manager implements IManager { private IJobList $jobList, private TaskMapper $taskMapper, private IConfig $config, + private TaskProcessingIManager $taskProcessingManager, ) { } @@ -88,6 +89,14 @@ class Manager implements IManager { } public function hasProviders(): bool { + // check if task processing equivalent types are available + $taskTaskTypes = $this->taskProcessingManager->getAvailableTaskTypes(); + foreach (self::$taskProcessingCompatibleTaskTypes as $textTaskTypeClass => $taskTaskTypeId) { + if (isset($taskTaskTypes[$taskTaskTypeId])) { + return true; + } + } + $context = $this->coordinator->getRegistrationContext(); if ($context === null) { return false; @@ -103,6 +112,15 @@ class Manager implements IManager { foreach ($this->getProviders() as $provider) { $tasks[$provider->getTaskType()] = true; } + + // check if task processing equivalent types are available + $taskTaskTypes = $this->taskProcessingManager->getAvailableTaskTypes(); + foreach (self::$taskProcessingCompatibleTaskTypes as $textTaskTypeClass => $taskTaskTypeId) { + if (isset($taskTaskTypes[$taskTaskTypeId])) { + $tasks[$textTaskTypeClass] = true; + } + } + return array_keys($tasks); } @@ -114,6 +132,49 @@ class Manager implements IManager { * @inheritDoc */ public function runTask(OCPTask $task): string { + // try to run a task processing task if possible + $taskTypeClass = $task->getType(); + if (isset(self::$taskProcessingCompatibleTaskTypes[$taskTypeClass]) && isset($this->taskProcessingManager->getAvailableTaskTypes()[self::$taskProcessingCompatibleTaskTypes[$taskTypeClass]])) { + try { + $taskProcessingTaskTypeId = self::$taskProcessingCompatibleTaskTypes[$taskTypeClass]; + $taskProcessingTask = new \OCP\TaskProcessing\Task( + $taskProcessingTaskTypeId, + ['input' => $task->getInput()], + $task->getAppId(), + $task->getUserId(), + $task->getIdentifier(), + ); + + $task->setStatus(OCPTask::STATUS_RUNNING); + if ($task->getId() === null) { + $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task)); + $task->setId($taskEntity->getId()); + } else { + $this->taskMapper->update(DbTask::fromPublicTask($task)); + } + $this->logger->debug('Running a TextProcessing (' . $taskTypeClass . ') task with TaskProcessing'); + $taskProcessingResultTask = $this->taskProcessingManager->runTask($taskProcessingTask); + if ($taskProcessingResultTask->getStatus() === \OCP\TaskProcessing\Task::STATUS_SUCCESSFUL) { + $output = $taskProcessingResultTask->getOutput(); + if (isset($output['output']) && is_string($output['output'])) { + $task->setOutput($output['output']); + $task->setStatus(OCPTask::STATUS_SUCCESSFUL); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + return $output['output']; + } + } + } catch (\Throwable $e) { + $this->logger->error('TextProcessing to TaskProcessing failed', ['exception' => $e]); + $task->setStatus(OCPTask::STATUS_FAILED); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + throw new TaskFailureException('TextProcessing to TaskProcessing failed: ' . $e->getMessage(), 0, $e); + } + $task->setStatus(OCPTask::STATUS_FAILED); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + throw new TaskFailureException('Could not run task'); + } + + // try to run the text processing task if (!$this->canHandleTask($task)) { throw new PreConditionNotMetException('No text processing provider is installed that can handle this task'); } @@ -124,7 +185,7 @@ class Manager implements IManager { $task->setStatus(OCPTask::STATUS_RUNNING); if ($provider instanceof IProviderWithExpectedRuntime) { $completionExpectedAt = new \DateTime('now'); - $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S')); + $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S')); $task->setCompletionExpectedAt($completionExpectedAt); } if ($task->getId() === null) { @@ -139,7 +200,7 @@ class Manager implements IManager { $this->taskMapper->update(DbTask::fromPublicTask($task)); return $output; } catch (\Throwable $e) { - $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]); + $this->logger->error('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]); $task->setStatus(OCPTask::STATUS_FAILED); $this->taskMapper->update(DbTask::fromPublicTask($task)); throw new TaskFailureException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage(), 0, $e); @@ -160,13 +221,17 @@ class Manager implements IManager { } $task->setStatus(OCPTask::STATUS_SCHEDULED); $providers = $this->getPreferredProviders($task); - if (count($providers) === 0) { + $equivalentTaskProcessingTypeAvailable = ( + isset(self::$taskProcessingCompatibleTaskTypes[$task->getType()]) + && isset($this->taskProcessingManager->getAvailableTaskTypes()[self::$taskProcessingCompatibleTaskTypes[$task->getType()]]) + ); + if (count($providers) === 0 && !$equivalentTaskProcessingTypeAvailable) { throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); } [$provider,] = $providers; if ($provider instanceof IProviderWithExpectedRuntime) { $completionExpectedAt = new \DateTime('now'); - $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S')); + $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S')); $task->setCompletionExpectedAt($completionExpectedAt); } $taskEntity = DbTask::fromPublicTask($task); @@ -185,7 +250,7 @@ class Manager implements IManager { throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); } [$provider,] = $this->getPreferredProviders($task); - $maxExecutionTime = (int) ini_get('max_execution_time'); + $maxExecutionTime = (int)ini_get('max_execution_time'); // Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time // or if the provider doesn't provide a getExpectedRuntime() method if (!$provider instanceof IProviderWithExpectedRuntime || $provider->getExpectedRuntime() > $maxExecutionTime * 0.8) { @@ -287,7 +352,7 @@ class Manager implements IManager { if ($provider instanceof IProviderWithId) { return $provider->getId() === $preferences[$task->getType()]; } - $provider::class === $preferences[$task->getType()]; + return $provider::class === $preferences[$task->getType()]; }))); if ($provider !== false) { $providers = array_filter($providers, fn ($p) => $p !== $provider); @@ -295,7 +360,6 @@ class Manager implements IManager { } } } - $providers = array_values(array_filter($providers, fn (IProvider $provider) => $task->canUseProvider($provider))); - return $providers; + return array_values(array_filter($providers, fn (IProvider $provider) => $task->canUseProvider($provider))); } } |