summaryrefslogtreecommitdiff
path: root/app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt
diff options
context:
space:
mode:
authorJustZvan <justzvan@justzvan.xyz>2026-02-06 13:38:36 +0100
committerJustZvan <justzvan@justzvan.xyz>2026-02-06 13:38:36 +0100
commitadb6a4fd9ec3a23c04d5e4c2ce799448237915c4 (patch)
tree786edcf5888788e0667a90fae96d7ebec68c507a /app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt
feat: initial commit
Diffstat (limited to 'app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt')
-rw-r--r--app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt334
1 files changed, 334 insertions, 0 deletions
diff --git a/app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt b/app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt
new file mode 100644
index 0000000..9517464
--- /dev/null
+++ b/app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt
@@ -0,0 +1,334 @@
+package sh.lajo.buddy
+
+import android.content.Intent
+import android.net.VpnService
+import android.os.IBinder
+import android.util.Log
+import java.io.BufferedReader
+import java.io.FileInputStream
+import java.io.FileOutputStream
+import java.io.InputStreamReader
+import java.io.OutputStream
+import java.net.HttpURLConnection
+import java.net.URL
+import java.util.Locale
+import java.util.concurrent.atomic.AtomicReference
+import java.util.zip.GZIPInputStream
+
+class BuddyVPNService : VpnService() {
+ private companion object {
+ private const val TAG = "BuddyVPNService"
+
+ private const val BLOCKLIST_URL_PRIMARY =
+ "https://cdn.jsdelivr.net/gh/lajo-sh/assets/domains.gz"
+ private const val BLOCKLIST_URL_FALLBACK =
+ "https://github.com/lajo-sh/assets/raw/refs/heads/main/domains.gz"
+
+ private const val CONNECT_TIMEOUT_MS = 10_000
+ private const val READ_TIMEOUT_MS = 20_000
+
+ private const val MAX_DOMAINS = 2_000_000
+ }
+
+ /** In-memory only. Rebuilt on every service start. */
+ private val blockedDomainsRef = AtomicReference<Set<String>>(emptySet())
+
+ override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
+ Thread {
+ val loaded = loadRemoteBlocklist()
+ if (loaded != null) {
+ blockedDomainsRef.set(loaded)
+ Log.i(TAG, "Loaded blocklist: ${loaded.size} domains")
+ } else {
+ Log.w(TAG, "Blocklist load failed; continuing with empty list")
+ }
+ }.start()
+
+ val builder = Builder()
+ .setSession("BuddyVPN")
+ .addAddress("10.0.0.2", 32)
+ .addRoute("10.0.0.1", 32)
+ .addDnsServer("10.0.0.1")
+
+ val vpnInterface = builder.establish() ?: return START_STICKY
+
+ Thread {
+ try {
+ val input = FileInputStream(vpnInterface.fileDescriptor)
+ val output = FileOutputStream(vpnInterface.fileDescriptor)
+ val buffer = ByteArray(32767)
+
+ val dnsSocket = java.net.DatagramSocket()
+ protect(dnsSocket)
+
+ while (true) {
+ val length = input.read(buffer)
+ if (length > 0) {
+ try {
+ handlePacket(buffer, length, output, dnsSocket)
+ } catch (e: Exception) {
+ Log.e(TAG, "Packet handling error", e)
+ }
+ }
+ }
+ } catch (e: Exception) {
+ Log.e(TAG, "VPN loop error: ${e.message}", e)
+ }
+ }.start()
+
+ return START_STICKY
+ }
+
+ private fun loadRemoteBlocklist(): Set<String>? {
+ return try {
+ fetchAndParseDomains(BLOCKLIST_URL_PRIMARY) ?: fetchAndParseDomains(BLOCKLIST_URL_FALLBACK)
+ } catch (e: Exception) {
+ Log.w(TAG, "Unexpected blocklist load error: ${e.message}", e)
+ null
+ }
+ }
+
+ private fun fetchAndParseDomains(url: String): Set<String>? {
+ var conn: HttpURLConnection? = null
+ return try {
+ conn = (URL(url).openConnection() as HttpURLConnection).apply {
+ instanceFollowRedirects = true
+ connectTimeout = CONNECT_TIMEOUT_MS
+ readTimeout = READ_TIMEOUT_MS
+ requestMethod = "GET"
+ setRequestProperty("Accept-Encoding", "gzip")
+ setRequestProperty("User-Agent", "BuddyVPNService")
+ }
+
+ val code = conn.responseCode
+ if (code !in 200..299) {
+ Log.w(TAG, "Blocklist fetch failed ($code) from $url")
+ return null
+ }
+
+ conn.inputStream.use { raw ->
+ GZIPInputStream(raw).use { gz ->
+ return parseDomainsFromGzipStream(gz)
+ }
+ }
+ } catch (e: Exception) {
+ Log.w(TAG, "Blocklist fetch/parse failed from $url: ${e.message}")
+ null
+ } finally {
+ conn?.disconnect()
+ }
+ }
+
+ private fun parseDomainsFromGzipStream(stream: java.io.InputStream): Set<String> {
+ val result = HashSet<String>(256 * 1024)
+ BufferedReader(InputStreamReader(stream)).useLines { lines ->
+ for (line in lines) {
+ val d = normalizeDomain(line) ?: continue
+ result.add(d)
+ if (result.size >= MAX_DOMAINS) break
+ }
+ }
+ return result
+ }
+
+ private fun normalizeDomain(raw: String): String? {
+ var s = raw.trim()
+ if (s.isEmpty()) return null
+
+ val hash = s.indexOf('#')
+ if (hash >= 0) s = s.substring(0, hash).trim()
+ if (s.isEmpty()) return null
+
+ val parts = s.split(Regex("\\s+"))
+ val candidate = parts.lastOrNull()?.trim().orEmpty()
+ if (candidate.isEmpty()) return null
+
+ val cleaned = candidate
+ .trimEnd('.')
+ .lowercase(Locale.US)
+
+ if (cleaned.length !in 1..253) return null
+ if (!cleaned.any { it == '.' }) return null
+
+ return cleaned
+ }
+
+ private fun isBlocked(domain: String): Boolean {
+ // Check if blocking is enabled in config
+ val config = ConfigManager.getConfig(this)
+ if (!config.blockAdultSites) {
+ return false
+ }
+
+ val blockedDomains = blockedDomainsRef.get()
+ if (blockedDomains.isEmpty()) return false
+
+ val d = normalizeDomain(domain) ?: return false
+ if (blockedDomains.contains(d)) return true
+
+ var idx = d.indexOf('.')
+ while (idx >= 0 && idx + 1 < d.length) {
+ val suffix = d.substring(idx + 1)
+ if (blockedDomains.contains(suffix)) return true
+ idx = d.indexOf('.', idx + 1)
+ }
+
+ return false
+ }
+
+ private fun isDns(packet: ByteArray, length: Int): Boolean {
+ if (length < 28) return false
+ val protocol = packet.getOrNull(9)?.toInt() ?: return false
+ if (protocol != 17) return false
+ val destPort = ((packet[22].toInt() and 0xFF) shl 8) or (packet[23].toInt() and 0xFF)
+ return destPort == 53
+ }
+
+ private fun extractDomain(packet: ByteArray): String? {
+ var index = 40
+ val labels = mutableListOf<String>()
+
+ try {
+ while (index < packet.size) {
+ val len = packet[index].toInt() and 0xFF
+ if (len == 0) break
+ if ((len and 0xC0) == 0xC0) {
+ return null
+ }
+ index++
+ if (index + len > packet.size) return null
+ labels.add(String(packet, index, len))
+ index += len
+ }
+ } catch (_: Exception) {
+ return null
+ }
+
+ return if (labels.isEmpty()) null else labels.joinToString(".")
+ }
+
+ private fun handlePacket(packet: ByteArray, length: Int, output: OutputStream, dnsSocket: java.net.DatagramSocket) {
+ if (!isDns(packet, length)) {
+ return
+ }
+
+ val domain = extractDomain(packet)
+ if (domain == null) {
+ forwardDns(packet, length, output, dnsSocket)
+ return
+ }
+
+ if (isBlocked(domain)) {
+ Log.w(TAG, "blocked: $domain")
+ sendNxDomain(packet, length, output)
+ } else {
+ forwardDns(packet, length, output, dnsSocket)
+ }
+ }
+
+ private fun forwardDns(packet: ByteArray, length: Int, output: OutputStream, dnsSocket: java.net.DatagramSocket) {
+ val dnsPayloadLen = length - 28
+ if (dnsPayloadLen <= 0) return
+
+ val buf = ByteArray(dnsPayloadLen)
+ System.arraycopy(packet, 28, buf, 0, dnsPayloadLen)
+
+ val outPacket = java.net.DatagramPacket(buf, dnsPayloadLen,
+ java.net.InetAddress.getByName("9.9.9.9"), 53) // quad 9
+
+ dnsSocket.send(outPacket)
+
+ val respBuf = ByteArray(4096)
+ val inPacket = java.net.DatagramPacket(respBuf, respBuf.size)
+ try {
+ dnsSocket.soTimeout = 2000
+ dnsSocket.receive(inPacket)
+
+ val response = checkAndConstructResponse(packet, inPacket.data, inPacket.length)
+ output.write(response)
+ } catch (_: Exception) {
+
+ }
+ }
+
+ private fun sendNxDomain(packet: ByteArray, length: Int, output: OutputStream) {
+ val dnsLen = length - 28
+ val responseDns = ByteArray(dnsLen)
+ System.arraycopy(packet, 28, responseDns, 0, dnsLen)
+
+ responseDns[2] = 0x81.toByte()
+ responseDns[3] = 0x03.toByte()
+
+ val ipPacket = checkAndConstructResponse(packet, responseDns, responseDns.size)
+ output.write(ipPacket)
+ }
+
+ private fun checkAndConstructResponse(request: ByteArray, dnsPayload: ByteArray, dnsLen: Int): ByteArray {
+ val totalLen = 28 + dnsLen
+ val response = ByteArray(totalLen)
+
+ response[0] = 0x45 // IPv4
+ response[1] = 0x00
+ // Length
+ response[2] = (totalLen shr 8).toByte()
+ response[3] = (totalLen and 0xFF).toByte()
+ // ID (random or 0)
+ response[4] = 0x00
+ response[5] = 0x00
+ // Flags/Frag
+ response[6] = 0x40 // Don't fragment
+ response[7] = 0x00
+ // TTL
+ response[8] = 64
+ // Protocol
+ response[9] = 17 // UDP
+ // Checksum (calc later)
+ response[10] = 0; response[11] = 0
+
+ System.arraycopy(request, 16, response, 12, 4) // Old Dst is new Src
+ System.arraycopy(request, 12, response, 16, 4) // Old Src is new Dst
+
+ // Calc IP Checksum
+ fillIpChecksum(response)
+
+ // 2. UDP Header (8 bytes)
+ // Src Port = Request Dst Port
+ response[20] = request[22]
+ response[21] = request[23]
+ // Dst Port = Request Src Port
+ response[22] = request[20]
+ response[23] = request[21]
+ // Length
+ val udpLen = 8 + dnsLen
+ response[24] = (udpLen shr 8).toByte()
+ response[25] = (udpLen and 0xFF).toByte()
+ // Checksum = 0 (valid for UDP)
+ response[26] = 0
+ response[27] = 0
+
+ // 3. DNS Payload
+ System.arraycopy(dnsPayload, 0, response, 28, dnsLen)
+
+ return response
+ }
+
+ private fun fillIpChecksum(packet: ByteArray) {
+ var sum = 0
+ // Header length 20 bytes = 10 shorts
+ for (i in 0 until 10) {
+ // Skip checksum field itself (10, 11)
+ if (i == 5) continue
+ val b1 = packet[i * 2].toInt() and 0xFF
+ val b2 = packet[i * 2 + 1].toInt() and 0xFF
+ sum += (b1 shl 8) + b2
+ }
+ while ((sum shr 16) > 0) {
+ sum = (sum and 0xFFFF) + (sum shr 16)
+ }
+ val checksum = sum.inv() and 0xFFFF
+ packet[10] = (checksum shr 8).toByte()
+ packet[11] = (checksum and 0xFF).toByte()
+ }
+
+ override fun onBind(intent: Intent?): IBinder? = super.onBind(intent)
+}