diff options
Diffstat (limited to 'lib/private/TextProcessing/Manager.php')
-rw-r--r-- | lib/private/TextProcessing/Manager.php | 107 |
1 files changed, 73 insertions, 34 deletions
diff --git a/lib/private/TextProcessing/Manager.php b/lib/private/TextProcessing/Manager.php index 47af57bf3ec..70f5f322aa5 100644 --- a/lib/private/TextProcessing/Manager.php +++ b/lib/private/TextProcessing/Manager.php @@ -27,20 +27,22 @@ namespace OC\TextProcessing; use OC\AppFramework\Bootstrap\Coordinator; use OC\TextProcessing\Db\Task as DbTask; -use OCP\IConfig; -use OCP\TextProcessing\IProviderWithId; -use OCP\TextProcessing\Task; -use OCP\TextProcessing\Task as OCPTask; use OC\TextProcessing\Db\TaskMapper; use OCP\AppFramework\Db\DoesNotExistException; use OCP\AppFramework\Db\MultipleObjectsReturnedException; use OCP\BackgroundJob\IJobList; use OCP\Common\Exception\NotFoundException; use OCP\DB\Exception; +use OCP\IConfig; use OCP\IServerContainer; +use OCP\PreConditionNotMetException; +use OCP\TextProcessing\Exception\TaskFailureException; use OCP\TextProcessing\IManager; use OCP\TextProcessing\IProvider; -use OCP\PreConditionNotMetException; +use OCP\TextProcessing\IProviderWithExpectedRuntime; +use OCP\TextProcessing\IProviderWithId; +use OCP\TextProcessing\Task; +use OCP\TextProcessing\Task as OCPTask; use Psr\Log\LoggerInterface; use RuntimeException; use Throwable; @@ -115,31 +117,16 @@ class Manager implements IManager { if (!$this->canHandleTask($task)) { throw new PreConditionNotMetException('No text processing provider is installed that can handle this task'); } - $providers = $this->getProviders(); - $json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', ''); - if ($json !== '') { - $preferences = json_decode($json, true); - if (isset($preferences[$task->getType()])) { - // If a preference for this task type is set, move the preferred provider to the start - $provider = current(array_filter($providers, function ($provider) use ($preferences, $task) { - if ($provider instanceof IProviderWithId) { - return $provider->getId() === $preferences[$task->getType()]; - } - return $provider::class === $preferences[$task->getType()]; - })); - if ($provider !== false) { - $providers = array_filter($providers, fn ($p) => $p !== $provider); - array_unshift($providers, $provider); - } - } - } + $providers = $this->getPreferredProviders($task); foreach ($providers as $provider) { - if (!$task->canUseProvider($provider)) { - continue; - } try { $task->setStatus(OCPTask::STATUS_RUNNING); + if ($provider instanceof IProviderWithExpectedRuntime) { + $completionExpectedAt = new \DateTime('now'); + $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S')); + $task->setCompletionExpectedAt($completionExpectedAt); + } if ($task->getId() === null) { $taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task)); $task->setId($taskEntity->getId()); @@ -151,31 +138,37 @@ class Manager implements IManager { $task->setStatus(OCPTask::STATUS_SUCCESSFUL); $this->taskMapper->update(DbTask::fromPublicTask($task)); return $output; - } catch (\RuntimeException $e) { - $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]); - $task->setStatus(OCPTask::STATUS_FAILED); - $this->taskMapper->update(DbTask::fromPublicTask($task)); - throw $e; } catch (\Throwable $e) { $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]); $task->setStatus(OCPTask::STATUS_FAILED); $this->taskMapper->update(DbTask::fromPublicTask($task)); - throw new RuntimeException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage(), 0, $e); + throw new TaskFailureException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage(), 0, $e); } } - throw new RuntimeException('Could not run task'); + $task->setStatus(OCPTask::STATUS_FAILED); + $this->taskMapper->update(DbTask::fromPublicTask($task)); + throw new TaskFailureException('Could not run task'); } /** * @inheritDoc - * @throws Exception */ public function scheduleTask(OCPTask $task): void { if (!$this->canHandleTask($task)) { throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); } $task->setStatus(OCPTask::STATUS_SCHEDULED); + $providers = $this->getPreferredProviders($task); + if (count($providers) === 0) { + throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); + } + [$provider,] = $providers; + if ($provider instanceof IProviderWithExpectedRuntime) { + $completionExpectedAt = new \DateTime('now'); + $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S')); + $task->setCompletionExpectedAt($completionExpectedAt); + } $taskEntity = DbTask::fromPublicTask($task); $this->taskMapper->insert($taskEntity); $task->setId($taskEntity->getId()); @@ -187,6 +180,25 @@ class Manager implements IManager { /** * @inheritDoc */ + public function runOrScheduleTask(OCPTask $task): bool { + if (!$this->canHandleTask($task)) { + throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); + } + [$provider,] = $this->getPreferredProviders($task); + $maxExecutionTime = (int) ini_get('max_execution_time'); + // Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time + // or if the provider doesn't provide a getExpectedRuntime() method + if (!$provider instanceof IProviderWithExpectedRuntime || $provider->getExpectedRuntime() > $maxExecutionTime * 0.8) { + $this->scheduleTask($task); + return false; + } + $this->runTask($task); + return true; + } + + /** + * @inheritDoc + */ public function deleteTask(Task $task): void { $taskEntity = DbTask::fromPublicTask($task); $this->taskMapper->delete($taskEntity); @@ -259,4 +271,31 @@ class Manager implements IManager { throw new RuntimeException('Failure while trying to find tasks by appId and identifier: ' . $e->getMessage(), 0, $e); } } + + /** + * @param OCPTask $task + * @return IProvider[] + */ + public function getPreferredProviders(OCPTask $task): array { + $providers = $this->getProviders(); + $json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', ''); + if ($json !== '') { + $preferences = json_decode($json, true); + if (isset($preferences[$task->getType()])) { + // If a preference for this task type is set, move the preferred provider to the start + $provider = current(array_values(array_filter($providers, function ($provider) use ($preferences, $task) { + if ($provider instanceof IProviderWithId) { + return $provider->getId() === $preferences[$task->getType()]; + } + $provider::class === $preferences[$task->getType()]; + }))); + if ($provider !== false) { + $providers = array_filter($providers, fn ($p) => $p !== $provider); + array_unshift($providers, $provider); + } + } + } + $providers = array_values(array_filter($providers, fn (IProvider $provider) => $task->canUseProvider($provider))); + return $providers; + } } |