diff options
-rw-r--r-- | lib/private/TaskProcessing/Manager.php | 31 | ||||
-rw-r--r-- | lib/private/TextProcessing/Manager.php | 56 |
2 files changed, 84 insertions, 3 deletions
diff --git a/lib/private/TaskProcessing/Manager.php b/lib/private/TaskProcessing/Manager.php index e9db978034e..873e0720f08 100644 --- a/lib/private/TaskProcessing/Manager.php +++ b/lib/private/TaskProcessing/Manager.php @@ -87,7 +87,6 @@ class Manager implements IManager { private IEventDispatcher $dispatcher, IAppDataFactory $appDataFactory, private IRootFolder $rootFolder, - private \OCP\TextProcessing\IManager $textProcessingManager, private \OCP\TextToImage\IManager $textToImageManager, private \OCP\SpeechToText\ISpeechToTextManager $speechToTextManager, private IUserMountCache $userMountCache, @@ -98,8 +97,34 @@ class Manager implements IManager { } + /** + * This is almost a copy of textProcessingManager->getProviders + * to avoid a dependency cycle between TextProcessingManager and TaskProcessingManager + */ + private function _getRawTextProcessingProviders(): array { + $context = $this->coordinator->getRegistrationContext(); + if ($context === null) { + return []; + } + + $providers = []; + + foreach ($context->getTextProcessingProviders() as $providerServiceRegistration) { + $class = $providerServiceRegistration->getService(); + try { + $providers[$class] = $this->serverContainer->get($class); + } catch (\Throwable $e) { + $this->logger->error('Failed to load Text processing provider ' . $class, [ + 'exception' => $e, + ]); + } + } + + return $providers; + } + private function _getTextProcessingProviders(): array { - $oldProviders = $this->textProcessingManager->getProviders(); + $oldProviders = $this->_getRawTextProcessingProviders(); $newProviders = []; foreach ($oldProviders as $oldProvider) { $provider = new class($oldProvider) implements IProvider, ISynchronousProvider { @@ -190,7 +215,7 @@ class Manager implements IManager { * @return ITaskType[] */ private function _getTextProcessingTaskTypes(): array { - $oldProviders = $this->textProcessingManager->getProviders(); + $oldProviders = $this->_getRawTextProcessingProviders(); $newTaskTypes = []; foreach ($oldProviders as $oldProvider) { // These are already implemented in the TaskProcessing realm diff --git a/lib/private/TextProcessing/Manager.php b/lib/private/TextProcessing/Manager.php index 4cc5ff77527..c73a6a9ee37 100644 --- a/lib/private/TextProcessing/Manager.php +++ b/lib/private/TextProcessing/Manager.php @@ -20,13 +20,22 @@ use OCP\DB\Exception; use OCP\IConfig; use OCP\IServerContainer; use OCP\PreConditionNotMetException; +use OCP\TaskProcessing\IManager as TaskProcessingIManager; +use OCP\TaskProcessing\TaskTypes\TextToText; +use OCP\TaskProcessing\TaskTypes\TextToTextHeadline; +use OCP\TaskProcessing\TaskTypes\TextToTextSummary; +use OCP\TaskProcessing\TaskTypes\TextToTextTopics; use OCP\TextProcessing\Exception\TaskFailureException; +use OCP\TextProcessing\FreePromptTaskType; +use OCP\TextProcessing\HeadlineTaskType; use OCP\TextProcessing\IManager; use OCP\TextProcessing\IProvider; use OCP\TextProcessing\IProviderWithExpectedRuntime; use OCP\TextProcessing\IProviderWithId; +use OCP\TextProcessing\SummaryTaskType; use OCP\TextProcessing\Task; use OCP\TextProcessing\Task as OCPTask; +use OCP\TextProcessing\TopicsTaskType; use Psr\Log\LoggerInterface; use RuntimeException; use Throwable; @@ -42,6 +51,7 @@ class Manager implements IManager { private IJobList $jobList, private TaskMapper $taskMapper, private IConfig $config, + private TaskProcessingIManager $taskProcessingManager, ) { } @@ -98,6 +108,52 @@ class Manager implements IManager { * @inheritDoc */ public function runTask(OCPTask $task): string { + // try to run a task processing task if possible + $taskTypeClass = $task->getType(); + $taskProcessingCompatibleTaskTypes = [ + FreePromptTaskType::class => TextToText::ID, + HeadlineTaskType::class => TextToTextHeadline::ID, + SummaryTaskType::class => TextToTextSummary::ID, + TopicsTaskType::class => TextToTextTopics::ID, + ]; + if (isset($taskProcessingCompatibleTaskTypes[$taskTypeClass])) { + try { + $taskProcessingTaskTypeId = $taskProcessingCompatibleTaskTypes[$taskTypeClass]; + $taskProcessingTask = new \OCP\TaskProcessing\Task( + $taskProcessingTaskTypeId, + ['input' => $task->getInput()], + $task->getAppId(), + $task->getUserId(), + $task->getIdentifier(), + ); + + $task->setStatus(OCPTask::STATUS_RUNNING); + if ($task->getId() === null) { + $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task)); + $task->setId($taskEntity->getId()); + } else { + $this->taskMapper->update(DbTask::fromPublicTask($task)); + } + $this->logger->debug('Running a TextProcessing (' . $taskTypeClass . ') task with TaskProcessing'); + $taskProcessingResultTask = $this->taskProcessingManager->runTask($taskProcessingTask); + if ($taskProcessingResultTask->getStatus() === \OCP\TaskProcessing\Task::STATUS_SUCCESSFUL) { + $task->setOutput($taskProcessingResultTask->getOutput()['output'] ?? ''); + $task->setStatus(OCPTask::STATUS_SUCCESSFUL); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + return $task->getOutput(); + } + } catch (\Throwable $e) { + $this->logger->error('TextProcessing to TaskProcessing failed', ['exception' => $e]); + $task->setStatus(OCPTask::STATUS_FAILED); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + throw new TaskFailureException('TextProcessing to TaskProcessing failed: ' . $e->getMessage(), 0, $e); + } + $task->setStatus(OCPTask::STATUS_FAILED); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + throw new TaskFailureException('Could not run task'); + } + + // try to run the text processing task if (!$this->canHandleTask($task)) { throw new PreConditionNotMetException('No text processing provider is installed that can handle this task'); } |