summaryrefslogtreecommitdiff
path: root/main.go
blob: 70b5c0e458d7341c31c49e290e0bfa8317a71257 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package main

import (
	"context"
	"log/slog"
	"time"
	"os"
	"os/signal"
	"path/filepath"
	"syscall"
	"fmt"
	"linum/internal/blocklist"
	"linum/internal/cache"
	"linum/internal/config"
	"linum/internal/resolver"
	"linum/internal/server"
)

func main() {
	flags := config.ParseFlags()

	cfg := config.Default()
	if _, err := os.Stat(flags.Config); err == nil {
		fileCfg, err := config.LoadFile(flags.Config)
		if err != nil {
			slog.Error("invalid config file", "err", err)
			os.Exit(1)
		}
		cfg = config.Merge(cfg, fileCfg)
	}
	cfg = flags.Apply(cfg)

	if err := cfg.Validate(); err != nil {
		slog.Error("config validation failed", "err", err)
		os.Exit(1)
	}

	var lvl slog.Level
	_ = lvl.UnmarshalText([]byte(cfg.Log.Level))
	logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: lvl}))
	slog.SetDefault(logger)

	fmt.Println("linum - simple dns recursive")
	fmt.Printf("github.com/radhityax/linum\n\n")

	logger.Info("config loaded", "file", flags.Config)

	r := resolver.New(
		resolver.WithTimeout(2 * time.Second),
	)

	c, err := cache.NewCache(cfg.Cache.MaxEntries, cfg.Cache.DBPath)
	if err != nil {
		logger.Error("create cache failed", "err", err)
		os.Exit(1)
	}
	defer c.Stop()

	var bl *blocklist.Blocklist
	if len(cfg.Blocklist.Files) > 0 || len(cfg.Blocklist.URLs) > 0 {
		resp := blocklist.ResponseZeroIP
		if cfg.Blocklist.Response == "nxdomain" {
			resp = blocklist.ResponseNXDOMAIN
		}
		bl = blocklist.New(resp)

		for _, pattern := range cfg.Blocklist.Files {
			matches, err := filepath.Glob(pattern)
			if err != nil {
				logger.Warn("invalid blocklist glob", "pattern", pattern, "err", err)
				continue
			}
			for _, f := range matches {
				if err := bl.LoadFile(f); err != nil {
					logger.Error("load blocklist failed", "file", f, "err", err)
					os.Exit(1)
				}
				logger.Info("blocklist loaded", "file", f, "rules", bl.TotalRules)
			}
		}
		for _, u := range cfg.Blocklist.URLs {
			if err := bl.LoadURL(u); err != nil {
				logger.Warn("load blocklist url failed", "url", u, "err", err)
				continue
			}
			logger.Info("blocklist url loaded", "url", u)
		}
	} else {
		logger.Info("no blocklist configured, ad-blocking disabled")
	}

	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
	defer stop()

	srv, err := server.New(cfg.Server.ListenUDP, cfg.Server.ListenTCP, cfg.Server.ListenDOH, logger, r, c, bl)
	if err != nil {
		logger.Error("create server failed", "err", err)
		os.Exit(1)
	}
	defer srv.Close()

	logger.Info("linum starting",
		"udp", cfg.Server.ListenUDP,
		"tcp", cfg.Server.ListenTCP,
		"doh", cfg.Server.ListenDOH,
	)

	if err := srv.Run(ctx); err != nil && err != context.Canceled {
		logger.Error("server stopped with error", "err", err)
		os.Exit(1)
	}
	logger.Info("linum stopped cleanly")
}