diff options
-rw-r--r-- | lib/private/LanguageModel/Db/Task.php | 58 | ||||
-rw-r--r-- | lib/private/LanguageModel/Db/TaskMapper.php | 34 | ||||
-rw-r--r-- | lib/private/LanguageModel/LanguageModelManager.php | 162 | ||||
-rw-r--r-- | lib/private/LanguageModel/TaskBackgroundJob.php | 72 | ||||
-rw-r--r-- | lib/public/LanguageModel/AbstractLanguageModelTask.php | 48 | ||||
-rw-r--r-- | lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php | 8 | ||||
-rw-r--r-- | lib/public/LanguageModel/Events/TaskFailedEvent.php | 4 | ||||
-rw-r--r-- | lib/public/LanguageModel/Events/TaskSuccessfulEvent.php | 4 | ||||
-rw-r--r-- | lib/public/LanguageModel/FreePromptTask.php | 21 | ||||
-rw-r--r-- | lib/public/LanguageModel/ILanguageModelManager.php | 10 | ||||
-rw-r--r-- | lib/public/LanguageModel/ILanguageModelTask.php | 56 | ||||
-rw-r--r-- | lib/public/LanguageModel/SummaryTask.php | 21 |
12 files changed, 452 insertions, 46 deletions
diff --git a/lib/private/LanguageModel/Db/Task.php b/lib/private/LanguageModel/Db/Task.php new file mode 100644 index 00000000000..cee6c2fd8b9 --- /dev/null +++ b/lib/private/LanguageModel/Db/Task.php @@ -0,0 +1,58 @@ +<?php + +namespace OC\LanguageModel\Db; + +use OCP\AppFramework\Db\Entity; +use OCP\LanguageModel\ILanguageModelTask; + +/** + * @method setType(string $type) + * @method string getType() + * @method setInput(string $type) + * @method string getInput() + * @method setStatus(int $type) + * @method int getStatus() + * @method setUserId(string $type) + * @method string getuserId() + * @method setAppId(string $type) + * @method string getAppId() + */ +class Task extends Entity { + + protected $type; + protected $input; + protected $status; + protected $userId; + protected $appId; + + /** + * @var string[] + */ + public static array $columns = ['id', 'type', 'input', 'status', 'user_id', 'app_id']; + + /** + * @var string[] + */ + public static array $fields = ['id', 'type', 'input', 'status', 'userId', 'appId']; + + + public function __construct() { + // add types in constructor + $this->addType('id', 'integer'); + $this->addType('type', 'string'); + $this->addType('input', 'string'); + $this->addType('status', 'integer'); + $this->addType('userId', 'string'); + $this->addType('appId', 'string'); + } + + public static function fromLanguageModelTask(ILanguageModelTask $task): Task { + return Task::fromParams([ + 'type' => $task->getType(), + 'status' => ILanguageModelTask::STATUS_UNKNOWN, + 'input' => $task->getInput(), + 'userId' => $task->getUserId(), + 'appId' => $task->getAppId(), + ]); + } +} diff --git a/lib/private/LanguageModel/Db/TaskMapper.php b/lib/private/LanguageModel/Db/TaskMapper.php new file mode 100644 index 00000000000..0b9004c4d96 --- /dev/null +++ b/lib/private/LanguageModel/Db/TaskMapper.php @@ -0,0 +1,34 @@ +<?php + +namespace OC\LanguageModel\Db; + +use OCP\AppFramework\Db\DoesNotExistException; +use OCP\AppFramework\Db\MultipleObjectsReturnedException; +use OCP\AppFramework\Db\QBMapper; +use OCP\DB\Exception; +use OCP\IDBConnection; + +/** + * @extends QBMapper<Task> + */ +class TaskMapper extends QBMapper { + + public function __construct(IDBConnection $db) { + parent::__construct($db, 'oc_llm_tasks', Task::class); + } + + /** + * @param int $id + * @return Task + * @throws Exception + * @throws DoesNotExistException + * @throws MultipleObjectsReturnedException + */ + public function find(int $id): Task { + $qb = $this->db->getQueryBuilder(); + $qb->select(Task::$columns) + ->from($this->tableName) + ->where($qb->expr()->eq('id', $qb->createPositionalParameter($id))); + return $this->findEntity($qb); + } +} diff --git a/lib/private/LanguageModel/LanguageModelManager.php b/lib/private/LanguageModel/LanguageModelManager.php new file mode 100644 index 00000000000..f9f13b15d6e --- /dev/null +++ b/lib/private/LanguageModel/LanguageModelManager.php @@ -0,0 +1,162 @@ +<?php + +namespace OC\LanguageModel; + +use OC\AppFramework\Bootstrap\Coordinator; +use OC\LanguageModel\Db\Task; +use OC\LanguageModel\Db\TaskMapper; +use OCP\LanguageModel\AbstractLanguageModelTask; +use OCP\LanguageModel\FreePromptTask; +use OCP\LanguageModel\SummaryTask; +use OCP\AppFramework\Db\DoesNotExistException; +use OCP\AppFramework\Db\MultipleObjectsReturnedException; +use OCP\BackgroundJob\IJobList; +use OCP\DB\Exception; +use OCP\IServerContainer; +use OCP\LanguageModel\ILanguageModelManager; +use OCP\LanguageModel\ILanguageModelProvider; +use OCP\LanguageModel\ILanguageModelTask; +use OCP\LanguageModel\ISummaryProvider; +use OCP\PreConditionNotMetException; +use Psr\Container\ContainerExceptionInterface; +use Psr\Container\NotFoundExceptionInterface; +use Psr\Log\LoggerInterface; +use RuntimeException; +use Throwable; + +class LanguageModelManager implements ILanguageModelManager { + + /** @var ?ILanguageModelProvider[] */ + private ?array $providers = null; + + public function __construct( + private IServerContainer $serverContainer, + private Coordinator $coordinator, + private LoggerInterface $logger, + private IJobList $jobList, + private TaskMapper $taskMapper, + ) { + } + + public function getProviders(): array { + $context = $this->coordinator->getRegistrationContext(); + if ($context === null) { + return []; + } + + if ($this->providers !== null) { + return $this->providers; + } + + $this->providers = []; + + foreach ($context->getSpeechToTextProviders() as $providerServiceRegistration) { + $class = $providerServiceRegistration->getService(); + try { + $this->providers[$class] = $this->serverContainer->get($class); + } catch (NotFoundExceptionInterface|ContainerExceptionInterface|Throwable $e) { + $this->logger->error('Failed to load LanguageModel provider ' . $class, [ + 'exception' => $e, + ]); + } + } + + return $this->providers; + } + + public function hasProviders(): bool { + $context = $this->coordinator->getRegistrationContext(); + if ($context === null) { + return false; + } + return !empty($context->getSpeechToTextProviders()); + } + + /** + * @inheritDoc + */ + public function getAvailableTasks(): array { + $tasks = []; + foreach ($this->getProviders() as $provider) { + $tasks[FreePromptTask::class] = true; + if ($provider instanceof ISummaryProvider) { + $tasks[SummaryTask::class] = true; + } + } + return array_keys($tasks); + } + + public function canHandleTask(ILanguageModelTask $task): bool { + return !empty(array_filter($this->getAvailableTasks(), fn ($class) => $task instanceof $class)); + } + + /** + * @inheritDoc + */ + public function runTask(ILanguageModelTask $task): string { + if (!$this->canHandleTask($task)) { + throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); + } + foreach ($this->getProviders() as $provider) { + if (!$task->canUseProvider($provider)) { + continue; + } + try { + $task->setStatus(ILanguageModelTask::STATUS_RUNNING); + $this->taskMapper->update(Task::fromLanguageModelTask($task)); + $output = $task->visitProvider($provider); + $task->setStatus(ILanguageModelTask::STATUS_SUCCESSFUL); + $this->taskMapper->update(Task::fromLanguageModelTask($task)); + return $output; + } catch (\RuntimeException $e) { + $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]); + $task->setStatus(ILanguageModelTask::STATUS_FAILED); + $this->taskMapper->update(Task::fromLanguageModelTask($task)); + throw $e; + } catch (\Throwable $e) { + $this->logger->info('LanguageModel call using provider ' . $provider->getName() . ' failed', ['exception' => $e]); + $task->setStatus(ILanguageModelTask::STATUS_FAILED); + $this->taskMapper->update(Task::fromLanguageModelTask($task)); + throw new RuntimeException('LanguageModel call using provider ' . $provider->getName() . ' failed: ' . $e->getMessage()); + } + } + + throw new RuntimeException('Could not transcribe file'); + } + + /** + * @inheritDoc + * @throws Exception + */ + public function scheduleTask(ILanguageModelTask $task): void { + if (!$this->canHandleTask($task)) { + throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task'); + } + $taskEntity = Task::fromLanguageModelTask($task); + $this->taskMapper->insert($taskEntity); + $task->setId($taskEntity->getId()); + $task->setStatus(ILanguageModelTask::STATUS_SCHEDULED); + $this->jobList->add(TaskBackgroundJob::class, [ + 'taskId' => $task->getId() + ]); + } + + /** + * @param int $id The id of the task + * @return ILanguageModelTask + * @throws RuntimeException If the query failed + * @throws \ValueError If the task could not be found + */ + public function getTask(int $id): ILanguageModelTask { + try { + $taskEntity = $this->taskMapper->find($id); + return AbstractLanguageModelTask::fromTaskEntity($taskEntity); + } catch (DoesNotExistException $e) { + throw new \ValueError('Could not find task with the provided id'); + } catch (MultipleObjectsReturnedException $e) { + throw new RuntimeException('Could not uniquely identify task with given id'); + } catch (Exception $e) { + throw new RuntimeException('Failure while trying to find task by id: '.$e->getMessage()); + } + } +} diff --git a/lib/private/LanguageModel/TaskBackgroundJob.php b/lib/private/LanguageModel/TaskBackgroundJob.php new file mode 100644 index 00000000000..55413ba3714 --- /dev/null +++ b/lib/private/LanguageModel/TaskBackgroundJob.php @@ -0,0 +1,72 @@ +<?php + +declare(strict_types=1); + +/** + * @copyright Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net> + * + * @author Marcel Klehr <mklehr@gmx.net> + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + + +namespace OC\LanguageModel; + +use OC\User\NoUserException; +use OCP\AppFramework\Utility\ITimeFactory; +use OCP\BackgroundJob\QueuedJob; +use OCP\EventDispatcher\IEventDispatcher; +use OCP\Files\File; +use OCP\Files\IRootFolder; +use OCP\Files\NotFoundException; +use OCP\Files\NotPermittedException; +use OCP\LanguageModel\Events\TaskFailedEvent; +use OCP\LanguageModel\Events\TaskSuccessfulEvent; +use OCP\LanguageModel\ILanguageModelManager; +use OCP\PreConditionNotMetException; +use OCP\SpeechToText\Events\TranscriptionFailedEvent; +use OCP\SpeechToText\Events\TranscriptionSuccessfulEvent; +use OCP\SpeechToText\ISpeechToTextManager; +use Psr\Log\LoggerInterface; + +class TaskBackgroundJob extends QueuedJob { + public function __construct( + ITimeFactory $timeFactory, + private ILanguageModelManager $languageModelManager, + private IEventDispatcher $eventDispatcher, + ) { + parent::__construct($timeFactory); + $this->setAllowParallelRuns(false); + } + + /** + * @param array{taskId: int} $argument + * @inheritDoc + */ + protected function run($argument) { + $taskId = $argument['taskId']; + $task = $this->languageModelManager->getTask($taskId); + try { + $output = $this->languageModelManager->runTask($task); + $event = new TaskSuccessfulEvent($task, $output); + + } catch (\RuntimeException|PreConditionNotMetException $e) { + $event = new TaskFailedEvent($task, $e->getMessage()); + } + $this->eventDispatcher->dispatchTyped($event); + } +} diff --git a/lib/public/LanguageModel/AbstractLanguageModelTask.php b/lib/public/LanguageModel/AbstractLanguageModelTask.php index 50ae2095235..12aedc95fe5 100644 --- a/lib/public/LanguageModel/AbstractLanguageModelTask.php +++ b/lib/public/LanguageModel/AbstractLanguageModelTask.php @@ -2,71 +2,91 @@ namespace OCP\LanguageModel; -abstract class AbstractLanguageModelTask { - public const STATUS_UNKNOWN = 0; - public const STATUS_RUNNING = 1; - public const STATUS_SUCCESSFUL = 2; - public const STATUS_FAILED = 4; +use OC\LanguageModel\Db\Task; +abstract class AbstractLanguageModelTask implements ILanguageModelTask { protected ?int $id; - protected int $status = self::STATUS_UNKNOWN; + protected int $status = ILanguageModelTask::STATUS_UNKNOWN; - public function __construct( + public final function __construct( protected string $input, protected string $appId, protected ?string $userId, ) { } + /** + * @param ILanguageModelProvider $provider + * @return string + * @throws \RuntimeException + */ abstract public function visitProvider(ILanguageModelProvider $provider): string; + abstract public function canUseProvider(ILanguageModelProvider $provider): bool; + + abstract public function getType(): string; + /** * @return int */ - public function getStatus(): int { + public final function getStatus(): int { return $this->status; } /** * @param int $status */ - public function setStatus(int $status): void { + public final function setStatus(int $status): void { $this->status = $status; } /** * @return int|null */ - public function getId(): ?int { + public final function getId(): ?int { return $this->id; } /** * @param int|null $id */ - public function setId(?int $id): void { + public final function setId(?int $id): void { $this->id = $id; } /** * @return string */ - public function getInput(): string { + public final function getInput(): string { return $this->input; } /** * @return string */ - public function getAppId(): string { + public final function getAppId(): string { return $this->appId; } /** * @return string|null */ - public function getUserId(): ?string { + public final function getUserId(): ?string { return $this->userId; } + + public final static function fromTaskEntity(Task $taskEntity): ILanguageModelTask { + $task = self::factory($taskEntity->getType(), $taskEntity->getInput(), $taskEntity->getuserId(), $taskEntity->getAppId()); + $task->setId($taskEntity->getId()); + $task->setStatus($taskEntity->getStatus()); + return $task; + } + + public final static function factory(string $type, string $input, ?string $userId, string $appId): ILanguageModelTask { + if (!in_array($type, self::TYPES)) { + throw new \InvalidArgumentException('Unknown task type'); + } + return new ILanguageModelTask::TYPES[$type]($input, $userId, $appId); + } } diff --git a/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php b/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php index 3d274330dc7..218a4480081 100644 --- a/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php +++ b/lib/public/LanguageModel/Events/AbstractLanguageModelEvent.php @@ -26,7 +26,7 @@ declare(strict_types=1); namespace OCP\LanguageModel\Events; use OCP\EventDispatcher\Event; -use OCP\LanguageModel\AbstractLanguageModelTask; +use OCP\LanguageModel\ILanguageModelTask; /** * @since 28.0.0 @@ -36,16 +36,16 @@ abstract class AbstractLanguageModelEvent extends Event { * @since 28.0.0 */ public function __construct( - private AbstractLanguageModelTask $task + private ILanguageModelTask $task ) { parent::__construct(); } /** - * @return AbstractLanguageModelTask + * @return ILanguageModelTask * @since 28.0.0 */ - public function getTask(): AbstractLanguageModelTask { + public function getTask(): ILanguageModelTask { return $this->task; } } diff --git a/lib/public/LanguageModel/Events/TaskFailedEvent.php b/lib/public/LanguageModel/Events/TaskFailedEvent.php index 2b0dea9153f..5134c37476a 100644 --- a/lib/public/LanguageModel/Events/TaskFailedEvent.php +++ b/lib/public/LanguageModel/Events/TaskFailedEvent.php @@ -2,14 +2,14 @@ namespace OCP\LanguageModel\Events; -use OCP\LanguageModel\AbstractLanguageModelTask; +use OCP\LanguageModel\ILanguageModelTask; /** * @since 28.0.0 */ class TaskFailedEvent extends AbstractLanguageModelEvent { - public function __construct(AbstractLanguageModelTask $task, + public function __construct(ILanguageModelTask $task, private string $errorMessage) { parent::__construct($task); } diff --git a/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php b/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php index 6cdb57143f9..156c5679e0b 100644 --- a/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php +++ b/lib/public/LanguageModel/Events/TaskSuccessfulEvent.php @@ -2,14 +2,14 @@ namespace OCP\LanguageModel\Events; -use OCP\LanguageModel\AbstractLanguageModelTask; +use OCP\LanguageModel\ILanguageModelTask; /** * @since 28.0.0 */ class TaskSuccessfulEvent extends AbstractLanguageModelEvent { - public function __construct(AbstractLanguageModelTask $task, + public function __construct(ILanguageModelTask $task, private string $output) { parent::__construct($task); } diff --git a/lib/public/LanguageModel/FreePromptTask.php b/lib/public/LanguageModel/FreePromptTask.php index ff7fa7fffed..a179048631c 100644 --- a/lib/public/LanguageModel/FreePromptTask.php +++ b/lib/public/LanguageModel/FreePromptTask.php @@ -4,7 +4,8 @@ namespace OCP\LanguageModel; use RuntimeException; -class FreePromptTask extends AbstractLanguageModelTask { +final class FreePromptTask extends AbstractLanguageModelTask { + public const TYPE = 'free_prompt'; /** * @param ILanguageModelProvider $provider @@ -12,14 +13,14 @@ class FreePromptTask extends AbstractLanguageModelTask { * @return string */ public function visitProvider(ILanguageModelProvider $provider): string { - $this->setStatus(self::STATUS_RUNNING); - try { - $output = $provider->prompt($this->getInput()); - } catch (RuntimeException $e) { - $this->setStatus(self::STATUS_FAILED); - throw $e; - } - $this->setStatus(self::STATUS_SUCCESSFUL); - return $output; + return $provider->prompt($this->getInput()); + } + + public function canUseProvider(ILanguageModelProvider $provider): bool { + return true; + } + + public function getType(): string { + return self::TYPE; } } diff --git a/lib/public/LanguageModel/ILanguageModelManager.php b/lib/public/LanguageModel/ILanguageModelManager.php index e0d33777052..a4d3079c180 100644 --- a/lib/public/LanguageModel/ILanguageModelManager.php +++ b/lib/public/LanguageModel/ILanguageModelManager.php @@ -27,6 +27,7 @@ declare(strict_types=1); namespace OCP\LanguageModel; use InvalidArgumentException; +use OCP\LanguageModel\AbstractLanguageModelTask; use OCP\LanguageModel\Events\AbstractLanguageModelEvent; use OCP\PreConditionNotMetException; use RuntimeException; @@ -45,11 +46,10 @@ interface ILanguageModelManager { /** * @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called - * @throws InvalidArgumentException If the file could not be found or is not of a supported type - * @throws RuntimeException If the transcription failed for other reasons + * @throws RuntimeException If something else failed * @since 28.0.0 */ - public function runTask(AbstractLanguageModelTask $task): AbstractLanguageModelEvent; + public function runTask(ILanguageModelTask $task): string; /** * Will schedule an LLM inference process in the background. The result will become available @@ -58,5 +58,7 @@ interface ILanguageModelManager { * @throws PreConditionNotMetException If no or not the requested provider was registered but this method was still called * @since 28.0.0 */ - public function scheduleTask(AbstractLanguageModelTask $task) : void; + public function scheduleTask(ILanguageModelTask $task) : void; + + public function getTask(int $id): ILanguageModelTask; } diff --git a/lib/public/LanguageModel/ILanguageModelTask.php b/lib/public/LanguageModel/ILanguageModelTask.php new file mode 100644 index 00000000000..478ee54e8a3 --- /dev/null +++ b/lib/public/LanguageModel/ILanguageModelTask.php @@ -0,0 +1,56 @@ +<?php + +namespace OCP\LanguageModel; + +interface ILanguageModelTask { + public const STATUS_FAILED = 4; + public const STATUS_SUCCESSFUL = 3; + public const STATUS_RUNNING = 2; + public const STATUS_SCHEDULED = 1; + public const STATUS_UNKNOWN = 0; + + public const TYPES = [ + SummaryTask::TYPE => SummaryTask::class, + FreePromptTask::TYPE => FreePromptTask::class, + ]; + + /** + * @return string + */ + public function getType(): string; + + /** + * @return int + */ + public function getStatus(): int; + + /** + * @param int $status + */ + public function setStatus(int $status): void; + + /** + * @param int|null $id + */ + public function setId(?int $id): void; + + /** + * @return int|null + */ + public function getId(): ?int; + + /** + * @return string + */ + public function getInput(): string; + + /** + * @return string + */ + public function getAppId(): string; + + /** + * @return string|null + */ + public function getUserId(): ?string; +} diff --git a/lib/public/LanguageModel/SummaryTask.php b/lib/public/LanguageModel/SummaryTask.php index 0037beb4593..35f20cebfb6 100644 --- a/lib/public/LanguageModel/SummaryTask.php +++ b/lib/public/LanguageModel/SummaryTask.php @@ -4,7 +4,8 @@ namespace OCP\LanguageModel; use RuntimeException; -class SummaryTask extends AbstractLanguageModelTask { +final class SummaryTask extends AbstractLanguageModelTask { + public const TYPE = 'summarize'; /** * @param ILanguageModelProvider&ISummaryProvider $provider @@ -15,14 +16,14 @@ class SummaryTask extends AbstractLanguageModelTask { if (!$provider instanceof ISummaryProvider) { throw new \RuntimeException('SummaryTask#visitProvider expects ISummaryProvider'); } - $this->setStatus(self::STATUS_RUNNING); - try { - $output = $provider->summarize($this->getInput()); - } catch (RuntimeException $e) { - $this->setStatus(self::STATUS_FAILED); - throw $e; - } - $this->setStatus(self::STATUS_SUCCESSFUL); - return $output; + return $provider->summarize($this->getInput()); + } + + public function canUseProvider(ILanguageModelProvider $provider): bool { + return $provider instanceof ISummaryProvider; + } + + public function getType(): string { + return self::TYPE; } } |