diff options
author | Marcel Klehr <mklehr@gmx.net> | 2023-07-14 15:59:50 +0200 |
---|---|---|
committer | Marcel Klehr <mklehr@gmx.net> | 2023-08-09 10:05:05 +0200 |
commit | cf2c42ae36a3c7280887bd3f15329739f9a6d221 (patch) | |
tree | 3e682483cc8177afd7b9e682fcaa103791f78d2a /tests/lib | |
parent | 696a45ddf1d460de7ffa6f252912375efd7e190e (diff) | |
download | nextcloud-server-cf2c42ae36a3c7280887bd3f15329739f9a6d221.tar.gz nextcloud-server-cf2c42ae36a3c7280887bd3f15329739f9a6d221.zip |
Massive refactoring: Turn LanguageModel OCP API into TextProcessing API
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
(cherry picked from commit ffe27ce14ca74b509c8721c9fba7c759498fa471)
Diffstat (limited to 'tests/lib')
-rw-r--r-- | tests/lib/TextProcessing/TextProcessingTest.php (renamed from tests/lib/LanguageModel/LanguageModelManagerTest.php) | 203 |
1 files changed, 99 insertions, 104 deletions
diff --git a/tests/lib/LanguageModel/LanguageModelManagerTest.php b/tests/lib/TextProcessing/TextProcessingTest.php index 6f8d6cd868d..797571019ce 100644 --- a/tests/lib/LanguageModel/LanguageModelManagerTest.php +++ b/tests/lib/TextProcessing/TextProcessingTest.php @@ -6,93 +6,97 @@ * See the COPYING-README file. */ -namespace Test\LanguageModel; +namespace Test\TextProcessing; use OC\AppFramework\Bootstrap\Coordinator; use OC\AppFramework\Bootstrap\RegistrationContext; use OC\AppFramework\Bootstrap\ServiceRegistration; use OC\EventDispatcher\EventDispatcher; -use OC\LanguageModel\Db\Task; -use OC\LanguageModel\Db\TaskMapper; -use OC\LanguageModel\LanguageModelManager; -use OC\LanguageModel\RemoveOldTasksBackgroundJob; -use OC\LanguageModel\TaskBackgroundJob; +use OC\TextProcessing\Db\Task as DbTask; +use OC\TextProcessing\Db\TaskMapper; +use OC\TextProcessing\Manager; +use OC\TextProcessing\RemoveOldTasksBackgroundJob; +use OC\TextProcessing\TaskBackgroundJob; use OCP\AppFramework\Db\DoesNotExistException; use OCP\AppFramework\Utility\ITimeFactory; use OCP\Common\Exception\NotFoundException; use OCP\EventDispatcher\IEventDispatcher; use OCP\IServerContainer; -use OCP\LanguageModel\Events\TaskFailedEvent; -use OCP\LanguageModel\Events\TaskSuccessfulEvent; -use OCP\LanguageModel\FreePromptTask; -use OCP\LanguageModel\HeadlineTask; -use OCP\LanguageModel\IHeadlineProvider; -use OCP\LanguageModel\ILanguageModelManager; -use OCP\LanguageModel\ILanguageModelProvider; -use OCP\LanguageModel\ILanguageModelTask; -use OCP\LanguageModel\ISummaryProvider; -use OCP\LanguageModel\SummaryTask; -use OCP\LanguageModel\TopicsTask; +use OCP\TextProcessing\Events\TaskFailedEvent; +use OCP\TextProcessing\Events\TaskSuccessfulEvent; +use OCP\TextProcessing\FreePromptTaskType; +use OCP\TextProcessing\IManager; +use OCP\TextProcessing\IProvider; +use OCP\TextProcessing\SummaryTaskType; use OCP\PreConditionNotMetException; +use OCP\TextProcessing\Task; +use OCP\TextProcessing\TopicsTaskType; use PHPUnit\Framework\Constraint\IsInstanceOf; use Psr\Log\LoggerInterface; use Test\BackgroundJob\DummyJobList; -class TestVanillaLanguageModelProvider implements ILanguageModelProvider { +class SuccessfulSummaryProvider implements IProvider { public bool $ran = false; public function getName(): string { return 'TEST Vanilla LLM Provider'; } - public function prompt(string $prompt): string { + public function process(string $prompt): string { $this->ran = true; - return $prompt . ' Free Prompt'; + return $prompt . ' Summarize'; + } + + public function getTaskType(): string { + return SummaryTaskType::class; } } -class TestFailingLanguageModelProvider implements ILanguageModelProvider { +class FailingSummaryProvider implements IProvider { public bool $ran = false; public function getName(): string { return 'TEST Vanilla LLM Provider'; } - public function prompt(string $prompt): string { + public function process(string $prompt): string { $this->ran = true; throw new \Exception('ERROR'); } + + public function getTaskType(): string { + return SummaryTaskType::class; + } } -class TestAdvancedLanguageModelProvider implements ILanguageModelProvider, ISummaryProvider, IHeadlineProvider { +class FreePromptProvider implements IProvider { + public bool $ran = false; + public function getName(): string { - return 'TEST Full LLM Provider'; + return 'TEST Free Prompt Provider'; } - public function prompt(string $prompt): string { + public function process(string $prompt): string { + $this->ran = true; return $prompt . ' Free Prompt'; } - public function findHeadline(string $text): string { - return $text . ' Headline'; - } - - public function summarize(string $text): string { - return $text. ' Summarize'; + public function getTaskType(): string { + return FreePromptTaskType::class; } } -class LanguageModelManagerTest extends \Test\TestCase { - private ILanguageModelManager $languageModelManager; +class TextProcessingTest extends \Test\TestCase { + private IManager $manager; private Coordinator $coordinator; protected function setUp(): void { parent::setUp(); $this->providers = [ - TestVanillaLanguageModelProvider::class => new TestVanillaLanguageModelProvider(), - TestAdvancedLanguageModelProvider::class => new TestAdvancedLanguageModelProvider(), - TestFailingLanguageModelProvider::class => new TestFailingLanguageModelProvider(), + SuccessfulSummaryProvider::class => new SuccessfulSummaryProvider(), + FailingSummaryProvider::class => new FailingSummaryProvider(), + FreePromptProvider::class => new FreePromptProvider(), ]; $this->serverContainer = $this->createMock(IServerContainer::class); @@ -117,7 +121,7 @@ class LanguageModelManagerTest extends \Test\TestCase { $this->taskMapper ->expects($this->any()) ->method('insert') - ->willReturnCallback(function (Task $task) { + ->willReturnCallback(function (DbTask $task) { $task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1); $task->setLastUpdated($this->currentTime->getTimestamp()); $this->tasksDb[$task->getId()] = $task->toRow(); @@ -126,7 +130,7 @@ class LanguageModelManagerTest extends \Test\TestCase { $this->taskMapper ->expects($this->any()) ->method('update') - ->willReturnCallback(function (Task $task) { + ->willReturnCallback(function (DbTask $task) { $task->setLastUpdated($this->currentTime->getTimestamp()); $this->tasksDb[$task->getId()] = $task->toRow(); return $task; @@ -138,7 +142,7 @@ class LanguageModelManagerTest extends \Test\TestCase { if (!isset($this->tasksDb[$id])) { throw new DoesNotExistException('Could not find it'); } - return Task::fromRow($this->tasksDb[$id]); + return DbTask::fromRow($this->tasksDb[$id]); }); $this->taskMapper ->expects($this->any()) @@ -153,7 +157,7 @@ class LanguageModelManagerTest extends \Test\TestCase { $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () { }); - $this->languageModelManager = new LanguageModelManager( + $this->manager = new Manager( $this->serverContainer, $this->coordinator, \OC::$server->get(LoggerInterface::class), @@ -163,57 +167,54 @@ class LanguageModelManagerTest extends \Test\TestCase { } public function testShouldNotHaveAnyProviders() { - $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([]); - $this->assertCount(0, $this->languageModelManager->getAvailableTaskClasses()); - $this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes()); - $this->assertFalse($this->languageModelManager->hasProviders()); + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); + $this->assertCount(0, $this->manager->getAvailableTaskTypes()); + $this->assertFalse($this->manager->hasProviders()); $this->expectException(PreConditionNotMetException::class); - $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null)); + $this->manager->runTask(new \OCP\TextProcessing\Task(FreePromptTaskType::class, 'Hello', 'test', null)); } public function testProviderShouldBeRegisteredAndRun() { - $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ - new ServiceRegistration('test', TestVanillaLanguageModelProvider::class) + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class) ]); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses()); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); - $this->assertTrue($this->languageModelManager->hasProviders()); - $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); + $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null))); // Summaries are not implemented by the vanilla provider, only free prompt $this->expectException(PreConditionNotMetException::class); - $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null)); + $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null)); } public function testProviderShouldBeRegisteredAndScheduled() { // register provider - $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ - new ServiceRegistration('test', TestVanillaLanguageModelProvider::class) + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class) ]); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses()); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); - $this->assertTrue($this->languageModelManager->hasProviders()); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); // create task object - $task = new FreePromptTask('Hello', 'test', null); + $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); $this->assertNull($task->getId()); $this->assertNull($task->getOutput()); // schedule works - $this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); - $this->languageModelManager->scheduleTask($task); + $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); + $this->manager->scheduleTask($task); // Task object is up-to-date $this->assertNotNull($task->getId()); $this->assertNull($task->getOutput()); - $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); // Task object retrieved from db is up-to-date - $task2 = $this->languageModelManager->getTask($task->getId()); + $task2 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task2->getId()); $this->assertEquals('Hello', $task2->getInput()); $this->assertNull($task2->getOutput()); - $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); $this->eventDispatcher = $this->createMock(IEventDispatcher::class); $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); @@ -221,79 +222,74 @@ class LanguageModelManagerTest extends \Test\TestCase { // run background job $bgJob = new TaskBackgroundJob( \OC::$server->get(ITimeFactory::class), - $this->languageModelManager, + $this->manager, $this->eventDispatcher, ); $bgJob->setArgument(['taskId' => $task->getId()]); $bgJob->start($this->jobList); - $provider = $this->providers[TestVanillaLanguageModelProvider::class]; + $provider = $this->providers[SuccessfulSummaryProvider::class]; $this->assertTrue($provider->ran); // Task object retrieved from db is up-to-date - $task3 = $this->languageModelManager->getTask($task->getId()); + $task3 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task3->getId()); $this->assertEquals('Hello', $task3->getInput()); - $this->assertEquals('Hello Free Prompt', $task3->getOutput()); - $this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task3->getStatus()); + $this->assertEquals('Hello Summarize', $task3->getOutput()); + $this->assertEquals(Task::STATUS_SUCCESSFUL, $task3->getStatus()); } public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() { - $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ - new ServiceRegistration('test', TestVanillaLanguageModelProvider::class), - new ServiceRegistration('test', TestAdvancedLanguageModelProvider::class), + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class), + new ServiceRegistration('test', FreePromptProvider::class), ]); - $this->assertCount(3, $this->languageModelManager->getAvailableTaskClasses()); - $this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes()); - $this->assertTrue($this->languageModelManager->hasProviders()); + $this->assertCount(2, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); // Try free prompt again - $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask(new FreePromptTask('Hello', 'test', null))); - - // Try headline task - $this->assertEquals('Hello Headline', $this->languageModelManager->runTask(new HeadlineTask('Hello', 'test', null))); + $this->assertEquals('Hello Free Prompt', $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null))); // Try summary task - $this->assertEquals('Hello Summarize', $this->languageModelManager->runTask(new SummaryTask('Hello', 'test', null))); + $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null))); // Topics are not implemented by both the vanilla provider and the full provider $this->expectException(PreConditionNotMetException::class); - $this->languageModelManager->runTask(new TopicsTask('Hello', 'test', null)); + $this->manager->runTask(new Task(TopicsTaskType::class, 'Hello', 'test', null)); } public function testNonexistentTask() { $this->expectException(NotFoundException::class); - $this->languageModelManager->getTask(98765432456); + $this->manager->getTask(98765432456); } public function testTaskFailure() { // register provider - $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ - new ServiceRegistration('test', TestFailingLanguageModelProvider::class), + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', FailingSummaryProvider::class), ]); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses()); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); - $this->assertTrue($this->languageModelManager->hasProviders()); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); // create task object - $task = new FreePromptTask('Hello', 'test', null); + $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); $this->assertNull($task->getId()); $this->assertNull($task->getOutput()); // schedule works - $this->assertEquals(ILanguageModelTask::STATUS_UNKNOWN, $task->getStatus()); - $this->languageModelManager->scheduleTask($task); + $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); + $this->manager->scheduleTask($task); // Task object is up-to-date $this->assertNotNull($task->getId()); $this->assertNull($task->getOutput()); - $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task->getStatus()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); // Task object retrieved from db is up-to-date - $task2 = $this->languageModelManager->getTask($task->getId()); + $task2 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task2->getId()); $this->assertEquals('Hello', $task2->getInput()); $this->assertNull($task2->getOutput()); - $this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); $this->eventDispatcher = $this->createMock(IEventDispatcher::class); $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); @@ -301,31 +297,30 @@ class LanguageModelManagerTest extends \Test\TestCase { // run background job $bgJob = new TaskBackgroundJob( \OC::$server->get(ITimeFactory::class), - $this->languageModelManager, + $this->manager, $this->eventDispatcher, ); $bgJob->setArgument(['taskId' => $task->getId()]); $bgJob->start($this->jobList); - $provider = $this->providers[TestFailingLanguageModelProvider::class]; + $provider = $this->providers[FailingSummaryProvider::class]; $this->assertTrue($provider->ran); // Task object retrieved from db is up-to-date - $task3 = $this->languageModelManager->getTask($task->getId()); + $task3 = $this->manager->getTask($task->getId()); $this->assertEquals($task->getId(), $task3->getId()); $this->assertEquals('Hello', $task3->getInput()); $this->assertNull($task3->getOutput()); - $this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task3->getStatus()); + $this->assertEquals(Task::STATUS_FAILED, $task3->getStatus()); } public function testOldTasksShouldBeCleanedUp() { - $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([ - new ServiceRegistration('test', TestVanillaLanguageModelProvider::class) + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class) ]); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskClasses()); - $this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes()); - $this->assertTrue($this->languageModelManager->hasProviders()); - $task = new FreePromptTask('Hello', 'test', null); - $this->assertEquals('Hello Free Prompt', $this->languageModelManager->runTask($task)); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); + $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); + $this->assertEquals('Hello Summarize', $this->manager->runTask($task)); $this->currentTime = $this->currentTime->add(new \DateInterval('P1Y')); // run background job @@ -338,6 +333,6 @@ class LanguageModelManagerTest extends \Test\TestCase { $bgJob->start($this->jobList); $this->expectException(NotFoundException::class); - $this->languageModelManager->getTask($task->getId()); + $this->manager->getTask($task->getId()); } } |