summaryrefslogtreecommitdiffstats
path: root/lib
diff options
context:
space:
mode:
authorMarcel Klehr <mklehr@gmx.net>2023-06-29 17:08:23 +0200
committerMarcel Klehr <mklehr@gmx.net>2023-08-09 10:01:21 +0200
commite7179598c7193ef0437811bdcad6af61980b9c0f (patch)
tree5faae0133b02eaca3705be7f289d128b5f1ed61f /lib
parentd21f7bf1fb2b420e7f5049983398d67f36fe044a (diff)
downloadnextcloud-server-e7179598c7193ef0437811bdcad6af61980b9c0f.tar.gz
nextcloud-server-e7179598c7193ef0437811bdcad6af61980b9c0f.zip
Make tests pass
Signed-off-by: Marcel Klehr <mklehr@gmx.net> (cherry picked from commit 66c0e6b9f79d1e82f55af2acd9d2b500c8128614)
Diffstat (limited to 'lib')
-rw-r--r--lib/private/LanguageModel/Db/Task.php12
-rw-r--r--lib/private/LanguageModel/LanguageModelManager.php17
-rw-r--r--lib/private/LanguageModel/TaskBackgroundJob.php2
-rw-r--r--lib/public/LanguageModel/AbstractLanguageModelTask.php9
4 files changed, 31 insertions, 9 deletions
diff --git a/lib/private/LanguageModel/Db/Task.php b/lib/private/LanguageModel/Db/Task.php
index 895969e08bb..bbafd7583c8 100644
--- a/lib/private/LanguageModel/Db/Task.php
+++ b/lib/private/LanguageModel/Db/Task.php
@@ -12,6 +12,8 @@ use OCP\LanguageModel\ILanguageModelTask;
* @method int getLastUpdated()
* @method setInput(string $type)
* @method string getInput()
+ * @method setOutput(string $type)
+ * @method string getOutput()
* @method setStatus(int $type)
* @method int getStatus()
* @method setUserId(string $type)
@@ -21,9 +23,9 @@ use OCP\LanguageModel\ILanguageModelTask;
*/
class Task extends Entity {
protected $lastUpdated;
-
protected $type;
protected $input;
+ protected $output;
protected $status;
protected $userId;
protected $appId;
@@ -45,13 +47,21 @@ class Task extends Entity {
$this->addType('lastUpdated', 'integer');
$this->addType('type', 'string');
$this->addType('input', 'string');
+ $this->addType('output', 'string');
$this->addType('status', 'integer');
$this->addType('userId', 'string');
$this->addType('appId', 'string');
}
+ public function toRow(): array {
+ return array_combine(self::$columns, array_map(function ($field) {
+ return $this->{'get'.ucfirst($field)}();
+ }, self::$fields));
+ }
+
public static function fromLanguageModelTask(ILanguageModelTask $task): Task {
return Task::fromParams([
+ 'id' => $task->getId(),
'type' => $task->getType(),
'lastUpdated' => time(),
'status' => $task->getStatus(),
diff --git a/lib/private/LanguageModel/LanguageModelManager.php b/lib/private/LanguageModel/LanguageModelManager.php
index 4a29e8d8b18..9117b131578 100644
--- a/lib/private/LanguageModel/LanguageModelManager.php
+++ b/lib/private/LanguageModel/LanguageModelManager.php
@@ -13,11 +13,15 @@ use OCP\DB\Exception;
use OCP\IServerContainer;
use OCP\LanguageModel\AbstractLanguageModelTask;
use OCP\LanguageModel\FreePromptTask;
+use OCP\LanguageModel\HeadlineTask;
+use OCP\LanguageModel\IHeadlineProvider;
use OCP\LanguageModel\ILanguageModelManager;
use OCP\LanguageModel\ILanguageModelProvider;
use OCP\LanguageModel\ILanguageModelTask;
use OCP\LanguageModel\ISummaryProvider;
+use OCP\LanguageModel\ITopicsProvider;
use OCP\LanguageModel\SummaryTask;
+use OCP\LanguageModel\TopicsTask;
use OCP\PreConditionNotMetException;
use Psr\Container\ContainerExceptionInterface;
use Psr\Container\NotFoundExceptionInterface;
@@ -69,7 +73,7 @@ class LanguageModelManager implements ILanguageModelManager {
if ($context === null) {
return false;
}
- return count($context->getSpeechToTextProviders()) > 0;
+ return count($context->getLanguageModelProviders()) > 0;
}
/**
@@ -82,6 +86,12 @@ class LanguageModelManager implements ILanguageModelManager {
if ($provider instanceof ISummaryProvider) {
$tasks[SummaryTask::class] = true;
}
+ if ($provider instanceof IHeadlineProvider) {
+ $tasks[HeadlineTask::class] = true;
+ }
+ if ($provider instanceof ITopicsProvider) {
+ $tasks[TopicsTask::class] = true;
+ }
}
return array_keys($tasks);
}
@@ -110,7 +120,8 @@ class LanguageModelManager implements ILanguageModelManager {
}
try {
$task->setStatus(ILanguageModelTask::STATUS_RUNNING);
- $this->taskMapper->update(Task::fromLanguageModelTask($task));
+ $taskEntity = $this->taskMapper->update(Task::fromLanguageModelTask($task));
+ $task->setId($taskEntity->getId());
$output = $task->visitProvider($provider);
$task->setOutput($output);
$task->setStatus(ILanguageModelTask::STATUS_SUCCESSFUL);
@@ -140,10 +151,10 @@ class LanguageModelManager implements ILanguageModelManager {
if (!$this->canHandleTask($task)) {
throw new PreConditionNotMetException('No LanguageModel provider is installed that can handle this task');
}
+ $task->setStatus(ILanguageModelTask::STATUS_SCHEDULED);
$taskEntity = Task::fromLanguageModelTask($task);
$this->taskMapper->insert($taskEntity);
$task->setId($taskEntity->getId());
- $task->setStatus(ILanguageModelTask::STATUS_SCHEDULED);
$this->jobList->add(TaskBackgroundJob::class, [
'taskId' => $task->getId()
]);
diff --git a/lib/private/LanguageModel/TaskBackgroundJob.php b/lib/private/LanguageModel/TaskBackgroundJob.php
index 1c1291dc126..f58510a9e3f 100644
--- a/lib/private/LanguageModel/TaskBackgroundJob.php
+++ b/lib/private/LanguageModel/TaskBackgroundJob.php
@@ -54,7 +54,7 @@ class TaskBackgroundJob extends QueuedJob {
try {
$this->languageModelManager->runTask($task);
$event = new TaskSuccessfulEvent($task);
- } catch (\RuntimeException|PreConditionNotMetException $e) {
+ } catch (\RuntimeException|PreConditionNotMetException|\Throwable $e) {
$event = new TaskFailedEvent($task, $e->getMessage());
}
$this->eventDispatcher->dispatchTyped($event);
diff --git a/lib/public/LanguageModel/AbstractLanguageModelTask.php b/lib/public/LanguageModel/AbstractLanguageModelTask.php
index 67c341f9b3e..a6b091dc14f 100644
--- a/lib/public/LanguageModel/AbstractLanguageModelTask.php
+++ b/lib/public/LanguageModel/AbstractLanguageModelTask.php
@@ -35,8 +35,8 @@ use OC\LanguageModel\Db\Task;
* @template-implements ILanguageModelTask<T>
*/
abstract class AbstractLanguageModelTask implements ILanguageModelTask {
- protected ?int $id;
- protected ?string $output;
+ protected ?int $id = null;
+ protected ?string $output = null;
protected int $status = ILanguageModelTask::STATUS_UNKNOWN;
/**
@@ -156,6 +156,7 @@ abstract class AbstractLanguageModelTask implements ILanguageModelTask {
$task = self::factory($taskEntity->getType(), $taskEntity->getInput(), $taskEntity->getuserId(), $taskEntity->getAppId());
$task->setId($taskEntity->getId());
$task->setStatus($taskEntity->getStatus());
+ $task->setOutput($taskEntity->getOutput());
return $task;
}
@@ -169,9 +170,9 @@ abstract class AbstractLanguageModelTask implements ILanguageModelTask {
* @since 28.0.0
*/
final public static function factory(string $type, string $input, ?string $userId, string $appId): ILanguageModelTask {
- if (!in_array($type, self::TYPES)) {
+ if (!in_array($type, array_keys(self::TYPES))) {
throw new \InvalidArgumentException('Unknown task type');
}
- return new (ILanguageModelTask::TYPES[$type])($input, $userId, $appId);
+ return new (ILanguageModelTask::TYPES[$type])($input, $appId, $userId);
}
}