- allow providers to obtain current task's userId - allow providers to expose average task runtime Signed-off-by: Marcel Klehr <mklehr@gmx.net>tags/v28.0.0beta3
@@ -35,6 +35,7 @@ use OCP\AppFramework\Http\Attribute\PublicPage; | |||
use OCP\AppFramework\Http\Attribute\UserRateLimit; | |||
use OCP\AppFramework\Http\DataResponse; | |||
use OCP\Common\Exception\NotFoundException; | |||
use OCP\DB\Exception; | |||
use OCP\IL10N; | |||
use OCP\IRequest; | |||
use OCP\TextProcessing\ITaskType; | |||
@@ -102,7 +103,7 @@ class TextProcessingApiController extends \OCP\AppFramework\OCSController { | |||
* @param string $appId ID of the app that will execute the task | |||
* @param string $identifier An arbitrary identifier for the task | |||
* | |||
* @return DataResponse<Http::STATUS_OK, array{task: CoreTextProcessingTask}, array{}>|DataResponse<Http::STATUS_BAD_REQUEST|Http::STATUS_PRECONDITION_FAILED, array{message: string}, array{}> | |||
* @return DataResponse<Http::STATUS_OK, array{task: CoreTextProcessingTask}, array{}>|DataResponse<Http::STATUS_INTERNAL_SERVER_ERROR|Http::STATUS_BAD_REQUEST|Http::STATUS_PRECONDITION_FAILED, array{message: string}, array{}> | |||
* | |||
* 200: Task scheduled successfully | |||
* 400: Scheduling task is not possible | |||
@@ -118,7 +119,11 @@ class TextProcessingApiController extends \OCP\AppFramework\OCSController { | |||
return new DataResponse(['message' => $this->l->t('Requested task type does not exist')], Http::STATUS_BAD_REQUEST); | |||
} | |||
try { | |||
$this->textProcessingManager->scheduleTask($task); | |||
try { | |||
$this->textProcessingManager->runOrScheduleTask($task); | |||
} catch(\RuntimeException) { | |||
// noop, because the task object has the failure status set already, we just return the task json | |||
} | |||
$json = $task->jsonSerialize(); | |||
@@ -127,6 +132,8 @@ class TextProcessingApiController extends \OCP\AppFramework\OCSController { | |||
]); | |||
} catch (PreConditionNotMetException) { | |||
return new DataResponse(['message' => $this->l->t('Necessary language model provider is not available')], Http::STATUS_PRECONDITION_FAILED); | |||
} catch (Exception) { | |||
return new DataResponse(['message' => 'Internal server error'], Http::STATUS_INTERNAL_SERVER_ERROR); | |||
} | |||
} | |||
@@ -0,0 +1,60 @@ | |||
<?php | |||
declare(strict_types=1); | |||
/** | |||
* @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net> | |||
* | |||
* @author Marcel Klehr <mklehr@gmx.net> | |||
* | |||
* @license GNU AGPL version 3 or any later version | |||
* | |||
* This program is free software: you can redistribute it and/or modify | |||
* it under the terms of the GNU Affero General Public License as | |||
* published by the Free Software Foundation, either version 3 of the | |||
* License, or (at your option) any later version. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* GNU Affero General Public License for more details. | |||
* | |||
* You should have received a copy of the GNU Affero General Public License | |||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | |||
* | |||
*/ | |||
namespace OC\Core\Migrations; | |||
use Closure; | |||
use OCP\DB\ISchemaWrapper; | |||
use OCP\DB\Types; | |||
use OCP\Migration\IOutput; | |||
use OCP\Migration\SimpleMigrationStep; | |||
/** | |||
* Introduce completion_expected_at column in textprocessing_tasks table | |||
*/ | |||
class Version28000Date20231103104802 extends SimpleMigrationStep { | |||
/** | |||
* @param IOutput $output | |||
* @param Closure $schemaClosure The `\Closure` returns a `ISchemaWrapper` | |||
* @param array $options | |||
* @return null|ISchemaWrapper | |||
*/ | |||
public function changeSchema(IOutput $output, Closure $schemaClosure, array $options): ?ISchemaWrapper { | |||
/** @var ISchemaWrapper $schema */ | |||
$schema = $schemaClosure(); | |||
if ($schema->hasTable('textprocessing_tasks')) { | |||
$table = $schema->getTable('textprocessing_tasks'); | |||
$table->addColumn('completion_expected_at', Types::DATETIME, [ | |||
'notnull' => false, | |||
]); | |||
return $schema; | |||
} | |||
return null; | |||
} | |||
} |
@@ -143,6 +143,7 @@ namespace OCA\Core; | |||
* input: string, | |||
* output: ?string, | |||
* identifier: string, | |||
* completionExpectedAt: ?int | |||
* } | |||
* | |||
* @psalm-type CoreTextToImageTask = array{ |
@@ -45,6 +45,8 @@ use OCP\TextProcessing\Task as OCPTask; | |||
* @method string getAppId() | |||
* @method setIdentifier(string $identifier) | |||
* @method string getIdentifier() | |||
* @method setCompletionExpectedAt(null|\DateTime $completionExpectedAt) | |||
* @method null|\DateTime getCompletionExpectedAt() | |||
*/ | |||
class Task extends Entity { | |||
protected $lastUpdated; | |||
@@ -55,16 +57,17 @@ class Task extends Entity { | |||
protected $userId; | |||
protected $appId; | |||
protected $identifier; | |||
protected $completionExpectedAt; | |||
/** | |||
* @var string[] | |||
*/ | |||
public static array $columns = ['id', 'last_updated', 'type', 'input', 'output', 'status', 'user_id', 'app_id', 'identifier']; | |||
public static array $columns = ['id', 'last_updated', 'type', 'input', 'output', 'status', 'user_id', 'app_id', 'identifier', 'completion_expected_at']; | |||
/** | |||
* @var string[] | |||
*/ | |||
public static array $fields = ['id', 'lastUpdated', 'type', 'input', 'output', 'status', 'userId', 'appId', 'identifier']; | |||
public static array $fields = ['id', 'lastUpdated', 'type', 'input', 'output', 'status', 'userId', 'appId', 'identifier', 'completionExpectedAt']; | |||
public function __construct() { | |||
@@ -78,6 +81,7 @@ class Task extends Entity { | |||
$this->addType('userId', 'string'); | |||
$this->addType('appId', 'string'); | |||
$this->addType('identifier', 'string'); | |||
$this->addType('completionExpectedAt', 'datetime'); | |||
} | |||
public function toRow(): array { | |||
@@ -98,6 +102,7 @@ class Task extends Entity { | |||
'userId' => $task->getUserId(), | |||
'appId' => $task->getAppId(), | |||
'identifier' => $task->getIdentifier(), | |||
'completionExpectedAt' => $task->getCompletionExpectedAt(), | |||
]); | |||
return $task; | |||
} | |||
@@ -107,6 +112,7 @@ class Task extends Entity { | |||
$task->setId($this->getId()); | |||
$task->setStatus($this->getStatus()); | |||
$task->setOutput($this->getOutput()); | |||
$task->setCompletionExpectedAt($this->getCompletionExpectedAt()); | |||
return $task; | |||
} | |||
} |
@@ -28,6 +28,7 @@ namespace OC\TextProcessing; | |||
use OC\AppFramework\Bootstrap\Coordinator; | |||
use OC\TextProcessing\Db\Task as DbTask; | |||
use OCP\IConfig; | |||
use OCP\TextProcessing\IProvider2; | |||
use OCP\TextProcessing\Task; | |||
use OCP\TextProcessing\Task as OCPTask; | |||
use OC\TextProcessing\Db\TaskMapper; | |||
@@ -114,19 +115,7 @@ class Manager implements IManager { | |||
if (!$this->canHandleTask($task)) { | |||
throw new PreConditionNotMetException('No text processing provider is installed that can handle this task'); | |||
} | |||
$providers = $this->getProviders(); | |||
$json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', ''); | |||
if ($json !== '') { | |||
$preferences = json_decode($json, true); | |||
if (isset($preferences[$task->getType()])) { | |||
// If a preference for this task type is set, move the preferred provider to the start | |||
$provider = current(array_filter($providers, fn ($provider) => $provider::class === $preferences[$task->getType()])); | |||
if ($provider !== false) { | |||
$providers = array_filter($providers, fn ($p) => $p !== $provider); | |||
array_unshift($providers, $provider); | |||
} | |||
} | |||
} | |||
$providers = $this->getPreferredProviders($task); | |||
foreach ($providers as $provider) { | |||
if (!$task->canUseProvider($provider)) { | |||
@@ -134,6 +123,11 @@ class Manager implements IManager { | |||
} | |||
try { | |||
$task->setStatus(OCPTask::STATUS_RUNNING); | |||
if ($provider instanceof IProvider2) { | |||
$completionExpectedAt = new \DateTime('now'); | |||
$completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S')); | |||
$task->setCompletionExpectedAt($completionExpectedAt); | |||
} | |||
if ($task->getId() === null) { | |||
$taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task)); | |||
$task->setId($taskEntity->getId()); | |||
@@ -158,18 +152,25 @@ class Manager implements IManager { | |||
} | |||
} | |||
$task->setStatus(OCPTask::STATUS_FAILED); | |||
$this->taskMapper->update(DbTask::fromPublicTask($task)); | |||
throw new RuntimeException('Could not run task'); | |||
} | |||
/** | |||
* @inheritDoc | |||
* @throws Exception | |||
*/ | |||
public function scheduleTask(OCPTask $task): void { | |||
if (!$this->canHandleTask($task)) { | |||
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); | |||
} | |||
$task->setStatus(OCPTask::STATUS_SCHEDULED); | |||
[$provider, ] = $this->getPreferredProviders($task); | |||
if ($provider instanceof IProvider2) { | |||
$completionExpectedAt = new \DateTime('now'); | |||
$completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S')); | |||
$task->setCompletionExpectedAt($completionExpectedAt); | |||
} | |||
$taskEntity = DbTask::fromPublicTask($task); | |||
$this->taskMapper->insert($taskEntity); | |||
$task->setId($taskEntity->getId()); | |||
@@ -178,6 +179,25 @@ class Manager implements IManager { | |||
]); | |||
} | |||
/** | |||
* @inheritDoc | |||
*/ | |||
public function runOrScheduleTask(OCPTask $task) : bool { | |||
if (!$this->canHandleTask($task)) { | |||
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); | |||
} | |||
[$provider,] = $this->getPreferredProviders($task); | |||
$maxExecutionTime = (int) ini_get('max_execution_time'); | |||
// Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time | |||
// or if the provider doesn't provide a getExpectedRuntime() method | |||
if (!$provider instanceof IProvider2 || $provider->getExpectedRuntime() > $maxExecutionTime * 0.8) { | |||
$this->scheduleTask($task); | |||
return false; | |||
} | |||
$this->runTask($task); | |||
return true; | |||
} | |||
/** | |||
* @inheritDoc | |||
*/ | |||
@@ -253,4 +273,25 @@ class Manager implements IManager { | |||
throw new RuntimeException('Failure while trying to find tasks by appId and identifier: ' . $e->getMessage(), 0, $e); | |||
} | |||
} | |||
/** | |||
* @param OCPTask $task | |||
* @return IProvider[] | |||
*/ | |||
public function getPreferredProviders(OCPTask $task): array { | |||
$providers = $this->getProviders(); | |||
$json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', ''); | |||
if ($json !== '') { | |||
$preferences = json_decode($json, true); | |||
if (isset($preferences[$task->getType()])) { | |||
// If a preference for this task type is set, move the preferred provider to the start | |||
$provider = current(array_filter($providers, fn ($provider) => $provider::class === $preferences[$task->getType()])); | |||
if ($provider !== false) { | |||
$providers = array_filter($providers, fn ($p) => $p !== $provider); | |||
array_unshift($providers, $provider); | |||
} | |||
} | |||
} | |||
return $providers; | |||
} | |||
} |
@@ -27,6 +27,7 @@ declare(strict_types=1); | |||
namespace OCP\TextProcessing; | |||
use OCP\Common\Exception\NotFoundException; | |||
use OCP\DB\Exception; | |||
use OCP\PreConditionNotMetException; | |||
use RuntimeException; | |||
@@ -68,10 +69,25 @@ interface IManager { | |||
* | |||
* @param Task $task The task to schedule | |||
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called | |||
* @throws Exception storing the task in the database failed | |||
* @since 27.1.0 | |||
*/ | |||
public function scheduleTask(Task $task) : void; | |||
/** | |||
* If the designated provider for the passed task provides an expected average runtime, we check if the runtime fits into the | |||
* max execution time of this php process and run it synchronously if it does, if it doesn't fit (or the provider doesn't provide that information) | |||
* execution is deferred to a background job | |||
* | |||
* @param Task $task The task to schedule | |||
* @returns bool A boolean indicating whether the task was run synchronously (`true`) or offloaded to a background job (`false`) | |||
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called | |||
* @throws RuntimeException If running the task failed | |||
* @throws Exception storing the task in the database failed | |||
* @since 28.0.0 | |||
*/ | |||
public function runOrScheduleTask(Task $task): bool; | |||
/** | |||
* Delete a task that has been scheduled before | |||
* |
@@ -0,0 +1,48 @@ | |||
<?php | |||
declare(strict_types=1); | |||
/** | |||
* @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net> | |||
* | |||
* @author Marcel Klehr <mklehr@gmx.net> | |||
* | |||
* @license GNU AGPL version 3 or any later version | |||
* | |||
* This program is free software: you can redistribute it and/or modify | |||
* it under the terms of the GNU Affero General Public License as | |||
* published by the Free Software Foundation, either version 3 of the | |||
* License, or (at your option) any later version. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* GNU Affero General Public License for more details. | |||
* | |||
* You should have received a copy of the GNU Affero General Public License | |||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | |||
*/ | |||
namespace OCP\TextProcessing; | |||
/** | |||
* This interface supersedes IProvider. It allows the system to learn | |||
* the provider's expected runtime and lets the provider know which user is running a task | |||
* @since 28.0.0 | |||
* @template T of ITaskType | |||
* @template-extends IProvider<T> | |||
*/ | |||
interface IProvider2 extends IProvider { | |||
/** | |||
* @param ?string $userId the current user's id | |||
* @since 28.0.0 | |||
*/ | |||
public function setUserId(?string $userId): string; | |||
/** | |||
* @return int The expected average runtime of a task in seconds | |||
* @since 28.0.0 | |||
*/ | |||
public function getExpectedRuntime(): int; | |||
} |
@@ -35,6 +35,7 @@ namespace OCP\TextProcessing; | |||
final class Task implements \JsonSerializable { | |||
protected ?int $id = null; | |||
protected ?string $output = null; | |||
private ?\DateTime $completionExpectedAt = null; | |||
/** | |||
* @since 27.1.0 | |||
@@ -92,12 +93,15 @@ final class Task implements \JsonSerializable { | |||
/** | |||
* @psalm-param P $provider | |||
* @param IProvider $provider | |||
* @param IProvider|IProvider2 $provider | |||
* @return string | |||
* @since 27.1.0 | |||
*/ | |||
public function visitProvider(IProvider $provider): string { | |||
public function visitProvider(IProvider|IProvider2 $provider): string { | |||
if ($this->canUseProvider($provider)) { | |||
if ($provider instanceof IProvider2) { | |||
$provider->setUserId($this->getUserId()); | |||
} | |||
return $provider->process($this->getInput()); | |||
} else { | |||
throw new \RuntimeException('Task of type ' . $this->getType() . ' cannot visit provider with task type ' . $provider->getTaskType()); | |||
@@ -203,7 +207,7 @@ final class Task implements \JsonSerializable { | |||
} | |||
/** | |||
* @psalm-return array{id: ?int, type: S, status: 0|1|2|3|4, userId: ?string, appId: string, input: string, output: ?string, identifier: string} | |||
* @psalm-return array{id: ?int, type: S, status: 0|1|2|3|4, userId: ?string, appId: string, input: string, output: ?string, identifier: string, completionExpectedAt: ?int} | |||
* @since 27.1.0 | |||
*/ | |||
public function jsonSerialize(): array { | |||
@@ -216,6 +220,15 @@ final class Task implements \JsonSerializable { | |||
'input' => $this->getInput(), | |||
'output' => $this->getOutput(), | |||
'identifier' => $this->getIdentifier(), | |||
'completionExpectedAt' => $this->getCompletionExpectedAt()?->getTimestamp(), | |||
]; | |||
} | |||
final public function setCompletionExpectedAt(\DateTime $completionExpectedAt): void { | |||
$this->completionExpectedAt = $completionExpectedAt; | |||
} | |||
final public function getCompletionExpectedAt(): ?\DateTime { | |||
return $this->completionExpectedAt; | |||
} | |||
} |