diff options
| author | radhitya <alif@radhitya.org> | 2026-06-18 18:17:19 +0700 |
|---|---|---|
| committer | radhitya <alif@radhitya.org> | 2026-06-18 18:17:19 +0700 |
| commit | 359f6a1cba3f2e281cefa727db34e3497dc15a2c (patch) | |
| tree | efaeeef69c903127b28c51f962de1714bd121ecd | |
| parent | f5753c6a8cac5a57a042b0388f38abeff5d1f37d (diff) | |
add custom forwarders
| -rw-r--r-- | internal/config/config.go | 19 | ||||
| -rw-r--r-- | internal/resolver/resolver.go | 90 | ||||
| -rw-r--r-- | linum.toml | 7 | ||||
| -rw-r--r-- | main.go | 35 |
4 files changed, 101 insertions, 50 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index f2624c2..1fa8069 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -27,8 +27,10 @@ type CacheConfig struct { } type ResolverConfig struct { + Mode string `toml:"mode"` Timeout string `toml:"timeout"` MaxDelegations int `toml:"max_delegations"` + Forwarders []string `toml:"forwarders"` } type BlocklistConfig struct { @@ -71,6 +73,7 @@ func Default() Config{ MaxEntries: 100000, }, Resolver: ResolverConfig{ + Mode: "recursive", Timeout: "2s", MaxDelegations: 30, }, @@ -111,6 +114,12 @@ func Merge(dst, src Config) Config { if src.Resolver.MaxDelegations > 0 { dst.Resolver.MaxDelegations = src.Resolver.MaxDelegations } + if src.Resolver.Mode != "" { + dst.Resolver.Mode = src.Resolver.Mode + } + if len(src.Resolver.Forwarders) > 0 { + dst.Resolver.Forwarders = src.Resolver.Forwarders + } if src.Blocklist.Response != "" { dst.Blocklist.Response = src.Blocklist.Response } @@ -148,5 +157,15 @@ func (c Config) Validate() error { default: return fmt.Errorf("invalid blocklist response %q (want zero_ip or nxdomain)", c.Blocklist.Response) } + + switch c.Resolver.Mode { + case "recursive", "forward", "": + // nothing happened lol + default: + return fmt.Errorf("invalid resolver mode %q (recursive or forward)", c.Resolver.Mode) + } + if c.Resolver.Mode == "forward" && len(c.Resolver.Forwarders) == 0 { + return fmt.Errorf("resolver mode=forward requires at least one forwarder") + } return nil } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 5aa7bc1..1e4f1da 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -20,6 +20,7 @@ type Resolver struct { maxDelegations int timeout time.Duration retries int + forwarders []string client *dns.Client } @@ -42,6 +43,11 @@ func New(opts ...Option) *Resolver { return r } +func WithForwarders(addrs []string) Option { + return func(r *Resolver) { + r.forwarders = addrs + } +} func WithRootAddresses(addrs []string) Option { return func(r *Resolver) { r.roots = addrs @@ -59,9 +65,22 @@ func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns if ctx == nil { ctx = context.Background() } + if len(r.forwarders) > 0 { + return r.forward(ctx, qname, qtype) + } return r.resolve(ctx, qname, qtype) } +func (r *Resolver) forward(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { + servers := make([]string, len(r.forwarders)) + copy(servers, r.forwarders) + reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, true) + if err != nil { + return nil, fmt.Errorf("forward %s %s: %w", + qname, dns.TypeToString[qtype],err) + } + return reply, nil +} func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { servers := make([]string, len(r.roots)) copy(servers, r.roots) @@ -74,10 +93,11 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn if len(servers) == 0 { return nil, ErrNoServers } - reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype) + reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, + false) if err != nil { return nil, fmt.Errorf("resolve %s %s: %w", - qname, dns.TypeToString[qtype], err) + qname, dns.TypeToString[qtype], err) } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: @@ -128,40 +148,40 @@ func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint1 } } if target == "" { - return reply, nil - } + return reply, nil + } - const maxChain = 10 - seen := map[string]bool{reply.Question[0].Header().Name: true} - for i := 0; i < maxChain; i++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - if seen[target] { - break - } - seen[target] = true + const maxChain = 10 + seen := map[string]bool{reply.Question[0].Header().Name: true} + for i := 0; i < maxChain; i++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if seen[target] { + break + } + seen[target] = true - chainReply, err := r.resolve(ctx, target, qtype) - if err != nil { - return reply, nil - } - reply.Answer = append(reply.Answer, chainReply.Answer...) + chainReply, err := r.resolve(ctx, target, qtype) + if err != nil { + return reply, nil + } + reply.Answer = append(reply.Answer, chainReply.Answer...) - nextTarget := "" - for _, rr := range chainReply.Answer { - if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target { - nextTarget = cn.Target - } - } - if nextTarget == "" { - break - } - target = nextTarget - } - return reply, nil + nextTarget := "" + for _, rr := range chainReply.Answer { + if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target { + nextTarget = cn.Target + } + } + if nextTarget == "" { + break + } + target = nextTarget + } + return reply, nil } func isReferral(msg *dns.Msg) bool { return !msg.Authoritative && len(msg.Ns) > 0 @@ -228,11 +248,11 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err } func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, - qname string, qtype uint16) (*dns.Msg, error) { +qname string, qtype uint16, rd bool) (*dns.Msg, error) { msg := dns.NewMsg(qname, qtype) msg.UDPSize = 4096 - msg.RecursionDesired = false + msg.RecursionDesired = rd type result struct { reply *dns.Msg @@ -1,14 +1,17 @@ [server] -listen_udp = ":5353" -listen_tcp = ":5353" +listen_udp = ":5354" +listen_tcp = ":5354" listen_doh = ":8443" [cache] max_entries = 100000 db_path = "/tmp/cache.db" + [resolver] +mode = "forward" timeout = "2s" max_delegations = 30 +forwarders = ["1.1.1.1", "8.8.8.8"] [blocklist] response = "zero_ip" @@ -45,10 +45,19 @@ func main() { logger.Info("config loaded", "file", flags.Config) - r := resolver.New( - resolver.WithTimeout(2 * time.Second), - ) - + var ropts []resolver.Option + if cfg.Resolver.Mode == "forward" && len(cfg.Resolver.Forwarders) > 0 { + ropts = append(ropts, resolver.WithForwarders(cfg.Resolver.Forwarders)) + logger.Info("resolver mode: forward", "upstreams", cfg.Resolver.Forwarders) + } else { + logger.Info("resolver mode: recursive (root hints)") + } + if dur, err := time.ParseDuration(cfg.Resolver.Timeout); err == nil { + ropts = append(ropts, resolver.WithTimeout(dur)) + } else { + ropts = append(ropts, resolver.WithTimeout(2*time.Second)) + } + r := resolver.New(ropts...) c, err := cache.NewCache(cfg.Cache.MaxEntries, cfg.Cache.DBPath) if err != nil { logger.Error("create cache failed", "err", err) @@ -100,14 +109,14 @@ func main() { defer srv.Close() logger.Info("linum starting", - "udp", cfg.Server.ListenUDP, - "tcp", cfg.Server.ListenTCP, - "doh", cfg.Server.ListenDOH, - ) + "udp", cfg.Server.ListenUDP, + "tcp", cfg.Server.ListenTCP, + "doh", cfg.Server.ListenDOH, +) - if err := srv.Run(ctx); err != nil && err != context.Canceled { - logger.Error("server stopped with error", "err", err) - os.Exit(1) - } - logger.Info("linum stopped cleanly") +if err := srv.Run(ctx); err != nil && err != context.Canceled { + logger.Error("server stopped with error", "err", err) + os.Exit(1) +} +logger.Info("linum stopped cleanly") } |
