diff options
Diffstat (limited to 'app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt')
| -rw-r--r-- | app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt | 334 |
1 files changed, 334 insertions, 0 deletions
diff --git a/app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt b/app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt new file mode 100644 index 0000000..9517464 --- /dev/null +++ b/app/src/main/java/sh/lajo/buddy/BuddyVPNService.kt @@ -0,0 +1,334 @@ +package sh.lajo.buddy + +import android.content.Intent +import android.net.VpnService +import android.os.IBinder +import android.util.Log +import java.io.BufferedReader +import java.io.FileInputStream +import java.io.FileOutputStream +import java.io.InputStreamReader +import java.io.OutputStream +import java.net.HttpURLConnection +import java.net.URL +import java.util.Locale +import java.util.concurrent.atomic.AtomicReference +import java.util.zip.GZIPInputStream + +class BuddyVPNService : VpnService() { + private companion object { + private const val TAG = "BuddyVPNService" + + private const val BLOCKLIST_URL_PRIMARY = + "https://cdn.jsdelivr.net/gh/lajo-sh/assets/domains.gz" + private const val BLOCKLIST_URL_FALLBACK = + "https://github.com/lajo-sh/assets/raw/refs/heads/main/domains.gz" + + private const val CONNECT_TIMEOUT_MS = 10_000 + private const val READ_TIMEOUT_MS = 20_000 + + private const val MAX_DOMAINS = 2_000_000 + } + + /** In-memory only. Rebuilt on every service start. */ + private val blockedDomainsRef = AtomicReference<Set<String>>(emptySet()) + + override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { + Thread { + val loaded = loadRemoteBlocklist() + if (loaded != null) { + blockedDomainsRef.set(loaded) + Log.i(TAG, "Loaded blocklist: ${loaded.size} domains") + } else { + Log.w(TAG, "Blocklist load failed; continuing with empty list") + } + }.start() + + val builder = Builder() + .setSession("BuddyVPN") + .addAddress("10.0.0.2", 32) + .addRoute("10.0.0.1", 32) + .addDnsServer("10.0.0.1") + + val vpnInterface = builder.establish() ?: return START_STICKY + + Thread { + try { + val input = FileInputStream(vpnInterface.fileDescriptor) + val output = FileOutputStream(vpnInterface.fileDescriptor) + val buffer = ByteArray(32767) + + val dnsSocket = java.net.DatagramSocket() + protect(dnsSocket) + + while (true) { + val length = input.read(buffer) + if (length > 0) { + try { + handlePacket(buffer, length, output, dnsSocket) + } catch (e: Exception) { + Log.e(TAG, "Packet handling error", e) + } + } + } + } catch (e: Exception) { + Log.e(TAG, "VPN loop error: ${e.message}", e) + } + }.start() + + return START_STICKY + } + + private fun loadRemoteBlocklist(): Set<String>? { + return try { + fetchAndParseDomains(BLOCKLIST_URL_PRIMARY) ?: fetchAndParseDomains(BLOCKLIST_URL_FALLBACK) + } catch (e: Exception) { + Log.w(TAG, "Unexpected blocklist load error: ${e.message}", e) + null + } + } + + private fun fetchAndParseDomains(url: String): Set<String>? { + var conn: HttpURLConnection? = null + return try { + conn = (URL(url).openConnection() as HttpURLConnection).apply { + instanceFollowRedirects = true + connectTimeout = CONNECT_TIMEOUT_MS + readTimeout = READ_TIMEOUT_MS + requestMethod = "GET" + setRequestProperty("Accept-Encoding", "gzip") + setRequestProperty("User-Agent", "BuddyVPNService") + } + + val code = conn.responseCode + if (code !in 200..299) { + Log.w(TAG, "Blocklist fetch failed ($code) from $url") + return null + } + + conn.inputStream.use { raw -> + GZIPInputStream(raw).use { gz -> + return parseDomainsFromGzipStream(gz) + } + } + } catch (e: Exception) { + Log.w(TAG, "Blocklist fetch/parse failed from $url: ${e.message}") + null + } finally { + conn?.disconnect() + } + } + + private fun parseDomainsFromGzipStream(stream: java.io.InputStream): Set<String> { + val result = HashSet<String>(256 * 1024) + BufferedReader(InputStreamReader(stream)).useLines { lines -> + for (line in lines) { + val d = normalizeDomain(line) ?: continue + result.add(d) + if (result.size >= MAX_DOMAINS) break + } + } + return result + } + + private fun normalizeDomain(raw: String): String? { + var s = raw.trim() + if (s.isEmpty()) return null + + val hash = s.indexOf('#') + if (hash >= 0) s = s.substring(0, hash).trim() + if (s.isEmpty()) return null + + val parts = s.split(Regex("\\s+")) + val candidate = parts.lastOrNull()?.trim().orEmpty() + if (candidate.isEmpty()) return null + + val cleaned = candidate + .trimEnd('.') + .lowercase(Locale.US) + + if (cleaned.length !in 1..253) return null + if (!cleaned.any { it == '.' }) return null + + return cleaned + } + + private fun isBlocked(domain: String): Boolean { + // Check if blocking is enabled in config + val config = ConfigManager.getConfig(this) + if (!config.blockAdultSites) { + return false + } + + val blockedDomains = blockedDomainsRef.get() + if (blockedDomains.isEmpty()) return false + + val d = normalizeDomain(domain) ?: return false + if (blockedDomains.contains(d)) return true + + var idx = d.indexOf('.') + while (idx >= 0 && idx + 1 < d.length) { + val suffix = d.substring(idx + 1) + if (blockedDomains.contains(suffix)) return true + idx = d.indexOf('.', idx + 1) + } + + return false + } + + private fun isDns(packet: ByteArray, length: Int): Boolean { + if (length < 28) return false + val protocol = packet.getOrNull(9)?.toInt() ?: return false + if (protocol != 17) return false + val destPort = ((packet[22].toInt() and 0xFF) shl 8) or (packet[23].toInt() and 0xFF) + return destPort == 53 + } + + private fun extractDomain(packet: ByteArray): String? { + var index = 40 + val labels = mutableListOf<String>() + + try { + while (index < packet.size) { + val len = packet[index].toInt() and 0xFF + if (len == 0) break + if ((len and 0xC0) == 0xC0) { + return null + } + index++ + if (index + len > packet.size) return null + labels.add(String(packet, index, len)) + index += len + } + } catch (_: Exception) { + return null + } + + return if (labels.isEmpty()) null else labels.joinToString(".") + } + + private fun handlePacket(packet: ByteArray, length: Int, output: OutputStream, dnsSocket: java.net.DatagramSocket) { + if (!isDns(packet, length)) { + return + } + + val domain = extractDomain(packet) + if (domain == null) { + forwardDns(packet, length, output, dnsSocket) + return + } + + if (isBlocked(domain)) { + Log.w(TAG, "blocked: $domain") + sendNxDomain(packet, length, output) + } else { + forwardDns(packet, length, output, dnsSocket) + } + } + + private fun forwardDns(packet: ByteArray, length: Int, output: OutputStream, dnsSocket: java.net.DatagramSocket) { + val dnsPayloadLen = length - 28 + if (dnsPayloadLen <= 0) return + + val buf = ByteArray(dnsPayloadLen) + System.arraycopy(packet, 28, buf, 0, dnsPayloadLen) + + val outPacket = java.net.DatagramPacket(buf, dnsPayloadLen, + java.net.InetAddress.getByName("9.9.9.9"), 53) // quad 9 + + dnsSocket.send(outPacket) + + val respBuf = ByteArray(4096) + val inPacket = java.net.DatagramPacket(respBuf, respBuf.size) + try { + dnsSocket.soTimeout = 2000 + dnsSocket.receive(inPacket) + + val response = checkAndConstructResponse(packet, inPacket.data, inPacket.length) + output.write(response) + } catch (_: Exception) { + + } + } + + private fun sendNxDomain(packet: ByteArray, length: Int, output: OutputStream) { + val dnsLen = length - 28 + val responseDns = ByteArray(dnsLen) + System.arraycopy(packet, 28, responseDns, 0, dnsLen) + + responseDns[2] = 0x81.toByte() + responseDns[3] = 0x03.toByte() + + val ipPacket = checkAndConstructResponse(packet, responseDns, responseDns.size) + output.write(ipPacket) + } + + private fun checkAndConstructResponse(request: ByteArray, dnsPayload: ByteArray, dnsLen: Int): ByteArray { + val totalLen = 28 + dnsLen + val response = ByteArray(totalLen) + + response[0] = 0x45 // IPv4 + response[1] = 0x00 + // Length + response[2] = (totalLen shr 8).toByte() + response[3] = (totalLen and 0xFF).toByte() + // ID (random or 0) + response[4] = 0x00 + response[5] = 0x00 + // Flags/Frag + response[6] = 0x40 // Don't fragment + response[7] = 0x00 + // TTL + response[8] = 64 + // Protocol + response[9] = 17 // UDP + // Checksum (calc later) + response[10] = 0; response[11] = 0 + + System.arraycopy(request, 16, response, 12, 4) // Old Dst is new Src + System.arraycopy(request, 12, response, 16, 4) // Old Src is new Dst + + // Calc IP Checksum + fillIpChecksum(response) + + // 2. UDP Header (8 bytes) + // Src Port = Request Dst Port + response[20] = request[22] + response[21] = request[23] + // Dst Port = Request Src Port + response[22] = request[20] + response[23] = request[21] + // Length + val udpLen = 8 + dnsLen + response[24] = (udpLen shr 8).toByte() + response[25] = (udpLen and 0xFF).toByte() + // Checksum = 0 (valid for UDP) + response[26] = 0 + response[27] = 0 + + // 3. DNS Payload + System.arraycopy(dnsPayload, 0, response, 28, dnsLen) + + return response + } + + private fun fillIpChecksum(packet: ByteArray) { + var sum = 0 + // Header length 20 bytes = 10 shorts + for (i in 0 until 10) { + // Skip checksum field itself (10, 11) + if (i == 5) continue + val b1 = packet[i * 2].toInt() and 0xFF + val b2 = packet[i * 2 + 1].toInt() and 0xFF + sum += (b1 shl 8) + b2 + } + while ((sum shr 16) > 0) { + sum = (sum and 0xFFFF) + (sum shr 16) + } + val checksum = sum.inv() and 0xFFFF + packet[10] = (checksum shr 8).toByte() + packet[11] = (checksum and 0xFF).toByte() + } + + override fun onBind(intent: Intent?): IBinder? = super.onBind(intent) +} |