summaryrefslogtreecommitdiff
path: root/internal/blocklist
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
committerradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
commit4e6a897a0b55ee533c05f89fa38dbe0704f2798d (patch)
tree12d9700e53775503ad7ba2beb72bedfc64bdd70d /internal/blocklist
parent3e44adc94f32bfe500730fcbf1c02cedf65b0a30 (diff)
dns recursive resolver(iterative, root hints, delegfation, glue, fallback), adblocker, dns cache
Diffstat (limited to 'internal/blocklist')
-rw-r--r--internal/blocklist/blocklist.go181
-rw-r--r--internal/blocklist/blocklist_test.go78
2 files changed, 259 insertions, 0 deletions
diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go
new file mode 100644
index 0000000..ae6d20f
--- /dev/null
+++ b/internal/blocklist/blocklist.go
@@ -0,0 +1,181 @@
+package blocklist
+
+import (
+ "bufio"
+ "net/http"
+ "os"
+ "strings"
+ "sync"
+ "sync/atomic"
+)
+
+type ResponseAction int
+
+const (
+ ResponseZeroIP ResponseAction = iota
+ ResponseNXDOMAIN
+)
+
+type Blocklist struct {
+ mu sync.RWMutex
+ blocked *trie
+ exceptions *trie
+ response ResponseAction
+ TotalRules int32
+ Hits int64
+}
+
+func New(action ResponseAction) *Blocklist {
+ return &Blocklist{
+ blocked: newTrie(),
+ exceptions: newTrie(),
+ response: action,
+ }
+}
+
+func (b *Blocklist) Response() ResponseAction {
+ return b.response
+}
+
+func (b *Blocklist) IsBlocked(domain string) bool {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ labels := splitDomain(domain)
+ if !b.blocked.match(labels) {
+ return false
+ }
+ if b.exceptions.match(labels) {
+ return false
+ }
+ atomic.AddInt64(&b.Hits, 1)
+ return true
+}
+
+func (b *Blocklist) LoadFile(path string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ scanner := bufio.NewScanner(f)
+ var n int32
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" || line[0] == '#' || line[0] == '!' {
+ continue
+ }
+ if b.addRule(line) {
+ n++
+ }
+ }
+ atomic.StoreInt32(&b.TotalRules, n)
+ return scanner.Err()
+}
+
+func (b *Blocklist) LoadURL(url string) error {
+ resp, err := http.Get(url)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ scanner := bufio.NewScanner(resp.Body)
+ var n int32
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" || line[0] == '#' {
+ continue
+ }
+ if b.addRule(line) {
+ n++
+ }
+ }
+ atomic.AddInt32(&b.TotalRules, n)
+ return scanner.Err()
+}
+
+func (b *Blocklist) addRule(line string) bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ if strings.HasPrefix(line, "@@") {
+ domain := strings.TrimPrefix(line, "@@")
+ domain = strings.TrimPrefix(domain, "||")
+ if idx := strings.Index(domain, "^"); idx > 0 {
+ domain = domain[:idx]
+ }
+ b.exceptions.insert(splitDomain(domain))
+ return true
+ }
+
+ fields := strings.Fields(line)
+ if len(fields) >= 2 {
+ ip := fields[0]
+ if ip == "0.0.0.0" || ip == "127.0.0.1" || ip == "::1" || ip == "::" {
+ b.blocked.insert(splitDomain(fields[len(fields)-1]))
+ return true
+ }
+ }
+ if strings.HasPrefix(line, "||") {
+ domain := strings.TrimPrefix(line, "||")
+ if idx := strings.Index(domain, "^"); idx > 0 {
+ domain = domain[:idx]
+ }
+ b.blocked.insert(splitDomain(domain))
+ return true
+ }
+ if strings.Contains(line, ".") && !strings.ContainsAny(line, " /") {
+ b.blocked.insert(splitDomain(line))
+ return true
+ }
+ return false
+}
+
+func splitDomain(domain string) []string {
+ domain = strings.TrimSuffix(domain, ".")
+ return strings.Split(domain, ".")
+}
+
+type trieNode struct {
+ children map[string]*trieNode
+ terminal bool
+}
+
+type trie struct{ root *trieNode }
+
+func newTrie() *trie {
+ return &trie{
+ root: &trieNode{
+ children: make(map[string]*trieNode),
+ },
+ }
+}
+func (t *trie) insert(labels []string) {
+ node := t.root
+ for i := len(labels) - 1; i >= 0; i-- {
+ child, ok := node.children[labels[i]]
+ if !ok {
+ child = &trieNode{children: make(map[string]*trieNode)}
+ node.children[labels[i]] = child
+ }
+ node = child
+ }
+ node.terminal = true
+}
+
+func (t *trie) match(labels []string) bool {
+ node := t.root
+ for i := len(labels) - 1; i >= 0; i-- {
+ if node.terminal {
+ return true
+ }
+ child, ok := node.children[labels[i]]
+ if !ok {
+ return false
+ }
+ node = child
+ }
+ return node.terminal
+}
diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go
new file mode 100644
index 0000000..9ae1749
--- /dev/null
+++ b/internal/blocklist/blocklist_test.go
@@ -0,0 +1,78 @@
+package blocklist
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestBlockZeroIP(t *testing.T) {
+ b := New(ResponseZeroIP)
+ b.addRule("||example.com^")
+ b.addRule("0.0.0.0 doubleclick.net")
+
+ cases := []struct {
+ domain string
+ block bool
+ }{
+ {"example.com.", true},
+ {"sub.example.com.", true},
+ {"notexample.com.", false},
+ {"doubleclick.net.", true},
+ {"ads.doubleclick.net.", true},
+ {"google.com.", false},
+ }
+ for _, c := range cases {
+ got := b.IsBlocked(c.domain)
+ if got != c.block {
+ t.Errorf("IsBlocked(%q) = %v, want %v", c.domain, got, c.block)
+ }
+ }
+}
+
+func TestException(t *testing.T) {
+ b := New(ResponseZeroIP)
+ b.addRule("||example.com^")
+ b.addRule("@@||whitelist.example.com^")
+
+ if !b.IsBlocked("example.com.") {
+ t.Error("expected blocked")
+ }
+ if b.IsBlocked("whitelist.example.com.") {
+ t.Error("expected NOT blocked (exception)")
+ }
+ if b.IsBlocked("sub.whitelist.example.com.") {
+ t.Error("expected NOT blocked (exception subdomain)")
+ }
+}
+
+func TestLoadFile(t *testing.T) {
+ dir := t.TempDir()
+ p := filepath.Join(dir, "block.txt")
+ os.WriteFile(p, []byte("0.0.0.0 ads.com\n||tracker.net^\n"), 0644)
+
+ b := New(ResponseZeroIP)
+ if err := b.LoadFile(p); err != nil {
+ t.Fatal(err)
+ }
+ if !b.IsBlocked("ads.com.") {
+ t.Error("expected blocked")
+ }
+ if !b.IsBlocked("sub.tracker.net.") {
+ t.Error("expected blocked")
+ }
+ if b.IsBlocked("safe.org.") {
+ t.Error("expected NOT blocked")
+ }
+}
+
+func TestResponseType(t *testing.T) {
+ b := New(ResponseNXDOMAIN)
+ if b.Response() != ResponseNXDOMAIN {
+ t.Error("expected NXDOMAIN")
+ }
+ b2 := New(ResponseZeroIP)
+ if b2.Response() != ResponseZeroIP {
+ t.Error("expected ZeroIP")
+ }
+}