1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
|
package sh.lajo.buddy
import android.Manifest
import android.content.ContentResolver
import android.content.Context
import android.content.Intent
import android.content.pm.PackageManager
import android.database.ContentObserver
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import android.os.Build
import android.os.Handler
import android.os.Looper
import android.provider.MediaStore
import android.util.Log
import androidx.core.content.ContextCompat
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
import java.nio.FloatBuffer
import java.util.Collections
class ImagesObserver(
private val context: Context,
private val contentResolver: ContentResolver
) : ContentObserver(Handler(Looper.getMainLooper())) {
companion object {
private const val TAG = "ImagesObserver"
private const val IMG_SIZE = 224
private const val MEAN = 0.5f
private const val STD = 0.5f
private val LABELS = arrayOf("normal", "nsfw")
}
private val ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
private val ortSession: OrtSession by lazy {
val modelBytes = context.assets.open("nsfw_model.onnx").readBytes()
ortEnv.createSession(modelBytes, OrtSession.SessionOptions())
}
private val inferenceExecutor = java.util.concurrent.Executors.newSingleThreadExecutor()
override fun onChange(selfChange: Boolean) {
onChange(selfChange, null)
}
override fun onChange(selfChange: Boolean, uri: Uri?) {
val permission = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
Manifest.permission.READ_MEDIA_IMAGES
} else {
Manifest.permission.READ_EXTERNAL_STORAGE
}
if (ContextCompat.checkSelfPermission(context, permission) != PackageManager.PERMISSION_GRANTED) {
Log.w(TAG, "No permission to read media images, skipping")
return
}
Log.d(TAG, "Image change detected, URI: $uri")
findAndSendNewImage()
}
private fun findAndSendNewImage() {
try {
val projection = arrayOf(
MediaStore.Images.Media._ID,
MediaStore.Images.Media.DISPLAY_NAME,
MediaStore.Images.Media.DATE_ADDED
)
val sortOrder = "${MediaStore.Images.Media.DATE_ADDED} DESC"
val cursor = contentResolver.query(
MediaStore.Images.Media.EXTERNAL_CONTENT_URI,
projection,
null,
null,
sortOrder
)
cursor?.use {
if (it.moveToFirst()) {
val id = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media._ID))
val name = it.getString(it.getColumnIndexOrThrow(MediaStore.Images.Media.DISPLAY_NAME))
val dateAdded = it.getLong(it.getColumnIndexOrThrow(MediaStore.Images.Media.DATE_ADDED))
Log.d(TAG, "Newest image: $name (ID: $id, Added: $dateAdded)")
val now = System.currentTimeMillis() / 1000
if (now - dateAdded < 30) {
val imageUri = Uri.withAppendedPath(
MediaStore.Images.Media.EXTERNAL_CONTENT_URI,
id.toString()
)
inferenceExecutor.execute {
classifyImage(imageUri, name)
}
}
}
}
} catch (e: Exception) {
Log.e(TAG, "Error finding new image", e)
}
}
private fun classifyImage(uri: Uri, name: String) {
try {
val bitmap = loadAndResizeBitmap(uri) ?: run {
Log.e(TAG, "Failed to decode bitmap for $name")
return
}
val tensorData = bitmapToFloatArray(bitmap)
val shape = longArrayOf(1, 3, IMG_SIZE.toLong(), IMG_SIZE.toLong())
val inputTensor = OnnxTensor.createTensor(ortEnv, FloatBuffer.wrap(tensorData), shape)
val inputName = ortSession.inputNames.iterator().next()
val results = ortSession.run(Collections.singletonMap(inputName, inputTensor))
val logits = (results[0].value as Array<FloatArray>)[0]
val labelIndex = if (logits[0] > logits[1]) 0 else 1
val label = LABELS[labelIndex]
val confidence = softmax(logits)[labelIndex]
Log.i(TAG, "[$name] => $label (${(confidence * 100).toInt()}%)")
inputTensor.close()
results.close()
if (label == "nsfw") {
ConfigManager.getConfig(context)?.let { config ->
if (config.galleryScanningMode == "delete") {
contentResolver.delete(uri, null, null)
Log.w(TAG, "Deleted NSFW image: $name")
} else if (config.galleryScanningMode == "notify") {
sendNsfwImageDetectedEvent(name, confidence)
Log.w(TAG, "Reported NSFW image: $name")
}
}
}
} catch (e: Exception) {
Log.e(TAG, "Inference failed for $name", e)
}
}
private fun sendNsfwImageDetectedEvent(name: String, confidence: Float) {
val intent = Intent(context, WebSocketService::class.java).apply {
action = WebSocketService.ACTION_NSFW_IMAGE_DETECTED
putExtra(WebSocketService.EXTRA_NSFW_IMAGE_NAME, name)
putExtra(WebSocketService.EXTRA_NSFW_IMAGE_CONFIDENCE, confidence)
}
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
context.startForegroundService(intent)
} else {
context.startService(intent)
}
}
private fun loadAndResizeBitmap(uri: Uri): Bitmap? {
return try {
val stream = contentResolver.openInputStream(uri) ?: return null
val raw = BitmapFactory.decodeStream(stream)
stream.close()
Bitmap.createScaledBitmap(raw, IMG_SIZE, IMG_SIZE, true)
} catch (e: Exception) {
Log.e(TAG, "Error loading bitmap", e)
null
}
}
private fun bitmapToFloatArray(bitmap: Bitmap): FloatArray {
val pixels = IntArray(IMG_SIZE * IMG_SIZE)
bitmap.getPixels(pixels, 0, IMG_SIZE, 0, 0, IMG_SIZE, IMG_SIZE)
val tensor = FloatArray(3 * IMG_SIZE * IMG_SIZE)
val channelSize = IMG_SIZE * IMG_SIZE
for (i in pixels.indices) {
val px = pixels[i]
tensor[i] = (((px shr 16) and 0xFF) / 255f - MEAN) / STD // R
tensor[i + channelSize] = (((px shr 8) and 0xFF) / 255f - MEAN) / STD // G
tensor[i + 2 * channelSize] = ((px and 0xFF) / 255f - MEAN) / STD // B
}
return tensor
}
private fun softmax(logits: FloatArray): FloatArray {
val max = logits.max()
val exps = logits.map { Math.exp((it - max).toDouble()).toFloat() }
val sum = exps.sum()
return exps.map { it / sum }.toFloatArray()
}
fun register() {
contentResolver.registerContentObserver(
MediaStore.Images.Media.EXTERNAL_CONTENT_URI,
true,
this
)
Log.d(TAG, "ImagesObserver registered")
}
fun unregister() {
contentResolver.unregisterContentObserver(this)
Log.d(TAG, "ImagesObserver unregistered")
}
fun close() {
inferenceExecutor.shutdown()
ortSession.close()
ortEnv.close()
}
}
|