Browse Source

enh(TextProcessing): Add IProvider2

- allow providers to obtain current task's userId
- allow providers to expose average task runtime

Signed-off-by: Marcel Klehr <mklehr@gmx.net>
tags/v28.0.0beta3
Marcel Klehr 8 months ago
parent
commit
181f819e41

+ 9
- 2
core/Controller/TextProcessingApiController.php View File

@@ -35,6 +35,7 @@ use OCP\AppFramework\Http\Attribute\PublicPage;
use OCP\AppFramework\Http\Attribute\UserRateLimit;
use OCP\AppFramework\Http\DataResponse;
use OCP\Common\Exception\NotFoundException;
use OCP\DB\Exception;
use OCP\IL10N;
use OCP\IRequest;
use OCP\TextProcessing\ITaskType;
@@ -102,7 +103,7 @@ class TextProcessingApiController extends \OCP\AppFramework\OCSController {
* @param string $appId ID of the app that will execute the task
* @param string $identifier An arbitrary identifier for the task
*
* @return DataResponse<Http::STATUS_OK, array{task: CoreTextProcessingTask}, array{}>|DataResponse<Http::STATUS_BAD_REQUEST|Http::STATUS_PRECONDITION_FAILED, array{message: string}, array{}>
* @return DataResponse<Http::STATUS_OK, array{task: CoreTextProcessingTask}, array{}>|DataResponse<Http::STATUS_INTERNAL_SERVER_ERROR|Http::STATUS_BAD_REQUEST|Http::STATUS_PRECONDITION_FAILED, array{message: string}, array{}>
*
* 200: Task scheduled successfully
* 400: Scheduling task is not possible
@@ -118,7 +119,11 @@ class TextProcessingApiController extends \OCP\AppFramework\OCSController {
return new DataResponse(['message' => $this->l->t('Requested task type does not exist')], Http::STATUS_BAD_REQUEST);
}
try {
$this->textProcessingManager->scheduleTask($task);
try {
$this->textProcessingManager->runOrScheduleTask($task);
} catch(\RuntimeException) {
// noop, because the task object has the failure status set already, we just return the task json
}

$json = $task->jsonSerialize();

@@ -127,6 +132,8 @@ class TextProcessingApiController extends \OCP\AppFramework\OCSController {
]);
} catch (PreConditionNotMetException) {
return new DataResponse(['message' => $this->l->t('Necessary language model provider is not available')], Http::STATUS_PRECONDITION_FAILED);
} catch (Exception) {
return new DataResponse(['message' => 'Internal server error'], Http::STATUS_INTERNAL_SERVER_ERROR);
}
}


+ 60
- 0
core/Migrations/Version28000Date20231103104802.php View File

@@ -0,0 +1,60 @@
<?php

declare(strict_types=1);

/**
* @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net>
*
* @author Marcel Klehr <mklehr@gmx.net>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/

namespace OC\Core\Migrations;

use Closure;
use OCP\DB\ISchemaWrapper;
use OCP\DB\Types;
use OCP\Migration\IOutput;
use OCP\Migration\SimpleMigrationStep;

/**
* Introduce completion_expected_at column in textprocessing_tasks table
*/
class Version28000Date20231103104802 extends SimpleMigrationStep {
/**
* @param IOutput $output
* @param Closure $schemaClosure The `\Closure` returns a `ISchemaWrapper`
* @param array $options
* @return null|ISchemaWrapper
*/
public function changeSchema(IOutput $output, Closure $schemaClosure, array $options): ?ISchemaWrapper {
/** @var ISchemaWrapper $schema */
$schema = $schemaClosure();
if ($schema->hasTable('textprocessing_tasks')) {
$table = $schema->getTable('textprocessing_tasks');

$table->addColumn('completion_expected_at', Types::DATETIME, [
'notnull' => false,
]);

return $schema;
}

return null;
}
}

+ 1
- 0
core/ResponseDefinitions.php View File

@@ -143,6 +143,7 @@ namespace OCA\Core;
* input: string,
* output: ?string,
* identifier: string,
* completionExpectedAt: ?int
* }
*
* @psalm-type CoreTextToImageTask = array{

+ 8
- 2
lib/private/TextProcessing/Db/Task.php View File

@@ -45,6 +45,8 @@ use OCP\TextProcessing\Task as OCPTask;
* @method string getAppId()
* @method setIdentifier(string $identifier)
* @method string getIdentifier()
* @method setCompletionExpectedAt(null|\DateTime $completionExpectedAt)
* @method null|\DateTime getCompletionExpectedAt()
*/
class Task extends Entity {
protected $lastUpdated;
@@ -55,16 +57,17 @@ class Task extends Entity {
protected $userId;
protected $appId;
protected $identifier;
protected $completionExpectedAt;

/**
* @var string[]
*/
public static array $columns = ['id', 'last_updated', 'type', 'input', 'output', 'status', 'user_id', 'app_id', 'identifier'];
public static array $columns = ['id', 'last_updated', 'type', 'input', 'output', 'status', 'user_id', 'app_id', 'identifier', 'completion_expected_at'];

/**
* @var string[]
*/
public static array $fields = ['id', 'lastUpdated', 'type', 'input', 'output', 'status', 'userId', 'appId', 'identifier'];
public static array $fields = ['id', 'lastUpdated', 'type', 'input', 'output', 'status', 'userId', 'appId', 'identifier', 'completionExpectedAt'];


public function __construct() {
@@ -78,6 +81,7 @@ class Task extends Entity {
$this->addType('userId', 'string');
$this->addType('appId', 'string');
$this->addType('identifier', 'string');
$this->addType('completionExpectedAt', 'datetime');
}

public function toRow(): array {
@@ -98,6 +102,7 @@ class Task extends Entity {
'userId' => $task->getUserId(),
'appId' => $task->getAppId(),
'identifier' => $task->getIdentifier(),
'completionExpectedAt' => $task->getCompletionExpectedAt(),
]);
return $task;
}
@@ -107,6 +112,7 @@ class Task extends Entity {
$task->setId($this->getId());
$task->setStatus($this->getStatus());
$task->setOutput($this->getOutput());
$task->setCompletionExpectedAt($this->getCompletionExpectedAt());
return $task;
}
}

+ 55
- 14
lib/private/TextProcessing/Manager.php View File

@@ -28,6 +28,7 @@ namespace OC\TextProcessing;
use OC\AppFramework\Bootstrap\Coordinator;
use OC\TextProcessing\Db\Task as DbTask;
use OCP\IConfig;
use OCP\TextProcessing\IProvider2;
use OCP\TextProcessing\Task;
use OCP\TextProcessing\Task as OCPTask;
use OC\TextProcessing\Db\TaskMapper;
@@ -114,19 +115,7 @@ class Manager implements IManager {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No text processing provider is installed that can handle this task');
}
$providers = $this->getProviders();
$json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', '');
if ($json !== '') {
$preferences = json_decode($json, true);
if (isset($preferences[$task->getType()])) {
// If a preference for this task type is set, move the preferred provider to the start
$provider = current(array_filter($providers, fn ($provider) => $provider::class === $preferences[$task->getType()]));
if ($provider !== false) {
$providers = array_filter($providers, fn ($p) => $p !== $provider);
array_unshift($providers, $provider);
}
}
}
$providers = $this->getPreferredProviders($task);

foreach ($providers as $provider) {
if (!$task->canUseProvider($provider)) {
@@ -134,6 +123,11 @@ class Manager implements IManager {
}
try {
$task->setStatus(OCPTask::STATUS_RUNNING);
if ($provider instanceof IProvider2) {
$completionExpectedAt = new \DateTime('now');
$completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
$task->setCompletionExpectedAt($completionExpectedAt);
}
if ($task->getId() === null) {
$taskEntity = $this->taskMapper->insert(DbTask::fromPublicTask($task));
$task->setId($taskEntity->getId());
@@ -158,18 +152,25 @@ class Manager implements IManager {
}
}

$task->setStatus(OCPTask::STATUS_FAILED);
$this->taskMapper->update(DbTask::fromPublicTask($task));
throw new RuntimeException('Could not run task');
}

/**
* @inheritDoc
* @throws Exception
*/
public function scheduleTask(OCPTask $task): void {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
$task->setStatus(OCPTask::STATUS_SCHEDULED);
[$provider, ] = $this->getPreferredProviders($task);
if ($provider instanceof IProvider2) {
$completionExpectedAt = new \DateTime('now');
$completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
$task->setCompletionExpectedAt($completionExpectedAt);
}
$taskEntity = DbTask::fromPublicTask($task);
$this->taskMapper->insert($taskEntity);
$task->setId($taskEntity->getId());
@@ -178,6 +179,25 @@ class Manager implements IManager {
]);
}

/**
* @inheritDoc
*/
public function runOrScheduleTask(OCPTask $task) : bool {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
[$provider,] = $this->getPreferredProviders($task);
$maxExecutionTime = (int) ini_get('max_execution_time');
// Offload the task to a background job if the expected runtime of the likely provider is longer than 80% of our max execution time
// or if the provider doesn't provide a getExpectedRuntime() method
if (!$provider instanceof IProvider2 || $provider->getExpectedRuntime() > $maxExecutionTime * 0.8) {
$this->scheduleTask($task);
return false;
}
$this->runTask($task);
return true;
}

/**
* @inheritDoc
*/
@@ -253,4 +273,25 @@ class Manager implements IManager {
throw new RuntimeException('Failure while trying to find tasks by appId and identifier: ' . $e->getMessage(), 0, $e);
}
}

/**
* @param OCPTask $task
* @return IProvider[]
*/
public function getPreferredProviders(OCPTask $task): array {
$providers = $this->getProviders();
$json = $this->config->getAppValue('core', 'ai.textprocessing_provider_preferences', '');
if ($json !== '') {
$preferences = json_decode($json, true);
if (isset($preferences[$task->getType()])) {
// If a preference for this task type is set, move the preferred provider to the start
$provider = current(array_filter($providers, fn ($provider) => $provider::class === $preferences[$task->getType()]));
if ($provider !== false) {
$providers = array_filter($providers, fn ($p) => $p !== $provider);
array_unshift($providers, $provider);
}
}
}
return $providers;
}
}

+ 16
- 0
lib/public/TextProcessing/IManager.php View File

@@ -27,6 +27,7 @@ declare(strict_types=1);
namespace OCP\TextProcessing;

use OCP\Common\Exception\NotFoundException;
use OCP\DB\Exception;
use OCP\PreConditionNotMetException;
use RuntimeException;

@@ -68,10 +69,25 @@ interface IManager {
*
* @param Task $task The task to schedule
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called
* @throws Exception storing the task in the database failed
* @since 27.1.0
*/
public function scheduleTask(Task $task) : void;

/**
* If the designated provider for the passed task provides an expected average runtime, we check if the runtime fits into the
* max execution time of this php process and run it synchronously if it does, if it doesn't fit (or the provider doesn't provide that information)
* execution is deferred to a background job
*
* @param Task $task The task to schedule
* @returns bool A boolean indicating whether the task was run synchronously (`true`) or offloaded to a background job (`false`)
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called
* @throws RuntimeException If running the task failed
* @throws Exception storing the task in the database failed
* @since 28.0.0
*/
public function runOrScheduleTask(Task $task): bool;

/**
* Delete a task that has been scheduled before
*

+ 48
- 0
lib/public/TextProcessing/IProvider2.php View File

@@ -0,0 +1,48 @@
<?php

declare(strict_types=1);

/**
* @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net>
*
* @author Marcel Klehr <mklehr@gmx.net>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/


namespace OCP\TextProcessing;

/**
* This interface supersedes IProvider. It allows the system to learn
* the provider's expected runtime and lets the provider know which user is running a task
* @since 28.0.0
* @template T of ITaskType
* @template-extends IProvider<T>
*/
interface IProvider2 extends IProvider {
/**
* @param ?string $userId the current user's id
* @since 28.0.0
*/
public function setUserId(?string $userId): string;

/**
* @return int The expected average runtime of a task in seconds
* @since 28.0.0
*/
public function getExpectedRuntime(): int;
}

+ 16
- 3
lib/public/TextProcessing/Task.php View File

@@ -35,6 +35,7 @@ namespace OCP\TextProcessing;
final class Task implements \JsonSerializable {
protected ?int $id = null;
protected ?string $output = null;
private ?\DateTime $completionExpectedAt = null;

/**
* @since 27.1.0
@@ -92,12 +93,15 @@ final class Task implements \JsonSerializable {

/**
* @psalm-param P $provider
* @param IProvider $provider
* @param IProvider|IProvider2 $provider
* @return string
* @since 27.1.0
*/
public function visitProvider(IProvider $provider): string {
public function visitProvider(IProvider|IProvider2 $provider): string {
if ($this->canUseProvider($provider)) {
if ($provider instanceof IProvider2) {
$provider->setUserId($this->getUserId());
}
return $provider->process($this->getInput());
} else {
throw new \RuntimeException('Task of type ' . $this->getType() . ' cannot visit provider with task type ' . $provider->getTaskType());
@@ -203,7 +207,7 @@ final class Task implements \JsonSerializable {
}

/**
* @psalm-return array{id: ?int, type: S, status: 0|1|2|3|4, userId: ?string, appId: string, input: string, output: ?string, identifier: string}
* @psalm-return array{id: ?int, type: S, status: 0|1|2|3|4, userId: ?string, appId: string, input: string, output: ?string, identifier: string, completionExpectedAt: ?int}
* @since 27.1.0
*/
public function jsonSerialize(): array {
@@ -216,6 +220,15 @@ final class Task implements \JsonSerializable {
'input' => $this->getInput(),
'output' => $this->getOutput(),
'identifier' => $this->getIdentifier(),
'completionExpectedAt' => $this->getCompletionExpectedAt()?->getTimestamp(),
];
}

final public function setCompletionExpectedAt(\DateTime $completionExpectedAt): void {
$this->completionExpectedAt = $completionExpectedAt;
}

final public function getCompletionExpectedAt(): ?\DateTime {
return $this->completionExpectedAt;
}
}

Loading…
Cancel
Save