Browse Source

refactor: rename getTaskType to getTaskTypeId

Signed-off-by: Marcel Klehr <mklehr@gmx.net>
pull/45094/head
Marcel Klehr 1 month ago
parent
commit
b150d779f3

+ 2
- 2
core/Controller/TaskProcessingApiController.php View File

@@ -263,10 +263,10 @@ class TaskProcessingApiController extends \OCP\AppFramework\OCSController {
private function extractFileIdsFromTask(Task $task) {
$ids = [];
$taskTypes = $this->taskProcessingManager->getAvailableTaskTypes();
if (!isset($taskTypes[$task->getTaskType()])) {
if (!isset($taskTypes[$task->getTaskTypeId()])) {
throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find task type');
}
$taskType = $taskTypes[$task->getTaskType()];
$taskType = $taskTypes[$task->getTaskTypeId()];
foreach ($taskType['inputShape'] + $taskType['optionalInputShape'] as $key => $descriptor) {
if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) {
$ids[] = $task->getInput()[$key];

+ 1
- 1
lib/private/TaskProcessing/Db/Task.php View File

@@ -102,7 +102,7 @@ class Task extends Entity {
/** @var Task $taskEntity */
$taskEntity = Task::fromParams([
'id' => $task->getId(),
'type' => $task->getTaskType(),
'type' => $task->getTaskTypeId(),
'lastUpdated' => time(),
'status' => $task->getStatus(),
'input' => json_encode($task->getInput(), JSON_THROW_ON_ERROR),

+ 18
- 18
lib/private/TaskProcessing/Manager.php View File

@@ -121,7 +121,7 @@ class Manager implements IManager {
return $this->provider->getName();
}

public function getTaskType(): string {
public function getTaskTypeId(): string {
return match ($this->provider->getTaskType()) {
\OCP\TextProcessing\FreePromptTaskType::class => TextToText::ID,
\OCP\TextProcessing\HeadlineTaskType::class => TextToTextHeadline::ID,
@@ -240,7 +240,7 @@ class Manager implements IManager {
return $this->provider->getName();
}

public function getTaskType(): string {
public function getTaskTypeId(): string {
return TextToImage::ID;
}

@@ -327,7 +327,7 @@ class Manager implements IManager {
return $this->provider->getName();
}

public function getTaskType(): string {
public function getTaskTypeId(): string {
return AudioToText::ID;
}

@@ -451,7 +451,7 @@ class Manager implements IManager {
private function _getPreferredProvider(string $taskType) {
$providers = $this->getProviders();
foreach ($providers as $provider) {
if ($provider->getTaskType() === $taskType) {
if ($provider->getTaskTypeId() === $taskType) {
return $provider;
}
}
@@ -535,11 +535,11 @@ class Manager implements IManager {

$availableTaskTypes = [];
foreach ($providers as $provider) {
if (!isset($taskTypes[$provider->getTaskType()])) {
if (!isset($taskTypes[$provider->getTaskTypeId()])) {
continue;
}
$taskType = $taskTypes[$provider->getTaskType()];
$availableTaskTypes[$provider->getTaskType()] = [
$taskType = $taskTypes[$provider->getTaskTypeId()];
$availableTaskTypes[$provider->getTaskTypeId()] = [
'name' => $taskType->getName(),
'description' => $taskType->getDescription(),
'inputShape' => $taskType->getInputShape(),
@@ -556,23 +556,23 @@ class Manager implements IManager {
}

public function canHandleTask(Task $task): bool {
return isset($this->getAvailableTaskTypes()[$task->getTaskType()]);
return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]);
}

public function scheduleTask(Task $task): void {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskType());
throw new PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId());
}
$taskTypes = $this->getAvailableTaskTypes();
$inputShape = $taskTypes[$task->getTaskType()]['inputShape'];
$optionalInputShape = $taskTypes[$task->getTaskType()]['optionalInputShape'];
$inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape'];
$optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape'];
// validate input
$this->validateInput($inputShape, $task->getInput());
$this->validateInput($optionalInputShape, $task->getInput(), true);
// remove superfluous keys and set input
$task->setInput($this->removeSuperfluousArrayKeys($task->getInput(), $inputShape, $optionalInputShape));
$task->setStatus(Task::STATUS_SCHEDULED);
$provider = $this->_getPreferredProvider($task->getTaskType());
$provider = $this->_getPreferredProvider($task->getTaskTypeId());
// calculate expected completion time
$completionExpectedAt = new \DateTime('now');
$completionExpectedAt->add(new \DateInterval('PT'.$provider->getExpectedRuntime().'S'));
@@ -638,17 +638,17 @@ class Manager implements IManager {
// TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently
$task = $this->getTask($id);
if ($task->getStatus() === Task::STATUS_CANCELLED) {
$this->logger->info('A TaskProcessing ' . $task->getTaskType() . ' task with id ' . $id . ' finished but was cancelled in the mean time. Moving on without storing result.');
$this->logger->info('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' finished but was cancelled in the mean time. Moving on without storing result.');
return;
}
if ($error !== null) {
$task->setStatus(Task::STATUS_FAILED);
$task->setErrorMessage($error);
$this->logger->warning('A TaskProcessing ' . $task->getTaskType() . ' task with id ' . $id . ' failed with the following message: ' . $error);
$this->logger->warning('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' failed with the following message: ' . $error);
} elseif ($result !== null) {
$taskTypes = $this->getAvailableTaskTypes();
$outputShape = $taskTypes[$task->getTaskType()]['outputShape'];
$optionalOutputShape = $taskTypes[$task->getTaskType()]['optionalOutputShape'];
$outputShape = $taskTypes[$task->getTaskTypeId()]['outputShape'];
$optionalOutputShape = $taskTypes[$task->getTaskTypeId()]['optionalOutputShape'];
try {
// validate output
$this->validateOutput($outputShape, $result);
@@ -823,8 +823,8 @@ class Manager implements IManager {

public function prepareInputData(Task $task): array {
$taskTypes = $this->getAvailableTaskTypes();
$inputShape = $taskTypes[$task->getTaskType()]['inputShape'];
$optionalInputShape = $taskTypes[$task->getTaskType()]['optionalInputShape'];
$inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape'];
$optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape'];
$input = $task->getInput();
// validate input, again for good measure (should have been validated in scheduleTask)
$this->validateInput($inputShape, $input);

+ 1
- 1
lib/private/TaskProcessing/SynchronousBackgroundJob.php View File

@@ -37,7 +37,7 @@ class SynchronousBackgroundJob extends QueuedJob {
if (!$provider instanceof ISynchronousProvider) {
continue;
}
$taskType = $provider->getTaskType();
$taskType = $provider->getTaskTypeId();
try {
$task = $this->taskProcessingManager->getNextScheduledTask($taskType);
} catch (NotFoundException $e) {

+ 1
- 1
lib/public/TaskProcessing/IProvider.php View File

@@ -51,7 +51,7 @@ interface IProvider {
* @since 30.0.0
* @return string
*/
public function getTaskType(): string;
public function getTaskTypeId(): string;

/**
* @return int The expected average runtime of a task in seconds

+ 3
- 3
lib/public/TaskProcessing/Task.php View File

@@ -82,7 +82,7 @@ final class Task implements \JsonSerializable {
* @since 30.0.0
*/
final public function __construct(
protected readonly string $taskType,
protected readonly string $taskTypeId,
protected array $input,
protected readonly string $appId,
protected readonly ?string $userId,
@@ -93,8 +93,8 @@ final class Task implements \JsonSerializable {
/**
* @since 30.0.0
*/
final public function getTaskType(): string {
return $this->taskType;
final public function getTaskTypeId(): string {
return $this->taskTypeId;
}

/**

+ 4
- 4
tests/lib/TaskProcessing/TaskProcessingTest.php View File

@@ -80,7 +80,7 @@ class AsyncProvider implements IProvider {
return self::class;
}

public function getTaskType(): string {
public function getTaskTypeId(): string {
return AudioToImage::ID;
}

@@ -110,7 +110,7 @@ class SuccessfulSyncProvider implements IProvider, ISynchronousProvider {
return self::class;
}

public function getTaskType(): string {
public function getTaskTypeId(): string {
return TextToText::ID;
}

@@ -145,7 +145,7 @@ class FailingSyncProvider implements IProvider, ISynchronousProvider {
return self::class;
}

public function getTaskType(): string {
public function getTaskTypeId(): string {
return TextToText::ID;
}

@@ -179,7 +179,7 @@ class BrokenSyncProvider implements IProvider, ISynchronousProvider {
return self::class;
}

public function getTaskType(): string {
public function getTaskTypeId(): string {
return TextToText::ID;
}


Loading…
Cancel
Save