aboutsummaryrefslogtreecommitdiffstats
path: root/tests/lib/TaskProcessing
diff options
context:
space:
mode:
authorMarcel Klehr <mklehr@gmx.net>2024-05-03 14:15:03 +0200
committerMarcel Klehr <mklehr@gmx.net>2024-05-14 11:38:40 +0200
commitbd5dfd0b5f5f4bbbc0046924a62dbd54ad5fa2c2 (patch)
tree6a5eac9d618c2c555273dd8ff186bd69f9d3f6ba /tests/lib/TaskProcessing
parenteebeb82416c191e29245c1364922e3d4716b8ee1 (diff)
downloadnextcloud-server-bd5dfd0b5f5f4bbbc0046924a62dbd54ad5fa2c2.tar.gz
nextcloud-server-bd5dfd0b5f5f4bbbc0046924a62dbd54ad5fa2c2.zip
test: Add more tests for legacy pass-through
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
Diffstat (limited to 'tests/lib/TaskProcessing')
-rw-r--r--tests/lib/TaskProcessing/TaskProcessingTest.php233
1 files changed, 231 insertions, 2 deletions
diff --git a/tests/lib/TaskProcessing/TaskProcessingTest.php b/tests/lib/TaskProcessing/TaskProcessingTest.php
index b9d402bbde8..5be43314d3e 100644
--- a/tests/lib/TaskProcessing/TaskProcessingTest.php
+++ b/tests/lib/TaskProcessing/TaskProcessingTest.php
@@ -38,7 +38,10 @@ use OCP\TaskProcessing\ISynchronousProvider;
use OCP\TaskProcessing\ITaskType;
use OCP\TaskProcessing\ShapeDescriptor;
use OCP\TaskProcessing\Task;
+use OCP\TaskProcessing\TaskTypes\TextToImage;
use OCP\TaskProcessing\TaskTypes\TextToText;
+use OCP\TaskProcessing\TaskTypes\TextToTextSummary;
+use OCP\TextProcessing\SummaryTaskType;
use PHPUnit\Framework\Constraint\IsInstanceOf;
use Psr\Log\LoggerInterface;
use Test\BackgroundJob\DummyJobList;
@@ -204,6 +207,85 @@ class BrokenSyncProvider implements IProvider, ISynchronousProvider {
}
}
+class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
+ public bool $ran = false;
+
+ public function getName(): string {
+ return 'TEST Vanilla LLM Provider';
+ }
+
+ public function process(string $prompt): string {
+ $this->ran = true;
+ return $prompt . ' Summarize';
+ }
+
+ public function getTaskType(): string {
+ return SummaryTaskType::class;
+ }
+}
+
+class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider {
+ public bool $ran = false;
+
+ public function getName(): string {
+ return 'TEST Vanilla LLM Provider';
+ }
+
+ public function process(string $prompt): string {
+ $this->ran = true;
+ throw new \Exception('ERROR');
+ }
+
+ public function getTaskType(): string {
+ return SummaryTaskType::class;
+ }
+}
+
+class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider {
+ public bool $ran = false;
+
+ public function getId(): string {
+ return 'test:successful';
+ }
+
+ public function getName(): string {
+ return 'TEST Provider';
+ }
+
+ public function generate(string $prompt, array $resources): void {
+ $this->ran = true;
+ foreach($resources as $resource) {
+ fwrite($resource, 'test');
+ fclose($resource);
+ }
+ }
+
+ public function getExpectedRuntime(): int {
+ return 1;
+ }
+}
+
+class FailingTextToImageProvider implements \OCP\TextToImage\IProvider {
+ public bool $ran = false;
+
+ public function getId(): string {
+ return 'test:failing';
+ }
+
+ public function getName(): string {
+ return 'TEST Provider';
+ }
+
+ public function generate(string $prompt, array $resources): void {
+ $this->ran = true;
+ throw new \RuntimeException('ERROR');
+ }
+
+ public function getExpectedRuntime(): int {
+ return 1;
+ }
+}
+
/**
* @group DB
*/
@@ -227,6 +309,10 @@ class TaskProcessingTest extends \Test\TestCase {
BrokenSyncProvider::class => new BrokenSyncProvider(),
AsyncProvider::class => new AsyncProvider(),
AudioToImage::class => new AudioToImage(),
+ SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(),
+ FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(),
+ SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(),
+ FailingTextToImageProvider::class => new FailingTextToImageProvider(),
];
$this->serverContainer = $this->createMock(IServerContainer::class);
@@ -257,6 +343,26 @@ class TaskProcessingTest extends \Test\TestCase {
$this->eventDispatcher = $this->createMock(IEventDispatcher::class);
+ $textProcessingManager = new \OC\TextProcessing\Manager(
+ $this->serverContainer,
+ $this->coordinator,
+ \OC::$server->get(LoggerInterface::class),
+ $this->jobList,
+ \OC::$server->get(\OC\TextProcessing\Db\TaskMapper::class),
+ \OC::$server->get(IConfig::class),
+ );
+
+ $text2imageManager = new \OC\TextToImage\Manager(
+ $this->serverContainer,
+ $this->coordinator,
+ \OC::$server->get(LoggerInterface::class),
+ $this->jobList,
+ \OC::$server->get(\OC\TextToImage\Db\TaskMapper::class),
+ \OC::$server->get(IConfig::class),
+ \OC::$server->get(IAppDataFactory::class),
+ );
+
+
$this->manager = new Manager(
$this->coordinator,
$this->serverContainer,
@@ -266,8 +372,8 @@ class TaskProcessingTest extends \Test\TestCase {
$this->eventDispatcher,
\OC::$server->get(IAppDataFactory::class),
\OC::$server->get(IRootFolder::class),
- \OC::$server->get(\OCP\TextProcessing\IManager::class),
- \OC::$server->get(\OCP\TextToImage\IManager::class),
+ $textProcessingManager,
+ $text2imageManager,
\OC::$server->get(ISpeechToTextManager::class),
);
}
@@ -507,4 +613,127 @@ class TaskProcessingTest extends \Test\TestCase {
$this->expectException(NotFoundException::class);
$this->manager->getTask($task->getId());
}
+
+ public function testShouldTransparentlyHandleTextProcessingProviders() {
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
+ new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class)
+ ]);
+ $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
+ ]);
+ $taskTypes = $this->manager->getAvailableTaskTypes();
+ self::assertCount(1, $taskTypes);
+ self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
+ self::assertTrue($this->manager->hasProviders());
+ $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
+ $this->manager->scheduleTask($task);
+
+ $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
+
+ $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
+ \OCP\Server::get(ITimeFactory::class),
+ $this->manager,
+ $this->jobList,
+ \OCP\Server::get(LoggerInterface::class),
+ );
+ $backgroundJob->start($this->jobList);
+
+ $task = $this->manager->getTask($task->getId());
+ self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
+ self::assertIsArray($task->getOutput());
+ self::assertTrue(isset($task->getOutput()['output']));
+ self::assertEquals('Hello Summarize', $task->getOutput()['output']);
+ self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran);
+ }
+
+ public function testShouldTransparentlyHandleFailingTextProcessingProviders() {
+ $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([
+ new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class)
+ ]);
+ $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
+ ]);
+ $taskTypes = $this->manager->getAvailableTaskTypes();
+ self::assertCount(1, $taskTypes);
+ self::assertTrue(isset($taskTypes[TextToTextSummary::ID]));
+ self::assertTrue($this->manager->hasProviders());
+ $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null);
+ $this->manager->scheduleTask($task);
+
+ $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
+
+ $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
+ \OCP\Server::get(ITimeFactory::class),
+ $this->manager,
+ $this->jobList,
+ \OCP\Server::get(LoggerInterface::class),
+ );
+ $backgroundJob->start($this->jobList);
+
+ $task = $this->manager->getTask($task->getId());
+ self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
+ self::assertTrue($task->getOutput() === null);
+ self::assertEquals('ERROR', $task->getErrorMessage());
+ self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran);
+ }
+
+ public function testShouldTransparentlyHandleText2ImageProviders() {
+ $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
+ new ServiceRegistration('test', SuccessfulTextToImageProvider::class)
+ ]);
+ $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
+ ]);
+ $taskTypes = $this->manager->getAvailableTaskTypes();
+ self::assertCount(1, $taskTypes);
+ self::assertTrue(isset($taskTypes[TextToImage::ID]));
+ self::assertTrue($this->manager->hasProviders());
+ $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
+ $this->manager->scheduleTask($task);
+
+ $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class));
+
+ $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
+ \OCP\Server::get(ITimeFactory::class),
+ $this->manager,
+ $this->jobList,
+ \OCP\Server::get(LoggerInterface::class),
+ );
+ $backgroundJob->start($this->jobList);
+
+ $task = $this->manager->getTask($task->getId());
+ self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus());
+ self::assertIsArray($task->getOutput());
+ self::assertTrue(isset($task->getOutput()['images']));
+ self::assertIsArray($task->getOutput()['images']);
+ self::assertCount(3, $task->getOutput()['images']);
+ self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran);
+ }
+
+ public function testShouldTransparentlyHandleFailingText2ImageProviders() {
+ $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([
+ new ServiceRegistration('test', FailingTextToImageProvider::class)
+ ]);
+ $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([
+ ]);
+ $taskTypes = $this->manager->getAvailableTaskTypes();
+ self::assertCount(1, $taskTypes);
+ self::assertTrue(isset($taskTypes[TextToImage::ID]));
+ self::assertTrue($this->manager->hasProviders());
+ $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null);
+ $this->manager->scheduleTask($task);
+
+ $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class));
+
+ $backgroundJob = new \OC\TaskProcessing\SynchronousBackgroundJob(
+ \OCP\Server::get(ITimeFactory::class),
+ $this->manager,
+ $this->jobList,
+ \OCP\Server::get(LoggerInterface::class),
+ );
+ $backgroundJob->start($this->jobList);
+
+ $task = $this->manager->getTask($task->getId());
+ self::assertEquals(Task::STATUS_FAILED, $task->getStatus());
+ self::assertTrue($task->getOutput() === null);
+ self::assertEquals('ERROR', $task->getErrorMessage());
+ self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran);
+ }
}