aboutsummaryrefslogtreecommitdiffstats
path: root/lib/private
diff options
context:
space:
mode:
authorJoas Schilling <213943+nickvergessen@users.noreply.github.com>2023-11-13 10:49:14 +0100
committerGitHub <noreply@github.com>2023-11-13 10:49:14 +0100
commit0feb55ee93b579bfcba30410325ca0c28f501779 (patch)
tree57f0752858983b0ba52344ca32c7c9a905b1c126 /lib/private
parent36d370b2163ac2a98dbecf2400eb1300025b79b2 (diff)
parent2031fe17b4f6b48d3b72fee520285dde37488afb (diff)
downloadnextcloud-server-0feb55ee93b579bfcba30410325ca0c28f501779.tar.gz
nextcloud-server-0feb55ee93b579bfcba30410325ca0c28f501779.zip
Merge pull request #41271 from nextcloud/enh/text-processing-iprovider2
enh(TextProcessing): Add two new provider interfaces
Diffstat (limited to 'lib/private')
-rw-r--r--lib/private/TextProcessing/Db/Task.php10
-rw-r--r--lib/private/TextProcessing/Manager.php87
2 files changed, 71 insertions, 26 deletions
diff --git a/lib/private/TextProcessing/Db/Task.php b/lib/private/TextProcessing/Db/Task.php
index 9c6f16d11ae..5f362d429f3 100644
--- a/lib/private/TextProcessing/Db/Task.php
+++ b/lib/private/TextProcessing/Db/Task.php
@@ -45,6 +45,8 @@ use OCP\TextProcessing\Task as OCPTask;
* @method string getAppId()
* @method setIdentifier(string $identifier)
* @method string getIdentifier()
+ * @method setCompletionExpectedAt(null|\DateTime $completionExpectedAt)
+ * @method null|\DateTime getCompletionExpectedAt()
*/
class Task extends Entity {
protected $lastUpdated;
@@ -55,16 +57,17 @@ class Task extends Entity {
protected $userId;
protected $appId;
protected $identifier;
+ protected $completionExpectedAt;
/**
* @var string[]
*/
- public static array $columns = ['id', 'last_updated', 'type', 'input', 'output', 'status', 'user_id', 'app_id', 'identifier'];
+ public static array $columns = ['id', 'last_updated', 'type', 'input', 'output', 'status', 'user_id', 'app_id', 'identifier', 'completion_expected_at'];
/**
* @var string[]
*/
- public static array $fields = ['id', 'lastUpdated', 'type', 'input', 'output', 'status', 'userId', 'appId', 'identifier'];
+ public static array $fields = ['id', 'lastUpdated', 'type', 'input', 'output', 'status', 'userId', 'appId', 'identifier', 'completionExpectedAt'];
public function __construct() {
@@ -78,6 +81,7 @@ class Task extends Entity {
$this->addType('userId', 'string');
$this->addType('appId', 'string');
$this->addType('identifier', 'string');
+ $this->addType('completionExpectedAt', 'datetime');
}
public function toRow(): array {
@@ -98,6 +102,7 @@ class Task extends Entity {
'userId' => $task->getUserId(),
'appId' => $task->getAppId(),
'identifier' => $task->getIdentifier(),
+ 'completionExpectedAt' => $task->getCompletionExpectedAt(),
]);
return $task;
}
@@ -107,6 +112,7 @@ class Task extends Entity {
$task->setId($this->getId());
$task->setStatus($this->getStatus());
$task->setOutput($this->getOutput());
+ $task->setCompletionExpectedAt($this->getCompletionExpectedAt());
return $task;
}
}
diff --git a/lib/private/TextProcessing/Manager.php b/lib/private/TextProcessing/Manager.php
index b9cb06c298e..c98ff893143 100644
--- a/lib/private/TextProcessing/Manager.php
+++ b/lib/private/TextProcessing/Manager.php
@@ -28,6 +28,8 @@ namespace OC\TextProcessing;
use OC\AppFramework\Bootstrap\Coordinator;
use OC\TextProcessing\Db\Task as DbTask;
use OCP\IConfig;
+use OCP\TextProcessing\Exception\TaskFailureException;
+use OCP\TextProcessing\IProviderWithExpectedRuntime;
use OCP\TextProcessing\Task;
use OCP\TextProcessing\Task as OCPTask;
use OC\TextProcessing\Db\TaskMapper;
@@ -114,26 +116,16 @@ class Manager implements IManager {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No text processing provider is installed that can handle this task');
}
- $providers = $this->getProviders();
- $json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', '');
- if ($json !== '') {
- $preferences = json_decode($json, true);
- if (isset($preferences[$task->getType()])) {
- // If a preference for this task type is set, move the preferred provider to the start
- $provider = current(array_filter($providers, fn ($provider) => $provider::class === $preferences[$task->getType()]));
- if ($provider !== false) {
- $providers = array_filter($providers, fn ($p) => $p !== $provider);
- array_unshift($providers, $provider);
- }
- }
- }
+ $providers = $this->getPreferredProviders($task);
foreach ($providers as $provider) {
- if (!$task->canUseProvider($provider)) {
- continue;
- }
try {
$task->setStatus(OCPTask::STATUS_RUNNING);
+ if ($provider instanceof IProviderWithExpectedRuntime) {
+ $completionExpectedAt = new \DateTime('now');
+ $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
+ $task->setCompletionExpectedAt($completionExpectedAt);
+ }
if ($task->getId() === null) {
$taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task));
$task->setId($taskEntity->getId());
@@ -145,31 +137,37 @@ class Manager implements IManager {
$task->setStatus(OCPTask::STATUS_SUCCESSFUL);
$this->taskMapper->update(DbTask::fromPublicTask($task));
return $output;
- } catch (\RuntimeException $e) {
- $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
- $task->setStatus(OCPTask::STATUS_FAILED);
- $this->taskMapper->update(DbTask::fromPublicTask($task));
- throw $e;
} catch (\Throwable $e) {
$this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
$task->setStatus(OCPTask::STATUS_FAILED);
$this->taskMapper->update(DbTask::fromPublicTask($task));
- throw new RuntimeException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage(), 0, $e);
+ throw new TaskFailureException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage(), 0, $e);
}
}
- throw new RuntimeException('Could not run task');
+ $task->setStatus(OCPTask::STATUS_FAILED);
+ $this->taskMapper->update(DbTask::fromPublicTask($task));
+ throw new TaskFailureException('Could not run task');
}
/**
* @inheritDoc
- * @throws Exception
*/
public function scheduleTask(OCPTask $task): void {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
$task->setStatus(OCPTask::STATUS_SCHEDULED);
+ $providers = $this->getPreferredProviders($task);
+ if (count($providers) === 0) {
+ throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
+ }
+ [$provider,] = $providers;
+ if ($provider instanceof IProviderWithExpectedRuntime) {
+ $completionExpectedAt = new \DateTime('now');
+ $completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
+ $task->setCompletionExpectedAt($completionExpectedAt);
+ }
$taskEntity = DbTask::fromPublicTask($task);
$this->taskMapper->insert($taskEntity);
$task->setId($taskEntity->getId());
@@ -181,6 +179,25 @@ class Manager implements IManager {
/**
* @inheritDoc
*/
+ public function runOrScheduleTask(OCPTask $task): bool {
+ if (!$this->canHandleTask($task)) {
+ throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
+ }
+ [$provider,] = $this->getPreferredProviders($task);
+ $maxExecutionTime = (int) ini_get('max_execution_time');
+ // Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time
+ // or if the provider doesn't provide a getExpectedRuntime() method
+ if (!$provider instanceof IProviderWithExpectedRuntime || $provider->getExpectedRuntime() > $maxExecutionTime * 0.8) {
+ $this->scheduleTask($task);
+ return false;
+ }
+ $this->runTask($task);
+ return true;
+ }
+
+ /**
+ * @inheritDoc
+ */
public function deleteTask(Task $task): void {
$taskEntity = DbTask::fromPublicTask($task);
$this->taskMapper->delete($taskEntity);
@@ -253,4 +270,26 @@ class Manager implements IManager {
throw new RuntimeException('Failure while trying to find tasks by appId and identifier: ' . $e->getMessage(), 0, $e);
}
}
+
+ /**
+ * @param OCPTask $task
+ * @return IProvider[]
+ */
+ public function getPreferredProviders(OCPTask $task): array {
+ $providers = $this->getProviders();
+ $json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', '');
+ if ($json !== '') {
+ $preferences = json_decode($json, true);
+ if (isset($preferences[$task->getType()])) {
+ // If a preference for this task type is set, move the preferred provider to the start
+ $provider = current(array_values(array_filter($providers, fn ($provider) => $provider::class === $preferences[$task->getType()])));
+ if ($provider !== false) {
+ $providers = array_filter($providers, fn ($p) => $p !== $provider);
+ array_unshift($providers, $provider);
+ }
+ }
+ }
+ $providers = array_values(array_filter($providers, fn (IProvider $provider) => $task->canUseProvider($provider)));
+ return $providers;
+ }
}