diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/config/config.go | 19 | ||||
| -rw-r--r-- | internal/resolver/resolver.go | 90 |
2 files changed, 74 insertions, 35 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index f2624c2..1fa8069 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -27,8 +27,10 @@ type CacheConfig struct { } type ResolverConfig struct { + Mode string `toml:"mode"` Timeout string `toml:"timeout"` MaxDelegations int `toml:"max_delegations"` + Forwarders []string `toml:"forwarders"` } type BlocklistConfig struct { @@ -71,6 +73,7 @@ func Default() Config{ MaxEntries: 100000, }, Resolver: ResolverConfig{ + Mode: "recursive", Timeout: "2s", MaxDelegations: 30, }, @@ -111,6 +114,12 @@ func Merge(dst, src Config) Config { if src.Resolver.MaxDelegations > 0 { dst.Resolver.MaxDelegations = src.Resolver.MaxDelegations } + if src.Resolver.Mode != "" { + dst.Resolver.Mode = src.Resolver.Mode + } + if len(src.Resolver.Forwarders) > 0 { + dst.Resolver.Forwarders = src.Resolver.Forwarders + } if src.Blocklist.Response != "" { dst.Blocklist.Response = src.Blocklist.Response } @@ -148,5 +157,15 @@ func (c Config) Validate() error { default: return fmt.Errorf("invalid blocklist response %q (want zero_ip or nxdomain)", c.Blocklist.Response) } + + switch c.Resolver.Mode { + case "recursive", "forward", "": + // nothing happened lol + default: + return fmt.Errorf("invalid resolver mode %q (recursive or forward)", c.Resolver.Mode) + } + if c.Resolver.Mode == "forward" && len(c.Resolver.Forwarders) == 0 { + return fmt.Errorf("resolver mode=forward requires at least one forwarder") + } return nil } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 5aa7bc1..1e4f1da 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -20,6 +20,7 @@ type Resolver struct { maxDelegations int timeout time.Duration retries int + forwarders []string client *dns.Client } @@ -42,6 +43,11 @@ func New(opts ...Option) *Resolver { return r } +func WithForwarders(addrs []string) Option { + return func(r *Resolver) { + r.forwarders = addrs + } +} func WithRootAddresses(addrs []string) Option { return func(r *Resolver) { r.roots = addrs @@ -59,9 +65,22 @@ func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns if ctx == nil { ctx = context.Background() } + if len(r.forwarders) > 0 { + return r.forward(ctx, qname, qtype) + } return r.resolve(ctx, qname, qtype) } +func (r *Resolver) forward(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { + servers := make([]string, len(r.forwarders)) + copy(servers, r.forwarders) + reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, true) + if err != nil { + return nil, fmt.Errorf("forward %s %s: %w", + qname, dns.TypeToString[qtype],err) + } + return reply, nil +} func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) { servers := make([]string, len(r.roots)) copy(servers, r.roots) @@ -74,10 +93,11 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn if len(servers) == 0 { return nil, ErrNoServers } - reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype) + reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, + false) if err != nil { return nil, fmt.Errorf("resolve %s %s: %w", - qname, dns.TypeToString[qtype], err) + qname, dns.TypeToString[qtype], err) } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: @@ -128,40 +148,40 @@ func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint1 } } if target == "" { - return reply, nil - } + return reply, nil + } - const maxChain = 10 - seen := map[string]bool{reply.Question[0].Header().Name: true} - for i := 0; i < maxChain; i++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - if seen[target] { - break - } - seen[target] = true + const maxChain = 10 + seen := map[string]bool{reply.Question[0].Header().Name: true} + for i := 0; i < maxChain; i++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if seen[target] { + break + } + seen[target] = true - chainReply, err := r.resolve(ctx, target, qtype) - if err != nil { - return reply, nil - } - reply.Answer = append(reply.Answer, chainReply.Answer...) + chainReply, err := r.resolve(ctx, target, qtype) + if err != nil { + return reply, nil + } + reply.Answer = append(reply.Answer, chainReply.Answer...) - nextTarget := "" - for _, rr := range chainReply.Answer { - if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target { - nextTarget = cn.Target - } - } - if nextTarget == "" { - break - } - target = nextTarget - } - return reply, nil + nextTarget := "" + for _, rr := range chainReply.Answer { + if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target { + nextTarget = cn.Target + } + } + if nextTarget == "" { + break + } + target = nextTarget + } + return reply, nil } func isReferral(msg *dns.Msg) bool { return !msg.Authoritative && len(msg.Ns) > 0 @@ -228,11 +248,11 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err } func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, - qname string, qtype uint16) (*dns.Msg, error) { +qname string, qtype uint16, rd bool) (*dns.Msg, error) { msg := dns.NewMsg(qname, qtype) msg.UDPSize = 4096 - msg.RecursionDesired = false + msg.RecursionDesired = rd type result struct { reply *dns.Msg |
