summaryrefslogtreecommitdiff
path: root/internal/resolver/resolver.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/resolver/resolver.go')
-rw-r--r--internal/resolver/resolver.go192
1 files changed, 192 insertions, 0 deletions
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go
new file mode 100644
index 0000000..4ad023a
--- /dev/null
+++ b/internal/resolver/resolver.go
@@ -0,0 +1,192 @@
+package resolver
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+var (
+ ErrMaxDelegations = errors.New("max delegations exceeded")
+ ErrNoServers = errors.New("no nameservers available")
+)
+
+type Resolver struct {
+ roots []string
+ maxDelegations int
+ timeout time.Duration
+ retries int
+}
+
+type Option func(*Resolver)
+
+func New(opts ...Option) *Resolver {
+ r := &Resolver{
+ roots: loadRootServers(),
+ maxDelegations: 30,
+ timeout: 2 * time.Second,
+ retries: 2,
+ }
+ for _, opt := range opts {
+ opt(r)
+ }
+ return r
+}
+
+func WithRootAddresses(addrs []string) Option {
+ return func(r *Resolver) {
+ r.roots = addrs
+ }
+}
+
+func WithTimeout(d time.Duration) Option {
+ return func(r *Resolver) {
+ r.timeout = d
+ }
+}
+
+func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return r.resolve(ctx, qname, qtype)
+}
+
+func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) {
+ servers := make([]string, len(r.roots))
+ copy(servers, r.roots)
+ for depth := 0; depth < r.maxDelegations; depth++ {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+ if len(servers) == 0 {
+ return nil, ErrNoServers
+ }
+ reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype)
+ if err != nil {
+ return nil, fmt.Errorf("resolve %s %s: %w",
+ qname, dns.TypeToString[qtype], err)
+ }
+ switch {
+ case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0:
+ return reply, nil
+ case reply.Rcode == dns.RcodeNameError:
+ return reply, nil
+ case reply.Rcode == dns.RcodeSuccess && isReferral(reply):
+ next, err := r.nextServers(ctx, reply)
+ if err != nil {
+ return nil, err
+ }
+ servers = next
+ default:
+ if len(servers) > 1 {
+ servers = servers[1:]
+ } else {
+ return reply, nil
+ }
+ }
+
+ }
+ return nil, ErrMaxDelegations
+}
+
+func isReferral(msg *dns.Msg) bool {
+ return !msg.Authoritative && len(msg.Ns) > 0
+}
+
+func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, error) {
+ var targets []string
+ glue := make(map[string]string)
+
+ for _, rr := range msg.Ns {
+ ns, ok := rr.(*dns.NS)
+ if !ok {
+ continue
+ }
+ targets = append(targets, ns.Ns)
+ }
+ for _, rr := range msg.Extra {
+ switch v := rr.(type) {
+ case *dns.A:
+ if _, exists := glue[v.Hdr.Name]; !exists {
+ glue[v.Hdr.Name] = v.A.String()
+ }
+ case *dns.AAAA:
+ if _, exists := glue[v.Hdr.Name]; !exists {
+ glue[v.Hdr.Name] = v.AAAA.String()
+ }
+ }
+ }
+
+ var addrs []string
+ var unresolved []string
+ for _, t := range targets {
+ if ip, ok := glue[t]; ok {
+ addrs = append(addrs, ip)
+ } else {
+ unresolved = append(unresolved, t)
+ }
+ }
+
+ if len(addrs) > 0 {
+ return addrs, nil
+ }
+
+ for _, name := range unresolved {
+ reply, err := r.resolve(ctx, name, dns.TypeA)
+ if err != nil {
+ continue
+ }
+ if reply.Rcode != dns.RcodeSuccess {
+ continue
+ }
+ for _, rr := range reply.Answer {
+ if a, ok := rr.(*dns.A); ok {
+ addrs = append(addrs, a.A.String())
+ break
+ }
+ }
+ }
+
+ if len(addrs) == 0 {
+ return nil, ErrNoServers
+ }
+ return addrs, nil
+}
+
+func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
+ qname string, qtype uint16) (*dns.Msg, error) {
+ msg := new(dns.Msg)
+ msg.SetQuestion(qname, qtype)
+ msg.SetEdns0(4096, false)
+ msg.RecursionDesired = false
+
+ client := &dns.Client{
+ Net: "udp",
+ UDPSize: 4096,
+ Timeout: r.timeout,
+ }
+
+ var lastErr error
+ for _, srv := range servers {
+ addr := srv
+ if _, _, err := net.SplitHostPort(srv); err != nil {
+ addr = net.JoinHostPort(srv, "53")
+ }
+ for attempt := 0; attempt < r.retries; attempt++ {
+ reply, _, err := client.ExchangeContext(ctx, msg, addr)
+ if err == nil {
+ return reply, nil
+ }
+ lastErr = err
+ time.Sleep(time.Duration(attempt+1) * 200 * time.Millisecond)
+ }
+ }
+ return nil, lastErr
+}