summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorMarcel Klehr <mklehr@gmx.net>2023-07-14 15:59:50 +0200
committerMarcel Klehr <mklehr@gmx.net>2023-08-09 10:05:05 +0200
commitcf2c42ae36a3c7280887bd3f15329739f9a6d221 (patch)
tree3e682483cc8177afd7b9e682fcaa103791f78d2a /tests
parent696a45ddf1d460de7ffa6f252912375efd7e190e (diff)
downloadnextcloud-server-cf2c42ae36a3c7280887bd3f15329739f9a6d221.tar.gz
nextcloud-server-cf2c42ae36a3c7280887bd3f15329739f9a6d221.zip
Massive refactoring: Turn LanguageModel OCP API into TextProcessing API
Signed-off-by: Marcel Klehr <mklehr@gmx.net> (cherry picked from commit ffe27ce14ca74b509c8721c9fba7c759498fa471)
Diffstat (limited to 'tests')
-rw-r--r--tests/lib/TextProcessing/TextProcessingTest.php (renamed from tests/lib/LanguageModel/LanguageModelManagerTest.php)203
1 files changed, 99 insertions, 104 deletions
diff --git a/tests/lib/LanguageModel/LanguageModelManagerTest.php b/tests/lib/TextProcessing/TextProcessingTest.php
index 6f8d6cd868d..797571019ce 100644
--- a/tests/lib/LanguageModel/LanguageModelManagerTest.php
+++ b/tests/lib/TextProcessing/TextProcessingTest.php
@@ -6,93 +6,97 @@
* See the COPYING-README file.
*/
-namespace Test\LanguageModel;
+namespace Test\TextProcessing;
use OC\AppFramework\Bootstrap\Coordinator;
use OC\AppFramework\Bootstrap\RegistrationContext;
use OC\AppFramework\Bootstrap\ServiceRegistration;
use OC\EventDispatcher\EventDispatcher;
-use OC\LanguageModel\Db\Task;
-use OC\LanguageModel\Db\TaskMapper;
-use OC\LanguageModel\LanguageModelManager;
-use OC\LanguageModel\RemoveOldTasksBackgroundJob;
-use OC\LanguageModel\TaskBackgroundJob;
+use OC\TextProcessing\Db\Task as DbTask;
+use OC\TextProcessing\Db\TaskMapper;
+use OC\TextProcessing\Manager;
+use OC\TextProcessing\RemoveOldTasksBackgroundJob;
+use OC\TextProcessing\TaskBackgroundJob;
use OCP\AppFramework\Db\DoesNotExistException;
use OCP\AppFramework\Utility\ITimeFactory;
use OCP\Common\Exception\NotFoundException;
use OCP\EventDispatcher\IEventDispatcher;
use OCP\IServerContainer;
-use OCP\LanguageModel\Events\TaskFailedEvent;
-use OCP\LanguageModel\Events\TaskSuccessfulEvent;
-use OCP\LanguageModel\FreePromptTask;
-use OCP\LanguageModel\HeadlineTask;
-use OCP\LanguageModel\IHeadlineProvider;
-use OCP\LanguageModel\ILanguageModelManager;
-use OCP\LanguageModel\ILanguageModelProvider;
-use OCP\LanguageModel\ILanguageModelTask;
-use OCP\LanguageModel\ISummaryProvider;
-use OCP\LanguageModel\SummaryTask;
-use OCP\LanguageModel\TopicsTask;
+use OCP\TextProcessing\Events\TaskFailedEvent;
+use OCP\TextProcessing\Events\TaskSuccessfulEvent;
+use OCP\TextProcessing\FreePromptTaskType;
+use OCP\TextProcessing\IManager;
+use OCP\TextProcessing\IProvider;
+use OCP\TextProcessing\SummaryTaskType;
use OCP\PreConditionNotMetException;
+use OCP\TextProcessing\Task;
+use OCP\TextProcessing\TopicsTaskType;
use PHPUnit\Framework\Constraint\IsInstanceOf;
use Psr\Log\LoggerInterface;
use Test\BackgroundJob\DummyJobList;
-class TestVanillaLanguageModelProvider implements ILanguageModelProvider {
+class SuccessfulSummaryProvider implements IProvider {
public bool $ran = false;
public function getName(): string {
return 'TEST Vanilla LLM Provider';
}
- public function prompt(string $prompt): string {
+ public function process(string $prompt): string {
$this->ran = true;
- return $prompt . ' Free Prompt';
+ return $prompt . ' Summarize';
+ }
+
+ public function getTaskType(): string {
+ return SummaryTaskType::class;
}
}
-class TestFailingLanguageModelProvider implements ILanguageModelProvider {
+class FailingSummaryProvider implements IProvider {
public bool $ran = false;
public function getName(): string {
return 'TEST Vanilla LLM Provider';
}
- public function prompt(string $prompt): string {
+ public function process(string $prompt): string {
$this->ran = true;
throw new \Exception('ERROR');
}
+
+ public function getTaskType(): string {
+ return SummaryTaskType::class;
+ }
}
-class TestAdvancedLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider {
+class FreePromptProvider implements IProvider {
+ public bool $ran = false;
+
public function getName(): string {
- return 'TEST Full LLM Provider';
+ return 'TEST Free Prompt Provider';
}
- public function prompt(string $prompt): string {
+ public function process(string $prompt): string {
+ $this->ran = true;
return $prompt . ' Free Prompt';
}
- public function findHeadline(string $text): string {
- return $text . ' Headline';
- }
-
- public function summarize(string $text): string {
- return $text. ' Summarize';
+ public function getTaskType(): string {
+ return FreePromptTaskType::class;
}
}
-class LanguageModelManagerTest extends \Test\TestCase {
- private ILanguageModelManager $languageModelManager;
+class TextProcessingTest extends \Test\TestCase {
+ private IManager $manager;
private Coordinator $coordinator;
protected function setUp(): void {
parent::setUp();
$this->providers = [
- TestVanillaLanguageModelProvider::class => new TestVanillaLanguageModelProvider(),
- TestAdvancedLanguageModelProvider::class => new TestAdvancedLanguageModelProvider(),
- TestFailingLanguageModelProvider::class => new TestFailingLanguageModelProvider(),
+ SuccessfulSummaryProvider::class => new SuccessfulSummaryProvider(),
+ FailingSummaryProvider::class => new FailingSummaryProvider(),
+ FreePromptProvider::class => new FreePromptProvider(),
];
$this->serverContainer = $this->createMock(IServerContainer::class);
@@ -117,7 +121,7 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->taskMapper
->expects($this->any())
->method('insert')
- ->willReturnCallback(function (Task $task) {
+ ->willReturnCallback(function (DbTask $task) {
$task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1);
$task->setLastUpdated($this->currentTime->getTimestamp());
$this->tasksDb[$task->getId()] = $task->toRow();
@@ -126,7 +130,7 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->taskMapper
->expects($this->any())
->method('update')
- ->willReturnCallback(function (Task $task) {
+ ->willReturnCallback(function (DbTask $task) {
$task->setLastUpdated($this->currentTime->getTimestamp());
$this->tasksDb[$task->getId()] = $task->toRow();
return $task;
@@ -138,7 +142,7 @@ class LanguageModelManagerTest extends \Test\TestCase {
if (!isset($this->tasksDb[$id])) {
throw new DoesNotExistException('Could not find it');
}
- return Task::fromRow($this->tasksDb[$id]);
+ return DbTask::fromRow($this->tasksDb[$id]);
});
$this->taskMapper
->expects($this->any())
@@ -153,7 +157,7 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->jobList->expects($this->any())->method('add')->willReturnCallback(function () {
});
- $this->languageModelManager = new LanguageModelManager(
+ $this->manager = new Manager(
$this->serverContainer,
$this->coordinator,
\OC::$server->get(LoggerInterface::class),
@@ -163,57 +167,54 @@ class LanguageModelManagerTest extends \Test\TestCase {
}
public function testShouldNotHaveAnyProviders() {
- $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([]);
- $this->assertCount(0, $this->languageModelManager->getAvailableTaskClasses());
- $this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes());
- $this->assertFalse($this->languageModelManager->hasProviders());
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]);
+ $this->assertCount(0, $this->manager->getAvailableTaskTypes());
+ $this->assertFalse($this->manager->hasProviders());
$this->expectException(PreConditionNotMetException::class);
- $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null));
+ $this->manager->runTask(new \OCP\TextProcessing\Task(FreePromptTaskType::class, 'Hello', 'test', null));
}
public function testProviderShouldBeRegisteredAndRun() {
- $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
- new ServiceRegistration('test', TestVanillaLanguageModelProvider::class)
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
+ new ServiceRegistration('test', SuccessfulSummaryProvider::class)
]);
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses());
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
- $this->assertTrue($this->languageModelManager->hasProviders());
- $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)));
+ $this->assertCount(1, $this->manager->getAvailableTaskTypes());
+ $this->assertTrue($this->manager->hasProviders());
+ $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null)));
// Summaries are not implemented by the vanilla provider, only free prompt
$this->expectException(PreConditionNotMetException::class);
- $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null));
+ $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null));
}
public function testProviderShouldBeRegisteredAndScheduled() {
// register provider
- $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
- new ServiceRegistration('test', TestVanillaLanguageModelProvider::class)
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
+ new ServiceRegistration('test', SuccessfulSummaryProvider::class)
]);
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses());
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
- $this->assertTrue($this->languageModelManager->hasProviders());
+ $this->assertCount(1, $this->manager->getAvailableTaskTypes());
+ $this->assertTrue($this->manager->hasProviders());
// create task object
- $task = new FreePromptTask('Hello', 'test', null);
+ $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
$this->assertNull($task->getId());
$this->assertNull($task->getOutput());
// schedule works
- $this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus());
- $this->languageModelManager->scheduleTask($task);
+ $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
+ $this->manager->scheduleTask($task);
// Task object is up-to-date
$this->assertNotNull($task->getId());
$this->assertNull($task->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus());
+ $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
// Task object retrieved from db is up-to-date
- $task2 = $this->languageModelManager->getTask($task->getId());
+ $task2 = $this->manager->getTask($task->getId());
$this->assertEquals($task->getId(), $task2->getId());
$this->assertEquals('Hello', $task2->getInput());
$this->assertNull($task2->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());
+ $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
$this->eventDispatcher = $this->createMock(IEventDispatcher::class);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
@@ -221,79 +222,74 @@ class LanguageModelManagerTest extends \Test\TestCase {
// run background job
$bgJob = new TaskBackgroundJob(
\OC::$server->get(ITimeFactory::class),
- $this->languageModelManager,
+ $this->manager,
$this->eventDispatcher,
);
$bgJob->setArgument(['taskId' => $task->getId()]);
$bgJob->start($this->jobList);
- $provider = $this->providers[TestVanillaLanguageModelProvider::class];
+ $provider = $this->providers[SuccessfulSummaryProvider::class];
$this->assertTrue($provider->ran);
// Task object retrieved from db is up-to-date
- $task3 = $this->languageModelManager->getTask($task->getId());
+ $task3 = $this->manager->getTask($task->getId());
$this->assertEquals($task->getId(), $task3->getId());
$this->assertEquals('Hello', $task3->getInput());
- $this->assertEquals('Hello Free Prompt', $task3->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task3->getStatus());
+ $this->assertEquals('Hello Summarize', $task3->getOutput());
+ $this->assertEquals(Task::STATUS_SUCCESSFUL, $task3->getStatus());
}
public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() {
- $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
- new ServiceRegistration('test', TestVanillaLanguageModelProvider::class),
- new ServiceRegistration('test', TestAdvancedLanguageModelProvider::class),
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
+ new ServiceRegistration('test', SuccessfulSummaryProvider::class),
+ new ServiceRegistration('test', FreePromptProvider::class),
]);
- $this->assertCount(3, $this->languageModelManager->getAvailableTaskClasses());
- $this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes());
- $this->assertTrue($this->languageModelManager->hasProviders());
+ $this->assertCount(2, $this->manager->getAvailableTaskTypes());
+ $this->assertTrue($this->manager->hasProviders());
// Try free prompt again
- $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)));
-
- // Try headline task
- $this->assertEquals('Hello Headline', $this->languageModelManager->runTask(new HeadlineTask('Hello', 'test', null)));
+ $this->assertEquals('Hello Free Prompt', $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null)));
// Try summary task
- $this->assertEquals('Hello Summarize', $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null)));
+ $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null)));
// Topics are not implemented by both the vanilla provider and the full provider
$this->expectException(PreConditionNotMetException::class);
- $this->languageModelManager->runTask(new TopicsTask('Hello', 'test', null));
+ $this->manager->runTask(new Task(TopicsTaskType::class, 'Hello', 'test', null));
}
public function testNonexistentTask() {
$this->expectException(NotFoundException::class);
- $this->languageModelManager->getTask(98765432456);
+ $this->manager->getTask(98765432456);
}
public function testTaskFailure() {
// register provider
- $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
- new ServiceRegistration('test', TestFailingLanguageModelProvider::class),
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
+ new ServiceRegistration('test', FailingSummaryProvider::class),
]);
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses());
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
- $this->assertTrue($this->languageModelManager->hasProviders());
+ $this->assertCount(1, $this->manager->getAvailableTaskTypes());
+ $this->assertTrue($this->manager->hasProviders());
// create task object
- $task = new FreePromptTask('Hello', 'test', null);
+ $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
$this->assertNull($task->getId());
$this->assertNull($task->getOutput());
// schedule works
- $this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus());
- $this->languageModelManager->scheduleTask($task);
+ $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus());
+ $this->manager->scheduleTask($task);
// Task object is up-to-date
$this->assertNotNull($task->getId());
$this->assertNull($task->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus());
+ $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus());
// Task object retrieved from db is up-to-date
- $task2 = $this->languageModelManager->getTask($task->getId());
+ $task2 = $this->manager->getTask($task->getId());
$this->assertEquals($task->getId(), $task2->getId());
$this->assertEquals('Hello', $task2->getInput());
$this->assertNull($task2->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());
+ $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus());
$this->eventDispatcher = $this->createMock(IEventDispatcher::class);
$this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
@@ -301,31 +297,30 @@ class LanguageModelManagerTest extends \Test\TestCase {
// run background job
$bgJob = new TaskBackgroundJob(
\OC::$server->get(ITimeFactory::class),
- $this->languageModelManager,
+ $this->manager,
$this->eventDispatcher,
);
$bgJob->setArgument(['taskId' => $task->getId()]);
$bgJob->start($this->jobList);
- $provider = $this->providers[TestFailingLanguageModelProvider::class];
+ $provider = $this->providers[FailingSummaryProvider::class];
$this->assertTrue($provider->ran);
// Task object retrieved from db is up-to-date
- $task3 = $this->languageModelManager->getTask($task->getId());
+ $task3 = $this->manager->getTask($task->getId());
$this->assertEquals($task->getId(), $task3->getId());
$this->assertEquals('Hello', $task3->getInput());
$this->assertNull($task3->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task3->getStatus());
+ $this->assertEquals(Task::STATUS_FAILED, $task3->getStatus());
}
public function testOldTasksShouldBeCleanedUp() {
- $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
- new ServiceRegistration('test', TestVanillaLanguageModelProvider::class)
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
+ new ServiceRegistration('test', SuccessfulSummaryProvider::class)
]);
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses());
- $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
- $this->assertTrue($this->languageModelManager->hasProviders());
- $task = new FreePromptTask('Hello', 'test', null);
- $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask($task));
+ $this->assertCount(1, $this->manager->getAvailableTaskTypes());
+ $this->assertTrue($this->manager->hasProviders());
+ $task = new Task(SummaryTaskType::class, 'Hello', 'test', null);
+ $this->assertEquals('Hello Summarize', $this->manager->runTask($task));
$this->currentTime = $this->currentTime->add(new \DateInterval('P1Y'));
// run background job
@@ -338,6 +333,6 @@ class LanguageModelManagerTest extends \Test\TestCase {
$bgJob->start($this->jobList);
$this->expectException(NotFoundException::class);
- $this->languageModelManager->getTask($task->getId());
+ $this->manager->getTask($task->getId());
}
}