summaryrefslogtreecommitdiff
path: root/app/src/main/java/sh/lajo/buddy/ImagesObserver.kt
diff options
context:
space:
mode:
Diffstat (limited to 'app/src/main/java/sh/lajo/buddy/ImagesObserver.kt')
-rw-r--r--app/src/main/java/sh/lajo/buddy/ImagesObserver.kt218
1 files changed, 218 insertions, 0 deletions
diff --git a/app/src/main/java/sh/lajo/buddy/ImagesObserver.kt b/app/src/main/java/sh/lajo/buddy/ImagesObserver.kt
new file mode 100644
index 0000000..f5148fb
--- /dev/null
+++ b/app/src/main/java/sh/lajo/buddy/ImagesObserver.kt
@@ -0,0 +1,218 @@
+package sh.lajo.buddy
+
+import android.Manifest
+import android.content.ContentResolver
+import android.content.Context
+import android.content.Intent
+import android.content.pm.PackageManager
+import android.database.ContentObserver
+import android.graphics.Bitmap
+import android.graphics.BitmapFactory
+import android.net.Uri
+import android.os.Build
+import android.os.Handler
+import android.os.Looper
+import android.provider.MediaStore
+import android.util.Log
+import androidx.core.content.ContextCompat
+import ai.onnxruntime.OnnxTensor
+import ai.onnxruntime.OrtEnvironment
+import ai.onnxruntime.OrtSession
+import java.nio.FloatBuffer
+import java.util.Collections
+
+class ImagesObserver(
+ private val context: Context,
+ private val contentResolver: ContentResolver
+) : ContentObserver(Handler(Looper.getMainLooper())) {
+
+ companion object {
+ private const val TAG = "ImagesObserver"
+ private const val IMG_SIZE = 224
+ private const val MEAN = 0.5f
+ private const val STD = 0.5f
+ private val LABELS = arrayOf("normal", "nsfw")
+ }
+
+ private val ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
+ private val ortSession: OrtSession by lazy {
+ val modelBytes = context.assets.open("nsfw_model.onnx").readBytes()
+ ortEnv.createSession(modelBytes, OrtSession.SessionOptions())
+ }
+ private val inferenceExecutor = java.util.concurrent.Executors.newSingleThreadExecutor()
+
+ override fun onChange(selfChange: Boolean) {
+ onChange(selfChange, null)
+ }
+
+ override fun onChange(selfChange: Boolean, uri: Uri?) {
+ val permission = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
+ Manifest.permission.READ_MEDIA_IMAGES
+ } else {
+ Manifest.permission.READ_EXTERNAL_STORAGE
+ }
+
+ if (ContextCompat.checkSelfPermission(context, permission) != PackageManager.PERMISSION_GRANTED) {
+ Log.w(TAG, "No permission to read media images, skipping")
+ return
+ }
+
+ Log.d(TAG, "Image change detected, URI: $uri")
+ findAndSendNewImage()
+ }
+
+ private fun findAndSendNewImage() {
+ try {
+ val projection = arrayOf(
+ MediaStore.Images.Media._ID,
+ MediaStore.Images.Media.DISPLAY_NAME,
+ MediaStore.Images.Media.DATE_ADDED
+ )
+
+ val sortOrder = "${MediaStore.Images.Media.DATE_ADDED} DESC"
+
+ val cursor = contentResolver.query(
+ MediaStore.Images.Media.EXTERNAL_CONTENT_URI,
+ projection,
+ null,
+ null,
+ sortOrder
+ )
+
+ cursor?.use {
+ if (it.moveToFirst()) {
+ val id = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media._ID))
+ val name = it.getString(it.getColumnIndexOrThrow(MediaStore.Images.Media.DISPLAY_NAME))
+ val dateAdded = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media.DATE_ADDED))
+
+ Log.d(TAG, "Newest image: $name (ID: $id, Added: $dateAdded)")
+
+ val now = System.currentTimeMillis() / 1000
+ if (now - dateAdded < 30) {
+ val imageUri = Uri.withAppendedPath(
+ MediaStore.Images.Media.EXTERNAL_CONTENT_URI,
+ id.toString()
+ )
+ inferenceExecutor.execute {
+ classifyImage(imageUri, name)
+ }
+ }
+ }
+ }
+ } catch (e: Exception) {
+ Log.e(TAG, "Error finding new image", e)
+ }
+ }
+
+ private fun classifyImage(uri: Uri, name: String) {
+ try {
+ val bitmap = loadAndResizeBitmap(uri) ?: run {
+ Log.e(TAG, "Failed to decode bitmap for $name")
+ return
+ }
+
+ val tensorData = bitmapToFloatArray(bitmap)
+ val shape = longArrayOf(1, 3, IMG_SIZE.toLong(), IMG_SIZE.toLong())
+ val inputTensor = OnnxTensor.createTensor(ortEnv, FloatBuffer.wrap(tensorData), shape)
+
+ val inputName = ortSession.inputNames.iterator().next()
+ val results = ortSession.run(Collections.singletonMap(inputName, inputTensor))
+
+ val logits = (results[0].value as Array<FloatArray>)[0]
+ val labelIndex = if (logits[0] > logits[1]) 0 else 1
+ val label = LABELS[labelIndex]
+ val confidence = softmax(logits)[labelIndex]
+
+ Log.i(TAG, "[$name] => $label (${(confidence * 100).toInt()}%)")
+
+ inputTensor.close()
+ results.close()
+
+ if (label == "nsfw") {
+ ConfigManager.getConfig(context)?.let { config ->
+ if (config.galleryScanningMode == "delete") {
+ contentResolver.delete(uri, null, null)
+ Log.w(TAG, "Deleted NSFW image: $name")
+ } else if (config.galleryScanningMode == "notify") {
+ sendNsfwImageDetectedEvent(name, confidence)
+ Log.w(TAG, "Reported NSFW image: $name")
+ }
+ }
+ }
+
+ } catch (e: Exception) {
+ Log.e(TAG, "Inference failed for $name", e)
+ }
+ }
+
+ private fun sendNsfwImageDetectedEvent(name: String, confidence: Float) {
+ val intent = Intent(context, WebSocketService::class.java).apply {
+ action = WebSocketService.ACTION_NSFW_IMAGE_DETECTED
+ putExtra(WebSocketService.EXTRA_NSFW_IMAGE_NAME, name)
+ putExtra(WebSocketService.EXTRA_NSFW_IMAGE_CONFIDENCE, confidence)
+ }
+
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
+ context.startForegroundService(intent)
+ } else {
+ context.startService(intent)
+ }
+ }
+
+ private fun loadAndResizeBitmap(uri: Uri): Bitmap? {
+ return try {
+ val stream = contentResolver.openInputStream(uri) ?: return null
+ val raw = BitmapFactory.decodeStream(stream)
+ stream.close()
+ Bitmap.createScaledBitmap(raw, IMG_SIZE, IMG_SIZE, true)
+ } catch (e: Exception) {
+ Log.e(TAG, "Error loading bitmap", e)
+ null
+ }
+ }
+
+ private fun bitmapToFloatArray(bitmap: Bitmap): FloatArray {
+ val pixels = IntArray(IMG_SIZE * IMG_SIZE)
+ bitmap.getPixels(pixels, 0, IMG_SIZE, 0, 0, IMG_SIZE, IMG_SIZE)
+
+ val tensor = FloatArray(3 * IMG_SIZE * IMG_SIZE)
+ val channelSize = IMG_SIZE * IMG_SIZE
+
+ for (i in pixels.indices) {
+ val px = pixels[i]
+ tensor[i] = (((px shr 16) and 0xFF) / 255f - MEAN) / STD // R
+ tensor[i + channelSize] = (((px shr 8) and 0xFF) / 255f - MEAN) / STD // G
+ tensor[i + 2 * channelSize] = ((px and 0xFF) / 255f - MEAN) / STD // B
+ }
+
+ return tensor
+ }
+
+ private fun softmax(logits: FloatArray): FloatArray {
+ val max = logits.max()
+ val exps = logits.map { Math.exp((it - max).toDouble()).toFloat() }
+ val sum = exps.sum()
+ return exps.map { it / sum }.toFloatArray()
+ }
+
+ fun register() {
+
+ contentResolver.registerContentObserver(
+ MediaStore.Images.Media.EXTERNAL_CONTENT_URI,
+ true,
+ this
+ )
+ Log.d(TAG, "ImagesObserver registered")
+ }
+
+ fun unregister() {
+ contentResolver.unregisterContentObserver(this)
+ Log.d(TAG, "ImagesObserver unregistered")
+ }
+
+ fun close() {
+ inferenceExecutor.shutdown()
+ ortSession.close()
+ ortEnv.close()
+ }
+} \ No newline at end of file