]> source.dussan.org Git - nextcloud-server.git/commitdiff
feat(textprocessing): TextProcessingManager::runTask calls TaskProcessingManager...
authorJulien Veyssier <julien-nc@posteo.net>
Wed, 28 Aug 2024 09:50:23 +0000 (11:50 +0200)
committerJulien Veyssier <julien-nc@posteo.net>
Fri, 30 Aug 2024 08:07:01 +0000 (10:07 +0200)
Signed-off-by: Julien Veyssier <julien-nc@posteo.net>
lib/private/TaskProcessing/Manager.php
lib/private/TextProcessing/Manager.php

index e9db978034ec2eb475100505780a8c214a7ae444..873e0720f08d7ed9e72a23633889ee8f735336d8 100644 (file)
@@ -87,7 +87,6 @@ class Manager implements IManager {
                private IEventDispatcher $dispatcher,
                IAppDataFactory $appDataFactory,
                private IRootFolder $rootFolder,
-               private \OCP\TextProcessing\IManager $textProcessingManager,
                private \OCP\TextToImage\IManager $textToImageManager,
                private \OCP\SpeechToText\ISpeechToTextManager $speechToTextManager,
                private IUserMountCache $userMountCache,
@@ -98,8 +97,34 @@ class Manager implements IManager {
        }
 
 
+       /**
+        * This is almost a copy of textProcessingManager->getProviders
+        * to avoid a dependency cycle between TextProcessingManager and TaskProcessingManager
+        */
+       private function _getRawTextProcessingProviders(): array {
+               $context = $this->coordinator->getRegistrationContext();
+               if ($context === null) {
+                       return [];
+               }
+
+               $providers = [];
+
+               foreach ($context->getTextProcessingProviders() as $providerServiceRegistration) {
+                       $class = $providerServiceRegistration->getService();
+                       try {
+                               $providers[$class] = $this->serverContainer->get($class);
+                       } catch (\Throwable $e) {
+                               $this->logger->error('Failed to load Text processing provider ' . $class, [
+                                       'exception' => $e,
+                               ]);
+                       }
+               }
+
+               return $providers;
+       }
+
        private function _getTextProcessingProviders(): array {
-               $oldProviders = $this->textProcessingManager->getProviders();
+               $oldProviders = $this->_getRawTextProcessingProviders();
                $newProviders = [];
                foreach ($oldProviders as $oldProvider) {
                        $provider = new class($oldProvider) implements IProvider, ISynchronousProvider {
@@ -190,7 +215,7 @@ class Manager implements IManager {
         * @return ITaskType[]
         */
        private function _getTextProcessingTaskTypes(): array {
-               $oldProviders = $this->textProcessingManager->getProviders();
+               $oldProviders = $this->_getRawTextProcessingProviders();
                $newTaskTypes = [];
                foreach ($oldProviders as $oldProvider) {
                        // These are already implemented in the TaskProcessing realm
index 4cc5ff7752710c06b3a06fbe746fdd46918de98e..c73a6a9ee37c21c81f23cd026f1be31dd3d7d14c 100644 (file)
@@ -20,13 +20,22 @@ use OCP\DB\Exception;
 use OCP\IConfig;
 use OCP\IServerContainer;
 use OCP\PreConditionNotMetException;
+use OCP\TaskProcessing\IManager as TaskProcessingIManager;
+use OCP\TaskProcessing\TaskTypes\TextToText;
+use OCP\TaskProcessing\TaskTypes\TextToTextHeadline;
+use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
+use OCP\TaskProcessing\TaskTypes\TextToTextTopics;
 use OCP\TextProcessing\Exception\TaskFailureException;
+use OCP\TextProcessing\FreePromptTaskType;
+use OCP\TextProcessing\HeadlineTaskType;
 use OCP\TextProcessing\IManager;
 use OCP\TextProcessing\IProvider;
 use OCP\TextProcessing\IProviderWithExpectedRuntime;
 use OCP\TextProcessing\IProviderWithId;
+use OCP\TextProcessing\SummaryTaskType;
 use OCP\TextProcessing\Task;
 use OCP\TextProcessing\Task as OCPTask;
+use OCP\TextProcessing\TopicsTaskType;
 use Psr\Log\LoggerInterface;
 use RuntimeException;
 use Throwable;
@@ -42,6 +51,7 @@ class Manager implements IManager {
                private IJobList $jobList,
                private TaskMapper $taskMapper,
                private IConfig $config,
+               private TaskProcessingIManager $taskProcessingManager,
        ) {
        }
 
@@ -98,6 +108,52 @@ class Manager implements IManager {
         * @inheritDoc
         */
        public function runTask(OCPTask $task): string {
+               // try to run a task processing task if possible
+               $taskTypeClass = $task->getType();
+               $taskProcessingCompatibleTaskTypes = [
+                       FreePromptTaskType::class => TextToText::ID,
+                       HeadlineTaskType::class => TextToTextHeadline::ID,
+                       SummaryTaskType::class => TextToTextSummary::ID,
+                       TopicsTaskType::class => TextToTextTopics::ID,
+               ];
+               if (isset($taskProcessingCompatibleTaskTypes[$taskTypeClass])) {
+                       try {
+                               $taskProcessingTaskTypeId = $taskProcessingCompatibleTaskTypes[$taskTypeClass];
+                               $taskProcessingTask = new \OCP\TaskProcessing\Task(
+                                       $taskProcessingTaskTypeId,
+                                       ['input' => $task->getInput()],
+                                       $task->getAppId(),
+                                       $task->getUserId(),
+                                       $task->getIdentifier(),
+                               );
+
+                               $task->setStatus(OCPTask::STATUS_RUNNING);
+                               if ($task->getId() === null) {
+                                       $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task));
+                                       $task->setId($taskEntity->getId());
+                               } else {
+                                       $this->taskMapper->update(DbTask::fromPublicTask($task));
+                               }
+                               $this->logger->debug('Running a TextProcessing (' . $taskTypeClass . ') task with TaskProcessing');
+                               $taskProcessingResultTask = $this->taskProcessingManager->runTask($taskProcessingTask);
+                               if ($taskProcessingResultTask->getStatus() === \OCP\TaskProcessing\Task::STATUS_SUCCESSFUL) {
+                                       $task->setOutput($taskProcessingResultTask->getOutput()['output'] ?? '');
+                                       $task->setStatus(OCPTask::STATUS_SUCCESSFUL);
+                                       $this->taskMapper->update(DbTask::fromPublicTask($task));
+                                       return $task->getOutput();
+                               }
+                       } catch (\Throwable $e) {
+                               $this->logger->error('TextProcessing to TaskProcessing failed', ['exception' => $e]);
+                               $task->setStatus(OCPTask::STATUS_FAILED);
+                               $this->taskMapper->update(DbTask::fromPublicTask($task));
+                               throw new TaskFailureException('TextProcessing to TaskProcessing failed: ' . $e->getMessage(), 0, $e);
+                       }
+                       $task->setStatus(OCPTask::STATUS_FAILED);
+                       $this->taskMapper->update(DbTask::fromPublicTask($task));
+                       throw new TaskFailureException('Could not run task');
+               }
+
+               // try to run the text processing task
                if (!$this->canHandleTask($task)) {
                        throw new PreConditionNotMetException('No text processing provider is installed that can handle this task');
                }