grpc-javaのClient/ServerのテストをKotlinで書く - Server編


grpc-javaで実装されたgRPC ClientとgRPC Serverのテストコードについてまとめていきたい。

ClientとServerのどちらも大枠は同じである。テストコードのなかでgRPC Serverを起動させる。そしてリクエスト内のトランザクションを必要に応じてモック化しながら期待値が取得できているか、期待される関数が呼び出せれているかを検証する。

今回のエントリではServer側のテストをJUnitとKotlinを用いてまとめていく。

テスト対象のproto

テスト対象のprotoは次のとおりSimple-RPCとする。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
service TaskService {
  rpc GetTaskService (TaskInbound) returns (TaskOutbound) {
    option (google.api.http) = {
      get: "/v1/task"
    };
  }
}

message TaskInbound {
  uint32 task_id = 1;
}

message TaskOutbound {
  uint32 task_id = 1;
  string title = 2;
  string finishedAt = 3;
  string createdAt = 4;
  string updatedAt = 5;
}

テストするgRPC Serverとテスト内容

テスト対象のServerのコードは次のとおりである

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
override fun getTaskService(request: TaskInbound?, responseObserver: StreamObserver<TaskOutbound>?) {
    try {
        val (taskId) = GRpcInboundValidator.validTaskInbound(request)

        val log = GRpcLogContextHandler.getLog()
        log.elem { "taskId" to taskId }

        val task = getTaskService(GetTaskCommand(taskId.toLong()))
        val msg = getOutbound(task)
        responseObserver?.onNext(msg)
        responseObserver?.onCompleted()
    } catch (e: WebAppException.NotFoundException) {
        logger.error { "gRPC server error, task not found." }
        responseObserver?.onError(
                Status.NOT_FOUND.withDescription("task not found.").asRuntimeException())
    } catch (e: WebAppException.BadRequestException) {
        logger.error { "gRPC server error, invalid request." }
        responseObserver?.onError(
                Status.INVALID_ARGUMENT.withDescription("invalid request.").asRuntimeException())
    }
}

テスト内容

テストする内容を次のようにまとめる。

テストコード

次にテストコードである。

先述したとおりテストコードのなかでServerを起動させる。そして起動しているServerにテスト対象のgRPC Sereverをアサインする。コードとしては次のようになる。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
@Before
fun setUp() {
    getTaskService = mock(GetTaskServiceImpl::class)
    // 一部省略

    target = TaskBackendServer(getTaskService, getTaskListService, createTaskService, updateTaskService,
            deleteTaskService, finishTaskService)
    inProcessServer = InProcessServerBuilder
            .forName(UNIQUE_SERVER_NAME).addService(target).directExecutor().build()
    inProcessChannel = InProcessChannelBuilder.forName(UNIQUE_SERVER_NAME).directExecutor().build()

    inProcessServer.start()
}

@After
fun tearDown() {
    inProcessChannel.shutdownNow()
    inProcessServer.shutdownNow()
}

正常系のテスト

次のコードは正常系をテストしたコードである。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
@Test
fun getProducts_onCompleted() {

    val taskId = 1L
    val request = TaskInbound.newBuilder()
            .setTaskId(taskId.toInt())
            .build()

    val command = GetTaskCommand(taskId)
    val now = LocalDateTime.now()
    val task = Task(taskId.toInt(), "mocked Task", now, now, now)

    val log = GRpcLogBuilder()

    // mock
    mockStatic(GRpcLogContextHandler::class)
    Mockito.`when`(GRpcLogContextHandler.getLog()).thenReturn(log)
    Mockito.`when`(getTaskService(command)).thenReturn(task)

    // request server
    val blockingStub = TaskServiceGrpc.newBlockingStub(inProcessChannel)
    val actual = blockingStub.getTaskService(request) // ブロッキングしてgRPC Serverのレスポンスを受け取る

    // assertion
    actual.taskId shouldBe 1
    actual.title shouldBe "mocked Task"
}

異常系のテスト

次のコードは異常系をテストしたコードである。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@Test
fun getProducts_NOT_FOUND() {

    val taskId = 1L
    val request = TaskInbound.newBuilder().setTaskId(taskId.toInt()).build()

    val command = GetTaskCommand(taskId)

    // mock
    mockStatic(GRpcLogContextHandler::class)
    Mockito.`when`(GRpcLogContextHandler.getLog()).thenReturn(GRpcLogBuilder())
    Mockito.`when`(getTaskService(command)).thenThrow(WebAppException.NotFoundException("not found"))

    try {
        // request server
        val blockingStub = TaskServiceGrpc.newBlockingStub(inProcessChannel)
        blockingStub.getTaskService(request)
    } catch (e: StatusRuntimeException) {
        // assertion
        e.status.code shouldBe Status.NOT_FOUND.code
        e.message shouldBe "NOT_FOUND: task not found."
    }
}

@Test
fun getProducts_INVALID_ARGUMENT() {

    val taskId = 0L
    val request = TaskInbound.newBuilder().setTaskId(taskId.toInt()).build()

    try {
        // request server
        val blockingStub = TaskServiceGrpc.newBlockingStub(inProcessChannel)
        blockingStub.getTaskService(request)
    } catch (e: StatusRuntimeException) {
        // assertion
        e.status.code shouldBe Status.INVALID_ARGUMENT.code
        e.message shouldBe "INVALID_ARGUMENT: invalid request."
    }
}

まとめ

コード

エントリで紹介したコードは一部分のためコード全体はgithubを参照してください。

テストコードはこちらです。