summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/blocklist/blocklist.go93
-rw-r--r--internal/blocklist/blocklist_test.go20
2 files changed, 61 insertions, 52 deletions
diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go
index ae6d20f..364c492 100644
--- a/internal/blocklist/blocklist.go
+++ b/internal/blocklist/blocklist.go
@@ -4,6 +4,7 @@ import (
"bufio"
"net/http"
"os"
+ "fmt"
"strings"
"sync"
"sync/atomic"
@@ -58,63 +59,23 @@ func (b *Blocklist) LoadFile(path string) error {
return err
}
defer f.Close()
-
- scanner := bufio.NewScanner(f)
- var n int32
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if line == "" || line[0] == '#' || line[0] == '!' {
- continue
- }
- if b.addRule(line) {
- n++
- }
- }
- atomic.StoreInt32(&b.TotalRules, n)
- return scanner.Err()
-}
-
-func (b *Blocklist) LoadURL(url string) error {
- resp, err := http.Get(url)
- if err != nil {
- return err
- }
- defer resp.Body.Close()
-
- scanner := bufio.NewScanner(resp.Body)
- var n int32
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if line == "" || line[0] == '#' {
- continue
- }
- if b.addRule(line) {
- n++
- }
- }
- atomic.AddInt32(&b.TotalRules, n)
- return scanner.Err()
+ return b.load(bufio.NewScanner(f))
}
-
-func (b *Blocklist) addRule(line string) bool {
- b.mu.Lock()
- defer b.mu.Unlock()
-
+func (b *Blocklist) parseRule(line string, blocked, exceptions *trie) bool {
if strings.HasPrefix(line, "@@") {
domain := strings.TrimPrefix(line, "@@")
domain = strings.TrimPrefix(domain, "||")
if idx := strings.Index(domain, "^"); idx > 0 {
domain = domain[:idx]
}
- b.exceptions.insert(splitDomain(domain))
+ exceptions.insert(splitDomain(domain))
return true
}
-
fields := strings.Fields(line)
if len(fields) >= 2 {
ip := fields[0]
if ip == "0.0.0.0" || ip == "127.0.0.1" || ip == "::1" || ip == "::" {
- b.blocked.insert(splitDomain(fields[len(fields)-1]))
+ blocked.insert(splitDomain(fields[len(fields)-1]))
return true
}
}
@@ -123,15 +84,55 @@ func (b *Blocklist) addRule(line string) bool {
if idx := strings.Index(domain, "^"); idx > 0 {
domain = domain[:idx]
}
- b.blocked.insert(splitDomain(domain))
+ blocked.insert(splitDomain(domain))
return true
}
if strings.Contains(line, ".") && !strings.ContainsAny(line, " /") {
- b.blocked.insert(splitDomain(line))
+ blocked.insert(splitDomain(line))
return true
}
return false
}
+func (b *Blocklist) load(scanner *bufio.Scanner) error {
+ blocked := newTrie()
+ exceptions := newTrie()
+ var n int32
+
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" || line[0] == '#' || line[0] == '!' {
+ continue
+ }
+ if b.parseRule(line, blocked, exceptions) {
+ n++
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ return err
+ }
+
+ b.mu.Lock()
+ b.blocked = blocked
+ b.exceptions = exceptions
+ b.mu.Unlock()
+ atomic.StoreInt32(&b.TotalRules, n)
+ return nil
+}
+func (b *Blocklist) LoadURL(url string) error {
+ resp, err := http.Get(url)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("blocklist url %s: status %d", url, resp.StatusCode)
+ }
+
+ return b.load(bufio.NewScanner(resp.Body))
+}
+
+
func splitDomain(domain string) []string {
domain = strings.TrimSuffix(domain, ".")
diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go
index 9ae1749..b6d4e83 100644
--- a/internal/blocklist/blocklist_test.go
+++ b/internal/blocklist/blocklist_test.go
@@ -7,9 +7,14 @@ import (
)
func TestBlockZeroIP(t *testing.T) {
+dir := t.TempDir()
+ p := filepath.Join(dir, "block.txt")
+ os.WriteFile(p, []byte("||example.com^\n0.0.0.0 doubleclick.net\n"), 0644)
+
b := New(ResponseZeroIP)
- b.addRule("||example.com^")
- b.addRule("0.0.0.0 doubleclick.net")
+ if err := b.LoadFile(p); err != nil {
+ t.Fatal(err)
+ }
cases := []struct {
domain string
@@ -29,11 +34,15 @@ func TestBlockZeroIP(t *testing.T) {
}
}
}
-
func TestException(t *testing.T) {
+ dir := t.TempDir()
+ p := filepath.Join(dir, "block.txt")
+ os.WriteFile(p, []byte("||example.com^\n@@||whitelist.example.com^\n"), 0644)
+
b := New(ResponseZeroIP)
- b.addRule("||example.com^")
- b.addRule("@@||whitelist.example.com^")
+ if err := b.LoadFile(p); err != nil {
+ t.Fatal(err)
+ }
if !b.IsBlocked("example.com.") {
t.Error("expected blocked")
@@ -45,7 +54,6 @@ func TestException(t *testing.T) {
t.Error("expected NOT blocked (exception subdomain)")
}
}
-
func TestLoadFile(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "block.txt")