diff options
author | Marcel Klehr <mklehr@gmx.net> | 2024-05-03 14:15:03 +0200 |
---|---|---|
committer | Marcel Klehr <mklehr@gmx.net> | 2024-05-14 11:38:40 +0200 |
commit | bd5dfd0b5f5f4bbbc0046924a62dbd54ad5fa2c2 (patch) | |
tree | 6a5eac9d618c2c555273dd8ff186bd69f9d3f6ba /tests/lib/TaskProcessing | |
parent | eebeb82416c191e29245c1364922e3d4716b8ee1 (diff) | |
download | nextcloud-server-bd5dfd0b5f5f4bbbc0046924a62dbd54ad5fa2c2.tar.gz nextcloud-server-bd5dfd0b5f5f4bbbc0046924a62dbd54ad5fa2c2.zip |
test: Add more tests for legacy pass-through
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
Diffstat (limited to 'tests/lib/TaskProcessing')
-rw-r--r-- | tests/lib/TaskProcessing/TaskProcessingTest.php | 233 |
1 files changed, 231 insertions, 2 deletions
diff --git a/tests/lib/TaskProcessing/TaskProcessingTest.php b/tests/lib/TaskProcessing/TaskProcessingTest.php index b9d402bbde8..5be43314d3e 100644 --- a/tests/lib/TaskProcessing/TaskProcessingTest.php +++ b/tests/lib/TaskProcessing/TaskProcessingTest.php @@ -38,7 +38,10 @@ use OCP\TaskProcessing\ISynchronousProvider; use OCP\TaskProcessing\ITaskType; use OCP\TaskProcessing\ShapeDescriptor; use OCP\TaskProcessing\Task; +use OCP\TaskProcessing\TaskTypes\TextToImage; use OCP\TaskProcessing\TaskTypes\TextToText; +use OCP\TaskProcessing\TaskTypes\TextToTextSummary; +use OCP\TextProcessing\SummaryTaskType; use PHPUnit\Framework\Constraint\IsInstanceOf; use Psr\Log\LoggerInterface; use Test\BackgroundJob\DummyJobList; @@ -204,6 +207,85 @@ class BrokenSyncProvider implements IProvider, ISynchronousProvider { } } +class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function process(string $prompt): string { + $this->ran = true; + return $prompt . ' Summarize'; + } + + public function getTaskType(): string { + return SummaryTaskType::class; + } +} + +class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function process(string $prompt): string { + $this->ran = true; + throw new \Exception('ERROR'); + } + + public function getTaskType(): string { + return SummaryTaskType::class; + } +} + +class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider { + public bool $ran = false; + + public function getId(): string { + return 'test:successful'; + } + + public function getName(): string { + return 'TEST Provider'; + } + + public function generate(string $prompt, array $resources): void { + $this->ran = true; + foreach($resources as $resource) { + fwrite($resource, 'test'); + fclose($resource); + } + } + + public function getExpectedRuntime(): int { + return 1; + } +} + +class FailingTextToImageProvider implements \OCP\TextToImage\IProvider { + public bool $ran = false; + + public function getId(): string { + return 'test:failing'; + } + + public function getName(): string { + return 'TEST Provider'; + } + + public function generate(string $prompt, array $resources): void { + $this->ran = true; + throw new \RuntimeException('ERROR'); + } + + public function getExpectedRuntime(): int { + return 1; + } +} + /** * @group DB */ @@ -227,6 +309,10 @@ class TaskProcessingTest extends \Test\TestCase { BrokenSyncProvider::class => new BrokenSyncProvider(), AsyncProvider::class => new AsyncProvider(), AudioToImage::class => new AudioToImage(), + SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(), + FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(), + SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(), + FailingTextToImageProvider::class => new FailingTextToImageProvider(), ]; $this->serverContainer = $this->createMock(IServerContainer::class); @@ -257,6 +343,26 @@ class TaskProcessingTest extends \Test\TestCase { $this->eventDispatcher = $this->createMock(IEventDispatcher::class); + $textProcessingManager = new \OC\TextProcessing\Manager( + $this->serverContainer, + $this->coordinator, + \OC::$server->get(LoggerInterface::class), + $this->jobList, + \OC::$server->get(\OC\TextProcessing\Db\TaskMapper::class), + \OC::$server->get(IConfig::class), + ); + + $text2imageManager = new \OC\TextToImage\Manager( + $this->serverContainer, + $this->coordinator, + \OC::$server->get(LoggerInterface::class), + $this->jobList, + \OC::$server->get(\OC\TextToImage\Db\TaskMapper::class), + \OC::$server->get(IConfig::class), + \OC::$server->get(IAppDataFactory::class), + ); + + $this->manager = new Manager( $this->coordinator, $this->serverContainer, @@ -266,8 +372,8 @@ class TaskProcessingTest extends \Test\TestCase { $this->eventDispatcher, \OC::$server->get(IAppDataFactory::class), \OC::$server->get(IRootFolder::class), - \OC::$server->get(\OCP\TextProcessing\IManager::class), - \OC::$server->get(\OCP\TextToImage\IManager::class), + $textProcessingManager, + $text2imageManager, \OC::$server->get(ISpeechToTextManager::class), ); } @@ -507,4 +613,127 @@ class TaskProcessingTest extends \Test\TestCase { $this->expectException(NotFoundException::class); $this->manager->getTask($task->getId()); } + + public function testShouldTransparentlyHandleTextProcessingProviders() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); + self::assertIsArray($task->getOutput()); + self::assertTrue(isset($task->getOutput()['output'])); + self::assertEquals('Hello Summarize', $task->getOutput()['output']); + self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran); + } + + public function testShouldTransparentlyHandleFailingTextProcessingProviders() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); + self::assertTrue($task->getOutput() === null); + self::assertEquals('ERROR', $task->getErrorMessage()); + self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran); + } + + public function testShouldTransparentlyHandleText2ImageProviders() { + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulTextToImageProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToImage::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); + self::assertIsArray($task->getOutput()); + self::assertTrue(isset($task->getOutput()['images'])); + self::assertIsArray($task->getOutput()['images']); + self::assertCount(3, $task->getOutput()['images']); + self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran); + } + + public function testShouldTransparentlyHandleFailingText2ImageProviders() { + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ + new ServiceRegistration('test', FailingTextToImageProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToImage::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); + self::assertTrue($task->getOutput() === null); + self::assertEquals('ERROR', $task->getErrorMessage()); + self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran); + } } |