diff options
Diffstat (limited to 'app/src/main/java/sh/lajo/buddy/ImagesObserver.kt')
| -rw-r--r-- | app/src/main/java/sh/lajo/buddy/ImagesObserver.kt | 218 |
1 files changed, 218 insertions, 0 deletions
diff --git a/app/src/main/java/sh/lajo/buddy/ImagesObserver.kt b/app/src/main/java/sh/lajo/buddy/ImagesObserver.kt new file mode 100644 index 0000000..f5148fb --- /dev/null +++ b/app/src/main/java/sh/lajo/buddy/ImagesObserver.kt @@ -0,0 +1,218 @@ +package sh.lajo.buddy + +import android.Manifest +import android.content.ContentResolver +import android.content.Context +import android.content.Intent +import android.content.pm.PackageManager +import android.database.ContentObserver +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.net.Uri +import android.os.Build +import android.os.Handler +import android.os.Looper +import android.provider.MediaStore +import android.util.Log +import androidx.core.content.ContextCompat +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import java.nio.FloatBuffer +import java.util.Collections + +class ImagesObserver( + private val context: Context, + private val contentResolver: ContentResolver +) : ContentObserver(Handler(Looper.getMainLooper())) { + + companion object { + private const val TAG = "ImagesObserver" + private const val IMG_SIZE = 224 + private const val MEAN = 0.5f + private const val STD = 0.5f + private val LABELS = arrayOf("normal", "nsfw") + } + + private val ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() + private val ortSession: OrtSession by lazy { + val modelBytes = context.assets.open("nsfw_model.onnx").readBytes() + ortEnv.createSession(modelBytes, OrtSession.SessionOptions()) + } + private val inferenceExecutor = java.util.concurrent.Executors.newSingleThreadExecutor() + + override fun onChange(selfChange: Boolean) { + onChange(selfChange, null) + } + + override fun onChange(selfChange: Boolean, uri: Uri?) { + val permission = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + Manifest.permission.READ_MEDIA_IMAGES + } else { + Manifest.permission.READ_EXTERNAL_STORAGE + } + + if (ContextCompat.checkSelfPermission(context, permission) != PackageManager.PERMISSION_GRANTED) { + Log.w(TAG, "No permission to read media images, skipping") + return + } + + Log.d(TAG, "Image change detected, URI: $uri") + findAndSendNewImage() + } + + private fun findAndSendNewImage() { + try { + val projection = arrayOf( + MediaStore.Images.Media._ID, + MediaStore.Images.Media.DISPLAY_NAME, + MediaStore.Images.Media.DATE_ADDED + ) + + val sortOrder = "${MediaStore.Images.Media.DATE_ADDED} DESC" + + val cursor = contentResolver.query( + MediaStore.Images.Media.EXTERNAL_CONTENT_URI, + projection, + null, + null, + sortOrder + ) + + cursor?.use { + if (it.moveToFirst()) { + val id = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media._ID)) + val name = it.getString(it.getColumnIndexOrThrow(MediaStore.Images.Media.DISPLAY_NAME)) + val dateAdded = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media.DATE_ADDED)) + + Log.d(TAG, "Newest image: $name (ID: $id, Added: $dateAdded)") + + val now = System.currentTimeMillis() / 1000 + if (now - dateAdded < 30) { + val imageUri = Uri.withAppendedPath( + MediaStore.Images.Media.EXTERNAL_CONTENT_URI, + id.toString() + ) + inferenceExecutor.execute { + classifyImage(imageUri, name) + } + } + } + } + } catch (e: Exception) { + Log.e(TAG, "Error finding new image", e) + } + } + + private fun classifyImage(uri: Uri, name: String) { + try { + val bitmap = loadAndResizeBitmap(uri) ?: run { + Log.e(TAG, "Failed to decode bitmap for $name") + return + } + + val tensorData = bitmapToFloatArray(bitmap) + val shape = longArrayOf(1, 3, IMG_SIZE.toLong(), IMG_SIZE.toLong()) + val inputTensor = OnnxTensor.createTensor(ortEnv, FloatBuffer.wrap(tensorData), shape) + + val inputName = ortSession.inputNames.iterator().next() + val results = ortSession.run(Collections.singletonMap(inputName, inputTensor)) + + val logits = (results[0].value as Array<FloatArray>)[0] + val labelIndex = if (logits[0] > logits[1]) 0 else 1 + val label = LABELS[labelIndex] + val confidence = softmax(logits)[labelIndex] + + Log.i(TAG, "[$name] => $label (${(confidence * 100).toInt()}%)") + + inputTensor.close() + results.close() + + if (label == "nsfw") { + ConfigManager.getConfig(context)?.let { config -> + if (config.galleryScanningMode == "delete") { + contentResolver.delete(uri, null, null) + Log.w(TAG, "Deleted NSFW image: $name") + } else if (config.galleryScanningMode == "notify") { + sendNsfwImageDetectedEvent(name, confidence) + Log.w(TAG, "Reported NSFW image: $name") + } + } + } + + } catch (e: Exception) { + Log.e(TAG, "Inference failed for $name", e) + } + } + + private fun sendNsfwImageDetectedEvent(name: String, confidence: Float) { + val intent = Intent(context, WebSocketService::class.java).apply { + action = WebSocketService.ACTION_NSFW_IMAGE_DETECTED + putExtra(WebSocketService.EXTRA_NSFW_IMAGE_NAME, name) + putExtra(WebSocketService.EXTRA_NSFW_IMAGE_CONFIDENCE, confidence) + } + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + context.startForegroundService(intent) + } else { + context.startService(intent) + } + } + + private fun loadAndResizeBitmap(uri: Uri): Bitmap? { + return try { + val stream = contentResolver.openInputStream(uri) ?: return null + val raw = BitmapFactory.decodeStream(stream) + stream.close() + Bitmap.createScaledBitmap(raw, IMG_SIZE, IMG_SIZE, true) + } catch (e: Exception) { + Log.e(TAG, "Error loading bitmap", e) + null + } + } + + private fun bitmapToFloatArray(bitmap: Bitmap): FloatArray { + val pixels = IntArray(IMG_SIZE * IMG_SIZE) + bitmap.getPixels(pixels, 0, IMG_SIZE, 0, 0, IMG_SIZE, IMG_SIZE) + + val tensor = FloatArray(3 * IMG_SIZE * IMG_SIZE) + val channelSize = IMG_SIZE * IMG_SIZE + + for (i in pixels.indices) { + val px = pixels[i] + tensor[i] = (((px shr 16) and 0xFF) / 255f - MEAN) / STD // R + tensor[i + channelSize] = (((px shr 8) and 0xFF) / 255f - MEAN) / STD // G + tensor[i + 2 * channelSize] = ((px and 0xFF) / 255f - MEAN) / STD // B + } + + return tensor + } + + private fun softmax(logits: FloatArray): FloatArray { + val max = logits.max() + val exps = logits.map { Math.exp((it - max).toDouble()).toFloat() } + val sum = exps.sum() + return exps.map { it / sum }.toFloatArray() + } + + fun register() { + + contentResolver.registerContentObserver( + MediaStore.Images.Media.EXTERNAL_CONTENT_URI, + true, + this + ) + Log.d(TAG, "ImagesObserver registered") + } + + fun unregister() { + contentResolver.unregisterContentObserver(this) + Log.d(TAG, "ImagesObserver unregistered") + } + + fun close() { + inferenceExecutor.shutdown() + ortSession.close() + ortEnv.close() + } +}
\ No newline at end of file |