]> source.dussan.org Git - nextcloud-server.git/commitdiff
Make tests pass
authorMarcel Klehr <mklehr@gmx.net>
Thu, 29 Jun 2023 15:08:23 +0000 (17:08 +0200)
committerMarcel Klehr <mklehr@gmx.net>
Wed, 9 Aug 2023 08:01:21 +0000 (10:01 +0200)
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
(cherry picked from commit 66c0e6b9f79d1e82f55af2acd9d2b500c8128614)

lib/private/LanguageModel/Db/Task.php
lib/private/LanguageModel/LanguageModelManager.php
lib/private/LanguageModel/TaskBackgroundJob.php
lib/public/LanguageModel/AbstractLanguageModelTask.php

index 895969e08bb91bac6c4622437733248d0c6d0336..bbafd7583c8b62205471d493c460d7f3a3a80fa9 100644 (file)
@@ -12,6 +12,8 @@ use OCP\LanguageModel\ILanguageModelTask;
  * @method int getLastUpdated()
  * @method setInput(string $type)
  * @method string getInput()
+ * @method setOutput(string $type)
+ * @method string getOutput()
  * @method setStatus(int $type)
  * @method int getStatus()
  * @method setUserId(string $type)
@@ -21,9 +23,9 @@ use OCP\LanguageModel\ILanguageModelTask;
  */
 class Task extends Entity {
        protected $lastUpdated;
-
        protected $type;
        protected $input;
+       protected $output;
        protected $status;
        protected $userId;
        protected $appId;
@@ -45,13 +47,21 @@ class Task extends Entity {
                $this->addType('lastUpdated', 'integer');
                $this->addType('type', 'string');
                $this->addType('input', 'string');
+               $this->addType('output', 'string');
                $this->addType('status', 'integer');
                $this->addType('userId', 'string');
                $this->addType('appId', 'string');
        }
 
+       public function toRow(): array {
+               return array_combine(self::$columns, array_map(function ($field) {
+                       return $this->{'get'.ucfirst($field)}();
+               }, self::$fields));
+       }
+
        public static function fromLanguageModelTask(ILanguageModelTask $task): Task {
                return Task::fromParams([
+                       'id' => $task->getId(),
                        'type' => $task->getType(),
                        'lastUpdated' => time(),
                        'status' => $task->getStatus(),
index 4a29e8d8b18c6b83a2be618316e7d0236e499424..9117b131578ab04158110aef7ba1e88479b984fd 100644 (file)
@@ -13,11 +13,15 @@ use OCP\DB\Exception;
 use OCP\IServerContainer;
 use OCP\LanguageModel\AbstractLanguageModelTask;
 use OCP\LanguageModel\FreePromptTask;
+use OCP\LanguageModel\HeadlineTask;
+use OCP\LanguageModel\IHeadlineProvider;
 use OCP\LanguageModel\ILanguageModelManager;
 use OCP\LanguageModel\ILanguageModelProvider;
 use OCP\LanguageModel\ILanguageModelTask;
 use OCP\LanguageModel\ISummaryProvider;
+use OCP\LanguageModel\ITopicsProvider;
 use OCP\LanguageModel\SummaryTask;
+use OCP\LanguageModel\TopicsTask;
 use OCP\PreConditionNotMetException;
 use Psr\Container\ContainerExceptionInterface;
 use Psr\Container\NotFoundExceptionInterface;
@@ -69,7 +73,7 @@ class LanguageModelManager implements ILanguageModelManager {
                if ($context === null) {
                        return false;
                }
-               return count($context->getSpeechToTextProviders()) > 0;
+               return count($context->getLanguageModelProviders()) > 0;
        }
 
        /**
@@ -82,6 +86,12 @@ class LanguageModelManager implements ILanguageModelManager {
                        if ($provider instanceof ISummaryProvider) {
                                $tasks[SummaryTask::class] = true;
                        }
+                       if ($provider instanceof IHeadlineProvider) {
+                               $tasks[HeadlineTask::class] = true;
+                       }
+                       if ($provider instanceof ITopicsProvider) {
+                               $tasks[TopicsTask::class] = true;
+                       }
                }
                return array_keys($tasks);
        }
@@ -110,7 +120,8 @@ class LanguageModelManager implements ILanguageModelManager {
                        }
                        try {
                                $task->setStatus(ILanguageModelTask::STATUS_RUNNING);
-                               $this->taskMapper->update(Task::fromLanguageModelTask($task));
+                               $taskEntity = $this->taskMapper->update(Task::fromLanguageModelTask($task));
+                               $task->setId($taskEntity->getId());
                                $output = $task->visitProvider($provider);
                                $task->setOutput($output);
                                $task->setStatus(ILanguageModelTask::STATUS_SUCCESSFUL);
@@ -140,10 +151,10 @@ class LanguageModelManager implements ILanguageModelManager {
                if (!$this->canHandleTask($task)) {
                        throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
                }
+               $task->setStatus(ILanguageModelTask::STATUS_SCHEDULED);
                $taskEntity = Task::fromLanguageModelTask($task);
                $this->taskMapper->insert($taskEntity);
                $task->setId($taskEntity->getId());
-               $task->setStatus(ILanguageModelTask::STATUS_SCHEDULED);
                $this->jobList->add(TaskBackgroundJob::class, [
                        'taskId' => $task->getId()
                ]);
index 1c1291dc12699ef6dd8422fb1dc461725636943b..f58510a9e3f2cf1b304bc86ab000174c39843cd4 100644 (file)
@@ -54,7 +54,7 @@ class TaskBackgroundJob extends QueuedJob {
                try {
                        $this->languageModelManager->runTask($task);
                        $event = new TaskSuccessfulEvent($task);
-               } catch (\RuntimeException|PreConditionNotMetException $e) {
+               } catch (\RuntimeException|PreConditionNotMetException|\Throwable $e) {
                        $event = new TaskFailedEvent($task, $e->getMessage());
                }
                $this->eventDispatcher->dispatchTyped($event);
index 67c341f9b3e37dd2a2bc80073ac410708f26676b..a6b091dc14fa5b28a4cdb98a6788805caacf7e5a 100644 (file)
@@ -35,8 +35,8 @@ use OC\LanguageModel\Db\Task;
  * @template-implements ILanguageModelTask<T>
  */
 abstract class AbstractLanguageModelTask implements ILanguageModelTask {
-       protected ?int $id;
-       protected ?string $output;
+       protected ?int $id = null;
+       protected ?string $output = null;
        protected int $status = ILanguageModelTask::STATUS_UNKNOWN;
 
        /**
@@ -156,6 +156,7 @@ abstract class AbstractLanguageModelTask implements ILanguageModelTask {
                $task = self::factory($taskEntity->getType(), $taskEntity->getInput(), $taskEntity->getuserId(), $taskEntity->getAppId());
                $task->setId($taskEntity->getId());
                $task->setStatus($taskEntity->getStatus());
+               $task->setOutput($taskEntity->getOutput());
                return $task;
        }
 
@@ -169,9 +170,9 @@ abstract class AbstractLanguageModelTask implements ILanguageModelTask {
         * @since 28.0.0
         */
        final public static function factory(string $type, string $input, ?string $userId, string $appId): ILanguageModelTask {
-               if (!in_array($type, self::TYPES)) {
+               if (!in_array($type, array_keys(self::TYPES))) {
                        throw new \InvalidArgumentException('Unknown task type');
                }
-               return new (ILanguageModelTask::TYPES[$type])($input, $userId, $appId);
+               return new (ILanguageModelTask::TYPES[$type])($input, $appId, $userId);
        }
 }