summaryrefslogtreecommitdiff
path: root/internal/server/server_test.go
blob: 93e2a6917be981e25f0231e73fe3128ea954d193 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package server

import (
	"codeberg.org/miekg/dns"
	"codeberg.org/miekg/dns/rdata"
	"context"
	"linum/internal/resolver"
	"log/slog"
	"testing"
	"time"
)

func TestBuildResponseUnknownType(t *testing.T) {
	s := testServer(t)
	m := new(dns.Msg)
	m.Question = []dns.RR{&dns.RFC3597{
		Hdr:     dns.Header{Name: "example.com.", Class: dns.ClassINET},
		RFC3597: rdata.RFC3597{RRType: 0xaa58},
	}}

	resp, _ := s.buildResponse(m)
	if resp == nil {
		t.Fatal("buildResponse returned nil")
	}
	if resp.Rcode != dns.RcodeNotImplemented {
		t.Errorf("expected NOTIMPL, got %d", resp.Rcode)
	}
}

func testServer(t *testing.T) *Server {
	t.Helper()
	r := resolver.New(
		resolver.WithRootAddresses([]string{"127.0.0.1:1"}),
		resolver.WithTimeout(50*time.Millisecond),
	)
	baseCtx, cancel := context.WithCancel(context.Background())

	t.Cleanup(cancel)
	return &Server{logger: slog.Default(), resolver: r, baseCtx: baseCtx}
}

func TestBuildResponse(t *testing.T) {
	s := testServer(t)
	tests := []struct {
		name        string
		req         *dns.Msg
		wantRcode   uint16
		wantAnswers int
		wantEdns0   bool
	}{
		{
			name: "no questions returns FORMERR",
			req: func() *dns.Msg {
				return new(dns.Msg)
			}(),
			wantRcode:   dns.RcodeFormatError,
			wantAnswers: 0,
			wantEdns0:   false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			resp, _ := s.buildResponse(tt.req)
			if resp.Rcode != tt.wantRcode {
				t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode)
			}
			if len(resp.Answer) != tt.wantAnswers {
				t.Errorf("answers: got %d, want %d", len(resp.Answer), tt.wantAnswers)
			}
			if tt.wantEdns0 {
				if resp.UDPSize == 0 {
					t.Error("expected EDNS0 in response, got none")
				} else if resp.UDPSize != 4096 {
					t.Errorf("edns0 udp size: got %d, want 4096", resp.UDPSize)
				}
			}
		})
	}
}

func TestBuildResponseWithQuery(t *testing.T) {
	s := testServer(t)
	m := dns.NewMsg("example.com.", dns.TypeA)
	resp, _ := s.buildResponse(m)
	if resp == nil {
		t.Fatal("buildResponse returned nil")
	}
	if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeServerFailure {
		t.Errorf("expected success or server failure, got %d", resp.Rcode)
	}
}

func FuzzBuildResponse(f *testing.F) {
	baseCtx, cancel := context.WithCancel(context.Background())
	defer cancel()
	s := &Server{logger: slog.Default(), baseCtx: baseCtx}
	s.resolver = resolver.New(
		resolver.WithRootAddresses([]string{"127.0.0.1:1"}),
		resolver.WithTimeout(10*time.Millisecond),
	)

	seed := []byte{
		0x00, 0x00, // ID
		0x01, 0x00, // flags: RD
		0x00, 0x01, // QDCOUNT: 1
		0x00, 0x00, // ANCOUNT
		0x00, 0x00, // NSCOUNT
		0x00, 0x00, // ARCOUNT
		// Question: example.com A
		0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65,
		0x03, 0x63, 0x6f, 0x6d,
		0x00,
		0x00, 0x01, // QTYPE: A
		0x00, 0x01, // QCLASS: IN
	}
	f.Add(seed)
	f.Fuzz(func(t *testing.T, data []byte) {
		msg := new(dns.Msg)
		msg.Data = data
		if err := msg.Unpack(); err != nil {
			return
		}
		resp, _ := s.buildResponse(msg)
		if resp == nil {
			t.Fatal("buildResponse returned nil")
		}
		if err := resp.Pack(); err != nil {
			t.Errorf("pack failed: %v", err)
		}
	})
}