aboutsummaryrefslogtreecommitdiffstats
path: root/tests/lib/LanguageModel/LanguageModelManagerTest.php
diff options
context:
space:
mode:
authorMarcel Klehr <mklehr@gmx.net>2023-06-29 17:07:31 +0200
committerMarcel Klehr <mklehr@gmx.net>2023-07-07 13:39:10 +0200
commit20cb9935ca80c32665b131315078661064037795 (patch)
tree91071461062786f9e8e184439585516f40dbdb0a /tests/lib/LanguageModel/LanguageModelManagerTest.php
parentebc76315441d75c3c7659c8a3fd0a285bdcd8cb2 (diff)
downloadnextcloud-server-20cb9935ca80c32665b131315078661064037795.tar.gz
nextcloud-server-20cb9935ca80c32665b131315078661064037795.zip
Fix tests
Signed-off-by: Marcel Klehr <mklehr@gmx.net>
Diffstat (limited to 'tests/lib/LanguageModel/LanguageModelManagerTest.php')
-rw-r--r--tests/lib/LanguageModel/LanguageModelManagerTest.php126
1 files changed, 98 insertions, 28 deletions
diff --git a/tests/lib/LanguageModel/LanguageModelManagerTest.php b/tests/lib/LanguageModel/LanguageModelManagerTest.php
index 80cf4348a01..39580fa3cb8 100644
--- a/tests/lib/LanguageModel/LanguageModelManagerTest.php
+++ b/tests/lib/LanguageModel/LanguageModelManagerTest.php
@@ -9,10 +9,15 @@
namespace Test\LanguageModel;
use OC\AppFramework\Bootstrap\Coordinator;
+use OC\AppFramework\Bootstrap\RegistrationContext;
+use OC\AppFramework\Bootstrap\ServiceRegistration;
+use OC\EventDispatcher\EventDispatcher;
+use OC\LanguageModel\Db\Task;
use OC\LanguageModel\Db\TaskMapper;
use OC\LanguageModel\LanguageModelManager;
use OC\LanguageModel\TaskBackgroundJob;
-use OCP\BackgroundJob\IJobList;
+use OCP\AppFramework\Db\DoesNotExistException;
+use OCP\AppFramework\Utility\ITimeFactory;
use OCP\Common\Exception\NotFoundException;
use OCP\EventDispatcher\IEventDispatcher;
use OCP\IServerContainer;
@@ -82,16 +87,69 @@ class LanguageModelManagerTest extends \Test\TestCase {
protected function setUp(): void {
parent::setUp();
+ $this->providers = [
+ TestVanillaLanguageModelProvider::class => new TestVanillaLanguageModelProvider(),
+ TestFullLanguageModelProvider::class => new TestFullLanguageModelProvider(),
+ TestFailingLanguageModelProvider::class => new TestFailingLanguageModelProvider(),
+ ];
+
+ $this->serverContainer = $this->createMock(IServerContainer::class);
+ $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) {
+ return $this->providers[$class];
+ });
+
+ $this->eventDispatcher = new EventDispatcher(
+ new \Symfony\Component\EventDispatcher\EventDispatcher(),
+ $this->serverContainer,
+ \OC::$server->get(LoggerInterface::class),
+ );
+
+ $this->registrationContext = $this->createMock(RegistrationContext::class);
+ $this->coordinator = $this->createMock(Coordinator::class);
+ $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext);
+
+ $this->taskMapper = $this->createMock(TaskMapper::class);
+ $this->tasksDb = [];
+ $this->taskMapper
+ ->expects($this->any())
+ ->method('insert')
+ ->willReturnCallback(function (Task $task) {
+ $task->setId(count($this->tasksDb) ? max(array_keys($this->tasksDb)) : 1);
+ $this->tasksDb[$task->getId()] = $task->toRow();
+ return $task;
+ });
+ $this->taskMapper
+ ->expects($this->any())
+ ->method('update')
+ ->willReturnCallback(function (Task $task) {
+ $this->tasksDb[$task->getId()] = $task->toRow();
+ return $task;
+ });
+ $this->taskMapper
+ ->expects($this->any())
+ ->method('find')
+ ->willReturnCallback(function (int $id) {
+ if (!isset($this->tasksDb[$id])) {
+ throw new DoesNotExistException('Could not find it');
+ }
+ return Task::fromRow($this->tasksDb[$id]);
+ });
+
+ $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']);
+ $this->jobList->expects($this->any())->method('add')->willReturnCallback(function () {
+ });
+
$this->languageModelManager = new LanguageModelManager(
- \OC::$server->get(IServerContainer::class),
- $this->coordinator = \OC::$server->get(Coordinator::class),
+ $this->serverContainer,
+ $this->coordinator,
\OC::$server->get(LoggerInterface::class),
- \OC::$server->get(IJobList::class),
- \OC::$server->get(TaskMapper::class),
+ $this->jobList,
+ $this->taskMapper,
);
}
public function testShouldNotHaveAnyProviders() {
+ $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([]);
$this->assertCount(0, $this->languageModelManager->getAvailableTasks());
$this->assertCount(0, $this->languageModelManager->getAvailableTaskTypes());
$this->assertFalse($this->languageModelManager->hasProviders());
@@ -100,7 +158,9 @@ class LanguageModelManagerTest extends \Test\TestCase {
}
public function testProviderShouldBeRegisteredAndRun() {
- $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
+ $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
+ new ServiceRegistration('test', TestVanillaLanguageModelProvider::class)
+ ]);
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
@@ -113,7 +173,9 @@ class LanguageModelManagerTest extends \Test\TestCase {
public function testProviderShouldBeRegisteredAndScheduled() {
// register provider
- $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
+ $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
+ new ServiceRegistration('test', TestVanillaLanguageModelProvider::class)
+ ]);
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
@@ -139,10 +201,10 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->assertNull($task2->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());
- /** @var IEventDispatcher $eventDispatcher */
- $eventDispatcher = \OC::$server->get(IEventDispatcher::class);
+ /** @var IEventDispatcher $this->eventDispatcher */
+ $this->eventDispatcher = \OC::$server->get(IEventDispatcher::class);
$successfulEventFired = false;
- $eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
+ $this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$successfulEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
@@ -150,7 +212,7 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->assertEquals('Hello Free Prompt', $t->getOutput());
});
$failedEventFired = false;
- $eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
+ $this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$failedEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
@@ -159,11 +221,14 @@ class LanguageModelManagerTest extends \Test\TestCase {
});
// run background job
- /** @var TaskBackgroundJob $bgJob */
- $bgJob = \OC::$server->get(TaskBackgroundJob::class);
+ $bgJob = new TaskBackgroundJob(
+ \OC::$server->get(ITimeFactory::class),
+ $this->languageModelManager,
+ $this->eventDispatcher,
+ );
$bgJob->setArgument(['taskId' => $task->getId()]);
- $bgJob->start(new DummyJobList());
- $provider = \OC::$server->get(TestVanillaLanguageModelProvider::class);
+ $bgJob->start($this->jobList);
+ $provider = $this->providers[TestVanillaLanguageModelProvider::class];
$this->assertTrue($provider->ran);
$this->assertTrue($successfulEventFired);
$this->assertFalse($failedEventFired);
@@ -173,12 +238,14 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->assertEquals($task->getId(), $task3->getId());
$this->assertEquals('Hello', $task3->getInput());
$this->assertEquals('Hello Free Prompt', $task3->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task2->getStatus());
+ $this->assertEquals(ILanguageModelTask::STATUS_SUCCESSFUL, $task3->getStatus());
}
public function testMultipleProvidersShouldBeRegisteredAndRunCorrectly() {
- $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestVanillaLanguageModelProvider::class);
- $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFullLanguageModelProvider::class);
+ $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
+ new ServiceRegistration('test', TestVanillaLanguageModelProvider::class),
+ new ServiceRegistration('test', TestFullLanguageModelProvider::class),
+ ]);
$this->assertCount(3, $this->languageModelManager->getAvailableTasks());
$this->assertCount(3, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
@@ -204,7 +271,9 @@ class LanguageModelManagerTest extends \Test\TestCase {
public function testTaskFailure() {
// register provider
- $this->coordinator->getRegistrationContext()->registerLanguageModelProvider('test', TestFailingLanguageModelProvider::class);
+ $this->registrationContext->expects($this->any())->method('getLanguageModelProviders')->willReturn([
+ new ServiceRegistration('test', TestFailingLanguageModelProvider::class),
+ ]);
$this->assertCount(1, $this->languageModelManager->getAvailableTasks());
$this->assertCount(1, $this->languageModelManager->getAvailableTaskTypes());
$this->assertTrue($this->languageModelManager->hasProviders());
@@ -230,10 +299,8 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->assertNull($task2->getOutput());
$this->assertEquals(ILanguageModelTask::STATUS_SCHEDULED, $task2->getStatus());
- /** @var IEventDispatcher $eventDispatcher */
- $eventDispatcher = \OC::$server->get(IEventDispatcher::class);
$successfulEventFired = false;
- $eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
+ $this->eventDispatcher->addListener(TaskSuccessfulEvent::class, function (TaskSuccessfulEvent $event) use (&$successfulEventFired, $task) {
$successfulEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
@@ -241,7 +308,7 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->assertEquals('Hello Free Prompt', $t->getOutput());
});
$failedEventFired = false;
- $eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
+ $this->eventDispatcher->addListener(TaskFailedEvent::class, function (TaskFailedEvent $event) use (&$failedEventFired, $task) {
$failedEventFired = true;
$t = $event->getTask();
$this->assertEquals($task->getId(), $t->getId());
@@ -250,11 +317,14 @@ class LanguageModelManagerTest extends \Test\TestCase {
});
// run background job
- /** @var TaskBackgroundJob $bgJob */
- $bgJob = \OC::$server->get(TaskBackgroundJob::class);
+ $bgJob = new TaskBackgroundJob(
+ \OC::$server->get(ITimeFactory::class),
+ $this->languageModelManager,
+ $this->eventDispatcher,
+ );
$bgJob->setArgument(['taskId' => $task->getId()]);
- $bgJob->start(new DummyJobList());
- $provider = \OC::$server->get(TestFailingLanguageModelProvider::class);
+ $bgJob->start($this->jobList);
+ $provider = $this->providers[TestFailingLanguageModelProvider::class];
$this->assertTrue($provider->ran);
$this->assertTrue($failedEventFired);
$this->assertFalse($successfulEventFired);
@@ -264,6 +334,6 @@ class LanguageModelManagerTest extends \Test\TestCase {
$this->assertEquals($task->getId(), $task3->getId());
$this->assertEquals('Hello', $task3->getInput());
$this->assertNull($task3->getOutput());
- $this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task2->getStatus());
+ $this->assertEquals(ILanguageModelTask::STATUS_FAILED, $task3->getStatus());
}
}