summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-21 09:48:42 +0700
committerradhitya <alif@radhitya.org>2026-06-21 09:48:42 +0700
commitb7359e1d45f505171356bcae3c7d5e2341ecc859 (patch)
treef91d4a4b08ce279d488a76e9b7141e69fc844ea9 /internal
parent2b1f613c42de3861141eb6f93c1740b6937ee183 (diff)
forward mode, cache opt, ACL, rate limit, admin/health, systemd, fix UDP reply
Diffstat (limited to 'internal')
-rw-r--r--internal/blocklist/blocklist.go16
-rw-r--r--internal/blocklist/blocklist_test.go2
-rw-r--r--internal/cache/cache.go159
-rw-r--r--internal/cache/cache_test.go20
-rw-r--r--internal/config/config.go63
-rw-r--r--internal/resolver/resolver.go14
-rw-r--r--internal/resolver/resolver_test.go2
-rw-r--r--internal/resolver/root.go2
-rw-r--r--internal/server/admin.go62
-rw-r--r--internal/server/doh.go70
-rw-r--r--internal/server/handler.go121
-rw-r--r--internal/server/server.go163
-rw-r--r--internal/server/server_test.go2
13 files changed, 466 insertions, 230 deletions
diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go
index f6e0592..8925299 100644
--- a/internal/blocklist/blocklist.go
+++ b/internal/blocklist/blocklist.go
@@ -2,9 +2,9 @@ package blocklist
import (
"bufio"
+ "fmt"
"net/http"
"os"
- "fmt"
"strings"
"sync"
"sync/atomic"
@@ -18,19 +18,19 @@ const (
)
type Blocklist struct {
- mu sync.RWMutex
- blocked *trie
+ mu sync.RWMutex
+ blocked *trie
exceptions *trie
- response ResponseAction
+ response ResponseAction
TotalRules int32
- Hits int64
+ Hits int64
}
func New(action ResponseAction) *Blocklist {
return &Blocklist{
- blocked: newTrie(),
+ blocked: newTrie(),
exceptions: newTrie(),
- response: action,
+ response: action,
}
}
@@ -128,8 +128,6 @@ func (b *Blocklist) LoadURL(url string) error {
return b.load(bufio.NewScanner(resp.Body))
}
-
-
func splitDomain(domain string) []string {
domain = strings.TrimSuffix(domain, ".")
return strings.Split(domain, ".")
diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go
index b6d4e83..874744c 100644
--- a/internal/blocklist/blocklist_test.go
+++ b/internal/blocklist/blocklist_test.go
@@ -7,7 +7,7 @@ import (
)
func TestBlockZeroIP(t *testing.T) {
-dir := t.TempDir()
+ dir := t.TempDir()
p := filepath.Join(dir, "block.txt")
os.WriteFile(p, []byte("||example.com^\n0.0.0.0 doubleclick.net\n"), 0644)
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index bb35b8e..b0a0959 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -12,18 +12,18 @@ import (
)
type Key struct {
- Name string
+ Name string
Qtype uint16
Class uint16
}
type entry struct {
- msg *dns.Msg
+ packed []byte
storedAt time.Time
- ttl time.Duration
+ ttl time.Duration
}
-func (e *entry) expired() bool {
+func (e *entry) expired() bool {
return time.Since(e.storedAt) >= e.ttl
}
@@ -36,20 +36,20 @@ func (e *entry) remaining() time.Duration {
}
type Cache struct {
- mu sync.RWMutex
+ mu sync.RWMutex
entries map[Key]*entry
maxSize int
- db *sql.DB
- dbCh chan dbWrite
- wg sync.WaitGroup
- hits int64
- misses int64
+ db *sql.DB
+ dbCh chan dbWrite
+ wg sync.WaitGroup
+ hits int64
+ misses int64
evicted int64
- stopCh chan struct{}
+ stopCh chan struct{}
}
type dbWrite struct {
key Key
- e *entry
+ e *entry
}
func NewCache(maxSize int, dbPath string) (*Cache, error) {
@@ -59,7 +59,7 @@ func NewCache(maxSize int, dbPath string) (*Cache, error) {
c := &Cache{
entries: make(map[Key]*entry),
maxSize: maxSize,
- stopCh: make(chan struct{}),
+ stopCh: make(chan struct{}),
}
if dbPath != "" {
@@ -108,7 +108,7 @@ func (c *Cache) Stop() {
}
}
-func (c *Cache) Get(key Key) (*dns.Msg, bool) {
+func (c *Cache) Get(key Key) ([]byte, bool) {
c.mu.RLock()
e, ok := c.entries[key]
if !ok || e.expired() {
@@ -119,10 +119,11 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) {
c.mu.RUnlock()
atomic.AddInt64(&c.hits, 1)
- msg := deepCopyMsg(e.msg)
- remaining := e.remaining()
- adjustTTL(msg, e.ttl, remaining)
- return msg, true
+ cp := make([]byte, len(e.packed))
+ copy(cp, e.packed)
+ cp[0] = 0
+ cp[1] = 0
+ return cp, true
}
func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
@@ -133,10 +134,16 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
ttl = 60 * time.Second
}
+ if err := msg.Pack(); err != nil {
+ return
+ }
+ cp := make([]byte, len(msg.Data))
+ copy(cp, msg.Data)
+
e := &entry{
- msg: deepCopyMsg(msg),
+ packed: cp,
storedAt: time.Now(),
- ttl: ttl,
+ ttl: ttl,
}
c.mu.Lock()
@@ -156,8 +163,8 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
func (c *Cache) Stats() (int64, int64, int64) {
return atomic.LoadInt64(&c.hits),
- atomic.LoadInt64(&c.misses),
- atomic.LoadInt64(&c.evicted)
+ atomic.LoadInt64(&c.misses),
+ atomic.LoadInt64(&c.evicted)
}
func (c *Cache) Len() int {
@@ -228,15 +235,9 @@ func (c *Cache) writeToDB(key Key, e *entry) {
if c.db == nil {
return
}
- err := e.msg.Pack()
- if err != nil {
- return
- }
- data := e.msg.Data
- _, err = c.db.Exec(
+ _, err := c.db.Exec(
`INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns)
- VALUES (?, ?, ?, ?, ?, ?)`,
- key.Name, key.Qtype, key.Class, data,
+ VALUES (?,?,?,?,?,?)`, key.Name, key.Qtype, key.Class, e.packed,
e.storedAt.UnixNano(), int64(e.ttl),
)
if err != nil {
@@ -269,14 +270,8 @@ func (c *Cache) loadFromDB() {
continue
}
- msg := new(dns.Msg)
- msg.Data = data
- if err := msg.Unpack(); err != nil {
- continue
- }
-
e := &entry{
- msg: msg,
+ packed: data,
storedAt: time.Unix(0, storedAtNano),
ttl: time.Duration(ttlNs),
}
@@ -289,7 +284,6 @@ func (c *Cache) loadFromDB() {
}
}
-
func computeTTL(msg *dns.Msg) time.Duration {
if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 {
for _, rr := range msg.Ns {
@@ -332,92 +326,3 @@ func computeTTL(msg *dns.Msg) time.Duration {
}
return 0
}
-
-func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) {
- ratio := float64(remaining) / float64(originalTTL)
- for _, rr := range msg.Answer {
- rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
- }
- for _, rr := range msg.Ns {
- if dns.RRToType(rr) == dns.TypeOPT {
- continue
- }
- rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
- }
- for _, rr := range msg.Extra {
- if dns.RRToType(rr) == dns.TypeOPT {
- continue
- }
- rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
- }
-}
-
-func applyRatio(ttl uint32, ratio float64) uint32 {
- return uint32(float64(ttl) * ratio)
-}
-
-func deepCopyMsg(msg *dns.Msg) *dns.Msg {
- cp := new(dns.Msg)
- cp.Response = msg.Response
- cp.ID = msg.ID
- cp.Opcode = msg.Opcode
- cp.Authoritative = msg.Authoritative
- cp.Truncated = msg.Truncated
- cp.RecursionDesired = msg.RecursionDesired
- cp.RecursionAvailable = msg.RecursionAvailable
- cp.Rcode = msg.Rcode
- cp.UDPSize = msg.UDPSize
- cp.Question = copyRRSlice(msg.Question)
- cp.Answer = copyRRSlice(msg.Answer)
- cp.Ns = copyRRSlice(msg.Ns)
- cp.Extra = copyRRSlice(msg.Extra)
- return cp
-}
-
-func copyRRSlice(src []dns.RR) []dns.RR {
- if src == nil {
- return nil
- }
- dst := make([]dns.RR, len(src))
- for i, rr := range src {
- dst[i] = copyRR(rr)
- }
- return dst
-}
-
-func copyRR(rr dns.RR) dns.RR {
- if rr == nil {
- return nil
- }
- switch v := rr.(type) {
- case *dns.A:
- cp := *v
- return &cp
- case *dns.AAAA:
- cp := *v
- return &cp
- case *dns.NS:
- cp := *v
- return &cp
- case *dns.CNAME:
- cp := *v
- return &cp
- case *dns.SOA:
- cp := *v
- return &cp
- case *dns.MX:
- cp := *v
- return &cp
- case *dns.TXT:
- cp := *v
- return &cp
- case *dns.SRV:
- cp := *v
- return &cp
- case *dns.PTR:
- cp := *v
- return &cp
- default:
- return rr
- }
-}
diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go
index 33e1d30..aa0aeb0 100644
--- a/internal/cache/cache_test.go
+++ b/internal/cache/cache_test.go
@@ -26,14 +26,18 @@ func TestSetGet(t *testing.T) {
key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
c.Set(key, msg, 300*time.Second)
- got, ok := c.Get(key)
+ packed, ok := c.Get(key)
if !ok {
t.Fatal("expected cache hit")
}
- if len(got.Answer) != 1 {
- t.Fatalf("expected 1 answer, got %d", len(got.Answer))
+ msg.Data = packed
+ if err := msg.Unpack(); err != nil {
+ t.Fatal(err)
+ }
+ if len(msg.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(msg.Answer))
}
- a, _ := got.Answer[0].(*dns.A)
+ a, _ := msg.Answer[0].(*dns.A)
if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}
@@ -130,11 +134,15 @@ func TestSQLitePersistence(t *testing.T) {
}
defer c2.Stop()
- got, ok := c2.Get(key)
+ packed, ok := c2.Get(key)
if !ok {
t.Fatal("expected cache hit from SQLite load")
}
- a, _ := got.Answer[0].(*dns.A)
+ msg.Data = packed
+ if err := msg.Unpack(); err != nil {
+ t.Fatal(err)
+ }
+ a, _ := msg.Answer[0].(*dns.A)
if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 1fa8069..b2c88ee 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -8,11 +8,13 @@ import (
)
type Config struct {
- Server ServerConfig `toml:"server"`
- Cache CacheConfig `toml:"cache"`
- Resolver ResolverConfig `toml:"resolver"`
+ Server ServerConfig `toml:"server"`
+ Cache CacheConfig `toml:"cache"`
+ Resolver ResolverConfig `toml:"resolver"`
Blocklist BlocklistConfig `toml:"blocklist"`
- Log LogConfig `toml:"log"`
+ Admin AdminConfig `toml:"admin"`
+ ACL ACLConfig `toml:"acl"`
+ Log LogConfig `toml:"log"`
}
type ServerConfig struct {
@@ -27,10 +29,10 @@ type CacheConfig struct {
}
type ResolverConfig struct {
- Mode string `toml:"mode"`
- Timeout string `toml:"timeout"`
- MaxDelegations int `toml:"max_delegations"`
- Forwarders []string `toml:"forwarders"`
+ Mode string `toml:"mode"`
+ Timeout string `toml:"timeout"`
+ MaxDelegations int `toml:"max_delegations"`
+ Forwarders []string `toml:"forwarders"`
}
type BlocklistConfig struct {
@@ -44,13 +46,23 @@ type LogConfig struct {
}
type CLIFlags struct {
- Config string
- LogLevel string
+ Config string
+ LogLevel string
ListenUDP string
ListenTCP string
ListenDOH string
}
+type ACLConfig struct {
+ Allow []string `toml:"allow"`
+ RateLimitQPS int `toml:"rate_limit_qps"`
+ RateLimitBurst int `toml:"rate_limit_burst"`
+}
+
+type AdminConfig struct {
+ Listen string `toml:"listen"`
+}
+
func ParseFlags() CLIFlags {
var f CLIFlags
flag.StringVar(&f.Config, "config", "linum.toml", "path to config file")
@@ -62,7 +74,7 @@ func ParseFlags() CLIFlags {
return f
}
-func Default() Config{
+func Default() Config {
return Config{
Server: ServerConfig{
ListenUDP: ":5353",
@@ -73,8 +85,8 @@ func Default() Config{
MaxEntries: 100000,
},
Resolver: ResolverConfig{
- Mode: "recursive",
- Timeout: "2s",
+ Mode: "recursive",
+ Timeout: "2s",
MaxDelegations: 30,
},
Blocklist: BlocklistConfig{
@@ -83,6 +95,14 @@ func Default() Config{
Log: LogConfig{
Level: "info",
},
+ Admin: AdminConfig{
+ Listen: "127.0.0.1:8080",
+ },
+ ACL: ACLConfig{
+ Allow: []string{},
+ RateLimitQPS: 50,
+ RateLimitBurst: 10,
+ },
}
}
@@ -132,6 +152,18 @@ func Merge(dst, src Config) Config {
if src.Log.Level != "" {
dst.Log.Level = src.Log.Level
}
+ if src.ACL.Allow != nil {
+ dst.ACL.Allow = src.ACL.Allow
+ }
+ if src.ACL.RateLimitQPS != 0 {
+ dst.ACL.RateLimitQPS = src.ACL.RateLimitQPS
+ }
+ if src.ACL.RateLimitBurst != 0 {
+ dst.ACL.RateLimitBurst = src.ACL.RateLimitBurst
+ }
+ if src.Admin.Listen != "" {
+ dst.Admin.Listen = src.Admin.Listen
+ }
return dst
}
@@ -157,10 +189,13 @@ func (c Config) Validate() error {
default:
return fmt.Errorf("invalid blocklist response %q (want zero_ip or nxdomain)", c.Blocklist.Response)
}
+ switch c.ACL.Allow {
+ default:
+ }
switch c.Resolver.Mode {
case "recursive", "forward", "":
- // nothing happened lol
+ // nothing happened lol
default:
return fmt.Errorf("invalid resolver mode %q (recursive or forward)", c.Resolver.Mode)
}
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go
index 1e4f1da..c7a0694 100644
--- a/internal/resolver/resolver.go
+++ b/internal/resolver/resolver.go
@@ -20,8 +20,8 @@ type Resolver struct {
maxDelegations int
timeout time.Duration
retries int
- forwarders []string
- client *dns.Client
+ forwarders []string
+ client *dns.Client
}
type Option func(*Resolver)
@@ -77,7 +77,7 @@ func (r *Resolver) forward(ctx context.Context, qname string, qtype uint16) (*dn
reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, true)
if err != nil {
return nil, fmt.Errorf("forward %s %s: %w",
- qname, dns.TypeToString[qtype],err)
+ qname, dns.TypeToString[qtype], err)
}
return reply, nil
}
@@ -94,10 +94,10 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
return nil, ErrNoServers
}
reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype,
- false)
+ false)
if err != nil {
return nil, fmt.Errorf("resolve %s %s: %w",
- qname, dns.TypeToString[qtype], err)
+ qname, dns.TypeToString[qtype], err)
}
switch {
case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0:
@@ -248,7 +248,7 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
}
func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
-qname string, qtype uint16, rd bool) (*dns.Msg, error) {
+ qname string, qtype uint16, rd bool) (*dns.Msg, error) {
msg := dns.NewMsg(qname, qtype)
msg.UDPSize = 4096
@@ -256,7 +256,7 @@ qname string, qtype uint16, rd bool) (*dns.Msg, error) {
type result struct {
reply *dns.Msg
- err error
+ err error
}
ch := make(chan result, len(servers))
diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go
index 54f727e..daa8a98 100644
--- a/internal/resolver/resolver_test.go
+++ b/internal/resolver/resolver_test.go
@@ -116,7 +116,7 @@ func TestNextServersWithGlue(t *testing.T) {
A: rdata.A{Addr: netip.MustParseAddr("192.0.2.1")},
})
msg.Extra = append(msg.Extra, &dns.AAAA{
- Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300},
+ Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300},
AAAA: rdata.AAAA{Addr: netip.MustParseAddr("2001:db8::1")},
})
diff --git a/internal/resolver/root.go b/internal/resolver/root.go
index 0557dac..7c5244a 100644
--- a/internal/resolver/root.go
+++ b/internal/resolver/root.go
@@ -1,8 +1,8 @@
package resolver
import (
- _ "embed"
"codeberg.org/miekg/dns"
+ _ "embed"
"strings"
)
diff --git a/internal/server/admin.go b/internal/server/admin.go
new file mode 100644
index 0000000..46242d4
--- /dev/null
+++ b/internal/server/admin.go
@@ -0,0 +1,62 @@
+package server
+
+import (
+ "fmt"
+ "log/slog"
+ "net/http"
+)
+
+type Admin struct {
+ server *http.Server
+ s *Server
+}
+
+func NewAdmin(listen string, s *Server) *Admin {
+ if listen == "" {
+ return nil
+ }
+
+ mux := http.NewServeMux()
+ a := &Admin{
+ server: &http.Server{
+ Addr: listen,
+ Handler: mux,
+ },
+ s: s,
+ }
+ mux.HandleFunc("/health", a.healthHandler)
+ return a
+}
+
+func (a *Admin) Start() {
+ if a == nil {
+ return
+ }
+ go func() {
+ slog.Info("admin server listening", "addr", a.server.Addr)
+ if err := a.server.ListenAndServe(); err != nil &&
+ err != http.ErrServerClosed {
+ slog.Error("admin server failed", "error", err)
+ }
+ }()
+}
+
+func (a *Admin) Close() error {
+ if a == nil {
+ return nil
+ }
+ return a.server.Close()
+}
+
+func (a *Admin) healthHandler(w http.ResponseWriter, r *http.Request) {
+ if a.s == nil {
+ http.Error(w, "not ready", http.StatusServiceUnavailable)
+ return
+ }
+ if a.s.Ready() {
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprint(w, "ok\n")
+ return
+ }
+ http.Error(w, "not ready", http.StatusServiceUnavailable)
+}
diff --git a/internal/server/doh.go b/internal/server/doh.go
index 2f9dfc0..0feb094 100644
--- a/internal/server/doh.go
+++ b/internal/server/doh.go
@@ -2,13 +2,27 @@ package server
import (
"encoding/base64"
- "codeberg.org/miekg/dns"
"io"
"log/slog"
"net/http"
+
+ "codeberg.org/miekg/dns"
+ "linum/internal/cache"
)
func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
+ clientIP := parseHTTPClientIP(r.RemoteAddr)
+ if !s.isAllowed(clientIP) {
+ slog.Warn("doh query denied by ACL", "client", clientIP)
+ http.Error(w, "refused", http.StatusForbidden)
+ return
+ }
+ if !s.rateLimit(clientIP.String()) {
+ slog.Warn("doh query rate limited", "client", clientIP)
+ http.Error(w, "too many requests", http.StatusTooManyRequests)
+ return
+ }
+
var raw []byte
switch r.Method {
@@ -53,23 +67,55 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
return
}
+ q := msg.Question[0]
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+ qclass := q.Header().Class
+
+ if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
+ httpWriteMsg(w, s.blockedResponse(msg))
+ slog.Debug("doh query served",
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
+ "blocked", true,
+ )
+ return
+ }
+
+ if s.cache != nil {
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
+ if packed, ok := s.cache.Get(key); ok {
+ packed[0] = byte(msg.ID >> 8)
+ packed[1] = byte(msg.ID)
+ w.Header().Set("Content-Type", "application/dns-message")
+ w.Header().Set("Cache-Control", "no-cache, max-age=0")
+ w.Write(packed)
+ slog.Debug("doh query served",
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
+ "cached", true,
+ )
+ return
+ }
+ }
+
resp, _ := s.buildResponse(msg)
- if err := resp.Pack(); err != nil {
+ httpWriteMsg(w, resp)
+ slog.Debug("doh query served",
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
+ "rcode", dns.RcodeToString[resp.Rcode],
+ )
+}
+
+func httpWriteMsg(w http.ResponseWriter, msg *dns.Msg) {
+ if err := msg.Pack(); err != nil {
http.Error(w, "pack response", http.StatusInternalServerError)
return
}
-
w.Header().Set("Content-Type", "application/dns-message")
w.Header().Set("Cache-Control", "no-cache, max-age=0")
- if _, err := w.Write(resp.Data); err != nil {
+ if _, err := w.Write(msg.Data); err != nil {
slog.Error("doh write failed", "err", err)
- return
}
-
- slog.Info("doh query served",
- "qname", msg.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])],
- "rcode", dns.RcodeToString[resp.Rcode],
- "client", r.RemoteAddr,
- )
}
diff --git a/internal/server/handler.go b/internal/server/handler.go
index d0aa705..4516468 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -4,6 +4,7 @@ import (
"context"
"io"
"log/slog"
+ "net"
"net/netip"
"time"
@@ -14,6 +15,18 @@ import (
)
func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
+ clientIP := parseClientIP(w.RemoteAddr())
+ if !s.isAllowed(clientIP) {
+ slog.Warn("query denied by ACL", "client", clientIP)
+ s.writeRefused(w, req)
+ return
+ }
+ if !s.rateLimit(clientIP.String()) {
+ slog.Warn("query rate limited", "client", clientIP)
+ s.writeRefused(w, req)
+ return
+ }
+
if len(req.Question) == 0 {
m := new(dns.Msg)
m.Rcode = dns.RcodeFormatError
@@ -23,27 +36,42 @@ func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns
io.Copy(w, m)
return
}
+ q := req.Question[0]
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+ qclass := q.Header().Class
- resp, blocked := s.buildResponse(req)
+ if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
+ resp := s.blockedResponse(req)
+ resp.ID = req.ID
+ io.Copy(w, resp)
+ slog.Info("query served", "qname", qname, "qtype",
+ dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode],
+ "blocked", true)
+ return
+ }
+ if s.cache != nil {
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
+ if packed, ok := s.cache.Get(key); ok {
+ packed[0] = byte(req.ID >> 8)
+ packed[1] = byte(req.ID)
+ io.Copy(w, &dns.Msg{Data: packed})
+ slog.Debug("query served", "qname", qname, "qtype",
+ dns.TypeToString[qtype], "cached", true)
+ return
+ }
+ }
+ resp, blocked := s.buildResponse(req)
resp.ID = req.ID
resp.Data = nil
if _, err := io.Copy(w, resp); err != nil {
- slog.Error("write response failed",
- "err", err,
- "qname", req.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
- )
+ slog.Error("write failed", "err", err, "qname", qname)
return
}
- slog.Info("query served",
- "qname", req.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
- "rcode", dns.RcodeToString[resp.Rcode],
- "blocked", blocked,
- )
+ slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype],
+ "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked)
}
-
func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
m := new(dns.Msg)
@@ -53,30 +81,19 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
m.Question = req.Question
return m, false
}
+
q := req.Question[0]
qname := q.Header().Name
qtype := dns.RRToType(q)
qclass := q.Header().Class
- if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
- return s.blockedResponse(req), true
- }
-
- if s.cache != nil {
- key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
- if cached, ok := s.cache.Get(key); ok {
- cached.ID = req.ID
- return cached, false
- }
- }
-
ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second)
defer cancel()
reply, err := s.resolver.Lookup(ctx, qname, qtype)
if err != nil {
slog.Error("resolution failed",
- "err", err,
+ "err", err,
"qname", qname,
"qtype", dns.TypeToString[qtype],
)
@@ -124,9 +141,59 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
})
case dns.TypeAAAA:
m.Answer = append(m.Answer, &dns.AAAA{
- Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
+ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()},
})
}
return m
}
+
+func parseClientIP(addr net.Addr) net.IP {
+ switch a := addr.(type) {
+ case *net.UDPAddr:
+ return a.IP
+ case *net.TCPAddr:
+ return a.IP
+ }
+ host, _, _ := net.SplitHostPort(addr.String())
+ return net.ParseIP(host)
+}
+
+func parseHTTPClientIP(addr string) net.IP {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return net.ParseIP(addr)
+ }
+ return net.ParseIP(host)
+}
+
+func (s *Server) isAllowed(ip net.IP) bool {
+ if ip == nil {
+ return false
+ }
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ for _, n := range s.aclNets {
+ if n.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *Server) rateLimit(ip string) bool {
+ if s.rateLimiter == nil {
+ return true
+ }
+ return s.rateLimiter.allow(ip)
+}
+
+func (s *Server) writeRefused(w dns.ResponseWriter, req *dns.Msg) {
+ resp := new(dns.Msg)
+ resp.Response = true
+ resp.ID = req.ID
+ resp.Question = req.Question
+ resp.Rcode = dns.RcodeRefused
+ io.Copy(w, resp)
+}
diff --git a/internal/server/server.go b/internal/server/server.go
index 1aa3256..a90e5ac 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -2,55 +2,95 @@ package server
import (
"context"
+ "fmt"
"log/slog"
+ "net"
"net/http"
+ "sync"
"time"
- "codeberg.org/miekg/dns"
- "linum/internal/resolver"
"linum/internal/blocklist"
"linum/internal/cache"
+ "linum/internal/config"
+ "linum/internal/resolver"
+
+ "codeberg.org/miekg/dns"
)
type Server struct {
- logger *slog.Logger
- resolver *resolver.Resolver
- cache *cache.Cache
- blocklist *blocklist.Blocklist
- udp *dns.Server
- tcp *dns.Server
- doh *http.Server
- baseCtx context.Context
+ resolver *resolver.Resolver
+ cache *cache.Cache
+ blocklist *blocklist.Blocklist
+ logger *slog.Logger
+ cfg config.ServerConfig
+ baseCtx context.Context
+ udp *dns.Server
+ tcp *dns.Server
+ doh *http.Server
+ admin *Admin
+ aclNets []*net.IPNet
+ rateLimiter *rateLimiter
+
+ mu sync.RWMutex
+ upUDP bool
+ upTCP bool
+ upDoH bool
+
cancel context.CancelFunc
}
-func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger,
-r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
+func (s *Server) Ready() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ wantUDP := s.cfg.ListenUDP != ""
+ wantTCP := s.cfg.ListenTCP != ""
+ wantDoH := s.cfg.ListenDOH != ""
+ return (!wantUDP || s.upUDP) && (!wantTCP || s.upTCP) && (!wantDoH || s.upDoH)
+
+}
+func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver, c *cache.Cache, bl *blocklist.Blocklist, cfg config.Config) (*Server, error) {
baseCtx, cancel := context.WithCancel(context.Background())
+ s := &Server{
+ logger: logger,
+ resolver: r,
+ cache: c,
+ blocklist: bl,
+ cfg: cfg.Server,
+ baseCtx: baseCtx,
+ cancel: cancel,
+ aclNets: make([]*net.IPNet, 0, len(cfg.ACL.Allow)),
+ rateLimiter: newRateLimiter(cfg.ACL.RateLimitQPS, cfg.ACL.RateLimitBurst),
+ }
- s := &Server{logger: logger, resolver: r, cache: c, blocklist: b,
- baseCtx: baseCtx, cancel: cancel}
+ for _, cidr := range cfg.ACL.Allow {
+ _, ipnet, err := net.ParseCIDR(cidr)
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("invalid acl CIDR %q: %w", cidr, err)
+ }
+ s.aclNets = append(s.aclNets, ipnet)
+ }
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
if udpAddr != "" {
s.udp = &dns.Server{
- Addr: udpAddr,
- Net: "udp",
- Handler: mux,
- UDPSize: 4096,
- ReadTimeout: 5 * time.Second,
- ReusePort: true,
+ Addr: udpAddr,
+ Net: "udp",
+ Handler: mux,
+ UDPSize: 4096,
+ ReadTimeout: 5 * time.Second,
+ ReusePort: true,
}
}
if tcpAddr != "" {
s.tcp = &dns.Server{
- Addr: tcpAddr,
- Net: "tcp",
- Handler: mux,
- ReadTimeout: 5 * time.Second,
+ Addr: tcpAddr,
+ Net: "tcp",
+ Handler: mux,
+ ReadTimeout: 5 * time.Second,
}
}
@@ -64,9 +104,30 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
WriteTimeout: 5 * time.Second,
}
}
+
+ if cfg.Admin.Listen != "" {
+ s.admin = NewAdmin(cfg.Admin.Listen, s)
+ }
+
return s, nil
}
func (s *Server) Run(ctx context.Context) error {
+ s.mu.Lock()
+ if s.udp != nil {
+ s.upUDP = true
+ }
+ if s.tcp != nil {
+ s.upTCP = true
+ }
+ if s.doh != nil {
+ s.upDoH = true
+ }
+ s.mu.Unlock()
+
+ if s.admin != nil {
+ s.admin.Start()
+ }
+
errCh := make(chan error, 3)
if s.udp != nil {
go func() {
@@ -110,6 +171,9 @@ func (s *Server) Run(ctx context.Context) error {
}
func (s *Server) Close() error {
+ if s.admin != nil {
+ s.admin.Close()
+ }
if s.udp != nil {
s.udp.Shutdown(context.Background())
}
@@ -121,3 +185,54 @@ func (s *Server) Close() error {
}
return nil
}
+
+type rateLimiter struct {
+ mu sync.Mutex
+ rate float64
+ burst float64
+ buckets map[string]*rateBucket
+}
+
+type rateBucket struct {
+ tokens float64
+ last time.Time
+}
+
+func newRateLimiter(qps, burst int) *rateLimiter {
+ if qps <= 0 && burst <= 0 {
+ return nil
+ }
+ if burst <= 0 {
+ burst = qps
+ }
+ return &rateLimiter{
+ rate: float64(qps),
+ burst: float64(burst),
+ buckets: make(map[string]*rateBucket),
+ }
+}
+
+func (rl *rateLimiter) allow(ip string) bool {
+ now := time.Now()
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ b, ok := rl.buckets[ip]
+ if !ok {
+ b = &rateBucket{tokens: rl.burst, last: now}
+ rl.buckets[ip] = b
+ }
+
+ elapsed := now.Sub(b.last).Seconds()
+ b.tokens += elapsed * rl.rate
+ if b.tokens > rl.burst {
+ b.tokens = rl.burst
+ }
+ b.last = now
+
+ if b.tokens < 1 {
+ return false
+ }
+ b.tokens--
+ return true
+}
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 002aba8..2340acd 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -1,8 +1,8 @@
package server
import (
- "log/slog"
"context"
+ "log/slog"
"testing"
"time"