diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/lib/BackgroundJob/DummyJobList.php | 2 | ||||
-rw-r--r-- | tests/lib/TextProcessing/TextProcessingTest.php | 338 |
2 files changed, 339 insertions, 1 deletions
diff --git a/tests/lib/BackgroundJob/DummyJobList.php b/tests/lib/BackgroundJob/DummyJobList.php index 42b69cfbe41..8574f462ca7 100644 --- a/tests/lib/BackgroundJob/DummyJobList.php +++ b/tests/lib/BackgroundJob/DummyJobList.php @@ -141,7 +141,7 @@ class DummyJobList extends \OC\BackgroundJob\JobList { } public function hasReservedJob(?string $className = null): bool { - return $this->reserved[$className ?? '']; + return isset($this->reserved[$className ?? '']) && $this->reserved[$className ?? '']; } public function setHasReservedJob(?string $className, bool $hasReserved): void { diff --git a/tests/lib/TextProcessing/TextProcessingTest.php b/tests/lib/TextProcessing/TextProcessingTest.php new file mode 100644 index 00000000000..797571019ce --- /dev/null +++ b/tests/lib/TextProcessing/TextProcessingTest.php @@ -0,0 +1,338 @@ +<?php +/** + * Copyright (c) 2023 Marcel Klehr <mklehr@gmx.net> + * This file is licensed under the Affero General Public License version 3 or + * later. + * See the COPYING-README file. + */ + +namespace Test\TextProcessing; + +use OC\AppFramework\Bootstrap\Coordinator; +use OC\AppFramework\Bootstrap\RegistrationContext; +use OC\AppFramework\Bootstrap\ServiceRegistration; +use OC\EventDispatcher\EventDispatcher; +use OC\TextProcessing\Db\Task as DbTask; +use OC\TextProcessing\Db\TaskMapper; +use OC\TextProcessing\Manager; +use OC\TextProcessing\RemoveOldTasksBackgroundJob; +use OC\TextProcessing\TaskBackgroundJob; +use OCP\AppFramework\Db\DoesNotExistException; +use OCP\AppFramework\Utility\ITimeFactory; +use OCP\Common\Exception\NotFoundException; +use OCP\EventDispatcher\IEventDispatcher; +use OCP\IServerContainer; +use OCP\TextProcessing\Events\TaskFailedEvent; +use OCP\TextProcessing\Events\TaskSuccessfulEvent; +use OCP\TextProcessing\FreePromptTaskType; +use OCP\TextProcessing\IManager; +use OCP\TextProcessing\IProvider; +use OCP\TextProcessing\SummaryTaskType; +use OCP\PreConditionNotMetException; +use OCP\TextProcessing\Task; +use OCP\TextProcessing\TopicsTaskType; +use PHPUnit\Framework\Constraint\IsInstanceOf; +use Psr\Log\LoggerInterface; +use Test\BackgroundJob\DummyJobList; + +class SuccessfulSummaryProvider implements IProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function process(string $prompt): string { + $this->ran = true; + return $prompt . ' Summarize'; + } + + public function getTaskType(): string { + return SummaryTaskType::class; + } +} + +class FailingSummaryProvider implements IProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function process(string $prompt): string { + $this->ran = true; + throw new \Exception('ERROR'); + } + + public function getTaskType(): string { + return SummaryTaskType::class; + } +} + +class FreePromptProvider implements IProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Free Prompt Provider'; + } + + public function process(string $prompt): string { + $this->ran = true; + return $prompt . ' Free Prompt'; + } + + public function getTaskType(): string { + return FreePromptTaskType::class; + } +} + +class TextProcessingTest extends \Test\TestCase { + private IManager $manager; + private Coordinator $coordinator; + + protected function setUp(): void { + parent::setUp(); + + $this->providers = [ + SuccessfulSummaryProvider::class => new SuccessfulSummaryProvider(), + FailingSummaryProvider::class => new FailingSummaryProvider(), + FreePromptProvider::class => new FreePromptProvider(), + ]; + + $this->serverContainer = $this->createMock(IServerContainer::class); + $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) { + return $this->providers[$class]; + }); + + $this->eventDispatcher = new EventDispatcher( + new \Symfony\Component\EventDispatcher\EventDispatcher(), + $this->serverContainer, + \OC::$server->get(LoggerInterface::class), + ); + + $this->registrationContext = $this->createMock(RegistrationContext::class); + $this->coordinator = $this->createMock(Coordinator::class); + $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext); + + $this->currentTime = new \DateTimeImmutable('now'); + + $this->taskMapper = $this->createMock(TaskMapper::class); + $this->tasksDb = []; + $this->taskMapper + ->expects($this->any()) + ->method('insert') + ->willReturnCallback(function (DbTask $task) { + $task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1); + $task->setLastUpdated($this->currentTime->getTimestamp()); + $this->tasksDb[$task->getId()] = $task->toRow(); + return $task; + }); + $this->taskMapper + ->expects($this->any()) + ->method('update') + ->willReturnCallback(function (DbTask $task) { + $task->setLastUpdated($this->currentTime->getTimestamp()); + $this->tasksDb[$task->getId()] = $task->toRow(); + return $task; + }); + $this->taskMapper + ->expects($this->any()) + ->method('find') + ->willReturnCallback(function (int $id) { + if (!isset($this->tasksDb[$id])) { + throw new DoesNotExistException('Could not find it'); + } + return DbTask::fromRow($this->tasksDb[$id]); + }); + $this->taskMapper + ->expects($this->any()) + ->method('deleteOlderThan') + ->willReturnCallback(function (int $timeout) { + $this->tasksDb = array_filter($this->tasksDb, function (array $task) use ($timeout) { + return $task['last_updated'] >= $this->currentTime->getTimestamp() - $timeout; + }); + }); + + $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']); + $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () { + }); + + $this->manager = new Manager( + $this->serverContainer, + $this->coordinator, + \OC::$server->get(LoggerInterface::class), + $this->jobList, + $this->taskMapper, + ); + } + + public function testShouldNotHaveAnyProviders() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); + $this->assertCount(0, $this->manager->getAvailableTaskTypes()); + $this->assertFalse($this->manager->hasProviders()); + $this->expectException(PreConditionNotMetException::class); + $this->manager->runTask(new \OCP\TextProcessing\Task(FreePromptTaskType::class, 'Hello', 'test', null)); + } + + public function testProviderShouldBeRegisteredAndRun() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class) + ]); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); + $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null))); + + // Summaries are not implemented by the vanilla provider, only free prompt + $this->expectException(PreConditionNotMetException::class); + $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null)); + } + + public function testProviderShouldBeRegisteredAndScheduled() { + // register provider + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class) + ]); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); + + // create task object + $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); + $this->assertNull($task->getId()); + $this->assertNull($task->getOutput()); + + // schedule works + $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); + $this->manager->scheduleTask($task); + + // Task object is up-to-date + $this->assertNotNull($task->getId()); + $this->assertNull($task->getOutput()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); + + // Task object retrieved from db is up-to-date + $task2 = $this->manager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task2->getId()); + $this->assertEquals('Hello', $task2->getInput()); + $this->assertNull($task2->getOutput()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); + + $this->eventDispatcher = $this->createMock(IEventDispatcher::class); + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); + + // run background job + $bgJob = new TaskBackgroundJob( + \OC::$server->get(ITimeFactory::class), + $this->manager, + $this->eventDispatcher, + ); + $bgJob->setArgument(['taskId' => $task->getId()]); + $bgJob->start($this->jobList); + $provider = $this->providers[SuccessfulSummaryProvider::class]; + $this->assertTrue($provider->ran); + + // Task object retrieved from db is up-to-date + $task3 = $this->manager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task3->getId()); + $this->assertEquals('Hello', $task3->getInput()); + $this->assertEquals('Hello Summarize', $task3->getOutput()); + $this->assertEquals(Task::STATUS_SUCCESSFUL, $task3->getStatus()); + } + + public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class), + new ServiceRegistration('test', FreePromptProvider::class), + ]); + $this->assertCount(2, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); + + // Try free prompt again + $this->assertEquals('Hello Free Prompt', $this->manager->runTask(new Task(FreePromptTaskType::class, 'Hello', 'test', null))); + + // Try summary task + $this->assertEquals('Hello Summarize', $this->manager->runTask(new Task(SummaryTaskType::class, 'Hello', 'test', null))); + + // Topics are not implemented by both the vanilla provider and the full provider + $this->expectException(PreConditionNotMetException::class); + $this->manager->runTask(new Task(TopicsTaskType::class, 'Hello', 'test', null)); + } + + public function testNonexistentTask() { + $this->expectException(NotFoundException::class); + $this->manager->getTask(98765432456); + } + + public function testTaskFailure() { + // register provider + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', FailingSummaryProvider::class), + ]); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); + + // create task object + $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); + $this->assertNull($task->getId()); + $this->assertNull($task->getOutput()); + + // schedule works + $this->assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); + $this->manager->scheduleTask($task); + + // Task object is up-to-date + $this->assertNotNull($task->getId()); + $this->assertNull($task->getOutput()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); + + // Task object retrieved from db is up-to-date + $task2 = $this->manager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task2->getId()); + $this->assertEquals('Hello', $task2->getInput()); + $this->assertNull($task2->getOutput()); + $this->assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); + + $this->eventDispatcher = $this->createMock(IEventDispatcher::class); + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); + + // run background job + $bgJob = new TaskBackgroundJob( + \OC::$server->get(ITimeFactory::class), + $this->manager, + $this->eventDispatcher, + ); + $bgJob->setArgument(['taskId' => $task->getId()]); + $bgJob->start($this->jobList); + $provider = $this->providers[FailingSummaryProvider::class]; + $this->assertTrue($provider->ran); + + // Task object retrieved from db is up-to-date + $task3 = $this->manager->getTask($task->getId()); + $this->assertEquals($task->getId(), $task3->getId()); + $this->assertEquals('Hello', $task3->getInput()); + $this->assertNull($task3->getOutput()); + $this->assertEquals(Task::STATUS_FAILED, $task3->getStatus()); + } + + public function testOldTasksShouldBeCleanedUp() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulSummaryProvider::class) + ]); + $this->assertCount(1, $this->manager->getAvailableTaskTypes()); + $this->assertTrue($this->manager->hasProviders()); + $task = new Task(SummaryTaskType::class, 'Hello', 'test', null); + $this->assertEquals('Hello Summarize', $this->manager->runTask($task)); + + $this->currentTime = $this->currentTime->add(new \DateInterval('P1Y')); + // run background job + $bgJob = new RemoveOldTasksBackgroundJob( + \OC::$server->get(ITimeFactory::class), + $this->taskMapper, + \OC::$server->get(LoggerInterface::class), + ); + $bgJob->setArgument([]); + $bgJob->start($this->jobList); + + $this->expectException(NotFoundException::class); + $this->manager->getTask($task->getId()); + } +} |