package sh.lajo.buddy import android.Manifest import android.content.ContentResolver import android.content.Context import android.content.Intent import android.content.pm.PackageManager import android.database.ContentObserver import android.graphics.Bitmap import android.graphics.BitmapFactory import android.net.Uri import android.os.Build import android.os.Handler import android.os.Looper import android.provider.MediaStore import android.util.Log import androidx.core.content.ContextCompat import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtEnvironment import ai.onnxruntime.OrtSession import java.nio.FloatBuffer import java.util.Collections class ImagesObserver( private val context: Context, private val contentResolver: ContentResolver ) : ContentObserver(Handler(Looper.getMainLooper())) { companion object { private const val TAG = "ImagesObserver" private const val IMG_SIZE = 224 private const val MEAN = 0.5f private const val STD = 0.5f private val LABELS = arrayOf("normal", "nsfw") } private val ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private val ortSession: OrtSession by lazy { val modelBytes = context.assets.open("nsfw_model.onnx").readBytes() ortEnv.createSession(modelBytes, OrtSession.SessionOptions()) } private val inferenceExecutor = java.util.concurrent.Executors.newSingleThreadExecutor() override fun onChange(selfChange: Boolean) { onChange(selfChange, null) } override fun onChange(selfChange: Boolean, uri: Uri?) { val permission = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { Manifest.permission.READ_MEDIA_IMAGES } else { Manifest.permission.READ_EXTERNAL_STORAGE } if (ContextCompat.checkSelfPermission(context, permission) != PackageManager.PERMISSION_GRANTED) { Log.w(TAG, "No permission to read media images, skipping") return } Log.d(TAG, "Image change detected, URI: $uri") findAndSendNewImage() } private fun findAndSendNewImage() { try { val projection = arrayOf( MediaStore.Images.Media._ID, MediaStore.Images.Media.DISPLAY_NAME, MediaStore.Images.Media.DATE_ADDED ) val sortOrder = "${MediaStore.Images.Media.DATE_ADDED} DESC" val cursor = contentResolver.query( MediaStore.Images.Media.EXTERNAL_CONTENT_URI, projection, null, null, sortOrder ) cursor?.use { if (it.moveToFirst()) { val id = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media._ID)) val name = it.getString(it.getColumnIndexOrThrow(MediaStore.Images.Media.DISPLAY_NAME)) val dateAdded = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media.DATE_ADDED)) Log.d(TAG, "Newest image: $name (ID: $id, Added: $dateAdded)") val now = System.currentTimeMillis() / 1000 if (now - dateAdded < 30) { val imageUri = Uri.withAppendedPath( MediaStore.Images.Media.EXTERNAL_CONTENT_URI, id.toString() ) inferenceExecutor.execute { classifyImage(imageUri, name) } } } } } catch (e: Exception) { Log.e(TAG, "Error finding new image", e) } } private fun classifyImage(uri: Uri, name: String) { try { val bitmap = loadAndResizeBitmap(uri) ?: run { Log.e(TAG, "Failed to decode bitmap for $name") return } val tensorData = bitmapToFloatArray(bitmap) val shape = longArrayOf(1, 3, IMG_SIZE.toLong(), IMG_SIZE.toLong()) val inputTensor = OnnxTensor.createTensor(ortEnv, FloatBuffer.wrap(tensorData), shape) val inputName = ortSession.inputNames.iterator().next() val results = ortSession.run(Collections.singletonMap(inputName, inputTensor)) val logits = (results[0].value as Array)[0] val labelIndex = if (logits[0] > logits[1]) 0 else 1 val label = LABELS[labelIndex] val confidence = softmax(logits)[labelIndex] Log.i(TAG, "[$name] => $label (${(confidence * 100).toInt()}%)") inputTensor.close() results.close() if (label == "nsfw") { ConfigManager.getConfig(context)?.let { config -> if (config.galleryScanningMode == "delete") { contentResolver.delete(uri, null, null) Log.w(TAG, "Deleted NSFW image: $name") } else if (config.galleryScanningMode == "notify") { sendNsfwImageDetectedEvent(name, confidence) Log.w(TAG, "Reported NSFW image: $name") } } } } catch (e: Exception) { Log.e(TAG, "Inference failed for $name", e) } } private fun sendNsfwImageDetectedEvent(name: String, confidence: Float) { val intent = Intent(context, WebSocketService::class.java).apply { action = WebSocketService.ACTION_NSFW_IMAGE_DETECTED putExtra(WebSocketService.EXTRA_NSFW_IMAGE_NAME, name) putExtra(WebSocketService.EXTRA_NSFW_IMAGE_CONFIDENCE, confidence) } if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { context.startForegroundService(intent) } else { context.startService(intent) } } private fun loadAndResizeBitmap(uri: Uri): Bitmap? { return try { val stream = contentResolver.openInputStream(uri) ?: return null val raw = BitmapFactory.decodeStream(stream) stream.close() Bitmap.createScaledBitmap(raw, IMG_SIZE, IMG_SIZE, true) } catch (e: Exception) { Log.e(TAG, "Error loading bitmap", e) null } } private fun bitmapToFloatArray(bitmap: Bitmap): FloatArray { val pixels = IntArray(IMG_SIZE * IMG_SIZE) bitmap.getPixels(pixels, 0, IMG_SIZE, 0, 0, IMG_SIZE, IMG_SIZE) val tensor = FloatArray(3 * IMG_SIZE * IMG_SIZE) val channelSize = IMG_SIZE * IMG_SIZE for (i in pixels.indices) { val px = pixels[i] tensor[i] = (((px shr 16) and 0xFF) / 255f - MEAN) / STD // R tensor[i + channelSize] = (((px shr 8) and 0xFF) / 255f - MEAN) / STD // G tensor[i + 2 * channelSize] = ((px and 0xFF) / 255f - MEAN) / STD // B } return tensor } private fun softmax(logits: FloatArray): FloatArray { val max = logits.max() val exps = logits.map { Math.exp((it - max).toDouble()).toFloat() } val sum = exps.sum() return exps.map { it / sum }.toFloatArray() } fun register() { contentResolver.registerContentObserver( MediaStore.Images.Media.EXTERNAL_CONTENT_URI, true, this ) Log.d(TAG, "ImagesObserver registered") } fun unregister() { contentResolver.unregisterContentObserver(this) Log.d(TAG, "ImagesObserver unregistered") } fun close() { inferenceExecutor.shutdown() ortSession.close() ortEnv.close() } }