diff options
-rw-r--r-- | lib/private/TaskProcessing/Manager.php | 5 | ||||
-rw-r--r-- | tests/lib/TaskProcessing/TaskProcessingTest.php | 233 |
2 files changed, 234 insertions, 4 deletions
diff --git a/lib/private/TaskProcessing/Manager.php b/lib/private/TaskProcessing/Manager.php index 818d0d1ce83..9a75217910d 100644 --- a/lib/private/TaskProcessing/Manager.php +++ b/lib/private/TaskProcessing/Manager.php @@ -38,6 +38,7 @@ use OCP\Files\GenericFileException; use OCP\Files\IAppData; use OCP\Files\IRootFolder; use OCP\Files\NotPermittedException; +use OCP\Files\SimpleFS\ISimpleFile; use OCP\IL10N; use OCP\IServerContainer; use OCP\L10N\IFactory; @@ -265,7 +266,7 @@ class Manager implements IManager { $resources = []; $files = []; for ($i = 0; $i < $input['numberOfImages']; $i++) { - $file = $folder->newFile( time() . '-' . rand(1, 100000) . '-' . $i); + $file = $folder->newFile(time() . '-' . rand(1, 100000) . '-' . $i); $files[] = $file; $resource = $file->write(); if ($resource !== false && $resource !== true && is_resource($resource)) { @@ -282,7 +283,7 @@ class Manager implements IManager { } catch (\RuntimeException $e) { throw new ProcessingException($e->getMessage(), 0, $e); } - return ['images' => array_map(fn (File $file) => $file->getContent(), $files)]; + return ['images' => array_map(fn (ISimpleFile $file) => $file->getContent(), $files)]; } }; $newProviders[$newProvider->getId()] = $newProvider; diff --git a/tests/lib/TaskProcessing/TaskProcessingTest.php b/tests/lib/TaskProcessing/TaskProcessingTest.php index b9d402bbde8..5be43314d3e 100644 --- a/tests/lib/TaskProcessing/TaskProcessingTest.php +++ b/tests/lib/TaskProcessing/TaskProcessingTest.php @@ -38,7 +38,10 @@ use OCP\TaskProcessing\ISynchronousProvider; use OCP\TaskProcessing\ITaskType; use OCP\TaskProcessing\ShapeDescriptor; use OCP\TaskProcessing\Task; +use OCP\TaskProcessing\TaskTypes\TextToImage; use OCP\TaskProcessing\TaskTypes\TextToText; +use OCP\TaskProcessing\TaskTypes\TextToTextSummary; +use OCP\TextProcessing\SummaryTaskType; use PHPUnit\Framework\Constraint\IsInstanceOf; use Psr\Log\LoggerInterface; use Test\BackgroundJob\DummyJobList; @@ -204,6 +207,85 @@ class BrokenSyncProvider implements IProvider, ISynchronousProvider { } } +class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function process(string $prompt): string { + $this->ran = true; + return $prompt . ' Summarize'; + } + + public function getTaskType(): string { + return SummaryTaskType::class; + } +} + +class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { + public bool $ran = false; + + public function getName(): string { + return 'TEST Vanilla LLM Provider'; + } + + public function process(string $prompt): string { + $this->ran = true; + throw new \Exception('ERROR'); + } + + public function getTaskType(): string { + return SummaryTaskType::class; + } +} + +class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider { + public bool $ran = false; + + public function getId(): string { + return 'test:successful'; + } + + public function getName(): string { + return 'TEST Provider'; + } + + public function generate(string $prompt, array $resources): void { + $this->ran = true; + foreach($resources as $resource) { + fwrite($resource, 'test'); + fclose($resource); + } + } + + public function getExpectedRuntime(): int { + return 1; + } +} + +class FailingTextToImageProvider implements \OCP\TextToImage\IProvider { + public bool $ran = false; + + public function getId(): string { + return 'test:failing'; + } + + public function getName(): string { + return 'TEST Provider'; + } + + public function generate(string $prompt, array $resources): void { + $this->ran = true; + throw new \RuntimeException('ERROR'); + } + + public function getExpectedRuntime(): int { + return 1; + } +} + /** * @group DB */ @@ -227,6 +309,10 @@ class TaskProcessingTest extends \Test\TestCase { BrokenSyncProvider::class => new BrokenSyncProvider(), AsyncProvider::class => new AsyncProvider(), AudioToImage::class => new AudioToImage(), + SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(), + FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(), + SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(), + FailingTextToImageProvider::class => new FailingTextToImageProvider(), ]; $this->serverContainer = $this->createMock(IServerContainer::class); @@ -257,6 +343,26 @@ class TaskProcessingTest extends \Test\TestCase { $this->eventDispatcher = $this->createMock(IEventDispatcher::class); + $textProcessingManager = new \OC\TextProcessing\Manager( + $this->serverContainer, + $this->coordinator, + \OC::$server->get(LoggerInterface::class), + $this->jobList, + \OC::$server->get(\OC\TextProcessing\Db\TaskMapper::class), + \OC::$server->get(IConfig::class), + ); + + $text2imageManager = new \OC\TextToImage\Manager( + $this->serverContainer, + $this->coordinator, + \OC::$server->get(LoggerInterface::class), + $this->jobList, + \OC::$server->get(\OC\TextToImage\Db\TaskMapper::class), + \OC::$server->get(IConfig::class), + \OC::$server->get(IAppDataFactory::class), + ); + + $this->manager = new Manager( $this->coordinator, $this->serverContainer, @@ -266,8 +372,8 @@ class TaskProcessingTest extends \Test\TestCase { $this->eventDispatcher, \OC::$server->get(IAppDataFactory::class), \OC::$server->get(IRootFolder::class), - \OC::$server->get(\OCP\TextProcessing\IManager::class), - \OC::$server->get(\OCP\TextToImage\IManager::class), + $textProcessingManager, + $text2imageManager, \OC::$server->get(ISpeechToTextManager::class), ); } @@ -507,4 +613,127 @@ class TaskProcessingTest extends \Test\TestCase { $this->expectException(NotFoundException::class); $this->manager->getTask($task->getId()); } + + public function testShouldTransparentlyHandleTextProcessingProviders() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); + self::assertIsArray($task->getOutput()); + self::assertTrue(isset($task->getOutput()['output'])); + self::assertEquals('Hello Summarize', $task->getOutput()['output']); + self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran); + } + + public function testShouldTransparentlyHandleFailingTextProcessingProviders() { + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ + new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); + self::assertTrue($task->getOutput() === null); + self::assertEquals('ERROR', $task->getErrorMessage()); + self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran); + } + + public function testShouldTransparentlyHandleText2ImageProviders() { + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ + new ServiceRegistration('test', SuccessfulTextToImageProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToImage::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); + self::assertIsArray($task->getOutput()); + self::assertTrue(isset($task->getOutput()['images'])); + self::assertIsArray($task->getOutput()['images']); + self::assertCount(3, $task->getOutput()['images']); + self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran); + } + + public function testShouldTransparentlyHandleFailingText2ImageProviders() { + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ + new ServiceRegistration('test', FailingTextToImageProvider::class) + ]); + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ + ]); + $taskTypes = $this->manager->getAvailableTaskTypes(); + self::assertCount(1, $taskTypes); + self::assertTrue(isset($taskTypes[TextToImage::ID])); + self::assertTrue($this->manager->hasProviders()); + $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); + $this->manager->scheduleTask($task); + + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); + + $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob( + \OCP\Server::get(ITimeFactory::class), + $this->manager, + $this->jobList, + \OCP\Server::get(LoggerInterface::class), + ); + $backgroundJob->start($this->jobList); + + $task = $this->manager->getTask($task->getId()); + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); + self::assertTrue($task->getOutput() === null); + self::assertEquals('ERROR', $task->getErrorMessage()); + self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran); + } } |