aboutsummaryrefslogtreecommitdiffstats
path: root/lib
diff options
context:
space:
mode:
authorMarcel Klehr <mklehr@gmx.net>2023-06-16 13:06:47 +0200
committerMarcel Klehr <mklehr@gmx.net>2023-07-07 13:39:10 +0200
commit34138736538f604af2c6c52aa43662d1d66087d0 (patch)
tree11a80b580159c57bbb5c3d45a0e7770da314030c /lib
parent01dd1a894dbf9eb6ed1fb013c4b5ee4816c32904 (diff)
downloadnextcloud-server-34138736538f604af2c6c52aa43662d1d66087d0.tar.gz
nextcloud-server-34138736538f604af2c6c52aa43662d1d66087d0.zip
LLM OCP API: Implement private backend code + add ILanguageModelTask
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
Diffstat (limited to 'lib')
-rw-r--r--lib/private/LanguageModel/Db/Task.php58
-rw-r--r--lib/private/LanguageModel/Db/TaskMapper.php34
-rw-r--r--lib/private/LanguageModel/LanguageModelManager.php162
-rw-r--r--lib/private/LanguageModel/TaskBackgroundJob.php72
-rw-r--r--lib/public/LanguageModel/AbstractLanguageModelTask.php48
-rw-r--r--lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php8
-rw-r--r--lib/public/LanguageModel/Events/TaskFailedEvent.php4
-rw-r--r--lib/public/LanguageModel/Events/TaskSuccessfulEvent.php4
-rw-r--r--lib/public/LanguageModel/FreePromptTask.php21
-rw-r--r--lib/public/LanguageModel/ILanguageModelManager.php10
-rw-r--r--lib/public/LanguageModel/ILanguageModelTask.php56
-rw-r--r--lib/public/LanguageModel/SummaryTask.php21
12 files changed, 452 insertions, 46 deletions
diff --git a/lib/private/LanguageModel/Db/Task.php b/lib/private/LanguageModel/Db/Task.php
new file mode 100644
index 00000000000..cee6c2fd8b9
--- /dev/null
+++ b/lib/private/LanguageModel/Db/Task.php
@@ -0,0 +1,58 @@
+<?php
+
+namespace OC\LanguageModel\Db;
+
+use OCP\AppFramework\Db\Entity;
+use OCP\LanguageModel\ILanguageModelTask;
+
+/**
+ * @method setType(string $type)
+ * @method string getType()
+ * @method setInput(string $type)
+ * @method string getInput()
+ * @method setStatus(int $type)
+ * @method int getStatus()
+ * @method setUserId(string $type)
+ * @method string getuserId()
+ * @method setAppId(string $type)
+ * @method string getAppId()
+ */
+class Task extends Entity {
+
+ protected $type;
+ protected $input;
+ protected $status;
+ protected $userId;
+ protected $appId;
+
+ /**
+ * @var string[]
+ */
+ public static array $columns = ['id', 'type', 'input', 'status', 'user_id', 'app_id'];
+
+ /**
+ * @var string[]
+ */
+ public static array $fields = ['id', 'type', 'input', 'status', 'userId', 'appId'];
+
+
+ public function __construct() {
+ // add types in constructor
+ $this->addType('id', 'integer');
+ $this->addType('type', 'string');
+ $this->addType('input', 'string');
+ $this->addType('status', 'integer');
+ $this->addType('userId', 'string');
+ $this->addType('appId', 'string');
+ }
+
+ public static function fromLanguageModelTask(ILanguageModelTask $task): Task {
+ return Task::fromParams([
+ 'type' => $task->getType(),
+ 'status' => ILanguageModelTask::STATUS_UNKNOWN,
+ 'input' => $task->getInput(),
+ 'userId' => $task->getUserId(),
+ 'appId' => $task->getAppId(),
+ ]);
+ }
+}
diff --git a/lib/private/LanguageModel/Db/TaskMapper.php b/lib/private/LanguageModel/Db/TaskMapper.php
new file mode 100644
index 00000000000..0b9004c4d96
--- /dev/null
+++ b/lib/private/LanguageModel/Db/TaskMapper.php
@@ -0,0 +1,34 @@
+<?php
+
+namespace OC\LanguageModel\Db;
+
+use OCP\AppFramework\Db\DoesNotExistException;
+use OCP\AppFramework\Db\MultipleObjectsReturnedException;
+use OCP\AppFramework\Db\QBMapper;
+use OCP\DB\Exception;
+use OCP\IDBConnection;
+
+/**
+ * @extends QBMapper<Task>
+ */
+class TaskMapper extends QBMapper {
+
+ public function __construct(IDBConnection $db) {
+ parent::__construct($db, 'oc_llm_tasks', Task::class);
+ }
+
+ /**
+ * @param int $id
+ * @return Task
+ * @throws Exception
+ * @throws DoesNotExistException
+ * @throws MultipleObjectsReturnedException
+ */
+ public function find(int $id): Task {
+ $qb = $this->db->getQueryBuilder();
+ $qb->select(Task::$columns)
+ ->from($this->tableName)
+ ->where($qb->expr()->eq('id', $qb->createPositionalParameter($id)));
+ return $this->findEntity($qb);
+ }
+}
diff --git a/lib/private/LanguageModel/LanguageModelManager.php b/lib/private/LanguageModel/LanguageModelManager.php
new file mode 100644
index 00000000000..f9f13b15d6e
--- /dev/null
+++ b/lib/private/LanguageModel/LanguageModelManager.php
@@ -0,0 +1,162 @@
+<?php
+
+namespace OC\LanguageModel;
+
+use OC\AppFramework\Bootstrap\Coordinator;
+use OC\LanguageModel\Db\Task;
+use OC\LanguageModel\Db\TaskMapper;
+use OCP\LanguageModel\AbstractLanguageModelTask;
+use OCP\LanguageModel\FreePromptTask;
+use OCP\LanguageModel\SummaryTask;
+use OCP\AppFramework\Db\DoesNotExistException;
+use OCP\AppFramework\Db\MultipleObjectsReturnedException;
+use OCP\BackgroundJob\IJobList;
+use OCP\DB\Exception;
+use OCP\IServerContainer;
+use OCP\LanguageModel\ILanguageModelManager;
+use OCP\LanguageModel\ILanguageModelProvider;
+use OCP\LanguageModel\ILanguageModelTask;
+use OCP\LanguageModel\ISummaryProvider;
+use OCP\PreConditionNotMetException;
+use Psr\Container\ContainerExceptionInterface;
+use Psr\Container\NotFoundExceptionInterface;
+use Psr\Log\LoggerInterface;
+use RuntimeException;
+use Throwable;
+
+class LanguageModelManager implements ILanguageModelManager {
+
+ /** @var ?ILanguageModelProvider[] */
+ private ?array $providers = null;
+
+ public function __construct(
+ private IServerContainer $serverContainer,
+ private Coordinator $coordinator,
+ private LoggerInterface $logger,
+ private IJobList $jobList,
+ private TaskMapper $taskMapper,
+ ) {
+ }
+
+ public function getProviders(): array {
+ $context = $this->coordinator->getRegistrationContext();
+ if ($context === null) {
+ return [];
+ }
+
+ if ($this->providers !== null) {
+ return $this->providers;
+ }
+
+ $this->providers = [];
+
+ foreach ($context->getSpeechToTextProviders() as $providerServiceRegistration) {
+ $class = $providerServiceRegistration->getService();
+ try {
+ $this->providers[$class] = $this->serverContainer->get($class);
+ } catch (NotFoundExceptionInterface|ContainerExceptionInterface|Throwable $e) {
+ $this->logger->error('Failed to load LanguageModel provider ' . $class, [
+ 'exception' => $e,
+ ]);
+ }
+ }
+
+ return $this->providers;
+ }
+
+ public function hasProviders(): bool {
+ $context = $this->coordinator->getRegistrationContext();
+ if ($context === null) {
+ return false;
+ }
+ return !empty($context->getSpeechToTextProviders());
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function getAvailableTasks(): array {
+ $tasks = [];
+ foreach ($this->getProviders() as $provider) {
+ $tasks[FreePromptTask::class] = true;
+ if ($provider instanceof ISummaryProvider) {
+ $tasks[SummaryTask::class] = true;
+ }
+ }
+ return array_keys($tasks);
+ }
+
+ public function canHandleTask(ILanguageModelTask $task): bool {
+ return !empty(array_filter($this->getAvailableTasks(), fn ($class) => $task instanceof $class));
+ }
+
+ /**
+ * @inheritDoc
+ */
+ public function runTask(ILanguageModelTask $task): string {
+ if (!$this->canHandleTask($task)) {
+ throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
+ }
+ foreach ($this->getProviders() as $provider) {
+ if (!$task->canUseProvider($provider)) {
+ continue;
+ }
+ try {
+ $task->setStatus(ILanguageModelTask::STATUS_RUNNING);
+ $this->taskMapper->update(Task::fromLanguageModelTask($task));
+ $output = $task->visitProvider($provider);
+ $task->setStatus(ILanguageModelTask::STATUS_SUCCESSFUL);
+ $this->taskMapper->update(Task::fromLanguageModelTask($task));
+ return $output;
+ } catch (\RuntimeException $e) {
+ $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
+ $task->setStatus(ILanguageModelTask::STATUS_FAILED);
+ $this->taskMapper->update(Task::fromLanguageModelTask($task));
+ throw $e;
+ } catch (\Throwable $e) {
+ $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]);
+ $task->setStatus(ILanguageModelTask::STATUS_FAILED);
+ $this->taskMapper->update(Task::fromLanguageModelTask($task));
+ throw new RuntimeException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage());
+ }
+ }
+
+ throw new RuntimeException('Could not transcribe file');
+ }
+
+ /**
+ * @inheritDoc
+ * @throws Exception
+ */
+ public function scheduleTask(ILanguageModelTask $task): void {
+ if (!$this->canHandleTask($task)) {
+ throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
+ }
+ $taskEntity = Task::fromLanguageModelTask($task);
+ $this->taskMapper->insert($taskEntity);
+ $task->setId($taskEntity->getId());
+ $task->setStatus(ILanguageModelTask::STATUS_SCHEDULED);
+ $this->jobList->add(TaskBackgroundJob::class, [
+ 'taskId' => $task->getId()
+ ]);
+ }
+
+ /**
+ * @param int $id The id of the task
+ * @return ILanguageModelTask
+ * @throws RuntimeException If the query failed
+ * @throws \ValueError If the task could not be found
+ */
+ public function getTask(int $id): ILanguageModelTask {
+ try {
+ $taskEntity = $this->taskMapper->find($id);
+ return AbstractLanguageModelTask::fromTaskEntity($taskEntity);
+ } catch (DoesNotExistException $e) {
+ throw new \ValueError('Could not find task with the provided id');
+ } catch (MultipleObjectsReturnedException $e) {
+ throw new RuntimeException('Could not uniquely identify task with given id');
+ } catch (Exception $e) {
+ throw new RuntimeException('Failure while trying to find task by id: '.$e->getMessage());
+ }
+ }
+}
diff --git a/lib/private/LanguageModel/TaskBackgroundJob.php b/lib/private/LanguageModel/TaskBackgroundJob.php
new file mode 100644
index 00000000000..55413ba3714
--- /dev/null
+++ b/lib/private/LanguageModel/TaskBackgroundJob.php
@@ -0,0 +1,72 @@
+<?php
+
+declare(strict_types=1);
+
+/**
+ * @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net>
+ *
+ * @author Marcel Klehr <mklehr@gmx.net>
+ *
+ * @license GNU AGPL version 3 or any later version
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+
+namespace OC\LanguageModel;
+
+use OC\User\NoUserException;
+use OCP\AppFramework\Utility\ITimeFactory;
+use OCP\BackgroundJob\QueuedJob;
+use OCP\EventDispatcher\IEventDispatcher;
+use OCP\Files\File;
+use OCP\Files\IRootFolder;
+use OCP\Files\NotFoundException;
+use OCP\Files\NotPermittedException;
+use OCP\LanguageModel\Events\TaskFailedEvent;
+use OCP\LanguageModel\Events\TaskSuccessfulEvent;
+use OCP\LanguageModel\ILanguageModelManager;
+use OCP\PreConditionNotMetException;
+use OCP\SpeechToText\Events\TranscriptionFailedEvent;
+use OCP\SpeechToText\Events\TranscriptionSuccessfulEvent;
+use OCP\SpeechToText\ISpeechToTextManager;
+use Psr\Log\LoggerInterface;
+
+class TaskBackgroundJob extends QueuedJob {
+ public function __construct(
+ ITimeFactory $timeFactory,
+ private ILanguageModelManager $languageModelManager,
+ private IEventDispatcher $eventDispatcher,
+ ) {
+ parent::__construct($timeFactory);
+ $this->setAllowParallelRuns(false);
+ }
+
+ /**
+ * @param array{taskId: int} $argument
+ * @inheritDoc
+ */
+ protected function run($argument) {
+ $taskId = $argument['taskId'];
+ $task = $this->languageModelManager->getTask($taskId);
+ try {
+ $output = $this->languageModelManager->runTask($task);
+ $event = new TaskSuccessfulEvent($task, $output);
+
+ } catch (\RuntimeException|PreConditionNotMetException $e) {
+ $event = new TaskFailedEvent($task, $e->getMessage());
+ }
+ $this->eventDispatcher->dispatchTyped($event);
+ }
+}
diff --git a/lib/public/LanguageModel/AbstractLanguageModelTask.php b/lib/public/LanguageModel/AbstractLanguageModelTask.php
index 50ae2095235..12aedc95fe5 100644
--- a/lib/public/LanguageModel/AbstractLanguageModelTask.php
+++ b/lib/public/LanguageModel/AbstractLanguageModelTask.php
@@ -2,71 +2,91 @@
namespace OCP\LanguageModel;
-abstract class AbstractLanguageModelTask {
- public const STATUS_UNKNOWN = 0;
- public const STATUS_RUNNING = 1;
- public const STATUS_SUCCESSFUL = 2;
- public const STATUS_FAILED = 4;
+use OC\LanguageModel\Db\Task;
+abstract class AbstractLanguageModelTask implements ILanguageModelTask {
protected ?int $id;
- protected int $status = self::STATUS_UNKNOWN;
+ protected int $status = ILanguageModelTask::STATUS_UNKNOWN;
- public function __construct(
+ public final function __construct(
protected string $input,
protected string $appId,
protected ?string $userId,
) {
}
+ /**
+ * @param ILanguageModelProvider $provider
+ * @return string
+ * @throws \RuntimeException
+ */
abstract public function visitProvider(ILanguageModelProvider $provider): string;
+ abstract public function canUseProvider(ILanguageModelProvider $provider): bool;
+
+ abstract public function getType(): string;
+
/**
* @return int
*/
- public function getStatus(): int {
+ public final function getStatus(): int {
return $this->status;
}
/**
* @param int $status
*/
- public function setStatus(int $status): void {
+ public final function setStatus(int $status): void {
$this->status = $status;
}
/**
* @return int|null
*/
- public function getId(): ?int {
+ public final function getId(): ?int {
return $this->id;
}
/**
* @param int|null $id
*/
- public function setId(?int $id): void {
+ public final function setId(?int $id): void {
$this->id = $id;
}
/**
* @return string
*/
- public function getInput(): string {
+ public final function getInput(): string {
return $this->input;
}
/**
* @return string
*/
- public function getAppId(): string {
+ public final function getAppId(): string {
return $this->appId;
}
/**
* @return string|null
*/
- public function getUserId(): ?string {
+ public final function getUserId(): ?string {
return $this->userId;
}
+
+ public final static function fromTaskEntity(Task $taskEntity): ILanguageModelTask {
+ $task = self::factory($taskEntity->getType(), $taskEntity->getInput(), $taskEntity->getuserId(), $taskEntity->getAppId());
+ $task->setId($taskEntity->getId());
+ $task->setStatus($taskEntity->getStatus());
+ return $task;
+ }
+
+ public final static function factory(string $type, string $input, ?string $userId, string $appId): ILanguageModelTask {
+ if (!in_array($type, self::TYPES)) {
+ throw new \InvalidArgumentException('Unknown task type');
+ }
+ return new ILanguageModelTask::TYPES[$type]($input, $userId, $appId);
+ }
}
diff --git a/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php b/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php
index 3d274330dc7..218a4480081 100644
--- a/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php
+++ b/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php
@@ -26,7 +26,7 @@ declare(strict_types=1);
namespace OCP\LanguageModel\Events;
use OCP\EventDispatcher\Event;
-use OCP\LanguageModel\AbstractLanguageModelTask;
+use OCP\LanguageModel\ILanguageModelTask;
/**
* @since 28.0.0
@@ -36,16 +36,16 @@ abstract class AbstractLanguageModelEvent extends Event {
* @since 28.0.0
*/
public function __construct(
- private AbstractLanguageModelTask $task
+ private ILanguageModelTask $task
) {
parent::__construct();
}
/**
- * @return AbstractLanguageModelTask
+ * @return ILanguageModelTask
* @since 28.0.0
*/
- public function getTask(): AbstractLanguageModelTask {
+ public function getTask(): ILanguageModelTask {
return $this->task;
}
}
diff --git a/lib/public/LanguageModel/Events/TaskFailedEvent.php b/lib/public/LanguageModel/Events/TaskFailedEvent.php
index 2b0dea9153f..5134c37476a 100644
--- a/lib/public/LanguageModel/Events/TaskFailedEvent.php
+++ b/lib/public/LanguageModel/Events/TaskFailedEvent.php
@@ -2,14 +2,14 @@
namespace OCP\LanguageModel\Events;
-use OCP\LanguageModel\AbstractLanguageModelTask;
+use OCP\LanguageModel\ILanguageModelTask;
/**
* @since 28.0.0
*/
class TaskFailedEvent extends AbstractLanguageModelEvent {
- public function __construct(AbstractLanguageModelTask $task,
+ public function __construct(ILanguageModelTask $task,
private string $errorMessage) {
parent::__construct($task);
}
diff --git a/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php b/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php
index 6cdb57143f9..156c5679e0b 100644
--- a/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php
+++ b/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php
@@ -2,14 +2,14 @@
namespace OCP\LanguageModel\Events;
-use OCP\LanguageModel\AbstractLanguageModelTask;
+use OCP\LanguageModel\ILanguageModelTask;
/**
* @since 28.0.0
*/
class TaskSuccessfulEvent extends AbstractLanguageModelEvent {
- public function __construct(AbstractLanguageModelTask $task,
+ public function __construct(ILanguageModelTask $task,
private string $output) {
parent::__construct($task);
}
diff --git a/lib/public/LanguageModel/FreePromptTask.php b/lib/public/LanguageModel/FreePromptTask.php
index ff7fa7fffed..a179048631c 100644
--- a/lib/public/LanguageModel/FreePromptTask.php
+++ b/lib/public/LanguageModel/FreePromptTask.php
@@ -4,7 +4,8 @@ namespace OCP\LanguageModel;
use RuntimeException;
-class FreePromptTask extends AbstractLanguageModelTask {
+final class FreePromptTask extends AbstractLanguageModelTask {
+ public const TYPE = 'free_prompt';
/**
* @param ILanguageModelProvider $provider
@@ -12,14 +13,14 @@ class FreePromptTask extends AbstractLanguageModelTask {
* @return string
*/
public function visitProvider(ILanguageModelProvider $provider): string {
- $this->setStatus(self::STATUS_RUNNING);
- try {
- $output = $provider->prompt($this->getInput());
- } catch (RuntimeException $e) {
- $this->setStatus(self::STATUS_FAILED);
- throw $e;
- }
- $this->setStatus(self::STATUS_SUCCESSFUL);
- return $output;
+ return $provider->prompt($this->getInput());
+ }
+
+ public function canUseProvider(ILanguageModelProvider $provider): bool {
+ return true;
+ }
+
+ public function getType(): string {
+ return self::TYPE;
}
}
diff --git a/lib/public/LanguageModel/ILanguageModelManager.php b/lib/public/LanguageModel/ILanguageModelManager.php
index e0d33777052..a4d3079c180 100644
--- a/lib/public/LanguageModel/ILanguageModelManager.php
+++ b/lib/public/LanguageModel/ILanguageModelManager.php
@@ -27,6 +27,7 @@ declare(strict_types=1);
namespace OCP\LanguageModel;
use InvalidArgumentException;
+use OCP\LanguageModel\AbstractLanguageModelTask;
use OCP\LanguageModel\Events\AbstractLanguageModelEvent;
use OCP\PreConditionNotMetException;
use RuntimeException;
@@ -45,11 +46,10 @@ interface ILanguageModelManager {
/**
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called
- * @throws InvalidArgumentException If the file could not be found or is not of a supported type
- * @throws RuntimeException If the transcription failed for other reasons
+ * @throws RuntimeException If something else failed
* @since 28.0.0
*/
- public function runTask(AbstractLanguageModelTask $task): AbstractLanguageModelEvent;
+ public function runTask(ILanguageModelTask $task): string;
/**
* Will schedule an LLM inference process in the background. The result will become available
@@ -58,5 +58,7 @@ interface ILanguageModelManager {
* @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called
* @since 28.0.0
*/
- public function scheduleTask(AbstractLanguageModelTask $task) : void;
+ public function scheduleTask(ILanguageModelTask $task) : void;
+
+ public function getTask(int $id): ILanguageModelTask;
}
diff --git a/lib/public/LanguageModel/ILanguageModelTask.php b/lib/public/LanguageModel/ILanguageModelTask.php
new file mode 100644
index 00000000000..478ee54e8a3
--- /dev/null
+++ b/lib/public/LanguageModel/ILanguageModelTask.php
@@ -0,0 +1,56 @@
+<?php
+
+namespace OCP\LanguageModel;
+
+interface ILanguageModelTask {
+ public const STATUS_FAILED = 4;
+ public const STATUS_SUCCESSFUL = 3;
+ public const STATUS_RUNNING = 2;
+ public const STATUS_SCHEDULED = 1;
+ public const STATUS_UNKNOWN = 0;
+
+ public const TYPES = [
+ SummaryTask::TYPE => SummaryTask::class,
+ FreePromptTask::TYPE => FreePromptTask::class,
+ ];
+
+ /**
+ * @return string
+ */
+ public function getType(): string;
+
+ /**
+ * @return int
+ */
+ public function getStatus(): int;
+
+ /**
+ * @param int $status
+ */
+ public function setStatus(int $status): void;
+
+ /**
+ * @param int|null $id
+ */
+ public function setId(?int $id): void;
+
+ /**
+ * @return int|null
+ */
+ public function getId(): ?int;
+
+ /**
+ * @return string
+ */
+ public function getInput(): string;
+
+ /**
+ * @return string
+ */
+ public function getAppId(): string;
+
+ /**
+ * @return string|null
+ */
+ public function getUserId(): ?string;
+}
diff --git a/lib/public/LanguageModel/SummaryTask.php b/lib/public/LanguageModel/SummaryTask.php
index 0037beb4593..35f20cebfb6 100644
--- a/lib/public/LanguageModel/SummaryTask.php
+++ b/lib/public/LanguageModel/SummaryTask.php
@@ -4,7 +4,8 @@ namespace OCP\LanguageModel;
use RuntimeException;
-class SummaryTask extends AbstractLanguageModelTask {
+final class SummaryTask extends AbstractLanguageModelTask {
+ public const TYPE = 'summarize';
/**
* @param ILanguageModelProvider&ISummaryProvider $provider
@@ -15,14 +16,14 @@ class SummaryTask extends AbstractLanguageModelTask {
if (!$provider instanceof ISummaryProvider) {
throw new \RuntimeException('SummaryTask#visitProvider expects ISummaryProvider');
}
- $this->setStatus(self::STATUS_RUNNING);
- try {
- $output = $provider->summarize($this->getInput());
- } catch (RuntimeException $e) {
- $this->setStatus(self::STATUS_FAILED);
- throw $e;
- }
- $this->setStatus(self::STATUS_SUCCESSFUL);
- return $output;
+ return $provider->summarize($this->getInput());
+ }
+
+ public function canUseProvider(ILanguageModelProvider $provider): bool {
+ return $provider instanceof ISummaryProvider;
+ }
+
+ public function getType(): string {
+ return self::TYPE;
}
}