diff options
Diffstat (limited to 'internal/blocklist/blocklist.go')
| -rw-r--r-- | internal/blocklist/blocklist.go | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go new file mode 100644 index 0000000..ae6d20f --- /dev/null +++ b/internal/blocklist/blocklist.go @@ -0,0 +1,181 @@ +package blocklist + +import ( + "bufio" + "net/http" + "os" + "strings" + "sync" + "sync/atomic" +) + +type ResponseAction int + +const ( + ResponseZeroIP ResponseAction = iota + ResponseNXDOMAIN +) + +type Blocklist struct { + mu sync.RWMutex + blocked *trie + exceptions *trie + response ResponseAction + TotalRules int32 + Hits int64 +} + +func New(action ResponseAction) *Blocklist { + return &Blocklist{ + blocked: newTrie(), + exceptions: newTrie(), + response: action, + } +} + +func (b *Blocklist) Response() ResponseAction { + return b.response +} + +func (b *Blocklist) IsBlocked(domain string) bool { + b.mu.RLock() + defer b.mu.RUnlock() + + labels := splitDomain(domain) + if !b.blocked.match(labels) { + return false + } + if b.exceptions.match(labels) { + return false + } + atomic.AddInt64(&b.Hits, 1) + return true +} + +func (b *Blocklist) LoadFile(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' || line[0] == '!' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.StoreInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) LoadURL(url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.AddInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) addRule(line string) bool { + b.mu.Lock() + defer b.mu.Unlock() + + if strings.HasPrefix(line, "@@") { + domain := strings.TrimPrefix(line, "@@") + domain = strings.TrimPrefix(domain, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.exceptions.insert(splitDomain(domain)) + return true + } + + fields := strings.Fields(line) + if len(fields) >= 2 { + ip := fields[0] + if ip == "0.0.0.0" || ip == "127.0.0.1" || ip == "::1" || ip == "::" { + b.blocked.insert(splitDomain(fields[len(fields)-1])) + return true + } + } + if strings.HasPrefix(line, "||") { + domain := strings.TrimPrefix(line, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.blocked.insert(splitDomain(domain)) + return true + } + if strings.Contains(line, ".") && !strings.ContainsAny(line, " /") { + b.blocked.insert(splitDomain(line)) + return true + } + return false +} + +func splitDomain(domain string) []string { + domain = strings.TrimSuffix(domain, ".") + return strings.Split(domain, ".") +} + +type trieNode struct { + children map[string]*trieNode + terminal bool +} + +type trie struct{ root *trieNode } + +func newTrie() *trie { + return &trie{ + root: &trieNode{ + children: make(map[string]*trieNode), + }, + } +} +func (t *trie) insert(labels []string) { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + child, ok := node.children[labels[i]] + if !ok { + child = &trieNode{children: make(map[string]*trieNode)} + node.children[labels[i]] = child + } + node = child + } + node.terminal = true +} + +func (t *trie) match(labels []string) bool { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + if node.terminal { + return true + } + child, ok := node.children[labels[i]] + if !ok { + return false + } + node = child + } + return node.terminal +} |
