ONNX를 이용하여 치와와와 머핀을 구분해보자.

새로 구입한 디바이스가 몇개 있는데 이게 스냅드래곤 X 엘리트를 탑재했다.

요즘 AI 디바이스니 뭐니 하는데 이걸 써보려고 했다. 그런데 스냅드래곤 X 엘리트에서 사용하려면 ONNX 모델로 만들어야하니 그거까지 해보려고 대충 짰다.

그래서 주제는 뭐가 좋을까… 하다가 치와와와 머핀을 구분하는 모델을 만들어봅시다.

우선 데이터셋을 얻기 위해 Kaggle에 있는 데이터를 받습니다.

https://www.kaggle.com/datasets/samuelcortinhas/muffin-vs-chihuahua-image-classification

학습도 ONNX로 되는거같긴한데… 이건 그냥 파이썬으로 합시다. 모델은 사전 트레이닝 된 ResNet18을 사용해보자.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

data_dir_root_path= "path/to/images"

# preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_data = datasets.ImageFolder(data_dir_root_path + "/train", transform=transform)
test_data = datasets.ImageFolder(data_dir_root_path + "/test", transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True)

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)

device = torch.device("mps") # Apple Silicon 사용중이라 mps로 했음
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 10

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}")

    # Validation
    model.eval()
    correct = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels.data)

    val_acc = correct.float() / len(test_loader.dataset)
    print(f"Validation Accuracy: {val_acc:.4f}")

dummy_input = torch.randn(1, 3, 224, 224)
model.to(torch.device("cpu")) # IO작업은 CPU로 옮겨야 함

onnx_file_path = 'path/to/models/chihuahua_vs_muffin.onnx'
torch.onnx.export(
    model,
    dummy_input,
    onnx_file_path,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
    opset_version=11
)

이런식으로 하면 파이토치를 통해 onnx 모델을 생성할 수 있다. 나는 M2 Max (GPU 38core) 모델로 6분 좀 넘게 걸렸다.

이제 onnx 를 자바로 불러와봅시다. 자바에서는 DJL (Deep Java Library)를 사용하고, onnxruntime-engine을 쓰면 된다.

// https://mvnrepository.com/artifact/ai.djl/api
implementation("ai.djl:api:0.31.1")
// https://mvnrepository.com/artifact/ai.djl.onnxruntime/onnxruntime-engine
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.31.1")

여기서 문제는 학습을 시킬 때 이미지를 224*224로 리사이징하고 텐서로 바꾼 다음 노멀라이즈 하는 부분인데, onnxruntime에는 파이토치의 Transform으로 제공하는 것이 없으므로 이 부분을 수동으로(…) 해야한다. 다른 방법이 있을 것 같은데 난 귀찮아서 수동으로 했다.

enum class ChihuahuaOrMuffin {
    Chihuahua, Muffin
}

@Suppress("MemberVisibilityCanBePrivate")
data class ChihuahuaOrMuffinPrediction(
    val chihuahua: Float,
    val muffin: Float
) {
    val result: ChihuahuaOrMuffin
        get() = if (chihuahua > muffin) ChihuahuaOrMuffin.Chihuahua else ChihuahuaOrMuffin.Muffin
    val probabilities: Map<ChihuahuaOrMuffin, Float>

    init {
        val sum = sequenceOf(chihuahua, muffin).map(Float::absoluteValue).sum()
        val max = maxOf(chihuahua.absoluteValue, muffin.absoluteValue)
        val min = sum - max
        probabilities = mapOf(
            ChihuahuaOrMuffin.Chihuahua to when (result) {
                ChihuahuaOrMuffin.Chihuahua -> max / sum
                ChihuahuaOrMuffin.Muffin -> min / sum
            },
            ChihuahuaOrMuffin.Muffin to when (result) {
                ChihuahuaOrMuffin.Chihuahua -> min / sum
                ChihuahuaOrMuffin.Muffin -> max / sum
            }
        )
    }
}

suspend fun preprocessImage(
    env: OrtEnvironment,
    imagePath: Path,
    resizedWidth: Int = 224,
    resizedHeight: Int = 224
): OnnxTensor = coroutineScope {
    // resize
    val originalImage = ImageIO.read(imagePath.toFile())
    val resizedImage = BufferedImage(resizedWidth, resizedWidth, BufferedImage.TYPE_INT_RGB)
    val graphics = resizedImage.createGraphics()
    graphics.drawImage(originalImage, 0, 0, resizedImage.width, resizedImage.height, null)
    graphics.dispose()

    // to floatarray and normalize
    val data = FloatArray(3 * resizedWidth * resizedHeight)

    for (y in 0 until resizedHeight) {
        for (x in 0 until resizedWidth) {
            val rgb = resizedImage.getRGB(x, y)
            val r = ((rgb shr 16) and 0xFF) / 255f
            val g = ((rgb shr 8) and 0xFF) / 255f
            val b = (rgb and 0xFF) / 255f

            // RGB 채널을 분리해서 저장해야한다.
            data[y * resizedWidth + x] = (r - 0.485f) / 0.229f
            data[y * resizedWidth + x + resizedWidth * resizedHeight] = (g - 0.456f) / 0.224f
            data[y * resizedWidth + x + resizedWidth * resizedHeight * 2] = (b - 0.406f) / 0.225f
        }
    }

    OnnxTensor.createTensor(
        env,
        FloatBuffer.wrap(data),
        longArrayOf(1, 3, resizedWidth.toLong(), resizedHeight.toLong())
    )
}

그 다음 예측함수를 만든다.

suspend fun predict(imagePath: Path): ChihuahuaOrMuffinPrediction = coroutineScope {
    OrtEnvironment.getEnvironment().use { env ->
        val modelPath =
            object {}.javaClass.classLoader.getResource("models/chihuahua_vs_muffin.onnx")!!.toURI().toPath() // resources/models/chihuahua_vs_muffin.onnx 로 위에서 파이썬으로 만든 모델 추가해둠
        val ortSessionOptions = OrtSession.SessionOptions()
        ortSessionOptions.addCoreML() // Apple Silicon에서 실행하므로 CoreML을 넣어서 NPU를 사용하도록 옵션 추가
        env.createSession(modelPath.toString(), ortSessionOptions).use { session ->
            val inputName = session.inputNames.iterator().next()
            val outputName = session.outputNames.iterator().next()

            preprocessImage(env, imagePath).use { inputTensor ->
                session.run(mapOf(inputName to inputTensor))[outputName].orElseThrow().use { outputTensor ->
                    val predictions = (outputTensor.value as Array<*>)[0] as FloatArray
                    ChihuahuaOrMuffinPrediction(predictions[0], predictions[1])
                }
            }
        }
    }
}

Compose Desktop에서 기본 FilePicker가 없으므로 이미지 파일 셀렉트를 위해 만들어준다. 옛날에 했을땐 되게 복잡하게 했었던 거 같은데, 그냥 AWT 를 사용해서 만든다.

@Composable
fun FilePicker(
    parent: Frame? = null,
    onCloseRequest: (result: String?) -> Unit
) = AwtWindow(create = {
    object : FileDialog(parent, "Choose Image", LOAD) {
        override fun setVisible(value: Boolean) {
            super.setVisible(value)
            if (value) {
                onCloseRequest(directory + file)
            }
        }
    }
}, dispose = FileDialog::dispose)

그다음 메인 Composable fun을 만든다.

@Composable
@Preview
fun App() {
    var resultStateFlow by remember { mutableStateOf<ChihuahuaOrMuffinPrediction?>(null) }
    var imagePath by remember { mutableStateOf("") }
    var loading by remember { mutableStateOf(false) }
    var predictionRunning by remember { mutableStateOf(false) }
    val coroutineScope = rememberCoroutineScope()
    var fileDialogOpened by remember { mutableStateOf(false) }

    Box(modifier = Modifier.fillMaxSize().padding(16.dp)) {
        Column(modifier = Modifier.fillMaxSize()) {
            Row(modifier = Modifier.fillMaxWidth()) {
                Column(modifier = Modifier.weight(0.5f).padding(horizontal = 8.dp)) {
                    Button(modifier = Modifier.fillMaxWidth(), onClick = {
                        coroutineScope.launch {
                            loading = true
                            fileDialogOpened = true
                            loading = false
                        }
                    }) {
                        Text("Open Image")
                    }
                }
                Column(modifier = Modifier.weight(0.5f).padding(horizontal = 8.dp)) {
                    Button(modifier = Modifier.fillMaxWidth(), onClick = {
                        coroutineScope.launch {
                            predictionRunning = true
                            withContext(Dispatchers.IO) {
                                resultStateFlow = predict(Path(imagePath))
                            }
                            predictionRunning = false
                        }
                    }, enabled = !predictionRunning && imagePath.isNotEmpty() && !loading) {
                        Text("Predict")
                    }
                }
            }
            Row(modifier = Modifier.fillMaxWidth().wrapContentHeight()) {
                if (imagePath.isNotEmpty()) {
                    if (!loading) {
                        Image(
                            bitmap = ImageIO.read(Path(imagePath).toFile()).toComposeImageBitmap(),
                            contentDescription = null,
                            modifier = Modifier.fillMaxWidth().height(300.dp)
                        )
                    } else {
                        Text("Loading...")
                    }
                }
            }
            Row(modifier = Modifier.fillMaxWidth().wrapContentHeight()) {
                if (resultStateFlow != null) {
                    Box(modifier = Modifier.fillMaxWidth(), contentAlignment = Alignment.Center) {
                        val emoji = when (resultStateFlow!!.result) {
                            ChihuahuaOrMuffin.Chihuahua -> "\uD83D\uDC36" // 🐶
                            ChihuahuaOrMuffin.Muffin -> "\uD83E\uDDC1" // 🧁
                        }
                        Text(emoji, fontSize = TextUnit(6f, TextUnitType.Em))
                    }
                }
            }
        }
    }

    if (fileDialogOpened) {
        FilePicker {
            imagePath = it ?: ""
            fileDialogOpened = false
        }
    }
}

메인 함수를 만든다.

fun main() = application {
    Window(
        onCloseRequest = ::exitApplication,
        title = "ONNX App",
        state = rememberWindowState(width = 800.dp, height = 640.dp)
    ) {
        MaterialTheme {
            App()
        }
    }
}

테스트를 해본다. 위에서 받은 Kaggle 이미지들 중 테스트에 있는 이미지로…

잘 된다.

끗.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다