aboutsummaryrefslogtreecommitdiffstats
path: root/lib/private/TextProcessing/Manager.php
diff options
context:
space:
mode:
Diffstat (limited to 'lib/private/TextProcessing/Manager.php')
-rw-r--r--lib/private/TextProcessing/Manager.php116
1 files changed, 90 insertions, 26 deletions
diff --git a/lib/private/TextProcessing/Manager.php b/lib/private/TextProcessing/Manager.php
index 34f0b4e7964..3fe45ce55ec 100644
--- a/lib/private/TextProcessing/Manager.php
+++ b/lib/private/TextProcessing/Manager.php
@@ -3,24 +3,8 @@
declare(strict_types=1);
/**
- * @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net>
- *
- * @author Marcel Klehr <mklehr@gmx.net>
- *
- * @license GNU AGPL version 3 or any later version
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as
- * published by the Free Software Foundation, either version 3 of the
- * License, or (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ * SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-License-Identifier: AGPL-3.0-or-later
*/
namespace OC\TextProcessing;
@@ -36,13 +20,22 @@ use OCP\DB\Exception;
use OCP\IConfig;
use OCP\IServerContainer;
use OCP\PreConditionNotMetException;
+use OCP\TaskProcessing\IManager as TaskProcessingIManager;
+use OCP\TaskProcessing\TaskTypes\TextToText;
+use OCP\TaskProcessing\TaskTypes\TextToTextHeadline;
+use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
+use OCP\TaskProcessing\TaskTypes\TextToTextTopics;
use OCP\TextProcessing\Exception\TaskFailureException;
+use OCP\TextProcessing\FreePromptTaskType;
+use OCP\TextProcessing\HeadlineTaskType;
use OCP\TextProcessing\IManager;
use OCP\TextProcessing\IProvider;
use OCP\TextProcessing\IProviderWithExpectedRuntime;
use OCP\TextProcessing\IProviderWithId;
+use OCP\TextProcessing\SummaryTaskType;
use OCP\TextProcessing\Task;
use OCP\TextProcessing\Task as OCPTask;
+use OCP\TextProcessing\TopicsTaskType;
use Psr\Log\LoggerInterface;
use RuntimeException;
use Throwable;
@@ -51,6 +44,13 @@ class Manager implements IManager {
/** @var ?IProvider[] */
private ?array $providers = null;
+ private static array $taskProcessingCompatibleTaskTypes = [
+ FreePromptTaskType::class => TextToText::ID,
+ HeadlineTaskType::class => TextToTextHeadline::ID,
+ SummaryTaskType::class => TextToTextSummary::ID,
+ TopicsTaskType::class => TextToTextTopics::ID,
+ ];
+
public function __construct(
private IServerContainer $serverContainer,
private Coordinator $coordinator,
@@ -58,6 +58,7 @@ class Manager implements IManager {
private IJobList $jobList,
private TaskMapper $taskMapper,
private IConfig $config,
+ private TaskProcessingIManager $taskProcessingManager,
) {
}
@@ -88,6 +89,14 @@ class Manager implements IManager {
}
public function hasProviders(): bool {
+ // check if task processing equivalent types are available
+ $taskTaskTypes = $this->taskProcessingManager->getAvailableTaskTypes();
+ foreach (self::$taskProcessingCompatibleTaskTypes as $textTaskTypeClass => $taskTaskTypeId) {
+ if (isset($taskTaskTypes[$taskTaskTypeId])) {
+ return true;
+ }
+ }
+
$context = $this->coordinator->getRegistrationContext();
if ($context === null) {
return false;
@@ -103,6 +112,15 @@ class Manager implements IManager {
foreach ($this->getProviders() as $provider) {
$tasks[$provider->getTaskType()] = true;
}
+
+ // check if task processing equivalent types are available
+ $taskTaskTypes = $this->taskProcessingManager->getAvailableTaskTypes();
+ foreach (self::$taskProcessingCompatibleTaskTypes as $textTaskTypeClass => $taskTaskTypeId) {
+ if (isset($taskTaskTypes[$taskTaskTypeId])) {
+ $tasks[$textTaskTypeClass] = true;
+ }
+ }
+
return array_keys($tasks);
}
@@ -114,6 +132,49 @@ class Manager implements IManager {
* @inheritDoc
*/
public function runTask(OCPTask $task): string {
+ // try to run a task processing task if possible
+ $taskTypeClass = $task->getType();
+ if (isset(self::$taskProcessingCompatibleTaskTypes[$taskTypeClass]) && isset($this->taskProcessingManager->getAvailableTaskTypes()[self::$taskProcessingCompatibleTaskTypes[$taskTypeClass]])) {
+ try {
+ $taskProcessingTaskTypeId = self::$taskProcessingCompatibleTaskTypes[$taskTypeClass];
+ $taskProcessingTask = new \OCP\TaskProcessing\Task(
+ $taskProcessingTaskTypeId,
+ ['input' => $task->getInput()],
+ $task->getAppId(),
+ $task->getUserId(),
+ $task->getIdentifier(),
+ );
+
+ $task->setStatus(OCPTask::STATUS_RUNNING);
+ if ($task->getId() === null) {
+ $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task));
+ $task->setId($taskEntity->getId());
+ } else {
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ }
+ $this->logger->debug('Running a TextProcessing (' . $taskTypeClass . ') task with TaskProcessing');
+ $taskProcessingResultTask = $this->taskProcessingManager->runTask($taskProcessingTask);
+ if ($taskProcessingResultTask->getStatus() === \OCP\TaskProcessing\Task::STATUS_SUCCESSFUL) {
+ $output = $taskProcessingResultTask->getOutput();
+ if (isset($output['output']) && is_string($output['output'])) {
+ $task->setOutput($output['output']);
+ $task->setStatus(OCPTask::STATUS_SUCCESSFUL);
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ return $output['output'];
+ }
+ }
+ } catch (\Throwable $e) {
+ $this->logger->error('TextProcessing to TaskProcessing failed', ['exception' => $e]);
+ $task->setStatus(OCPTask::STATUS_FAILED);
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ throw new TaskFailureException('TextProcessing to TaskProcessing failed: ' . $e->getMessage(), 0, $e);
+ }
+ $task->setStatus(OCPTask::STATUS_FAILED);
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ throw new TaskFailureException('Could not run task');
+ }
+
+ // try to run the text processing task
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No text processing provider is installed that can handle this task');
}
@@ -124,7 +185,7 @@ class Manager implements IManager {
$task->setStatus(OCPTask::STATUS_RUNNING);
if ($provider instanceof IProviderWithExpectedRuntime) {
$completionExpectedAt = new \DateTime('now');
- $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
+ $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S'));
$task->setCompletionExpectedAt($completionExpectedAt);
}
if ($task->getId() === null) {
@@ -139,7 +200,7 @@ class Manager implements IManager {
$this->taskMapper->update(DbTask::fromPublicTask($task));
return $output;
} catch (\Throwable $e) {
- $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
+ $this->logger->error('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
$task->setStatus(OCPTask::STATUS_FAILED);
$this->taskMapper->update(DbTask::fromPublicTask($task));
throw new TaskFailureException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage(), 0, $e);
@@ -160,13 +221,17 @@ class Manager implements IManager {
}
$task->setStatus(OCPTask::STATUS_SCHEDULED);
$providers = $this->getPreferredProviders($task);
- if (count($providers) === 0) {
+ $equivalentTaskProcessingTypeAvailable = (
+ isset(self::$taskProcessingCompatibleTaskTypes[$task->getType()])
+ && isset($this->taskProcessingManager->getAvailableTaskTypes()[self::$taskProcessingCompatibleTaskTypes[$task->getType()]])
+ );
+ if (count($providers) === 0 && !$equivalentTaskProcessingTypeAvailable) {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
[$provider,] = $providers;
if ($provider instanceof IProviderWithExpectedRuntime) {
$completionExpectedAt = new \DateTime('now');
- $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
+ $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S'));
$task->setCompletionExpectedAt($completionExpectedAt);
}
$taskEntity = DbTask::fromPublicTask($task);
@@ -185,7 +250,7 @@ class Manager implements IManager {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
[$provider,] = $this->getPreferredProviders($task);
- $maxExecutionTime = (int) ini_get('max_execution_time');
+ $maxExecutionTime = (int)ini_get('max_execution_time');
// Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time
// or if the provider doesn't provide a getExpectedRuntime() method
if (!$provider instanceof IProviderWithExpectedRuntime || $provider->getExpectedRuntime() > $maxExecutionTime * 0.8) {
@@ -287,7 +352,7 @@ class Manager implements IManager {
if ($provider instanceof IProviderWithId) {
return $provider->getId() === $preferences[$task->getType()];
}
- $provider::class === $preferences[$task->getType()];
+ return $provider::class === $preferences[$task->getType()];
})));
if ($provider !== false) {
$providers = array_filter($providers, fn ($p) => $p !== $provider);
@@ -295,7 +360,6 @@ class Manager implements IManager {
}
}
}
- $providers = array_values(array_filter($providers, fn (IProvider $provider) => $task->canUseProvider($provider)));
- return $providers;
+ return array_values(array_filter($providers, fn (IProvider $provider) => $task->canUseProvider($provider)));
}
}