diff options
| author | radhitya <alif@radhitya.org> | 2026-06-14 14:36:32 +0700 |
|---|---|---|
| committer | radhitya <alif@radhitya.org> | 2026-06-14 14:36:32 +0700 |
| commit | 4e6a897a0b55ee533c05f89fa38dbe0704f2798d (patch) | |
| tree | 12d9700e53775503ad7ba2beb72bedfc64bdd70d /internal/blocklist | |
| parent | 3e44adc94f32bfe500730fcbf1c02cedf65b0a30 (diff) | |
dns recursive resolver(iterative, root hints, delegfation, glue, fallback), adblocker, dns cache
Diffstat (limited to 'internal/blocklist')
| -rw-r--r-- | internal/blocklist/blocklist.go | 181 | ||||
| -rw-r--r-- | internal/blocklist/blocklist_test.go | 78 |
2 files changed, 259 insertions, 0 deletions
diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go new file mode 100644 index 0000000..ae6d20f --- /dev/null +++ b/internal/blocklist/blocklist.go @@ -0,0 +1,181 @@ +package blocklist + +import ( + "bufio" + "net/http" + "os" + "strings" + "sync" + "sync/atomic" +) + +type ResponseAction int + +const ( + ResponseZeroIP ResponseAction = iota + ResponseNXDOMAIN +) + +type Blocklist struct { + mu sync.RWMutex + blocked *trie + exceptions *trie + response ResponseAction + TotalRules int32 + Hits int64 +} + +func New(action ResponseAction) *Blocklist { + return &Blocklist{ + blocked: newTrie(), + exceptions: newTrie(), + response: action, + } +} + +func (b *Blocklist) Response() ResponseAction { + return b.response +} + +func (b *Blocklist) IsBlocked(domain string) bool { + b.mu.RLock() + defer b.mu.RUnlock() + + labels := splitDomain(domain) + if !b.blocked.match(labels) { + return false + } + if b.exceptions.match(labels) { + return false + } + atomic.AddInt64(&b.Hits, 1) + return true +} + +func (b *Blocklist) LoadFile(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' || line[0] == '!' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.StoreInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) LoadURL(url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.AddInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) addRule(line string) bool { + b.mu.Lock() + defer b.mu.Unlock() + + if strings.HasPrefix(line, "@@") { + domain := strings.TrimPrefix(line, "@@") + domain = strings.TrimPrefix(domain, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.exceptions.insert(splitDomain(domain)) + return true + } + + fields := strings.Fields(line) + if len(fields) >= 2 { + ip := fields[0] + if ip == "0.0.0.0" || ip == "127.0.0.1" || ip == "::1" || ip == "::" { + b.blocked.insert(splitDomain(fields[len(fields)-1])) + return true + } + } + if strings.HasPrefix(line, "||") { + domain := strings.TrimPrefix(line, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.blocked.insert(splitDomain(domain)) + return true + } + if strings.Contains(line, ".") && !strings.ContainsAny(line, " /") { + b.blocked.insert(splitDomain(line)) + return true + } + return false +} + +func splitDomain(domain string) []string { + domain = strings.TrimSuffix(domain, ".") + return strings.Split(domain, ".") +} + +type trieNode struct { + children map[string]*trieNode + terminal bool +} + +type trie struct{ root *trieNode } + +func newTrie() *trie { + return &trie{ + root: &trieNode{ + children: make(map[string]*trieNode), + }, + } +} +func (t *trie) insert(labels []string) { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + child, ok := node.children[labels[i]] + if !ok { + child = &trieNode{children: make(map[string]*trieNode)} + node.children[labels[i]] = child + } + node = child + } + node.terminal = true +} + +func (t *trie) match(labels []string) bool { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + if node.terminal { + return true + } + child, ok := node.children[labels[i]] + if !ok { + return false + } + node = child + } + return node.terminal +} diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go new file mode 100644 index 0000000..9ae1749 --- /dev/null +++ b/internal/blocklist/blocklist_test.go @@ -0,0 +1,78 @@ +package blocklist + +import ( + "os" + "path/filepath" + "testing" +) + +func TestBlockZeroIP(t *testing.T) { + b := New(ResponseZeroIP) + b.addRule("||example.com^") + b.addRule("0.0.0.0 doubleclick.net") + + cases := []struct { + domain string + block bool + }{ + {"example.com.", true}, + {"sub.example.com.", true}, + {"notexample.com.", false}, + {"doubleclick.net.", true}, + {"ads.doubleclick.net.", true}, + {"google.com.", false}, + } + for _, c := range cases { + got := b.IsBlocked(c.domain) + if got != c.block { + t.Errorf("IsBlocked(%q) = %v, want %v", c.domain, got, c.block) + } + } +} + +func TestException(t *testing.T) { + b := New(ResponseZeroIP) + b.addRule("||example.com^") + b.addRule("@@||whitelist.example.com^") + + if !b.IsBlocked("example.com.") { + t.Error("expected blocked") + } + if b.IsBlocked("whitelist.example.com.") { + t.Error("expected NOT blocked (exception)") + } + if b.IsBlocked("sub.whitelist.example.com.") { + t.Error("expected NOT blocked (exception subdomain)") + } +} + +func TestLoadFile(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "block.txt") + os.WriteFile(p, []byte("0.0.0.0 ads.com\n||tracker.net^\n"), 0644) + + b := New(ResponseZeroIP) + if err := b.LoadFile(p); err != nil { + t.Fatal(err) + } + if !b.IsBlocked("ads.com.") { + t.Error("expected blocked") + } + if !b.IsBlocked("sub.tracker.net.") { + t.Error("expected blocked") + } + if b.IsBlocked("safe.org.") { + t.Error("expected NOT blocked") + } +} + +func TestResponseType(t *testing.T) { + b := New(ResponseNXDOMAIN) + if b.Response() != ResponseNXDOMAIN { + t.Error("expected NXDOMAIN") + } + b2 := New(ResponseZeroIP) + if b2.Response() != ResponseZeroIP { + t.Error("expected ZeroIP") + } +} |
