새로 구입한 디바이스가 몇개 있는데 이게 스냅드래곤 X 엘리트를 탑재했다.
요즘 AI 디바이스니 뭐니 하는데 이걸 써보려고 했다. 그런데 스냅드래곤 X 엘리트에서 사용하려면 ONNX 모델로 만들어야하니 그거까지 해보려고 대충 짰다.
그래서 주제는 뭐가 좋을까… 하다가 치와와와 머핀을 구분하는 모델을 만들어봅시다.
우선 데이터셋을 얻기 위해 Kaggle에 있는 데이터를 받습니다.
https://www.kaggle.com/datasets/samuelcortinhas/muffin-vs-chihuahua-image-classification
학습도 ONNX로 되는거같긴한데… 이건 그냥 파이썬으로 합시다. 모델은 사전 트레이닝 된 ResNet18을 사용해보자.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
data_dir_root_path= "path/to/images"
# preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_data = datasets.ImageFolder(data_dir_root_path + "/train", transform=transform)
test_data = datasets.ImageFolder(data_dir_root_path + "/test", transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True)
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device("mps") # Apple Silicon 사용중이라 mps로 했음
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 10
for epoch in range(epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}")
# Validation
model.eval()
correct = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
correct += torch.sum(preds == labels.data)
val_acc = correct.float() / len(test_loader.dataset)
print(f"Validation Accuracy: {val_acc:.4f}")
dummy_input = torch.randn(1, 3, 224, 224)
model.to(torch.device("cpu")) # IO작업은 CPU로 옮겨야 함
onnx_file_path = 'path/to/models/chihuahua_vs_muffin.onnx'
torch.onnx.export(
model,
dummy_input,
onnx_file_path,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
opset_version=11
)
이런식으로 하면 파이토치를 통해 onnx 모델을 생성할 수 있다. 나는 M2 Max (GPU 38core) 모델로 6분 좀 넘게 걸렸다.
이제 onnx 를 자바로 불러와봅시다. 자바에서는 DJL (Deep Java Library)를 사용하고, onnxruntime-engine을 쓰면 된다.
// https://mvnrepository.com/artifact/ai.djl/api
implementation("ai.djl:api:0.31.1")
// https://mvnrepository.com/artifact/ai.djl.onnxruntime/onnxruntime-engine
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.31.1")
여기서 문제는 학습을 시킬 때 이미지를 224*224로 리사이징하고 텐서로 바꾼 다음 노멀라이즈 하는 부분인데, onnxruntime에는 파이토치의 Transform으로 제공하는 것이 없으므로 이 부분을 수동으로(…) 해야한다. 다른 방법이 있을 것 같은데 난 귀찮아서 수동으로 했다.
enum class ChihuahuaOrMuffin {
Chihuahua, Muffin
}
@Suppress("MemberVisibilityCanBePrivate")
data class ChihuahuaOrMuffinPrediction(
val chihuahua: Float,
val muffin: Float
) {
val result: ChihuahuaOrMuffin
get() = if (chihuahua > muffin) ChihuahuaOrMuffin.Chihuahua else ChihuahuaOrMuffin.Muffin
val probabilities: Map<ChihuahuaOrMuffin, Float>
init {
val sum = sequenceOf(chihuahua, muffin).map(Float::absoluteValue).sum()
val max = maxOf(chihuahua.absoluteValue, muffin.absoluteValue)
val min = sum - max
probabilities = mapOf(
ChihuahuaOrMuffin.Chihuahua to when (result) {
ChihuahuaOrMuffin.Chihuahua -> max / sum
ChihuahuaOrMuffin.Muffin -> min / sum
},
ChihuahuaOrMuffin.Muffin to when (result) {
ChihuahuaOrMuffin.Chihuahua -> min / sum
ChihuahuaOrMuffin.Muffin -> max / sum
}
)
}
}
suspend fun preprocessImage(
env: OrtEnvironment,
imagePath: Path,
resizedWidth: Int = 224,
resizedHeight: Int = 224
): OnnxTensor = coroutineScope {
// resize
val originalImage = ImageIO.read(imagePath.toFile())
val resizedImage = BufferedImage(resizedWidth, resizedWidth, BufferedImage.TYPE_INT_RGB)
val graphics = resizedImage.createGraphics()
graphics.drawImage(originalImage, 0, 0, resizedImage.width, resizedImage.height, null)
graphics.dispose()
// to floatarray and normalize
val data = FloatArray(3 * resizedWidth * resizedHeight)
for (y in 0 until resizedHeight) {
for (x in 0 until resizedWidth) {
val rgb = resizedImage.getRGB(x, y)
val r = ((rgb shr 16) and 0xFF) / 255f
val g = ((rgb shr 8) and 0xFF) / 255f
val b = (rgb and 0xFF) / 255f
// RGB 채널을 분리해서 저장해야한다.
data[y * resizedWidth + x] = (r - 0.485f) / 0.229f
data[y * resizedWidth + x + resizedWidth * resizedHeight] = (g - 0.456f) / 0.224f
data[y * resizedWidth + x + resizedWidth * resizedHeight * 2] = (b - 0.406f) / 0.225f
}
}
OnnxTensor.createTensor(
env,
FloatBuffer.wrap(data),
longArrayOf(1, 3, resizedWidth.toLong(), resizedHeight.toLong())
)
}
그 다음 예측함수를 만든다.
suspend fun predict(imagePath: Path): ChihuahuaOrMuffinPrediction = coroutineScope {
OrtEnvironment.getEnvironment().use { env ->
val modelPath =
object {}.javaClass.classLoader.getResource("models/chihuahua_vs_muffin.onnx")!!.toURI().toPath() // resources/models/chihuahua_vs_muffin.onnx 로 위에서 파이썬으로 만든 모델 추가해둠
val ortSessionOptions = OrtSession.SessionOptions()
ortSessionOptions.addCoreML() // Apple Silicon에서 실행하므로 CoreML을 넣어서 NPU를 사용하도록 옵션 추가
env.createSession(modelPath.toString(), ortSessionOptions).use { session ->
val inputName = session.inputNames.iterator().next()
val outputName = session.outputNames.iterator().next()
preprocessImage(env, imagePath).use { inputTensor ->
session.run(mapOf(inputName to inputTensor))[outputName].orElseThrow().use { outputTensor ->
val predictions = (outputTensor.value as Array<*>)[0] as FloatArray
ChihuahuaOrMuffinPrediction(predictions[0], predictions[1])
}
}
}
}
}
Compose Desktop에서 기본 FilePicker가 없으므로 이미지 파일 셀렉트를 위해 만들어준다. 옛날에 했을땐 되게 복잡하게 했었던 거 같은데, 그냥 AWT 를 사용해서 만든다.
@Composable
fun FilePicker(
parent: Frame? = null,
onCloseRequest: (result: String?) -> Unit
) = AwtWindow(create = {
object : FileDialog(parent, "Choose Image", LOAD) {
override fun setVisible(value: Boolean) {
super.setVisible(value)
if (value) {
onCloseRequest(directory + file)
}
}
}
}, dispose = FileDialog::dispose)
그다음 메인 Composable fun을 만든다.
@Composable
@Preview
fun App() {
var resultStateFlow by remember { mutableStateOf<ChihuahuaOrMuffinPrediction?>(null) }
var imagePath by remember { mutableStateOf("") }
var loading by remember { mutableStateOf(false) }
var predictionRunning by remember { mutableStateOf(false) }
val coroutineScope = rememberCoroutineScope()
var fileDialogOpened by remember { mutableStateOf(false) }
Box(modifier = Modifier.fillMaxSize().padding(16.dp)) {
Column(modifier = Modifier.fillMaxSize()) {
Row(modifier = Modifier.fillMaxWidth()) {
Column(modifier = Modifier.weight(0.5f).padding(horizontal = 8.dp)) {
Button(modifier = Modifier.fillMaxWidth(), onClick = {
coroutineScope.launch {
loading = true
fileDialogOpened = true
loading = false
}
}) {
Text("Open Image")
}
}
Column(modifier = Modifier.weight(0.5f).padding(horizontal = 8.dp)) {
Button(modifier = Modifier.fillMaxWidth(), onClick = {
coroutineScope.launch {
predictionRunning = true
withContext(Dispatchers.IO) {
resultStateFlow = predict(Path(imagePath))
}
predictionRunning = false
}
}, enabled = !predictionRunning && imagePath.isNotEmpty() && !loading) {
Text("Predict")
}
}
}
Row(modifier = Modifier.fillMaxWidth().wrapContentHeight()) {
if (imagePath.isNotEmpty()) {
if (!loading) {
Image(
bitmap = ImageIO.read(Path(imagePath).toFile()).toComposeImageBitmap(),
contentDescription = null,
modifier = Modifier.fillMaxWidth().height(300.dp)
)
} else {
Text("Loading...")
}
}
}
Row(modifier = Modifier.fillMaxWidth().wrapContentHeight()) {
if (resultStateFlow != null) {
Box(modifier = Modifier.fillMaxWidth(), contentAlignment = Alignment.Center) {
val emoji = when (resultStateFlow!!.result) {
ChihuahuaOrMuffin.Chihuahua -> "\uD83D\uDC36" // 🐶
ChihuahuaOrMuffin.Muffin -> "\uD83E\uDDC1" // 🧁
}
Text(emoji, fontSize = TextUnit(6f, TextUnitType.Em))
}
}
}
}
}
if (fileDialogOpened) {
FilePicker {
imagePath = it ?: ""
fileDialogOpened = false
}
}
}
메인 함수를 만든다.
fun main() = application {
Window(
onCloseRequest = ::exitApplication,
title = "ONNX App",
state = rememberWindowState(width = 800.dp, height = 640.dp)
) {
MaterialTheme {
App()
}
}
}
테스트를 해본다. 위에서 받은 Kaggle 이미지들 중 테스트에 있는 이미지로…
잘 된다.
끗.
답글 남기기