코루틴을 이용한 Task Executor 만들기

TaskExecutor란 프로세스 내에 상주하고 있으면서 Task를 등록하면 그 Task를 실행해주는 녀석을 말한다. (내가 그냥 대충 지은거임)

이걸 몇년 전에 자바스크립트로 만든적이 있긴 한데 자바스크립트 특성 상 이게 어떻게 동작하는지 알 수가 없어서 (사실 지금 만든것도 이해가 잘 안됨) 좀 그랬는데 이걸 코루틴으로 만들어봅시다.

아래는 executor가 실제로 실행할 작업을 정의한 클래스이다. 얘는 진행상황을 알아야하기 때문이 진행률같은 정보를 들고 있는데 여기에는 StateFlow를 이용하면 된다.

import kotlinx.coroutines.Runnable
import kotlinx.coroutines.flow.*
import kotlin.coroutines.cancellation.CancellationException

data class ExecutorTask<TData, TResult>(
    val params: TData,
    val execute: suspend ExecutorProcess.(TData) -> TResult,
) : ExecutorProcess {
    private var _completionRate = MutableStateFlow(0f)
    private var _status = MutableStateFlow(ExecutorTaskStatus.Ready)
    private var _result = MutableSharedFlow<TResult>()
    private var _exception = MutableSharedFlow<Throwable>()
    private var cancelJob: Runnable? = null

    override val completionRate: StateFlow<Float> = _completionRate.asStateFlow()
    override val status: StateFlow<ExecutorTaskStatus> = _status.asStateFlow()
    val result = _result.asSharedFlow()
    val exception = _exception.asSharedFlow()


    override fun updateCompletionRate(completionRate: Float) {
        check(completionRate in 0f..100f) {
            "Completeness must be between 0 and 100, but was $completionRate"
        }
        _completionRate.value = completionRate
    }

    suspend operator fun invoke() {
        try {
            _status.value = ExecutorTaskStatus.Processing
            println("Task Started at ${Thread.currentThread().threadId()}")
            val result = this.execute(params)
            _status.value = ExecutorTaskStatus.Completed
            _result.emit(result)
        } catch (e: CancellationException) {
            _status.value = ExecutorTaskStatus.Canceled
        } catch (e: Exception) {
            _exception.emit(e)
            _status.value = ExecutorTaskStatus.Error
        }
    }

    suspend infix fun runsWithCancel(cancelJob: Runnable) {
        this.cancelJob = cancelJob
        invoke()
    }

    fun cancel() {
        cancelJob?.run()
    }
}

MutableStateFlow는 수정도 가능하기 때문에 다 private으로 선언하고 구독을 할 수 있는 것들은 readonly StateFlow로 따로 만든다. updateCompletionRate는 execute의 중간에 호출할 수 있어야하기 때문에 따로 인터페이스로 분리하여 execute의 리시버로 이 인터페이스를 넘겨서 접근해도 될 만 한 것만 넘긴다.

import kotlinx.coroutines.flow.StateFlow

interface ExecutorProcess {
    val completionRate: StateFlow<Float>
    val status: StateFlow<ExecutorTaskStatus>
    fun updateCompletionRate(completionRate: Float)
}

각 Task의 상태는 준비, 동작중, 완료, 에러, 취소됨 을 나타낸다. 이는 enum class로 되어있다.

enum class ExecutorTaskStatus {
    Ready,
    Processing,
    Completed,
    Error,
    Canceled,
}

이제 이 Task를 받아들여서 실제로 실행하는 걸 만들어야한다.

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlin.time.Duration.Companion.milliseconds

abstract class AbstractTaskExecutor(
    dispatcher: CoroutineDispatcher,
    countsAsynchronousTasks: Int,
    pollingMilliseconds: Int = 100
) {
    private val scope = CoroutineScope(dispatcher + SupervisorJob())
    private val _tasks = MutableStateFlow<List<ExecutorTask<*, *>>>(emptyList())
    private val processingJob: Job
    private val semaphore: Semaphore = Semaphore(countsAsynchronousTasks)
    private val fetchedTasks = MutableStateFlow<ArrayDeque<ExecutorTask<*, *>>>(ArrayDeque())
    val tasks = _tasks.asStateFlow()

    init {
        processingJob = scope.launch {
            while (true) {
                fetchTask().onSuccess { task ->
                    process(task)
                }.onFailure {
                    delay(pollingMilliseconds.milliseconds)
                }
            }
        }
    }

    private fun fetchTask(): Result<ExecutorTask<*, *>> {
        return fetchedTasks.value.removeFirstOrNull()?.let { Result.success(it) }
            ?: Result.failure(NoSuchElementException())
    }

    fun addTask(task: ExecutorTask<*, *>) {
        _tasks.value += task
        fetchedTasks.value += task
    }

    fun <T> addTask(data: T, converter: (T) -> ExecutorTask<*, *>) {
        val task = converter(data)
        addTask(task)
    }

    fun removeTask(task: ExecutorTask<*, *>) {
        _tasks.value -= task
    }

    fun cancel() {
        processingJob.cancel()
    }

    private fun process(task: ExecutorTask<*, *>) {
        scope.launch {
            semaphore.withPermit {
                task runsWithCancel  {
                    cancel()
                }
                task.updateCompletionRate(100f)
            }
        }
    }
}
  1. CoroutineDispatcher 를 따로 받는 이유는 얘를 순수 Kotlin 기반으로 만들기 위해서다. JVM이면 여러가지 Executor가 있고 한데 얘를 다른데서도 쓰려면 CoroutineDispatcher를 따로 받게 해야한다.
  2. SupervisorJob은 딱히 필요는 없는거같긴한데 이게 있어야 해당 scope 내에서 실행중인 자식 코루틴에서 에러가난 경우 예외의 propagation이 안일어난다고 한다. 얘가 있어야 실행중인 작업 중 하나에서 에러가 났는데 처리가 안된 경우 다른 작업이 멈추지 않게 된다. (아마도)
  3. tasks와 fetchedTasks가 따로 있는 이유는 화면에 출력할 작업과 실행할 작업을 나누기 위해서다. 몇시간 전에 이거 처음 만들때는 그냥 tasks 하나로 했는데 이 경우 tasks에서 실행해야할 task를 분리하는데 좀 어려움이 있다. 아마 뮤텍스를 쓰면 해결될 거 같은데 이 경우 몇가지 작업을 동시에 실행하는데 문제가 있다.
  4. 세마포어는 동시에 실행되는 작업의 개수를 설정하기 위해 있다.
  5. init 에서 processingJob을 실행한다. pollingMilliseconds ms 단위로 폴링을해서 작업이 끊기지 않게 한다.
  6. fetchTask는 실행해야하는 작업을 fetch 하는데 쓰인다. 큐가 비어있으면 실패를 반환하는데 processingJob 코드에서 처럼 실패하면 기다린다.
  7. addTask에 두가지가 있는데 직접 Task를 만들어서 넣는 것, 데이터를 converter 로직을 통해서 Task로 바꿔서 넘기는 게 있다. 일단 생각은 DB에서 읽어서 Task 실행하고 결과 받는거까지 한번에 하려고 이런 짓을 했는데 아직 적용은 안 해봄(…)
  8. cacnel은 processingJob을 취소하는 역할을 한다. 어플리케이션이 종료될 때 얘를 호출해야 메모리 누수같은 걸 막을 수 있지 않을까?

대충 이런 식이다. JVM에서 사용하면 다음과 같이 사용할 수 있다.

object SequentialTaskExecutor : AbstractTaskExecutor(
    dispatcher = Executors.newCachedThreadPool().asCoroutineDispatcher(),
    countsAsynchronousTasks = 5
)

캐싱된 스레드풀의 dispatcher를 이용해서 최대 5개 실행할 수 있는 executor를 생성한다.

다음은 어플리케이션 코드 (Compose for Desktop)

@Composable
@Preview
fun App() {
    val mainScope = rememberCoroutineScope()
    val coroutineScope = remember { SequentialTaskExecutor }
    val tasks by coroutineScope.tasks.collectAsState()
    var inputText by remember { mutableStateOf("") }

    MaterialTheme {
        Column(modifier = Modifier.fillMaxSize()) {
            Row(modifier = Modifier.fillMaxWidth().wrapContentHeight().padding(3.dp)) {
                Column(modifier = Modifier.fillMaxWidth().weight(1f)) {
                    Row(modifier = Modifier.fillMaxWidth()) {
                        TextField(
                            modifier = Modifier.fillMaxWidth(),
                            value = inputText,
                            onValueChange = {
                                if (it.isEmpty() || it.toIntOrNull() != null) {
                                    inputText = it
                                }
                            },
                            singleLine = true,
                        )
                    }

                    Row(modifier = Modifier.fillMaxWidth()) {
                        Button(
                            modifier = Modifier.fillMaxWidth(),
                            onClick = {
                                mainScope.launch {
                                    coroutineScope.addTask(
                                        ExecutorTask(inputText.toInt()) {
                                            for (i in 1..10) {
                                                delay(1000.milliseconds)
                                                updateCompletionRate(i * 10f)
                                            }
                                            100 / it
                                        }
                                    )
                                }
                            },
                            enabled = inputText.toIntOrNull() != null
                        ) {
                            Text(text = "Send")
                        }
                    }
                }
            }

            Row(modifier = Modifier.fillMaxWidth().weight(1f)) {
                LazyColumn {
                    items(tasks.size) {
                        val task = tasks[it]
                        val progress by task.completionRate.collectAsState()
                        val status by task.status.collectAsState()
                        val result by task.result.collectAsState(null)
                        val exception by task.exception.collectAsState(null)
                        Row(
                            modifier = Modifier.fillMaxWidth().padding(3.dp).border(
                                1.dp, Color.Black, shape = RoundedCornerShape(
                                    CornerSize(5)
                                )
                            )
                        ) {
                            Column(modifier = Modifier.weight(1f).padding(3.dp)) {
                                Row(modifier = Modifier.fillMaxWidth()) {
                                    Text("Task [$it] : ${task.params} / $status")
                                }
                                if (result !is Unit && result != null) {
                                    Row(modifier = Modifier.fillMaxWidth()) {
                                        Text("Result : $result")
                                    }
                                }
                                if (exception != null) {
                                    Row(modifier = Modifier.fillMaxWidth()) {
                                        Text("Exception : ${exception!!.message}", color = Color.Red)
                                    }
                                }
                                Row(modifier = Modifier.fillMaxWidth()) {
                                    LinearProgressIndicator(
                                        modifier = Modifier.fillMaxWidth().height(10.dp),
                                        progress = progress / 100f,
                                    )
                                }
                            }
                            Column(modifier = Modifier.wrapContentWidth().padding(3.dp)) {
                                Button(
                                    modifier = Modifier.wrapContentSize(),
                                    onClick = { task.cancel() },
                                    enabled = status == ExecutorTaskStatus.Processing
                                ) {
                                    Text("Cancel")
                                }
                            }
                            Column(modifier = Modifier.wrapContentWidth().padding(3.dp)) {
                                Button(
                                    modifier = Modifier.wrapContentSize(),
                                    onClick = {
                                        mainScope.launch {
                                            coroutineScope.removeTask(task)
                                        }
                                    },
                                    enabled = when (status) {
                                        ExecutorTaskStatus.Error,ExecutorTaskStatus.Completed, ExecutorTaskStatus.Canceled -> true
                                        else -> false
                                    }
                                ) {
                                    Text("Delete")
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}

fun main() = application {
    Window(onCloseRequest = {
        SequentialTaskExecutor.cancel()
        exitApplication()
    }) {
        App()
    }
}

대충 숫자를 입력해서 1초씩 10번 기다린 다음(각 루프에서 실행률을 10%씩 증가) 100을 입력받은 값을 Int로 바꿔서 나눈 결과를 받아오는 형식이다.

Compose for Desktop에서는 Flow를 collectAsState를 통해 State로 바꾸면 변경 시 값이 잘 바뀐다.

그래서 실행 결과 스샷

잘 된다.

끗.


IntelliJ AI Assistant에 물어보니 이 기능은 JVM의 경우 Executors.newFixedThreadPool이 정확히 같은 기능을 한다고 한다 (…)

하지만 내가 이걸 만든 이유는 코틀린을 사용할 수 있는 어디서든지 사용하도록 하기 위함이니 뭐… 그리고 예시코드에서는 newCahcedThreadPool 을 사용하였는데 이건 사용가능한 스레드가 없는 경우 새로운 스레드를 받아서 쓰기 때문에 이론상으로는 무한히 스레드가 늘어날 수 있다는데… 세마포어가 있어서 최대로 세마포어의 퍼밋 수 만큼만 만들도록 제한이 될거라고는 한다…


while 을 통한 busy-waiting이 문제가 될거같다고 얘기하니 그러면 tasks에 변화가 있을 만한 곳에 (task를 더할 때, task가 완료 되었을 때) 다음 작업을 시작하도록 유도하는게 좋을 것 같다고 하여 아래와 같이 수정했다.

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit

abstract class AbstractTaskExecutor(
    dispatcher: CoroutineDispatcher,
    countsAsynchronousTasks: Int,
) {
    private val scope = CoroutineScope(dispatcher + SupervisorJob())
    private val _tasks = MutableStateFlow<List<ExecutorTask<*, *>>>(emptyList())
    private val semaphore: Semaphore = Semaphore(countsAsynchronousTasks)
    private val fetchedTasks = MutableStateFlow<ArrayDeque<ExecutorTask<*, *>>>(ArrayDeque())
    val tasks = _tasks.asStateFlow()

    init {
        processNextTask()
    }

    private fun fetchTask(): ExecutorTask<*, *>? {
        return fetchedTasks.value.removeFirstOrNull()
    }

    private fun processNextTask() {
        fetchTask()?.let { task ->
            process(task)
        }
    }

    fun addTask(task: ExecutorTask<*, *>) {
        _tasks.value += task
        fetchedTasks.value += task
        processNextTask()
    }

    fun <T> addTask(data: T, converter: (T) -> ExecutorTask<*, *>) {
        val task = converter(data)
        addTask(task)
    }

    fun removeTask(task: ExecutorTask<*, *>) {
        _tasks.value -= task
    }

    private fun process(task: ExecutorTask<*, *>) {
        scope.launch {
            semaphore.withPermit {
                task runsWithCancel {
                    cancel()
                }
                task.updateCompletionRate(100f)
            }
        }
        processNextTask()
    }
}

이러면 processingJob이 불필요하므로 따로 cancel을 호출할 필요도 없어지니 더 좋네

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다