diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5bd4d275..93be810c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,12 +15,12 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 1 - - uses: WillAbides/setup-go-faster@v1.8.0 + - uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} - run: "go test -race ./..." - - uses: dominikh/staticcheck-action@v1.3.1 + - uses: dominikh/staticcheck-action@v1.4.0 with: - version: "2025.1" + version: "2025.1.1" install-go: false cache-key: ${{ matrix.go }} diff --git a/README.md b/README.md index 5b048ca8..2e936150 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,6 @@ A highly configurable DNS forwarding proxy with support for: - Multiple upstreams with fallbacks - Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN) - Policy driven domain based "split horizon" DNS with wildcard support -- Integrations with common router vendors and firmware - LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing - Prometheus metrics exporter @@ -26,35 +25,17 @@ All DNS protocols are supported, including: - `DNS-over-QUIC` # Use Cases -1. Use secure DNS protocols on networks and devices that don't natively support them (legacy routers, legacy OSes, TVs, smart toasters). +1. Use secure DNS protocols on networks and devices that don't natively support them (legacy OSes, TVs, smart toasters). 2. Create source IP based DNS routing policies with variable secure DNS upstreams. Subnet 1 (admin) uses upstream resolver A, while Subnet 2 (employee) uses upstream resolver B. 3. Create destination IP based DNS routing policies with variable secure DNS upstreams. Listener 1 uses upstream resolver C, while Listener 2 uses upstream resolver D. 4. Create domain level "split horizon" DNS routing policies to send internal domains (*.company.int) to a local DNS server, while everything else goes to another upstream. -5. Deploy on a router and create LAN client specific DNS routing policies from a web GUI (When using ControlD.com). ## OS Support -- Windows (386, amd64, arm) -- Windows Server (386, amd64) +- Windows Desktop (386, amd64, arm) - MacOS (amd64, arm64) - Linux (386, amd64, arm, mips) - FreeBSD (386, amd64, arm) -- Common routers (See below) - - -### Supported Routers -You can run `ctrld` on any supported router. The list of supported routers and firmware includes: -- Asus Merlin -- DD-WRT -- Firewalla -- FreshTomato -- GL.iNet -- OpenWRT -- pfSense / OPNsense -- Synology -- Ubiquiti (UniFi, EdgeOS) - -`ctrld` will attempt to interface with dnsmasq (or Windows Server) whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly. # Install There are several ways to download and install `ctrld`. @@ -63,12 +44,12 @@ There are several ways to download and install `ctrld`. The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform: ```shell -sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"' +sh -c 'sh -c "$(curl -sL https://api.controld.com/dl?version=2)"' ``` Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell: ```shell -(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'" +(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1?version=2' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'" ``` Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld) @@ -80,7 +61,7 @@ docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp control Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform. ## Build -Lastly, you can build `ctrld` from source which requires `go1.21+`: +Lastly, you can build `ctrld` from source which requires `go1.23+`: ```shell go build ./cmd/ctrld @@ -130,7 +111,7 @@ Available Commands: Flags: -h, --help help for ctrld -s, --silent do not write any log output - -v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging + -v, --verbose count verbose log output, "-v" basic logging, "-vv" debug logging --version version for ctrld Use "ctrld [command] --help" for more information about a command. @@ -161,9 +142,7 @@ You can then run a test query using a DNS client, for example, `dig`: If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server. ## Service Mode -This mode will run the application as a background system service on any Windows, MacOS, Linux, FreeBSD distribution or supported router. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot. - -When Control D upstreams are used on a router type device, `ctrld` will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them. +This mode will run the application as a background system service on any Windows, MacOS, Linux or FreeBSD distribution. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot. ### Command @@ -200,7 +179,7 @@ Linux or Macos `ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args. ## API Based Auto Configuration -Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. In this mode, the application will automatically choose a non-conflicting IP and/or port and configure itself as the upstream to whatever process is running on port 53 (like dnsmasq or Windows DNS Server). This mode is used when the 1 liner installer command from the Control D onboarding guide is executed. +Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. This mode is used when the 1 liner installer command from the Control D onboarding guide is executed. The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint. @@ -217,7 +196,7 @@ sudo ctrld start --cd abcd1234 Once you run the above command, the following things will happen: - You resolver configuration will be fetched from the API, and config file templated with the resolver data - Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands -- All physical network interface will be updated to use the listener started by the service or dnsmasq upstream will be switched to `ctrld` +- All physical network interface will be updated to use the listener started by the service - All DNS queries will be sent to the listener ## Manual Configuration diff --git a/client_info_darwin.go b/client_info_darwin.go deleted file mode 100644 index 4c3d10b2..00000000 --- a/client_info_darwin.go +++ /dev/null @@ -1,4 +0,0 @@ -package ctrld - -// SelfDiscover reports whether ctrld should only do self discover. -func SelfDiscover() bool { return true } diff --git a/client_info_others.go b/client_info_others.go deleted file mode 100644 index d728913a..00000000 --- a/client_info_others.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build !windows && !darwin - -package ctrld - -// SelfDiscover reports whether ctrld should only do self discover. -func SelfDiscover() bool { return false } diff --git a/client_info_windows.go b/client_info_windows.go deleted file mode 100644 index f20bca78..00000000 --- a/client_info_windows.go +++ /dev/null @@ -1,18 +0,0 @@ -package ctrld - -import ( - "golang.org/x/sys/windows" -) - -// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine. -func isWindowsWorkStation() bool { - // From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa - const VER_NT_WORKSTATION = 0x0000001 - osvi := windows.RtlGetVersion() - return osvi.ProductType == VER_NT_WORKSTATION -} - -// SelfDiscover reports whether ctrld should only do self discover. -func SelfDiscover() bool { - return isWindowsWorkStation() -} diff --git a/cmd/cli/ad_others.go b/cmd/cli/ad_others.go index b23476fe..6a7417fb 100644 --- a/cmd/cli/ad_others.go +++ b/cmd/cli/ad_others.go @@ -8,8 +8,3 @@ import ( // addExtraSplitDnsRule adds split DNS rule if present. func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false } - -// getActiveDirectoryDomain returns AD domain name of this computer. -func getActiveDirectoryDomain() (string, error) { - return "", nil -} diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index 66180a90..4820f72a 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -16,11 +16,11 @@ import ( func addExtraSplitDnsRule(cfg *ctrld.Config) bool { domain, err := getActiveDirectoryDomain() if err != nil { - mainLog.Load().Debug().Msgf("unable to get active directory domain: %v", err) + mainLog.Load().Debug().Msgf("Unable to get active directory domain: %v", err) return false } if domain == "" { - mainLog.Load().Debug().Msg("no active directory domain found") + mainLog.Load().Debug().Msg("No active directory domain found") return false } // Network rules are lowercase during toml config marshaling, @@ -40,11 +40,11 @@ func addSplitDnsRule(cfg *ctrld.Config, domain string) bool { } for _, rule := range lc.Policy.Rules { if _, ok := rule[domain]; ok { - mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n) + mainLog.Load().Debug().Msgf("Split-rule %q already existed for listener.%s", domain, n) return false } } - mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n) + mainLog.Load().Debug().Msgf("Adding split-rule %q for listener.%s", domain, n) lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) } return true diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 1984f702..eb2d7286 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -31,16 +31,15 @@ import ( "github.com/kardianos/service" "github.com/miekg/dns" "github.com/pelletier/go-toml/v2" - "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/viper" + "go.uber.org/zap" "tailscale.com/logtail/backoff" "tailscale.com/net/netmon" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/controld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" - "github.com/Control-D-Inc/ctrld/internal/router" ) // selfCheckInternalTestDomain is used for testing ctrld self response to clients. @@ -62,10 +61,13 @@ var ( defaultConfigFile = "ctrld.toml" rootCertPool *x509.CertPool errSelfCheckNoAnswer = errors.New("no response from ctrld listener. You can try to re-launch with flag --skip_self_checks") + // Store version once during init to avoid repeated calls to curVersion() + appVersion = curVersion() ) var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"} +// isNoConfigStart checks if the command is using no-config start mode func isNoConfigStart(cmd *cobra.Command) bool { for _, flagName := range basicModeFlags { if cmd.Flags().Lookup(flagName).Changed { @@ -84,34 +86,41 @@ _/ ___\ __\_ __ \ | / __ | \/ dns forwarding proxy \/ ` -var rootCmd = &cobra.Command{ - Use: "ctrld", - Short: strings.TrimLeft(rootShortDesc, "\n"), - Version: curVersion(), - PersistentPreRun: func(cmd *cobra.Command, args []string) { - initConsoleLogging() - }, -} - +// curVersion returns the current version string func curVersion() string { + // Ensure version has proper "v" prefix for semantic versioning + // This is needed because some build systems may provide version without the "v" prefix if version != "dev" && !strings.HasPrefix(version, "v") { version = "v" + version } + // Return version directly if it's not empty and not a dev build + // This avoids unnecessary commit hash concatenation for release versions if version != "" && version != "dev" { return version } + // Truncate commit hash to 7 characters for readability + // Git commit hashes are typically 40 characters, but 7 is sufficient for identification if len(commit) > 7 { commit = commit[:7] } return fmt.Sprintf("%s-%s", version, commit) } -func initCLI() { +func initCLI() *cobra.Command { // Enable opening via explorer.exe on Windows. // See: https://github.com/spf13/cobra/issues/844. cobra.MousetrapHelpText = "" cobra.EnableCommandSorting = false + rootCmd := &cobra.Command{ + Use: "ctrld", + Short: strings.TrimLeft(rootShortDesc, "\n"), + Version: appVersion, + PersistentPreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + }, + } + rootCmd.PersistentFlags().CountVarP( &verbose, "verbose", @@ -128,18 +137,13 @@ func initCLI() { rootCmd.SetHelpCommand(&cobra.Command{Hidden: true}) rootCmd.CompletionOptions.HiddenDefaultCmd = true - initRunCmd() - startCmd := initStartCmd() - stopCmd := initStopCmd() - restartCmd := initRestartCmd() - reloadCmd := initReloadCmd(restartCmd) - statusCmd := initStatusCmd() - uninstallCmd := initUninstallCmd() - interfacesCmd := initInterfacesCmd() - initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) - initClientsCmd() - initUpgradeCmd() - initLogCmd() + InitRunCmd(rootCmd) + InitServiceCmd(rootCmd) + InitClientsCmd(rootCmd) + InitUpgradeCmd(rootCmd) + InitLogCmd(rootCmd) + + return rootCmd } // isMobile reports whether the current OS is a mobile platform. @@ -219,6 +223,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cfg: &cfg, appCallback: appCallback, } + p.logger.Store(mainLog.Load()) if homedir == "" { if dir, err := userHomeDir(); err == nil { homedir = dir @@ -229,39 +234,40 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { sockDir = d } sockPath := filepath.Join(sockDir, ctrldLogUnixSock) - if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil { - if conn, err := net.Dial(addr.Network(), addr.String()); err == nil { - lc := &logConn{conn: conn} - consoleWriter.Out = io.MultiWriter(os.Stdout, lc) - p.logConn = lc - } else { - if !errors.Is(err, os.ErrNotExist) { - mainLog.Load().Warn().Err(err).Msg("unable to create log ipc connection") - } + hlc := newHTTPLogClient(sockPath) + + // Test if HTTP log server is available + if err := hlc.Ping(); err != nil { + if !errConnectionRefused(err) { + p.Warn().Err(err).Msg("Unable to ping log server") } } else { - mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath) + // Server is available, use HTTP log client + consoleWriter = newHumanReadableZapCore(io.MultiWriter(os.Stdout, hlc), consoleWriterLevel) + p.logConn = hlc } notifyExitToLogServer := func() { if p.logConn != nil { - _, _ = p.logConn.Write([]byte(msgExit)) + _ = p.logConn.Close() } } if daemon && runtime.GOOS == "windows" { - mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") + p.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") } if !daemon { // We need to call s.Run() as soon as possible to response to the OS manager, so it // can see ctrld is running and don't mark ctrld as failed service. go func() { - s, err := newService(p, svcConfig) + svcCmd := NewServiceCommand() + svcConfig := svcCmd.createServiceConfig() + s, err := svcCmd.newService(p, svcConfig) if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed create new service") + p.Fatal().Err(err).Msg("Failed to create new service") } if err := s.Run(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start service") + p.Error().Err(err).Msg("Failed to start service") } }() } @@ -269,7 +275,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { tryReadingConfig(writeDefaultConfig) if err := readBase64Config(configBase64); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config") + p.Fatal().Err(err).Msg("Failed to read base64 config") } processNoConfigFlags(noConfigStart) @@ -278,7 +284,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + p.Fatal().Msgf("Failed to unmarshal config: %v", err) } p.mu.Unlock() @@ -288,30 +294,21 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // so it's able to log information in processCDFlags. p.initLogging(true) - mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) - mainLog.Load().Info().Msgf("os: %s", osVersion()) + p.Info().Msgf("Starting ctrld %s", curVersion()) + p.Info().Msgf("OS: %s", osVersion()) // Wait for network up. if !ctrldnet.Up() { notifyExitToLogServer() - mainLog.Load().Fatal().Msg("network is not up yet") + p.Fatal().Msg("Network is not up yet") } - p.router = router.New(&cfg, cdUID != "") cs, err := newControlServer(filepath.Join(sockDir, ControlSocketName())) if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create control server") + p.Warn().Err(err).Msg("Could not create control server") } p.cs = cs - // Processing --cd flag require connecting to ControlD API, which needs valid - // time for validating server certificate. Some routers need NTP synchronization - // to set the current time, so this check must happen before processCDFlags. - if err := p.router.PreRun(); err != nil { - notifyExitToLogServer() - mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check") - } - oldLogPath := cfg.Service.LogPath if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid @@ -324,14 +321,14 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { return } - cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() + cdLogger := p.logger.Load().With().Str("mode", "cd") // Performs self-uninstallation if the ControlD device does not exist. var uer *controld.ErrorResponse if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { _ = uninstallInvalidCdUID(p, cdLogger, false) } notifyExitToLogServer() - cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config") + cdLogger.Fatal().Err(err).Msg("Failed to fetch resolver config") } else { p.mu.Lock() p.rc = rc @@ -348,24 +345,25 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if updated { if err := writeConfigFile(&cfg); err != nil { notifyExitToLogServer() - mainLog.Load().Fatal().Err(err).Msg("failed to write config file") + p.Fatal().Err(err).Msg("Failed to write config file") } else { - mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) + p.Info().Msg("Writing config file to: " + defaultConfigFile) } } if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath { // After processCDFlags, log config may change, so reset mainLog and re-init logging. - l := zerolog.New(io.Discard) - mainLog.Store(&l) + l := zap.NewNop() + mainLog.Store(&ctrld.Logger{Logger: l}) // Copy logs written so far to new log file if possible. if buf, err := os.ReadFile(oldLogPath); err == nil { if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not copy old log file") + p.Warn().Err(err).Msg("Could not copy old log file") } } initLoggingWithBackup(false) + p.logger.Store(mainLog.Load()) } if err := validateConfig(&cfg); err != nil { @@ -377,13 +375,13 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if daemon { exe, err := os.Executable() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to find the binary") + p.Error().Err(err).Msg("Failed to find the binary") notifyExitToLogServer() os.Exit(1) } curDir, err := os.Getwd() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get current working directory") + p.Error().Err(err).Msg("Failed to get current working directory") notifyExitToLogServer() os.Exit(1) } @@ -391,11 +389,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...) cmd.Dir = curDir if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start process as daemon") + p.Error().Err(err).Msg("Failed to start process as daemon") notifyExitToLogServer() os.Exit(1) } - mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") + p.Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") os.Exit(0) } @@ -403,7 +401,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := allocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) + p.Error().Err(err).Msgf("Could not allocate ip: %s", lc.IP) } } } @@ -414,41 +412,22 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { for _, lc := range p.cfg.Listener { if shouldAllocateLoopbackIP(lc.IP) { if err := deAllocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) + p.Error().Err(err).Msgf("Could not de-allocate ip: %s", lc.IP) } } } }) - if platform := router.Name(); platform != "" { - if cp := router.CertPool(); cp != nil { - rootCertPool = cp - } - if iface != "" { - p.onStarted = append(p.onStarted, func() { - mainLog.Load().Debug().Msg("router setup on start") - if err := p.router.Setup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not configure router") - } - }) - p.onStopped = append(p.onStopped, func() { - mainLog.Load().Debug().Msg("router cleanup on stop") - if err := p.router.Cleanup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not cleanup router") - } - }) - } - } p.onStopped = append(p.onStopped, func() { // restore static DNS settings or DHCP p.resetDNS(false, true) // Iterate over all physical interfaces and restore static DNS if a saved static config exists. withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) + file := ctrld.SavedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { - mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + p.Error().Err(err).Msgf("Could not restore static dns on interface %s", i.Name) } else { - mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + p.Debug().Msgf("Restored static dns on interface %s successfully", i.Name) } } return nil @@ -459,29 +438,41 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { <-stopCh } +// writeConfigFile writes the configuration to a file func writeConfigFile(cfg *ctrld.Config) error { + mainLog.Load().Debug().Msg("Writing configuration file") + if cfu := v.ConfigFileUsed(); cfu != "" { defaultConfigFile = cfu } else if configPath != "" { defaultConfigFile = configPath } + + mainLog.Load().Debug().Str("config_file", defaultConfigFile).Msg("Opening configuration file for writing") + f, err := os.OpenFile(defaultConfigFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644)) if err != nil { + mainLog.Load().Error().Err(err).Str("config_file", defaultConfigFile).Msg("Failed to open configuration file") return err } defer f.Close() if cdUID != "" { if _, err := f.WriteString("# AUTO-GENERATED VIA CD FLAG - DO NOT MODIFY\n\n"); err != nil { + mainLog.Load().Error().Err(err).Msg("Failed to write CD header to configuration file") return err } } enc := toml.NewEncoder(f).SetIndentTables(true) if err := enc.Encode(&cfg); err != nil { + mainLog.Load().Error().Err(err).Str("config_file", defaultConfigFile).Msg("Failed to encode configuration") return err } if err := f.Close(); err != nil { + mainLog.Load().Error().Err(err).Str("config_file", defaultConfigFile).Msg("Failed to close configuration file") return err } + + mainLog.Load().Debug().Str("config_file", defaultConfigFile).Msg("Configuration file written successfully") return nil } @@ -496,7 +487,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { if notice { mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed()) } - mainLog.Load().Info().Msg("loading config file from: " + v.ConfigFileUsed()) + mainLog.Load().Info().Msg("Loading config file from: " + v.ConfigFileUsed()) defaultConfigFile = v.ConfigFileUsed() return true } @@ -508,22 +499,21 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { // If error is viper.ConfigFileNotFoundError, write default config. if errors.As(err, &viper.ConfigFileNotFoundError{}) { if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) + mainLog.Load().Fatal().Msgf("Failed to unmarshal default config: %v", err) } - nop := zerolog.Nop() - _, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true) + _, _ = tryUpdateListenerConfig(&cfg, func() {}, true) addExtraSplitDnsRule(&cfg) if err := writeConfigFile(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) + mainLog.Load().Fatal().Msgf("Failed to write default config file: %v", err) } else { fp, err := filepath.Abs(defaultConfigFile) if err != nil { - mainLog.Load().Fatal().Msgf("failed to get default config file path: %v", err) + mainLog.Load().Fatal().Msgf("Failed to get default config file path: %v", err) } if cdUID == "" && nextdns == "" { mainLog.Load().Notice().Msg("Generating controld default config: " + fp) } - mainLog.Load().Info().Msg("writing default config file to: " + fp) + mainLog.Load().Info().Msg("Writing default config file to: " + fp) } return false } @@ -532,12 +522,12 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { if errors.As(err, &viper.ConfigParseError{}) { if de := decoderErrorFromTomlFile(v.ConfigFileUsed()); de != nil { row, col := de.Position() - mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err) + mainLog.Load().Fatal().Msgf("Failed to decode config file at line: %d, column: %d, error: %v", row, col, err) } } // Otherwise, report fatal error and exit. - mainLog.Load().Fatal().Msgf("failed to decode config file: %v", err) + mainLog.Load().Fatal().Msgf("Failed to decode config file: %v", err) return false } @@ -559,11 +549,17 @@ func readBase64Config(configBase64 string) error { if configBase64 == "" { return nil } + + mainLog.Load().Debug().Msg("Reading base64 encoded configuration") + configStr, err := base64.StdEncoding.DecodeString(configBase64) if err != nil { + mainLog.Load().Error().Err(err).Msg("Failed to decode base64 configuration") return fmt.Errorf("invalid base64 config: %w", err) } + mainLog.Load().Debug().Int("config_length", len(configStr)).Msg("Base64 configuration decoded successfully") + // readBase64Config is called when: // // - "--base64_config" flag set. @@ -572,24 +568,39 @@ func readBase64Config(configBase64 string) error { // So we need to re-create viper instance to discard old one. v = viper.NewWithOptions(viper.KeyDelimiter("::")) v.SetConfigType("toml") - return v.ReadConfig(bytes.NewReader(configStr)) + + mainLog.Load().Debug().Msg("Parsing base64 configuration as TOML") + + if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil { + mainLog.Load().Error().Err(err).Msg("Failed to parse base64 configuration as TOML") + return err + } + + mainLog.Load().Debug().Msg("Base64 configuration processed successfully") + return nil } +// processNoConfigFlags processes flags for no-config mode func processNoConfigFlags(noConfigStart bool) { if !noConfigStart { return } + + mainLog.Load().Debug().Msg("Processing no-config mode flags") + if listenAddress == "" || primaryUpstream == "" { mainLog.Load().Fatal().Msg(`"listen" and "primary_upstream" flags must be set in no config mode`) } processListenFlag() endpointAndTyp := func(endpoint string) (string, string) { + mainLog.Load().Debug().Str("endpoint", endpoint).Msg("Processing endpoint for resolver type") typ := ctrld.ResolverTypeFromEndpoint(endpoint) endpoint = strings.TrimPrefix(endpoint, "quic://") if after, found := strings.CutPrefix(endpoint, "h3://"); found { endpoint = "https://" + after } + mainLog.Load().Debug().Str("endpoint", endpoint).Str("type", typ).Msg("Endpoint processed") return endpoint, typ } pEndpoint, pType := endpointAndTyp(primaryUpstream) @@ -599,7 +610,8 @@ func processNoConfigFlags(noConfigStart bool) { Type: pType, Timeout: 5000, } - puc.Init() + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + puc.Init(loggerCtx) upstream := map[string]*ctrld.UpstreamConfig{"0": puc} if secondaryUpstream != "" { sEndpoint, sType := endpointAndTyp(secondaryUpstream) @@ -609,7 +621,7 @@ func processNoConfigFlags(noConfigStart bool) { Type: sType, Timeout: 5000, } - suc.Init() + suc.Init(loggerCtx) upstream["1"] = suc rules := make([]ctrld.Rule, 0, len(domains)) for _, domain := range domains { @@ -637,18 +649,23 @@ func deactivationPinSet() bool { return cdDeactivationPin.Load() != defaultDeactivationPin } +// processCDFlags processes Control D related flags func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { - logger := mainLog.Load().With().Str("mode", "cd").Logger() - logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) + logger := mainLog.Load().With().Str("mode", "cd") + logger.Info().Msgf("Fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) bo.LogLongerThan = 30 * time.Second - ctx := context.Background() - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + ctx := ctrld.LoggerCtx(context.Background(), logger) + resolverConfig, err := controld.FetchResolverConfig(ctx, cdUID, appVersion, cdDev) + + // Retry logic for network errors using bootstrap DNS + // This is needed because the initial DNS resolution might fail due to network issues + // or DNS server unavailability, but bootstrap DNS can provide alternative resolution for { if errUrlNetworkError(err) { bo.BackOff(ctx, err) - logger.Warn().Msg("could not fetch resolver using bootstrap DNS, retrying...") - resolverConfig, err = controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + logger.Warn().Msg("Could not fetch resolver using bootstrap DNS, retrying...") + resolverConfig, err = controld.FetchResolverConfig(ctx, cdUID, appVersion, cdDev) continue } break @@ -657,21 +674,23 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { if isMobile() { return nil, err } - logger.Warn().Err(err).Msg("could not fetch resolver config") + logger.Warn().Err(err).Msg("Could not fetch resolver config") return nil, err } if resolverConfig.DeactivationPin != nil { - logger.Debug().Msg("saving deactivation pin") + logger.Debug().Msg("Saving deactivation pin") cdDeactivationPin.Store(*resolverConfig.DeactivationPin) } - logger.Info().Msg("generating ctrld config from Control-D configuration") + logger.Info().Msg("Generating ctrld config from Control-D configuration") + // Reset config to ensure clean state before applying Control-D settings + // This prevents mixing of old configuration with new Control-D settings *cfg = ctrld.Config{} // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { - logger.Info().Msg("using defined custom config of Control-D resolver") + logger.Info().Msg("Using defined custom config of Control-D resolver") var cfgErr error if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil { setListenerDefaultValue(cfg) @@ -680,13 +699,13 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { return resolverConfig, nil } } - mainLog.Load().Warn().Err(err).Msg("disregarding invalid custom config") + mainLog.Load().Warn().Err(err).Msg("Disregarding invalid custom config") } bootstrapIP := func(endpoint string) string { u, err := url.Parse(endpoint) if err != nil { - logger.Warn().Err(err).Msgf("no bootstrap IP for invalid endpoint: %s", endpoint) + logger.Warn().Err(err).Msgf("No bootstrap ip for invalid endpoint: %s", endpoint) return "" } switch { @@ -698,6 +717,8 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { return "" } + // Initialize upstream configuration with Control-D resolver settings + // This creates the primary DNS resolver configuration for the proxy cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) cfg.Upstream["0"] = &ctrld.UpstreamConfig{ BootstrapIP: bootstrapIP(resolverConfig.DOH), @@ -705,10 +726,16 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { Type: cdUpstreamProto, Timeout: 5000, } + + // Create exclusion rules for domains that should bypass Control-D + // These domains will be resolved using the system's default DNS servers rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) for _, domain := range resolverConfig.Exclude { rules = append(rules, ctrld.Rule{domain: []string{}}) } + + // Initialize listener configuration with policy rules + // This sets up the DNS proxy listener with the exclusion policy cfg.Listener = make(map[string]*ctrld.ListenerConfig) lc := &ctrld.ListenerConfig{ Policy: &ctrld.ListenerPolicyConfig{ @@ -759,17 +786,20 @@ func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) erro return v.Unmarshal(&cfg) } +// processListenFlag processes the listen flag func processListenFlag() { if listenAddress == "" { return } + mainLog.Load().Debug().Str("listen_address", listenAddress).Msg("Processing listen flag") + host, portStr, err := net.SplitHostPort(listenAddress) if err != nil { - mainLog.Load().Fatal().Msgf("invalid listener address: %v", err) + mainLog.Load().Fatal().Msgf("Invalid listener address: %v", err) } port, err := strconv.Atoi(portStr) if err != nil { - mainLog.Load().Fatal().Msgf("invalid port number: %v", err) + mainLog.Load().Fatal().Msgf("Invalid port number: %v", err) } lc := &ctrld.ListenerConfig{ IP: host, @@ -778,23 +808,34 @@ func processListenFlag() { v.Set("listener", map[string]*ctrld.ListenerConfig{ "0": lc, }) + + mainLog.Load().Debug().Str("host", host).Int("port", port).Msg("Listen flag processed successfully") } +// processLogAndCacheFlags processes log and cache related flags func processLogAndCacheFlags() { + mainLog.Load().Debug().Msg("Processing log and cache flags") + if logPath != "" { cfg.Service.LogPath = logPath + mainLog.Load().Debug().Str("log_path", logPath).Msg("Log path flag processed") } if logPath != "" && cfg.Service.LogLevel == "" { cfg.Service.LogLevel = "debug" + mainLog.Load().Debug().Msg("Log level set to debug") } if cacheSize != 0 { cfg.Service.CacheEnable = true cfg.Service.CacheSize = cacheSize + mainLog.Load().Debug().Int("cache_size", cacheSize).Msg("Cache flag processed") } v.Set("service", cfg.Service) + + mainLog.Load().Debug().Msg("Log and cache flags processed successfully") } +// netInterface returns the network interface by name func netInterface(ifaceName string) (*net.Interface, error) { if ifaceName == "auto" { ifaceName = defaultIfaceName() @@ -814,10 +855,8 @@ func netInterface(ifaceName string) (*net.Interface, error) { return iface, err } +// defaultIfaceName returns the default interface name func defaultIfaceName() string { - if ifaceName := router.DefaultInterfaceName(); ifaceName != "" { - return ifaceName - } dri, err := netmon.DefaultRouteInterface() if err != nil { // On WSL 1, the route table does not have any default route. But the fact that @@ -830,7 +869,7 @@ func defaultIfaceName() string { if runtime.GOOS == "linux" { return "lo" } - mainLog.Load().Debug().Err(err).Msg("no default route interface found") + mainLog.Load().Debug().Err(err).Msg("No default route interface found") return "" } return dri @@ -849,7 +888,7 @@ func defaultIfaceName() string { func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bool, service.Status, error) { status, err := s.Status() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not get service status") + mainLog.Load().Warn().Err(err).Msg("Could not get service status") return false, service.StatusUnknown, err } // If ctrld is not running, do nothing, just return the status as-is. @@ -861,7 +900,7 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo return true, status, nil } - mainLog.Load().Debug().Msg("waiting for ctrld listener to be ready") + mainLog.Load().Debug().Msg("Waiting for ctrld listener to be ready") cc := newSocketControlClient(ctx, s, sockDir) if cc == nil { return false, status, errors.New("could not connect to control server") @@ -874,13 +913,13 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo v.SetConfigFile(defaultConfigFile) } if err := v.ReadInConfig(); err != nil { - mainLog.Load().Error().Err(err).Msgf("failed to re-read configuration file: %s", v.ConfigFileUsed()) + mainLog.Load().Error().Err(err).Msgf("Failed to re-read configuration file: %s", v.ConfigFileUsed()) return false, status, err } cfg = ctrld.Config{} if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to update new config") + mainLog.Load().Error().Err(err).Msg("Failed to update new config") return false, status, err } @@ -890,12 +929,12 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo return true, status, nil } - mainLog.Load().Debug().Msg("ctrld listener is ready") + mainLog.Load().Debug().Msg("Ctrld listener is ready") lc := cfg.FirstListener() addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) - mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr) + mainLog.Load().Debug().Msgf("Performing listener test, sending queries to %s", addr) if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil { return false, status, err @@ -945,19 +984,21 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri lastErr = exErr bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr)) } - mainLog.Load().Debug().Msgf("self-check against %q failed", domain) + mainLog.Load().Debug().Msgf("Self-check against %q failed", domain) + loggerCtx := ctrld.LoggerCtx(ctx, mainLog.Load()) // Ping all upstreams to provide better error message to users. for name, uc := range cfg.Upstream { - if err := uc.ErrorPing(); err != nil { - mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) + if err := uc.ErrorPing(loggerCtx); err != nil { + mainLog.Load().Err(err).Msgf("Failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) } } marker := strings.Repeat("=", 32) mainLog.Load().Debug().Msg(marker) - mainLog.Load().Debug().Msgf("listener address : %s", addr) - mainLog.Load().Debug().Msgf("last error : %v", lastErr) + + mainLog.Load().Debug().Msgf("Listener address : %s", addr) + mainLog.Load().Debug().Msgf("Last error : %v", lastErr) if lastAnswer != nil { - mainLog.Load().Debug().Msgf("last answer from ctrld :") + mainLog.Load().Debug().Msgf("Last answer from ctrld :") mainLog.Load().Debug().Msg(marker) for _, s := range strings.Split(lastAnswer.String(), "\n") { mainLog.Load().Debug().Msgf("%s", s) @@ -966,36 +1007,13 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri return errSelfCheckNoAnswer } +// userHomeDir returns the user's home directory func userHomeDir() (string, error) { - dir, err := router.HomeDir() - if err != nil { - return "", err - } - if dir != "" { - return dir, nil - } - // viper will expand for us. - if runtime.GOOS == "windows" { - // If we're on windows, use the install path for this. - exePath, err := os.Executable() - if err != nil { - return "", err - } - - return filepath.Dir(exePath), nil - } // Mobile platform should provide a rw dir path for this. if isMobile() { return homedir, nil } - dir = "/etc/controld" - if err := os.MkdirAll(dir, 0750); err != nil { - return os.UserHomeDir() // fallback to user home directory - } - if ok, _ := dirWritable(dir); !ok { - return os.UserHomeDir() - } - return dir, nil + return ctrld.UserHomeDir() } // socketDir returns directory that ctrld will create socket file for running controlServer. @@ -1051,7 +1069,7 @@ func readConfigWithNotice(writeDefaultConfig, notice bool) { dir, err := userHomeDir() if err != nil { - mainLog.Load().Fatal().Msgf("failed to get user home dir: %v", err) + mainLog.Load().Fatal().Msgf("Failed to get user home dir: %v", err) } for _, config := range configs { ctrld.SetConfigNameWithPath(v, config.name, dir) @@ -1073,54 +1091,46 @@ func uninstall(p *prog, s service.Service) { } initInteractiveLogging() if doTasks(tasks) { - if err := p.router.ConfigureService(svcConfig); err != nil { - mainLog.Load().Fatal().Err(err).Msg("could not configure service") - } - if err := p.router.Uninstall(svcConfig); err != nil { - mainLog.Load().Warn().Err(err).Msg("post uninstallation failed, please check system/service log for details error") - return - } // restore static DNS settings or DHCP p.resetDNS(false, true) // Iterate over all physical interfaces and restore DNS if a saved static config exists. withEachPhysicalInterfaces(p.runningIface, "restore static DNS", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) + file := ctrld.SavedStaticDnsSettingsFilePath(i) if _, err := os.Stat(file); err == nil { if err := restoreDNS(i); err != nil { - mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name) + mainLog.Load().Error().Err(err).Msgf("Could not restore static dns on interface %s", i.Name) } else { - mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name) + mainLog.Load().Debug().Msgf("Restored static dns on interface %s successfully", i.Name) err = os.Remove(file) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("Could not remove saved static DNS file for interface %s", i.Name) + mainLog.Load().Debug().Err(err).Msgf("Could not remove saved static dns file for interface %s", i.Name) } } } return nil }) - if router.Name() != "" { - mainLog.Load().Debug().Msg("Router cleanup") - } - // Stop already did router.Cleanup and report any error if happens, - // ignoring error here to prevent false positive. - _ = p.router.Cleanup() mainLog.Load().Notice().Msg("Service uninstalled") return } } func validateConfig(cfg *ctrld.Config) error { + mainLog.Load().Debug().Msg("Validating configuration") + if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil { var ve validator.ValidationErrors if errors.As(err, &ve) { for _, fe := range ve { - mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) + mainLog.Load().Error().Msgf("Invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) } } + mainLog.Load().Error().Err(err).Msg("Configuration validation failed") return err } + + mainLog.Load().Debug().Msg("Configuration validation completed successfully") return nil } @@ -1206,7 +1216,7 @@ func mobileListenerIp() string { // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool { - updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true) + updated, _ := tryUpdateListenerConfig(cfg, notifyToLogServerFunc, true) if addExtraSplitDnsRule(cfg) { updated = true } @@ -1216,26 +1226,20 @@ func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool // tryUpdateListenerConfig tries updating listener config with a working one. // If fatal is true, and there's listen address conflicted, the function do // fatal error. -func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) { +func tryUpdateListenerConfig(cfg *ctrld.Config, notifyFunc func(), fatal bool) (updated, ok bool) { ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" nextdnsMode := nextdns != "" - // For Windows server with local Dns server running, we can only try on random local IP. - hasLocalDnsServer := hasLocalDnsServerRunning() - notRouter := router.Name() == "" isDesktop := ctrld.IsDesktopPlatform() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { listener.IP = "0.0.0.0" - // Windows Server lies to us that we could listen on 0.0.0.0:53 - // even there's a process already done that, stick to local IP only. - // // For desktop clients, also stick the listener to the local IP only. // Listening on 0.0.0.0 would expose it to the entire local network, potentially // creating security vulnerabilities (such as DNS amplification or abusing). - if hasLocalDnsServer || isDesktop { + if isDesktop { listener.IP = "127.0.0.1" } lcc[n].IP = true @@ -1246,25 +1250,19 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti } // In cd mode, we always try to pick an ip:port pair to work. // Same if nextdns resolver is used. - // - // Except on Windows Server with local Dns running, - // we could only listen on random local IP port 53. if cdMode || nextdnsMode { lcc[n].IP = true lcc[n].Port = true - if hasLocalDnsServer { - lcc[n].Port = false - } } updated = updated || lcc[n].IP || lcc[n].Port } il := mainLog.Load() - if infoLogger != nil { - il = infoLogger - } if isMobile() { // On Mobile, only use first listener, ignore others. + // This is needed because mobile platforms have limited resources and + // multiple listeners can cause conflicts with system DNS services and + // likely don't work anyway. firstLn := cfg.FirstListener() for k := range cfg.Listener { if cfg.Listener[k] != firstLn { @@ -1272,6 +1270,8 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti } } if cdMode { + // Use mobile-specific listener settings for Control-D mode + // Mobile platforms require specific IP/port combinations to avoid permission issues. firstLn.IP = mobileListenerIp() firstLn.Port = mobileListenerPort() clear(lcc) @@ -1299,7 +1299,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti return errors.Join(udpErr, tcpErr) } - logMsg := func(e *zerolog.Event, listenerNum int, format string, v ...any) { + logMsg := func(e *ctrld.LogEvent, listenerNum int, format string, v ...any) { e.MsgFunc(func() string { return fmt.Sprintf("listener.%d %s", listenerNum, fmt.Sprintf(format, v...)) }) @@ -1334,21 +1334,12 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti // On firewalla, we don't need to check localhost, because the lo interface is excluded in dnsmasq // config, so we can always listen on localhost port 53, but no traffic could be routed there. - tryLocalhost := !isLoopback(listener.IP) && router.CanListenLocalhost() + tryLocalhost := !isLoopback(listener.IP) tryAllPort53 := true - tryOldIPPort5354 := true - tryPort5354 := true - if hasLocalDnsServer { + if isZeroIP && listener.Port == 53 { tryAllPort53 = false - tryOldIPPort5354 = false - tryPort5354 = false - } - // if not running on a router, we should not try to listen on any port other than 53 - // if we do, this will break the dns resolution for the system. - if notRouter { - tryOldIPPort5354 = false - tryPort5354 = false } + attempts := 0 maxAttempts := 10 for { @@ -1362,7 +1353,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti break } - logMsg(il.Info(), n, "error listening on address: %s, error: %v", addr, err) + logMsg(il.Debug().Err(err), n, "error listening on address: %s", addr) if !check.IP && !check.Port { if fatal { @@ -1372,6 +1363,9 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti ok = false break } + + // Try standard port 53 first for better compatibility + // This is the most common DNS port and has the highest chance of working if tryAllPort53 { tryAllPort53 = false if check.IP { @@ -1385,6 +1379,9 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti } continue } + + // Try localhost as fallback for security and compatibility + // Localhost is often available even when other addresses are blocked if tryLocalhost { tryLocalhost = false if check.IP { @@ -1398,36 +1395,15 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti } continue } - if tryOldIPPort5354 { - tryOldIPPort5354 = false - if check.IP { - listener.IP = oldIP - } - if check.Port { - listener.Port = 5354 - } - logMsg(il.Info(), n, "could not listen on address: %s, trying current ip with port 5354", addr) - continue - } - if tryPort5354 { - tryPort5354 = false - if check.IP { - listener.IP = "0.0.0.0" - } - if check.Port { - listener.Port = 5354 - } - logMsg(il.Info(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr) - continue - } + + // Try random IP/port combinations as last resort + // This ensures the service can start even in constrained environments if check.IP && !isZeroIP { // for "0.0.0.0" or "::", we only need to try new port. listener.IP = randomLocalIP() } else { listener.IP = oldIP } - // if we are not running on a router, we should not try to listen on any port other than 53 - // if we do, this will break the dns resolution for the system. - if check.Port && !notRouter { + if check.Port { listener.Port = randomPort() } else { listener.Port = oldPort @@ -1449,6 +1425,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti } // Specific case for systemd-resolved. + // systemd-resolved has specific requirements for DNS forwarding that we must handle if useSystemdResolved { if listener := cfg.FirstListener(); listener != nil && listener.Port == 53 { n := listeners[0] @@ -1482,6 +1459,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti return } +// dirWritable checks if a directory is writable func dirWritable(dir string) (bool, error) { f, err := os.CreateTemp(dir, "") if err != nil { @@ -1491,6 +1469,7 @@ func dirWritable(dir string) (bool, error) { return true, f.Close() } +// osVersion returns the operating system version func osVersion() string { oi := osinfo.New() if runtime.GOOS == "freebsd" { @@ -1513,13 +1492,14 @@ func cdUIDFromProvToken() string { } // Validate custom hostname if provided. if customHostname != "" && !validHostname(customHostname) { - mainLog.Load().Fatal().Msgf("invalid custom hostname: %q", customHostname) + mainLog.Load().Fatal().Msgf("Invalid custom hostname: %q", customHostname) } req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname} // Process provision token if provided. - resolverConfig, err := controld.FetchResolverUID(req, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + resolverConfig, err := controld.FetchResolverUID(loggerCtx, req, appVersion, cdDev) if err != nil { - mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg) + mainLog.Load().Fatal().Err(err).Msgf("Failed to fetch resolver uid with provision token: %s", cdOrg) } return resolverConfig.UID } @@ -1631,6 +1611,7 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { } } +// validateCdUpstreamProtocol validates the Control D upstream protocol func validateCdUpstreamProtocol() { if cdUID == "" { return @@ -1638,10 +1619,11 @@ func validateCdUpstreamProtocol() { switch cdUpstreamProto { case ctrld.ResolverTypeDOH, ctrld.ResolverTypeDOH3: default: - mainLog.Load().Fatal().Msg(`flag "--protocol" must be "doh" or "doh3"`) + mainLog.Load().Fatal().Msg(`Flag "--protocol" must be "doh" or "doh3"`) } } +// validateCdAndNextDNSFlags validates that Control D and NextDNS flags are not used together func validateCdAndNextDNSFlags() { if (cdUID != "" || cdOrg != "") && nextdns != "" { mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName) @@ -1682,6 +1664,7 @@ func doGenerateNextDNSConfig(uid string) error { return writeConfigFile(&cfg) } +// noticeWritingControlDConfig logs on notice level that a Control D config is being written func noticeWritingControlDConfig() error { if cdUID != "" { mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile) @@ -1703,7 +1686,7 @@ func checkDeactivationPin(s service.Service, stopCh chan struct{}) error { mainLog.Load().Debug().Msg("Checking deactivation pin") dir, err := socketDir() if err != nil { - mainLog.Load().Err(err).Msg("could not check deactivation pin") + mainLog.Load().Err(err).Msg("Could not check deactivation pin") return err } mainLog.Load().Debug().Msg("Creating control client") @@ -1762,29 +1745,13 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M return c.ExchangeContext(ctx, msg, addr) } -// absHomeDir returns the absolute path to given filename using home directory as root dir. -func absHomeDir(filename string) string { - if homedir != "" { - return filepath.Join(homedir, filename) - } - dir, err := userHomeDir() - if err != nil { - return filename - } - return filepath.Join(dir, filename) -} - -// runInCdMode reports whether ctrld service is running in cd mode. -func runInCdMode() bool { - return curCdUID() != "" -} - // curCdUID returns the current ControlD UID used by running ctrld process. func curCdUID() string { - if s, _ := newService(&prog{}, svcConfig); s != nil { + svcCmd := NewServiceCommand() + if s, _, _ := svcCmd.initializeServiceManager(); s != nil { // Configure Windows service failure actions if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + mainLog.Load().Debug().Err(err).Msgf("Failed to configure windows service %s failure actions", ctrldServiceName) } if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(context.TODO(), s, dir) @@ -1819,7 +1786,7 @@ func goArm() string { // upgradeUrl returns the url for downloading new ctrld binary. func upgradeUrl(baseUrl string) string { - dlPath := fmt.Sprintf("%s-%s/ctrld", runtime.GOOS, runtime.GOARCH) + dlPath := fmt.Sprintf("v2/%s-%s/ctrld", runtime.GOOS, runtime.GOARCH) // Use arm version set during build time, v5 binary can be run on higher arm version system. if armVersion := goArm(); armVersion != "" { dlPath = fmt.Sprintf("%s-%sv%s/ctrld", runtime.GOOS, runtime.GOARCH, armVersion) @@ -1856,13 +1823,14 @@ func runningIface(s service.Service) *ifaceResponse { // doValidateCdRemoteConfig fetches and validates custom config for cdUID. func doValidateCdRemoteConfig(cdUID string, fatal bool) error { - rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + loggerCtx := ctrld.LoggerCtx(context.Background(), mainLog.Load()) + rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev) if err != nil { logger := mainLog.Load().Fatal() if !fatal { logger = mainLog.Load().Warn() } - logger.Err(err).Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) + logger.Err(err).Err(err).Msgf("Failed to fetch resolver uid: %s", cdUID) if !fatal { return err } @@ -1891,32 +1859,33 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) error { if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil { if de := decoderErrorFromTomlFile(tmpConfFile); de != nil { row, col := de.Position() - mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error()) + mainLog.Load().Error().Msgf("Failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error()) errorLogged = true } _ = os.Remove(tmpConfFile) } // If we could not log details error, emit what we have already got. if !errorLogged { - mainLog.Load().Error().Msgf("failed to parse custom config: %v", cfgErr) + mainLog.Load().Error().Msgf("Failed to parse custom config: %v", cfgErr) } } } else { - mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err) + mainLog.Load().Error().Msgf("Failed to unmarshal custom config: %v", err) } } if cfgErr != nil { - mainLog.Load().Warn().Msg("disregarding invalid custom config") + mainLog.Load().Warn().Msg("Disregarding invalid custom config") } v = oldV return nil } // uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist. -func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { - s, err := newService(p, svcConfig) +func uninstallInvalidCdUID(p *prog, logger *ctrld.Logger, doStop bool) bool { + svcCmd := NewServiceCommand() + s, _, err := svcCmd.initializeServiceManager() if err != nil { - logger.Warn().Err(err).Msg("failed to create new service") + logger.Warn().Err(err).Msg("Failed to create new service") return false } // restore static DNS settings or DHCP @@ -1924,7 +1893,7 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { tasks := []task{{s.Uninstall, true, "Uninstall"}} if doTasks(tasks) { - logger.Info().Msg("uninstalled service") + logger.Info().Msg("Uninstalled service") if doStop { _ = s.Stop() } diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go deleted file mode 100644 index a1074f29..00000000 --- a/cmd/cli/commands.go +++ /dev/null @@ -1,1397 +0,0 @@ -package cli - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "net/http" - "os" - "os/exec" - "path/filepath" - "runtime" - "slices" - "sort" - "strconv" - "strings" - "time" - - "github.com/docker/go-units" - "github.com/kardianos/service" - "github.com/minio/selfupdate" - "github.com/olekukonko/tablewriter" - "github.com/spf13/cobra" - "github.com/spf13/pflag" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/clientinfo" - "github.com/Control-D-Inc/ctrld/internal/router" -) - -// dialSocketControlServerTimeout is the default timeout to wait when ping control server. -const dialSocketControlServerTimeout = 30 * time.Second - -func initLogCmd() *cobra.Command { - warnRuntimeLoggingNotEnabled := func() { - mainLog.Load().Warn().Msg("runtime debug logging is not enabled") - mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`) - } - logSendCmd := &cobra.Command{ - Use: "send", - Short: "Send runtime debug logs to ControlD", - Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - - p := &prog{router: router.New(&cfg, false)} - s, _ := newService(p, svcConfig) - - status, err := s.Status() - if errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - if status == service.StatusStopped { - mainLog.Load().Warn().Msg("service is not running") - return - } - - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(sendLogsPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to send logs") - } - defer resp.Body.Close() - switch resp.StatusCode { - case http.StatusServiceUnavailable: - mainLog.Load().Warn().Msg("runtime logs could only be sent once per minute") - return - case http.StatusMovedPermanently: - warnRuntimeLoggingNotEnabled() - return - } - var logs logSentResponse - if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to decode sent logs result") - } - size := units.BytesSize(float64(logs.Size)) - if logs.Error == "" { - mainLog.Load().Notice().Msgf("runtime logs sent successfully (%s)", size) - } else { - mainLog.Load().Error().Msgf("failed to send logs (%s)", size) - mainLog.Load().Error().Msg(logs.Error) - } - }, - } - logViewCmd := &cobra.Command{ - Use: "view", - Short: "View current runtime debug logs", - Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - - p := &prog{router: router.New(&cfg, false)} - s, _ := newService(p, svcConfig) - - status, err := s.Status() - if errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - if status == service.StatusStopped { - mainLog.Load().Warn().Msg("service is not running") - return - } - - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(viewLogsPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to get logs") - } - defer resp.Body.Close() - - switch resp.StatusCode { - case http.StatusMovedPermanently: - warnRuntimeLoggingNotEnabled() - return - case http.StatusBadRequest: - mainLog.Load().Warn().Msg("runtime debugs log is not available") - buf, err := io.ReadAll(resp.Body) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to read response body") - } - mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf)) - return - case http.StatusOK: - } - var logs logViewResponse - if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to decode view logs result") - } - fmt.Println(logs.Data) - }, - } - logCmd := &cobra.Command{ - Use: "log", - Short: "Manage runtime debug logs", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - logSendCmd.Use, - }, - } - logCmd.AddCommand(logSendCmd) - logCmd.AddCommand(logViewCmd) - rootCmd.AddCommand(logCmd) - - return logCmd -} - -func initRunCmd() *cobra.Command { - runCmd := &cobra.Command{ - Use: "run", - Short: "Run the DNS proxy server", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - RunCobraCommand(cmd) - }, - } - runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon") - runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") - runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") - runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") - runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") - runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") - runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") - runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") - runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") - runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") - runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") - runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") - _ = runCmd.Flags().MarkHidden("dev") - runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") - _ = runCmd.Flags().MarkHidden("homedir") - runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) - _ = runCmd.Flags().MarkHidden("iface") - runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) - runCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener") - - runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true} - rootCmd.AddCommand(runCmd) - - return runCmd -} - -func initStartCmd() *cobra.Command { - startCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "start", - Short: "Install and start the ctrld service", - Long: `Install and start the ctrld service - -NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, - Args: func(cmd *cobra.Command, args []string) error { - args = filterEmptyStrings(args) - if len(args) > 0 { - return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + - "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - checkStrFlagEmpty(cmd, cdUidFlagName) - checkStrFlagEmpty(cmd, cdOrgFlagName) - validateCdAndNextDNSFlags() - sc := &service.Config{} - *sc = *svcConfig - osArgs := os.Args[2:] - osArgs = filterEmptyStrings(osArgs) - if os.Args[1] == "service" { - osArgs = os.Args[3:] - } - setDependencies(sc) - sc.Arguments = append([]string{"run"}, osArgs...) - - p := &prog{ - router: router.New(&cfg, cdUID != ""), - cfg: &cfg, - } - s, err := newService(p, sc) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - p.preRun() - - status, err := s.Status() - isCtrldRunning := status == service.StatusRunning - isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) - - // Get current running iface, if any. - var currentIface *ifaceResponse - - // If pin code was set, do not allow running start command. - if isCtrldRunning { - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - currentIface = runningIface(s) - mainLog.Load().Debug().Msgf("current interface on start: %v", currentIface) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - reportSetDnsOk := func(sockDir string) { - if cc := newSocketControlClient(ctx, s, sockDir); cc != nil { - if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { - if iface == "auto" { - iface = defaultIfaceName() - } - res := &ifaceResponse{} - if err := json.NewDecoder(resp.Body).Decode(res); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get iface info") - return - } - if res.OK { - name := res.Name - if iff, err := net.InterfaceByName(name); err == nil { - _, _ = patchNetIfaceName(iff) - name = iff.Name - } - logger := mainLog.Load().With().Str("iface", name).Logger() - logger.Debug().Msg("setting DNS successfully") - if res.All { - // Log that DNS is set for other interfaces. - withEachPhysicalInterfaces( - name, - "set DNS", - func(i *net.Interface) error { return nil }, - ) - } - } - } - } - } - - // No config path, generating config in HOME directory. - noConfigStart := isNoConfigStart(cmd) - writeDefaultConfig := !noConfigStart && configBase64 == "" - - logServerStarted := make(chan struct{}) - // A buffer channel to gather log output from runCmd and report - // to user in case self-check process failed. - runCmdLogCh := make(chan string, 256) - ud, err := userHomeDir() - sockDir := ud - if err != nil { - mainLog.Load().Warn().Msg("log server did not start") - close(logServerStarted) - } else { - setWorkingDirectory(sc, ud) - if configPath == "" && writeDefaultConfig { - defaultConfigFile = filepath.Join(ud, defaultConfigFile) - } - sc.Arguments = append(sc.Arguments, "--homedir="+ud) - if d, err := socketDir(); err == nil { - sockDir = d - } - sockPath := filepath.Join(sockDir, ctrldLogUnixSock) - _ = os.Remove(sockPath) - go func() { - defer func() { - close(runCmdLogCh) - _ = os.Remove(sockPath) - }() - close(logServerStarted) - if conn := runLogServer(sockPath); conn != nil { - // Enough buffer for log message, we don't produce - // such long log message, but just in case. - buf := make([]byte, 1024) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - msg := string(buf[:n]) - if _, _, found := strings.Cut(msg, msgExit); found { - cancel() - } - runCmdLogCh <- msg - } - } - }() - } - <-logServerStarted - - if !startOnly { - startOnly = len(osArgs) == 0 - } - // If user run "ctrld start" and ctrld is already installed, starting existing service. - if startOnly && isCtrldInstalled { - tryReadingConfigWithNotice(false, true) - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) - } - - // if already running, dont restart - if isCtrldRunning { - mainLog.Load().Notice().Msg("service is already running") - return - } - - initInteractiveLogging() - tasks := []task{ - {func() error { - // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { - if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { - return err - } - return nil - }) - return nil - }, false, "Save current DNS"}, - {func() error { - return ConfigureWindowsServiceFailureActions(ctrldServiceName) - }, false, "Configure service failure actions"}, - {s.Start, true, "Start"}, - {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, - } - mainLog.Load().Notice().Msg("Starting existing ctrld service") - if doTasks(tasks) { - mainLog.Load().Notice().Msg("Service started") - sockDir, err := socketDir() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory") - os.Exit(1) - } - reportSetDnsOk(sockDir) - } else { - mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service") - os.Exit(1) - } - return - } - - if cdUID != "" { - _ = doValidateCdRemoteConfig(cdUID, true) - } else if uid := cdUIDFromProvToken(); uid != "" { - cdUID = uid - mainLog.Load().Debug().Msg("using uid from provision token") - removeOrgFlagsFromArgs(sc) - // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. - sc.Arguments = append(sc.Arguments, "--cd="+cdUID) - } - if cdUID != "" { - validateCdUpstreamProtocol() - } - - if err := p.router.ConfigureService(sc); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to configure service on router") - } - - if configPath != "" { - v.SetConfigFile(configPath) - } - - tryReadingConfigWithNotice(writeDefaultConfig, true) - - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) - } - - initInteractiveLogging() - - if nextdns != "" { - removeNextDNSFromArgs(sc) - } - - // Explicitly passing config, so on system where home directory could not be obtained, - // or sub-process env is different with the parent, we still behave correctly and use - // the expected config file. - if configPath == "" { - sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile) - } - - if router.Name() != "" && iface != "" { - mainLog.Load().Debug().Msg("cleaning up router before installing") - _ = p.router.Cleanup() - } - - tasks := []task{ - {s.Stop, false, "Stop"}, - {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"}, - {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, - //resetDnsTask(p, s, isCtrldInstalled, currentIface), - {func() error { - // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { - if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { - return err - } - return nil - }) - return nil - }, false, "Save current DNS"}, - {s.Install, false, "Install"}, - {func() error { - return ConfigureWindowsServiceFailureActions(ctrldServiceName) - }, false, "Configure Windows service failure actions"}, - {s.Start, true, "Start"}, - // Note that startCmd do not actually write ControlD config, but the config file was - // generated after s.Start, so we notice users here for consistent with nextdns mode. - {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, - } - mainLog.Load().Notice().Msg("Starting service") - if doTasks(tasks) { - if err := p.router.Install(sc); err != nil { - mainLog.Load().Warn().Err(err).Msg("post installation failed, please check system/service log for details error") - return - } - - // add a small delay to ensure the service is started and did not crash - time.Sleep(1 * time.Second) - - ok, status, err := selfCheckStatus(ctx, s, sockDir) - switch { - case ok && status == service.StatusRunning: - mainLog.Load().Notice().Msg("Service started") - default: - marker := bytes.Repeat([]byte("="), 32) - // If ctrld service is not running, emitting log obtained from ctrld process. - if status != service.StatusRunning || ctx.Err() != nil { - mainLog.Load().Error().Msg("ctrld service may not have started due to an error or misconfiguration, service log:") - _, _ = mainLog.Load().Write(marker) - haveLog := false - for msg := range runCmdLogCh { - _, _ = mainLog.Load().Write([]byte(strings.ReplaceAll(msg, msgExit, ""))) - haveLog = true - } - // If we're unable to get log from "ctrld run", notice users about it. - if !haveLog { - mainLog.Load().Write([]byte(`"`)) - } - } - // Report any error if occurred. - if err != nil { - _, _ = mainLog.Load().Write(marker) - msg := fmt.Sprintf("An error occurred while performing test query: %s", err) - mainLog.Load().Write([]byte(msg)) - } - // If ctrld service is running but selfCheckStatus failed, it could be related - // to user's system firewall configuration, notice users about it. - if status == service.StatusRunning && err == nil { - _, _ = mainLog.Load().Write(marker) - mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) - mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) - } - - _, _ = mainLog.Load().Write(marker) - uninstall(p, s) - os.Exit(1) - } - reportSetDnsOk(sockDir) - } - }, - } - // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". - startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") - startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") - startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") - startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") - startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") - startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") - startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") - startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") - startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") - startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") - startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") - _ = startCmd.Flags().MarkHidden("dev") - startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) - startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") - startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) - startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) - startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service") - _ = startCmd.Flags().MarkHidden("start_only") - startCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener") - - routerCmd := &cobra.Command{ - Use: "setup", - Run: func(cmd *cobra.Command, _ []string) { - exe, err := os.Executable() - if err != nil { - mainLog.Load().Fatal().Msgf("could not find executable path: %v", err) - os.Exit(1) - } - flags := make([]string, 0) - cmd.Flags().Visit(func(flag *pflag.Flag) { - flags = append(flags, fmt.Sprintf("--%s=%s", flag.Name, flag.Value)) - }) - cmdArgs := []string{"start"} - cmdArgs = append(cmdArgs, flags...) - command := exec.Command(exe, cmdArgs...) - command.Stdout = os.Stdout - command.Stderr = os.Stderr - command.Stdin = os.Stdin - if err := command.Run(); err != nil { - mainLog.Load().Fatal().Msg(err.Error()) - } - }, - } - routerCmd.Flags().AddFlagSet(startCmd.Flags()) - routerCmd.Hidden = true - rootCmd.AddCommand(routerCmd) - - startCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "start", - Short: "Quick start service and configure DNS on interface", - Long: `Quick start service and configure DNS on interface - -NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, - Args: func(cmd *cobra.Command, args []string) error { - args = filterEmptyStrings(args) - if len(args) > 0 { - return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + - "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - if len(os.Args) == 2 { - startOnly = true - } - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - startCmd.Run(cmd, args) - }, - } - startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) - startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) - rootCmd.AddCommand(startCmdAlias) - - return startCmd -} - -func initStopCmd() *cobra.Command { - stopCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "stop", - Short: "Stop the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - p.preRun() - if ir := runningIface(s); ir != nil { - p.runningIface = ir.Name - p.requiredMultiNICsConfig = ir.All - } - - initInteractiveLogging() - - status, err := s.Status() - if errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - if status == service.StatusStopped { - mainLog.Load().Warn().Msg("service is already stopped") - return - } - - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - if doTasks([]task{{s.Stop, true, "Stop"}}) { - if router.WaitProcessExited() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - for { - select { - case <-ctx.Done(): - mainLog.Load().Error().Msg("timeout while waiting for service to stop") - return - default: - } - time.Sleep(time.Second) - if status, _ := s.Status(); status == service.StatusStopped { - break - } - } - } - mainLog.Load().Notice().Msg("Service stopped") - } - }, - } - stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) - stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`) - _ = stopCmd.Flags().MarkHidden("pin") - - stopCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "stop", - Short: "Quick stop service and remove DNS from interface", - Run: func(cmd *cobra.Command, args []string) { - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - stopCmd.Run(cmd, args) - }, - } - stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) - rootCmd.AddCommand(stopCmdAlias) - - return stopCmd -} - -func initRestartCmd() *cobra.Command { - restartCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "restart", - Short: "Restart the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - cdUID = curCdUID() - cdMode := cdUID != "" - - p := &prog{router: router.New(&cfg, cdMode)} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - if iface == "" { - iface = "auto" - } - p.preRun() - if ir := runningIface(s); ir != nil { - p.runningIface = ir.Name - p.requiredMultiNICsConfig = ir.All - } - - initInteractiveLogging() - - var validateConfigErr error - if cdMode { - validateConfigErr = doValidateCdRemoteConfig(cdUID, false) - } - - if ir := runningIface(s); ir != nil { - iface = ir.Name - } - - doRestart := func() bool { - tasks := []task{ - {s.Stop, true, "Stop"}, - {func() error { - p.router.Cleanup() - // restore static DNS settings or DHCP - p.resetDNS(false, true) - return nil - }, false, "Cleanup"}, - {func() error { - time.Sleep(time.Second * 1) - return nil - }, false, "Waiting for service to stop"}, - } - if doTasks(tasks) { - - if router.WaitProcessExited() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - loop: - for { - select { - case <-ctx.Done(): - mainLog.Load().Error().Msg("timeout while waiting for service to stop") - break loop - default: - } - time.Sleep(time.Second) - if status, _ := s.Status(); status == service.StatusStopped { - break - } - } - } - } else { - return false - } - - tasks = []task{ - {s.Start, true, "Start"}, - } - - return doTasks(tasks) - - } - - if doRestart() { - if dir, err := socketDir(); err == nil { - timeout := dialSocketControlServerTimeout - // If we failed to validate remote config above, it's likely that - // we are having problem with network connection. So using a shorter - // timeout than default one for better UX. - if validateConfigErr != nil { - timeout = 5 * time.Second - } - if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil { - _, _ = cc.post(ifacePath, nil) - } else { - mainLog.Load().Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet") - } - } else { - mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") - } - mainLog.Load().Notice().Msg("Service restarted") - } else { - mainLog.Load().Error().Msg("Service restart failed") - } - }, - } - - restartCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "restart", - Short: "Restart the ctrld service", - Run: func(cmd *cobra.Command, args []string) { - restartCmd.Run(cmd, args) - }, - } - rootCmd.AddCommand(restartCmdAlias) - - return restartCmd -} - -func initReloadCmd(restartCmd *cobra.Command) *cobra.Command { - reloadCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "reload", - Short: "Reload the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - - p := &prog{router: router.New(&cfg, false)} - s, _ := newService(p, svcConfig) - - status, err := s.Status() - if errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - if status == service.StatusStopped { - mainLog.Load().Warn().Msg("service is not running") - return - } - - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(reloadPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") - } - defer resp.Body.Close() - switch resp.StatusCode { - case http.StatusOK: - mainLog.Load().Notice().Msg("Service reloaded") - case http.StatusCreated: - s, err := newService(&prog{}, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") - mainLog.Load().Warn().Msg("Restarting service") - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("Service not installed") - return - } - restartCmd.Run(cmd, args) - default: - buf, err := io.ReadAll(resp.Body) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") - } - mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) - } - }, - } - - reloadCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "reload", - Short: "Reload the ctrld service", - Run: func(cmd *cobra.Command, args []string) { - reloadCmd.Run(cmd, args) - }, - } - rootCmd.AddCommand(reloadCmdAlias) - - return reloadCmd -} - -func initStatusCmd() *cobra.Command { - statusCmd := &cobra.Command{ - Use: "status", - Short: "Show status of the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - s, err := newService(&prog{}, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - status, err := s.Status() - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - os.Exit(1) - } - switch status { - case service.StatusUnknown: - mainLog.Load().Notice().Msg("Unknown status") - os.Exit(2) - case service.StatusRunning: - mainLog.Load().Notice().Msg("Service is running") - os.Exit(0) - case service.StatusStopped: - mainLog.Load().Notice().Msg("Service is stopped") - os.Exit(1) - } - }, - } - if runtime.GOOS == "darwin" { - // On darwin, running status command without privileges may return wrong information. - statusCmd.PreRun = func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - } - } - - statusCmdAlias := &cobra.Command{ - Use: "status", - Short: "Show status of the ctrld service", - Args: cobra.NoArgs, - Run: statusCmd.Run, - } - rootCmd.AddCommand(statusCmdAlias) - - return statusCmd -} - -func initUninstallCmd() *cobra.Command { - uninstallCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "uninstall", - Short: "Stop and uninstall the ctrld service", - Long: `Stop and uninstall the ctrld service. - -NOTE: Uninstalling will set DNS to values provided by DHCP.`, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if iface == "" { - iface = "auto" - } - p.preRun() - if ir := runningIface(s); ir != nil { - p.runningIface = ir.Name - p.requiredMultiNICsConfig = ir.All - } - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - uninstall(p, s) - if cleanup { - var files []string - // Config file. - files = append(files, v.ConfigFileUsed()) - // Log file and backup log file. - // For safety, only process if log file path is absolute. - if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) { - files = append(files, logFile) - oldLogFile := logFile + oldLogSuffix - if _, err := os.Stat(oldLogFile); err == nil { - files = append(files, oldLogFile) - } - } - // Socket files. - if dir, _ := socketDir(); dir != "" { - files = append(files, filepath.Join(dir, ctrldControlUnixSock)) - files = append(files, filepath.Join(dir, ctrldLogUnixSock)) - } - // Static DNS settings files. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) - if _, err := os.Stat(file); err == nil { - files = append(files, file) - } - return nil - }) - // Windows forwarders file. - if hasLocalDnsServerRunning() { - files = append(files, absHomeDir(windowsForwardersFilename)) - } - // Binary itself. - bin, _ := os.Executable() - if bin != "" && supportedSelfDelete { - files = append(files, bin) - } - // Backup file after upgrading. - oldBin := bin + oldBinSuffix - if _, err := os.Stat(oldBin); err == nil { - files = append(files, oldBin) - } - for _, file := range files { - if file == "" { - continue - } - if err := os.Remove(file); err != nil { - if os.IsNotExist(err) { - continue - } - mainLog.Load().Warn().Err(err).Msgf("failed to remove file: %s", file) - } else { - mainLog.Load().Debug().Msgf("file removed: %s", file) - } - } - if err := selfDeleteExe(); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to delete ctrld binary") - } else { - if !supportedSelfDelete { - mainLog.Load().Debug().Msgf("file removed: %s", bin) - } - } - } - }, - } - uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`) - uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`) - _ = uninstallCmd.Flags().MarkHidden("pin") - uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`) - - uninstallCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "uninstall", - Short: "Stop and uninstall the ctrld service", - Long: `Stop and uninstall the ctrld service. - -NOTE: Uninstalling will set DNS to values provided by DHCP.`, - Run: func(cmd *cobra.Command, args []string) { - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - uninstallCmd.Run(cmd, args) - }, - } - uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags()) - rootCmd.AddCommand(uninstallCmdAlias) - - return uninstallCmd -} - -func initInterfacesCmd() *cobra.Command { - listIfacesCmd := &cobra.Command{ - Use: "list", - Short: "List network interfaces of the host", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error { - fmt.Printf("Index : %d\n", i.Index) - fmt.Printf("Name : %s\n", i.Name) - var status string - if i.Flags&net.FlagUp != 0 { - status = "Up" - } else { - status = "Down" - } - fmt.Printf("Status: %s\n", status) - addrs, _ := i.Addrs() - for i, ipaddr := range addrs { - if i == 0 { - fmt.Printf("Addrs : %v\n", ipaddr) - continue - } - fmt.Printf(" %v\n", ipaddr) - } - nss, err := currentStaticDNS(i) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get DNS") - } - if len(nss) == 0 { - nss = currentDNS(i) - } - for i, dns := range nss { - if i == 0 { - fmt.Printf("DNS : %s\n", dns) - continue - } - fmt.Printf(" : %s\n", dns) - } - println() - return nil - }) - }, - } - interfacesCmd := &cobra.Command{ - Use: "interfaces", - Short: "Manage network interfaces", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - listIfacesCmd.Use, - }, - } - interfacesCmd.AddCommand(listIfacesCmd) - - return interfacesCmd -} - -func initClientsCmd() *cobra.Command { - listClientsCmd := &cobra.Command{ - Use: "list", - Short: "List clients that ctrld discovered", - Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - - p := &prog{router: router.New(&cfg, false)} - s, _ := newService(p, svcConfig) - - status, err := s.Status() - if errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - if status == service.StatusStopped { - mainLog.Load().Warn().Msg("service is not running") - return - } - - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(listClientsPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to get clients list") - } - defer resp.Body.Close() - - var clients []*clientinfo.Client - if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to decode clients list result") - } - map2Slice := func(m map[string]struct{}) []string { - s := make([]string, 0, len(m)) - for k := range m { - if k == "" { // skip empty source from output. - continue - } - s = append(s, k) - } - sort.Strings(s) - return s - } - // If metrics is enabled, server set this for all clients, so we can check only the first one. - // Ideally, we may have a field in response to indicate that query count should be shown, but - // it would break earlier version of ctrld, which only look list of clients in response. - withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount - data := make([][]string, len(clients)) - for i, c := range clients { - row := []string{ - c.IP.String(), - c.Hostname, - c.Mac, - strings.Join(map2Slice(c.Source), ","), - } - if withQueryCount { - row = append(row, strconv.FormatInt(c.QueryCount, 10)) - } - data[i] = row - } - table := tablewriter.NewWriter(os.Stdout) - headers := []string{"IP", "Hostname", "Mac", "Discovered"} - if withQueryCount { - headers = append(headers, "Queries") - } - table.SetHeader(headers) - table.SetAutoFormatHeaders(false) - table.AppendBulk(data) - table.Render() - }, - } - clientsCmd := &cobra.Command{ - Use: "clients", - Short: "Manage clients", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - listClientsCmd.Use, - }, - } - clientsCmd.AddCommand(listClientsCmd) - rootCmd.AddCommand(clientsCmd) - - return clientsCmd -} - -func initUpgradeCmd() *cobra.Command { - const ( - upgradeChannelDev = "dev" - upgradeChannelProd = "prod" - upgradeChannelDefault = "default" - ) - upgradeChannel := map[string]string{ - upgradeChannelDefault: "https://dl.controld.dev", - upgradeChannelDev: "https://dl.controld.dev", - upgradeChannelProd: "https://dl.controld.com", - } - if isStableVersion(curVersion()) { - upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd] - } - upgradeCmd := &cobra.Command{ - Use: "upgrade", - Short: "Upgrading ctrld to latest version", - ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, - Args: cobra.MaximumNArgs(1), - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - bin, err := os.Executable() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path") - } - sc := &service.Config{} - *sc = *svcConfig - sc.Executable = bin - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, sc) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if iface == "" { - iface = "auto" - } - p.preRun() - if ir := runningIface(s); ir != nil { - p.runningIface = ir.Name - p.requiredMultiNICsConfig = ir.All - } - - svcInstalled := true - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - svcInstalled = false - } - oldBin := bin + oldBinSuffix - baseUrl := upgradeChannel[upgradeChannelDefault] - if len(args) > 0 { - channel := args[0] - switch channel { - case upgradeChannelProd, upgradeChannelDev: // ok - default: - mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) - } - baseUrl = upgradeChannel[channel] - } - dlUrl := upgradeUrl(baseUrl) - mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) - - resp, err := getWithRetry(dlUrl, downloadServerIp) - if err != nil { - - mainLog.Load().Fatal().Err(err).Msg("failed to download binary") - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - mainLog.Load().Fatal().Msgf("could not download binary: %s", http.StatusText(resp.StatusCode)) - } - mainLog.Load().Debug().Msg("Updating current binary") - if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil { - if rerr := selfupdate.RollbackError(err); rerr != nil { - mainLog.Load().Error().Err(rerr).Msg("could not rollback old binary") - } - mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") - } - - doRestart := func() bool { - if !svcInstalled { - return true - } - tasks := []task{ - {s.Stop, true, "Stop"}, - {func() error { - p.router.Cleanup() - // restore static DNS settings or DHCP - p.resetDNS(false, true) - return nil - }, false, "Cleanup"}, - {func() error { - time.Sleep(time.Second * 1) - return nil - }, false, "Waiting for service to stop"}, - } - if doTasks(tasks) { - - if router.WaitProcessExited() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - loop: - for { - select { - case <-ctx.Done(): - mainLog.Load().Error().Msg("timeout while waiting for service to stop") - break loop - default: - } - time.Sleep(time.Second) - if status, _ := s.Status(); status == service.StatusStopped { - break - } - } - } - } - - tasks = []task{ - {s.Start, true, "Start"}, - } - if doTasks(tasks) { - if dir, err := socketDir(); err == nil { - if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { - _, _ = cc.post(ifacePath, nil) - return true - } - } - } - return false - } - if svcInstalled { - mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") - } - if doRestart() { - _ = os.Remove(oldBin) - _ = os.Chmod(bin, 0755) - ver := "unknown version" - out, err := exec.Command(bin, "--version").CombinedOutput() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version") - } - if after, found := strings.CutPrefix(string(out), "ctrld version "); found { - ver = after - } - mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver) - return - } - - mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin) - if err := os.Remove(bin); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to remove new binary") - } - if err := os.Rename(oldBin, bin); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to restore old binary") - } - if doRestart() { - mainLog.Load().Notice().Msg("Restored previous binary successfully") - return - } - }, - } - rootCmd.AddCommand(upgradeCmd) - - return upgradeCmd -} - -func initServicesCmd(commands ...*cobra.Command) *cobra.Command { - serviceCmd := &cobra.Command{ - Use: "service", - Short: "Manage ctrld service", - Args: cobra.OnlyValidArgs, - } - serviceCmd.ValidArgs = make([]string, len(commands)) - for i, cmd := range commands { - serviceCmd.ValidArgs[i] = cmd.Use - serviceCmd.AddCommand(cmd) - } - rootCmd.AddCommand(serviceCmd) - - return serviceCmd -} - -// filterEmptyStrings removes empty strings from a slice of strings. -// It returns a new slice containing only non-empty strings. -func filterEmptyStrings(slice []string) []string { - return slices.DeleteFunc(slice, func(s string) bool { - return s == "" - }) -} diff --git a/cmd/cli/commands_clients.go b/cmd/cli/commands_clients.go new file mode 100644 index 00000000..9f577758 --- /dev/null +++ b/cmd/cli/commands_clients.go @@ -0,0 +1,141 @@ +package cli + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + + "github.com/kardianos/service" + "github.com/olekukonko/tablewriter" + "github.com/spf13/cobra" + + "github.com/Control-D-Inc/ctrld/internal/clientinfo" +) + +// ClientsCommand handles clients-related operations +type ClientsCommand struct { + controlClient *controlClient +} + +// NewClientsCommand creates a new clients command handler +func NewClientsCommand() (*ClientsCommand, error) { + dir, err := socketDir() + if err != nil { + return nil, fmt.Errorf("failed to find ctrld home dir: %w", err) + } + + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + return &ClientsCommand{ + controlClient: cc, + }, nil +} + +// ListClients lists all connected clients +func (cc *ClientsCommand) ListClients(cmd *cobra.Command, args []string) error { + // Check service status first + sc := NewServiceCommand() + s, _, err := sc.initializeServiceManager() + if err != nil { + return err + } + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return nil + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("Service is not running") + return nil + } + + resp, err := cc.controlClient.post(listClientsPath, nil) + if err != nil { + return fmt.Errorf("failed to get clients: %w", err) + } + defer resp.Body.Close() + + var clients []*clientinfo.Client + if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil { + return fmt.Errorf("failed to decode clients result: %w", err) + } + + map2Slice := func(m map[string]struct{}) []string { + s := make([]string, 0, len(m)) + for k := range m { + if k == "" { // skip empty source from output. + continue + } + s = append(s, k) + } + sort.Strings(s) + return s + } + + // If metrics is enabled, server set this for all clients, so we can check only the first one. + // Ideally, we may have a field in response to indicate that query count should be shown, but + // it would break earlier version of ctrld, which only look list of clients in response. + withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount + data := make([][]string, len(clients)) + for i, c := range clients { + row := []string{ + c.IP.String(), + c.Hostname, + c.Mac, + strings.Join(map2Slice(c.Source), ","), + } + if withQueryCount { + row = append(row, strconv.FormatInt(c.QueryCount, 10)) + } + data[i] = row + } + + table := tablewriter.NewWriter(os.Stdout) + headers := []string{"IP", "Hostname", "Mac", "Discovered"} + if withQueryCount { + headers = append(headers, "Queries") + } + table.SetHeader(headers) + table.SetAutoFormatHeaders(false) + table.AppendBulk(data) + table.Render() + + return nil +} + +// InitClientsCmd creates the clients command with proper logic +func InitClientsCmd(rootCmd *cobra.Command) *cobra.Command { + listClientsCmd := &cobra.Command{ + Use: "list", + Short: "List clients that ctrld discovered", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: func(cmd *cobra.Command, args []string) error { + cc, err := NewClientsCommand() + if err != nil { + return err + } + return cc.ListClients(cmd, args) + }, + } + + clientsCmd := &cobra.Command{ + Use: "clients", + Short: "Manage clients", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + listClientsCmd.Use, + }, + } + clientsCmd.AddCommand(listClientsCmd) + rootCmd.AddCommand(clientsCmd) + + return clientsCmd +} diff --git a/cmd/cli/commands_interfaces.go b/cmd/cli/commands_interfaces.go new file mode 100644 index 00000000..e4565725 --- /dev/null +++ b/cmd/cli/commands_interfaces.go @@ -0,0 +1,87 @@ +package cli + +import ( + "fmt" + "net" + + "github.com/spf13/cobra" +) + +// InterfacesCommand handles interfaces-related operations +type InterfacesCommand struct{} + +// NewInterfacesCommand creates a new interfaces command handler +func NewInterfacesCommand() (*InterfacesCommand, error) { + return &InterfacesCommand{}, nil +} + +// ListInterfaces lists all network interfaces +func (ic *InterfacesCommand) ListInterfaces(cmd *cobra.Command, args []string) error { + withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error { + fmt.Printf("Index : %d\n", i.Index) + fmt.Printf("Name : %s\n", i.Name) + var status string + if i.Flags&net.FlagUp != 0 { + status = "Up" + } else { + status = "Down" + } + fmt.Printf("Status: %s\n", status) + addrs, _ := i.Addrs() + for i, ipaddr := range addrs { + if i == 0 { + fmt.Printf("Addrs : %v\n", ipaddr) + continue + } + fmt.Printf(" %v\n", ipaddr) + } + nss, err := currentStaticDNS(i) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get DNS") + } + if len(nss) == 0 { + nss = currentDNS(i) + } + for i, dns := range nss { + if i == 0 { + fmt.Printf("DNS : %s\n", dns) + continue + } + fmt.Printf(" : %s\n", dns) + } + println() + return nil + }) + return nil +} + +// InitInterfacesCmd creates the interfaces command with proper logic +func InitInterfacesCmd(_ *cobra.Command) *cobra.Command { + listInterfacesCmd := &cobra.Command{ + Use: "list", + Short: "List network interfaces", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: func(cmd *cobra.Command, args []string) error { + ic, err := NewInterfacesCommand() + if err != nil { + return err + } + return ic.ListInterfaces(cmd, args) + }, + } + + interfacesCmd := &cobra.Command{ + Use: "interfaces", + Short: "Manage network interfaces", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + listInterfacesCmd.Use, + }, + } + interfacesCmd.AddCommand(listInterfacesCmd) + + return interfacesCmd +} diff --git a/cmd/cli/commands_log.go b/cmd/cli/commands_log.go new file mode 100644 index 00000000..f96306b0 --- /dev/null +++ b/cmd/cli/commands_log.go @@ -0,0 +1,175 @@ +package cli + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "path/filepath" + + "github.com/docker/go-units" + "github.com/kardianos/service" + "github.com/spf13/cobra" +) + +// LogCommand handles log-related operations +type LogCommand struct { + controlClient *controlClient +} + +// NewLogCommand creates a new log command handler +func NewLogCommand() (*LogCommand, error) { + dir, err := socketDir() + if err != nil { + return nil, fmt.Errorf("failed to find ctrld home dir: %w", err) + } + + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + return &LogCommand{ + controlClient: cc, + }, nil +} + +// warnRuntimeLoggingNotEnabled logs a warning about runtime logging not being enabled +func (lc *LogCommand) warnRuntimeLoggingNotEnabled() { + mainLog.Load().Warn().Msg("Runtime debug logging is not enabled") + mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`) +} + +// SendLogs sends runtime debug logs to ControlD +func (lc *LogCommand) SendLogs(cmd *cobra.Command, args []string) error { + sc := NewServiceCommand() + s, _, err := sc.initializeServiceManager() + if err != nil { + return err + } + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return nil + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("Service is not running") + return nil + } + + resp, err := lc.controlClient.post(sendLogsPath, nil) + if err != nil { + return fmt.Errorf("failed to send logs: %w", err) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusServiceUnavailable: + mainLog.Load().Warn().Msg("Runtime logs could only be sent once per minute") + return nil + case http.StatusMovedPermanently: + lc.warnRuntimeLoggingNotEnabled() + return nil + } + + var logs logSentResponse + if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { + return fmt.Errorf("failed to decode sent logs result: %w", err) + } + + if logs.Error != "" { + return fmt.Errorf("failed to send logs: %s", logs.Error) + } + + mainLog.Load().Notice().Msgf("Sent %s of runtime logs", units.BytesSize(float64(logs.Size))) + return nil +} + +// ViewLogs views current runtime debug logs +func (lc *LogCommand) ViewLogs(cmd *cobra.Command, args []string) error { + sc := NewServiceCommand() + s, _, err := sc.initializeServiceManager() + if err != nil { + return err + } + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return nil + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("Service is not running") + return nil + } + + resp, err := lc.controlClient.post(viewLogsPath, nil) + if err != nil { + return fmt.Errorf("failed to get logs: %w", err) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusMovedPermanently: + lc.warnRuntimeLoggingNotEnabled() + return nil + case http.StatusBadRequest: + mainLog.Load().Warn().Msg("Runtime debug logs are not available") + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("Failed to read response body") + } + mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf)) + return nil + case http.StatusOK: + } + + var logs logViewResponse + if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { + return fmt.Errorf("failed to decode view logs result: %w", err) + } + + fmt.Print(logs.Data) + return nil +} + +// InitLogCmd creates the log command with proper logic +func InitLogCmd(rootCmd *cobra.Command) *cobra.Command { + lc, err := NewLogCommand() + if err != nil { + panic(fmt.Sprintf("failed to create log command: %v", err)) + } + + logSendCmd := &cobra.Command{ + Use: "send", + Short: "Send runtime debug logs to ControlD", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: lc.SendLogs, + } + + logViewCmd := &cobra.Command{ + Use: "view", + Short: "View current runtime debug logs", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: lc.ViewLogs, + } + + logCmd := &cobra.Command{ + Use: "log", + Short: "Manage runtime debug logs", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + logSendCmd.Use, + logViewCmd.Use, + }, + } + logCmd.AddCommand(logSendCmd) + logCmd.AddCommand(logViewCmd) + rootCmd.AddCommand(logCmd) + + return logCmd +} diff --git a/cmd/cli/commands_run.go b/cmd/cli/commands_run.go new file mode 100644 index 00000000..9d3260b4 --- /dev/null +++ b/cmd/cli/commands_run.go @@ -0,0 +1,59 @@ +package cli + +import ( + "github.com/spf13/cobra" + + "github.com/Control-D-Inc/ctrld" +) + +// RunCommand handles run-related operations +type RunCommand struct { + // Add any dependencies here if needed in the future +} + +// NewRunCommand creates a new run command handler +func NewRunCommand() *RunCommand { + return &RunCommand{} +} + +// Run implements the logic for the run command +func (rc *RunCommand) Run(cmd *cobra.Command, args []string) { + RunCobraCommand(cmd) +} + +// InitRunCmd creates the run command with proper logic +func InitRunCmd(rootCmd *cobra.Command) *cobra.Command { + rc := NewRunCommand() + + runCmd := &cobra.Command{ + Use: "run", + Short: "Run the DNS proxy server", + Args: cobra.NoArgs, + Run: rc.Run, + } + runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon") + runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") + runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") + runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") + runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") + runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") + runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") + runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") + runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") + runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = runCmd.Flags().MarkHidden("dev") + runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") + _ = runCmd.Flags().MarkHidden("homedir") + runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + _ = runCmd.Flags().MarkHidden("iface") + runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + runCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener") + + runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true} + rootCmd.AddCommand(runCmd) + + return runCmd +} diff --git a/cmd/cli/commands_service.go b/cmd/cli/commands_service.go new file mode 100644 index 00000000..eb263081 --- /dev/null +++ b/cmd/cli/commands_service.go @@ -0,0 +1,256 @@ +package cli + +import ( + "fmt" + "os" + "runtime" + + "github.com/kardianos/service" + "github.com/spf13/cobra" +) + +// filterEmptyStrings removes empty strings from a slice +// This is used to clean up command line arguments and configuration values +func filterEmptyStrings(slice []string) []string { + var result []string + for _, s := range slice { + if s != "" { + result = append(result, s) + } + } + return result +} + +// ServiceCommand handles service-related operations +// This encapsulates all service management functionality for the CLI +type ServiceCommand struct { + serviceManager *ServiceManager +} + +// initializeServiceManager creates a service manager with default configuration +// This sets up the basic service infrastructure needed for all service operations +func (sc *ServiceCommand) initializeServiceManager() (service.Service, *prog, error) { + svcConfig := sc.createServiceConfig() + return sc.initializeServiceManagerWithServiceConfig(svcConfig) +} + +// initializeServiceManagerWithServiceConfig creates a service manager with the given configuration +// This allows for custom service configuration while maintaining the same initialization pattern +func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *service.Config) (service.Service, *prog, error) { + p := &prog{} + + s, err := sc.newService(p, svcConfig) + if err != nil { + return nil, nil, fmt.Errorf("failed to create service: %w", err) + } + + sc.serviceManager = &ServiceManager{prog: p, svc: s} + return s, p, nil +} + +// newService creates a new service instance using the provided program and configuration. +// This abstracts the service creation process for different operating systems +func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (service.Service, error) { + s, err := newService(p, svcConfig) + if err != nil { + return nil, fmt.Errorf("failed to create service: %w", err) + } + return s, nil +} + +// NewServiceCommand creates a new service command handler +// This provides a clean factory method for creating service command instances +func NewServiceCommand() *ServiceCommand { + return &ServiceCommand{} +} + +// createServiceConfig creates a properly initialized service configuration +// This ensures consistent service naming and description across all platforms +func (sc *ServiceCommand) createServiceConfig() *service.Config { + return &service.Config{ + Name: ctrldServiceName, + DisplayName: "Control-D Helper Service", + Description: "A highly configurable, multi-protocol DNS forwarding proxy", + Option: service.KeyValue{}, + } +} + +// InitServiceCmd creates the service command with proper logic and aliases +// This sets up all service-related subcommands with appropriate permissions and flags +func InitServiceCmd(rootCmd *cobra.Command) *cobra.Command { + // Create service command handlers + sc := NewServiceCommand() + + startCmd, startCmdAlias := createStartCommands(sc) + rootCmd.AddCommand(startCmdAlias) + + // Stop command + stopCmd := &cobra.Command{ + Use: "stop", + Short: "Stop the ctrld service", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: sc.Stop, + } + stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) + stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`) + _ = stopCmd.Flags().MarkHidden("pin") + + // Restart command + restartCmd := &cobra.Command{ + Use: "restart", + Short: "Restart the ctrld service", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: sc.Restart, + } + + // Status command + statusCmd := &cobra.Command{ + Use: "status", + Short: "Show status of the ctrld service", + Args: cobra.NoArgs, + RunE: sc.Status, + } + if runtime.GOOS == "darwin" { + // On darwin, running status command without privileges may return wrong information. + statusCmd.PreRun = func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + } + } + + // Reload command + reloadCmd := &cobra.Command{ + Use: "reload", + Short: "Reload the ctrld service", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: sc.Reload, + } + + // Uninstall command + uninstallCmd := &cobra.Command{ + Use: "uninstall", + Short: "Stop and uninstall the ctrld service", + Long: `Stop and uninstall the ctrld service. + +NOTE: Uninstalling will set DNS to values provided by DHCP.`, + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: sc.Uninstall, + } + uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) + uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`) + _ = uninstallCmd.Flags().MarkHidden("pin") + uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`) + + // Interfaces command - use the existing InitInterfacesCmd function + interfacesCmd := InitInterfacesCmd(rootCmd) + + stopCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "stop", + Short: "Quick stop service and remove DNS from interface", + RunE: func(cmd *cobra.Command, args []string) error { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + return stopCmd.RunE(cmd, args) + }, + } + stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) + stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) + rootCmd.AddCommand(stopCmdAlias) + + // Create aliases for other service commands + restartCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "restart", + Short: "Restart the ctrld service", + RunE: func(cmd *cobra.Command, args []string) error { + return restartCmd.RunE(cmd, args) + }, + } + rootCmd.AddCommand(restartCmdAlias) + + reloadCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + RunE: func(cmd *cobra.Command, args []string) error { + return reloadCmd.RunE(cmd, args) + }, + } + rootCmd.AddCommand(reloadCmdAlias) + + statusCmdAlias := &cobra.Command{ + Use: "status", + Short: "Show status of the ctrld service", + Args: cobra.NoArgs, + RunE: statusCmd.RunE, + } + rootCmd.AddCommand(statusCmdAlias) + + uninstallCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "uninstall", + Short: "Stop and uninstall the ctrld service", + Long: `Stop and uninstall the ctrld service. + +NOTE: Uninstalling will set DNS to values provided by DHCP.`, + RunE: func(cmd *cobra.Command, args []string) error { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + return uninstallCmd.RunE(cmd, args) + }, + } + uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) + uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags()) + rootCmd.AddCommand(uninstallCmdAlias) + + // Create service command + serviceCmd := &cobra.Command{ + Use: "service", + Short: "Manage ctrld service", + Args: cobra.OnlyValidArgs, + } + serviceCmd.ValidArgs = make([]string, 7) + serviceCmd.ValidArgs[0] = startCmd.Use + serviceCmd.ValidArgs[1] = stopCmd.Use + serviceCmd.ValidArgs[2] = restartCmd.Use + serviceCmd.ValidArgs[3] = reloadCmd.Use + serviceCmd.ValidArgs[4] = statusCmd.Use + serviceCmd.ValidArgs[5] = uninstallCmd.Use + serviceCmd.ValidArgs[6] = interfacesCmd.Use + + serviceCmd.AddCommand(startCmd) + serviceCmd.AddCommand(stopCmd) + serviceCmd.AddCommand(restartCmd) + serviceCmd.AddCommand(reloadCmd) + serviceCmd.AddCommand(statusCmd) + serviceCmd.AddCommand(uninstallCmd) + serviceCmd.AddCommand(interfacesCmd) + + rootCmd.AddCommand(serviceCmd) + + return serviceCmd +} diff --git a/cmd/cli/commands_service_manager.go b/cmd/cli/commands_service_manager.go new file mode 100644 index 00000000..2b35e8eb --- /dev/null +++ b/cmd/cli/commands_service_manager.go @@ -0,0 +1,41 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/kardianos/service" +) + +// dialSocketControlServerTimeout is the default timeout to wait when ping control server. +const dialSocketControlServerTimeout = 30 * time.Second + +// ServiceManager handles service operations +type ServiceManager struct { + prog *prog + svc service.Service +} + +// NewServiceManager creates a new service manager +func NewServiceManager() (*ServiceManager, error) { + p := &prog{} + + // Create a proper service configuration + svcConfig := &service.Config{ + Name: ctrldServiceName, + DisplayName: "Control-D Helper Service", + Description: "A highly configurable, multi-protocol DNS forwarding proxy", + Option: service.KeyValue{}, + } + + s, err := newService(p, svcConfig) + if err != nil { + return nil, fmt.Errorf("failed to create service: %w", err) + } + return &ServiceManager{prog: p, svc: s}, nil +} + +// Status returns the current service status +func (sm *ServiceManager) Status() (service.Status, error) { + return sm.svc.Status() +} diff --git a/cmd/cli/commands_service_reload.go b/cmd/cli/commands_service_reload.go new file mode 100644 index 00000000..5ddf4ff6 --- /dev/null +++ b/cmd/cli/commands_service_reload.go @@ -0,0 +1,67 @@ +package cli + +import ( + "errors" + "io" + "net/http" + "path/filepath" + + "github.com/kardianos/service" + "github.com/spf13/cobra" +) + +// Reload implements the logic from cmdReload.Run +func (sc *ServiceCommand) Reload(cmd *cobra.Command, args []string) error { + logger := mainLog.Load() + logger.Debug().Msg("Service reload command started") + + s, _, err := sc.initializeServiceManager() + if err != nil { + logger.Error().Err(err).Msg("Failed to initialize service manager") + return err + } + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + logger.Warn().Msg("Service not installed") + return nil + } + if status == service.StatusStopped { + logger.Warn().Msg("Service is not running") + return nil + } + + dir, err := socketDir() + if err != nil { + logger.Fatal().Err(err).Msg("Failed to find ctrld home dir") + } + + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(reloadPath, nil) + if err != nil { + logger.Fatal().Err(err).Msg("Failed to send reload signal to ctrld") + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + logger.Notice().Msg("Service reloaded") + case http.StatusCreated: + logger.Warn().Msg("Service was reloaded, but new config requires service restart.") + logger.Warn().Msg("Restarting service") + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + logger.Warn().Msg("Service not installed") + return nil + } + return sc.Restart(cmd, args) + default: + buf, err := io.ReadAll(resp.Body) + if err != nil { + logger.Fatal().Err(err).Msg("Could not read response from control server") + } + logger.Error().Err(err).Msgf("Failed to reload ctrld: %s", string(buf)) + } + + logger.Debug().Msg("Service reload command completed") + return nil +} diff --git a/cmd/cli/commands_service_restart.go b/cmd/cli/commands_service_restart.go new file mode 100644 index 00000000..02e5a69b --- /dev/null +++ b/cmd/cli/commands_service_restart.go @@ -0,0 +1,111 @@ +package cli + +import ( + "context" + "errors" + "time" + + "github.com/kardianos/service" + "github.com/spf13/cobra" +) + +// Restart implements the logic from cmdRestart.Run +func (sc *ServiceCommand) Restart(cmd *cobra.Command, args []string) error { + logger := mainLog.Load() + logger.Debug().Msg("Service restart command started") + + readConfig(false) + v.Unmarshal(&cfg) + cdUID = curCdUID() + cdMode := cdUID != "" + + s, p, err := sc.initializeServiceManager() + if err != nil { + logger.Error().Err(err).Msg("Failed to initialize service manager") + return err + } + + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + logger.Warn().Msg("Service not installed") + return nil + } + + p.cfg = &cfg + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + + initInteractiveLogging() + + var validateConfigErr error + if cdMode { + logger.Debug().Msg("Validating ControlD remote config") + validateConfigErr = doValidateCdRemoteConfig(cdUID, false) + if validateConfigErr != nil { + logger.Warn().Err(validateConfigErr).Msg("ControlD remote config validation failed") + } + } + + if ir := runningIface(s); ir != nil { + iface = ir.Name + } + + doRestart := func() bool { + logger.Debug().Msg("Starting service restart sequence") + + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + // restore static DNS settings or DHCP + p.resetDNS(false, true) + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + if !doTasks(tasks) { + logger.Error().Msg("Service stop tasks failed") + return false + } + tasks = []task{ + {s.Start, true, "Start"}, + } + success := doTasks(tasks) + if success { + logger.Debug().Msg("Service restart sequence completed successfully") + } else { + logger.Error().Msg("Service restart sequence failed") + } + return success + } + + if doRestart() { + if dir, err := socketDir(); err == nil { + timeout := dialSocketControlServerTimeout + if validateConfigErr != nil { + timeout = 5 * time.Second + } + if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil { + _, _ = cc.post(ifacePath, nil) + logger.Debug().Msg("Control server ping successful") + } else { + logger.Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet") + } + } else { + logger.Warn().Err(err).Msg("Service was restarted, but could not ping the control server") + } + logger.Notice().Msg("Service restarted") + } else { + logger.Error().Msg("Service restart failed") + } + + logger.Debug().Msg("Service restart command completed") + return nil +} diff --git a/cmd/cli/commands_service_start.go b/cmd/cli/commands_service_start.go new file mode 100644 index 00000000..c5430efd --- /dev/null +++ b/cmd/cli/commands_service_start.go @@ -0,0 +1,387 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/kardianos/service" + "github.com/spf13/cobra" + + "github.com/Control-D-Inc/ctrld" +) + +// Start implements the logic from cmdStart.Run +func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error { + logger := mainLog.Load() + logger.Debug().Msg("Service start command started") + + checkStrFlagEmpty(cmd, cdUidFlagName) + checkStrFlagEmpty(cmd, cdOrgFlagName) + validateCdAndNextDNSFlags() + + svcConfig := sc.createServiceConfig() + osArgs := os.Args[2:] + osArgs = filterEmptyStrings(osArgs) + if os.Args[1] == "service" { + osArgs = os.Args[3:] + } + setDependencies(svcConfig) + svcConfig.Arguments = append([]string{"run"}, osArgs...) + + // Initialize service manager with proper configuration + s, p, err := sc.initializeServiceManagerWithServiceConfig(svcConfig) + if err != nil { + logger.Error().Err(err).Msg("Failed to initialize service manager") + return err + } + + p.cfg = &cfg + p.preRun() + + status, err := s.Status() + isCtrldRunning := status == service.StatusRunning + isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) + + // Get current running iface, if any. + var currentIface *ifaceResponse + + // If pin code was set, do not allow running start command. + if isCtrldRunning { + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + logger.Error().Msg("Deactivation pin check failed") + os.Exit(deactivationPinInvalidExitCode) + } + currentIface = runningIface(s) + logger.Debug().Msgf("Current interface on start: %v", currentIface) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + reportSetDnsOk := func(sockDir string) { + if cc := newSocketControlClient(ctx, s, sockDir); cc != nil { + if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { + if iface == "auto" { + iface = defaultIfaceName() + } + res := &ifaceResponse{} + if err := json.NewDecoder(resp.Body).Decode(res); err != nil { + logger.Warn().Err(err).Msg("Failed to get iface info") + return + } + if res.OK { + name := res.Name + if iff, err := net.InterfaceByName(name); err == nil { + _, _ = patchNetIfaceName(iff) + name = iff.Name + } + logger := logger.With().Str("iface", name) + logger.Debug().Msg("Setting DNS successfully") + if res.All { + // Log that DNS is set for other interfaces. + withEachPhysicalInterfaces( + name, + "set DNS", + func(i *net.Interface) error { return nil }, + ) + } + } + } + } + } + + // No config path, generating config in HOME directory. + noConfigStart := isNoConfigStart(cmd) + writeDefaultConfig := !noConfigStart && configBase64 == "" + + logServerStarted := make(chan struct{}) + stopLogCh := make(chan struct{}) + ud, err := userHomeDir() + sockDir := ud + var logServerSocketPath string + if err != nil { + logger.Warn().Err(err).Msg("Failed to get user home directory") + logger.Warn().Msg("Log server did not start") + close(logServerStarted) + } else { + setWorkingDirectory(svcConfig, ud) + if configPath == "" && writeDefaultConfig { + defaultConfigFile = filepath.Join(ud, defaultConfigFile) + } + svcConfig.Arguments = append(svcConfig.Arguments, "--homedir="+ud) + if d, err := socketDir(); err == nil { + sockDir = d + } + logServerSocketPath = filepath.Join(sockDir, ctrldLogUnixSock) + _ = os.Remove(logServerSocketPath) + go func() { + defer os.Remove(logServerSocketPath) + + close(logServerStarted) + + // Start HTTP log server + if err := httpLogServer(logServerSocketPath, stopLogCh); err != nil && err != http.ErrServerClosed { + logger.Warn().Err(err).Msg("Failed to serve HTTP log server") + return + } + }() + } + <-logServerStarted + + if !startOnly { + startOnly = len(osArgs) == 0 + } + // If user run "ctrld start" and ctrld is already installed, starting existing service. + if startOnly && isCtrldInstalled { + tryReadingConfigWithNotice(false, true) + if err := v.Unmarshal(&cfg); err != nil { + logger.Fatal().Msgf("Failed to unmarshal config: %v", err) + } + + // if already running, dont restart + if isCtrldRunning { + logger.Notice().Msg("Service is already running") + return nil + } + + initInteractiveLogging() + tasks := []task{ + {func() error { + // Save current DNS so we can restore later. + withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil + }) + return nil + }, false, "Save current DNS"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure service failure actions"}, + {s.Start, true, "Start"}, + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, + } + logger.Notice().Msg("Starting existing ctrld service") + if doTasks(tasks) { + logger.Notice().Msg("Service started") + sockDir, err := socketDir() + if err != nil { + logger.Warn().Err(err).Msg("Failed to get socket directory") + os.Exit(1) + } + reportSetDnsOk(sockDir) + } else { + logger.Error().Err(err).Msg("Failed to start existing ctrld service") + os.Exit(1) + } + return nil + } + + if cdUID != "" { + _ = doValidateCdRemoteConfig(cdUID, true) + } else if uid := cdUIDFromProvToken(); uid != "" { + cdUID = uid + logger.Debug().Msg("Using uid from provision token") + removeOrgFlagsFromArgs(svcConfig) + // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. + svcConfig.Arguments = append(svcConfig.Arguments, "--cd="+cdUID) + } + if cdUID != "" { + validateCdUpstreamProtocol() + } + + if configPath != "" { + v.SetConfigFile(configPath) + } + + tryReadingConfigWithNotice(writeDefaultConfig, true) + + if err := v.Unmarshal(&cfg); err != nil { + logger.Fatal().Msgf("Failed to unmarshal config: %v", err) + } + + initInteractiveLogging() + + if nextdns != "" { + removeNextDNSFromArgs(svcConfig) + } + + // Explicitly passing config, so on system where home directory could not be obtained, + // or sub-process env is different with the parent, we still behave correctly and use + // the expected config file. + if configPath == "" { + svcConfig.Arguments = append(svcConfig.Arguments, "--config="+defaultConfigFile) + } + + tasks := []task{ + {s.Stop, false, "Stop"}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"}, + {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, + //resetDnsTask(p, s, isCtrldInstalled, currentIface), + {func() error { + // Save current DNS so we can restore later. + withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil + }) + return nil + }, false, "Save current DNS"}, + {s.Install, false, "Install"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure Windows service failure actions"}, + {s.Start, true, "Start"}, + // Note that startCmd do not actually write ControlD config, but the config file was + // generated after s.Start, so we notice users here for consistent with nextdns mode. + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, + } + logger.Notice().Msg("Starting service") + if doTasks(tasks) { + // add a small delay to ensure the service is started and did not crash + time.Sleep(1 * time.Second) + + ok, status, err := selfCheckStatus(ctx, s, sockDir) + switch { + case ok && status == service.StatusRunning: + logger.Notice().Msg("Service started") + default: + marker := append(bytes.Repeat([]byte("="), 32), '\n') + // If ctrld service is not running, emitting log obtained from ctrld process. + if status != service.StatusRunning || ctx.Err() != nil { + logger.Error().Msg("Ctrld service may not have started due to an error or misconfiguration, service log:") + _, _ = logger.Write(marker) + + // Wait for log collection to complete + <-stopLogCh + + // Retrieve logs from HTTP server if available + if logServerSocketPath != "" { + hlc := newHTTPLogClient(logServerSocketPath) + logs, err := hlc.GetLogs() + if err != nil { + logger.Warn().Err(err).Msg("Failed to get logs from HTTP log server") + } + if len(logs) == 0 { + logger.Write([]byte("\n")) + } else { + logger.Write(logs) + logger.Write([]byte("\n")) + } + } else { + logger.Write([]byte("\n")) + } + } + // Report any error if occurred. + if err != nil { + _, _ = logger.Write(marker) + msg := fmt.Sprintf("An error occurred while performing test query: %s\n", err) + logger.Write([]byte(msg)) + } + // If ctrld service is running but selfCheckStatus failed, it could be related + // to user's system firewall configuration, notice users about it. + if status == service.StatusRunning && err == nil { + _, _ = logger.Write(marker) + logger.Write([]byte("ctrld service was running, but a DNS query could not be sent to its listener\n")) + logger.Write([]byte("Please check your system firewall if it is configured to block/intercept/redirect DNS queries\n")) + } + + _, _ = logger.Write(marker) + uninstall(p, s) + os.Exit(1) + } + reportSetDnsOk(sockDir) + } + + logger.Debug().Msg("Service start command completed") + return nil +} + +// createStartCommands creates the start command and its alias +func createStartCommands(sc *ServiceCommand) (*cobra.Command, *cobra.Command) { + // Start command + startCmd := &cobra.Command{ + Use: "start", + Short: "Install and start the ctrld service", + Long: `Install and start the ctrld service + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Args: func(cmd *cobra.Command, args []string) error { + args = filterEmptyStrings(args) + if len(args) > 0 { + return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") + } + return nil + }, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: sc.Start, + } + // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". + startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") + startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") + startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") + startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") + startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") + startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") + startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") + startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") + startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = startCmd.Flags().MarkHidden("dev") + startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") + startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) + startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service") + _ = startCmd.Flags().MarkHidden("start_only") + startCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener") + + // Start command alias + startCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "start", + Short: "Quick start service and configure DNS on interface", + Long: `Quick start service and configure DNS on interface + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Args: func(cmd *cobra.Command, args []string) error { + args = filterEmptyStrings(args) + if len(args) > 0 { + return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" + + "Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options") + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + if len(os.Args) == 2 { + startOnly = true + } + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + return startCmd.RunE(cmd, args) + }, + } + startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) + + return startCmd, startCmdAlias +} diff --git a/cmd/cli/commands_service_status.go b/cmd/cli/commands_service_status.go new file mode 100644 index 00000000..270e0e06 --- /dev/null +++ b/cmd/cli/commands_service_status.go @@ -0,0 +1,41 @@ +package cli + +import ( + "os" + + "github.com/kardianos/service" + "github.com/spf13/cobra" +) + +// Status implements the logic from cmdStatus.Run +func (sc *ServiceCommand) Status(cmd *cobra.Command, args []string) error { + logger := mainLog.Load() + logger.Debug().Msg("Service status command started") + + s, _, err := sc.initializeServiceManager() + if err != nil { + logger.Error().Err(err).Msg("Failed to initialize service manager") + return err + } + + status, err := s.Status() + if err != nil { + logger.Error().Msg(err.Error()) + os.Exit(1) + } + + switch status { + case service.StatusUnknown: + logger.Notice().Msg("Unknown status") + os.Exit(2) + case service.StatusRunning: + logger.Notice().Msg("Service is running") + os.Exit(0) + case service.StatusStopped: + logger.Notice().Msg("Service is stopped") + os.Exit(1) + } + + logger.Debug().Msg("Service status command completed") + return nil +} diff --git a/cmd/cli/commands_service_stop.go b/cmd/cli/commands_service_stop.go new file mode 100644 index 00000000..0f47e462 --- /dev/null +++ b/cmd/cli/commands_service_stop.go @@ -0,0 +1,61 @@ +package cli + +import ( + "errors" + "os" + + "github.com/kardianos/service" + "github.com/spf13/cobra" +) + +// Stop implements the logic from cmdStop.Run +func (sc *ServiceCommand) Stop(cmd *cobra.Command, args []string) error { + logger := mainLog.Load() + logger.Debug().Msg("Service stop command started") + + readConfig(false) + v.Unmarshal(&cfg) + + s, p, err := sc.initializeServiceManager() + if err != nil { + logger.Error().Err(err).Msg("Failed to initialize service manager") + return err + } + + p.cfg = &cfg + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + + initInteractiveLogging() + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + logger.Warn().Msg("Service not installed") + return nil + } + if status == service.StatusStopped { + logger.Warn().Msg("Service is already stopped") + return nil + } + + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + logger.Error().Msg("Deactivation pin check failed") + os.Exit(deactivationPinInvalidExitCode) + } + + logger.Debug().Msg("Stopping service") + if doTasks([]task{{s.Stop, true, "Stop"}}) { + logger.Notice().Msg("Service stopped") + } else { + logger.Error().Msg("Service stop failed") + } + + logger.Debug().Msg("Service stop command completed") + return nil +} diff --git a/cmd/cli/commands_service_uninstall.go b/cmd/cli/commands_service_uninstall.go new file mode 100644 index 00000000..78a3d5e1 --- /dev/null +++ b/cmd/cli/commands_service_uninstall.go @@ -0,0 +1,106 @@ +package cli + +import ( + "net" + "os" + "path/filepath" + + "github.com/spf13/cobra" + + "github.com/Control-D-Inc/ctrld" +) + +// Uninstall implements the logic from cmdUninstall.Run +func (sc *ServiceCommand) Uninstall(cmd *cobra.Command, args []string) error { + logger := mainLog.Load() + logger.Debug().Msg("Service uninstall command started") + + readConfig(false) + v.Unmarshal(&cfg) + + s, p, err := sc.initializeServiceManager() + if err != nil { + logger.Error().Err(err).Msg("Failed to initialize service manager") + return err + } + + p.cfg = &cfg + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + logger.Error().Msg("Deactivation pin check failed") + os.Exit(deactivationPinInvalidExitCode) + } + + logger.Debug().Msg("Starting service uninstall") + uninstall(p, s) + + if cleanup { + logger.Debug().Msg("Performing cleanup operations") + var files []string + // Config file. + files = append(files, v.ConfigFileUsed()) + // Log file and backup log file. + // For safety, only process if log file path is absolute. + if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) { + files = append(files, logFile) + oldLogFile := logFile + oldLogSuffix + if _, err := os.Stat(oldLogFile); err == nil { + files = append(files, oldLogFile) + } + } + // Socket files. + if dir, _ := socketDir(); dir != "" { + files = append(files, filepath.Join(dir, ctrldControlUnixSock)) + files = append(files, filepath.Join(dir, ctrldLogUnixSock)) + } + // Static DNS settings files. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + file := ctrld.SavedStaticDnsSettingsFilePath(i) + files = append(files, file) + return nil + }) + bin, err := os.Executable() + if err != nil { + logger.Warn().Err(err).Msg("Failed to get executable path") + } + if bin != "" && supportedSelfDelete { + files = append(files, bin) + } + // Backup file after upgrading. + oldBin := bin + oldBinSuffix + if _, err := os.Stat(oldBin); err == nil { + files = append(files, oldBin) + } + for _, file := range files { + if file == "" { + continue + } + if err := os.Remove(file); err == nil { + logger.Notice().Str("file", file).Msg("File removed during cleanup") + } else { + logger.Debug().Err(err).Str("file", file).Msg("Failed to remove file during cleanup") + } + } + // Self-delete the ctrld binary if supported + if err := selfDeleteExe(); err != nil { + logger.Warn().Err(err).Msg("Failed to delete ctrld binary") + } else { + if !supportedSelfDelete { + logger.Debug().Msgf("File removed: %s", bin) + } + } + + logger.Debug().Msg("Cleanup operations completed") + } + + logger.Debug().Msg("Service uninstall command completed") + return nil +} diff --git a/cmd/cli/commands_test.go b/cmd/cli/commands_test.go new file mode 100644 index 00000000..98ac760b --- /dev/null +++ b/cmd/cli/commands_test.go @@ -0,0 +1,197 @@ +package cli + +import ( + "bytes" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestBasicCommandStructure tests the actual root command structure +func TestBasicCommandStructure(t *testing.T) { + // Test the actual root command that's returned from initCLI() + rootCmd := initCLI() + + // Test that root command has basic properties + assert.Equal(t, "ctrld", rootCmd.Use) + assert.NotEmpty(t, rootCmd.Short, "Root command should have a short description") + + // Test that root command has subcommands + commands := rootCmd.Commands() + assert.NotNil(t, commands, "Root command should have subcommands") + assert.Greater(t, len(commands), 0, "Root command should have at least one subcommand") + + // Test that expected commands exist + expectedCommands := []string{"run", "service", "clients", "upgrade", "log"} + for _, cmdName := range expectedCommands { + found := false + for _, cmd := range commands { + if cmd.Name() == cmdName { + found = true + break + } + } + assert.True(t, found, "Expected command %s not found in root command", cmdName) + } +} + +// TestServiceCommandCreation tests service command creation +func TestServiceCommandCreation(t *testing.T) { + sc := NewServiceCommand() + require.NotNil(t, sc, "ServiceCommand should be created") + + // Test service config creation + config := sc.createServiceConfig() + require.NotNil(t, config, "Service config should be created") + assert.Equal(t, ctrldServiceName, config.Name) + assert.Equal(t, "Control-D Helper Service", config.DisplayName) + assert.Equal(t, "A highly configurable, multi-protocol DNS forwarding proxy", config.Description) +} + +// TestServiceCommandSubCommands tests service command sub commands +func TestServiceCommandSubCommands(t *testing.T) { + rootCmd := &cobra.Command{ + Use: "ctrld", + Short: "DNS forwarding proxy", + } + + serviceCmd := InitServiceCmd(rootCmd) + require.NotNil(t, serviceCmd, "Service command should be created") + + // Test that service command has subcommands + subcommands := serviceCmd.Commands() + assert.Greater(t, len(subcommands), 0, "Service command should have subcommands") + + // Test specific subcommands exist + expectedCommands := []string{"start", "stop", "restart", "reload", "status", "uninstall", "interfaces"} + + for _, cmdName := range expectedCommands { + found := false + for _, cmd := range subcommands { + if cmd.Name() == cmdName { + found = true + break + } + } + assert.True(t, found, "Expected service subcommand %s not found", cmdName) + } +} + +// TestCommandHelp tests basic help functionality +func TestCommandHelp(t *testing.T) { + // Initialize the CLI to set up the root command + rootCmd := initCLI() + + // Test help command execution + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetErr(&buf) + + rootCmd.SetArgs([]string{"--help"}) + err := rootCmd.Execute() + assert.NoError(t, err, "Help command should execute without error") + assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description") +} + +// TestCommandVersion tests version command +func TestCommandVersion(t *testing.T) { + // Initialize the CLI to set up the root command + rootCmd := initCLI() + + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetErr(&buf) + + // Test version command + rootCmd.SetArgs([]string{"--version"}) + err := rootCmd.Execute() + assert.NoError(t, err, "Version command should execute without error") + assert.Contains(t, buf.String(), "version", "Version output should contain version information") +} + +// TestCommandErrorHandling tests error handling +func TestCommandErrorHandling(t *testing.T) { + // Initialize the CLI to set up the root command + rootCmd := initCLI() + + // Test invalid flag instead of invalid command + rootCmd.SetArgs([]string{"--invalid-flag"}) + err := rootCmd.Execute() + assert.Error(t, err, "Invalid flag should return error") +} + +// TestCommandFlags tests flag functionality +func TestCommandFlags(t *testing.T) { + // Initialize the CLI to set up the root command + rootCmd := initCLI() + + // Test that root command has expected flags + verboseFlag := rootCmd.PersistentFlags().Lookup("verbose") + assert.NotNil(t, verboseFlag, "Verbose flag should exist") + assert.Equal(t, "v", verboseFlag.Shorthand) + + silentFlag := rootCmd.PersistentFlags().Lookup("silent") + assert.NotNil(t, silentFlag, "Silent flag should exist") + assert.Equal(t, "s", silentFlag.Shorthand) +} + +// TestCommandExecution tests basic command execution +func TestCommandExecution(t *testing.T) { + // Initialize the CLI to set up the root command + rootCmd := initCLI() + + // Test that root command can be executed (help command) + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetErr(&buf) + + rootCmd.SetArgs([]string{"--help"}) + err := rootCmd.Execute() + assert.NoError(t, err, "Root command should execute without error") + assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description") +} + +// TestCommandArgs tests argument handling +func TestCommandArgs(t *testing.T) { + // Initialize the CLI to set up the root command + rootCmd := initCLI() + + // Test that root command can handle arguments properly + // Test with no args (should succeed) + err := rootCmd.Execute() + assert.NoError(t, err, "Root command with no args should execute") + + // Test with help flag (should succeed) + rootCmd.SetArgs([]string{"--help"}) + err = rootCmd.Execute() + assert.NoError(t, err, "Root command with help flag should execute") +} + +// TestCommandSubcommands tests subcommand functionality +func TestCommandSubcommands(t *testing.T) { + // Initialize the CLI to set up the root command + rootCmd := initCLI() + + // Test that root command has subcommands + commands := rootCmd.Commands() + assert.Greater(t, len(commands), 0, "Root command should have subcommands") + + // Test that specific subcommands exist and can be executed + expectedSubcommands := []string{"run", "service", "clients", "upgrade", "log"} + for _, subCmdName := range expectedSubcommands { + // Find the subcommand + var subCmd *cobra.Command + for _, cmd := range commands { + if cmd.Name() == subCmdName { + subCmd = cmd + break + } + } + assert.NotNil(t, subCmd, "Subcommand %s should exist", subCmdName) + + // Test that subcommand has help + assert.NotEmpty(t, subCmd.Short, "Subcommand %s should have a short description", subCmdName) + } +} diff --git a/cmd/cli/commands_upgrade.go b/cmd/cli/commands_upgrade.go new file mode 100644 index 00000000..a6ab304f --- /dev/null +++ b/cmd/cli/commands_upgrade.go @@ -0,0 +1,192 @@ +package cli + +import ( + "context" + "errors" + "net/http" + "os" + "os/exec" + "strings" + "time" + + "github.com/kardianos/service" + "github.com/minio/selfupdate" + "github.com/spf13/cobra" +) + +const ( + upgradeChannelDev = "dev" + upgradeChannelProd = "prod" + upgradeChannelDefault = "default" +) + +// UpgradeCommand handles upgrade-related operations +type UpgradeCommand struct { +} + +// NewUpgradeCommand creates a new upgrade command handler +func NewUpgradeCommand() (*UpgradeCommand, error) { + return &UpgradeCommand{}, nil +} + +// Upgrade performs the upgrade operation +func (uc *UpgradeCommand) Upgrade(cmd *cobra.Command, args []string) error { + upgradeChannel := map[string]string{ + upgradeChannelDefault: "https://dl.controld.dev", + upgradeChannelDev: "https://dl.controld.dev", + upgradeChannelProd: "https://dl.controld.com", + } + if isStableVersion(curVersion()) { + upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd] + } + + bin, err := os.Executable() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("Failed to get current ctrld binary path") + } + + readConfig(false) + v.Unmarshal(&cfg) + svcCmd := NewServiceCommand() + s, p, err := svcCmd.initializeServiceManager() + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return nil + } + + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + + svcInstalled := true + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + svcInstalled = false + } + + oldBin := bin + oldBinSuffix + baseUrl := upgradeChannel[upgradeChannelDefault] + if len(args) > 0 { + channel := args[0] + switch channel { + case upgradeChannelProd, upgradeChannelDev: // ok + default: + mainLog.Load().Fatal().Msgf("Upgrade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) + } + baseUrl = upgradeChannel[channel] + } + + dlUrl := upgradeUrl(baseUrl) + mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) + + resp, err := getWithRetry(dlUrl, downloadServerIp) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("Failed to download binary") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + mainLog.Load().Fatal().Msgf("Could not download binary: %s", http.StatusText(resp.StatusCode)) + } + + mainLog.Load().Debug().Msg("Updating current binary") + if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil { + if rerr := selfupdate.RollbackError(err); rerr != nil { + mainLog.Load().Error().Err(rerr).Msg("Could not rollback old binary") + } + mainLog.Load().Fatal().Err(err).Msg("Failed to update current binary") + } + + doRestart := func() bool { + if !svcInstalled { + return true + } + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + // restore static DNS settings or DHCP + p.resetDNS(false, true) + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + doTasks(tasks) + + tasks = []task{ + {s.Start, true, "Start"}, + } + if doTasks(tasks) { + if dir, err := socketDir(); err == nil { + if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { + _, _ = cc.post(ifacePath, nil) + return true + } + } + } + return false + } + + if svcInstalled { + mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") + } + + if doRestart() { + _ = os.Remove(oldBin) + _ = os.Chmod(bin, 0755) + ver := "unknown version" + out, err := exec.Command(bin, "--version").CombinedOutput() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version") + } + if after, found := strings.CutPrefix(string(out), "ctrld version "); found { + ver = after + } + mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver) + return nil + } + + mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin) + if err := os.Remove(bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("Failed to remove new binary") + } + if err := os.Rename(oldBin, bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("Failed to restore old binary") + } + if doRestart() { + mainLog.Load().Notice().Msg("Restored previous binary successfully") + return nil + } + + return nil +} + +// InitUpgradeCmd creates the upgrade command with proper logic +func InitUpgradeCmd(rootCmd *cobra.Command) *cobra.Command { + upgradeCmd := &cobra.Command{ + Use: "upgrade", + Short: "Upgrading ctrld to latest version", + ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, + Args: cobra.MaximumNArgs(1), + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + RunE: func(cmd *cobra.Command, args []string) error { + uc, err := NewUpgradeCommand() + if err != nil { + return err + } + return uc.Upgrade(cmd, args) + }, + } + + rootCmd.AddCommand(upgradeCmd) + + return upgradeCmd +} diff --git a/cmd/cli/conn.go b/cmd/cli/conn.go deleted file mode 100644 index 82e64688..00000000 --- a/cmd/cli/conn.go +++ /dev/null @@ -1,51 +0,0 @@ -package cli - -import ( - "net" - "time" -) - -// logConn wraps a net.Conn, override the Write behavior. -// runCmd uses this wrapper, so as long as startCmd finished, -// ctrld log won't be flushed with un-necessary write errors. -type logConn struct { - conn net.Conn -} - -func (lc *logConn) Read(b []byte) (n int, err error) { - return lc.conn.Read(b) -} - -func (lc *logConn) Close() error { - return lc.conn.Close() -} - -func (lc *logConn) LocalAddr() net.Addr { - return lc.conn.LocalAddr() -} - -func (lc *logConn) RemoteAddr() net.Addr { - return lc.conn.RemoteAddr() -} - -func (lc *logConn) SetDeadline(t time.Time) error { - return lc.conn.SetDeadline(t) -} - -func (lc *logConn) SetReadDeadline(t time.Time) error { - return lc.conn.SetReadDeadline(t) -} - -func (lc *logConn) SetWriteDeadline(t time.Time) error { - return lc.conn.SetWriteDeadline(t) -} - -func (lc *logConn) Write(b []byte) (int, error) { - // Write performs writes with underlying net.Conn, ignore any errors happen. - // "ctrld run" command use this wrapper to report errors to "ctrld start". - // If no error occurred, "ctrld start" may finish before "ctrld run" attempt - // to close the connection, so ignore errors conservatively here, prevent - // un-necessary error "write to closed connection" flushed to ctrld log. - _, _ = lc.conn.Write(b) - return len(b), nil -} diff --git a/cmd/cli/control_client.go b/cmd/cli/control_client.go index 7382d4e8..0ab10404 100644 --- a/cmd/cli/control_client.go +++ b/cmd/cli/control_client.go @@ -8,10 +8,12 @@ import ( "time" ) +// controlClient represents an HTTP client for communicating with the control server type controlClient struct { c *http.Client } +// newControlClient creates a new control client with Unix socket transport func newControlClient(addr string) *controlClient { return &controlClient{c: &http.Client{ Transport: &http.Transport{ diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 9281b904..1c9d37cf 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -37,12 +37,14 @@ type ifaceResponse struct { OK bool `json:"ok"` } +// controlServer represents an HTTP server for handling control requests type controlServer struct { server *http.Server mux *http.ServeMux addr string } +// newControlServer creates a new control server instance func newControlServer(addr string) (*controlServer, error) { mux := http.NewServeMux() s := &controlServer{ @@ -79,34 +81,34 @@ func (s *controlServer) register(pattern string, handler http.Handler) { func (p *prog) registerControlServerHandler() { p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { - mainLog.Load().Debug().Msg("handling list clients request") + p.Debug().Msg("Handling list clients request") clients := p.ciTable.ListClients() - mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list") + p.Debug().Int("client_count", len(clients)).Msg("Retrieved clients list") sort.Slice(clients, func(i, j int) bool { return clients[i].IP.Less(clients[j].IP) }) - mainLog.Load().Debug().Msg("sorted clients by IP address") + p.Debug().Msg("Sorted clients by IP address") if p.metricsQueryStats.Load() { - mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts") + p.Debug().Msg("Metrics query stats enabled, collecting query counts") for idx, client := range clients { - mainLog.Load().Debug(). + p.Debug(). Int("index", idx). Str("ip", client.IP.String()). Str("mac", client.Mac). Str("hostname", client.Hostname). - Msg("processing client metrics") + Msg("Processing client metrics") client.IncludeQueryCount = true dm := &dto.Metric{} if statsClientQueriesCount.MetricVec == nil { - mainLog.Load().Debug(). + p.Debug(). Str("client_ip", client.IP.String()). - Msg("skipping metrics collection: MetricVec is nil") + Msg("Skipping metrics collection: MetricVec is nil") continue } @@ -116,44 +118,44 @@ func (p *prog) registerControlServerHandler() { client.Hostname, ) if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Str("client_ip", client.IP.String()). Str("mac", client.Mac). Str("hostname", client.Hostname). - Msg("failed to get metrics for client") + Msg("Failed to get metrics for client") continue } if err := m.Write(dm); err == nil && dm.Counter != nil { client.QueryCount = int64(dm.Counter.GetValue()) - mainLog.Load().Debug(). + p.Debug(). Str("client_ip", client.IP.String()). Int64("query_count", client.QueryCount). - Msg("successfully collected query count") + Msg("Successfully collected query count") } else if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Str("client_ip", client.IP.String()). - Msg("failed to write metric") + Msg("Failed to write metric") } } } else { - mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts") + p.Debug().Msg("Metrics query stats disabled, skipping query counts") } if err := json.NewEncoder(w).Encode(&clients); err != nil { - mainLog.Load().Error(). + p.Error(). Err(err). Int("client_count", len(clients)). - Msg("failed to encode clients response") + Msg("Failed to encode clients response") http.Error(w, err.Error(), http.StatusInternalServerError) return } - mainLog.Load().Debug(). + p.Debug(). Int("client_count", len(clients)). - Msg("successfully sent clients list response") + Msg("Successfully sent clients list response") })) p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { select { @@ -175,14 +177,14 @@ func (p *prog) registerControlServerHandler() { oldSvc := p.cfg.Service p.mu.Unlock() if err := p.sendReloadSignal(); err != nil { - mainLog.Load().Err(err).Msg("could not send reload signal") + p.Error().Err(err).Msg("Could not send reload signal") http.Error(w, err.Error(), http.StatusInternalServerError) return } select { case <-p.reloadDoneCh: case <-time.After(5 * time.Second): - http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError) + http.Error(w, "Timeout waiting for ctrld reload", http.StatusInternalServerError) return } @@ -216,15 +218,16 @@ func (p *prog) registerControlServerHandler() { return } + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) // Re-fetch pin code from API. - if rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); rc != nil { + if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev); rc != nil { if rc.DeactivationPin != nil { cdDeactivationPin.Store(*rc.DeactivationPin) } else { cdDeactivationPin.Store(defaultDeactivationPin) } } else { - mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code") + p.Warn().Err(err).Msg("Could not re-fetch deactivation pin code") } // If pin code not set, allowing deactivation. @@ -236,7 +239,7 @@ func (p *prog) registerControlServerHandler() { var req deactivationRequest if err := json.NewDecoder(request.Body).Decode(&req); err != nil { w.WriteHeader(http.StatusPreconditionFailed) - mainLog.Load().Err(err).Msg("invalid deactivation request") + p.Error().Err(err).Msg("Invalid deactivation request") return } @@ -280,7 +283,7 @@ func (p *prog) registerControlServerHandler() { } })) p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { - lr, err := p.logReader() + lr, err := p.logReaderRaw() if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -306,7 +309,7 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusServiceUnavailable) return } - r, err := p.logReader() + r, err := p.logReaderNoColor() if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -319,14 +322,15 @@ func (p *prog) registerControlServerHandler() { UID: cdUID, Data: r.r, } - mainLog.Load().Debug().Msg("sending log file to ControlD server") + p.Debug().Msg("Sending log file to ControlD server") resp := logSentResponse{Size: r.size} - if err := controld.SendLogs(req, cdDev); err != nil { - mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) + if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil { + p.Error().Msgf("Could not send log file to ControlD server: %v", err) resp.Error = err.Error() w.WriteHeader(http.StatusInternalServerError) } else { - mainLog.Load().Debug().Msg("sending log file successfully") + p.Debug().Msg("Sending log file successfully") w.WriteHeader(http.StatusOK) } if err := json.NewEncoder(w).Encode(&resp); err != nil { @@ -336,6 +340,7 @@ func (p *prog) registerControlServerHandler() { })) } +// jsonResponse wraps an HTTP handler to set JSON content type func jsonResponse(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/cmd/cli/dns.go b/cmd/cli/dns.go deleted file mode 100644 index cf9d779e..00000000 --- a/cmd/cli/dns.go +++ /dev/null @@ -1,4 +0,0 @@ -package cli - -//lint:ignore U1000 use in os_linux.go -type getDNS func(iface string) []string diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 994741b1..10a9581e 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -25,56 +25,70 @@ import ( "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" - "github.com/Control-D-Inc/ctrld/internal/router" + "github.com/Control-D-Inc/ctrld/internal/rulematcher" ) +// DNS proxy constants for configuration and behavior control const ( + // staleTTL is the TTL for stale cache entries + // This allows serving cached responses even when upstreams are temporarily unavailable staleTTL = 60 * time.Second + + // localTTL is the TTL for local network responses + // Longer TTL for local queries reduces unnecessary repeated lookups localTTL = 3600 * time.Second + // EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option. // https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81 // This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification. + // This enables MAC address-based client identification for policy routing EDNS0_OPTION_MAC = 0xFDE9 // selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation. + // This prevents premature self-uninstallation due to temporary network issues selfUninstallMaxQueries = 32 ) +// osUpstreamConfig defines the default OS resolver configuration +// This is used as a fallback when all configured upstreams fail var osUpstreamConfig = &ctrld.UpstreamConfig{ Name: "OS resolver", Type: ctrld.ResolverTypeOS, Timeout: 3000, } +// privateUpstreamConfig defines the default private resolver configuration +// This is used for internal network queries that should not go to public resolvers var privateUpstreamConfig = &ctrld.UpstreamConfig{ Name: "Private resolver", Type: ctrld.ResolverTypePrivate, Timeout: 2000, } -var localUpstreamConfig = &ctrld.UpstreamConfig{ - Name: "Local resolver", - Type: ctrld.ResolverTypeLocal, - Timeout: 2000, -} - // proxyRequest contains data for proxying a DNS query to upstream. +// This structure encapsulates all the information needed to process a DNS request type proxyRequest struct { - msg *dns.Msg - ci *ctrld.ClientInfo - failoverRcodes []int - ufr *upstreamForResult + msg *dns.Msg + ci *ctrld.ClientInfo + failoverRcodes []int + ufr *upstreamForResult + staleAnswer *dns.Msg + isLanOrPtrQuery bool + upstreamConfigs []*ctrld.UpstreamConfig } // proxyResponse contains data for proxying a DNS response from upstream. +// This structure encapsulates the response and metadata for logging and metrics type proxyResponse struct { answer *dns.Msg + upstream string cached bool clientInfo bool - upstream string + refused bool } // upstreamForResult represents the result of processing rules for a request. +// This contains the matched policy information for logging and debugging type upstreamForResult struct { upstreams []string matchedPolicy string @@ -84,170 +98,277 @@ type upstreamForResult struct { srcAddr string } -func (p *prog) serveDNS(listenerNum string) error { +// serveDNS sets up and starts a DNS server on the specified listener, handling DNS queries and network monitoring. +// This is the main entry point for DNS server functionality +func (p *prog) serveDNS(ctx context.Context, listenerNum string) error { + logger := p.logger.Load() + logger.Debug().Msg("DNS server setup started") + listenerConfig := p.cfg.Listener[listenerNum] - // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { - mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") + p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: Failed to allocate listen IP") return allocErr } handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { - p.sema.acquire() - defer p.sema.release() - if len(m.Question) == 0 { - answer := new(dns.Msg) - answer.SetRcode(m, dns.RcodeFormatError) - _ = w.WriteMsg(answer) - return - } - listenerConfig := p.cfg.Listener[listenerNum] - reqId := requestID() - ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) - if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { - ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) - answer := new(dns.Msg) - answer.SetRcode(m, dns.RcodeRefused) - _ = w.WriteMsg(answer) - return - } - go p.detectLoop(m) - q := m.Question[0] - domain := canonicalName(q.Name) - switch { - case domain == "": - answer := new(dns.Msg) - answer.SetRcode(m, dns.RcodeFormatError) - _ = w.WriteMsg(answer) - return - case domain == selfCheckInternalTestDomain: - answer := resolveInternalDomainTestQuery(ctx, domain, m) - _ = w.WriteMsg(answer) - return - } - - if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { - p.cache.Purge() - ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain) - } - remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - ci := p.getClientInfo(remoteIP, m) - ci.ClientIDPref = p.cfg.Service.ClientIDPref - stripClientSubnet(m) - remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) - fmtSrcToDest := fmtRemoteToLocal(listenerNum, ci.Hostname, remoteAddr.String()) - t := time.Now() - ctrld.Log(ctx, mainLog.Load().Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) - ur := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) - - labelValues := make([]string, 0, len(statsQueriesCountLabels)) - labelValues = append(labelValues, net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))) - labelValues = append(labelValues, ci.IP) - labelValues = append(labelValues, ci.Mac) - labelValues = append(labelValues, ci.Hostname) - - var answer *dns.Msg - if !ur.matched && listenerConfig.Restricted { - ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String()) - answer = new(dns.Msg) - answer.SetRcode(m, dns.RcodeRefused) - labelValues = append(labelValues, "") // no upstream - } else { - var failoverRcode []int - if listenerConfig.Policy != nil { - failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers - } - pr := p.proxy(ctx, &proxyRequest{ - msg: m, - ci: ci, - failoverRcodes: failoverRcode, - ufr: ur, - }) - go p.doSelfUninstall(pr.answer) - - answer = pr.answer - rtt := time.Since(t) - ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) - upstream := pr.upstream - switch { - case pr.cached: - upstream = "cache" - case pr.clientInfo: - upstream = "client_info_table" - } - labelValues = append(labelValues, upstream) - } - labelValues = append(labelValues, dns.TypeToString[q.Qtype]) - labelValues = append(labelValues, dns.RcodeToString[answer.Rcode]) - go func() { - p.WithLabelValuesInc(statsQueriesCount, labelValues...) - p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...) - p.forceFetchingAPI(domain) - }() - if err := w.WriteMsg(answer); err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client") - } + p.handleDNSQuery(w, m, listenerNum, listenerConfig) }) - g, ctx := errgroup.WithContext(context.Background()) + logger.Debug().Msg("DNS server setup completed") + return p.startListeners(ctx, listenerConfig, handler) +} + +// startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols. +// It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors. +// This function manages the lifecycle of DNS server listeners +func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error { + logger := p.logger.Load() + logger.Debug().Msg("Starting DNS listeners") + + g, gctx := errgroup.WithContext(ctx) + for _, proto := range []string{"udp", "tcp"} { - proto := proto if needLocalIPv6Listener() { + logger.Debug().Str("protocol", proto).Msg("Starting local IPv6 listener") g.Go(func() error { - s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler) + s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(cfg.Port)), proto, handler) defer s.Shutdown() select { case <-p.stopCh: - case <-ctx.Done(): + case <-gctx.Done(): case err := <-errCh: - // Local ipv6 listener should not terminate ctrld. - // It's a workaround for a quirk on Windows. - mainLog.Load().Warn().Err(err).Msg("local ipv6 listener failed") + p.Warn().Err(err).Msg("Local IPv6 listener failed") } return nil }) } + // When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918 addresses of the machine // if explicitly set via setting rfc1918 flag, so ctrld could receive queries from LAN clients. - if needRFC1918Listeners(listenerConfig) { + if needRFC1918Listeners(cfg) { + logger.Debug().Str("protocol", proto).Msg("Starting RFC1918 listeners") g.Go(func() error { for _, addr := range ctrld.Rfc1918Addresses() { func() { - listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port)) + listenAddr := net.JoinHostPort(addr, strconv.Itoa(cfg.Port)) s, errCh := runDNSServer(listenAddr, proto, handler) defer s.Shutdown() select { case <-p.stopCh: - case <-ctx.Done(): + case <-gctx.Done(): case err := <-errCh: - // RFC1918 listener should not terminate ctrld. - // It's a workaround for a quirk on system with systemd-resolved. - mainLog.Load().Warn().Err(err).Msgf("could not listen on %s: %s", proto, listenAddr) + p.Warn().Err(err).Msgf("Could not listen on %s: %s", proto, listenAddr) } }() } return nil }) } + + logger.Debug().Str("protocol", proto).Str("ip", cfg.IP).Int("port", cfg.Port).Msg("Starting main listener") g.Go(func() error { - addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) + addr := net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)) s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() - p.started <- struct{}{} - select { case <-p.stopCh: - case <-ctx.Done(): + case <-gctx.Done(): case err := <-errCh: return err } return nil }) } + + logger.Debug().Msg("DNS listeners started successfully") return g.Wait() } +// handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers. +// This is the main entry point for all DNS query processing +func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) { + p.sema.acquire() + defer p.sema.release() + + if len(m.Question) == 0 { + sendDNSResponse(w, m, dns.RcodeFormatError) + return + } + + reqID := requestID() + ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqID) + ctx = ctrld.LoggerCtx(ctx, p.logger.Load()) + + ctrld.Log(ctx, p.Debug(), "Processing DNS query from %s", w.RemoteAddr().String()) + + if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { + ctrld.Log(ctx, p.Debug(), "Query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) + sendDNSResponse(w, m, dns.RcodeRefused) + return + } + + go p.detectLoop(m) + + q := m.Question[0] + domain := canonicalName(q.Name) + + if p.handleSpecialDomains(ctx, w, m, domain) { + ctrld.Log(ctx, p.Debug(), "Special domain query handled") + return + } + + ctrld.Log(ctx, p.Debug(), "Processing standard query for domain: %s", domain) + p.processStandardQuery(&standardQueryRequest{ + ctx: ctx, + writer: w, + msg: m, + listenerNum: listenerNum, + listenerConfig: listenerConfig, + domain: domain, + }) +} + +// handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status. +// This handles internal test domains and cache management commands +func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool { + switch { + case domain == "": + ctrld.Log(ctx, p.Debug(), "Empty domain query, sending format error") + sendDNSResponse(w, m, dns.RcodeFormatError) + return true + case domain == selfCheckInternalTestDomain: + ctrld.Log(ctx, p.Debug(), "Internal test domain query: %s", domain) + answer := resolveInternalDomainTestQuery(ctx, domain, m) + _ = w.WriteMsg(answer) + return true + } + + if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { + p.cache.Purge() + ctrld.Log(ctx, p.Debug(), "Received query %q, local cache is purged", domain) + } + + return false +} + +// standardQueryRequest represents a standard DNS query request with associated context and configuration. +// This encapsulates all the data needed to process a standard DNS query +type standardQueryRequest struct { + ctx context.Context + writer dns.ResponseWriter + msg *dns.Msg + listenerNum string + listenerConfig *ctrld.ListenerConfig + domain string +} + +// processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response. +// This is the main processing pipeline for normal DNS queries +func (p *prog) processStandardQuery(req *standardQueryRequest) { + ctrld.Log(req.ctx, p.Debug(), "Processing standard query started") + + remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String()) + ci := p.getClientInfo(remoteIP, req.msg) + ci.ClientIDPref = p.cfg.Service.ClientIDPref + + stripClientSubnet(req.msg) + remoteAddr := spoofRemoteAddr(req.writer.RemoteAddr(), ci) + fmtSrcToDest := fmtRemoteToLocal(req.listenerNum, ci.Hostname, remoteAddr.String()) + + startTime := time.Now() + q := req.msg.Question[0] + ctrld.Log(req.ctx, p.Info(), "QUERY: %s: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], req.domain) + + ur := p.upstreamFor(req.ctx, req.listenerNum, req.listenerConfig, remoteAddr, ci.Mac, req.domain) + + var answer *dns.Msg + // Handle restricted listener case + if !ur.matched && req.listenerConfig.Restricted { + ctrld.Log(req.ctx, p.Debug(), "Query refused, %s does not match any network policy", remoteAddr.String()) + answer = new(dns.Msg) + answer.SetRcode(req.msg, dns.RcodeRefused) + // Process the refused query + go p.postProcessStandardQuery(ci, req.listenerConfig, q, &proxyResponse{answer: answer, refused: true}) + } else { + // Process a normal query + ctrld.Log(req.ctx, p.Debug(), "Starting proxy query processing") + pr := p.proxy(req.ctx, &proxyRequest{ + msg: req.msg, + ci: ci, + failoverRcodes: p.getFailoverRcodes(req.listenerConfig), + ufr: ur, + }) + + rtt := time.Since(startTime) + ctrld.Log(req.ctx, p.Debug(), "Received response of %d bytes in %s", pr.answer.Len(), rtt) + + go p.postProcessStandardQuery(ci, req.listenerConfig, q, pr) + answer = pr.answer + } + + if err := req.writer.WriteMsg(answer); err != nil { + ctrld.Log(req.ctx, p.Error().Err(err), "serveDNS: failed to send DNS response to client") + } + + ctrld.Log(req.ctx, p.Debug(), "Standard query processing completed") +} + +// postProcessStandardQuery performs additional actions after processing a standard DNS query, such as metrics recording, +// handling canonical name adjustments, and triggering specific post-query actions like uninstallation procedures. +func (p *prog) postProcessStandardQuery(ci *ctrld.ClientInfo, listenerConfig *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) { + p.doSelfUninstall(pr) + p.recordMetrics(ci, listenerConfig, q, pr) + p.forceFetchingAPI(canonicalName(q.Name)) +} + +// getFailoverRcodes retrieves the failover response codes from the provided ListenerConfig. Returns nil if no policy exists. +func (p *prog) getFailoverRcodes(cfg *ctrld.ListenerConfig) []int { + if cfg.Policy != nil { + return cfg.Policy.FailoverRcodeNumbers + } + return nil +} + +// recordMetrics updates Prometheus metrics for DNS queries, including query count and client-specific query statistics. +func (p *prog) recordMetrics(ci *ctrld.ClientInfo, cfg *ctrld.ListenerConfig, q dns.Question, pr *proxyResponse) { + upstream := pr.upstream + switch { + case pr.cached: + upstream = "cache" + case pr.clientInfo: + upstream = "client_info_table" + } + labelValues := []string{ + net.JoinHostPort(cfg.IP, strconv.Itoa(cfg.Port)), + ci.IP, + ci.Mac, + ci.Hostname, + upstream, + dns.TypeToString[q.Qtype], + dns.RcodeToString[pr.answer.Rcode], + } + p.WithLabelValuesInc(statsQueriesCount, labelValues...) + p.WithLabelValuesInc(statsClientQueriesCount, []string{ci.IP, ci.Mac, ci.Hostname}...) +} + +// sendDNSResponse sends a DNS response with the specified RCODE to the client using the provided ResponseWriter. +func sendDNSResponse(w dns.ResponseWriter, m *dns.Msg, rcode int) { + answer := new(dns.Msg) + answer.SetRcode(m, rcode) + _ = w.WriteMsg(answer) +} + +// upstreamForRequest contains all parameters needed for upstream determination +type upstreamForRequest struct { + DefaultUpstreamNum string + ListenerConfig *ctrld.ListenerConfig + Addr net.Addr + SrcMac string + Domain string + MatchingConfig *rulematcher.MatchingConfig +} + // upstreamFor returns the list of upstreams for resolving the given domain, // matching by policies defined in the listener config. The second return value // reports whether the domain matches the policy. @@ -256,94 +377,95 @@ func (p *prog) serveDNS(listenerNum string) error { // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) { - upstreams := []string{upstreamPrefix + defaultUpstreamNum} - matchedPolicy := "no policy" - matchedNetwork := "no network" - matchedRule := "no rule" - matched := false - res = &upstreamForResult{srcAddr: addr.String()} - - defer func() { - res.upstreams = upstreams - res.matched = matched - res.matchedPolicy = matchedPolicy - res.matchedNetwork = matchedNetwork - res.matchedRule = matchedRule - }() + var matchingConfig *rulematcher.MatchingConfig + if lc.Policy != nil && lc.Policy.Matching != nil { + // Convert string-based order to RuleType enum + var order []rulematcher.RuleType + for _, ruleTypeStr := range lc.Policy.Matching.Order { + switch ruleTypeStr { + case "network": + order = append(order, rulematcher.RuleTypeNetwork) + case "mac": + order = append(order, rulematcher.RuleTypeMac) + case "domain": + order = append(order, rulematcher.RuleTypeDomain) + } + } - if lc.Policy == nil { - return + matchingConfig = &rulematcher.MatchingConfig{ + Order: order, + } + } + + req := &upstreamForRequest{ + DefaultUpstreamNum: defaultUpstreamNum, + ListenerConfig: lc, + Addr: addr, + SrcMac: srcMac, + Domain: domain, + MatchingConfig: matchingConfig, } - do := func(policyUpstreams []string) { - upstreams = append([]string(nil), policyUpstreams...) + return p.upstreamForWithConfig(ctx, req) +} + +// upstreamForWithConfig determines upstreams using configurable rule matching +func (p *prog) upstreamForWithConfig(ctx context.Context, req *upstreamForRequest) (res *upstreamForResult) { + // Default upstreams + upstreams := []string{upstreamPrefix + req.DefaultUpstreamNum} + res = &upstreamForResult{srcAddr: req.Addr.String()} + + // If no policy, return default upstreams + if req.ListenerConfig.Policy == nil { + res.upstreams = upstreams + res.matched = false + res.matchedPolicy = "no policy" + res.matchedNetwork = "no network" + res.matchedRule = "no rule" + return } - var networkTargets []string + // Extract source IP from address var sourceIP net.IP - switch addr := addr.(type) { + switch addr := req.Addr.(type) { case *net.UDPAddr: sourceIP = addr.IP case *net.TCPAddr: sourceIP = addr.IP } -networkRules: - for _, rule := range lc.Policy.Networks { - for source, targets := range rule { - networkNum := strings.TrimPrefix(source, "network.") - nc := p.cfg.Network[networkNum] - if nc == nil { - continue - } - for _, ipNet := range nc.IPNets { - if ipNet.Contains(sourceIP) { - matchedPolicy = lc.Policy.Name - matchedNetwork = source - networkTargets = targets - matched = true - break networkRules - } - } - } + // Create match request + matchRequest := &rulematcher.MatchRequest{ + SourceIP: sourceIP, + SourceMac: req.SrcMac, + Domain: req.Domain, + Policy: req.ListenerConfig.Policy, + Config: p.cfg, } -macRules: - for _, rule := range lc.Policy.Macs { - for source, targets := range rule { - if source != "" && (strings.EqualFold(source, srcMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(srcMac))) { - matchedPolicy = lc.Policy.Name - matchedNetwork = source - networkTargets = targets - matched = true - break macRules - } - } - } + // Use matching engine to find upstreams + engine := rulematcher.NewMatchingEngine(req.MatchingConfig) + matchResult := engine.FindUpstreams(ctx, matchRequest) - for _, rule := range lc.Policy.Rules { - // There's only one entry per rule, config validation ensures this. - for source, targets := range rule { - if source == domain || wildcardMatches(source, domain) { - matchedPolicy = lc.Policy.Name - if len(networkTargets) > 0 { - matchedNetwork += " (unenforced)" - } - matchedRule = source - do(targets) - matched = true - return - } - } - } + // Convert result to upstreamForResult format + res.upstreams = matchResult.Upstreams + res.matched = matchResult.Matched + res.matchedPolicy = matchResult.MatchedPolicy + res.matchedNetwork = matchResult.MatchedNetwork + res.matchedRule = matchResult.MatchedRule - if matched { - do(networkTargets) + // If no match found, use default upstreams + if !matchResult.Matched { + res.upstreams = upstreams } return } +// proxyPrivatePtrLookup performs a private PTR DNS lookup based on the client info table for the given query. +// It prevents DNS loops by locking the processing of the same domain name simultaneously. +// If a valid IP-to-hostname mapping exists, it creates a PTR DNS record as the response. +// Returns the DNS response if a hostname is found or nil otherwise. func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg { cDomainName := msg.Question[0].Name locked := p.ptrLoopGuard.TryLock(cDomainName) @@ -364,8 +486,8 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg }, Ptr: dns.Fqdn(name), }} - ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + ctrld.Log(ctx, p.Info(), "Private PTR lookup, using client info table") + ctrld.Log(ctx, p.Debug(), "Client info: %v", ctrld.ClientInfo{ Mac: p.ciTable.LookupMac(ip.String()), IP: ip.String(), Hostname: name, @@ -375,6 +497,10 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg return nil } +// proxyLanHostnameQuery resolves LAN hostnames to their corresponding IP addresses based on the dns.Msg request. +// It uses a loop guard mechanism to prevent DNS query loops and ensures a hostname is processed only once at a time. +// This method queries the client info table for the hostname's IP address and logs relevant debug and client info. +// If the hostname matches known IPs in the table, it generates an appropriate dns.Msg response; otherwise, it returns nil. func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg { q := msg.Question[0] hostname := strings.TrimSuffix(q.Name, ".") @@ -409,8 +535,8 @@ func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg AAAA: ip.AsSlice(), }} } - ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + ctrld.Log(ctx, p.Info(), "Lan hostname lookup, using client info table") + ctrld.Log(ctx, p.Debug(), "Client info: %v", ctrld.ClientInfo{ Mac: p.ciTable.LookupMac(ip.String()), IP: ip.String(), Hostname: hostname, @@ -420,243 +546,372 @@ func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg return nil } +// handleSpecialQueryTypes processes specific types of DNS queries such as SRV, PTR, and LAN hostname lookups. +// It modifies upstreams and upstreamConfigs based on the query type and updates the query context accordingly. +// Returns a proxyResponse if the query is resolved locally; otherwise, returns nil to proceed with upstream processing. +func (p *prog) handleSpecialQueryTypes(ctx *context.Context, req *proxyRequest, upstreams *[]string, upstreamConfigs *[]*ctrld.UpstreamConfig) *proxyResponse { + if req.ufr.matched { + ctrld.Log(*ctx, p.Debug(), "%s, %s, %s -> %v", + req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, *upstreams) + return nil + } + + switch { + case isSrvLanLookup(req.msg): + *upstreams = []string{upstreamOS} + *upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + *ctx = ctrld.LanQueryCtx(*ctx) + ctrld.Log(*ctx, p.Debug(), "SRV record lookup, using upstreams: %v", *upstreams) + return nil + case isPrivatePtrLookup(req.msg): + req.isLanOrPtrQuery = true + if answer := p.proxyPrivatePtrLookup(*ctx, req.msg); answer != nil { + return &proxyResponse{answer: answer, clientInfo: true} + } + *upstreams, *upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(*upstreams, *upstreamConfigs) + *ctx = ctrld.LanQueryCtx(*ctx) + ctrld.Log(*ctx, p.Debug(), "Private PTR lookup, using upstreams: %v", *upstreams) + return nil + case isLanHostnameQuery(req.msg): + req.isLanOrPtrQuery = true + if answer := p.proxyLanHostnameQuery(*ctx, req.msg); answer != nil { + return &proxyResponse{answer: answer, clientInfo: true} + } + *upstreams = []string{upstreamOS} + *upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + *ctx = ctrld.LanQueryCtx(*ctx) + ctrld.Log(*ctx, p.Debug(), "Lan hostname lookup, using upstreams: %v", *upstreams) + return nil + default: + ctrld.Log(*ctx, p.Debug(), "No explicit policy matched, using default routing -> %v", *upstreams) + return nil + } +} + +// proxy handles DNS query proxying by selecting upstreams, attempting cache lookups, and querying configured resolvers. func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { - var staleAnswer *dns.Msg + ctrld.Log(ctx, p.Debug(), "Proxy query processing started") + + upstreams, upstreamConfigs := p.initializeUpstreams(req) + ctrld.Log(ctx, p.Debug(), "Initialized upstreams: %v", upstreams) + + if specialRes := p.handleSpecialQueryTypes(&ctx, req, &upstreams, &upstreamConfigs); specialRes != nil { + ctrld.Log(ctx, p.Debug(), "Special query type handled") + return specialRes + } + + if cachedRes := p.tryCache(ctx, req, upstreams); cachedRes != nil { + ctrld.Log(ctx, p.Debug(), "Cache hit, returning cached response") + return cachedRes + } + + ctrld.Log(ctx, p.Debug(), "No cache hit, trying upstreams") + if res := p.tryUpstreams(ctx, req, upstreams, upstreamConfigs); res != nil { + ctrld.Log(ctx, p.Debug(), "Upstream query successful") + return res + } + + ctrld.Log(ctx, p.Debug(), "All upstreams failed, handling failure") + return p.handleAllUpstreamsFailure(ctx, req, upstreams) +} + +// initializeUpstreams determines which upstreams and configurations to use for a given proxyRequest. +// If no upstreams are configured, it defaults to the operating system's resolver configuration. +// Returns a slice of upstream names and their corresponding configurations. +func (p *prog) initializeUpstreams(req *proxyRequest) ([]string, []*ctrld.UpstreamConfig) { upstreams := req.ufr.upstreams - serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) - if len(upstreamConfigs) == 0 { - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - upstreams = []string{upstreamOS} - // For OS resolver, local addresses are ignored to prevent possible looping. - // However, on Active Directory Domain Controller, where it has local DNS server - // running and listening on local addresses, these local addresses must be used - // as nameservers, so queries for ADDC could be resolved as expected. - if p.isAdDomainQuery(req.msg) { - ctrld.Log(ctx, mainLog.Load().Debug(), - "AD domain query detected for %s in domain %s", - req.msg.Question[0].Name, p.adDomain) - upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} - upstreams = []string{upstreamOSLocal} + return []string{upstreamOS}, []*ctrld.UpstreamConfig{osUpstreamConfig} + } + return upstreams, upstreamConfigs +} + +// tryCache attempts to retrieve a cached response for the given DNS request from specified upstreams. +// Returns a proxyResponse if a cache hit occurs; otherwise, returns nil. +// Skips cache checking if caching is disabled or the request is a PTR query. +// Iterates through the provided upstreams to find a cached response using the checkCache method. +func (p *prog) tryCache(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse { + if p.cache == nil || req.msg.Question[0].Qtype == dns.TypePTR { // https://www.rfc-editor.org/rfc/rfc1035#section-7.4 + ctrld.Log(ctx, p.Debug(), "Cache disabled or PTR query, skipping cache lookup") + return nil + } + + ctrld.Log(ctx, p.Debug(), "Checking cache for upstreams: %v", upstreams) + for _, upstream := range upstreams { + if res := p.checkCache(ctx, req, upstream); res != nil { + ctrld.Log(ctx, p.Debug(), "Cache hit found for upstream: %s", upstream) + return res } } - res := &proxyResponse{} + ctrld.Log(ctx, p.Debug(), "No cache hit found") + return nil +} - // LAN/PTR lookup flow: - // - // 1. If there's matching rule, follow it. - // 2. Try from client info table. - // 3. Try private resolver. - // 4. Try remote upstream. - isLanOrPtrQuery := false - if req.ufr.matched { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } else { - switch { - case isSrvLanLookup(req.msg): - upstreams = []string{upstreamOS} - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams) - case isPrivatePtrLookup(req.msg): - isLanOrPtrQuery = true - if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { - res.answer = answer - res.clientInfo = true - return res - } - upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) - ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams) - case isLanHostnameQuery(req.msg): - isLanOrPtrQuery = true - if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil { - res.answer = answer - res.clientInfo = true - return res +// checkCache checks if a cached DNS response exists for the given request and upstream. +// Returns a proxyResponse with the cached response if found and valid, or nil otherwise. +func (p *prog) checkCache(ctx context.Context, req *proxyRequest, upstream string) *proxyResponse { + cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream)) + if cachedValue == nil { + ctrld.Log(ctx, p.Debug(), "No cached value found for upstream: %s", upstream) + return nil + } + + answer := cachedValue.Msg.Copy() + ctrld.SetCacheReply(answer, req.msg, answer.Rcode) + now := time.Now() + + if cachedValue.Expire.After(now) { + ctrld.Log(ctx, p.Debug(), "Hit cached response") + setCachedAnswerTTL(answer, now, cachedValue.Expire) + return &proxyResponse{answer: answer, cached: true} + } + + ctrld.Log(ctx, p.Debug(), "Cached response expired, storing as stale") + req.staleAnswer = answer + return nil +} + +// updateCache updates the DNS response cache with the given request, response, TTL, and upstream information. +func (p *prog) updateCache(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string) { + ttl := ttlFromMsg(answer) + now := time.Now() + expired := now.Add(time.Duration(ttl) * time.Second) + if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 { + expired = now.Add(time.Duration(cachedTTL) * time.Second) + } + setCachedAnswerTTL(answer, now, expired) + p.cache.Add(dnscache.NewKey(req.msg, upstream), dnscache.NewValue(answer, expired)) + ctrld.Log(ctx, p.Debug(), "Added cached response") +} + +// serveStaleResponse serves a stale cached DNS response when an upstream query fails, updating TTL for cached records. +func (p *prog) serveStaleResponse(ctx context.Context, staleAnswer *dns.Msg) *proxyResponse { + ctrld.Log(ctx, p.Debug(), "Serving stale cached response") + now := time.Now() + setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL)) + return &proxyResponse{answer: staleAnswer, cached: true} +} + +// handleAllUpstreamsFailure handles the failure scenario when all upstream resolvers fail to respond or process the request. +func (p *prog) handleAllUpstreamsFailure(ctx context.Context, req *proxyRequest, upstreams []string) *proxyResponse { + ctrld.Log(ctx, p.Error(), "All %v endpoints failed", upstreams) + + if p.leakOnUpstreamFailure() { + ctrld.Log(ctx, p.Debug(), "Leak on upstream failure enabled") + if p.um.countHealthy(upstreams) == 0 { + ctrld.Log(ctx, p.Debug(), "No healthy upstreams, triggering recovery") + p.triggerRecovery(upstreams[0] == upstreamOS) + } else { + ctrld.Log(ctx, p.Debug(), "One upstream is down but at least one is healthy; skipping recovery trigger") + } + + if upstreams[0] != upstreamOS { + ctrld.Log(ctx, p.Debug(), "Trying OS resolver as fallback") + if answer := p.tryOSResolver(ctx, req); answer != nil { + ctrld.Log(ctx, p.Debug(), "OS resolver fallback successful") + return answer } - upstreams = []string{upstreamOS} - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - ctx = ctrld.LanQueryCtx(ctx) - ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams) - default: - ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams) } } - // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 + ctrld.Log(ctx, p.Debug(), "Returning server failure response") + answer := new(dns.Msg) + answer.SetRcode(req.msg, dns.RcodeServerFailure) + return &proxyResponse{answer: answer} +} + +// shouldContinueWithNextUpstream determines whether processing should continue with the next upstream based on response conditions. +func (p *prog) shouldContinueWithNextUpstream(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, lastUpstream bool) bool { + if answer.Rcode == dns.RcodeSuccess { + ctrld.Log(ctx, p.Debug(), "Successful response, not continuing to next upstream") + return false + } + + // We are doing LAN/PTR lookup using private resolver, so always process the next one. + // Except for the last, we want to send a response instead of saying all upstream failed. + if req.isLanOrPtrQuery && !lastUpstream { + ctrld.Log(ctx, p.Debug(), "No response for LAN/PTR query from %s, process to next upstream", upstream) + return true + } + + if len(req.upstreamConfigs) > 1 && slices.Contains(req.failoverRcodes, answer.Rcode) { + ctrld.Log(ctx, p.Debug(), "Failover rcode matched, process to next upstream") + return true + } + + ctrld.Log(ctx, p.Debug(), "Not continuing to next upstream") + return false +} + +// prepareSuccessResponse prepares a successful DNS response for a given request, logs it, and updates the cache if applicable. +func (p *prog) prepareSuccessResponse(ctx context.Context, req *proxyRequest, answer *dns.Msg, upstream string, upstreamConfig *ctrld.UpstreamConfig) *proxyResponse { + ctrld.Log(ctx, p.Debug(), "Preparing success response") + + answer.Compress = true + if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { - for _, upstream := range upstreams { - cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream)) - if cachedValue == nil { - continue - } - answer := cachedValue.Msg.Copy() - ctrld.SetCacheReply(answer, req.msg, answer.Rcode) - now := time.Now() - if cachedValue.Expire.After(now) { - ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") - setCachedAnswerTTL(answer, now, cachedValue.Expire) - res.answer = answer - res.cached = true - return res - } - staleAnswer = answer - } + ctrld.Log(ctx, p.Debug(), "Updating cache with successful response") + p.updateCache(ctx, req, answer, upstream) } - resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { - ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) - dnsResolver, err := ctrld.NewResolver(upstreamConfig) - if err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") - return nil, err - } - resolveCtx, cancel := upstreamConfig.Context(ctx) - defer cancel() - return dnsResolver.Resolve(resolveCtx, msg) - } - resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { - if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request") - ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) - } - answer, err := resolve1(upstream, upstreamConfig, msg) - // if we have an answer, we should reset the failure count - // we dont use reset here since we dont want to prevent failure counts from being incremented - if answer != nil { - p.um.mu.Lock() - p.um.failureReq[upstream] = 0 - p.um.down[upstream] = false - p.um.mu.Unlock() - return answer - } - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + hostname := "" + if req.ci != nil { + hostname = req.ci.Hostname + } - // increase failure count when there is no answer - // rehardless of what kind of error we get - p.um.increaseFailureCount(upstream) + ctrld.Log(ctx, p.Info(), "REPLY: %s -> %s (%s): %s", + upstream, req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) - if err != nil { - // For timeout error (i.e: context deadline exceed), force re-bootstrapping. - var e net.Error - if errors.As(err, &e) && e.Timeout() { - upstreamConfig.ReBootstrap() - } - // For network error, turn ipv6 off if enabled. - if ctrld.HasIPv6() && (errUrlNetworkError(err) || errNetworkError(err)) { - ctrld.DisableIPv6() - } + return &proxyResponse{ + answer: answer, + upstream: upstreamConfig.Endpoint, + } +} + +// tryUpstreams attempts to proxy a DNS request through the provided upstreams and their configurations sequentially. +// It returns a successful proxyResponse if any upstream processes the request successfully, or nil otherwise. +// The function supports "serve stale" for cache by utilizing cached responses when upstreams fail. +func (p *prog) tryUpstreams(ctx context.Context, req *proxyRequest, upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) *proxyResponse { + serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale + req.upstreamConfigs = upstreamConfigs + + ctrld.Log(ctx, p.Debug(), "Trying %d upstreams", len(upstreamConfigs)) + + for n, upstreamConfig := range upstreamConfigs { + last := n == len(upstreamConfigs)-1 + ctrld.Log(ctx, p.Debug(), "Processing upstream %d/%d: %s", n+1, len(upstreamConfigs), upstreams[n]) + + if res := p.processUpstream(ctx, req, upstreams[n], upstreamConfig, serveStaleCache, last); res != nil { + ctrld.Log(ctx, p.Debug(), "Upstream %s succeeded", upstreams[n]) + return res } + ctrld.Log(ctx, p.Debug(), "Upstream %s failed", upstreams[n]) + } + + ctrld.Log(ctx, p.Debug(), "All upstreams failed") + return nil +} + +// processUpstream proxies a DNS query to a given upstream server and processes the response based on the provided configuration. +// It supports serving stale cache when upstream queries fail, and checks if processing should continue to another upstream. +// Returns a proxyResponse on success or nil if the upstream query fails or processing conditions are not met. +func (p *prog) processUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig, serveStaleCache, lastUpstream bool) *proxyResponse { + if upstreamConfig == nil { + ctrld.Log(ctx, p.Debug(), "Upstream config is nil, skipping") return nil } - for n, upstreamConfig := range upstreamConfigs { - if upstreamConfig == nil { - continue - } - logger := mainLog.Load().Debug(). + if p.isLoop(upstreamConfig) { + logger := p.Debug(). Str("upstream", upstreamConfig.String()). Str("query", req.msg.Question[0].Name). - Bool("is_ad_query", p.isAdDomainQuery(req.msg)). - Bool("is_lan_query", isLanOrPtrQuery) + Bool("is_lan_query", req.isLanOrPtrQuery) + ctrld.Log(ctx, logger, "DNS loop detected") + return nil + } - if p.isLoop(upstreamConfig) { - ctrld.Log(ctx, logger, "DNS loop detected") - continue - } - answer := resolve(upstreams[n], upstreamConfig, req.msg) - if answer == nil { - if serveStaleCache && staleAnswer != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response") - now := time.Now() - setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL)) - res.answer = staleAnswer - res.cached = true - return res - } - continue - } - // We are doing LAN/PTR lookup using private resolver, so always process next one. - // Except for the last, we want to send response instead of saying all upstream failed. - if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { - ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n]) - continue - } - if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) { - ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream") - continue + ctrld.Log(ctx, p.Debug(), "Querying upstream: %s", upstream) + answer := p.queryUpstream(ctx, req, upstream, upstreamConfig) + if answer == nil { + ctrld.Log(ctx, p.Debug(), "Upstream query failed") + if serveStaleCache && req.staleAnswer != nil { + ctrld.Log(ctx, p.Debug(), "Serving stale response due to upstream failure") + return p.serveStaleResponse(ctx, req.staleAnswer) } + return nil + } - // set compression, as it is not set by default when unpacking - answer.Compress = true + ctrld.Log(ctx, p.Debug(), "Upstream query successful") + if p.shouldContinueWithNextUpstream(ctx, req, answer, upstream, lastUpstream) { + return nil + } + return p.prepareSuccessResponse(ctx, req, answer, upstream, upstreamConfig) +} - if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { - ttl := ttlFromMsg(answer) - now := time.Now() - expired := now.Add(time.Duration(ttl) * time.Second) - if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 { - expired = now.Add(time.Duration(cachedTTL) * time.Second) - } - setCachedAnswerTTL(answer, now, expired) - p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired)) - ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response") +// queryUpstream sends a DNS query to a specified upstream using its configuration and handles errors and retries. +func (p *prog) queryUpstream(ctx context.Context, req *proxyRequest, upstream string, upstreamConfig *ctrld.UpstreamConfig) *dns.Msg { + if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { + ctrld.Log(ctx, p.Debug(), "Adding client info to upstream query") + ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) + } + + ctrld.Log(ctx, p.Debug(), "Sending query to %s: %s", upstream, upstreamConfig.Name) + dnsResolver, err := ctrld.NewResolver(ctx, upstreamConfig) + if err != nil { + ctrld.Log(ctx, p.Error().Err(err), "Failed to create resolver") + return nil + } + + resolveCtx, cancel := upstreamConfig.Context(ctx) + defer cancel() + + ctrld.Log(ctx, p.Debug(), "Resolving query with upstream") + answer, err := dnsResolver.Resolve(resolveCtx, req.msg) + if answer != nil { + ctrld.Log(ctx, p.Debug(), "Upstream resolution successful") + p.um.mu.Lock() + p.um.failureReq[upstream] = 0 + p.um.down[upstream] = false + p.um.mu.Unlock() + return answer + } + + ctrld.Log(ctx, p.Error().Err(err), "Failed to resolve query") + // Increasing the failure count when there is no answer regardless of what kind of error we get + p.um.increaseFailureCount(upstream) + if err != nil { + // For timeout error (i.e: context deadline exceed), force re-bootstrapping. + var e net.Error + if errors.As(err, &e) && e.Timeout() { + ctrld.Log(ctx, p.Debug(), "Timeout error, forcing re-bootstrapping") + upstreamConfig.ReBootstrap(ctx) } - hostname := "" - if req.ci != nil { - hostname = req.ci.Hostname + // For network error, turn ipv6 off if enabled. + if ctrld.HasIPv6(ctx) && (errUrlNetworkError(err) || errNetworkError(err)) { + ctrld.Log(ctx, p.Debug(), "Network error, disabling IPv6") + ctrld.DisableIPv6(ctx) } - ctrld.Log(ctx, mainLog.Load().Info(), "REPLY: %s -> %s (%s): %s", upstreams[n], req.ufr.srcAddr, hostname, dns.RcodeToString[answer.Rcode]) - res.answer = answer - res.upstream = upstreamConfig.Endpoint - return res } - ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) + return nil +} - // if we have no healthy upstreams, trigger recovery flow - if p.leakOnUpstreamFailure() { - if p.um.countHealthy(upstreams) == 0 { - p.recoveryCancelMu.Lock() - if p.recoveryCancel == nil { - var reason RecoveryReason - if upstreams[0] == upstreamOS { - reason = RecoveryReasonOSFailure - } else { - reason = RecoveryReasonRegularFailure - } - mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) - go p.handleRecovery(reason) - } else { - mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") - } - p.recoveryCancelMu.Unlock() - } else { - mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") - } +// triggerRecovery attempts to initiate a recovery process if no healthy upstreams are detected. +// If "isOSFailure" is true, the recovery will account for an operating system failure. +// Logs are generated to indicate whether recovery is triggered or already in progress. +func (p *prog) triggerRecovery(isOSFailure bool) { + p.recoveryCancelMu.Lock() + defer p.recoveryCancelMu.Unlock() - // attempt query to OS resolver while as a retry catch all - // we dont want this to happen if leakOnUpstreamFailure is false - if upstreams[0] != upstreamOS { - ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") - answer := resolve(upstreamOS, osUpstreamConfig, req.msg) - if answer != nil { - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") - res.answer = answer - res.upstream = osUpstreamConfig.Endpoint - return res - } - ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed") + if p.recoveryCancel == nil { + var reason RecoveryReason + if isOSFailure { + reason = RecoveryReasonOSFailure + } else { + reason = RecoveryReasonRegularFailure } + p.Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) + go p.handleRecovery(reason) + } else { + p.Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") } +} - answer := new(dns.Msg) - answer.SetRcode(req.msg, dns.RcodeServerFailure) - res.answer = answer - return res +// tryOSResolver attempts to query the OS resolver as a fallback mechanism when other upstreams fail. +// Logs success or failure of the query attempt and returns a proxyResponse or nil based on query result. +func (p *prog) tryOSResolver(ctx context.Context, req *proxyRequest) *proxyResponse { + ctrld.Log(ctx, p.Debug(), "Attempting query to OS resolver as a retry catch all") + answer := p.queryUpstream(ctx, req, upstreamOS, osUpstreamConfig) + if answer != nil { + ctrld.Log(ctx, p.Debug(), "OS resolver retry query successful") + return &proxyResponse{answer: answer, upstream: osUpstreamConfig.Endpoint} + } + ctrld.Log(ctx, p.Debug(), "OS resolver retry query failed") + return nil } +// upstreamsAndUpstreamConfigForPtr returns the updated upstreams and upstreamConfigs for a private PTR lookup scenario. func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { if len(p.localUpstreams) > 0 { tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams)) @@ -667,6 +922,7 @@ func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConf return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...) } +// upstreamConfigsFromUpstreamNumbers converts a list of upstream names into their corresponding UpstreamConfig objects. func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { @@ -676,14 +932,6 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U return upstreamConfigs } -func (p *prog) isAdDomainQuery(msg *dns.Msg) bool { - if p.adDomain == "" { - return false - } - cDomainName := canonicalName(msg.Question[0].Name) - return dns.IsSubDomain(p.adDomain, cDomainName) -} - // canonicalName returns canonical name from FQDN with "." trimmed. func canonicalName(fqdn string) string { q := strings.TrimSpace(fqdn) @@ -720,10 +968,12 @@ func wildcardMatches(wildcard, str string) bool { return false } +// fmtRemoteToLocal formats a remote address to indicate its mapping to a local listener using listener number and hostname. func fmtRemoteToLocal(listenerNum, hostname, remote string) string { return fmt.Sprintf("%s (%s) -> listener.%s", remote, hostname, listenerNum) } +// requestID generates a random 6-character hexadecimal string to uniquely identify a request. It panics on error. func requestID() string { b := make([]byte, 3) // 6 chars if _, err := rand.Read(b); err != nil { @@ -732,15 +982,7 @@ func requestID() string { return hex.EncodeToString(b) } -func containRcode(rcodes []int, rcode int) bool { - for i := range rcodes { - if rcodes[i] == rcode { - return true - } - } - return false -} - +// setCachedAnswerTTL updates the TTL of each DNS record in the provided message based on the current and expiration times. func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) { ttlSecs := expiredTime.Sub(now).Seconds() if ttlSecs < 0 { @@ -761,6 +1003,8 @@ func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) { } } +// ttlFromMsg extracts and returns the TTL value from the first record in the Answer or Ns sections of a DNS message. +// If no records exist in either section, the function returns 0. func ttlFromMsg(msg *dns.Msg) uint32 { for _, rr := range msg.Answer { return rr.Header().Ttl @@ -771,6 +1015,7 @@ func ttlFromMsg(msg *dns.Msg) uint32 { return 0 } +// needLocalIPv6Listener checks if a local IPv6 listener is required on Windows by verifying IPv6 support and the OS type. func needLocalIPv6Listener() bool { // On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can // listen on ::1, then spawn a listener for receiving DNS requests. @@ -842,6 +1087,8 @@ func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr { // // It's the caller responsibility to call Shutdown to close the server. func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) { + mainLog.Load().Debug().Str("address", addr).Str("network", network).Msg("Starting DNS server") + s := &dns.Server{ Addr: addr, Net: network, @@ -856,11 +1103,12 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha defer close(errCh) if err := s.ListenAndServe(); err != nil { s.NotifyStartedFunc() - mainLog.Load().Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr) + mainLog.Load().Error().Err(err).Msgf("Could not listen and serve on: %s", s.Addr) errCh <- err } }() <-startedCh + mainLog.Load().Debug().Str("address", addr).Str("network", network).Msg("DNS server started successfully") return s, errCh } @@ -940,8 +1188,9 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) { // - There is only 1 ControlD upstream in-use. // - Number of refused queries seen so far equals to selfUninstallMaxQueries. // - The cdUID is deleted. -func (p *prog) doSelfUninstall(answer *dns.Msg) { - if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused { +func (p *prog) doSelfUninstall(pr *proxyResponse) { + answer := pr.answer + if pr.refused || !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused { return } @@ -951,15 +1200,16 @@ func (p *prog) doSelfUninstall(answer *dns.Msg) { return } - logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger() + logger := p.logger.Load().With().Str("mode", "self-uninstall") if p.refusedQueryCount > selfUninstallMaxQueries { p.checkingSelfUninstall = true - _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) - logger.Debug().Msg("maximum number of refused queries reached, checking device status") + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) + _, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev) + logger.Debug().Msg("Maximum number of refused queries reached, checking device status") selfUninstallCheck(err, p, logger) if err != nil { - logger.Warn().Err(err).Msg("could not fetch resolver config") + logger.Warn().Err(err).Msg("Could not fetch resolver config") } // Cool-of period to prevent abusing the API. go p.selfUninstallCoolOfPeriod() @@ -1023,7 +1273,7 @@ func (p *prog) queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) regularIPs, loopbackIPs, err := netmon.LocalAddresses() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not get local addresses") + p.Warn().Err(err).Msg("Could not get local addresses") return false } for _, localIP := range slices.Concat(regularIPs, loopbackIPs) { @@ -1143,7 +1393,8 @@ func isWanClient(na net.Addr) bool { // resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller. func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg { - ctrld.Log(ctx, mainLog.Load().Debug(), "internal domain test query") + logger := ctrld.LoggerFromCtx(ctx) + ctrld.Log(ctx, logger.Debug(), "Internal domain test query") q := m.Question[0] answer := new(dns.Msg) @@ -1181,10 +1432,10 @@ func FlushDNSCache() error { } // monitorNetworkChanges starts monitoring for network interface changes -func (p *prog) monitorNetworkChanges() error { +func (p *prog) monitorNetworkChanges(ctx context.Context) error { mon, err := netmon.New(func(format string, args ...any) { // Always fetch the latest logger (and inject the prefix) - mainLog.Load().Printf("netmon: "+format, args...) + p.logger.Load().Printf("netmon: "+format, args...) }) if err != nil { return fmt.Errorf("creating network monitor: %w", err) @@ -1192,11 +1443,11 @@ func (p *prog) monitorNetworkChanges() error { mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { // Get map of valid interfaces - validIfaces := validInterfacesMap() + validIfaces := ctrld.ValidInterfaces(ctrld.LoggerCtx(ctx, p.logger.Load())) isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New) - mainLog.Load().Debug(). + p.Debug(). Interface("old_state", delta.Old). Interface("new_state", delta.New). Bool("is_major_change", isMajorChange). @@ -1224,7 +1475,7 @@ func (p *prog) monitorNetworkChanges() error { if newIface.IsUp() && len(usableNewIPs) > 0 { changed = true changeIPs = usableNewIPs - mainLog.Load().Debug(). + p.Debug(). Str("interface", ifaceName). Interface("new_ips", usableNewIPs). Msg("Interface newly appeared (was not present in old state)") @@ -1246,7 +1497,7 @@ func (p *prog) monitorNetworkChanges() error { if newIface.IsUp() && len(usableNewIPs) > 0 { changed = true changeIPs = usableNewIPs - mainLog.Load().Debug(). + p.Debug(). Str("interface", ifaceName). Interface("old_ips", oldIPs). Interface("new_ips", usableNewIPs). @@ -1259,39 +1510,39 @@ func (p *prog) monitorNetworkChanges() error { // if the default route changed, set changed to true if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface { changed = true - mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface) + p.Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface) } if !changed { - mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") + p.Debug().Msg("Ignoring interface change - no valid interfaces affected") // check if the default IPs are still on an interface that is up ValidateDefaultLocalIPsFromDelta(delta.New) return } if !activeInterfaceExists { - mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") + p.Debug().Msg("No active interfaces found, skipping reinitialization") return } // Get IPs from default route interface in new state - selfIP := defaultRouteIP() + selfIP := p.defaultRouteIP() // Ensure that selfIP is an IPv4 address. // If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil { - mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP) + p.Debug().Msgf("DefaultRouteIP returned a non-ipv4 address: %s, ignoring it", selfIP) selfIP = "" } var ipv6 string if delta.New.DefaultRouteInterface != "" { - mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) + p.Debug().Msgf("Default route interface: %s, ips: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] { ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() if selfIP == "" && addr.Is4() { - mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + p.Debug().Msgf("Checking ip: %s", addr.String()) if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { selfIP = addr.String() } @@ -1302,12 +1553,12 @@ func (p *prog) monitorNetworkChanges() error { } } else { // If no default route interface is set yet, use the changed IPs - mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs) + p.Debug().Msgf("No default route interface found, using changed ips: %v", changeIPs) for _, ip := range changeIPs { ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() if selfIP == "" && addr.Is4() { - mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + p.Debug().Msgf("Checking ip: %s", addr.String()) if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { selfIP = addr.String() } @@ -1320,24 +1571,21 @@ func (p *prog) monitorNetworkChanges() error { // Only set the IPv4 default if selfIP is a valid IPv4 address. if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil { - ctrld.SetDefaultLocalIPv4(ip) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(ctx, p.logger.Load()), ip) if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) } } if ip := net.ParseIP(ipv6); ip != nil { - ctrld.SetDefaultLocalIPv6(ip) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(ctx, p.logger.Load()), ip) } - mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) + p.Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) - // we only trigger recovery flow for network changes on non router devices - if router.Name() == "" { - p.handleRecovery(RecoveryReasonNetworkChange) - } + p.handleRecovery(RecoveryReasonNetworkChange) }) mon.Start() - mainLog.Load().Debug().Msg("Network monitor started") + p.Debug().Msg("Network monitor started") return nil } @@ -1392,11 +1640,11 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { // checkUpstreamOnce sends a test query to the specified upstream. // Returns nil if the upstream responds successfully. func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { - mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) + p.Debug().Msgf("Starting check for upstream: %s", upstream) - resolver, err := ctrld.NewResolver(uc) + resolver, err := ctrld.NewResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), uc) if err != nil { - mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) + p.Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) return err } @@ -1404,13 +1652,13 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro if uc.Timeout > 0 { timeout = time.Millisecond * time.Duration(uc.Timeout) } - mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) + p.Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - uc.ReBootstrap() - mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) + uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load())) + p.Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) start := time.Now() msg := uc.VerifyMsg() @@ -1418,89 +1666,138 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro duration := time.Since(start) if err != nil { - mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) + p.Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) } else { - mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) + p.Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) } return err } -// handleRecovery performs a unified recovery by removing DNS settings, -// canceling existing recovery checks for network changes, but coalescing duplicate -// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout), -// and then re-applying the DNS settings. +// handleRecovery orchestrates the recovery process by coordinating multiple smaller methods. +// It handles recovery cancellation logic, creates recovery context, prepares the system, +// waits for upstream recovery with timeout, and completes the recovery process. +// The method is designed to be called from a goroutine and handles different recovery reasons +// (network changes, regular failures, OS failures) with appropriate logic for each. func (p *prog) handleRecovery(reason RecoveryReason) { - mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings") + p.Debug().Msg("Starting recovery process: removing DNS settings") + + // Handle recovery cancellation based on reason + if !p.shouldStartRecovery(reason) { + return + } + + // Create recovery context and cleanup function + recoveryCtx, cleanup := p.createRecoveryContext() + defer cleanup() + + // Remove DNS settings and prepare for recovery + if err := p.prepareForRecovery(reason); err != nil { + p.Error().Err(err).Msg("Failed to prepare for recovery") + return + } + + // Build upstream map based on the recovery reason + upstreams := p.buildRecoveryUpstreams(reason) + + // Wait for upstream recovery + recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) + if err != nil { + p.Error().Err(err).Msg("Recovery failed; DNS settings remain removed") + return + } + + // Complete recovery process + if err := p.completeRecovery(reason, recovered); err != nil { + p.Error().Err(err).Msg("Failed to complete recovery") + return + } + + p.Info().Msgf("Recovery completed successfully for upstream %q", recovered) +} + +// shouldStartRecovery determines if recovery should start based on the reason and current state. +// Returns true if recovery should proceed, false otherwise. +func (p *prog) shouldStartRecovery(reason RecoveryReason) bool { + p.recoveryCancelMu.Lock() + defer p.recoveryCancelMu.Unlock() - // For network changes, cancel any existing recovery check because the network state has changed. if reason == RecoveryReasonNetworkChange { - p.recoveryCancelMu.Lock() + // For network changes, cancel any existing recovery check because the network state has changed. if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)") + p.Debug().Msg("Cancelling existing recovery check (network change)") p.recoveryCancel() p.recoveryCancel = nil } - p.recoveryCancelMu.Unlock() - } else { - // For upstream failures, if a recovery is already in progress, do nothing new. - p.recoveryCancelMu.Lock() - if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") - p.recoveryCancelMu.Unlock() - return - } - p.recoveryCancelMu.Unlock() + return true } - // Create a new recovery context without a fixed timeout. + // For upstream failures, if a recovery is already in progress, do nothing new. + if p.recoveryCancel != nil { + p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") + return false + } + + return true +} + +// createRecoveryContext creates a new recovery context and returns it along with a cleanup function. +func (p *prog) createRecoveryContext() (context.Context, func()) { p.recoveryCancelMu.Lock() recoveryCtx, cancel := context.WithCancel(context.Background()) p.recoveryCancel = cancel p.recoveryCancelMu.Unlock() - // Immediately remove our DNS settings from the interface. - // set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface + cleanup := func() { + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() + } + + return recoveryCtx, cleanup +} + +// prepareForRecovery removes DNS settings and initializes OS resolver if needed. +func (p *prog) prepareForRecovery(reason RecoveryReason) error { + // Set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface p.recoveryRunning.Store(true) - // we do not want to restore any static DNS settings + + // Remove DNS settings - we do not want to restore any static DNS settings // we must try to get the DHCP values, any static DNS settings // will be appended to nameservers from the saved interface values p.resetDNS(false, false) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { - mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") - ns := ctrld.InitializeOsResolver(true) - if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") - } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + if err := p.reinitializeOSResolver("OS resolver failure detected"); err != nil { + return fmt.Errorf("failed to reinitialize OS resolver: %w", err) } } - // Build upstream map based on the recovery reason. - upstreams := p.buildRecoveryUpstreams(reason) + return nil +} - // Wait indefinitely until one of the upstreams recovers. - recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) - if err != nil { - mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") - p.recoveryCancelMu.Lock() - p.recoveryCancel = nil - p.recoveryCancelMu.Unlock() - return +// reinitializeOSResolver reinitializes the OS resolver and logs the results. +func (p *prog) reinitializeOSResolver(message string) error { + p.Debug().Msg(message) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) + ns := ctrld.InitializeOsResolver(loggerCtx, true) + if len(ns) == 0 { + p.Warn().Msg("No nameservers found for OS resolver; using existing values") + } else { + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } - mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + return nil +} - // reset the upstream failure count and down state +// completeRecovery completes the recovery process by resetting upstream state and reapplying DNS settings. +func (p *prog) completeRecovery(reason RecoveryReason, recovered string) error { + // Reset the upstream failure count and down state p.um.reset(recovered) // For network changes we also reinitialize the OS resolver. if reason == RecoveryReasonNetworkChange { - ns := ctrld.InitializeOsResolver(true) - if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") - } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + if err := p.reinitializeOSResolver("Network change detected during recovery"); err != nil { + return fmt.Errorf("failed to reinitialize OS resolver during network change: %w", err) } } @@ -1508,13 +1805,10 @@ func (p *prog) handleRecovery(reason RecoveryReason) { p.setDNS() p.logInterfacesState() - // allow watchdogs to put the listener back on the interface if its changed for any reason + // Allow watchdogs to put the listener back on the interface if it's changed for any reason p.recoveryRunning.Store(false) - // Clear the recovery cancellation for a clean slate. - p.recoveryCancelMu.Lock() - p.recoveryCancel = nil - p.recoveryCancelMu.Unlock() + return nil } // waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers. @@ -1523,44 +1817,44 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string recoveredCh := make(chan string, 1) var wg sync.WaitGroup - mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) + p.Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) for name, uc := range upstreams { wg.Add(1) go func(name string, uc *ctrld.UpstreamConfig) { defer wg.Done() - mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name) + p.Debug().Msgf("Starting recovery check loop for upstream: %s", name) attempts := 0 for { select { case <-ctx.Done(): - mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name) + p.Debug().Msgf("Context canceled for upstream %s", name) return default: attempts++ // checkUpstreamOnce will reset any failure counters on success. if err := p.checkUpstreamOnce(name, uc); err == nil { - mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name) + p.Debug().Msgf("Upstream %s recovered successfully", name) select { case recoveredCh <- name: - mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name) + p.Debug().Msgf("Sent recovery notification for upstream %s", name) default: - mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered") + p.Debug().Msg("Recovery channel full, another upstream already recovered") } return } - mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name) + p.Debug().Msgf("Upstream %s check failed, sleeping before retry", name) time.Sleep(checkUpstreamBackoffSleep) // if this is the upstreamOS and it's the 3rd attempt (or multiple of 3), // we should try to reinit the OS resolver to ensure we can recover if name == upstreamOS && attempts%3 == 0 { - mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) - ns := ctrld.InitializeOsResolver(true) + p.Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(ctx, p.logger.Load()), true) if len(ns) == 0 { - mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + p.Warn().Msg("No nameservers found for OS resolver; using existing values") } else { - mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } } } @@ -1616,12 +1910,12 @@ func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) { // Check if the default IPv4 is still active. if currentIPv4 != nil && !activeIPs[currentIPv4.String()] { mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4) - ctrld.SetDefaultLocalIPv4(nil) + ctrld.SetDefaultLocalIPv4(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil) } // Check if the default IPv6 is still active. if currentIPv6 != nil && !activeIPs[currentIPv6.String()] { mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6) - ctrld.SetDefaultLocalIPv6(nil) + ctrld.SetDefaultLocalIPv6(ctrld.LoggerCtx(context.Background(), mainLog.Load()), nil) } } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 4a4e5b4e..6f5f7f05 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -77,7 +77,8 @@ func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false) p := &prog{cfg: cfg} - p.um = newUpstreamMonitor(p.cfg) + p.logger.Store(mainLog.Load()) + p.um = newUpstreamMonitor(p.cfg, mainLog.Load()) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() for _, nc := range p.cfg.Network { @@ -142,9 +143,94 @@ func Test_prog_upstreamFor(t *testing.T) { } } +func Test_prog_upstreamForWithCustomMatching(t *testing.T) { + cfg := testhelper.SampleConfig(t) + prog := &prog{cfg: cfg} + prog.logger.Store(mainLog.Load()) + for _, nc := range prog.cfg.Network { + for _, cidr := range nc.Cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatal(err) + } + nc.IPNets = append(nc.IPNets, ipNet) + } + } + + // Create a custom policy with domain-first matching order + customPolicy := &ctrld.ListenerPolicyConfig{ + Name: "Custom Policy", + Networks: []ctrld.Rule{ + {"network.0": []string{"upstream.1", "upstream.0"}}, + }, + Macs: []ctrld.Rule{ + {"14:45:A0:67:83:0A": []string{"upstream.2"}}, + }, + Rules: []ctrld.Rule{ + {"*.ru": []string{"upstream.1"}}, + }, + Matching: &ctrld.MatchingConfig{ + Order: []string{"domain", "mac", "network"}, + }, + } + + customListener := &ctrld.ListenerConfig{ + Policy: customPolicy, + } + + tests := []struct { + name string + ip string + mac string + domain string + upstreams []string + matched bool + }{ + { + name: "Domain rule should match first with custom order", + ip: "192.168.0.1:0", + mac: "14:45:A0:67:83:0A", + domain: "example.ru", + upstreams: []string{"upstream.1"}, + matched: true, + }, + { + name: "MAC rule should match when no domain rule", + ip: "192.168.0.1:0", + mac: "14:45:A0:67:83:0A", + domain: "example.com", + upstreams: []string{"upstream.2"}, + matched: true, + }, + { + name: "Network rule should match when no domain or MAC rule", + ip: "192.168.0.1:0", + mac: "00:11:22:33:44:55", + domain: "example.com", + upstreams: []string{"upstream.1", "upstream.0"}, + matched: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", tc.ip) + require.NoError(t, err) + require.NotNil(t, addr) + + ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID()) + ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain) + + assert.Equal(t, tc.matched, ufr.matched) + assert.Equal(t, tc.upstreams, ufr.upstreams) + }) + } +} + func TestCache(t *testing.T) { cfg := testhelper.SampleConfig(t) prog := &prog{cfg: cfg} + prog.logger.Store(mainLog.Load()) for _, nc := range prog.cfg.Network { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) @@ -464,3 +550,254 @@ func Test_isWanClient(t *testing.T) { }) } } + +func Test_shouldStartRecovery(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + hasExistingRecovery bool + expectedResult bool + description string + }{ + { + name: "network change with existing recovery", + reason: RecoveryReasonNetworkChange, + hasExistingRecovery: true, + expectedResult: true, + description: "should cancel existing recovery and start new one for network change", + }, + { + name: "network change without existing recovery", + reason: RecoveryReasonNetworkChange, + hasExistingRecovery: false, + expectedResult: true, + description: "should start new recovery for network change", + }, + { + name: "regular failure with existing recovery", + reason: RecoveryReasonRegularFailure, + hasExistingRecovery: true, + expectedResult: false, + description: "should skip duplicate recovery for regular failure", + }, + { + name: "regular failure without existing recovery", + reason: RecoveryReasonRegularFailure, + hasExistingRecovery: false, + expectedResult: true, + description: "should start new recovery for regular failure", + }, + { + name: "OS failure with existing recovery", + reason: RecoveryReasonOSFailure, + hasExistingRecovery: true, + expectedResult: false, + description: "should skip duplicate recovery for OS failure", + }, + { + name: "OS failure without existing recovery", + reason: RecoveryReasonOSFailure, + hasExistingRecovery: false, + expectedResult: true, + description: "should start new recovery for OS failure", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + // Setup existing recovery if needed + if tc.hasExistingRecovery { + p.recoveryCancelMu.Lock() + p.recoveryCancel = func() {} // Mock cancel function + p.recoveryCancelMu.Unlock() + } + + result := p.shouldStartRecovery(tc.reason) + assert.Equal(t, tc.expectedResult, result, tc.description) + }) + } +} + +func Test_createRecoveryContext(t *testing.T) { + p := newTestProg(t) + + ctx, cleanup := p.createRecoveryContext() + + // Verify context is created + assert.NotNil(t, ctx) + assert.NotNil(t, cleanup) + + // Verify recoveryCancel is set + p.recoveryCancelMu.Lock() + assert.NotNil(t, p.recoveryCancel) + p.recoveryCancelMu.Unlock() + + // Test cleanup function + cleanup() + + // Verify recoveryCancel is cleared + p.recoveryCancelMu.Lock() + assert.Nil(t, p.recoveryCancel) + p.recoveryCancelMu.Unlock() +} + +func Test_prepareForRecovery(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + wantErr bool + }{ + { + name: "regular failure", + reason: RecoveryReasonRegularFailure, + wantErr: false, + }, + { + name: "network change", + reason: RecoveryReasonNetworkChange, + wantErr: false, + }, + { + name: "OS failure", + reason: RecoveryReasonOSFailure, + wantErr: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + err := p.prepareForRecovery(tc.reason) + + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Verify recoveryRunning is set to true + assert.True(t, p.recoveryRunning.Load()) + }) + } +} + +func Test_completeRecovery(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + recovered string + wantErr bool + }{ + { + name: "regular failure recovery", + reason: RecoveryReasonRegularFailure, + recovered: "upstream1", + wantErr: false, + }, + { + name: "network change recovery", + reason: RecoveryReasonNetworkChange, + recovered: "upstream2", + wantErr: false, + }, + { + name: "OS failure recovery", + reason: RecoveryReasonOSFailure, + recovered: "upstream3", + wantErr: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + err := p.completeRecovery(tc.reason, tc.recovered) + + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Verify recoveryRunning is set to false + assert.False(t, p.recoveryRunning.Load()) + }) + } +} + +func Test_reinitializeOSResolver(t *testing.T) { + p := newTestProg(t) + + err := p.reinitializeOSResolver("Test message") + + // This function should not return an error under normal circumstances + // The actual behavior depends on the OS resolver implementation + assert.NoError(t, err) +} + +func Test_handleRecovery_Integration(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + wantErr bool + }{ + { + name: "network change recovery", + reason: RecoveryReasonNetworkChange, + wantErr: false, + }, + { + name: "regular failure recovery", + reason: RecoveryReasonRegularFailure, + wantErr: false, + }, + { + name: "OS failure recovery", + reason: RecoveryReasonOSFailure, + wantErr: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + // This is an integration test that exercises the full recovery flow + // In a real test environment, you would mock the dependencies + // For now, we're just testing that the method doesn't panic + // and that the recovery logic flows correctly + assert.NotPanics(t, func() { + // Test only the preparation phase to avoid actual upstream checking + if !p.shouldStartRecovery(tc.reason) { + return + } + + _, cleanup := p.createRecoveryContext() + defer cleanup() + + if err := p.prepareForRecovery(tc.reason); err != nil { + return + } + + // Skip the actual upstream recovery check for this test + // as it requires properly configured upstreams + }) + }) + } +} + +// newTestProg creates a properly initialized *prog for testing. +func newTestProg(t *testing.T) *prog { + p := &prog{cfg: testhelper.SampleConfig(t)} + p.logger.Store(mainLog.Load()) + p.um = newUpstreamMonitor(p.cfg, mainLog.Load()) + return p +} diff --git a/cmd/cli/hostname.go b/cmd/cli/hostname.go index d28435db..5b091c29 100644 --- a/cmd/cli/hostname.go +++ b/cmd/cli/hostname.go @@ -4,11 +4,15 @@ import "regexp" // validHostname reports whether hostname is a valid hostname. // A valid hostname contains 3 -> 64 characters and conform to RFC1123. +// This function validates hostnames to ensure they meet DNS naming standards +// and prevents invalid hostnames from being used in DNS configurations func validHostname(hostname string) bool { hostnameLen := len(hostname) if hostnameLen < 3 || hostnameLen > 64 { return false } + // RFC1123 regex pattern ensures hostnames follow DNS naming conventions + // This prevents issues with DNS resolution and system compatibility validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) return validHostnameRfc1123.MatchString(hostname) } diff --git a/cmd/cli/http_log.go b/cmd/cli/http_log.go new file mode 100644 index 00000000..c794cf00 --- /dev/null +++ b/cmd/cli/http_log.go @@ -0,0 +1,172 @@ +package cli + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "sync" +) + +// HTTP log server endpoint constants +const ( + httpLogEndpointPing = "/ping" + httpLogEndpointLogs = "/logs" + httpLogEndpointExit = "/exit" +) + +// httpLogClient sends logs to an HTTP server via POST requests. +// This replaces the logConn functionality with HTTP-based communication. +type httpLogClient struct { + baseURL string + client *http.Client +} + +// newHTTPLogClient creates a new HTTP log client +func newHTTPLogClient(sockPath string) *httpLogClient { + return &httpLogClient{ + baseURL: "http://unix", + client: &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + }, + } +} + +// Write sends log data to the HTTP server via POST request +func (hlc *httpLogClient) Write(b []byte) (int, error) { + // Send log data via HTTP POST to /logs endpoint + resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointLogs, "text/plain", bytes.NewReader(b)) + if err != nil { + // Ignore errors to prevent log pollution, just like the original logConn + return len(b), nil + } + resp.Body.Close() + return len(b), nil +} + +// Ping tests if the HTTP log server is available +func (hlc *httpLogClient) Ping() error { + resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointPing) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +// Close sends exit signal to the HTTP server +func (hlc *httpLogClient) Close() error { + // Send exit signal via HTTP POST with empty body + resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{})) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +// GetLogs retrieves all collected logs from the HTTP server +func (hlc *httpLogClient) GetLogs() ([]byte, error) { + resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointLogs) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNoContent { + return []byte{}, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// httpLogServer starts an HTTP server listening on unix socket to collect logs from runCmd. +func httpLogServer(sockPath string, stopLogCh chan struct{}) error { + addr, err := net.ResolveUnixAddr("unix", sockPath) + if err != nil { + return fmt.Errorf("invalid log sock path: %w", err) + } + + ln, err := net.ListenUnix("unix", addr) + if err != nil { + return fmt.Errorf("could not listen log socket: %w", err) + } + defer ln.Close() + + // Create a log writer to store all logs + logWriter := newLogWriter() + + // Use a sync.Once to ensure channel is only closed once + var channelClosed sync.Once + + mux := http.NewServeMux() + mux.HandleFunc(httpLogEndpointPing, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + w.WriteHeader(http.StatusOK) + }) + + mux.HandleFunc(httpLogEndpointLogs, func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + // POST /logs - Store log data + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + + // Store log data in log writer + logWriter.Write(body) + + w.WriteHeader(http.StatusOK) + + case http.MethodGet: + // GET /logs - Retrieve all logs + // Get all logs from the log writer + logWriter.mu.Lock() + logs := logWriter.buf.Bytes() + logWriter.mu.Unlock() + + if len(logs) == 0 { + w.WriteHeader(http.StatusNoContent) + return + } + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write(logs) + + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + + mux.HandleFunc(httpLogEndpointExit, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Close the stop channel to signal completion (only once) + channelClosed.Do(func() { + close(stopLogCh) + }) + w.WriteHeader(http.StatusOK) + }) + + server := &http.Server{Handler: mux} + return server.Serve(ln) +} diff --git a/cmd/cli/http_log_test.go b/cmd/cli/http_log_test.go new file mode 100644 index 00000000..ad664d49 --- /dev/null +++ b/cmd/cli/http_log_test.go @@ -0,0 +1,747 @@ +package cli + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "golang.org/x/net/nettest" +) + +func unixDomainSocketPath(t *testing.T) string { + t.Helper() + sockPath, err := nettest.LocalPath() + if err != nil { + t.Fatalf("Failed to create temporary directory: %v", err) + } + return sockPath +} + +func TestHTTPLogServer(t *testing.T) { + sockPath := unixDomainSocketPath(t) + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait a bit for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + t.Run("Ping endpoint", func(t *testing.T) { + resp, err := client.Get("http://unix" + httpLogEndpointPing) + if err != nil { + t.Fatalf("Failed to ping server: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + }) + + t.Run("Ping endpoint wrong method", func(t *testing.T) { + resp, err := client.Post("http://unix"+httpLogEndpointPing, "text/plain", bytes.NewReader([]byte("test"))) + if err != nil { + t.Fatalf("Failed to send POST to ping: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", resp.StatusCode) + } + }) + + t.Run("Log endpoint", func(t *testing.T) { + testLog := "test log message" + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(testLog))) + if err != nil { + t.Fatalf("Failed to send log: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if log was stored by retrieving it + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + if !strings.Contains(string(body), testLog) { + t.Errorf("Expected log '%s' not found in stored logs", testLog) + } + }) + + t.Run("Log endpoint wrong method", func(t *testing.T) { + // Test unsupported method (PUT) on /logs endpoint + req, err := http.NewRequest("PUT", "http://unix"+httpLogEndpointLogs, bytes.NewReader([]byte("test"))) + if err != nil { + t.Fatalf("Failed to create PUT request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to send PUT to logs: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", resp.StatusCode) + } + }) + + t.Run("Exit endpoint", func(t *testing.T) { + resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{})) + if err != nil { + t.Fatalf("Failed to send exit: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if channel is closed by trying to read from it + select { + case _, ok := <-stopLogCh: + if ok { + t.Error("Expected channel to be closed, but it's still open") + } + case <-time.After(1 * time.Second): + t.Error("Timeout waiting for channel closure") + } + }) + + t.Run("Exit endpoint wrong method", func(t *testing.T) { + resp, err := client.Get("http://unix" + httpLogEndpointExit) + if err != nil { + t.Fatalf("Failed to send GET to exit: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", resp.StatusCode) + } + }) + + t.Run("Multiple log messages", func(t *testing.T) { + logs := []string{"log1", "log2", "log3"} + + for _, log := range logs { + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n"))) + if err != nil { + t.Fatalf("Failed to send log '%s': %v", log, err) + } + resp.Body.Close() + } + + // Check if all logs were stored by retrieving them + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + logContent := string(body) + for i, expectedLog := range logs { + if !strings.Contains(logContent, expectedLog) { + t.Errorf("Log %d: expected '%s' not found in stored logs", i, expectedLog) + } + } + }) + + t.Run("Large log message", func(t *testing.T) { + largeLog := strings.Repeat("a", 1024*10) // 10KB log message + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(largeLog))) + if err != nil { + t.Fatalf("Failed to send large log: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if large log was stored by retrieving it + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + if !strings.Contains(string(body), largeLog) { + t.Error("Large log message was not stored correctly") + } + }) + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogServerInvalidSocketPath(t *testing.T) { + // Test with invalid socket path + invalidPath := "/invalid/path/that/does/not/exist.sock" + stopLogCh := make(chan struct{}) + + err := httpLogServer(invalidPath, stopLogCh) + if err == nil { + t.Error("Expected error for invalid socket path") + } + + if !strings.Contains(err.Error(), "could not listen log socket") { + t.Errorf("Expected 'could not listen log socket' error, got: %v", err) + } +} + +func TestHTTPLogServerSocketInUse(t *testing.T) { + // Create a temporary socket path + sockPath := unixDomainSocketPath(t) + defer os.Remove(sockPath) + + // Create the first server + stopLogCh1 := make(chan struct{}) + serverErr1 := make(chan error, 1) + go func() { + serverErr1 <- httpLogServer(sockPath, stopLogCh1) + }() + + // Wait for first server to start + time.Sleep(100 * time.Millisecond) + + // Try to create a second server on the same socket + stopLogCh2 := make(chan struct{}) + err := httpLogServer(sockPath, stopLogCh2) + if err == nil { + t.Error("Expected error when socket is already in use") + } + + if !strings.Contains(err.Error(), "could not listen log socket") { + t.Errorf("Expected 'could not listen log socket' error, got: %v", err) + } +} + +func TestHTTPLogServerConcurrentRequests(t *testing.T) { + // Create a temporary socket path + sockPath := unixDomainSocketPath(t) + defer os.Remove(sockPath) + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + // Send concurrent requests + numRequests := 10 + done := make(chan bool, numRequests) + + for i := 0; i < numRequests; i++ { + go func(i int) { + defer func() { done <- true }() + + logMsg := fmt.Sprintf("concurrent log %d", i) + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg))) + if err != nil { + t.Errorf("Failed to send concurrent log %d: %v", i, err) + return + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for request %d, got %d", i, resp.StatusCode) + } + }(i) + } + + // Wait for all requests to complete + for i := 0; i < numRequests; i++ { + select { + case <-done: + // Request completed + case <-time.After(5 * time.Second): + t.Errorf("Timeout waiting for concurrent request %d", i) + } + } + + // Check if all logs were stored by retrieving them + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + logContent := string(body) + // Verify all logs were stored + for i := 0; i < numRequests; i++ { + expectedLog := fmt.Sprintf("concurrent log %d", i) + if !strings.Contains(logContent, expectedLog) { + t.Errorf("Log '%s' was not stored", expectedLog) + } + } +} + +func TestHTTPLogServerErrorHandling(t *testing.T) { + // Create a temporary socket path + sockPath := unixDomainSocketPath(t) + defer os.Remove(sockPath) + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + t.Run("Invalid request body", func(t *testing.T) { + // Test with malformed request - this will fail at HTTP level, not server level + // The server will return 400 Bad Request for invalid body + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", strings.NewReader("")) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Empty body should still be processed successfully + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + }) +} + +func BenchmarkHTTPLogServer(b *testing.B) { + // Create a temporary socket path + tmpDir := b.TempDir() + sockPath := filepath.Join(tmpDir, "bench.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + go func() { + httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + // Benchmark log sending + b.ResetTimer() + for i := 0; i < b.N; i++ { + logMsg := fmt.Sprintf("benchmark log %d", i) + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg))) + if err != nil { + b.Fatalf("Failed to send log: %v", err) + } + resp.Body.Close() + } + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogClient(t *testing.T) { + // Create a temporary socket path + sockPath := unixDomainSocketPath(t) + defer os.Remove(sockPath) + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP log client + client := newHTTPLogClient(sockPath) + + t.Run("Ping server", func(t *testing.T) { + err := client.Ping() + if err != nil { + t.Errorf("Ping failed: %v", err) + } + }) + + t.Run("Write logs", func(t *testing.T) { + testLog := "test log message from client" + n, err := client.Write([]byte(testLog)) + if err != nil { + t.Errorf("Write failed: %v", err) + } + if n != len(testLog) { + t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n) + } + + // Check if log was stored by retrieving it + logs, err := client.GetLogs() + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + + if !strings.Contains(string(logs), testLog) { + t.Errorf("Expected log '%s' not found in stored logs", testLog) + } + }) + + t.Run("Close client", func(t *testing.T) { + err := client.Close() + if err != nil { + t.Errorf("Close failed: %v", err) + } + + // Check if channel is closed (signaling completion) + select { + case _, ok := <-stopLogCh: + if ok { + t.Error("Expected channel to be closed, but it's still open") + } + case <-time.After(1 * time.Second): + t.Error("Timeout waiting for channel closure") + } + }) +} + +func TestHTTPLogClientServerUnavailable(t *testing.T) { + // Create client with non-existent socket + sockPath := "/non/existent/socket.sock" + client := newHTTPLogClient(sockPath) + + t.Run("Ping unavailable server", func(t *testing.T) { + err := client.Ping() + if err == nil { + t.Error("Expected ping to fail for unavailable server") + } + }) + + t.Run("Write to unavailable server", func(t *testing.T) { + testLog := "test log message" + n, err := client.Write([]byte(testLog)) + if err != nil { + t.Errorf("Write should not return error (ignores errors): %v", err) + } + if n != len(testLog) { + t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n) + } + }) + + t.Run("Close unavailable server", func(t *testing.T) { + err := client.Close() + if err == nil { + t.Error("Expected close to fail for unavailable server") + } + }) +} + +func BenchmarkHTTPLogClient(b *testing.B) { + // Create a temporary socket path + tmpDir := b.TempDir() + sockPath := filepath.Join(tmpDir, "bench.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + go func() { + httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP log client + client := newHTTPLogClient(sockPath) + + // Benchmark client writes + b.ResetTimer() + for i := 0; i < b.N; i++ { + logMsg := fmt.Sprintf("benchmark write %d", i) + client.Write([]byte(logMsg)) + } + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogServerWithLogWriter(t *testing.T) { + // Create a temporary socket path + sockPath := unixDomainSocketPath(t) + defer os.Remove(sockPath) + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait a bit for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + t.Run("Store and retrieve logs", func(t *testing.T) { + // Send multiple log messages + logs := []string{"log message 1", "log message 2", "log message 3"} + + for _, log := range logs { + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n"))) + if err != nil { + t.Fatalf("Failed to send log '%s': %v", log, err) + } + resp.Body.Close() + } + + // Retrieve all logs + resp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read logs response: %v", err) + } + + logContent := string(body) + for _, log := range logs { + if !strings.Contains(logContent, log) { + t.Errorf("Expected log '%s' not found in retrieved logs", log) + } + } + }) + + t.Run("Empty logs endpoint", func(t *testing.T) { + // Create a new server for this test + sockPath2 := unixDomainSocketPath(t) + stopLogCh2 := make(chan struct{}) + + go func() { + httpLogServer(sockPath2, stopLogCh2) + }() + time.Sleep(100 * time.Millisecond) + + client2 := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath2) + }, + }, + } + + resp, err := client2.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + t.Errorf("Expected status 204, got %d", resp.StatusCode) + } + + os.Remove(sockPath2) + }) + + t.Run("Channel closure on exit", func(t *testing.T) { + // Send exit signal + resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{})) + if err != nil { + t.Fatalf("Failed to send exit: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if channel is closed by trying to read from it + select { + case _, ok := <-stopLogCh: + if ok { + t.Error("Expected channel to be closed, but it's still open") + } + case <-time.After(1 * time.Second): + t.Error("Timeout waiting for channel closure") + } + }) +} + +func TestHTTPLogClientGetLogs(t *testing.T) { + // Create a temporary socket path + sockPath := unixDomainSocketPath(t) + defer os.Remove(sockPath) + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + go func() { + httpLogServer(sockPath, stopLogCh) + }() + + // Wait a bit for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP log client + client := newHTTPLogClient(sockPath) + + t.Run("Get logs from client", func(t *testing.T) { + // Send some logs + testLogs := []string{"client log 1", "client log 2", "client log 3"} + for _, log := range testLogs { + client.Write([]byte(log + "\n")) + } + + // Retrieve logs using client method + logs, err := client.GetLogs() + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + + logContent := string(logs) + for _, log := range testLogs { + if !strings.Contains(logContent, log) { + t.Errorf("Expected log '%s' not found in retrieved logs", log) + } + } + }) + + t.Run("Get empty logs", func(t *testing.T) { + // Create a new client for empty logs test + sockPath2 := unixDomainSocketPath(t) + stopLogCh2 := make(chan struct{}) + + go func() { + httpLogServer(sockPath2, stopLogCh2) + }() + time.Sleep(100 * time.Millisecond) + + client2 := newHTTPLogClient(sockPath2) + logs, err := client2.GetLogs() + if err != nil { + t.Fatalf("Failed to get empty logs: %v", err) + } + + if len(logs) != 0 { + t.Errorf("Expected empty logs, got %d bytes", len(logs)) + } + + os.Remove(sockPath2) + }) +} diff --git a/cmd/cli/library.go b/cmd/cli/library.go index 7847dd7f..52474401 100644 --- a/cmd/cli/library.go +++ b/cmd/cli/library.go @@ -9,6 +9,7 @@ import ( // AppCallback provides hooks for injecting certain functionalities // from mobile platforms to main ctrld cli. +// This allows mobile applications to customize behavior without modifying core CLI code type AppCallback struct { HostName func() string LanIp func() string @@ -17,6 +18,7 @@ type AppCallback struct { } // AppConfig allows overwriting ctrld cli flags from mobile platforms. +// This provides a clean interface for mobile apps to configure ctrld behavior type AppConfig struct { CdUID string ProvisionID string @@ -27,18 +29,29 @@ type AppConfig struct { LogPath string } +// Network and HTTP configuration constants const ( + // defaultHTTPTimeout provides reasonable timeout for HTTP operations + // This prevents hanging requests while allowing sufficient time for network delays defaultHTTPTimeout = 30 * time.Second - defaultMaxRetries = 3 - downloadServerIp = "23.171.240.151" + + // defaultMaxRetries provides retry attempts for failed HTTP requests + // This improves reliability in unstable network conditions + defaultMaxRetries = 3 + + // downloadServerIp is the fallback IP for download operations + // This ensures downloads work even when DNS resolution fails + downloadServerIp = "23.171.240.151" ) // httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback +// This ensures reliable HTTP operations by preferring IPv4 and handling timeouts gracefully func httpClientWithFallback(timeout time.Duration) *http.Client { return &http.Client{ Timeout: timeout, Transport: &http.Transport{ // Prefer IPv4 over IPv6 + // This improves compatibility with networks that have IPv6 issues DialContext: (&net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, @@ -49,6 +62,7 @@ func httpClientWithFallback(timeout time.Duration) *http.Client { } // doWithRetry performs an HTTP request with retries +// This improves reliability by automatically retrying failed requests with exponential backoff func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) { var lastErr error client := httpClientWithFallback(defaultHTTPTimeout) @@ -60,7 +74,8 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, } for attempt := 0; attempt < maxRetries; attempt++ { if attempt > 0 { - time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff + // Linear backoff reduces server load and improves success rate + time.Sleep(time.Second * time.Duration(attempt+1)) } resp, err := client.Do(req) @@ -68,8 +83,8 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, return resp, nil } if ipReq != nil { - mainLog.Load().Warn().Err(err).Msgf("dial to %q failed", req.Host) - mainLog.Load().Warn().Msgf("fallback to direct IP to download prod version: %q", ip) + mainLog.Load().Warn().Err(err).Msgf("Dial to %q failed", req.Host) + mainLog.Load().Warn().Msgf("Fallback to direct ip to download prod version: %q", ip) resp, err = client.Do(ipReq) if err == nil { return resp, nil @@ -86,6 +101,7 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, } // Helper for making GET requests with retries +// This provides a simplified interface for common GET operations with built-in retry logic func getWithRetry(url string, ip string) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index ab6b855f..13b3cf3f 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -6,39 +6,105 @@ import ( "fmt" "io" "os" + "regexp" "strings" "sync" "time" - "github.com/rs/zerolog" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" "github.com/Control-D-Inc/ctrld" ) +// Log writer constants for buffer management and log formatting const ( - logWriterSize = 1024 * 1024 * 5 // 5 MB - logWriterSmallSize = 1024 * 1024 * 1 // 1 MB - logWriterInitialSize = 32 * 1024 // 32 KB - logWriterSentInterval = time.Minute + // logWriterSize is the default buffer size for log writers + // This provides sufficient space for runtime logs without excessive memory usage + logWriterSize = 1024 * 1024 * 5 // 5 MB + + // logWriterSmallSize is used for memory-constrained environments + // This reduces memory footprint while still maintaining log functionality + logWriterSmallSize = 1024 * 1024 * 1 // 1 MB + + // logWriterInitialSize is the initial buffer allocation + // This provides immediate space for early log entries + logWriterInitialSize = 32 * 1024 // 32 KB + + // logWriterSentInterval controls how often logs are sent to external systems + // This balances real-time logging with system performance + logWriterSentInterval = time.Minute + + // logWriterInitEndMarker marks the end of initialization logs + // This helps separate startup logs from runtime logs logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n" - logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n" + + // logWriterLogEndMarker marks the end of log sections + // This provides clear boundaries for log parsing and analysis + logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n" ) +// Custom level encoders that handle NOTICE level +// Since NOTICE and WARN share the same numeric value (1), we handle them specially +// in the encoder to display NOTICE messages with the "NOTICE" prefix. +// Note: WARN messages will also display as "NOTICE" because they share the same level value. +// This is the intended behavior for visual distinction. + +// noticeLevelEncoder provides custom level encoding for NOTICE level +// This ensures NOTICE messages are clearly distinguished from other log levels +func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { + switch l { + case ctrld.NoticeLevel: + enc.AppendString("NOTICE") + default: + zapcore.CapitalLevelEncoder(l, enc) + } +} + +// noticeColorLevelEncoder provides colored level encoding for NOTICE level +// This uses cyan color to make NOTICE messages visually distinct in terminal output +func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { + switch l { + case ctrld.NoticeLevel: + enc.AppendString("\x1b[36mNOTICE\x1b[0m") // Cyan color for NOTICE + default: + zapcore.CapitalColorLevelEncoder(l, enc) + } +} + +// logViewResponse represents the response structure for log viewing requests +// This provides a consistent JSON format for log data retrieval type logViewResponse struct { Data string `json:"data"` } +// logSentResponse represents the response structure for log sending operations +// This includes size information and error details for debugging type logSentResponse struct { Size int64 `json:"size"` Error string `json:"error"` } +// logReader provides read access to log data with size information. +// +// This struct encapsulates log reading functionality for external consumers, +// providing both the log content and metadata about the log size. It supports +// reading from both internal log buffers (when no external logging is configured) +// and external log files (when logging to file is enabled). +// +// Fields: +// - r: An io.ReadCloser that provides access to the log content +// - size: The total size of the log data in bytes +// +// The logReader is used by the control server to serve log content to clients +// and by various CLI commands that need to display or process log data. type logReader struct { r io.ReadCloser size int64 } // logWriter is an internal buffer to keep track of runtime log when no logging is enabled. +// This provides in-memory log storage for debugging and monitoring purposes type logWriter struct { mu sync.Mutex buf bytes.Buffer @@ -46,30 +112,37 @@ type logWriter struct { } // newLogWriter creates an internal log writer. +// This provides the default log writer with standard buffer size func newLogWriter() *logWriter { return newLogWriterWithSize(logWriterSize) } // newSmallLogWriter creates an internal log writer with small buffer size. +// This is used in memory-constrained environments or for temporary logging func newSmallLogWriter() *logWriter { return newLogWriterWithSize(logWriterSmallSize) } // newLogWriterWithSize creates an internal log writer with a given buffer size. +// This allows customization of log buffer size based on specific requirements func newLogWriterWithSize(size int) *logWriter { lw := &logWriter{size: size} return lw } +// Write implements io.Writer interface for logWriter +// This manages buffer overflow by discarding old data while preserving important markers func (lw *logWriter) Write(p []byte) (int, error) { lw.mu.Lock() defer lw.mu.Unlock() // If writing p causes overflows, discard old data. + // This prevents unbounded memory growth while maintaining recent logs if lw.buf.Len()+len(p) > lw.size { buf := lw.buf.Bytes() haveEndMarker := false // If there's init end marker already, preserve the data til the marker. + // This ensures initialization logs are always available for debugging if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 { buf = buf[:idx+len(logWriterInitEndMarker)] haveEndMarker = true @@ -95,20 +168,20 @@ func (lw *logWriter) Write(p []byte) (int, error) { // initLogging initializes global logging setup. func (p *prog) initLogging(backup bool) { - zerolog.TimeFieldFormat = time.RFC3339 + ".000" - logWriters := initLoggingWithBackup(backup) + logCores := initLoggingWithBackup(backup) // Initializing internal logging after global logging. - p.initInternalLogging(logWriters) + p.initInternalLogging(logCores) + p.logger.Store(mainLog.Load()) } // initInternalLogging performs internal logging if there's no log enabled. -func (p *prog) initInternalLogging(writers []io.Writer) { +func (p *prog) initInternalLogging(externalCores []zapcore.Core) { if !p.needInternalLogging() { return } p.initInternalLogWriterOnce.Do(func() { - mainLog.Load().Notice().Msg("internal logging enabled") + p.Notice().Msg("Internal logging enabled") p.internalLogWriter = newLogWriter() p.internalLogSent = time.Now().Add(-logWriterSentInterval) p.internalWarnLogWriter = newSmallLogWriter() @@ -117,28 +190,26 @@ func (p *prog) initInternalLogging(writers []io.Writer) { lw := p.internalLogWriter wlw := p.internalWarnLogWriter p.mu.Unlock() - // If ctrld was run without explicit verbose level, - // run the internal logging at debug level, so we could + + // Create zap cores for different writers + var cores []zapcore.Core + cores = append(cores, externalCores...) + + // Add core for internal log writer. + // Run the internal logging at debug level, so we could // have enough information for troubleshooting. - if verbose == 0 { - for i := range writers { - w := &zerolog.FilteredLevelWriter{ - Writer: zerolog.LevelWriterAdapter{Writer: writers[i]}, - Level: zerolog.NoticeLevel, - } - writers[i] = w - } - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - writers = append(writers, lw) - writers = append(writers, &zerolog.FilteredLevelWriter{ - Writer: zerolog.LevelWriterAdapter{Writer: wlw}, - Level: zerolog.WarnLevel, - }) - multi := zerolog.MultiLevelWriter(writers...) - l := mainLog.Load().Output(multi).With().Logger() - mainLog.Store(&l) - ctrld.ProxyLogger.Store(&l) + internalCore := newHumanReadableZapCore(lw, zapcore.DebugLevel) + cores = append(cores, internalCore) + + // Add core for internal warn log writer + warnCore := newHumanReadableZapCore(wlw, zapcore.WarnLevel) + cores = append(cores, warnCore) + + // Create a multi-core logger + multiCore := zapcore.NewTee(cores...) + logger := zap.New(multiCore) + + mainLog.Store(&ctrld.Logger{Logger: logger}) } // needInternalLogging reports whether prog needs to run internal logging. @@ -154,7 +225,69 @@ func (p *prog) needInternalLogging() bool { return true } -func (p *prog) logReader() (*logReader, error) { +// logReaderNoColor returns a logReader with ANSI color codes stripped from the log content. +// +// This method is useful when log content needs to be processed by tools that don't +// handle ANSI escape sequences properly, or when storing logs in plain text format. +// It internally calls logReader(true) to strip color codes. +// +// Returns: +// - *logReader: A logReader instance with color codes removed, or nil if no logs available +// - error: Any error encountered during log reading (e.g., empty logs, file access issues) +// +// Use cases: +// - Log processing pipelines that require plain text +// - Storing logs in databases or text files +// - Displaying logs in environments that don't support color +func (p *prog) logReaderNoColor() (*logReader, error) { + return p.logReader(true) +} + +// logReaderRaw returns a logReader with ANSI color codes preserved in the log content. +// +// This method maintains the original formatting of log entries including color codes, +// which is useful for displaying logs in terminals that support ANSI colors or when +// the original visual formatting needs to be preserved. It internally calls logReader(false). +// +// Returns: +// - *logReader: A logReader instance with color codes preserved, or nil if no logs available +// - error: Any error encountered during log reading (e.g., empty logs, file access issues) +// +// Use cases: +// - Terminal-based log viewers that support color +// - Interactive debugging sessions +// - Preserving original log formatting for display +func (p *prog) logReaderRaw() (*logReader, error) { + return p.logReader(false) +} + +// logReader creates a logReader instance for accessing log content with optional color stripping. +// +// This is the core method that handles log reading from different sources based on the +// current logging configuration. It supports both internal logging (when no external +// logging is configured) and external file logging (when logging to file is enabled). +// +// Behavior: +// - Internal logging: Reads from internal log buffers (normal logs + warning logs) +// and combines them with appropriate markers for separation +// - External logging: Reads directly from the configured log file +// - Empty logs: Returns appropriate error messages when no log content is available +// +// Parameters: +// - stripColor: If true, removes ANSI color codes from log content; if false, preserves them +// +// Returns: +// - *logReader: A logReader instance providing access to log content and size metadata +// - error: Any error encountered during log reading, including: +// - "nil internal log writer" - Internal logging not properly initialized +// - "nil internal warn log writer" - Warning log writer not properly initialized +// - "internal log is empty" - No content in internal log buffers +// - "log file is empty" - External log file exists but contains no data +// - File system errors when accessing external log files +// +// The method handles thread-safe access to internal log buffers and provides +// comprehensive error handling for various edge cases. +func (p *prog) logReader(stripColor bool) (*logReader, error) { if p.needInternalLogging() { p.mu.Lock() lw := p.internalLogWriter @@ -166,14 +299,15 @@ func (p *prog) logReader() (*logReader, error) { if wlw == nil { return nil, errors.New("nil internal warn log writer") } + // Normal log content. lw.mu.Lock() - lwReader := bytes.NewReader(lw.buf.Bytes()) + lwReader := newLogReader(&lw.buf, stripColor) lwSize := lw.buf.Len() lw.mu.Unlock() // Warn log content. wlw.mu.Lock() - wlwReader := bytes.NewReader(wlw.buf.Bytes()) + wlwReader := newLogReader(&wlw.buf, stripColor) wlwSize := wlw.buf.Len() wlw.mu.Unlock() reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader) @@ -202,3 +336,72 @@ func (p *prog) logReader() (*logReader, error) { } return lr, nil } + +// newHumanReadableZapCore creates a zap core optimized for human-readable log output. +// +// Features: +// - Uses development encoder configuration for enhanced readability +// - Console encoding with colored log levels for easy visual scanning +// - Millisecond precision timestamps in human-friendly format +// - Structured field output with clear key-value pairs +// - Ideal for development, debugging, and interactive terminal sessions +// +// Parameters: +// - w: The output writer (e.g., os.Stdout, file, buffer) +// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error) +// +// Returns a zapcore.Core configured for human consumption. +func newHumanReadableZapCore(w io.Writer, level zapcore.Level) zapcore.Core { + encoderConfig := zap.NewDevelopmentEncoderConfig() + encoderConfig.TimeKey = "time" + encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli) + encoderConfig.EncodeLevel = noticeColorLevelEncoder + encoder := zapcore.NewConsoleEncoder(encoderConfig) + return zapcore.NewCore(encoder, zapcore.AddSync(w), level) +} + +// newMachineFriendlyZapCore creates a zap core optimized for machine processing and log aggregation. +// +// Features: +// - Uses production encoder configuration for consistent, parseable output +// - Console encoding with non-colored log levels for log parsing tools +// - Millisecond precision timestamps in ISO-like format +// - Structured field output optimized for log aggregation systems +// - Ideal for production environments, log shipping, and automated analysis +// +// Parameters: +// - w: The output writer (e.g., os.Stdout, file, buffer) +// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error) +// +// Returns a zapcore.Core configured for machine consumption and log aggregation. +func newMachineFriendlyZapCore(w io.Writer, level zapcore.Level) zapcore.Core { + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "time" + encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli) + encoderConfig.EncodeLevel = noticeLevelEncoder + encoder := zapcore.NewConsoleEncoder(encoderConfig) + return zapcore.NewCore(encoder, zapcore.AddSync(w), level) +} + +// ansiRegex is a regular expression to match ANSI color codes. +var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`) + +// newLogReader creates a reader for log buffer content with optional ANSI color stripping. +// +// This function provides flexible log content access by allowing consumers to choose +// between raw log data (with ANSI color codes) or stripped content (without color codes). +// The color stripping is useful when logs need to be processed by tools that don't +// handle ANSI escape sequences properly, or when storing logs in plain text format. +// +// Parameters: +// - buf: The log buffer containing the log data to read +// - stripColor: If true, strips ANSI color codes from the log content; +// if false, returns raw log content with color codes preserved +// +// Returns an io.Reader that provides access to the processed log content. +func newLogReader(buf *bytes.Buffer, stripColor bool) io.Reader { + if stripColor { + return strings.NewReader(ansiRegex.ReplaceAllString(buf.String(), "")) + } + return strings.NewReader(buf.String()) +} diff --git a/cmd/cli/log_writer_test.go b/cmd/cli/log_writer_test.go index 5336d4eb..5af5c132 100644 --- a/cmd/cli/log_writer_test.go +++ b/cmd/cli/log_writer_test.go @@ -1,9 +1,16 @@ package cli import ( + "bytes" + "io" "strings" "sync" "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/Control-D-Inc/ctrld" ) func Test_logWriter_Write(t *testing.T) { @@ -83,3 +90,328 @@ func Test_logWriter_MarkerInitEnd(t *testing.T) { t.Fatalf("unexpected log content: %s", lw.buf.String()) } } + +// TestNoticeLevel tests that the custom NOTICE level works correctly +func TestNoticeLevel(t *testing.T) { + // Create a buffer to capture log output + var buf bytes.Buffer + + // Create encoder config with custom NOTICE level support + encoderConfig := zap.NewDevelopmentEncoderConfig() + encoderConfig.TimeKey = "time" + encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout("15:04:05.000") + encoderConfig.EncodeLevel = noticeLevelEncoder + + // Test with NOTICE level + encoder := zapcore.NewConsoleEncoder(encoderConfig) + core := zapcore.NewCore(encoder, zapcore.AddSync(&buf), ctrld.NoticeLevel) + logger := zap.New(core) + ctrldLogger := &ctrld.Logger{Logger: logger} + + // Log messages at different levels + ctrldLogger.Debug().Msg("This is a DEBUG message") + ctrldLogger.Info().Msg("This is an INFO message") + ctrldLogger.Notice().Msg("This is a NOTICE message") + ctrldLogger.Warn().Msg("This is a WARN message") + ctrldLogger.Error().Msg("This is an ERROR message") + + output := buf.String() + + // Verify that DEBUG and INFO messages are NOT logged (filtered out) + if strings.Contains(output, "DEBUG") { + t.Error("DEBUG message should not be logged when level is NOTICE") + } + if strings.Contains(output, "INFO") { + t.Error("INFO message should not be logged when level is NOTICE") + } + + // Verify that NOTICE, WARN, and ERROR messages ARE logged + if !strings.Contains(output, "NOTICE") { + t.Error("NOTICE message should be logged when level is NOTICE") + } + if !strings.Contains(output, "WARN") { + t.Error("WARN message should be logged when level is NOTICE") + } + if !strings.Contains(output, "ERROR") { + t.Error("ERROR message should be logged when level is NOTICE") + } + + // Verify the NOTICE message content + if !strings.Contains(output, "This is a NOTICE message") { + t.Error("NOTICE message content should be present") + } + + t.Logf("Log output with NOTICE level:\n%s", output) +} + +func TestNewLogReader(t *testing.T) { + tests := []struct { + name string + bufContent string + stripColor bool + expected string + description string + }{ + { + name: "empty_buffer_no_color_strip", + bufContent: "", + stripColor: false, + expected: "", + description: "Empty buffer should return empty reader", + }, + { + name: "empty_buffer_with_color_strip", + bufContent: "", + stripColor: true, + expected: "", + description: "Empty buffer with color strip should return empty reader", + }, + { + name: "plain_text_no_color_strip", + bufContent: "This is plain text without any color codes", + stripColor: false, + expected: "This is plain text without any color codes", + description: "Plain text should be returned as-is when not stripping colors", + }, + { + name: "plain_text_with_color_strip", + bufContent: "This is plain text without any color codes", + stripColor: true, + expected: "This is plain text without any color codes", + description: "Plain text should be returned as-is when stripping colors", + }, + { + name: "text_with_ansi_codes_no_strip", + bufContent: "Normal text \x1b[31mred text\x1b[0m normal again", + stripColor: false, + expected: "Normal text \x1b[31mred text\x1b[0m normal again", + description: "ANSI color codes should be preserved when not stripping", + }, + { + name: "text_with_ansi_codes_with_strip", + bufContent: "Normal text \x1b[31mred text\x1b[0m normal again", + stripColor: true, + expected: "Normal text red text normal again", + description: "ANSI color codes should be removed when stripping colors", + }, + { + name: "multiple_ansi_codes_no_strip", + bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text", + stripColor: false, + expected: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text", + description: "Multiple ANSI codes should be preserved when not stripping", + }, + { + name: "multiple_ansi_codes_with_strip", + bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text", + stripColor: true, + expected: "Bold Green Blue text", + description: "Multiple ANSI codes should be removed when stripping colors", + }, + { + name: "complex_ansi_sequences_no_strip", + bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m", + stripColor: false, + expected: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m", + description: "Complex ANSI sequences should be preserved when not stripping", + }, + { + name: "complex_ansi_sequences_with_strip", + bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m", + stripColor: true, + expected: "Bold red on green Orange", + description: "Complex ANSI sequences should be removed when stripping colors", + }, + { + name: "ansi_codes_with_newlines_no_strip", + bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3", + stripColor: false, + expected: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3", + description: "ANSI codes with newlines should be preserved when not stripping", + }, + { + name: "ansi_codes_with_newlines_with_strip", + bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3", + stripColor: true, + expected: "Line 1\nRed line\nLine 3", + description: "ANSI codes with newlines should be removed when stripping colors", + }, + { + name: "malformed_ansi_codes_no_strip", + bufContent: "Text \x1b[invalidm \x1b[0m normal", + stripColor: false, + expected: "Text \x1b[invalidm \x1b[0m normal", + description: "Malformed ANSI codes should be preserved when not stripping", + }, + { + name: "malformed_ansi_codes_with_strip", + bufContent: "Text \x1b[invalidm \x1b[0m normal", + stripColor: true, + expected: "Text \x1b[invalidm normal", + description: "Non-matching ANSI sequences should be preserved when stripping colors", + }, + { + name: "large_buffer_no_strip", + bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m", + stripColor: false, + expected: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m", + description: "Large buffer should handle ANSI codes correctly when not stripping", + }, + { + name: "large_buffer_with_strip", + bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m", + stripColor: true, + expected: strings.Repeat("A", 10000) + strings.Repeat("B", 1000), + description: "Large buffer should remove ANSI codes correctly when stripping", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a buffer with the test content + buf := &bytes.Buffer{} + buf.WriteString(tt.bufContent) + + // Create the log reader + reader := newLogReader(buf, tt.stripColor) + + // Read all content from the reader + content, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("Failed to read from log reader: %v", err) + } + + // Verify the content matches expected + actual := string(content) + if actual != tt.expected { + t.Errorf("Expected content: %q, got: %q", tt.expected, actual) + t.Logf("Description: %s", tt.description) + } + }) + } +} + +func TestNewLogReader_ReaderBehavior(t *testing.T) { + // Test that the returned reader behaves correctly + buf := &bytes.Buffer{} + buf.WriteString("Test content with \x1b[31mred\x1b[0m text") + + // Test with color stripping + reader := newLogReader(buf, true) + + // Test reading in chunks + chunk1 := make([]byte, 10) + n1, err := reader.Read(chunk1) + if err != nil && err != io.EOF { + t.Fatalf("Unexpected error reading first chunk: %v", err) + } + if n1 != 10 { + t.Errorf("Expected to read 10 bytes, got %d", n1) + } + + // Test reading remaining content + remaining, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("Failed to read remaining content: %v", err) + } + + // Verify total content + totalContent := string(chunk1[:n1]) + string(remaining) + expected := "Test content with red text" + if totalContent != expected { + t.Errorf("Expected total content: %q, got: %q", expected, totalContent) + } +} + +func TestNewLogReader_ConcurrentAccess(t *testing.T) { + // Test concurrent access to the same buffer + buf := &bytes.Buffer{} + buf.WriteString("Concurrent test with \x1b[32mgreen\x1b[0m text") + + var wg sync.WaitGroup + numGoroutines := 10 + results := make(chan string, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + reader := newLogReader(buf, true) + content, err := io.ReadAll(reader) + if err != nil { + t.Errorf("Failed to read content: %v", err) + return + } + results <- string(content) + }() + } + + wg.Wait() + close(results) + + // Verify all goroutines got the same result + expected := "Concurrent test with green text" + for result := range results { + if result != expected { + t.Errorf("Expected: %q, got: %q", expected, result) + } + } +} + +func TestNewLogReader_ANSIRegexEdgeCases(t *testing.T) { + // Test edge cases for ANSI regex matching + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty_escape_sequence", + input: "Text \x1b[m normal", + expected: "Text normal", + }, + { + name: "multiple_semicolons", + input: "Text \x1b[1;2;3;4m normal", + expected: "Text normal", + }, + { + name: "numeric_only", + input: "Text \x1b[123m normal", + expected: "Text normal", + }, + { + name: "mixed_numeric_semicolon", + input: "Text \x1b[1;23;456m normal", + expected: "Text normal", + }, + { + name: "no_closing_bracket", + input: "Text \x1b[31 normal", + expected: "Text \x1b[31 normal", + }, + { + name: "no_opening_bracket", + input: "Text 31m normal", + expected: "Text 31m normal", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + buf.WriteString(tt.input) + + reader := newLogReader(buf, true) + content, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("Failed to read content: %v", err) + } + + actual := string(content) + if actual != tt.expected { + t.Errorf("Expected: %q, got: %q", tt.expected, actual) + } + }) + } +} diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 3504bc34..a3c00eda 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -84,7 +84,7 @@ func (p *prog) detectLoop(msg *dns.Msg) { // // See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html func (p *prog) checkDnsLoop() { - mainLog.Load().Debug().Msg("start checking DNS loop") + p.Debug().Msg("Start checking DNS loop") upstream := make(map[string]*ctrld.UpstreamConfig) p.loopMu.Lock() for n, uc := range p.cfg.Upstream { @@ -93,7 +93,7 @@ func (p *prog) checkDnsLoop() { } // Do not send test query to external upstream. if !canBeLocalUpstream(uc.Domain) { - mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n) + p.Debug().Msgf("Skipping external: upstream.%s", n) continue } uid := uc.UID() @@ -102,6 +102,7 @@ func (p *prog) checkDnsLoop() { } p.loopMu.Unlock() + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) for uid := range p.loop { msg := loopTestMsg(uid) uc := upstream[uid] @@ -109,16 +110,16 @@ func (p *prog) checkDnsLoop() { if uc == nil { continue } - resolver, err := ctrld.NewResolver(uc) + resolver, err := ctrld.NewResolver(loggerCtx, uc) if err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + p.Warn().Err(err).Msgf("Could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) continue } if _, err := resolver.Resolve(context.Background(), msg); err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + p.Warn().Err(err).Msgf("Could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) } } - mainLog.Load().Debug().Msg("end checking DNS loop") + p.Debug().Msg("End checking DNS loop") } // checkDnsLoopTicker performs p.checkDnsLoop every minute. @@ -137,7 +138,7 @@ func (p *prog) checkDnsLoopTicker(ctx context.Context) { } } -// loopTestMsg generates DNS message for checking loop. +// loopTestMsg creates a DNS test message for loop detection func loopTestMsg(uid string) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 07839756..7581a16f 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -5,14 +5,16 @@ import ( "os" "path/filepath" "sync/atomic" - "time" "github.com/kardianos/service" - "github.com/rs/zerolog" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" "github.com/Control-D-Inc/ctrld" ) +// Global variables for CLI configuration and state management +// These are used across multiple commands and need to persist throughout the application lifecycle var ( configPath string configBase64 string @@ -41,11 +43,14 @@ var ( startOnly bool rfc1918 bool - mainLog atomic.Pointer[zerolog.Logger] - consoleWriter zerolog.ConsoleWriter - noConfigStart bool + mainLog atomic.Pointer[ctrld.Logger] + consoleWriter zapcore.Core + consoleWriterLevel zapcore.Level + noConfigStart bool ) +// Flag name constants for consistent reference across the codebase +// Using constants prevents typos and makes refactoring easier const ( cdUidFlagName = "cd" cdOrgFlagName = "cd-org" @@ -53,20 +58,26 @@ const ( nextdnsFlagName = "nextdns" ) +// init initializes the default logger before any CLI commands are executed +// This ensures logging is available even during early initialization phases func init() { - l := zerolog.New(io.Discard) - mainLog.Store(&l) + l := zap.NewNop() + mainLog.Store(&ctrld.Logger{Logger: l}) } +// Main is the entry point for the CLI application +// It initializes configuration, sets up the CLI structure, and executes the root command func Main() { ctrld.InitConfig(v, "ctrld") - initCLI() + rootCmd := initCLI() if err := rootCmd.Execute(); err != nil { mainLog.Load().Error().Msg(err.Error()) os.Exit(1) } } +// normalizeLogFilePath converts relative log file paths to absolute paths +// This ensures log files are created in predictable locations regardless of working directory func normalizeLogFilePath(logFilePath string) string { if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() { return logFilePath @@ -82,40 +93,36 @@ func normalizeLogFilePath(logFilePath string) string { } // initConsoleLogging initializes console logging, then storing to mainLog. +// This sets up human-readable logging output for interactive use func initConsoleLogging() { - consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) { - w.TimeFormat = time.StampMilli - }) - multi := zerolog.MultiLevelWriter(consoleWriter) - l := mainLog.Load().Output(multi).With().Timestamp().Logger() - mainLog.Store(&l) - + consoleWriterLevel = ctrld.NoticeLevel switch { case silent: - zerolog.SetGlobalLevel(zerolog.NoLevel) + // For silent mode, use a no-op logger to suppress all output + l := zap.NewNop() + mainLog.Store(&ctrld.Logger{Logger: l}) case verbose == 1: - ctrld.ProxyLogger.Store(&l) - zerolog.SetGlobalLevel(zerolog.InfoLevel) + // Info level provides basic operational information + consoleWriterLevel = zapcore.InfoLevel case verbose > 1: - ctrld.ProxyLogger.Store(&l) - zerolog.SetGlobalLevel(zerolog.DebugLevel) - default: - zerolog.SetGlobalLevel(zerolog.NoticeLevel) + // Debug level provides detailed diagnostic information + consoleWriterLevel = zapcore.DebugLevel } + consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel) + l := zap.New(consoleWriter) + mainLog.Store(&ctrld.Logger{Logger: l}) } // initInteractiveLogging is like initLogging, but the ProxyLogger is discarded // to be used for all interactive commands. // // Current log file config will also be ignored. +// This prevents log file conflicts during interactive command execution func initInteractiveLogging() { old := cfg.Service.LogPath cfg.Service.LogPath = "" - zerolog.TimeFieldFormat = time.RFC3339 + ".000" initLoggingWithBackup(false) cfg.Service.LogPath = old - l := zerolog.New(io.Discard) - ctrld.ProxyLogger.Store(&l) } // initLoggingWithBackup initializes log setup base on current config. @@ -124,68 +131,101 @@ func initInteractiveLogging() { // This is only used in runCmd for special handling in case of logging config // change in cd mode. Without special reason, the caller should use initLogging // wrapper instead of calling this function directly. -func initLoggingWithBackup(doBackup bool) []io.Writer { +func initLoggingWithBackup(doBackup bool) []zapcore.Core { var writers []io.Writer if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" { // Create parent directory if necessary. + // This ensures log files can be created even if the directory doesn't exist if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil { - mainLog.Load().Error().Msgf("failed to create log path: %v", err) + mainLog.Load().Error().Msgf("Failed to create log path: %v", err) os.Exit(1) } // Default open log file in append mode. + // This preserves existing log entries across restarts flags := os.O_CREATE | os.O_RDWR | os.O_APPEND if doBackup { // Backup old log file with .1 suffix. + // This prevents log file corruption during rotation if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) { - mainLog.Load().Error().Msgf("could not backup old log file: %v", err) + mainLog.Load().Error().Msgf("Could not backup old log file: %v", err) } else { // Backup was created, set flags for truncating old log file. + // This ensures a clean start for the new log file flags = os.O_CREATE | os.O_RDWR } } logFile, err := openLogFile(logFilePath, flags) if err != nil { - mainLog.Load().Error().Msgf("failed to create log file: %v", err) + mainLog.Load().Error().Msgf("Failed to create log file: %v", err) os.Exit(1) } writers = append(writers, logFile) } - writers = append(writers, consoleWriter) - multi := zerolog.MultiLevelWriter(writers...) - l := mainLog.Load().Output(multi).With().Logger() - mainLog.Store(&l) - // TODO: find a better way. - ctrld.ProxyLogger.Store(&l) - - zerolog.SetGlobalLevel(zerolog.NoticeLevel) + + // Create zap cores for different writers + // Multiple cores allow logging to both console and file simultaneously + var cores []zapcore.Core + cores = append(cores, consoleWriter) + + // Determine log level based on verbosity and configuration + // This provides flexible logging control for different use cases logLevel := cfg.Service.LogLevel switch { case silent: - zerolog.SetGlobalLevel(zerolog.NoLevel) - return writers + // For silent mode, use a no-op logger to suppress all output + l := zap.NewNop() + mainLog.Store(&ctrld.Logger{Logger: l}) + return cores case verbose == 1: logLevel = "info" case verbose > 1: logLevel = "debug" } - if logLevel == "" { - return writers + + // Parse log level string to zapcore.Level + // This provides human-readable log level configuration + var level zapcore.Level + switch logLevel { + case "debug": + level = zapcore.DebugLevel + case "info": + level = zapcore.InfoLevel + case "notice": + level = ctrld.NoticeLevel + case "warn": + level = zapcore.WarnLevel + case "error": + level = zapcore.ErrorLevel + default: + level = zapcore.InfoLevel // default level } - level, err := zerolog.ParseLevel(logLevel) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not set log level") - return writers + + consoleWriter.Enabled(level) + // Add cores for all writers + // This enables multi-destination logging (console + file) + for _, writer := range writers { + core := newMachineFriendlyZapCore(writer, level) + cores = append(cores, core) } - zerolog.SetGlobalLevel(level) - return writers + + // Create a multi-core logger + // This allows simultaneous logging to multiple destinations + multiCore := zapcore.NewTee(cores...) + logger := zap.New(multiCore) + mainLog.Store(&ctrld.Logger{Logger: logger}) + + return cores } +// initCache initializes DNS cache configuration +// This improves performance by caching frequently requested DNS responses func initCache() { if !cfg.Service.CacheEnable { return } if cfg.Service.CacheSize == 0 { + // Default cache size provides good balance between memory usage and performance cfg.Service.CacheSize = 4096 } } diff --git a/cmd/cli/main_test.go b/cmd/cli/main_test.go index 6ed26c73..d0a11492 100644 --- a/cmd/cli/main_test.go +++ b/cmd/cli/main_test.go @@ -5,13 +5,28 @@ import ( "strings" "testing" - "github.com/rs/zerolog" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/Control-D-Inc/ctrld" ) var logOutput strings.Builder func TestMain(m *testing.M) { - l := zerolog.New(&logOutput) - mainLog.Store(&l) + // Create a custom writer that writes to logOutput + writer := zapcore.AddSync(&logOutput) + + // Create zap encoder + encoderConfig := zap.NewDevelopmentEncoderConfig() + encoder := zapcore.NewConsoleEncoder(encoderConfig) + + // Create core that writes to our string builder + core := zapcore.NewCore(encoder, writer, zap.DebugLevel) + + // Create logger + l := zap.New(core) + + mainLog.Store(&ctrld.Logger{Logger: l}) os.Exit(m.Run()) } diff --git a/cmd/cli/metrics.go b/cmd/cli/metrics.go index 565cdcc5..330918c9 100644 --- a/cmd/cli/metrics.go +++ b/cmd/cli/metrics.go @@ -15,6 +15,7 @@ import ( ) // metricsServer represents a server to expose Prometheus metrics via HTTP. +// This provides monitoring and observability for the DNS proxy service type metricsServer struct { server *http.Server mux *http.ServeMux @@ -24,6 +25,7 @@ type metricsServer struct { } // newMetricsServer returns new metrics server. +// This initializes the HTTP server for exposing Prometheus metrics func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) { mux := http.NewServeMux() ms := &metricsServer{ @@ -37,11 +39,13 @@ func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, er } // register adds handlers for given pattern. +// This provides a clean interface for adding HTTP endpoints to the metrics server func (ms *metricsServer) register(pattern string, handler http.Handler) { ms.mux.Handle(pattern, handler) } // registerMetricsServerHandler adds handlers for metrics server. +// This sets up both Prometheus format and JSON format endpoints for metrics func (ms *metricsServer) registerMetricsServerHandler() { ms.register("/metrics", promhttp.HandlerFor( ms.reg, @@ -74,6 +78,7 @@ func (ms *metricsServer) registerMetricsServerHandler() { } // start runs the metricsServer. +// This starts the HTTP server for metrics exposure func (ms *metricsServer) start() error { listener, err := net.Listen("tcp", ms.addr) if err != nil { @@ -85,6 +90,7 @@ func (ms *metricsServer) start() error { } // stop shutdowns the metricsServer within 2 seconds timeout. +// This ensures graceful shutdown of the metrics server func (ms *metricsServer) stop() error { if !ms.started { return nil @@ -95,6 +101,7 @@ func (ms *metricsServer) stop() error { } // runMetricsServer initializes metrics stats and runs the metrics server if enabled. +// This sets up the complete metrics infrastructure including Prometheus collectors func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) { if !p.metricsEnabled() { return @@ -115,7 +122,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) { addr := p.cfg.Service.MetricsListener ms, err := newMetricsServer(addr, reg) if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create new metrics server") + mainLog.Load().Warn().Err(err).Msg("Could not create new metrics server") return } // Only start listener address if defined. @@ -130,9 +137,9 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) { statsVersion.WithLabelValues(commit, runtime.Version(), curVersion()).Inc() reg.MustRegister(statsTimeStart) statsTimeStart.Set(float64(time.Now().Unix())) - mainLog.Load().Debug().Msgf("starting metrics server on: %s", addr) + mainLog.Load().Debug().Msgf("Starting metrics server on: %s", addr) if err := ms.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start metrics server") + mainLog.Load().Warn().Err(err).Msg("Could not start metrics server") return } } @@ -144,7 +151,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) { } if err := ms.stop(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not stop metrics server") + mainLog.Load().Warn().Err(err).Msg("Could not stop metrics server") return } } diff --git a/cmd/cli/net_darwin.go b/cmd/cli/net_darwin.go index 62331610..7f756c4f 100644 --- a/cmd/cli/net_darwin.go +++ b/cmd/cli/net_darwin.go @@ -49,28 +49,3 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo _, ok := validIfacesMap[iface.Name] return ok } - -// validInterfacesMap returns a set of all valid hardware ports. -func validInterfacesMap() map[string]struct{} { - b, err := exec.Command("networksetup", "-listallhardwareports").Output() - if err != nil { - return nil - } - return parseListAllHardwarePorts(bytes.NewReader(b)) -} - -// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports" -// and returns map presents all hardware ports. -func parseListAllHardwarePorts(r io.Reader) map[string]struct{} { - m := make(map[string]struct{}) - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - after, ok := strings.CutPrefix(line, "Device: ") - if !ok { - continue - } - m[after] = struct{}{} - } - return m -} diff --git a/cmd/cli/net_linux.go b/cmd/cli/net_linux.go index ea17d3d8..f5a07de4 100644 --- a/cmd/cli/net_linux.go +++ b/cmd/cli/net_linux.go @@ -2,51 +2,16 @@ package cli import ( "net" - "net/netip" - "os" - "strings" - - "tailscale.com/net/netmon" ) +// patchNetIfaceName patches network interface names on Linux +// This is a no-op on Linux as interface names don't need special handling func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } // validInterface reports whether the *net.Interface is a valid one. // Only non-virtual interfaces are considered valid. +// This prevents DNS configuration on virtual interfaces like docker, veth, etc. func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { _, ok := validIfacesMap[iface.Name] return ok } - -// validInterfacesMap returns a set containing non virtual interfaces. -func validInterfacesMap() map[string]struct{} { - m := make(map[string]struct{}) - vis := virtualInterfaces() - netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { - if _, existed := vis[i.Name]; existed { - return - } - m[i.Name] = struct{}{} - }) - // Fallback to default route interface if found nothing. - if len(m) == 0 { - defaultRoute, err := netmon.DefaultRoute() - if err != nil { - return m - } - m[defaultRoute.InterfaceName] = struct{}{} - } - return m -} - -// virtualInterfaces returns a map of virtual interfaces on current machine. -func virtualInterfaces() map[string]struct{} { - s := make(map[string]struct{}) - entries, _ := os.ReadDir("/sys/devices/virtual/net") - for _, entry := range entries { - if entry.IsDir() { - s[strings.TrimSpace(entry.Name())] = struct{}{} - } - } - return s -} diff --git a/cmd/cli/net_others.go b/cmd/cli/net_others.go index f3472781..4ab96dea 100644 --- a/cmd/cli/net_others.go +++ b/cmd/cli/net_others.go @@ -4,19 +4,10 @@ package cli import ( "net" - - "tailscale.com/net/netmon" ) +// patchNetIfaceName patches network interface names on non-Linux/Darwin platforms func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } +// validInterface checks if an interface is valid on non-Linux/Darwin platforms func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true } - -// validInterfacesMap returns a set containing only default route interfaces. -func validInterfacesMap() map[string]struct{} { - defaultRoute, err := netmon.DefaultRoute() - if err != nil { - return nil - } - return map[string]struct{}{defaultRoute.InterfaceName: {}} -} diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index bed06b57..bdd6dcf5 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -1,16 +1,7 @@ package cli import ( - "io" - "log" "net" - "os" - - "github.com/microsoft/wmi/pkg/base/host" - "github.com/microsoft/wmi/pkg/base/instance" - "github.com/microsoft/wmi/pkg/base/query" - "github.com/microsoft/wmi/pkg/constant" - "github.com/microsoft/wmi/pkg/hardware/network/netadapter" ) func patchNetIfaceName(iface *net.Interface) (bool, error) { @@ -23,71 +14,3 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo _, ok := validIfacesMap[iface.Name] return ok } - -// validInterfacesMap returns a set of all physical interfaces. -func validInterfacesMap() map[string]struct{} { - m := make(map[string]struct{}) - for _, ifaceName := range validInterfaces() { - m[ifaceName] = struct{}{} - } - return m -} - -// validInterfaces returns a list of all physical interfaces. -func validInterfaces() []string { - log.SetOutput(io.Discard) - defer log.SetOutput(os.Stderr) - whost := host.NewWmiLocalHost() - q := query.NewWmiQuery("MSFT_NetAdapter") - instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) - if instances != nil { - defer instances.Close() - } - if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter") - return nil - } - var adapters []string - for _, i := range instances { - adapter, err := netadapter.NewNetworkAdapter(i) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get network adapter") - continue - } - - name, err := adapter.GetPropertyName() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get interface name") - continue - } - - // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) - // - // "Indicates if a connector is present on the network adapter. This value is set to TRUE - // if this is a physical adapter or FALSE if this is not a physical adapter." - physical, err := adapter.GetPropertyConnectorPresent() - if err != nil { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property") - continue - } - if !physical { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter") - continue - } - - // Check if it's a hardware interface. Checking only for connector present is not enough - // because some interfaces are not physical but have a connector. - hardware, err := adapter.GetPropertyHardwareInterface() - if err != nil { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property") - continue - } - if !hardware { - mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface") - continue - } - - adapters = append(adapters, name) - } - return adapters -} diff --git a/cmd/cli/net_windows_test.go b/cmd/cli/net_windows_test.go index a15f119b..551fe784 100644 --- a/cmd/cli/net_windows_test.go +++ b/cmd/cli/net_windows_test.go @@ -3,18 +3,23 @@ package cli import ( "bufio" "bytes" + "context" + "maps" "slices" "strings" "testing" "time" + + "github.com/Control-D-Inc/ctrld" ) func Test_validInterfaces(t *testing.T) { verbose = 3 initConsoleLogging() start := time.Now() - ifaces := validInterfaces() + im := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load())) t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + ifaces := slices.Collect(maps.Keys(im)) start = time.Now() ifacesPowershell := validInterfacesPowershell() diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index d757f8b7..1c6aab6e 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -5,6 +5,8 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" + + "github.com/Control-D-Inc/ctrld" ) func (p *prog) watchLinkState(ctx context.Context) { @@ -12,7 +14,7 @@ func (p *prog) watchLinkState(ctx context.Context) { done := make(chan struct{}) defer close(done) if err := netlink.LinkSubscribe(ch, done); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not subscribe link") + p.Warn().Err(err).Msg("Could not subscribe link") return } for { @@ -24,9 +26,9 @@ func (p *prog) watchLinkState(ctx context.Context) { continue } if lu.Change&unix.IFF_UP != 0 { - mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") + p.Debug().Msgf("Link state changed, re-bootstrapping") for _, uc := range p.cfg.Upstream { - uc.ReBootstrap() + uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load())) } } } diff --git a/cmd/cli/network_manager_linux.go b/cmd/cli/network_manager_linux.go index 1a8c22b9..dc847e3a 100644 --- a/cmd/cli/network_manager_linux.go +++ b/cmd/cli/network_manager_linux.go @@ -23,66 +23,67 @@ systemd-resolved=false var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename) // hasNetworkManager reports whether NetworkManager executable found. +// hasNetworkManager checks if NetworkManager is available on the system func hasNetworkManager() bool { exe, _ := exec.LookPath("NetworkManager") return exe != "" } -func setupNetworkManager() error { +func (p *prog) setupNetworkManager() error { if !hasNetworkManager() { return nil } if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent { - mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do") + p.Debug().Msg("NetworkManager already setup, nothing to do") return nil } err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644)) if os.IsNotExist(err) { - mainLog.Load().Debug().Msg("NetworkManager is not available") + p.Debug().Msg("NetworkManager is not available") return nil } if err != nil { - mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file") + p.Debug().Err(err).Msg("Could not write NetworkManager ctrld config file") return err } - reloadNetworkManager() - mainLog.Load().Debug().Msg("setup NetworkManager done") + p.reloadNetworkManager() + p.Debug().Msg("Setup NetworkManager done") return nil } -func restoreNetworkManager() error { +func (p *prog) restoreNetworkManager() error { if !hasNetworkManager() { return nil } err := os.Remove(networkManagerCtrldConfFile) if os.IsNotExist(err) { - mainLog.Load().Debug().Msg("NetworkManager is not available") + p.Debug().Msg("NetworkManager is not available") return nil } if err != nil { - mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file") + p.Debug().Err(err).Msg("Could not remove NetworkManager ctrld config file") return err } - reloadNetworkManager() - mainLog.Load().Debug().Msg("restore NetworkManager done") + p.reloadNetworkManager() + p.Debug().Msg("Restore NetworkManager done") return nil } -func reloadNetworkManager() { +func (p *prog) reloadNetworkManager() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() conn, err := dbus.NewSystemConnectionContext(ctx) if err != nil { - mainLog.Load().Error().Err(err).Msg("could not create new system connection") + p.Error().Err(err).Msg("Could not create new system connection") return } defer conn.Close() waitCh := make(chan string) if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil { - mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager") + p.Debug().Err(err).Msg("Could not reload NetworkManager") return } <-waitCh diff --git a/cmd/cli/network_manager_others.go b/cmd/cli/network_manager_others.go index 323d2f2e..e6e5f687 100644 --- a/cmd/cli/network_manager_others.go +++ b/cmd/cli/network_manager_others.go @@ -2,14 +2,14 @@ package cli -func setupNetworkManager() error { - reloadNetworkManager() +func (p *prog) setupNetworkManager() error { + p.reloadNetworkManager() return nil } -func restoreNetworkManager() error { - reloadNetworkManager() +func (p *prog) restoreNetworkManager() error { + p.reloadNetworkManager() return nil } -func reloadNetworkManager() {} +func (p *prog) reloadNetworkManager() {} diff --git a/cmd/cli/nextdns.go b/cmd/cli/nextdns.go index f4fed479..53e0492a 100644 --- a/cmd/cli/nextdns.go +++ b/cmd/cli/nextdns.go @@ -8,11 +8,12 @@ import ( const nextdnsURL = "https://dns.nextdns.io" +// generateNextDNSConfig generates NextDNS configuration for the given UID func generateNextDNSConfig(uid string) { if uid == "" { return } - mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver") + mainLog.Load().Info().Msg("Generating ctrld config for NextDNS resolver") cfg = ctrld.Config{ Listener: map[string]*ctrld.ListenerConfig{ "0": { diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index 4c358b0e..7421aee9 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -8,26 +8,31 @@ import ( "os/exec" "strings" - "github.com/Control-D-Inc/ctrld/internal/resolvconffile" + "github.com/Control-D-Inc/ctrld" ) -// allocate loopback ip +// allocateIP allocates an IP address on the specified interface // sudo ifconfig lo0 alias 127.0.0.2 up func allocateIP(ip string) error { + mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address") cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up") if err := cmd.Run(); err != nil { - mainLog.Load().Error().Err(err).Msg("allocateIP failed") + mainLog.Load().Error().Err(err).Msg("AllocateIP failed") return err } + mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully") return nil } +// deAllocateIP deallocates an IP address from the specified interface func deAllocateIP(ip string) error { + mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address") cmd := exec.Command("ifconfig", "lo0", "-alias", ip) if err := cmd.Run(); err != nil { - mainLog.Load().Error().Err(err).Msg("deAllocateIP failed") + mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed") return err } + mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully") return nil } @@ -47,6 +52,8 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e // networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1 // TODO(cuonglm): use system API func setDNS(iface *net.Interface, nameservers []string) error { + mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration") + // Note that networksetup won't modify search domains settings, // This assignment is just a placeholder to silent linter. _ = searchDomains @@ -56,6 +63,8 @@ func setDNS(iface *net.Interface, nameservers []string) error { if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil { return fmt.Errorf("%v: %w", string(out), err) } + + mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully") return nil } @@ -73,25 +82,30 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { // TODO(cuonglm): use system API func resetDNS(iface *net.Interface) error { + mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration") + cmd := "networksetup" args := []string{"-setdnsservers", iface.Name, "empty"} if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil { return fmt.Errorf("%v: %w", string(out), err) } + + mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully") return nil } // restoreDNS restores the DNS settings of the given interface. // this should only be executed upon turning off the ctrld service. func restoreDNS(iface *net.Interface) (err error) { - if ns := savedStaticNameservers(iface); len(ns) > 0 { + if ns := ctrld.SavedStaticNameservers(iface); len(ns) > 0 { err = setDNS(iface, ns) } return err } +// currentDNS returns the current DNS servers for the specified interface func currentDNS(_ *net.Interface) []string { - return resolvconffile.NameServers() + return ctrld.CurrentNameserversFromResolvconf() } // currentStaticDNS returns the current static DNS settings of given interface. diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index d66e4bff..76ac998e 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -9,27 +9,32 @@ import ( "tailscale.com/health" "tailscale.com/util/dnsname" + "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dns" - "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) -// allocate loopback ip +// allocateIP allocates an IP address on the specified interface // sudo ifconfig lo0 127.0.0.53 alias func allocateIP(ip string) error { + mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address") cmd := exec.Command("ifconfig", "lo0", ip, "alias") if err := cmd.Run(); err != nil { mainLog.Load().Error().Err(err).Msg("allocateIP failed") return err } + mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully") return nil } +// deAllocateIP deallocates an IP address from the specified interface func deAllocateIP(ip string) error { + mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address") cmd := exec.Command("ifconfig", "lo0", ip, "-alias") if err := cmd.Run(); err != nil { mainLog.Load().Error().Err(err).Msg("deAllocateIP failed") return err } + mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully") return nil } @@ -40,9 +45,11 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e // set the dns server for the provided network interface func setDNS(iface *net.Interface, nameservers []string) error { + mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration") + r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator") + mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator") return err } @@ -58,13 +65,15 @@ func setDNS(iface *net.Interface, nameservers []string) error { if sds, err := searchDomains(); err == nil { osConfig.SearchDomains = sds } else { - mainLog.Load().Debug().Err(err).Msg("failed to get search domains list") + mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list") } if err := r.SetDNS(osConfig); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to set DNS") + mainLog.Load().Error().Err(err).Msg("Failed to set DNS") return err } + + mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully") return nil } @@ -73,17 +82,22 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { return resetDNS(iface) } +// resetDNS resets DNS servers for the specified interface func resetDNS(iface *net.Interface) error { + mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration") + r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator") + mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator") return err } if err := r.Close(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting") + mainLog.Load().Error().Err(err).Msg("Failed to rollback DNS setting") return err } + + mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully") return nil } @@ -93,8 +107,9 @@ func restoreDNS(iface *net.Interface) (err error) { return err } +// currentDNS returns the current DNS servers for the specified interface func currentDNS(_ *net.Interface) []string { - return resolvconffile.NameServers() + return ctrld.CurrentNameserversFromResolvconf() } // currentStaticDNS returns the current static DNS settings of given interface. diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 8caad63c..013132b6 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -21,30 +21,36 @@ import ( "tailscale.com/health" "tailscale.com/util/dnsname" + "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dns" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" - "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system" +type getDNS func(iface string) []string + // allocate loopback ip // sudo ip a add 127.0.0.2/24 dev lo func allocateIP(ip string) error { + mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address") cmd := exec.Command("ip", "a", "add", ip+"/24", "dev", "lo") if out, err := cmd.CombinedOutput(); err != nil { - mainLog.Load().Error().Err(err).Msgf("allocateIP failed: %s", string(out)) + mainLog.Load().Error().Err(err).Msgf("AllocateIP failed: %s", string(out)) return err } + mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully") return nil } func deAllocateIP(ip string) error { + mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address") cmd := exec.Command("ip", "a", "del", ip+"/24", "dev", "lo") if err := cmd.Run(); err != nil { - mainLog.Load().Error().Err(err).Msg("deAllocateIP failed") + mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed") return err } + mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully") return nil } @@ -56,9 +62,11 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e } func setDNS(iface *net.Interface, nameservers []string) error { + mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration") + r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator") + mainLog.Load().Error().Err(err).Msg("Failed to create dns os configurator") return err } @@ -74,7 +82,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { if sds, err := searchDomains(); err == nil { osConfig.SearchDomains = sds } else { - mainLog.Load().Debug().Err(err).Msg("failed to get search domains list") + mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list") } trySystemdResolve := false if err := r.SetDNS(osConfig); err != nil { @@ -117,6 +125,8 @@ systemdResolve: } mainLog.Load().Debug().Msg("DNS was not set for some reason") } + + mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully") return nil } @@ -126,6 +136,8 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { } func resetDNS(iface *net.Interface) (err error) { + mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration") + defer func() { if err == nil { return @@ -137,7 +149,7 @@ func resetDNS(iface *net.Interface) (err error) { if r, oerr := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name); oerr == nil { _ = r.SetDNS(dns.OSConfig{}) if err := r.Close(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting") + mainLog.Load().Error().Err(err).Msg("Failed to rollback dns setting") return } err = nil @@ -165,18 +177,18 @@ func resetDNS(iface *net.Interface) (err error) { } // TODO(cuonglm): handle DHCPv6 properly. - mainLog.Load().Debug().Msg("checking for IPv6 availability") + mainLog.Load().Debug().Msg("Checking for ipv6 availability") if ctrldnet.IPv6Available(ctx) { c := client6.NewClient() conversation, err := c.Exchange(iface.Name) if err != nil && !errAddrInUse(err) { - mainLog.Load().Debug().Err(err).Msg("could not exchange DHCPv6") + mainLog.Load().Debug().Err(err).Msg("Could not exchange dhcpv6") } for _, packet := range conversation { if packet.Type() == dhcpv6.MessageTypeReply { msg, err := packet.GetInnerMessage() if err != nil { - mainLog.Load().Debug().Err(err).Msg("could not get inner DHCPv6 message") + mainLog.Load().Debug().Err(err).Msg("Could not get inner dhcpv6 message") return nil } nameservers := msg.Options.DNS() @@ -201,7 +213,7 @@ func restoreDNS(iface *net.Interface) (err error) { } func currentDNS(iface *net.Interface) []string { - resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() } + resolvconfFunc := func(_ string) []string { return ctrld.CurrentNameserversFromResolvconf() } for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} { if ns := fn(iface.Name); len(ns) > 0 { return ns diff --git a/cmd/cli/os_others.go b/cmd/cli/os_others.go index 45edf0a9..64b9709c 100644 --- a/cmd/cli/os_others.go +++ b/cmd/cli/os_others.go @@ -2,12 +2,12 @@ package cli -// TODO(cuonglm): implement. +// allocateIP allocates an IP address on the specified interface func allocateIP(ip string) error { return nil } -// TODO(cuonglm): implement. +// deAllocateIP deallocates an IP address from the specified interface func deAllocateIP(ip string) error { return nil } diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 7ebc54a8..d67ca06c 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -1,21 +1,18 @@ package cli import ( - "bytes" "errors" "fmt" "net" "net/netip" - "os" - "os/exec" "slices" "strings" - "sync" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "github.com/Control-D-Inc/ctrld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) @@ -24,11 +21,6 @@ const ( v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` ) -var ( - setDNSOnce sync.Once - resetDNSOnce sync.Once -) - // setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable. func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error { return setDNS(iface, nameservers) @@ -39,49 +31,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { if len(nameservers) == 0 { return errors.New("empty DNS nameservers") } - setDNSOnce.Do(func() { - // If there's a Dns server running, that means we are on AD with Dns feature enabled. - // Configuring the Dns server to forward queries to ctrld instead. - if hasLocalDnsServerRunning() { - mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders") - - file := absHomeDir(windowsForwardersFilename) - mainLog.Load().Debug().Msgf("Using forwarders file: %s", file) - - oldForwardersContent, err := os.ReadFile(file) - if err != nil { - mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file") - } else { - mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent)) - } - - hasLocalIPv6Listener := needLocalIPv6Listener() - mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status") - forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool { - if !hasLocalIPv6Listener { - return false - } - return s == "::1" - }) - mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list") - - if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings") - } else { - mainLog.Load().Debug().Msg("Successfully wrote new forwarders file") - } - - oldForwarders := strings.Split(string(oldForwardersContent), ",") - mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders") - - if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings") - } else { - mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders") - } - } - }) luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { return fmt.Errorf("setDNS: %w", err) @@ -125,25 +75,8 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { return resetDNS(iface) } -// TODO(cuonglm): should we use system API? +// resetDNS resets DNS servers for the specified interface func resetDNS(iface *net.Interface) error { - resetDNSOnce.Do(func() { - // See corresponding comment in setDNS. - if hasLocalDnsServerRunning() { - file := absHomeDir(windowsForwardersFilename) - content, err := os.ReadFile(file) - if err != nil { - mainLog.Load().Error().Err(err).Msg("could not read forwarders settings") - return - } - nameservers := strings.Split(string(content), ",") - if err := removeDnsServerForwarders(nameservers); err != nil { - mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings") - return - } - } - }) - luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { return fmt.Errorf("resetDNS: %w", err) @@ -161,7 +94,7 @@ func resetDNS(iface *net.Interface) error { // restoreDNS restores the DNS settings of the given interface. // this should only be executed upon turning off the ctrld service. func restoreDNS(iface *net.Interface) (err error) { - if nss := savedStaticNameservers(iface); len(nss) > 0 { + if nss := ctrld.SavedStaticNameservers(iface); len(nss) > 0 { v4ns := make([]string, 0, 2) v6ns := make([]string, 0, 2) for _, ns := range nss { @@ -178,24 +111,24 @@ func restoreDNS(iface *net.Interface) (err error) { } if len(v4ns) > 0 { - mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns) + mainLog.Load().Debug().Msgf("Restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns) if err := setDNS(iface, v4ns); err != nil { return fmt.Errorf("restoreDNS (IPv4): %w", err) } } else { - mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name) + mainLog.Load().Debug().Msgf("Restoring IPv4 DHCP for interface %q", iface.Name) if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil { return fmt.Errorf("restoreDNS (IPv4 clear): %w", err) } } if len(v6ns) > 0 { - mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns) + mainLog.Load().Debug().Msgf("Restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns) if err := setDNS(iface, v6ns); err != nil { return fmt.Errorf("restoreDNS (IPv6): %w", err) } } else { - mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name) + mainLog.Load().Debug().Msgf("Restoring IPv6 DHCP for interface %q", iface.Name) if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { return fmt.Errorf("restoreDNS (IPv6 clear): %w", err) } @@ -204,15 +137,16 @@ func restoreDNS(iface *net.Interface) (err error) { return err } +// currentDNS returns the current DNS servers for the specified interface func currentDNS(iface *net.Interface) []string { luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get interface LUID") + mainLog.Load().Error().Err(err).Msg("Failed to get interface LUID") return nil } nameservers, err := luid.DNS() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get interface DNS") + mainLog.Load().Error().Err(err).Msg("Failed to get interface DNS") return nil } ns := make([]string, 0, len(nameservers)) @@ -240,7 +174,7 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) { interfaceKeyPath := path + guid.String() k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name) + mainLog.Load().Debug().Err(err).Msgf("Failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name) continue } func() { @@ -248,11 +182,11 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) { for _, keyName := range []string{"NameServer", "ProfileNameServer"} { value, _, err := k.GetStringValue(keyName) if err != nil && !errors.Is(err, registry.ErrNotExist) { - mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName) + mainLog.Load().Debug().Err(err).Msgf("Error reading %s registry key", keyName) continue } if len(value) > 0 { - mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value) + mainLog.Load().Debug().Msgf("Found static DNS for interface %q: %s", iface.Name, value) parsed := parseDNSServers(value) for _, pns := range parsed { if !slices.Contains(ns, pns) { @@ -264,7 +198,7 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) { }() } if len(ns) == 0 { - mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name) + mainLog.Load().Debug().Msgf("No static DNS values found for interface %q", iface.Name) } return ns, nil } @@ -284,49 +218,3 @@ func parseDNSServers(val string) []string { } return servers } - -// addDnsServerForwarders adds given nameservers to DNS server forwarders list, -// and also removing old forwarders if provided. -func addDnsServerForwarders(nameservers, old []string) error { - newForwardersMap := make(map[string]struct{}) - newForwarders := make([]string, len(nameservers)) - for i := range nameservers { - newForwardersMap[nameservers[i]] = struct{}{} - newForwarders[i] = fmt.Sprintf("%q", nameservers[i]) - } - oldForwarders := old[:0] - for _, fwd := range old { - if _, ok := newForwardersMap[fwd]; !ok { - oldForwarders = append(oldForwarders, fwd) - } - } - // NOTE: It is important to add new forwarder before removing old one. - // Testing on Windows Server 2022 shows that removing forwarder1 - // then adding forwarder2 sometimes ends up adding both of them - // to the forwarders list. - cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ",")) - if len(oldForwarders) > 0 { - cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ",")) - } - if out, err := powershell(cmd); err != nil { - return fmt.Errorf("%w: %s", err, string(out)) - } - return nil -} - -// removeDnsServerForwarders removes given nameservers from DNS server forwarders list. -func removeDnsServerForwarders(nameservers []string) error { - for _, ns := range nameservers { - cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns) - if out, err := powershell(cmd); err != nil { - return fmt.Errorf("%w: %s", err, string(out)) - } - } - return nil -} - -// powershell runs the given powershell command. -func powershell(cmd string) ([]byte, error) { - out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() - return bytes.TrimSpace(out), err -} diff --git a/cmd/cli/os_windows_test.go b/cmd/cli/os_windows_test.go index 40be5ed2..054b77cc 100644 --- a/cmd/cli/os_windows_test.go +++ b/cmd/cli/os_windows_test.go @@ -1,8 +1,10 @@ package cli import ( + "bytes" "fmt" "net" + "os/exec" "slices" "strings" "testing" @@ -66,3 +68,9 @@ func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) { } return ns, nil } + +// powershell runs the given powershell command. +func powershell(cmd string) ([]byte, error) { + out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() + return bytes.TrimSpace(out), err +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 76f7c366..069b8835 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "io/fs" "math/rand" "net" @@ -24,7 +25,6 @@ import ( "github.com/Masterminds/semver/v3" "github.com/kardianos/service" - "github.com/rs/zerolog" "github.com/spf13/viper" "golang.org/x/sync/singleflight" "tailscale.com/net/netmon" @@ -34,8 +34,6 @@ import ( "github.com/Control-D-Inc/ctrld/internal/clientinfo" "github.com/Control-D-Inc/ctrld/internal/controld" "github.com/Control-D-Inc/ctrld/internal/dnscache" - "github.com/Control-D-Inc/ctrld/internal/router" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) const ( @@ -82,13 +80,6 @@ var logf = func(format string, args ...any) { //lint:ignore U1000 use in newLoopbackOSConfigurator var noopLogf = func(format string, args ...any) {} -var svcConfig = &service.Config{ - Name: ctrldServiceName, - DisplayName: "Control-D Helper Service", - Description: "A highly configurable, multi-protocol DNS forwarding proxy", - Option: service.KeyValue{}, -} - var useSystemdResolved = false type prog struct { @@ -101,8 +92,9 @@ type prog struct { apiReloadCh chan *ctrld.Config apiForceReloadCh chan struct{} apiForceReloadGroup singleflight.Group - logConn net.Conn + logConn io.WriteCloser cs *controlServer + logger atomic.Pointer[ctrld.Logger] csSetDnsDone chan struct{} csSetDnsOk bool dnsWg sync.WaitGroup @@ -119,7 +111,6 @@ type prog struct { sema semaphore ciTable *clientinfo.Table um *upstreamMonitor - router router.Router ptrLoopGuard *loopGuard lanLoopGuard *loopGuard metricsQueryStats atomic.Bool @@ -130,8 +121,6 @@ type prog struct { internalLogSent time.Time runningIface string requiredMultiNICsConfig bool - adDomain string - runningOnDomainController bool selfUninstallMu sync.Mutex refusedQueryCount int @@ -151,7 +140,7 @@ type prog struct { onStopped []func() } -func (p *prog) Start(s service.Service) error { +func (p *prog) Start(_ service.Service) error { go p.runWait() return nil } @@ -165,7 +154,6 @@ func (p *prog) runWait() { notifyReloadSigCh(reloadSigCh) reload := false - logger := mainLog.Load() for { reloadCh := make(chan struct{}) done := make(chan struct{}) @@ -178,9 +166,9 @@ func (p *prog) runWait() { var newCfg *ctrld.Config select { case sig := <-reloadSigCh: - logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) + p.Notice().Msgf("Got signal: %s, reloading...", sig.String()) case <-p.reloadCh: - logger.Notice().Msg("reloading...") + p.Notice().Msg("Reloading...") case apiCfg := <-p.apiReloadCh: newCfg = apiCfg case <-p.stopCh: @@ -203,18 +191,18 @@ func (p *prog) runWait() { } v.SetConfigFile(confFile) if err := v.ReadInConfig(); err != nil { - logger.Err(err).Msg("could not read new config") + p.Error().Err(err).Msg("Could not read new config") waitOldRunDone() continue } if err := v.Unmarshal(&newCfg); err != nil { - logger.Err(err).Msg("could not unmarshal new config") + p.Error().Err(err).Msg("Could not unmarshal new config") waitOldRunDone() continue } if cdUID != "" { if rc, err := processCDFlags(newCfg); err != nil { - logger.Err(err).Msg("could not fetch ControlD config") + p.Error().Err(err).Msg("Could not fetch controld config") waitOldRunDone() continue } else { @@ -244,28 +232,29 @@ func (p *prog) runWait() { } } if err := validateConfig(newCfg); err != nil { - logger.Err(err).Msg("invalid config") + p.Error().Err(err).Msg("Invalid config") continue } addExtraSplitDnsRule(newCfg) if err := writeConfigFile(newCfg); err != nil { - logger.Err(err).Msg("could not write new config") + p.Error().Err(err).Msg("Could not write new config") } // This needs to be done here, otherwise, the DNS handler may observe an invalid // upstream config because its initialization function have not been called yet. - mainLog.Load().Debug().Msg("setup upstream with new config") + p.Debug().Msg("Setup upstream with new config") p.setupUpstream(newCfg) p.mu.Lock() *p.cfg = *newCfg p.mu.Unlock() - logger.Notice().Msg("reloading config successfully") + p.Notice().Msg("Reloading config successfully") select { case p.reloadDoneCh <- struct{}{}: + p.Debug().Msg("Reload done signal sent") default: } } @@ -277,18 +266,14 @@ func (p *prog) preRun() { p.requiredMultiNICsConfig = requiredMultiNICsConfig() } p.runningIface = iface + p.logger.Store(mainLog.Load()) } func (p *prog) postRun() { if !service.Interactive() { - if runtime.GOOS == "windows" { - isDC, roleInt := isRunningOnDomainController() - p.runningOnDomainController = isDC - mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt) - } p.resetDNS(false, false) - ns := ctrld.InitializeOsResolver(false) - mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) + ns := ctrld.InitializeOsResolver(ctrld.LoggerCtx(context.Background(), p.logger.Load()), false) + p.Debug().Msgf("Initialized os resolver with nameservers: %v", ns) p.setDNS() p.csSetDnsDone <- struct{}{} close(p.csSetDnsDone) @@ -305,31 +290,32 @@ func (p *prog) apiConfigReload() { ticker := time.NewTicker(timeDurationOrDefault(p.cfg.Service.RefetchTime, 3600) * time.Second) defer ticker.Stop() - logger := mainLog.Load().With().Str("mode", "api-reload").Logger() - logger.Debug().Msg("starting custom config reload timer") + logger := p.logger.Load().With().Str("mode", "api-reload") + logger.Debug().Msg("Starting custom config reload timer") lastUpdated := time.Now().Unix() curVerStr := curVersion() curVer, err := semver.NewVersion(curVerStr) isStable := curVer != nil && curVer.Prerelease() == "" if err != nil || !isStable { - l := mainLog.Load().Warn() + l := p.Warn() if err != nil { l = l.Err(err) } - l.Msgf("current version is not stable, skipping self-upgrade: %s", curVerStr) + l.Msgf("Current version is not stable, skipping self-upgrade: %s", curVerStr) } - doReloadApiConfig := func(forced bool, logger zerolog.Logger) { - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + doReloadApiConfig := func(forced bool, logger *ctrld.Logger) { + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) + resolverConfig, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev) selfUninstallCheck(err, p, logger) if err != nil { - logger.Warn().Err(err).Msg("could not fetch resolver config") + logger.Warn().Err(err).Msg("Could not fetch resolver config") return } // Performing self-upgrade check for production version. if isStable { - _ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger) + _ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, logger) } if resolverConfig.DeactivationPin != nil { @@ -337,9 +323,9 @@ func (p *prog) apiConfigReload() { curDeactivationPin := cdDeactivationPin.Load() switch { case curDeactivationPin != defaultDeactivationPin: - logger.Debug().Msg("saving deactivation pin") + logger.Debug().Msg("Saving deactivation pin") case curDeactivationPin != newDeactivationPin: - logger.Debug().Msg("update deactivation pin") + logger.Debug().Msg("Update deactivation pin") } cdDeactivationPin.Store(newDeactivationPin) } else { @@ -362,7 +348,7 @@ func (p *prog) apiConfigReload() { } if noCustomConfig && !noExcludeListChanged { - logger.Debug().Msg("exclude list changes detected, reloading...") + logger.Debug().Msg("Exclude list changes detected, reloading...") p.apiReloadCh <- nil return } @@ -377,22 +363,22 @@ func (p *prog) apiConfigReload() { cfgErr = validateConfig(cfg) } if cfgErr != nil { - logger.Warn().Err(err).Msg("skipping invalid custom config") - if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil { - logger.Error().Err(err).Msg("could not mark custom last update failed") + logger.Warn().Err(err).Msg("Skipping invalid custom config") + if _, err := controld.UpdateCustomLastFailed(loggerCtx, cdUID, appVersion, cdDev, true); err != nil { + logger.Error().Err(err).Msg("Could not mark custom last update failed") } return } - logger.Debug().Msg("custom config changes detected, reloading...") + logger.Debug().Msg("Custom config changes detected, reloading...") p.apiReloadCh <- cfg } else { - logger.Debug().Msg("custom config does not change") + logger.Debug().Msg("Custom config does not change") } } for { select { case <-p.apiForceReloadCh: - doReloadApiConfig(true, logger.With().Bool("forced", true).Logger()) + doReloadApiConfig(true, logger.With().Bool("forced", true)) case <-ticker.C: doReloadApiConfig(false, logger) case <-p.stopCh: @@ -405,22 +391,23 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) { localUpstreams := make([]string, 0, len(cfg.Upstream)) ptrNameservers := make([]string, 0, len(cfg.Upstream)) isControlDUpstream := false + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) for n := range cfg.Upstream { uc := cfg.Upstream[n] sdns := uc.Type == ctrld.ResolverTypeSDNS - uc.Init() + uc.Init(loggerCtx) if sdns { - mainLog.Load().Debug().Msgf("initialized DNS Stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) + p.Debug().Msgf("Initialized dns stamps with endpoint: %s, type: %s", uc.Endpoint, uc.Type) } isControlDUpstream = isControlDUpstream || uc.IsControlD() if uc.BootstrapIP == "" { - uc.SetupBootstrapIP() - mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) + uc.SetupBootstrapIP(ctrld.LoggerCtx(context.Background(), p.logger.Load())) + p.Info().Msgf("Bootstrap ips for upstream.%s: %q", n, uc.BootstrapIPs()) } else { - mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) + p.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap ip for upstream.%s", n) } uc.SetCertPool(rootCertPool) - go uc.Ping() + go uc.Ping(loggerCtx) if canBeLocalUpstream(uc.Domain) { localUpstreams = append(localUpstreams, upstreamPrefix+n) @@ -458,9 +445,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.csSetDnsDone = make(chan struct{}, 1) p.registerControlServerHandler() if err := p.cs.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start control server") + p.Warn().Err(err).Msg("Could not start control server") } - mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr) + p.Debug().Msgf("Control server started: %s", p.cs.addr) } } p.onStartedDone = make(chan struct{}) @@ -472,7 +459,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled") + p.Error().Err(err).Msg("Failed to create cacher, caching is disabled") } else { p.cache = cacher p.cacheFlushDomainsMap = make(map[string]struct{}, 256) @@ -481,10 +468,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } } - if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() { - mainLog.Load().Debug().Msgf("active directory domain: %s", domain) - p.adDomain = domain - } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) @@ -493,14 +476,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { - mainLog.Load().Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr") + p.Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("Invalid cidr") continue } nc.IPNets = append(nc.IPNets, ipNet) } } - p.um = newUpstreamMonitor(p.cfg) + p.um = newUpstreamMonitor(p.cfg, p.logger.Load()) if !reload { p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} @@ -513,7 +496,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } p.setupUpstream(p.cfg) - p.setupClientInfoDiscover(defaultRouteIP()) + p.setupClientInfoDiscover() } // context for managing spawn goroutines. @@ -533,8 +516,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !reload { go func() { // Start network monitoring - if err := p.monitorNetworkChanges(); err != nil { - mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") + if err := p.monitorNetworkChanges(ctx); err != nil { + p.Error().Err(err).Msg("Failed to start network monitoring") } }() } @@ -546,14 +529,17 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { listenerConfig := p.cfg.Listener[listenerNum] upstreamConfig := p.cfg.Upstream[listenerNum] if upstreamConfig == nil { - mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + p.Warn().Msgf("No default upstream for: [listener.%s]", listenerNum) } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) - mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum); err != nil { - mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) + p.Info().Msgf("Starting dns server on listener.%s: %s", listenerNum, addr) + // serveCtx uses Background() context so listeners survive between reloads. + // Changes to listeners config require a service restart, not just reload. + serveCtx := context.Background() + if err := p.serveDNS(serveCtx, listenerNum); err != nil { + p.Fatal().Err(err).Msgf("Unable to start dns proxy on listener.%s", listenerNum) } - mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) + p.Debug().Msgf("End of serveDNS listener.%s: %s", listenerNum, addr) }(listenerNum) } go func() { @@ -598,7 +584,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !reload { // Stop writing log to unix socket. - consoleWriter.Out = os.Stdout + consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel) p.initLogging(false) if p.logConn != nil { _ = p.logConn.Close() @@ -610,19 +596,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } // setupClientInfoDiscover performs necessary works for running client info discover. -func (p *prog) setupClientInfoDiscover(selfIP string) { - p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers) +func (p *prog) setupClientInfoDiscover() { + selfIP := p.defaultRouteIP() + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers, p.logger.Load()) if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + p.Debug().Msgf("Watching custom lease file: %s", leaseFile) format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } - if leaseFiles := dnsmasq.AdditionalLeaseFiles(); len(leaseFiles) > 0 { - mainLog.Load().Debug().Msgf("watching additional lease files: %v", leaseFiles) - for _, leaseFile := range leaseFiles { - p.ciTable.AddLeaseFile(leaseFile, ctrld.Dnsmasq) - } - } } // runClientInfoDiscover runs the client info discover. @@ -636,18 +617,18 @@ func (p *prog) metricsEnabled() bool { return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != "" } -func (p *prog) Stop(s service.Service) error { +func (p *prog) Stop(_ service.Service) error { p.stopDnsWatchers() - mainLog.Load().Debug().Msg("dns watchers stopped") + p.Debug().Msg("Dns watchers stopped") for _, f := range p.onStopped { f() } - mainLog.Load().Debug().Msg("finish running onStopped functions") + p.Debug().Msg("Finish running onStopped functions") defer func() { - mainLog.Load().Info().Msg("Service stopped") + p.Info().Msg("Service stopped") }() if err := p.deAllocateIP(); err != nil { - mainLog.Load().Error().Err(err).Msg("de-allocate ip failed") + p.Error().Err(err).Msg("De-allocate ip failed") return err } if deactivationPinSet() { @@ -659,16 +640,16 @@ func (p *prog) Stop(s service.Service) error { // No valid pin code was checked, that mean we are stopping // because of OS signal sent directly from someone else. // In this case, restarting ctrld service by ourselves. - mainLog.Load().Debug().Msgf("receiving stopping signal without valid pin code") - mainLog.Load().Debug().Msgf("self restarting ctrld service") + p.Debug().Msgf("Receiving stopping signal without valid pin code") + p.Debug().Msgf("Self restarting ctrld service") if exe, err := os.Executable(); err == nil { cmd := exec.Command(exe, "restart") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to run self restart command") + p.Error().Err(err).Msg("Failed to run self restart command") } } else { - mainLog.Load().Error().Err(err).Msg("failed to self restart ctrld service") + p.Error().Err(err).Msg("Failed to self restart ctrld service") } os.Exit(deactivationPinInvalidExitCode) } @@ -729,9 +710,6 @@ func (p *prog) setDNS() { ns = "127.0.0.1" case lc.Port != 53: ns = "127.0.0.1" - if resolver := router.LocalResolverIP(); resolver != "" { - ns = resolver - } default: // If we ever reach here, it means ctrld is running on lc.IP port 53, // so we could just use lc.IP as nameserver. @@ -769,7 +747,7 @@ func (p *prog) setDNS() { p.dnsWg.Add(1) go func() { defer p.dnsWg.Done() - p.watchResolvConf(netIface, servers, setResolvConf) + p.watchResolvConf(netIface, servers, p.setResolvConf) }() } if p.dnsWatchdogEnabled() { @@ -786,7 +764,7 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In return } - logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger := p.logger.Load().With().Str("iface", p.runningIface) const maxDNSRetryAttempts = 3 const retryDelay = 1 * time.Second @@ -799,33 +777,33 @@ func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.In } if attempt < maxDNSRetryAttempts { // Try to find a different working interface - newIface := findWorkingInterface(p.runningIface) + newIface := p.findWorkingInterface() if newIface != p.runningIface { p.runningIface = newIface - logger = mainLog.Load().With().Str("iface", p.runningIface).Logger() - logger.Info().Msg("switched to new interface") + logger = p.logger.Load().With().Str("iface", p.runningIface) + logger.Info().Msg("Switched to new interface") continue } - logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...") + logger.Warn().Err(err).Int("attempt", attempt).Msg("Could not get interface, retrying...") time.Sleep(retryDelay) continue } - logger.Error().Err(err).Msg("could not get interface after all attempts") + logger.Error().Err(err).Msg("Could not get interface after all attempts") return } - if err := setupNetworkManager(); err != nil { - logger.Error().Err(err).Msg("could not patch NetworkManager") + if err := p.setupNetworkManager(); err != nil { + logger.Error().Err(err).Msg("Could not patch networkmanager") return } runningIface = netIface - logger.Debug().Msg("setting DNS for interface") + logger.Debug().Msg("Setting dns for interface") if err := setDNS(netIface, nameservers); err != nil { - logger.Error().Err(err).Msgf("could not set DNS for interface") + logger.Error().Err(err).Msgf("Could not set dns for interface") return } - logger.Debug().Msg("setting DNS successfully") + logger.Debug().Msg("Setting dns successfully") return } @@ -854,7 +832,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { return } - mainLog.Load().Debug().Msg("start DNS settings watchdog") + p.Debug().Msg("Start dns settings watchdog") ns := nameservers slices.Sort(ns) @@ -865,34 +843,34 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { case <-p.dnsWatcherStopCh: return case <-p.stopCh: - mainLog.Load().Debug().Msg("stop dns watchdog") + p.Debug().Msg("Stop dns watchdog") return case <-ticker.C: if p.recoveryRunning.Load() { return } - if dnsChanged(iface, ns) { - mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings") + if p.dnsChanged(iface, ns) { + p.Debug().Msg("DNS settings were changed, re-applying settings") // Check if the interface already has static DNS servers configured. // currentStaticDNS is an OS-dependent helper that returns the current static DNS. staticDNS, err := currentStaticDNS(iface) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name) + p.Debug().Err(err).Msgf("Failed to get static DNS for interface %s", iface.Name) } else if len(staticDNS) > 0 { //filter out loopback addresses staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { return net.ParseIP(s).IsLoopback() }) // if we have a static config and no saved IPs already, save them - if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 { + if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(iface)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(iface); err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name) + p.Debug().Err(err).Msgf("Failed to save static DNS for interface %s", iface.Name) } } } if err := setDNS(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings") + p.Error().Err(err).Str("iface", iface.Name).Msgf("Could not re-apply DNS settings") } } if p.requiredMultiNICsConfig { @@ -901,31 +879,31 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) { ifaceName = iface.Name } withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error { - if dnsChanged(i, ns) { + if p.dnsChanged(i, ns) { // Check if the interface already has static DNS servers configured. // currentStaticDNS is an OS-dependent helper that returns the current static DNS. staticDNS, err := currentStaticDNS(i) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name) + p.Debug().Err(err).Msgf("Failed to get static DNS for interface %s", i.Name) } else if len(staticDNS) > 0 { //filter out loopback addresses staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool { return net.ParseIP(s).IsLoopback() }) // if we have a static config and no saved IPs already, save them - if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 { + if len(staticDNS) > 0 && len(ctrld.SavedStaticNameservers(i)) == 0 { // Save these static DNS values so that they can be restored later. if err := saveCurrentStaticDNS(i); err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name) + p.Debug().Err(err).Msgf("Failed to save static DNS for interface %s", i.Name) } } } if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil { - mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings") + p.Error().Err(err).Str("iface", i.Name).Msgf("Could not re-apply DNS settings") } else { - mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name) + p.Debug().Msgf("Re-applying DNS for interface %q successfully", i.Name) } } return nil @@ -955,18 +933,18 @@ func (p *prog) resetDNS(isStart bool, restoreStatic bool) { // Otherwise, we restore the saved configuration (if any) or reset to DHCP. func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) { if p.runningIface == "" { - mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") + p.Debug().Msg("No running interface, skipping resetDNS") return } - logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger := p.logger.Load().With().Str("iface", p.runningIface) netIface, err := netInterface(p.runningIface) if err != nil { - logger.Error().Err(err).Msg("could not get interface") + logger.Error().Err(err).Msg("Could not get interface") return } runningIface = netIface - if err := restoreNetworkManager(); err != nil { - logger.Error().Err(err).Msg("could not restore NetworkManager") + if err := p.restoreNetworkManager(); err != nil { + logger.Error().Err(err).Msg("Could not restore NetworkManager") return } @@ -974,7 +952,7 @@ func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runnin if isStart { current, err := currentStaticDNS(netIface) if err != nil { - logger.Warn().Err(err).Msg("unable to obtain current static DNS configuration; proceeding to restore saved config") + logger.Warn().Err(err).Msg("Unable to obtain current static DNS configuration; proceeding to restore saved config") } else if len(current) > 0 { // If any static DNS value is not our own listener, assume an admin override. hasManualConfig := false @@ -992,17 +970,17 @@ func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runnin } // Default logic: if there is a saved static DNS configuration, restore it. - saved := savedStaticNameservers(netIface) + saved := ctrld.SavedStaticNameservers(netIface) if len(saved) > 0 && restoreStatic { logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved) if err := setDNS(netIface, saved); err != nil { - logger.Error().Err(err).Msgf("failed to restore static DNS config on interface %q", netIface.Name) + logger.Error().Err(err).Msgf("Failed to restore static DNS config on interface %q", netIface.Name) return } } else { logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name) if err := resetDNS(netIface); err != nil { - logger.Error().Err(err).Msgf("failed to reset DNS to DHCP on interface %q", netIface.Name) + logger.Error().Err(err).Msgf("Failed to reset DNS to DHCP on interface %q", netIface.Name) return } } @@ -1013,16 +991,16 @@ func (p *prog) logInterfacesState() { withEachPhysicalInterfaces("", "", func(i *net.Interface) error { addrs, err := i.Addrs() if err != nil { - mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses") + p.Warn().Str("interface", i.Name).Err(err).Msg("Failed to get addresses") } nss, err := currentStaticDNS(i) if err != nil { - mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS") + p.Warn().Str("interface", i.Name).Err(err).Msg("Failed to get DNS") } if len(nss) == 0 { nss = currentDNS(i) } - mainLog.Load().Debug(). + p.Debug(). Any("addrs", addrs). Strs("nameservers", nss). Int("index", i.Index). @@ -1032,7 +1010,8 @@ func (p *prog) logInterfacesState() { } // findWorkingInterface looks for a network interface with a valid IP configuration -func findWorkingInterface(currentIface string) string { +func (p *prog) findWorkingInterface() string { + currentIface := p.runningIface // Helper to check if IP is valid (not link-local) isValidIP := func(ip net.IP) bool { return ip != nil && @@ -1050,7 +1029,7 @@ func findWorkingInterface(currentIface string) string { addrs, err := iface.Addrs() if err != nil { - mainLog.Load().Debug(). + p.Debug(). Str("interface", iface.Name). Err(err). Msg("failed to get interface addresses") @@ -1069,13 +1048,15 @@ func findWorkingInterface(currentIface string) string { } // Get default route interface + foundDefaultRoute := false defaultRoute, err := netmon.DefaultRoute() if err != nil { - mainLog.Load().Debug(). + p.Debug(). Err(err). Msg("failed to get default route") } else { - mainLog.Load().Debug(). + foundDefaultRoute = true + p.Debug(). Str("default_route_iface", defaultRoute.InterfaceName). Msg("found default route") } @@ -1083,7 +1064,7 @@ func findWorkingInterface(currentIface string) string { // Get all interfaces ifaces, err := net.Interfaces() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to list network interfaces") + p.Error().Err(err).Msg("Failed to list network interfaces") return currentIface // Return current interface as fallback } @@ -1111,9 +1092,9 @@ func findWorkingInterface(currentIface string) string { } // Found working physical interface - if err == nil && defaultRoute.InterfaceName == iface.Name { + if foundDefaultRoute && defaultRoute.InterfaceName == iface.Name { // Found interface with default route - use it immediately - mainLog.Load().Info(). + p.Info(). Str("old_iface", currentIface). Str("new_iface", iface.Name). Msg("switching to interface with default route") @@ -1134,7 +1115,7 @@ func findWorkingInterface(currentIface string) string { // Return interfaces in order of preference: // 1. Current interface if it's still valid if currentIfaceValid { - mainLog.Load().Debug(). + p.Debug(). Str("interface", currentIface). Msg("keeping current interface") return currentIface @@ -1142,7 +1123,7 @@ func findWorkingInterface(currentIface string) string { // 2. First working interface found if firstWorkingIface != "" { - mainLog.Load().Info(). + p.Info(). Str("old_iface", currentIface). Str("new_iface", firstWorkingIface). Msg("switching to first working physical interface") @@ -1150,9 +1131,9 @@ func findWorkingInterface(currentIface string) string { } // 3. Fall back to current interface if nothing else works - mainLog.Load().Warn(). + p.Warn(). Str("current_iface", currentIface). - Msg("no working physical interface found, keeping current") + Msg("No working physical interface found, keeping current") return currentIface } @@ -1168,28 +1149,6 @@ func randomPort() int { return n } -// runLogServer starts a unix listener, use by startCmd to gather log from runCmd. -func runLogServer(sockPath string) net.Conn { - addr, err := net.ResolveUnixAddr("unix", sockPath) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("invalid log sock path") - return nil - } - ln, err := net.ListenUnix("unix", addr) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not listen log socket") - return nil - } - defer ln.Close() - - server, err := ln.Accept() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not accept connection") - return nil - } - return server -} - func errAddrInUse(err error) bool { var opErr *net.OpError if errors.As(err, &opErr) { @@ -1272,7 +1231,7 @@ func ifaceFirstPrivateIP(iface *net.Interface) string { } // defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6. -func defaultRouteIP() string { +func (p *prog) defaultRouteIP() string { dr, err := netmon.DefaultRoute() if err != nil { return "" @@ -1281,9 +1240,9 @@ func defaultRouteIP() string { if err != nil { return "" } - mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface") + p.Debug().Str("iface", drNetIface.Name).Msg("Checking default route interface") if ip := ifaceFirstPrivateIP(drNetIface); ip != "" { - mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface") + p.Debug().Str("ip", ip).Msg("Found ip with default route interface") return ip } @@ -1308,7 +1267,7 @@ func defaultRouteIP() string { }) if len(addrs) == 0 { - mainLog.Load().Warn().Msg("no default route IP found") + p.Warn().Msg("No default route IP found") return "" } sort.Slice(addrs, func(i, j int) bool { @@ -1316,7 +1275,7 @@ func defaultRouteIP() string { }) ip := addrs[0].String() - mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP") + p.Debug().Str("ip", ip).Msg("Found LAN interface IP") return ip } @@ -1331,8 +1290,8 @@ func canBeLocalUpstream(addr string) bool { // withEachPhysicalInterfaces runs the function f with each physical interfaces, excluding // the interface that matches excludeIfaceName. The context is used to clarify the // log message when error happens. -func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) { - validIfacesMap := validInterfacesMap() +func withEachPhysicalInterfaces(excludeIfaceName, contextStr string, f func(i *net.Interface) error) { + validIfacesMap := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load())) netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { // Skip loopback/virtual/down interface. if i.IsLoopback() || len(i.HardwareAddr) == 0 { @@ -1344,7 +1303,7 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. } netIface := i.Interface if patched, err := patchNetIfaceName(netIface); err != nil { - mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name") + mainLog.Load().Debug().Err(err).Msg("Failed to patch net interface name") return } else if !patched { // The interface is not functional, skipping. @@ -1356,11 +1315,11 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. } // TODO: investigate whether we should report this error? if err := f(netIface); err == nil { - if context != "" { - mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name) + if contextStr != "" { + mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", contextStr, i.Name) } } else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) { - mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name) + mainLog.Load().Err(err).Msgf("%s for interface %q failed", contextStr, i.Name) } }) } @@ -1381,7 +1340,7 @@ var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not // Only works on Windows and Mac. func saveCurrentStaticDNS(iface *net.Interface) error { if iface == nil { - mainLog.Load().Debug().Msg("could not save current static DNS settings for nil interface") + mainLog.Load().Debug().Msg("Could not save current static DNS settings for nil interface") return nil } switch runtime.GOOS { @@ -1389,14 +1348,14 @@ func saveCurrentStaticDNS(iface *net.Interface) error { default: return errSaveCurrentStaticDNSNotSupported } - file := savedStaticDnsSettingsFilePath(iface) + file := ctrld.SavedStaticDnsSettingsFilePath(iface) ns, err := currentStaticDNS(iface) if err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name) + mainLog.Load().Warn().Err(err).Msgf("Could not get current static DNS settings for %q", iface.Name) return err } if len(ns) == 0 { - mainLog.Load().Debug().Msgf("no static DNS settings for %q, removing old static DNS settings file", iface.Name) + mainLog.Load().Debug().Msgf("No static DNS settings for %q, removing old static DNS settings file", iface.Name) _ = os.Remove(file) // removing old static DNS settings return nil } @@ -1411,47 +1370,15 @@ func saveCurrentStaticDNS(iface *net.Interface) error { return nil } if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { - mainLog.Load().Warn().Err(err).Msgf("could not remove old static DNS settings file: %s", file) + mainLog.Load().Warn().Err(err).Msgf("Could not remove old static DNS settings file: %s", file) } nss := strings.Join(ns, ",") mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss) if err := os.WriteFile(file, []byte(nss), 0600); err != nil { - mainLog.Load().Err(err).Msgf("could not save DNS settings for iface: %s", iface.Name) + mainLog.Load().Err(err).Msgf("Could not save DNS settings for iface: %s", iface.Name) return err } - mainLog.Load().Debug().Msgf("save DNS settings for interface %q successfully", iface.Name) - return nil -} - -// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface. -func savedStaticDnsSettingsFilePath(iface *net.Interface) string { - if iface == nil { - return "" - } - return absHomeDir(".dns_" + iface.Name) -} - -// savedStaticNameservers returns the static DNS nameservers of the given interface. -// -//lint:ignore U1000 use in os_windows.go and os_darwin.go -func savedStaticNameservers(iface *net.Interface) []string { - if iface == nil { - mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface") - return nil - } - file := savedStaticDnsSettingsFilePath(iface) - if data, _ := os.ReadFile(file); len(data) > 0 { - saveValues := strings.Split(string(data), ",") - returnValues := []string{} - // check each one, if its in loopback range, remove it - for _, v := range saveValues { - if net.ParseIP(v).IsLoopback() { - continue - } - returnValues = append(returnValues, v) - } - return returnValues - } + mainLog.Load().Debug().Msgf("Save DNS settings for interface %q successfully", iface.Name) return nil } @@ -1459,21 +1386,21 @@ func savedStaticNameservers(iface *net.Interface) []string { // It returns false for a nil iface. // // The caller must sort the nameservers before calling this function. -func dnsChanged(iface *net.Interface, nameservers []string) bool { +func (p *prog) dnsChanged(iface *net.Interface, nameservers []string) bool { if iface == nil { return false } curNameservers, _ := currentStaticDNS(iface) slices.Sort(curNameservers) if !slices.Equal(curNameservers, nameservers) { - mainLog.Load().Debug().Msgf("interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers) + p.Debug().Msgf("Interface %q current DNS settings: %v, expected: %v", iface.Name, curNameservers, nameservers) return true } return false } // selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then. -func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { +func selfUninstallCheck(uninstallErr error, p *prog, logger *ctrld.Logger) { var uer *controld.ErrorResponse if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { p.stopDnsWatchers() @@ -1488,9 +1415,9 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { // // The callers must ensure curVer and logger are non-nil. // Returns true if upgrade is allowed, false otherwise. -func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { +func shouldUpgrade(vt string, cv *semver.Version, logger *ctrld.Logger) bool { if vt == "" { - logger.Debug().Msg("no version target set, skipped checking self-upgrade") + logger.Debug().Msg("No version target set, skipped checking self-upgrade") return false } vts := vt @@ -1499,7 +1426,7 @@ func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { } targetVer, err := semver.NewVersion(vts) if err != nil { - logger.Warn().Err(err).Msgf("invalid target version, skipped self-upgrade: %s", vt) + logger.Warn().Err(err).Msgf("Invalid target version, skipped self-upgrade: %s", vt) return false } @@ -1508,7 +1435,7 @@ func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { logger.Warn(). Str("target", vt). Str("current", cv.String()). - Msgf("major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major()) + Msgf("Major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major()) return false } @@ -1516,7 +1443,7 @@ func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { logger.Debug(). Str("target", vt). Str("current", cv.String()). - Msgf("target version is not greater than current one, skipped self-upgrade") + Msgf("Target version is not greater than current one, skipped self-upgrade") return false } @@ -1525,19 +1452,19 @@ func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool { // performUpgrade executes the self-upgrade command. // Returns true if upgrade was initiated successfully, false otherwise. -func performUpgrade(vt string) bool { +func performUpgrade(vt string, logger *ctrld.Logger) bool { exe, err := os.Executable() if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade") + logger.Error().Err(err).Msg("Failed to get executable path, skipped self-upgrade") return false } cmd := exec.Command(exe, "upgrade", "prod", "-vv") cmd.SysProcAttr = sysProcAttrForDetachedChildProcess() if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade") + logger.Error().Err(err).Msg("Failed to start self-upgrade") return false } - mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vt) + logger.Debug().Msgf("Self-upgrade triggered, version target: %s", vt) return true } @@ -1547,9 +1474,9 @@ func performUpgrade(vt string) bool { // // The callers must ensure curVer and logger are non-nil. // Returns true if upgrade is allowed and should proceed, false otherwise. -func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) bool { +func selfUpgradeCheck(vt string, cv *semver.Version, logger *ctrld.Logger) bool { if shouldUpgrade(vt, cv, logger) { - return performUpgrade(vt) + return performUpgrade(vt, logger) } return false } @@ -1560,29 +1487,5 @@ func (p *prog) leakOnUpstreamFailure() bool { if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil { return *ptr } - // Default is false on routers, since this leaking is only useful for devices that move between networks. - if router.Name() != "" { - return false - } - // if we are running on ADDC, we should not leak on upstream failure - if p.runningOnDomainController { - return false - } return true } - -// Domain controller role values from Win32_ComputerSystem -// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem -const ( - BackupDomainController = 4 - PrimaryDomainController = 5 -) - -// isRunningOnDomainController checks if the current machine is a domain controller -// by querying the DomainRole property from Win32_ComputerSystem via WMI. -func isRunningOnDomainController() (bool, int) { - if runtime.GOOS != "windows" { - return false, 0 - } - return isRunningOnDomainControllerWindows() -} diff --git a/cmd/cli/prog_darwin.go b/cmd/cli/prog_darwin.go index 9cd57864..a3854703 100644 --- a/cmd/cli/prog_darwin.go +++ b/cmd/cli/prog_darwin.go @@ -4,8 +4,10 @@ import ( "github.com/kardianos/service" ) +// setDependencies sets service dependencies for Darwin func setDependencies(svc *service.Config) {} +// setWorkingDirectory sets the working directory for the service func setWorkingDirectory(svc *service.Config, dir string) { svc.WorkingDirectory = dir } diff --git a/cmd/cli/prog_freebsd.go b/cmd/cli/prog_freebsd.go index 93d737fc..1be94cae 100644 --- a/cmd/cli/prog_freebsd.go +++ b/cmd/cli/prog_freebsd.go @@ -6,9 +6,11 @@ import ( "github.com/kardianos/service" ) +// setDependencies sets service dependencies for FreeBSD func setDependencies(svc *service.Config) { // TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed. _ = os.MkdirAll("/usr/local/etc/rc.d", 0755) } +// setWorkingDirectory sets the working directory for the service func setWorkingDirectory(svc *service.Config, dir string) {} diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index 2e5c7c76..c834b495 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -9,8 +9,6 @@ import ( "strings" "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld/internal/router" ) func init() { @@ -23,6 +21,7 @@ func init() { } } +// setDependencies sets service dependencies for Linux func setDependencies(svc *service.Config) { svc.Dependencies = []string{ "Wants=network-online.target", @@ -37,11 +36,9 @@ func setDependencies(svc *service.Config) { svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service") } } - if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 { - svc.Dependencies = append(svc.Dependencies, routerDeps...) - } } +// setWorkingDirectory sets the working directory for the service func setWorkingDirectory(svc *service.Config, dir string) { svc.WorkingDirectory = dir } diff --git a/cmd/cli/prog_log.go b/cmd/cli/prog_log.go new file mode 100644 index 00000000..91e797e0 --- /dev/null +++ b/cmd/cli/prog_log.go @@ -0,0 +1,33 @@ +package cli + +import "github.com/Control-D-Inc/ctrld" + +// Debug starts a new message with debug level. +func (p *prog) Debug() *ctrld.LogEvent { + return p.logger.Load().Debug() +} + +// Warn starts a new message with warn level. +func (p *prog) Warn() *ctrld.LogEvent { + return p.logger.Load().Warn() +} + +// Info starts a new message with info level. +func (p *prog) Info() *ctrld.LogEvent { + return p.logger.Load().Info() +} + +// Fatal starts a new message with fatal level. +func (p *prog) Fatal() *ctrld.LogEvent { + return p.logger.Load().Fatal() +} + +// Error starts a new message with error level. +func (p *prog) Error() *ctrld.LogEvent { + return p.logger.Load().Error() +} + +// Notice starts a new message with notice level. +func (p *prog) Notice() *ctrld.LogEvent { + return p.logger.Load().Notice() +} diff --git a/cmd/cli/prog_others.go b/cmd/cli/prog_others.go index 9026318b..c1b7f17d 100644 --- a/cmd/cli/prog_others.go +++ b/cmd/cli/prog_others.go @@ -4,8 +4,10 @@ package cli import "github.com/kardianos/service" +// setDependencies sets service dependencies for other platforms func setDependencies(svc *service.Config) {} +// setWorkingDirectory sets the working directory for the service func setWorkingDirectory(svc *service.Config, dir string) { // WorkingDirectory is not supported on Windows. svc.WorkingDirectory = dir diff --git a/cmd/cli/prog_test.go b/cmd/cli/prog_test.go index c4ef5c3b..eccc30bc 100644 --- a/cmd/cli/prog_test.go +++ b/cmd/cli/prog_test.go @@ -1,13 +1,12 @@ package cli import ( - "runtime" "testing" "time" "github.com/Masterminds/semver/v3" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "go.uber.org/zap" "github.com/Control-D-Inc/ctrld" ) @@ -174,10 +173,10 @@ func Test_shouldUpgrade(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { // Create test logger - testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + testLogger := &ctrld.Logger{Logger: zap.NewNop()} // Call the function and capture the result - result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger) + result := shouldUpgrade(tc.versionTarget, tc.currentVersion, testLogger) // Assert the expected result assert.Equal(t, tc.shouldUpgrade, result, tc.description) @@ -186,10 +185,6 @@ func Test_shouldUpgrade(t *testing.T) { } func Test_selfUpgradeCheck(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("skipped due to Windows file locking issue on Github Action runners") - } - // Helper function to create a version makeVersion := func(v string) *semver.Version { ver, err := semver.NewVersion(v) @@ -226,10 +221,10 @@ func Test_selfUpgradeCheck(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { // Create test logger - testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger() + testLogger := &ctrld.Logger{Logger: zap.NewNop()} // Call the function and capture the result - result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger) + result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, testLogger) // Assert the expected result assert.Equal(t, tc.shouldUpgrade, result, tc.description) @@ -238,10 +233,6 @@ func Test_selfUpgradeCheck(t *testing.T) { } func Test_performUpgrade(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("skipped due to Windows file locking issue on Github Action runners") - } - tests := []struct { name string versionTarget string @@ -265,8 +256,10 @@ func Test_performUpgrade(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { + // Create test logger + testLogger := &ctrld.Logger{Logger: zap.NewNop()} // Call the function and capture the result - result := performUpgrade(tc.versionTarget) + result := performUpgrade(tc.versionTarget, testLogger) assert.Equal(t, tc.expectedResult, result, tc.description) }) } diff --git a/cmd/cli/prog_windows.go b/cmd/cli/prog_windows.go index e4486255..bd5673f6 100644 --- a/cmd/cli/prog_windows.go +++ b/cmd/cli/prog_windows.go @@ -2,12 +2,10 @@ package cli import "github.com/kardianos/service" -func setDependencies(svc *service.Config) { - if hasLocalDnsServerRunning() { - svc.Dependencies = []string{"DNS"} - } -} +// setDependencies sets service dependencies for Windows +func setDependencies(svc *service.Config) {} +// setWorkingDirectory sets the working directory for the service func setWorkingDirectory(svc *service.Config, dir string) { // WorkingDirectory is not supported on Windows. svc.WorkingDirectory = dir diff --git a/cmd/cli/prometheus.go b/cmd/cli/prometheus.go index 9082a58f..90fce209 100644 --- a/cmd/cli/prometheus.go +++ b/cmd/cli/prometheus.go @@ -2,6 +2,8 @@ package cli import "github.com/prometheus/client_golang/prometheus" +// Prometheus metrics label constants for consistent labeling across all metrics +// These ensure standardized metric labeling for monitoring and alerting const ( metricsLabelListener = "listener" metricsLabelClientSourceIP = "client_source_ip" @@ -13,17 +15,21 @@ const ( ) // statsVersion represent ctrld version. +// This metric provides version information for monitoring and debugging var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "ctrld_build_info", Help: "Version of ctrld process.", }, []string{"gitref", "goversion", "version"}) // statsTimeStart represents start time of ctrld service. +// This metric tracks service uptime and helps with monitoring service restarts var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "ctrld_time_seconds", Help: "Start time of the ctrld process since unix epoch in seconds.", }) +// statsQueriesCountLabels defines the labels for query count metrics +// These labels provide detailed breakdown of DNS query statistics var statsQueriesCountLabels = []string{ metricsLabelListener, metricsLabelClientSourceIP, @@ -35,6 +41,7 @@ var statsQueriesCountLabels = []string{ } // statsQueriesCount counts total number of queries. +// This provides comprehensive DNS query statistics for monitoring and alerting var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "ctrld_queries_count", Help: "Total number of queries.", @@ -44,12 +51,14 @@ var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{ // // The labels "client_source_ip", "client_mac", "client_hostname" are unbounded, // thus this stat is highly inefficient if there are many devices. +// This metric should be used carefully in high-client environments var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "ctrld_client_queries_count", Help: "Total number queries of a client.", }, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname}) // WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled. +// This provides conditional metric collection to avoid performance impact when metrics are disabled func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) { if p.metricsQueryStats.Load() { c.WithLabelValues(lvs...).Inc() diff --git a/cmd/cli/reload_others.go b/cmd/cli/reload_others.go index 0977af90..cf374a04 100644 --- a/cmd/cli/reload_others.go +++ b/cmd/cli/reload_others.go @@ -8,10 +8,12 @@ import ( "syscall" ) +// notifyReloadSigCh sends reload signal to the channel func notifyReloadSigCh(ch chan os.Signal) { signal.Notify(ch, syscall.SIGUSR1) } +// sendReloadSignal sends a reload signal to the current process func (p *prog) sendReloadSignal() error { return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) } diff --git a/cmd/cli/reload_windows.go b/cmd/cli/reload_windows.go index 0e817e46..b60f796d 100644 --- a/cmd/cli/reload_windows.go +++ b/cmd/cli/reload_windows.go @@ -6,8 +6,10 @@ import ( "time" ) +// notifyReloadSigCh is a no-op on Windows platforms func notifyReloadSigCh(ch chan os.Signal) {} +// sendReloadSignal sends a reload signal to the program func (p *prog) sendReloadSignal() error { select { case p.reloadCh <- struct{}{}: diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 0f3f731a..325d19d0 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -3,59 +3,45 @@ package cli import ( "net" "net/netip" - "os" "path/filepath" - "strings" "time" "github.com/fsnotify/fsnotify" + + "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) // parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found. // Returns nil if no nameservers are found. +// This function parses the system DNS configuration to understand current nameserver settings func (p *prog) parseResolvConfNameservers(path string) ([]string, error) { - content, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - // Parse the file for "nameserver" lines - var currentNS []string - lines := strings.Split(string(content), "\n") - for _, line := range lines { - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "nameserver") { - parts := strings.Fields(trimmed) - if len(parts) >= 2 { - currentNS = append(currentNS, parts[1]) - } - } - } - - return currentNS, nil + return resolvconffile.NameserversFromFile(path) } // watchResolvConf watches any changes to /etc/resolv.conf file, // and reverting to the original config set by ctrld. +// This ensures that DNS settings are not overridden by other applications or system processes func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) { resolvConfPath := "/etc/resolv.conf" // Evaluating symbolics link to watch the target file that /etc/resolv.conf point to. + // This handles systems where resolv.conf is a symlink to another location if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" { resolvConfPath = rp } - mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath) + p.Debug().Msgf("Start watching %s file", resolvConfPath) watcher, err := fsnotify.NewWatcher() if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") + p.Warn().Err(err).Msg("Could not create watcher for /etc/resolv.conf") return } defer watcher.Close() // We watch /etc instead of /etc/resolv.conf directly, // see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well + // This is necessary because some systems don't properly notify on file changes watchDir := filepath.Dir(resolvConfPath) if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir) + p.Warn().Err(err).Msgf("Could not add %s to watcher list", watchDir) return } @@ -64,7 +50,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f case <-p.dnsWatcherStopCh: return case <-p.stopCh: - mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) + p.Debug().Msgf("Stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: if p.recoveryRunning.Load() { @@ -77,9 +63,10 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") + p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") // Convert expected nameservers to strings for comparison + // This allows us to detect when the resolv.conf has been modified expectedNS := make([]string, len(ns)) for i, addr := range ns { expectedNS[i] = addr.String() @@ -92,18 +79,20 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f for retry := 0; retry < maxRetries; retry++ { foundNS, err = p.parseResolvConfNameservers(resolvConfPath) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content") + p.Error().Err(err).Msg("Failed to read resolv.conf content") break } // If we found nameservers, break out of retry loop + // This handles cases where the file is being written but not yet complete if len(foundNS) > 0 { break } // Only retry if we found no nameservers + // This handles temporary file states during updates if retry < maxRetries-1 { - mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) + p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) select { case <-p.stopCh: return @@ -113,7 +102,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } } else { - mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries") + p.Debug().Msg("resolv.conf remained empty after all retries") } } @@ -130,7 +119,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f } } - mainLog.Load().Debug(). + p.Debug(). Strs("found", foundNS). Strs("expected", expectedNS). Bool("matches", matches). @@ -139,16 +128,16 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f // Only revert if the nameservers don't match if !matches { if err := watcher.Remove(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + p.Error().Err(err).Msg("Failed to pause watcher") continue } if err := setDnsFn(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + p.Error().Err(err).Msg("Failed to revert /etc/resolv.conf changes") } if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + p.Error().Err(err).Msg("Failed to continue running watcher") return } } @@ -158,7 +147,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f if !ok { return } - mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf") + p.Error().Err(err).Msg("Could not get event for /etc/resolv.conf") } } } diff --git a/cmd/cli/resolvconf_darwin.go b/cmd/cli/resolvconf_darwin.go index eb70eed6..05c70178 100644 --- a/cmd/cli/resolvconf_darwin.go +++ b/cmd/cli/resolvconf_darwin.go @@ -12,7 +12,7 @@ import ( const resolvConfPath = "/etc/resolv.conf" // setResolvConf sets the content of resolv.conf file using the given nameservers list. -func setResolvConf(iface *net.Interface, ns []netip.Addr) error { +func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error { servers := make([]string, len(ns)) for i := range ns { servers[i] = ns[i].String() diff --git a/cmd/cli/resolvconf_not_darwin_unix.go b/cmd/cli/resolvconf_not_darwin_unix.go index af335720..6eb52959 100644 --- a/cmd/cli/resolvconf_not_darwin_unix.go +++ b/cmd/cli/resolvconf_not_darwin_unix.go @@ -14,7 +14,7 @@ import ( ) // setResolvConf sets the content of the resolv.conf file using the given nameservers list. -func setResolvConf(iface *net.Interface, ns []netip.Addr) error { +func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error { r, err := newLoopbackOSConfigurator() if err != nil { return err @@ -27,7 +27,7 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error { if sds, err := searchDomains(); err == nil { oc.SearchDomains = sds } else { - mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file") + p.Debug().Err(err).Msg("Failed to get search domains list when reverting resolv.conf file") } return r.SetDNS(oc) } diff --git a/cmd/cli/resolvconf_test.go b/cmd/cli/resolvconf_test.go new file mode 100644 index 00000000..9d93607c --- /dev/null +++ b/cmd/cli/resolvconf_test.go @@ -0,0 +1,51 @@ +//go:build unix + +package cli + +import ( + "os" + "slices" + "strings" + "testing" + + "github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile" +) + +func oldParseResolvConfNameservers(path string) ([]string, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + // Parse the file for "nameserver" lines + var currentNS []string + lines := strings.Split(string(content), "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "nameserver") { + parts := strings.Fields(trimmed) + if len(parts) >= 2 { + currentNS = append(currentNS, parts[1]) + } + } + } + + return currentNS, nil +} + +// Test_prog_parseResolvConfNameservers tests the parsing of nameservers from resolv.conf content. +// Note: The previous implementation was removed to reduce code duplication and consolidate +// the resolv.conf handling logic into a single unified approach. All resolv.conf parsing +// is now handled by the resolvconffile package, which provides a consistent interface +// for both reading and modifying resolv.conf files across different platforms. +func Test_prog_parseResolvConfNameservers(t *testing.T) { + oldNss, _ := oldParseResolvConfNameservers(resolvconffile.Path) + p := &prog{} + nss, _ := p.parseResolvConfNameservers(resolvconffile.Path) + slices.Sort(oldNss) + slices.Sort(nss) + if !slices.Equal(oldNss, nss) { + t.Errorf("result mismatched, old: %v, new: %v", oldNss, nss) + } + t.Logf("result: %v", nss) +} diff --git a/cmd/cli/resolvconf_windows.go b/cmd/cli/resolvconf_windows.go index 3e4ba1c0..20a984fe 100644 --- a/cmd/cli/resolvconf_windows.go +++ b/cmd/cli/resolvconf_windows.go @@ -6,7 +6,7 @@ import ( ) // setResolvConf sets the content of resolv.conf file using the given nameservers list. -func setResolvConf(_ *net.Interface, _ []netip.Addr) error { +func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error { return nil } diff --git a/cmd/cli/search_domains_windows.go b/cmd/cli/search_domains_windows.go index 320a3223..28d1bb97 100644 --- a/cmd/cli/search_domains_windows.go +++ b/cmd/cli/search_domains_windows.go @@ -33,7 +33,7 @@ func searchDomains() ([]dnsname.FQDN, error) { for a := aa.FirstDNSSuffix; a != nil; a = a.Next { d, err := dnsname.ToFQDN(a.String()) if err != nil { - mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String()) + mainLog.Load().Debug().Err(err).Msgf("Failed to parse domain: %s", a.String()) continue } sds = append(sds, d) diff --git a/cmd/cli/self_delete_others.go b/cmd/cli/self_delete_others.go index 02ae9774..826590ec 100644 --- a/cmd/cli/self_delete_others.go +++ b/cmd/cli/self_delete_others.go @@ -4,4 +4,5 @@ package cli var supportedSelfDelete = true +// selfDeleteExe performs self-deletion on non-Windows platforms func selfDeleteExe() error { return nil } diff --git a/cmd/cli/self_delete_windows.go b/cmd/cli/self_delete_windows.go index c2f2719e..c9618a27 100644 --- a/cmd/cli/self_delete_windows.go +++ b/cmd/cli/self_delete_windows.go @@ -33,6 +33,7 @@ type FILE_DISPOSITION_INFO struct { DeleteFile bool } +// dsOpenHandle opens a handle to the specified file with DELETE access func dsOpenHandle(pwPath *uint16) (windows.Handle, error) { handle, err := windows.CreateFile( pwPath, @@ -51,6 +52,7 @@ func dsOpenHandle(pwPath *uint16) (windows.Handle, error) { return handle, nil } +// dsRenameHandle renames a file handle to a stream name func dsRenameHandle(hHandle windows.Handle) error { var fRename FILE_RENAME_INFO DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef") @@ -82,6 +84,7 @@ func dsRenameHandle(hHandle windows.Handle) error { return nil } +// dsDepositeHandle marks a file handle for deletion func dsDepositeHandle(hHandle windows.Handle) error { var fDelete FILE_DISPOSITION_INFO fDelete.DeleteFile = true @@ -100,6 +103,7 @@ func dsDepositeHandle(hHandle windows.Handle) error { return nil } +// selfDeleteExe performs self-deletion on Windows platforms func selfDeleteExe() error { var wcPath [windows.MAX_PATH + 1]uint16 var hCurrent windows.Handle diff --git a/cmd/cli/self_kill_others.go b/cmd/cli/self_kill_others.go index e9fb1f8f..4f32d6f8 100644 --- a/cmd/cli/self_kill_others.go +++ b/cmd/cli/self_kill_others.go @@ -5,12 +5,13 @@ package cli import ( "os" - "github.com/rs/zerolog" + "github.com/Control-D-Inc/ctrld" ) -func selfUninstall(p *prog, logger zerolog.Logger) { +// selfUninstall performs self-uninstallation on non-Unix platforms +func selfUninstall(p *prog, logger *ctrld.Logger) { if uninstallInvalidCdUID(p, logger, false) { - logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) + logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID) os.Exit(0) } } diff --git a/cmd/cli/self_kill_unix.go b/cmd/cli/self_kill_unix.go index 157425fd..70c7c08d 100644 --- a/cmd/cli/self_kill_unix.go +++ b/cmd/cli/self_kill_unix.go @@ -9,17 +9,18 @@ import ( "runtime" "syscall" - "github.com/rs/zerolog" + "github.com/Control-D-Inc/ctrld" ) -func selfUninstall(p *prog, logger zerolog.Logger) { +// selfUninstall performs self-uninstallation on Unix platforms +func selfUninstall(p *prog, logger *ctrld.Logger) { if runtime.GOOS == "linux" { selfUninstallLinux(p, logger) } bin, err := os.Executable() if err != nil { - logger.Fatal().Err(err).Msg("could not determine executable") + logger.Fatal().Err(err).Msg("Could not determine executable") } args := []string{"uninstall"} if deactivationPinSet() { @@ -28,18 +29,19 @@ func selfUninstall(p *prog, logger zerolog.Logger) { cmd := exec.Command(bin, args...) cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} if err := cmd.Start(); err != nil { - logger.Fatal().Err(err).Msg("could not start self uninstall command") + logger.Fatal().Err(err).Msg("Could not start self uninstall command") } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) + logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID) _ = cmd.Wait() os.Exit(0) } -func selfUninstallLinux(p *prog, logger zerolog.Logger) { +// selfUninstallLinux performs self-uninstallation on Linux platforms +func selfUninstallLinux(p *prog, logger *ctrld.Logger) { if uninstallInvalidCdUID(p, logger, true) { - logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) + logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID) os.Exit(0) } } diff --git a/cmd/cli/sema.go b/cmd/cli/sema.go index 92b6ce0f..4285eaf4 100644 --- a/cmd/cli/sema.go +++ b/cmd/cli/sema.go @@ -1,24 +1,31 @@ package cli +// semaphore provides a simple synchronization mechanism type semaphore interface { acquire() release() } +// noopSemaphore is a no-operation implementation of semaphore type noopSemaphore struct{} +// acquire performs a no-operation for the noop semaphore func (n noopSemaphore) acquire() {} +// release performs a no-operation for the noop semaphore func (n noopSemaphore) release() {} +// chanSemaphore is a channel-based implementation of semaphore type chanSemaphore struct { ready chan struct{} } +// acquire blocks until a slot is available in the semaphore func (c *chanSemaphore) acquire() { c.ready <- struct{}{} } +// release signals that a slot has been freed in the semaphore func (c *chanSemaphore) release() { <-c.ready } diff --git a/cmd/cli/service.go b/cmd/cli/service.go index f75ee558..7046353d 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -11,9 +11,6 @@ import ( "github.com/coreos/go-systemd/v22/unit" "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld/internal/router" - "github.com/Control-D-Inc/ctrld/internal/router/openwrt" ) // newService wraps service.New call to return service.Service @@ -24,10 +21,6 @@ func newService(i service.Interface, c *service.Config) (service.Service, error) return nil, err } switch { - case router.IsOldOpenwrt(), router.IsNetGearOrbi(): - return &procd{sysV: &sysV{s}, svcConfig: c}, nil - case router.IsGLiNet(): - return &sysV{s}, nil case s.Platform() == "unix-systemv": return &sysV{s}, nil case s.Platform() == "linux-systemd": @@ -42,7 +35,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error) // sysV wraps a service.Service, and provide start/stop/status command // base on "/etc/init.d/". // -// Use this on system where "service" command is not available, like GL.iNET router. +// Use this on system where "service" command is not available. type sysV struct { service.Service } @@ -89,37 +82,6 @@ func (s *sysV) Status() (service.Status, error) { return unixSystemVServiceStatus() } -// procd wraps a service.Service, and provide start/stop command -// base on "/etc/init.d/", status command base on parsing "ps" command output. -// -// Use this on system where "/etc/init.d/ status" command is not available, -// like old GL.iNET Opal router. -type procd struct { - *sysV - svcConfig *service.Config -} - -func (s *procd) Status() (service.Status, error) { - if !s.installed() { - return service.StatusUnknown, service.ErrNotInstalled - } - bin := s.svcConfig.Executable - if bin == "" { - exe, err := os.Executable() - if err != nil { - return service.StatusUnknown, nil - } - bin = exe - } - - // Looking for something like "/sbin/ctrld run ". - shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ") - if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil { - return service.StatusStopped, nil - } - return service.StatusRunning, nil -} - // systemd wraps a service.Service, and provide status command to // report the status correctly. type systemd struct { @@ -153,7 +115,7 @@ func (s *systemd) Start() error { if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil { return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out)) } - mainLog.Load().Debug().Msg("set KillMode=process successfully") + mainLog.Load().Debug().Msg("Set KillMode=process successfully") } return s.Service.Start() } @@ -163,7 +125,7 @@ func (s *systemd) Start() error { func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) { opts, err := unit.DeserializeOptions(r) if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to deserialize options") + mainLog.Load().Error().Err(err).Msg("Failed to deserialize options") return } change = true @@ -187,6 +149,7 @@ func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) { return opts, change } +// newLaunchd creates a new launchd service wrapper func newLaunchd(s service.Service) *launchd { return &launchd{ Service: s, @@ -216,28 +179,30 @@ type task struct { Name string } +// doTasks executes a list of tasks and returns success status func doTasks(tasks []task) bool { for _, task := range tasks { mainLog.Load().Debug().Msgf("Running task %s", task.Name) if err := task.f(); err != nil { if task.abortOnError { - mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err) + mainLog.Load().Error().Msgf("Error running task %s: %v", task.Name, err) return false } // if this is darwin stop command, dont print debug // since launchctl complains on every start if runtime.GOOS != "darwin" || task.Name != "Stop" { - mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err) + mainLog.Load().Debug().Msgf("Error running task %s: %v", task.Name, err) } } } return true } +// checkHasElevatedPrivilege checks if the process has elevated privileges and exits if not func checkHasElevatedPrivilege() { ok, err := hasElevatedPrivilege() if err != nil { - mainLog.Load().Error().Msgf("could not detect user privilege: %v", err) + mainLog.Load().Error().Msgf("Could not detect user privilege: %v", err) return } if !ok { @@ -246,16 +211,10 @@ func checkHasElevatedPrivilege() { } } +// unixSystemVServiceStatus checks the status of a Unix System V service func unixSystemVServiceStatus() (service.Status, error) { out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput() if err != nil { - // Specific case for openwrt >= 24.10, it returns non-success code - // for above status command, which may not right. - if router.Name() == openwrt.Name { - if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" { - return service.StatusStopped, nil - } - } return service.StatusUnknown, nil } diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index 954b2287..ce630d68 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -6,17 +6,15 @@ import ( "os" ) +// hasElevatedPrivilege checks if the current process has elevated privileges func hasElevatedPrivilege() (bool, error) { return os.Geteuid() == 0, nil } +// openLogFile opens a log file with the specified flags func openLogFile(path string, flags int) (*os.File, error) { return os.OpenFile(path, flags, os.FileMode(0o600)) } -// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. -func hasLocalDnsServerRunning() bool { return false } - +// ConfigureWindowsServiceFailureActions is a no-op on non-Windows platforms func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } - -func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index fddb0ef8..aa36bd8f 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,22 +2,16 @@ package cli import ( "os" - "reflect" "runtime" - "strconv" - "strings" "syscall" "time" "unsafe" - "github.com/microsoft/wmi/pkg/base/host" - "github.com/microsoft/wmi/pkg/base/instance" - "github.com/microsoft/wmi/pkg/base/query" - "github.com/microsoft/wmi/pkg/constant" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc/mgr" ) +// hasElevatedPrivilege checks if the current process has elevated privileges on Windows func hasElevatedPrivilege() (bool, error) { var sid *windows.SID if err := windows.AllocateAndInitializeSid( @@ -100,6 +94,7 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } +// openLogFile opens a log file with the specified mode on Windows func openLogFile(path string, mode int) (*os.File, error) { if len(path) == 0 { return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND} @@ -151,77 +146,3 @@ func openLogFile(path string, mode int) (*os.File, error) { return os.NewFile(uintptr(handle), path), nil } - -const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{})) - -// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. -func hasLocalDnsServerRunning() bool { - h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) - if e != nil { - return false - } - p := windows.ProcessEntry32{Size: processEntrySize} - for { - e := windows.Process32Next(h, &p) - if e != nil { - return false - } - if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" { - return true - } - } -} - -func isRunningOnDomainControllerWindows() (bool, int) { - whost := host.NewWmiLocalHost() - q := query.NewWmiQuery("Win32_ComputerSystem") - instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q) - if err != nil { - mainLog.Load().Debug().Err(err).Msg("WMI query failed") - return false, 0 - } - if instances == nil { - mainLog.Load().Debug().Msg("WMI query returned nil instances") - return false, 0 - } - defer instances.Close() - - if len(instances) == 0 { - mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem") - return false, 0 - } - - val, err := instances[0].GetProperty("DomainRole") - if err != nil { - mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property") - return false, 0 - } - if val == nil { - mainLog.Load().Debug().Msg("DomainRole property is nil") - return false, 0 - } - - // Safely handle varied types: string or integer - var roleInt int - switch v := val.(type) { - case string: - // "4", "5", etc. - parsed, parseErr := strconv.Atoi(v) - if parseErr != nil { - mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v) - return false, 0 - } - roleInt = parsed - case int8, int16, int32, int64: - roleInt = int(reflect.ValueOf(v).Int()) - case uint8, uint16, uint32, uint64: - roleInt = int(reflect.ValueOf(v).Uint()) - default: - mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v) - return false, 0 - } - - // Check if role indicates a domain controller - isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController - return isDC, roleInt -} diff --git a/cmd/cli/service_windows_test.go b/cmd/cli/service_windows_test.go deleted file mode 100644 index 67c2725d..00000000 --- a/cmd/cli/service_windows_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package cli - -import ( - "testing" - "time" -) - -func Test_hasLocalDnsServerRunning(t *testing.T) { - start := time.Now() - hasDns := hasLocalDnsServerRunning() - t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) - - start = time.Now() - hasDnsPowershell := hasLocalDnsServerRunningPowershell() - t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) - - if hasDns != hasDnsPowershell { - t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns) - } -} - -func hasLocalDnsServerRunningPowershell() bool { - _, err := powershell("Get-Process -Name DNS") - return err == nil -} diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 6e19e38a..fcd8c7c7 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -2,6 +2,7 @@ package cli import ( "sync" + "sync/atomic" "time" "github.com/Control-D-Inc/ctrld" @@ -16,7 +17,8 @@ const ( // upstreamMonitor performs monitoring upstreams health. type upstreamMonitor struct { - cfg *ctrld.Config + cfg *ctrld.Config + logger atomic.Pointer[ctrld.Logger] mu sync.RWMutex checking map[string]bool @@ -28,7 +30,8 @@ type upstreamMonitor struct { failureTimerActive map[string]bool } -func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { +// newUpstreamMonitor creates a new upstream monitor instance +func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor { um := &upstreamMonitor{ cfg: cfg, checking: make(map[string]bool), @@ -37,6 +40,7 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { recovered: make(map[string]bool), failureTimerActive: make(map[string]bool), } + um.logger.Store(logger) for n := range cfg.Upstream { upstream := upstreamPrefix + n um.reset(upstream) @@ -53,7 +57,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { defer um.mu.Unlock() if um.recovered[upstream] { - mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream) + um.logger.Load().Debug().Msgf("Upstream %q is recovered, skipping failure count increase", upstream) return } @@ -61,7 +65,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { failedCount := um.failureReq[upstream] // Log the updated failure count. - mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) + um.logger.Load().Debug().Msgf("Upstream %q failure count updated to %d", upstream, failedCount) // If this is the first failure and no timer is running, start a 10-second timer. if failedCount == 1 && !um.failureTimerActive[upstream] { @@ -74,7 +78,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { // and the upstream is not in a recovered state, mark it as down. if um.failureReq[upstream] > 0 && !um.recovered[upstream] { um.down[upstream] = true - mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) + um.logger.Load().Warn().Msgf("Upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) } // Reset the timer flag so that a new timer can be spawned if needed. um.failureTimerActive[upstream] = false @@ -84,7 +88,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { // If the failure count quickly reaches the threshold, mark the upstream as down immediately. if failedCount >= maxFailureRequest { um.down[upstream] = true - mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) + um.logger.Load().Warn().Msgf("Upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) } } diff --git a/cmd/ctrld_library/main.go b/cmd/ctrld_library/main.go index b2e643db..6713568c 100644 --- a/cmd/ctrld_library/main.go +++ b/cmd/ctrld_library/main.go @@ -45,7 +45,7 @@ func (c *Controller) Start(CdUID string, ProvisionID string, CustomHostname stri } } -// As workaround to avoid circular dependency between cli and ctrld_library module +// mapCallback maps the AppCallback interface to cli.AppCallback to avoid circular dependency func mapCallback(callback AppCallback) cli.AppCallback { return cli.AppCallback{ HostName: func() string { diff --git a/config.go b/config.go index 73484d70..3e6548de 100644 --- a/config.go +++ b/config.go @@ -114,6 +114,10 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) { // InitConfig initializes default config values for given *viper.Viper instance. func InitConfig(v *viper.Viper, name string) { + ctx := context.Background() + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "Config initialization started") + v.SetDefault("listener", map[string]*ListenerConfig{ "0": { IP: "", @@ -152,6 +156,8 @@ func InitConfig(v *viper.Viper, name string) { Timeout: 3000, }, }) + + Log(ctx, logger.Debug(), "Config initialization completed") } // Config represents ctrld supported configuration. @@ -309,14 +315,20 @@ func (lc *ListenerConfig) IsDirectDnsListener() bool { } } +// MatchingConfig defines the configuration for rule matching behavior +type MatchingConfig struct { + Order []string `mapstructure:"order" toml:"order,omitempty"` +} + // ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests. type ListenerPolicyConfig struct { - Name string `mapstructure:"name" toml:"name,omitempty"` - Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"` - Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"` - Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"` - FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"` - FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` + Name string `mapstructure:"name" toml:"name,omitempty"` + Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"` + Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"` + Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"` + FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"` + FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` + Matching *MatchingConfig `mapstructure:"-" toml:"-"` } // Rule is a map from source to list of upstreams. @@ -325,12 +337,13 @@ type ListenerPolicyConfig struct { type Rule map[string][]string // Init initialized necessary values for an UpstreamConfig. -func (uc *UpstreamConfig) Init() { +func (uc *UpstreamConfig) Init(ctx context.Context) { + logger := LoggerFromCtx(ctx) if err := uc.initDnsStamps(); err != nil { - ProxyLogger.Load().Fatal().Err(err).Msg("invalid DNS Stamps") + logger.Fatal().Err(err).Msg("Invalid dns stamps") } uc.initDoHScheme() - uc.uid = upstreamUID() + uc.uid = upstreamUID(ctx) if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Hostname() switch uc.Type { @@ -350,6 +363,9 @@ func (uc *UpstreamConfig) Init() { } } if uc.IPStack == "" { + // Set default IP stack based on upstream type + // Control-D upstreams use split stack for better IPv4/IPv6 handling, + // while other upstreams use both stacks for maximum compatibility if uc.IsControlD() { uc.IPStack = IpStackSplit } else { @@ -411,7 +427,7 @@ func (uc *UpstreamConfig) IsDiscoverable() bool { return *uc.Discoverable } switch uc.Type { - case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal: + case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate: if ip, err := netip.ParseAddr(uc.Domain); err == nil { return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) } @@ -434,21 +450,18 @@ func (uc *UpstreamConfig) UID() string { return uc.uid } -// SetupBootstrapIP manually find all available IPs of the upstream. -// The first usable IP will be used as bootstrap IP of the upstream. -// The upstream domain will be looked up using following orders: -// -// - Current system DNS settings. -// - Direct IPs table for ControlD upstreams. -// - ControlD Bootstrap DNS 76.76.2.22 -// +// SetupBootstrapIP sets up bootstrap IPs for the upstream config. // The setup process will block until there's usable IPs found. -func (uc *UpstreamConfig) SetupBootstrapIP() { +func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) { + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "Setting up bootstrap IPs for upstream: %s", uc.Name) + b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second) isControlD := uc.IsControlD() - nss := initDefaultOsResolver() + nss := initDefaultOsResolver(ctx) for { - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss) + Log(ctx, logger.Debug(), "Looking up bootstrap IPs for domain: %s", uc.Domain) + uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, nss) // For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses, // filtering them out here to prevent weird behavior. if isControlD { @@ -463,18 +476,18 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.bootstrapIPs = uc.bootstrapIPs[:n] if len(uc.bootstrapIPs) == 0 { uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain) - ProxyLogger.Load().Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain) + logger.Warn().Msgf("No record found for %q, lookup from direct ip table", uc.Domain) } } if len(uc.bootstrapIPs) == 0 { - ProxyLogger.Load().Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP) - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")}) + logger.Warn().Msgf("No record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP) + uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")}) } if len(uc.bootstrapIPs) > 0 { break } - ProxyLogger.Load().Warn().Msg("could not resolve bootstrap IPs, retrying...") + logger.Warn().Msg("Could not resolve bootstrap ips, retrying...") b.BackOff(context.Background(), errors.New("no bootstrap IPs")) } for _, ip := range uc.bootstrapIPs { @@ -484,11 +497,12 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip) } } - ProxyLogger.Load().Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs) + logger.Debug().Msgf("Bootstrap ips: %v", uc.bootstrapIPs) + Log(ctx, logger.Debug(), "Bootstrap IP setup completed for upstream: %s", uc.Name) } // ReBootstrap re-setup the bootstrap IP and the transport. -func (uc *UpstreamConfig) ReBootstrap() { +func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: default: @@ -496,7 +510,8 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { if uc.rebootstrap.CompareAndSwap(false, true) { - ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc) + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "Re-bootstrapping upstream: %s", uc.Name) } return true, nil }) @@ -504,35 +519,35 @@ func (uc *UpstreamConfig) ReBootstrap() { // SetupTransport initializes the network transport used to connect to upstream server. // For now, only DoH upstream is supported. -func (uc *UpstreamConfig) SetupTransport() { +func (uc *UpstreamConfig) SetupTransport(ctx context.Context) { switch uc.Type { case ResolverTypeDOH: - uc.setupDOHTransport() + uc.setupDOHTransport(ctx) case ResolverTypeDOH3: - uc.setupDOH3Transport() + uc.setupDOH3Transport(ctx) } } -func (uc *UpstreamConfig) setupDOHTransport() { +func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) { switch uc.IPStack { case IpStackBoth, "": - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) case IpStackV4: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs4) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4) case IpStackV6: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs6) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6) case IpStackSplit: - uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) - if HasIPv6() { - uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) + uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6) } else { uc.transport6 = uc.transport4 } - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs) } } -func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { +func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.MaxIdleConnsPerHost = 100 transport.TLSClientConfig = &tls.Config{ @@ -552,12 +567,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { dialerTimeoutMs = uc.Timeout } dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond + logger := LoggerFromCtx(ctx) transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { _, port, _ := net.SplitHostPort(addr) if uc.BootstrapIP != "" { dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout} addr := net.JoinHostPort(uc.BootstrapIP, port) - Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr) + Log(ctx, logger.Debug(), "Sending doh request to: %s", addr) return dialer.DialContext(ctx, network, addr) } pd := &ctrldnet.ParallelDialer{} @@ -567,11 +583,11 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { for i := range addrs { dialAddrs[i] = net.JoinHostPort(addrs[i], port) } - conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load()) + conn, err := pd.DialContext(ctx, network, dialAddrs, logger.Logger) if err != nil { return nil, err } - Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr()) + Log(ctx, logger.Debug(), "Sending doh request to: %s", conn.RemoteAddr()) return conn, nil } runtime.SetFinalizer(transport, func(transport *http.Transport) { @@ -581,19 +597,20 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { } // Ping warms up the connection to DoH/DoH3 upstream. -func (uc *UpstreamConfig) Ping() { - if err := uc.ping(); err != nil { - ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint) - _ = uc.FallbackToDirectIP() +func (uc *UpstreamConfig) Ping(ctx context.Context) { + if err := uc.ping(ctx); err != nil { + logger := LoggerFromCtx(ctx) + logger.Debug().Err(err).Msgf("Upstream ping failed: %s", uc.Endpoint) + _ = uc.FallbackToDirectIP(ctx) } } // ErrorPing is like Ping, but return an error if any. -func (uc *UpstreamConfig) ErrorPing() error { - return uc.ping() +func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error { + return uc.ping(ctx) } -func (uc *UpstreamConfig) ping() error { +func (uc *UpstreamConfig) ping(ctx context.Context) error { switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: default: @@ -622,11 +639,11 @@ func (uc *UpstreamConfig) ping() error { for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} { switch uc.Type { case ResolverTypeDOH: - if err := ping(uc.dohTransport(typ)); err != nil { + if err := ping(uc.dohTransport(ctx, typ)); err != nil { return err } case ResolverTypeDOH3: - if err := ping(uc.doh3Transport(typ)); err != nil { + if err := ping(uc.doh3Transport(ctx, typ)); err != nil { return err } } @@ -661,12 +678,12 @@ func (uc *UpstreamConfig) isNextDNS() bool { return domain == "dns.nextdns.io" } -func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { +func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { - uc.SetupTransport() + uc.SetupTransport(ctx) }) if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport() + uc.SetupTransport(ctx) } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: @@ -682,7 +699,7 @@ func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { return uc.transport } -func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { +func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string { switch uc.IPStack { case IpStackBoth: return pick(uc.bootstrapIPs) @@ -695,7 +712,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { case dns.TypeA: return pick(uc.bootstrapIPs4) default: - if HasIPv6() { + if HasIPv6(ctx) { return pick(uc.bootstrapIPs6) } return pick(uc.bootstrapIPs4) @@ -704,7 +721,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { return pick(uc.bootstrapIPs) } -func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { +func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) { switch uc.IPStack { case IpStackBoth: return "tcp-tls", "udp" @@ -717,7 +734,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { case dns.TypeA: return "tcp4-tls", "udp4" default: - if HasIPv6() { + if HasIPv6(ctx) { return "tcp6-tls", "udp6" } return "tcp4-tls", "udp4" @@ -798,7 +815,7 @@ func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context } // FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain. -func (uc *UpstreamConfig) FallbackToDirectIP() bool { +func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool { if !uc.IsControlD() { return false } @@ -817,7 +834,8 @@ func (uc *UpstreamConfig) FallbackToDirectIP() bool { default: return } - ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip) + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Warn(), "Using direct IP for %q: %s", uc.Endpoint, ip) uc.u.Host = ip done = true }) @@ -826,12 +844,18 @@ func (uc *UpstreamConfig) FallbackToDirectIP() bool { // Init initialized necessary values for an ListenerConfig. func (lc *ListenerConfig) Init() { + logger := LoggerFromCtx(context.Background()) + Log(context.Background(), logger.Debug(), "Initializing listener config") + if lc.Policy != nil { lc.Policy.FailoverRcodeNumbers = make([]int, len(lc.Policy.FailoverRcodes)) for i, rcode := range lc.Policy.FailoverRcodes { lc.Policy.FailoverRcodeNumbers[i] = dnsrcode.FromString(rcode) } + Log(context.Background(), logger.Debug(), "Listener policy initialized with %d failover rcodes", len(lc.Policy.FailoverRcodes)) } + + Log(context.Background(), logger.Debug(), "Listener config initialization completed") } // ValidateConfig validates the given config. @@ -951,11 +975,12 @@ func pick(s []string) string { } // upstreamUID generates an unique identifier for an upstream. -func upstreamUID() string { +func upstreamUID(ctx context.Context) string { + logger := LoggerFromCtx(ctx) b := make([]byte, 4) for { if _, err := crand.Read(b); err != nil { - ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...") + logger.Warn().Err(err).Msg("Could not generate uid for upstream, retrying...") continue } return hex.EncodeToString(b) diff --git a/config_internal_test.go b/config_internal_test.go index b37e982f..0e7f3bb4 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -1,6 +1,7 @@ package ctrld import ( + "context" "net/url" "testing" @@ -36,10 +37,10 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed. // t.Parallel() - tc.uc.Init() - tc.uc.SetupBootstrapIP() + tc.uc.Init(context.Background()) + tc.uc.SetupBootstrapIP(context.Background()) if len(tc.uc.bootstrapIPs) == 0 { - t.Log(defaultNameservers()) + t.Log(defaultNameservers(context.Background())) t.Fatalf("could not bootstrap ip: %s", tc.uc.String()) } }) @@ -355,7 +356,7 @@ func TestUpstreamConfig_Init(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() + tc.uc.Init(context.Background()) tc.uc.uid = "" // we don't care about the uid. assert.Equal(t, tc.expected, tc.uc) }) @@ -497,7 +498,7 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() + tc.uc.Init(context.Background()) if got := tc.uc.IsDiscoverable(); got != tc.discoverable { t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got) } diff --git a/config_quic.go b/config_quic.go index 33f56b92..57bd8641 100644 --- a/config_quic.go +++ b/config_quic.go @@ -14,34 +14,35 @@ import ( "github.com/quic-go/quic-go/http3" ) -func (uc *UpstreamConfig) setupDOH3Transport() { +func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) { switch uc.IPStack { case IpStackBoth, "": - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) case IpStackV4: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) case IpStackV6: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) case IpStackSplit: - uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) - if HasIPv6() { - uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) + uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4) + if HasIPv6(ctx) { + uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6) } else { uc.http3RoundTripper6 = uc.http3RoundTripper4 } - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs) } } -func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { +func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper { rt := &http3.Transport{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} + logger := LoggerFromCtx(ctx) rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { _, port, _ := net.SplitHostPort(addr) // if we have a bootstrap ip set, use it to avoid DNS lookup if uc.BootstrapIP != "" { addr = net.JoinHostPort(uc.BootstrapIP, port) - ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr) + Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr) udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, err @@ -61,7 +62,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { if err != nil { return nil, err } - ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) + Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } runtime.SetFinalizer(rt, func(rt *http3.Transport) { @@ -70,12 +71,12 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { return rt } -func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { +func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { - uc.SetupTransport() + uc.SetupTransport(ctx) }) if uc.rebootstrap.CompareAndSwap(true, false) { - uc.SetupTransport() + uc.SetupTransport(ctx) } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: diff --git a/desktop_darwin.go b/desktop_darwin.go index 039c0fac..7ba8b6b2 100644 --- a/desktop_darwin.go +++ b/desktop_darwin.go @@ -5,3 +5,6 @@ package ctrld func IsDesktopPlatform() bool { return true } + +// SelfDiscover reports whether ctrld should only do self discover. +func SelfDiscover() bool { return true } diff --git a/desktop_others.go b/desktop_others.go index de486e78..6d6a9a3f 100644 --- a/desktop_others.go +++ b/desktop_others.go @@ -7,3 +7,6 @@ package ctrld func IsDesktopPlatform() bool { return false } + +// SelfDiscover reports whether ctrld should only do self discover. +func SelfDiscover() bool { return false } diff --git a/desktop_windows.go b/desktop_windows.go index 4e9526b9..186a5ffc 100644 --- a/desktop_windows.go +++ b/desktop_windows.go @@ -1,7 +1,22 @@ package ctrld +import "golang.org/x/sys/windows" + // IsDesktopPlatform indicates if ctrld is running on a desktop platform, // currently defined as macOS or Windows workstation. func IsDesktopPlatform() bool { return isWindowsWorkStation() } + +// SelfDiscover reports whether ctrld should only do self discover. +func SelfDiscover() bool { + return isWindowsWorkStation() +} + +// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine. +func isWindowsWorkStation() bool { + // From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa + const VER_NT_WORKSTATION = 0x0000001 + osvi := windows.RtlGetVersion() + return osvi.ProductType == VER_NT_WORKSTATION +} diff --git a/docs/config.md b/docs/config.md index 99e98c9c..69ba0103 100644 --- a/docs/config.md +++ b/docs/config.md @@ -18,10 +18,6 @@ The config file allows for advanced configuration of the `ctrld` utility to cove - `/etc/controld` on *nix. - User's home directory on Windows. - - Same directory with `ctrld` binary on these routers: - - `ddwrt` - - `merlin` - - `freshtomato` - Current directory. The user can choose to override default value using command line `--config` or `-c`: @@ -293,7 +289,7 @@ If a remote upstream fails to resolve a query or is unreachable, `ctrld` will fo - Type: boolean - Required: no -- Default: true on Windows, MacOS and non-router Linux. +- Default: true on Windows, MacOS and Linux. ## Upstream The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to. diff --git a/docs/runtime-internal-logging.md b/docs/runtime-internal-logging.md new file mode 100644 index 00000000..982632cb --- /dev/null +++ b/docs/runtime-internal-logging.md @@ -0,0 +1,46 @@ +# Runtime Internal Logging + +When no logging is configured (i.e., `log_path` is not set), ctrld automatically enables an internal logging system. This system stores logs in memory to provide troubleshooting information when problems occur. + +## Purpose + +The runtime internal logging system is designed primarily for **ctrld developers**, not end users. It captures detailed diagnostic information that can be useful for troubleshooting issues when they arise, especially in production environments where explicit logging may not be configured. + +## When It's Enabled + +Internal logging is automatically enabled when: + +- ctrld is running in Control D mode (i.e., `--cd` flag is provided) +- No log file is configured (i.e., `log_path` is empty or not set) + +If a log file is explicitly configured via `log_path`, internal logging will **not** be enabled, as the configured log file serves the logging purpose. + +## How It Works + +The internal logging system: + +- Stores logs in **in-memory buffers** (not written to disk) +- Captures logs at **debug level** for normal operations and **warn level** for warnings +- Maintains separate buffers for normal logs and warning logs +- Automatically manages buffer size to prevent unbounded memory growth +- Preserves initialization logs even when buffers overflow + +## Configuration + +**Important**: The `log_level` configuration option does **not** affect the internal logging system. Internal logging always operates at debug level for normal logs and warn level for warnings, regardless of the `log_level` setting in the configuration file. + +The `log_level` setting only affects: +- Console output (when running interactively) +- File-based logging (when `log_path` is configured) + +## Accessing Internal Logs + +Internal logs can be accessed through the control server API endpoints. This functionality is intended for developers and support personnel who need to diagnose issues. + +## Notes + +- Internal logging is **not** a replacement for proper log file configuration in production environments +- For production deployments, it is recommended to configure `log_path` to enable persistent file-based logging +- Internal logs are stored in memory and will be lost if the process terminates unexpectedly +- The internal logging system is automatically disabled when explicit logging is configured + diff --git a/docs/v2.0.0-breaking-changes.md b/docs/v2.0.0-breaking-changes.md new file mode 100644 index 00000000..30ac034a --- /dev/null +++ b/docs/v2.0.0-breaking-changes.md @@ -0,0 +1,135 @@ +# ctrld v2.0.0 Breaking Changes + +This document outlines the breaking changes introduced in ctrld v2.0.0 and provides migration guidance for affected users. + +## Overview + +ctrld v2.0.0 removes automatic configuration support for router and server platforms. This means ctrld will no longer perform "magic" configuration to automatically set itself up as an upstream for existing DNS software on these platforms. + +## What's Changing + +### Removed Platform Support + +**Router Platforms:** +- ctrld will no longer automatically configure itself as an upstream for dnsmasq or other DNS software +- No automatic detection and configuration of router-specific DNS settings + +**Server Platforms:** +- ctrld will no longer automatically configure Windows Server DNS forwarder settings +- No automatic integration with server DNS services + +### What Remains Supported + +**Desktop Platforms:** +- Windows Desktop +- macOS Desktop +- Linux Desktop + +These platforms continue to receive full automatic configuration support. + +## Stay on v1.x.x + +ctrld v1.x.x will continue to be supported for router and server platforms: +- Important bug fixes (regression or security) will be cherry-picked to v1.x.x branch +- New features may still be added (but may take longer to implement) +- Long-term support for these platforms + +## Migration Path for Router and Server Users + +If you're currently using ctrld v1.x.x on router or server platforms, you need to follow these steps to migrate to v2.0.0: + +### Step 1: Downloading ctrld v2 binary + +To download ctrld v2.0.0, follow these steps: + +Stop the current ctrld service: + +```sh +ctrld stop +``` + +Or uninstall the current version: + +```sh +ctrld uninstall +``` + +Download the appropriate binary for your platform: https://dl.controld.com/v2/linux-amd64/ctrld + +> **Note**: Replace `amd64` with your platform architecture as needed. + +Verify that the binary was updated correctly: + +```sh +ctrld --version +``` + +Expected output: +``` +ctrld version v2.0.0 +``` + +### Step 2: Start ctrld without self-checking + +You have two ways to start ctrld: + +**Option A: Use Remote Configuration (Recommended)** +1. **Export your current configuration:** + - Copy the contents of your current `ctrld.toml` file + +2. **Import to Control D Dashboard:** + - Log into your Control D dashboard + - Use the remote configuration feature to upload your configuration + +3. **Start ctrld with remote config:** + ```bash + sudo ctrld service start --cd= --skip_self_checks + ``` + +> **Note**: You must use `ctrld service start` to prevent DNS being set automatically by ctrld. + +**Option B: Use Local Configuration** +```bash +sudo ctrld service start --skip_self_checks +``` + +### Step 3: Configure DNS Software to Use ctrld as Upstream + +**For dnsmasq users:** +1. Configure dnsmasq to use ctrld as upstream: + ```bash + # Add to dnsmasq.conf + no-resolv + server=127.0.0.1#5354 + add-mac + add-subnet=32,128 + # Disable cache or set max-cache-ttl=0 + # to prevent queries from caching + cache-size=0 + # max-cache-ttl=0 + ``` +2. Restart dnsmasq: + ```bash + sudo service dnsmasq restart + ``` + +**For Windows Server users:** +1. Configure DNS forwarder in Windows Server: + - Open DNS Manager + - Right-click on your server name + - Select "Properties" → "Forwarders" tab + - Add `` as a forwarder + +## Getting Help + +If you encounter any issues during migration or have questions about the v2.0.0 changes: + +1. **File an issue:** [GitHub Issues](https://github.com/Control-D-Inc/ctrld/issues) +2. **Contact support:** Email help@controld.com. +3. **Check documentation:** Review the [configuration documentation](config.md) for detailed setup instructions + +## Summary + +While ctrld v2.0.0 removes automatic configuration for router and server platforms, it provides a more focused experience for desktop users while still allowing router/server users to continue using ctrld with manual configuration or by staying on the v1.x.x branch. + +The migration path is designed to be straightforward, with multiple options to suit different use cases and technical comfort levels. diff --git a/doh.go b/doh.go index 3459cb8a..9e944dd1 100644 --- a/doh.go +++ b/doh.go @@ -53,6 +53,9 @@ var EncodeArchNameMap = map[string]string{ var DecodeArchNameMap = map[string]string{} func init() { + // Create reverse mappings for OS and architecture names + // This is needed because the API expects encoded values, but we need to decode + // them back to their original form for processing for k, v := range EncodeOsNameMap { DecodeOsNameMap[v] = k } @@ -85,8 +88,12 @@ type dohResolver struct { // Resolve performs DNS query with given DNS message using DOH protocol. func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "DoH resolver query started") + data, err := msg.Pack() if err != nil { + Log(ctx, logger.Error().Err(err), "Failed to pack DNS message") return nil, err } @@ -98,6 +105,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro endpoint.RawQuery = query.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) if err != nil { + Log(ctx, logger.Error().Err(err), "Could not create HTTP request") return nil, fmt.Errorf("could not create request: %w", err) } addHeader(ctx, req, r.uc) @@ -105,19 +113,23 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - c := http.Client{Transport: r.uc.dohTransport(dnsTyp)} + c := http.Client{Transport: r.uc.dohTransport(ctx, dnsTyp)} if r.isDoH3 { - transport := r.uc.doh3Transport(dnsTyp) + transport := r.uc.doh3Transport(ctx, dnsTyp) if transport == nil { + Log(ctx, logger.Error(), "DoH3 is not supported") return nil, errors.New("DoH3 is not supported") } c.Transport = transport } + + Log(ctx, logger.Debug(), "Sending DoH request to: %s", endpoint.String()) resp, err := c.Do(req) - if err != nil && r.uc.FallbackToDirectIP() { + if err != nil && r.uc.FallbackToDirectIP(ctx) { retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx)) defer cancel() - Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip") + logger := LoggerFromCtx(ctx) + logger.Warn().Err(err).Msg("Retrying request after fallback to direct ip") resp, err = c.Do(req.Clone(retryCtx)) } if err != nil { @@ -127,23 +139,29 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro closer.Close() } } + Log(ctx, logger.Error().Err(err), "DoH request failed") return nil, fmt.Errorf("could not perform request: %w", err) } defer resp.Body.Close() buf, err := io.ReadAll(resp.Body) if err != nil { + Log(ctx, logger.Error().Err(err), "Could not read response body") return nil, fmt.Errorf("could not read message from response: %w", err) } if resp.StatusCode != http.StatusOK { + Log(ctx, logger.Error(), "Wrong response from DOH server, got: %s, status: %d", string(buf), resp.StatusCode) return nil, fmt.Errorf("wrong response from DOH server, got: %s, status: %d", string(buf), resp.StatusCode) } answer := new(dns.Msg) if err := answer.Unpack(buf); err != nil { + Log(ctx, logger.Error().Err(err), "Failed to unpack DNS answer") return nil, fmt.Errorf("answer.Unpack: %w", err) } + + Log(ctx, logger.Debug(), "DoH resolver query successful") return answer, nil } @@ -163,7 +181,8 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { } } if printed { - Log(ctx, ProxyLogger.Load().Debug(), "sending request header: %v", dohHeader) + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "Sending request header: %v", dohHeader) } dohHeader.Set("Content-Type", headerApplicationDNS) dohHeader.Set("Accept", headerApplicationDNS) diff --git a/doh_test.go b/doh_test.go index 92fa79f8..700b299c 100644 --- a/doh_test.go +++ b/doh_test.go @@ -157,20 +157,21 @@ func Test_ClientCertificateVerificationError(t *testing.T) { }, } + ctx := context.Background() for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - tc.uc.Init() - tc.uc.SetupBootstrapIP() - r, err := NewResolver(tc.uc) + tc.uc.Init(ctx) + tc.uc.SetupBootstrapIP(ctx) + r, err := NewResolver(ctx, tc.uc) if err != nil { t.Fatal(err) } msg := new(dns.Msg) msg.SetQuestion("verify.controld.com.", dns.TypeA) msg.RecursionDesired = true - _, err = r.Resolve(context.Background(), msg) + _, err = r.Resolve(ctx, msg) // Verify the error contains the expected certificate information if err == nil { t.Fatal("expected certificate verification error, got nil") diff --git a/doq.go b/doq.go index 0903411c..b665cece 100644 --- a/doq.go +++ b/doq.go @@ -18,6 +18,9 @@ type doqResolver struct { } func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "DoQ resolver query started") + endpoint := r.uc.Endpoint tlsConfig := &tls.Config{NextProtos: []string{"doq"}} ip := r.uc.BootstrapIP @@ -26,12 +29,20 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - ip = r.uc.bootstrapIPForDNSType(dnsTyp) + ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp) } tlsConfig.ServerName = r.uc.Domain _, port, _ := net.SplitHostPort(endpoint) endpoint = net.JoinHostPort(ip, port) - return resolve(ctx, msg, endpoint, tlsConfig) + + Log(ctx, logger.Debug(), "Sending DoQ request to: %s", endpoint) + answer, err := resolve(ctx, msg, endpoint, tlsConfig) + if err != nil { + Log(ctx, logger.Error().Err(err), "DoQ request failed") + } else { + Log(ctx, logger.Debug(), "DoQ resolver query successful") + } + return answer, err } func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { diff --git a/dot.go b/dot.go index 295134c9..96fa651b 100644 --- a/dot.go +++ b/dot.go @@ -13,6 +13,9 @@ type dotResolver struct { } func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "DoT resolver query started") + // The dialer is used to prevent bootstrapping cycle. // If r.endpoint is set to dns.controld.dev, we need to resolve // dns.controld.dev first. By using a dialer with custom resolver, @@ -23,7 +26,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(dnsTyp) + tcpNet, _ := r.uc.netForDNSType(ctx, dnsTyp) dnsClient := &dns.Client{ Net: tcpNet, Dialer: dialer, @@ -37,6 +40,12 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro endpoint = net.JoinHostPort(r.uc.BootstrapIP, port) } + Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint) answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) + if err != nil { + Log(ctx, logger.Error().Err(err), "DoT request failed") + } else { + Log(ctx, logger.Debug(), "DoT resolver query successful") + } return answer, wrapCertificateVerificationError(err) } diff --git a/go.mod b/go.mod index 2280eb65..45a21d1c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/Control-D-Inc/ctrld -go 1.23.0 - -toolchain go1.23.7 +go 1.24 require ( github.com/Masterminds/semver/v3 v3.2.1 @@ -29,16 +27,15 @@ require ( github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.5.0 github.com/prometheus/prom2json v1.3.3 - github.com/quic-go/quic-go v0.54.0 - github.com/rs/zerolog v1.28.0 - github.com/spf13/cobra v1.8.1 - github.com/spf13/pflag v1.0.5 + github.com/quic-go/quic-go v0.57.1 + github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.16.0 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/net v0.38.0 - golang.org/x/sync v0.12.0 - golang.org/x/sys v0.31.0 + go.uber.org/zap v1.27.0 + golang.org/x/net v0.43.0 + golang.org/x/sync v0.16.0 + golang.org/x/sys v0.35.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.74.0 ) @@ -64,8 +61,6 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/magiconair/properties v1.8.7 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/mdlayher/netlink v1.7.2 // indirect @@ -77,23 +72,24 @@ require ( github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect - github.com/quic-go/qpack v0.5.1 // indirect + github.com/quic-go/qpack v0.6.0 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/spf13/afero v1.9.5 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/spf13/pflag v1.0.6 // indirect github.com/subosito/gotenv v1.4.2 // indirect github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e // indirect github.com/vishvananda/netns v0.0.4 // indirect - go.uber.org/mock v0.5.0 // indirect + go.uber.org/multierr v1.11.0 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect - golang.org/x/crypto v0.36.0 // indirect + golang.org/x/crypto v0.41.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect - golang.org/x/mod v0.19.0 // indirect - golang.org/x/text v0.23.0 // indirect - golang.org/x/tools v0.23.0 // indirect + golang.org/x/mod v0.27.0 // indirect + golang.org/x/text v0.28.0 // indirect + golang.org/x/tools v0.36.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 56a71e19..cedd8554 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= -github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE= -github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo= @@ -64,7 +62,7 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q= github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= @@ -207,12 +205,6 @@ github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= @@ -259,10 +251,10 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcETyaUgo= github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc= -github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= -github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg= -github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI10= +github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -272,18 +264,17 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM= github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= -github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmqnqMYADFk= github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.16.0 h1:rGGH0XDZhdUOryiDWjmIvUSWpbNqisK8Wk0Vyefw8hc= github.com/spf13/viper v1.16.0/go.mod h1:yg78JgCJcbrQOvV9YLXgkLaZqUidkY9K+Dd1FofRzQg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -299,8 +290,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= @@ -320,8 +311,14 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= @@ -336,8 +333,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -373,8 +370,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= -golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -407,8 +404,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -428,8 +425,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -472,14 +469,11 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -490,11 +484,13 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -542,8 +538,8 @@ golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= -golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= +golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/clientinfo/arp_unix.go b/internal/clientinfo/arp_unix.go index f5d8f884..51c934ae 100644 --- a/internal/clientinfo/arp_unix.go +++ b/internal/clientinfo/arp_unix.go @@ -20,6 +20,8 @@ func (a *arpDiscover) scan() { } // trim brackets + // Unix "arp -an" output formats IP addresses with parentheses like "(192.168.1.1)" + // We need to remove these brackets for proper IP parsing ip := strings.ReplaceAll(fields[1], "(", "") ip = strings.ReplaceAll(ip, ")", "") diff --git a/internal/clientinfo/arp_windows.go b/internal/clientinfo/arp_windows.go index 016b752f..c037b29f 100644 --- a/internal/clientinfo/arp_windows.go +++ b/internal/clientinfo/arp_windows.go @@ -17,10 +17,14 @@ func (a *arpDiscover) scan() { continue // empty lines } if line[0] != ' ' { + // Mark that we've found an interface header line + // Windows "arp -a" output has interface headers followed by ARP entries header = true // "Interface:" lines, next is header line. continue } if header { + // Skip the header line that follows interface names + // These lines contain column headers like "Internet Address" and "Physical Address" header = false // header lines continue } diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index f69b670f..93c9a8d2 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -79,10 +79,9 @@ type Table struct { initOnce sync.Once stopOnce sync.Once refreshInterval int + logger *ctrld.Logger dhcp *dhcp - merlin *merlinDiscover - ubios *ubiosDiscover arp *arpDiscover ndp *ndpDiscover ptr *ptrDiscover @@ -98,11 +97,18 @@ type Table struct { ptrNameservers []string } -func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { +func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string, logger *ctrld.Logger) *Table { refreshInterval := cfg.Service.DiscoverRefreshInterval + // Set default refresh interval if not configured + // This ensures client discovery continues to work even without explicit configuration if refreshInterval <= 0 { refreshInterval = 2 * 60 // 2 minutes } + // Use no-op logger if none provided + // This prevents nil pointer dereferences when logging is not configured + if logger == nil { + logger = ctrld.NopLogger + } return &Table{ svcCfg: cfg.Service, quitCh: make(chan struct{}), @@ -111,6 +117,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { cdUID: cdUID, ptrNameservers: ns, refreshInterval: refreshInterval, + logger: logger, } } @@ -179,7 +186,7 @@ func (t *Table) SetSelfIP(ip string) { // initSelfDiscover initializes necessary client metadata for self query. func (t *Table) initSelfDiscover() { - t.dhcp = &dhcp{selfIP: t.selfIP} + t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger} t.dhcp.addSelf() t.ipResolvers = append(t.ipResolvers, t.dhcp) t.macResolvers = append(t.macResolvers, t.dhcp) @@ -189,48 +196,24 @@ func (t *Table) initSelfDiscover() { func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { - ctrld.ProxyLogger.Load().Debug().Msg("start self discovery with custom client id") + t.logger.Debug().Msg("Start self discovery with custom client id") t.initSelfDiscover() return } // If we are running on platforms that should only do self discover, use it as the only source, too. if ctrld.SelfDiscover() { - ctrld.ProxyLogger.Load().Debug().Msg("start self discovery on desktop platforms") + t.logger.Debug().Msg("Start self discovery on desktop platforms") t.initSelfDiscover() return } - // Otherwise, process all possible sources in order, that means - // the first result of IP/MAC/Hostname lookup will be used. - // - // Routers custom clients: - // - Merlin - // - Ubios - if t.discoverDHCP() || t.discoverARP() { - t.merlin = &merlinDiscover{} - t.ubios = &ubiosDiscover{} - discovers := map[string]interface { - refresher - HostnameResolver - }{ - "Merlin": t.merlin, - "Ubios": t.ubios, - } - for platform, discover := range discovers { - if err := discover.refresh(); err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform) - } - t.hostnameResolvers = append(t.hostnameResolvers, discover) - t.refreshers = append(t.refreshers, discover) - } - } // Hosts file mapping. if t.discoverHosts() { - t.hf = &hostsFile{} - ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery") + t.hf = &hostsFile{logger: t.logger} + t.logger.Debug().Msg("Start hosts file discovery") if err := t.hf.init(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover") + t.logger.Error().Err(err).Msg("Could not init hosts file discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.hf) t.refreshers = append(t.refreshers, t.hf) @@ -239,10 +222,10 @@ func (t *Table) init() { } // DHCP lease files. if t.discoverDHCP() { - t.dhcp = &dhcp{selfIP: t.selfIP} - ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery") + t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger} + t.logger.Debug().Msg("Start dhcp discovery") if err := t.dhcp.init(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init DHCP discover") + t.logger.Error().Err(err).Msg("Could not init dhcp discover") } else { t.ipResolvers = append(t.ipResolvers, t.dhcp) t.macResolvers = append(t.macResolvers, t.dhcp) @@ -253,8 +236,8 @@ func (t *Table) init() { // ARP/NDP table. if t.discoverARP() { t.arp = &arpDiscover{} - t.ndp = &ndpDiscover{} - ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery") + t.ndp = &ndpDiscover{logger: t.logger} + t.logger.Debug().Msg("Start arp discovery") discovers := map[string]interface { refresher IpResolver @@ -266,7 +249,7 @@ func (t *Table) init() { for protocol, discover := range discovers { if err := discover.refresh(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", protocol) + t.logger.Error().Err(err).Msgf("Could not init %s discover", protocol) } else { t.ipResolvers = append(t.ipResolvers, discover) t.macResolvers = append(t.macResolvers, discover) @@ -283,7 +266,10 @@ func (t *Table) init() { } // PTR lookup. if t.discoverPTR() { - t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} + t.ptr = &ptrDiscover{ + resolver: ctrld.NewPrivateResolver(context.Background()), + logger: t.logger, + } if len(t.ptrNameservers) > 0 { nss := make([]string, 0, len(t.ptrNameservers)) for _, ns := range t.ptrNameservers { @@ -292,21 +278,22 @@ func (t *Table) init() { host, port = h, p } // Only use valid ip:port pair. + // Invalid nameservers can cause PTR discovery to fail silently if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { nss = append(nss, net.JoinHostPort(host, port)) } else { - ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) + t.logger.Warn().Msgf("Ignoring invalid nameserver for ptr discover: %q", ns) } } if len(nss) > 0 { t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) + t.logger.Debug().Msgf("Using nameservers %v for ptr discovery", nss) } } - ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") + t.logger.Debug().Msg("Start ptr discovery") if err := t.ptr.refresh(); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover") + t.logger.Error().Err(err).Msg("Could not init ptr discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.ptr) t.refreshers = append(t.refreshers, t.ptr) @@ -314,10 +301,10 @@ func (t *Table) init() { } // mdns. if t.discoverMDNS() { - t.mdns = &mdns{} - ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery") + t.mdns = &mdns{logger: t.logger} + t.logger.Debug().Msg("Start mdns discovery") if err := t.mdns.init(t.quitCh); err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init mDNS discover") + t.logger.Error().Err(err).Msg("Could not init mdns discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.mdns) } @@ -483,6 +470,7 @@ func (t *Table) ListClients() []*Client { for _, c := range ipMap { // If we found a client with empty hostname, use hostname from // an existed client which has the same MAC address. + // This helps fill in missing hostnames when multiple IPs share the same MAC if cFromMac := clientsByMAC[c.Mac]; cFromMac != nil && c.Hostname == "" { c.Hostname = cFromMac.Hostname } diff --git a/internal/clientinfo/client_info_test.go b/internal/clientinfo/client_info_test.go index b5bdfa57..7abb9078 100644 --- a/internal/clientinfo/client_info_test.go +++ b/internal/clientinfo/client_info_test.go @@ -2,6 +2,8 @@ package clientinfo import ( "testing" + + "github.com/Control-D-Inc/ctrld" ) func Test_normalizeIP(t *testing.T) { @@ -28,8 +30,9 @@ func Test_normalizeIP(t *testing.T) { func TestTable_LookupRFC1918IPv4(t *testing.T) { table := &Table{ - dhcp: &dhcp{}, - arp: &arpDiscover{}, + dhcp: &dhcp{}, + arp: &arpDiscover{}, + logger: ctrld.NopLogger, } table.ipResolvers = append(table.ipResolvers, table.dhcp) diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 5d11d5eb..efe44ed9 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -13,13 +13,11 @@ import ( "strings" "sync" - "tailscale.com/net/netmon" - "github.com/fsnotify/fsnotify" + "tailscale.com/net/netmon" "tailscale.com/util/lineread" "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router" ) type dhcp struct { @@ -30,6 +28,7 @@ type dhcp struct { watcher *fsnotify.Watcher selfIP string + logger *ctrld.Logger } func (d *dhcp) init() error { @@ -39,10 +38,6 @@ func (d *dhcp) init() error { } d.addSelf() d.watcher = watcher - for file, format := range clientInfoFiles { - // Ignore errors for default lease files. - _ = d.addLeaseFile(file, format) - } return nil } @@ -50,11 +45,7 @@ func (d *dhcp) watchChanges() { if d.watcher == nil { return } - if dir := router.LeaseFilesDir(); dir != "" { - if err := d.watcher.Add(dir); err != nil { - ctrld.ProxyLogger.Load().Err(err).Str("dir", dir).Msg("could not watch lease dir") - } - } + for { select { case event, ok := <-d.watcher.Events: @@ -64,7 +55,7 @@ func (d *dhcp) watchChanges() { if event.Has(fsnotify.Create) { if format, ok := clientInfoFiles[event.Name]; ok { if err := d.addLeaseFile(event.Name, format); err != nil { - ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("could not add lease file") + d.logger.Err(err).Str("file", event.Name).Msg("Could not add lease file") } } continue @@ -72,14 +63,14 @@ func (d *dhcp) watchChanges() { if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { format := clientInfoFiles[event.Name] if err := d.readLeaseFile(event.Name, format); err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("leases file changed but failed to update client info") + d.logger.Err(err).Str("file", event.Name).Msg("Leases file changed but failed to update client info") } } case err, ok := <-d.watcher.Errors: if !ok { return } - ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file") + d.logger.Err(err).Msg("Could not watch client info file") } } @@ -150,6 +141,9 @@ func (d *dhcp) lookupIPByHostname(name string, v6 bool) string { return true } if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + // Categorize addresses into RFC1918 (private) and public + // RFC1918 addresses are prioritized because they're more likely to be + // the actual client IP in most network configurations if addr.IsPrivate() { rfc1918Addrs = append(rfc1918Addrs, addr) } else { @@ -222,7 +216,7 @@ func (d *dhcp) dnsmasqReadClientInfoReader(reader io.Reader) error { } ip := normalizeIP(string(fields[2])) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("Invalid ip address entry: %q", ip) ip = "" } @@ -273,13 +267,17 @@ func (d *dhcp) iscDHCPReadClientInfoReader(reader io.Reader) error { } switch fields[0] { case "lease": + // Normalize IP address to lowercase for consistent comparison + // DHCP lease files may contain mixed-case IP addresses ip = normalizeIP(strings.ToLower(fields[1])) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("Invalid ip address entry: %q", ip) ip = "" } case "hardware": if len(fields) >= 3 { + // Convert MAC to lowercase and remove trailing semicolon + // DHCP lease files use semicolon-terminated MAC addresses mac = strings.ToLower(strings.TrimRight(fields[2], ";")) if _, err := net.ParseMAC(mac); err != nil { // Invalid dhcp, skip. @@ -287,6 +285,8 @@ func (d *dhcp) iscDHCPReadClientInfoReader(reader io.Reader) error { } } case "client-hostname": + // Remove quotes and semicolons from hostname + // DHCP lease files may quote hostnames and add semicolons hostname = strings.Trim(fields[1], `";`) } } @@ -328,7 +328,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error { } ip := normalizeIP(record[0]) if net.ParseIP(ip) == nil { - ctrld.ProxyLogger.Load().Warn().Msgf("invalid ip address entry: %q", ip) + d.logger.Warn().Msgf("Invalid ip address entry: %q", ip) ip = "" } @@ -350,7 +350,7 @@ func (d *dhcp) keaDhcp4ReadClientInfoReader(r io.Reader) error { func (d *dhcp) addSelf() { hostname, err := os.Hostname() if err != nil { - ctrld.ProxyLogger.Load().Err(err).Msg("could not get hostname") + d.logger.Err(err).Msg("Could not get hostname") return } hostname = normalizeHostname(hostname) @@ -390,22 +390,4 @@ func (d *dhcp) addSelf() { } } }) - for _, netIface := range router.SelfInterfaces() { - mac := netIface.HardwareAddr.String() - if mac == "" { - return - } - d.mac2name.Store(mac, hostname) - addrs, _ := netIface.Addrs() - for _, addr := range addrs { - ipNet, ok := addr.(*net.IPNet) - if !ok { - continue - } - ip := ipNet.IP - d.mac.LoadOrStore(ip.String(), mac) - d.ip.LoadOrStore(mac, ip.String()) - d.ip2name.Store(ip.String(), hostname) - } - } } diff --git a/internal/clientinfo/dhcp_lease_files.go b/internal/clientinfo/dhcp_lease_files.go index 34aabf3a..3f1c5ac8 100644 --- a/internal/clientinfo/dhcp_lease_files.go +++ b/internal/clientinfo/dhcp_lease_files.go @@ -3,18 +3,5 @@ package clientinfo import "github.com/Control-D-Inc/ctrld" // clientInfoFiles specifies client info files and how to read them on supported platforms. -var clientInfoFiles = map[string]ctrld.LeaseFileFormat{ - "/tmp/dnsmasq.leases": ctrld.Dnsmasq, // ddwrt - "/tmp/dhcp.leases": ctrld.Dnsmasq, // openwrt - "/var/lib/misc/dnsmasq.leases": ctrld.Dnsmasq, // merlin - "/mnt/data/udapi-config/dnsmasq.lease": ctrld.Dnsmasq, // UDM Pro - "/data/udapi-config/dnsmasq.lease": ctrld.Dnsmasq, // UDR - "/etc/dhcpd/dhcpd-leases.log": ctrld.Dnsmasq, // Synology - "/tmp/var/lib/misc/dnsmasq.leases": ctrld.Dnsmasq, // Tomato - "/run/dnsmasq-dhcp.leases": ctrld.Dnsmasq, // EdgeOS - "/run/dhcpd.leases": ctrld.IscDhcpd, // EdgeOS - "/var/dhcpd/var/db/dhcpd.leases": ctrld.IscDhcpd, // Pfsense - "/home/pi/.router/run/dhcp/dnsmasq.leases": ctrld.Dnsmasq, // Firewalla - "/var/lib/kea/dhcp4.leases": ctrld.KeaDHCP4, // Pfsense - "/var/db/dnsmasq.leases": ctrld.Dnsmasq, // OPNsense -} +// TODO: cleanup this after server support removal. +var clientInfoFiles = map[string]ctrld.LeaseFileFormat{} diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go index d96229df..003e1b81 100644 --- a/internal/clientinfo/hostsfile.go +++ b/internal/clientinfo/hostsfile.go @@ -27,6 +27,7 @@ type hostsFile struct { watcher *fsnotify.Watcher mu sync.Mutex m map[string][]string + logger *ctrld.Logger } // init performs initialization works, which is necessary before hostsFile can be fully operated. @@ -55,7 +56,7 @@ func (hf *hostsFile) refresh() error { // override hosts file with host_entries.conf content if present. hem, err := parseHostEntriesConf(hostEntriesConfPath) if err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not read host_entries.conf file") + hf.logger.Debug().Err(err).Msg("Could not read host_entries.conf file") } for k, v := range hem { hf.m[k] = v @@ -77,14 +78,14 @@ func (hf *hostsFile) watchChanges() { } if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { if err := hf.refresh(); err != nil && !os.IsNotExist(err) { - ctrld.ProxyLogger.Load().Err(err).Msg("hosts file changed but failed to update client info") + hf.logger.Err(err).Msg("Hosts file changed but Failed to update client info") } } case err, ok := <-hf.watcher.Errors: if !ok { return } - ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file") + hf.logger.Err(err).Msg("Could not watch client info file") } } @@ -164,6 +165,8 @@ func parseHostEntriesConfFromReader(r io.Reader) map[string][]string { for scanner.Scan() { line := scanner.Text() if after, found := strings.CutPrefix(line, "local-zone:"); found { + // Extract local zone name for domain suffix removal + // This is needed because unbound appends the local zone to hostnames after = strings.TrimSpace(after) fields := strings.Fields(after) if len(fields) > 1 { @@ -176,6 +179,8 @@ func parseHostEntriesConfFromReader(r io.Reader) map[string][]string { if !found { continue } + // Clean up the parsed data by removing whitespace and quotes + // This ensures consistent formatting for hostname processing after = strings.TrimSpace(after) after = strings.Trim(after, `"`) fields := strings.Fields(after) @@ -183,6 +188,8 @@ func parseHostEntriesConfFromReader(r io.Reader) map[string][]string { continue } ip := fields[0] + // Remove local zone suffix from hostname for cleaner lookups + // Unbound adds the local zone to hostnames, but we want just the base name name := strings.TrimSuffix(fields[1], "."+localZone) hostsMap[ip] = append(hostsMap[ip], name) } diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index a09d7296..981de124 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -34,7 +34,8 @@ var ( ) type mdns struct { - name sync.Map // ip => hostname + name sync.Map // ip => hostname + logger *ctrld.Logger } func (m *mdns) LookupHostnameByIP(ip string) string { @@ -74,10 +75,7 @@ func (m *mdns) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() - if addr.IsLoopback() { // Continue searching if this is loopback address. - return true - } - return false + return addr.IsLoopback() // Continue searching if this is loopback address. } } return true @@ -92,9 +90,9 @@ func (m *mdns) init(quitCh chan struct{}) error { } // Check if IPv6 is available once and use the result for the rest of the function. - ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init") + m.logger.Debug().Msgf("Checking for ipv6 availability in mdns init") ipv6 := ctrldnet.IPv6Available(context.Background()) - ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6) + m.logger.Debug().Msgf("ipv6 is %v in mdns init", ipv6) v4ConnList := make([]*net.UDPConn, 0, len(ifaces)) v6ConnList := make([]*net.UDPConn, 0, len(ifaces)) @@ -128,11 +126,11 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan for { err := m.probe(conns, remoteAddr) if shouldStopProbing(err) { - ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err) + m.logger.Warn().Msgf("Stop probing %q: %v", remoteAddr, err) break } if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("error while probing mdns") + m.logger.Warn().Err(err).Msg("Error while probing mdns") bo.BackOff(context.Background(), errors.New("mdns probe backoff")) continue } @@ -160,7 +158,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if errors.Is(err, net.ErrClosed) { return } - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error") + m.logger.Debug().Err(err).Msg("Mdns readLoop error") return } @@ -183,11 +181,11 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if ip != "" && name != "" { name = normalizeHostname(name) if val, loaded := m.name.LoadOrStore(ip, name); !loaded { - ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip) + m.logger.Debug().Msgf("Found hostname: %q, ip: %q via mdns", name, ip) } else { old := val.(string) if old != name { - ctrld.ProxyLogger.Load().Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old) + m.logger.Debug().Msgf("Update hostname: %q, ip: %q, old: %q via mdns", name, ip, old) m.name.Store(ip, name) } } @@ -217,6 +215,8 @@ func (m *mdns) probe(conns []*net.UDPConn, remoteAddr net.Addr) error { for _, conn := range conns { _ = conn.SetWriteDeadline(time.Now().Add(time.Second * 30)) if _, werr := conn.WriteTo(buf, remoteAddr); werr != nil { + // Capture the last write error for reporting + // Multiple connections may fail, but we only report the last error err = werr } } @@ -226,7 +226,7 @@ func (m *mdns) probe(conns []*net.UDPConn, remoteAddr net.Addr) error { // getDataFromAvahiDaemonCache reads entries from avahi-daemon cache to update mdns data. func (m *mdns) getDataFromAvahiDaemonCache() { if _, err := exec.LookPath("avahi-browse"); err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not find avahi-browse binary, skipping.") + m.logger.Debug().Err(err).Msg("Could not find avahi-browse binary, skipping.") return } // Run avahi-browse to discover services from cache: @@ -236,7 +236,7 @@ func (m *mdns) getDataFromAvahiDaemonCache() { // - "-c" -> read from cache. out, err := exec.Command("avahi-browse", "-a", "-r", "-p", "-c").Output() if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("could not browse services from avahi cache") + m.logger.Debug().Err(err).Msg("Could not browse services from avahi cache") return } m.storeDataFromAvahiBrowseOutput(bytes.NewReader(out)) @@ -256,7 +256,7 @@ func (m *mdns) storeDataFromAvahiBrowseOutput(r io.Reader) { name := normalizeHostname(fields[6]) // Only using cache value if we don't have existed one. if _, loaded := m.name.LoadOrStore(ip, name); !loaded { - ctrld.ProxyLogger.Load().Debug().Msgf("found hostname: %q, ip: %q via avahi cache", name, ip) + m.logger.Debug().Msgf("Found hostname: %q, ip: %q via avahi cache", name, ip) } } } diff --git a/internal/clientinfo/mdns_test.go b/internal/clientinfo/mdns_test.go index e6f86989..28c23d9f 100644 --- a/internal/clientinfo/mdns_test.go +++ b/internal/clientinfo/mdns_test.go @@ -3,6 +3,8 @@ package clientinfo import ( "strings" "testing" + + "github.com/Control-D-Inc/ctrld" ) func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) { @@ -11,7 +13,7 @@ func Test_mdns_storeDataFromAvahiBrowseOutput(t *testing.T) { =;wlp0s20f3;IPv6;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0" =;wlp0s20f3;IPv4;Foo\032\0402\041;_companion-link._tcp;local;Foo-2.local;192.168.1.123;64842;"rpBA=00:00:00:00:00:01" "rpHI=e6ae2cbbca0e" "rpAD=36566f4d850f" "rpVr=510.71.1" "rpHA=0ddc20fdddc8" "rpFl=0x30000" "rpHN=1d4a03afdefa" "rpMac=0" ` - m := &mdns{} + m := &mdns{logger: ctrld.NopLogger} m.storeDataFromAvahiBrowseOutput(strings.NewReader(content)) ip := "192.168.1.123" val, loaded := m.name.LoadOrStore(ip, "") diff --git a/internal/clientinfo/merlin.go b/internal/clientinfo/merlin.go deleted file mode 100644 index 8a39398f..00000000 --- a/internal/clientinfo/merlin.go +++ /dev/null @@ -1,71 +0,0 @@ -package clientinfo - -import ( - "strings" - "sync" - - "github.com/Control-D-Inc/ctrld/internal/router" - "github.com/Control-D-Inc/ctrld/internal/router/merlin" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -const merlinNvramCustomClientListKey = "custom_clientlist" - -type merlinDiscover struct { - hostname sync.Map // mac => hostname -} - -func (m *merlinDiscover) refresh() error { - if router.Name() != merlin.Name { - return nil - } - out, err := nvram.Run("get", merlinNvramCustomClientListKey) - if err != nil { - return err - } - ctrld.ProxyLogger.Load().Debug().Msg("reading Merlin custom client list") - m.parseMerlinCustomClientList(out) - return nil -} - -func (m *merlinDiscover) LookupHostnameByIP(ip string) string { - return "" -} - -func (m *merlinDiscover) LookupHostnameByMac(mac string) string { - val, ok := m.hostname.Load(mac) - if !ok { - return "" - } - return val.(string) -} - -// "nvram get custom_clientlist" output: -// -// 00:00:00:00:00:01>0>4>>00:00:00:00:00:02>0>24>>... -// -// So to parse it, do the following steps: -// -// - Split by "<" => entries -// - For each entry, split by ">" => parts -// - Empty parts => skip -// - Empty parts[0] => skip empty hostname -// - Empty parts[1] => skip empty MAC -func (m *merlinDiscover) parseMerlinCustomClientList(data string) { - entries := strings.Split(data, "<") - for _, entry := range entries { - parts := strings.SplitN(string(entry), ">", 3) - if len(parts) < 2 || len(parts[0]) == 0 || len(parts[1]) == 0 { - continue - } - hostname := normalizeHostname(parts[0]) - mac := strings.ToLower(parts[1]) - m.hostname.Store(mac, hostname) - } -} - -func (m *merlinDiscover) String() string { - return "merlin" -} diff --git a/internal/clientinfo/merlin_test.go b/internal/clientinfo/merlin_test.go deleted file mode 100644 index 0437035a..00000000 --- a/internal/clientinfo/merlin_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package clientinfo - -import ( - "testing" -) - -func TestParseMerlinCustomClientList(t *testing.T) { - tests := []struct { - name string - clientList string - macList []string - hostnameList []string - macNotPresentList []string - }{ - { - "normal", - "00:00:00:00:00:01>0>4>>", - []string{"00:00:00:00:00:01"}, - []string{"client1"}, - nil, - }, - { - "multiple clients", - "00:00:00:00:00:01>0>4>>00:00:00:00:00:02>0>24>>", - []string{"00:00:00:00:00:01", "00:00:00:00:00:02"}, - []string{"client1", "client2"}, - nil, - }, - { - "empty hostname", - "00:00:00:00:00:01>0>4>><>00:00:00:00:00:02>0>24>>", - []string{"00:00:00:00:00:01"}, - []string{"client1"}, - []string{"00:00:00:00:00:02"}, - }, - { - "empty dhcp", - "00:00:00:00:00:01>0>4>>>>", - []string{"00:00:00:00:00:01"}, - []string{"client1"}, - []string{""}, - }, - { - "invalid", - "qwerty", - nil, - nil, - nil, - }, - { - "empty", - "", - - nil, - nil, - nil, - }, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - m := &merlinDiscover{} - m.parseMerlinCustomClientList(tc.clientList) - for i, mac := range tc.macList { - val, ok := m.hostname.Load(mac) - if !ok { - t.Errorf("missing hostname: %s", mac) - } - hostname := val.(string) - if hostname != tc.hostnameList[i] { - t.Errorf("hostname mismatch, want: %q, got: %q", tc.hostnameList[i], hostname) - } - } - for _, mac := range tc.macNotPresentList { - if _, ok := m.hostname.Load(mac); ok { - t.Errorf("mac2name address %q should not be present", mac) - } - } - }) - } -} diff --git a/internal/clientinfo/ndp.go b/internal/clientinfo/ndp.go index 9d9155d7..f53e7fe1 100644 --- a/internal/clientinfo/ndp.go +++ b/internal/clientinfo/ndp.go @@ -20,8 +20,9 @@ import ( // ndpDiscover provides client discovery functionality using NDP protocol. type ndpDiscover struct { - mac sync.Map // ip => mac - ip sync.Map // mac => ip + mac sync.Map // ip => mac + ip sync.Map // mac => ip + logger *ctrld.Logger } // refresh re-scans the NDP table. @@ -97,7 +98,7 @@ func (nd *ndpDiscover) saveInfo(ip, mac string) { func (nd *ndpDiscover) listen(ctx context.Context) { ifis, err := allInterfacesWithV6LinkLocal() if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6 interfaces") + nd.logger.Debug().Err(err).Msg("Failed to find valid ipv6 interfaces") return } for _, ifi := range ifis { @@ -110,11 +111,11 @@ func (nd *ndpDiscover) listen(ctx context.Context) { func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface) { c, ip, err := ndp.Listen(ifi, ndp.Unspecified) if err != nil { - ctrld.ProxyLogger.Load().Debug().Err(err).Msg("ndp listen failed") + nd.logger.Debug().Err(err).Msg("Ndp listen failed") return } defer c.Close() - ctrld.ProxyLogger.Load().Debug().Msgf("listening ndp on: %s", ip.String()) + nd.logger.Debug().Msgf("Listening ndp on: %s", ip.String()) for { select { case <-ctx.Done(): @@ -128,7 +129,7 @@ func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface if errors.As(readErr, &opErr) && (opErr.Timeout() || opErr.Temporary()) { continue } - ctrld.ProxyLogger.Load().Debug().Err(readErr).Msg("ndp read loop error") + nd.logger.Debug().Err(readErr).Msg("Ndp read loop error") return } @@ -173,6 +174,9 @@ func (nd *ndpDiscover) scanUnix(r io.Reader) { } if mac := parseMAC(fields[1]); mac != "" { ip := fields[0] + // Remove interface suffix from IPv6 addresses + // Unix systems append interface names to IPv6 addresses (e.g., "fe80::1%eth0") + // This suffix needs to be removed for proper IP parsing if idx := strings.IndexByte(ip, '%'); idx != -1 { ip = ip[:idx] } @@ -191,11 +195,15 @@ func normalizeMac(mac string) string { return mac } // Windows use "-" instead of ":" as separator. + // This normalization is needed because different operating systems use different + // separators for MAC addresses, but net.ParseMAC expects ":" format mac = strings.ReplaceAll(mac, "-", ":") parts := strings.Split(mac, ":") if len(parts) != 6 { return "" } + // Pad single-digit hex values with leading zero + // This ensures consistent formatting for MAC address parsing for i, c := range parts { if len(c) == 1 { parts[i] = "0" + c diff --git a/internal/clientinfo/ndp_linux.go b/internal/clientinfo/ndp_linux.go index ebd416f0..fb3aacd2 100644 --- a/internal/clientinfo/ndp_linux.go +++ b/internal/clientinfo/ndp_linux.go @@ -5,15 +5,13 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" - - "github.com/Control-D-Inc/ctrld" ) // scan populates NDP table using information from system mappings. func (nd *ndpDiscover) scan() { neighs, err := netlink.NeighList(0, netlink.FAMILY_V6) if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not get neigh list") + nd.logger.Warn().Err(err).Msg("Could not get neighbor list") return } @@ -34,7 +32,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { done := make(chan struct{}) defer close(done) if err := netlink.NeighSubscribe(ch, done); err != nil { - ctrld.ProxyLogger.Load().Err(err).Msg("could not perform neighbor subscribing") + nd.logger.Err(err).Msg("Could not perform neighbor subscribing") return } for { @@ -47,7 +45,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { } ip := normalizeIP(nu.IP.String()) if nu.Type == unix.RTM_DELNEIGH { - ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor: %s", ip) + nd.logger.Debug().Msgf("Removing ndp neighbor: %s", ip) nd.mac.Delete(ip) continue } @@ -56,7 +54,7 @@ func (nd *ndpDiscover) subscribe(ctx context.Context) { case netlink.NUD_REACHABLE: nd.saveInfo(ip, mac) case netlink.NUD_FAILED: - ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor with failed state: %s", ip) + nd.logger.Debug().Msgf("Removing ndp neighbor with failed state: %s", ip) nd.mac.Delete(ip) } } diff --git a/internal/clientinfo/ndp_others.go b/internal/clientinfo/ndp_others.go index 007407b8..70d0c90b 100644 --- a/internal/clientinfo/ndp_others.go +++ b/internal/clientinfo/ndp_others.go @@ -7,8 +7,6 @@ import ( "context" "os/exec" "runtime" - - "github.com/Control-D-Inc/ctrld" ) // scan populates NDP table using information from system mappings. @@ -17,14 +15,14 @@ func (nd *ndpDiscover) scan() { case "windows": data, err := exec.Command("netsh", "interface", "ipv6", "show", "neighbors").Output() if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table") + nd.logger.Warn().Err(err).Msg("Could not query ndp table") return } nd.scanWindows(bytes.NewReader(data)) default: data, err := exec.Command("ndp", "-an").Output() if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msg("could not query ndp table") + nd.logger.Warn().Err(err).Msg("Could not query ndp table") return } nd.scanUnix(bytes.NewReader(data)) diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 9a1d10c4..aa6d5ec4 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -17,6 +17,7 @@ type ptrDiscover struct { hostname sync.Map // ip => hostname resolver ctrld.Resolver serverDown atomic.Bool + logger *ctrld.Logger } func (p *ptrDiscover) refresh() error { @@ -73,14 +74,14 @@ func (p *ptrDiscover) lookupHostname(ip string) string { msg := new(dns.Msg) addr, err := dns.ReverseAddr(ip) if err != nil { - ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address") + p.logger.Info().Str("discovery", "ptr").Err(err).Msg("Invalid ip address") return "" } msg.SetQuestion(addr, dns.TypePTR) ans, err := p.resolver.Resolve(ctx, msg) if err != nil { if p.serverDown.CompareAndSwap(false, true) { - ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") + p.logger.Info().Str("discovery", "ptr").Err(err).Msg("Could not perform ptr lookup") go p.checkServer() } return "" @@ -104,10 +105,9 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() - if addr.IsLoopback() { // Continue searching if this is loopback address. - return true - } - return false + // Continue searching if this is a loopback address + // We prefer non-loopback addresses as they're more likely to be the actual client IP + return addr.IsLoopback() // Continue searching if this is loopback address. } } return true diff --git a/internal/clientinfo/ubios.go b/internal/clientinfo/ubios.go deleted file mode 100644 index 0ffd6e59..00000000 --- a/internal/clientinfo/ubios.go +++ /dev/null @@ -1,79 +0,0 @@ -package clientinfo - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "os/exec" - "strings" - "sync" - - "github.com/Control-D-Inc/ctrld/internal/router" - "github.com/Control-D-Inc/ctrld/internal/router/ubios" -) - -// ubiosDiscover provides client discovery functionality on Ubios routers. -type ubiosDiscover struct { - hostname sync.Map // mac => hostname -} - -// refresh reloads unifi devices from database. -func (u *ubiosDiscover) refresh() error { - if router.Name() != ubios.Name { - return nil - } - return u.refreshDevices() -} - -// LookupHostnameByIP returns hostname for given IP. -func (u *ubiosDiscover) LookupHostnameByIP(ip string) string { - return "" -} - -// LookupHostnameByMac returns unifi device custom hostname for the given MAC address. -func (u *ubiosDiscover) LookupHostnameByMac(mac string) string { - val, ok := u.hostname.Load(mac) - if !ok { - return "" - } - return val.(string) -} - -// refreshDevices updates unifi devices name from local mongodb. -func (u *ubiosDiscover) refreshDevices() error { - cmd := exec.Command("/usr/bin/mongo", "localhost:27117/ace", "--quiet", "--eval", ` - DBQuery.shellBatchSize = 256; - db.user.find({name: {$exists: true, $ne: ""}}, {_id:0, mac:1, name:1});`) - b, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("out: %s, err: %w", string(b), err) - } - return u.storeDevices(bytes.NewReader(b)) -} - -// storeDevices saves unifi devices name for caching. -func (u *ubiosDiscover) storeDevices(r io.Reader) error { - decoder := json.NewDecoder(r) - device := struct { - MAC string - Name string - }{} - for { - err := decoder.Decode(&device) - if err == io.EOF { - break - } - if err != nil { - return err - } - mac := strings.ToLower(device.MAC) - u.hostname.Store(mac, normalizeHostname(device.Name)) - } - return nil -} - -// String returns human-readable format of ubiosDiscover. -func (u *ubiosDiscover) String() string { - return "ubios" -} diff --git a/internal/clientinfo/ubios_test.go b/internal/clientinfo/ubios_test.go deleted file mode 100644 index 657cf180..00000000 --- a/internal/clientinfo/ubios_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package clientinfo - -import ( - "strings" - "testing" -) - -func Test_ubiosDiscover_storeDevices(t *testing.T) { - ud := &ubiosDiscover{} - r := strings.NewReader(`{ "mac": "00:00:00:00:00:01", "name": "device 1" } -{ "mac": "00:00:00:00:00:02", "name": "device 2" } -`) - if err := ud.storeDevices(r); err != nil { - t.Fatal(err) - } - - tests := []struct { - name string - mac string - hostname string - }{ - {"device 1", "00:00:00:00:00:01", "device 1"}, - {"device 2", "00:00:00:00:00:02", "device 2"}, - {"non-existed", "00:00:00:00:00:03", ""}, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - if got := ud.LookupHostnameByMac(tc.mac); got != tc.hostname { - t.Errorf("hostname mismatched, want: %q, got: %q", tc.hostname, got) - } - }) - } - - // Test for invalid input. - r = strings.NewReader(`{ "mac": "00:00:00:00:00:01", "name": "device 1"`) - if err := ud.storeDevices(r); err == nil { - t.Fatal("expected error, got nil") - } else { - t.Log(err) - } -} diff --git a/internal/controld/config.go b/internal/controld/config.go index 595e758e..fe5bd72c 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -18,8 +18,6 @@ import ( "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/certs" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" - "github.com/Control-D-Inc/ctrld/internal/router" - "github.com/Control-D-Inc/ctrld/internal/router/ddwrt" ) const ( @@ -88,117 +86,176 @@ type LogsRequest struct { } // FetchResolverConfig fetch Control D config for given uid. -func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) { +func FetchResolverConfig(ctx context.Context, rawUID, version string, cdDev bool) (*ResolverConfig, error) { + logger := ctrld.LoggerFromCtx(ctx) + ctrld.Log(ctx, logger.Debug(), "Fetching ControlD resolver configuration") + uid, clientID := ParseRawUID(rawUID) + ctrld.Log(ctx, logger.Debug(), "Parsed UID: %s, ClientID: %s", uid, clientID) + req := utilityRequest{UID: uid} if clientID != "" { req.ClientID = clientID + ctrld.Log(ctx, logger.Debug(), "Including client ID in request") } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) + + ctrld.Log(ctx, logger.Debug(), "Sending resolver config request to ControlD API") + return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body)) } // FetchResolverUID fetch resolver uid from provision token. -func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) { +func FetchResolverUID(ctx context.Context, req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) { + logger := ctrld.LoggerFromCtx(ctx) + ctrld.Log(ctx, logger.Debug(), "Fetching resolver UID from provision token") + if req == nil { + ctrld.Log(ctx, logger.Error(), "Invalid request: request is nil") return nil, errors.New("invalid request") } + hostname := req.Hostname if hostname == "" { hostname, _ = os.Hostname() + ctrld.Log(ctx, logger.Debug(), "Using system hostname: %s", hostname) + } else { + ctrld.Log(ctx, logger.Debug(), "Using provided hostname: %s", hostname) } + + ctrld.Log(ctx, logger.Debug(), "Sending UID request to ControlD API") body, _ := json.Marshal(UtilityOrgRequest{ProvToken: req.ProvToken, Hostname: hostname}) - return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, false, bytes.NewReader(body)) } // UpdateCustomLastFailed calls API to mark custom config is bad. -func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { +func UpdateCustomLastFailed(ctx context.Context, rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) req := utilityRequest{UID: uid} if clientID != "" { req.ClientID = clientID } body, _ := json.Marshal(req) - return postUtilityAPI(version, cdDev, true, bytes.NewReader(body)) + return postUtilityAPI(ctx, version, cdDev, true, bytes.NewReader(body)) } -func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { +func postUtilityAPI(ctx context.Context, version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) { + logger := ctrld.LoggerFromCtx(ctx) + ctrld.Log(ctx, logger.Debug(), "Posting utility API request") + apiUrl := resolverDataURLCom if cdDev { apiUrl = resolverDataURLDev + ctrld.Log(ctx, logger.Debug(), "Using development API URL: %s", apiUrl) + } else { + ctrld.Log(ctx, logger.Debug(), "Using production API URL: %s", apiUrl) } + + ctrld.Log(ctx, logger.Debug(), "Creating HTTP request") req, err := http.NewRequest("POST", apiUrl, body) if err != nil { + ctrld.Log(ctx, logger.Error(), "Failed to create HTTP request: %v", err) return nil, fmt.Errorf("http.NewRequest: %w", err) } + + ctrld.Log(ctx, logger.Debug(), "Setting request parameters") q := req.URL.Query() q.Set("platform", "ctrld") q.Set("version", version) if lastUpdatedFailed { q.Set("custom_last_failed", "1") + ctrld.Log(ctx, logger.Debug(), "Marking custom config as failed") } req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") - transport := apiTransport(cdDev) + + ctrld.Log(ctx, logger.Debug(), "Setting up API transport") + transport := apiTransport(ctx, cdDev) client := &http.Client{ Timeout: defaultTimeout, Transport: transport, } - resp, err := doWithFallback(client, req, apiServerIP(cdDev)) + + ctrld.Log(ctx, logger.Debug(), "Sending request to ControlD API") + resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev)) if err != nil { + ctrld.Log(ctx, logger.Error(), "Failed to send request to ControlD API: %v", err) return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err) } defer resp.Body.Close() + + ctrld.Log(ctx, logger.Debug(), "Processing API response") d := json.NewDecoder(resp.Body) if resp.StatusCode != http.StatusOK { errResp := &ErrorResponse{} if err := d.Decode(errResp); err != nil { + ctrld.Log(ctx, logger.Error(), "Failed to decode error response: %v", err) return nil, err } + ctrld.Log(ctx, logger.Error(), "ControlD API returned error: %s", errResp.Error()) return nil, errResp } ur := &utilityResponse{} if err := d.Decode(ur); err != nil { + ctrld.Log(ctx, logger.Error(), "Failed to decode utility response: %v", err) return nil, err } + + ctrld.Log(ctx, logger.Debug(), "Successfully received resolver configuration") return &ur.Body.Resolver, nil } // SendLogs sends runtime log to ControlD API. -func SendLogs(lr *LogsRequest, cdDev bool) error { +func SendLogs(ctx context.Context, lr *LogsRequest, cdDev bool) error { + logger := ctrld.LoggerFromCtx(ctx) + ctrld.Log(ctx, logger.Debug(), "Sending runtime logs to ControlD API") + defer lr.Data.Close() apiUrl := logURLCom if cdDev { apiUrl = logURLDev } + + ctrld.Log(ctx, logger.Debug(), "Creating HTTP request for log upload") req, err := http.NewRequest("POST", apiUrl, lr.Data) if err != nil { + ctrld.Log(ctx, logger.Error(), "Failed to create HTTP request: %v", err) return fmt.Errorf("http.NewRequest: %w", err) } q := req.URL.Query() q.Set("uid", lr.UID) req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - transport := apiTransport(cdDev) + + ctrld.Log(ctx, logger.Debug(), "Setting up API transport") + transport := apiTransport(ctx, cdDev) client := &http.Client{ Timeout: sendLogTimeout, Transport: transport, } - resp, err := doWithFallback(client, req, apiServerIP(cdDev)) + + ctrld.Log(ctx, logger.Debug(), "Sending log data to ControlD API") + resp, err := doWithFallback(ctx, client, req, apiServerIP(cdDev)) if err != nil { + ctrld.Log(ctx, logger.Error(), "Failed to send logs to ControlD API: %v", err) return fmt.Errorf("SendLogs client.Do: %w", err) } defer resp.Body.Close() + + ctrld.Log(ctx, logger.Debug(), "Processing API response") d := json.NewDecoder(resp.Body) if resp.StatusCode != http.StatusOK { errResp := &ErrorResponse{} if err := d.Decode(errResp); err != nil { + ctrld.Log(ctx, logger.Error(), "Failed to decode error response: %v", err) return err } + ctrld.Log(ctx, logger.Error(), "ControlD API returned error: %s", errResp.Error()) return errResp } _, _ = io.Copy(io.Discard, resp.Body) + + ctrld.Log(ctx, logger.Debug(), "Runtime logs sent successfully to ControlD API") return nil } @@ -213,7 +270,7 @@ func ParseRawUID(rawUID string) (string, string) { } // apiTransport returns an HTTP transport for connecting to ControlD API endpoint. -func apiTransport(cdDev bool) *http.Transport { +func apiTransport(loggerCtx context.Context, cdDev bool) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { apiDomain := apiDomainCom @@ -227,13 +284,16 @@ func apiTransport(cdDev bool) *http.Transport { apiIPs = []string{apiDomainDevIPv4} } - ips := ctrld.LookupIP(apiDomain) + ips := ctrld.LookupIP(loggerCtx, apiDomain) if len(ips) == 0 { - ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs) + logger := ctrld.LoggerFromCtx(loggerCtx) + logger.Warn().Msgf("No ips found for %s, use direct ips: %v", apiDomain, apiIPs) ips = apiIPs } // Separate IPv4 and IPv6 addresses + // This separation is needed because different network stacks may have different + // connectivity to IPv4 vs IPv6, so we try them separately for better reliability var ipv4s, ipv6s []string for _, ip := range ips { if strings.Contains(ip, ":") { @@ -245,7 +305,8 @@ func apiTransport(cdDev bool) *http.Transport { dial := func(ctx context.Context, network string, addrs []string) (net.Conn, error) { d := &ctrldnet.ParallelDialer{} - return d.DialContext(ctx, network, addrs, ctrld.ProxyLogger.Load()) + logger := ctrld.LoggerFromCtx(loggerCtx) + return d.DialContext(ctx, network, addrs, logger.Logger) } _, port, _ := net.SplitHostPort(addr) @@ -269,7 +330,7 @@ func apiTransport(cdDev bool) *http.Transport { // Fallback to direct IPv6 return dial(ctx, "tcp6", addrsFromPort(apiIpsV6, port)) } - if router.Name() == ddwrt.Name || runtime.GOOS == "android" { + if runtime.GOOS == "android" { transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()} } return transport @@ -283,10 +344,11 @@ func addrsFromPort(ips []string, port string) []string { return addrs } -func doWithFallback(client *http.Client, req *http.Request, apiIp string) (*http.Response, error) { +func doWithFallback(ctx context.Context, client *http.Client, req *http.Request, apiIp string) (*http.Response, error) { resp, err := client.Do(req) if err != nil { - ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp) + logger := ctrld.LoggerFromCtx(ctx) + logger.Warn().Err(err).Msgf("Failed to send request, fallback to direct ip: %s", apiIp) ipReq := req.Clone(req.Context()) ipReq.Host = apiIp ipReq.URL.Host = apiIp diff --git a/internal/net/net.go b/internal/net/net.go index f4b55860..e10db0f3 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -3,7 +3,6 @@ package net import ( "context" "errors" - "io" "net" "os" "os/signal" @@ -12,7 +11,7 @@ import ( "syscall" "time" - "github.com/rs/zerolog" + "go.uber.org/zap" "tailscale.com/logtail/backoff" ) @@ -34,8 +33,8 @@ var Dialer = &net.Dialer{ Dial: func(ctx context.Context, network, address string) (net.Conn, error) { d := ParallelDialer{} d.Timeout = 10 * time.Second - l := zerolog.New(io.Discard) - return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS}, &l) + l := zap.NewNop() + return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS}, l) }, }, } @@ -161,7 +160,7 @@ type ParallelDialer struct { net.Dialer } -func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string, logger *zerolog.Logger) (net.Conn, error) { +func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string, logger *zap.Logger) (net.Conn, error) { if len(addrs) == 0 { return nil, errors.New("empty addresses") } @@ -181,16 +180,16 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs for _, addr := range addrs { go func(addr string) { defer wg.Done() - logger.Debug().Msgf("dialing to %s", addr) + logger.Debug("Dialing to", zap.String("address", addr)) conn, err := d.Dialer.DialContext(ctx, network, addr) if err != nil { - logger.Debug().Msgf("failed to dial %s: %v", addr, err) + logger.Debug("Failed to dial", zap.String("address", addr), zap.Error(err)) } select { case ch <- ¶llelDialerResult{conn: conn, err: err}: case <-done: if conn != nil { - logger.Debug().Msgf("connection closed: %s", conn.RemoteAddr()) + logger.Debug("Connection closed", zap.String("remote_address", conn.RemoteAddr().String())) conn.Close() } } @@ -201,7 +200,7 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs for res := range ch { if res.err == nil { cancel() - logger.Debug().Msgf("connected to %s", res.conn.RemoteAddr()) + logger.Debug("Connected to", zap.String("remote_address", res.conn.RemoteAddr().String())) return res.conn, res.err } errs = append(errs, res.err) diff --git a/internal/resolvconffile/dns.go b/internal/resolvconffile/dns.go index 0d532eb2..db987504 100644 --- a/internal/resolvconffile/dns.go +++ b/internal/resolvconffile/dns.go @@ -1,5 +1,3 @@ -//go:build !js && !windows - package resolvconffile import ( @@ -11,6 +9,7 @@ import ( const resolvconfPath = "/etc/resolv.conf" +// NameServersWithPort retrieves a list of nameservers with the default DNS port 53 appended to each address. func NameServersWithPort() []string { c, err := resolvconffile.ParseFile(resolvconfPath) if err != nil { @@ -23,16 +22,27 @@ func NameServersWithPort() []string { return ns } +// NameServers retrieves a list of nameservers from the /etc/resolv.conf file +// Returns an empty slice if reading fails. func NameServers() []string { - c, err := resolvconffile.ParseFile(resolvconfPath) + nss, _ := NameserversFromFile(resolvconfPath) + return nss +} + +// NameserversFromFile reads nameserver addresses from the specified resolv.conf file +// and returns them as a slice of strings. +// +// Returns an error if the file cannot be parsed. +func NameserversFromFile(path string) ([]string, error) { + c, err := resolvconffile.ParseFile(path) if err != nil { - return nil + return nil, err } ns := make([]string, 0, len(c.Nameservers)) for _, nameserver := range c.Nameservers { ns = append(ns, nameserver.String()) } - return ns + return ns, nil } // SearchDomains returns the current search domains config in /etc/resolv.conf file. diff --git a/internal/router/ddwrt/ddwrt.go b/internal/router/ddwrt/ddwrt.go deleted file mode 100644 index edd7e6b6..00000000 --- a/internal/router/ddwrt/ddwrt.go +++ /dev/null @@ -1,117 +0,0 @@ -package ddwrt - -import ( - "errors" - "fmt" - "os/exec" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/Control-D-Inc/ctrld/internal/router/ntp" - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -const Name = "ddwrt" - -//lint:ignore ST1005 This error is for human. -var errDdwrtJffs2NotEnabled = errors.New(`could not install service without jffs, follow this guide to enable: - -https://wiki.dd-wrt.com/wiki/index.php/Journalling_Flash_File_System -`) - -var nvramKvMap = map[string]string{ - "dns_dnsmasq": "1", // Make dnsmasq running but disable DNS ability, ctrld will replace it. - "dnsmasq_options": "", // Configuration of dnsmasq set by ctrld, filled by setupDDWrt. - "dns_crypt": "0", // Disable DNSCrypt. - "dnssec": "0", // Disable DNSSEC. -} - -type Ddwrt struct { - cfg *ctrld.Config -} - -// New returns a router.Router for configuring/setup/run ctrld on ddwrt routers. -func New(cfg *ctrld.Config) *Ddwrt { - return &Ddwrt{cfg: cfg} -} - -func (d *Ddwrt) ConfigureService(config *service.Config) error { - if !ddwrtJff2Enabled() { - return errDdwrtJffs2NotEnabled - } - return nil -} - -func (d *Ddwrt) Install(_ *service.Config) error { - return nil -} - -func (d *Ddwrt) Uninstall(_ *service.Config) error { - return nil -} - -func (d *Ddwrt) PreRun() error { - _ = d.Cleanup() - return ntp.WaitNvram() -} - -func (d *Ddwrt) Setup() error { - if d.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Already setup. - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" { - return nil - } - - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, d.cfg) - if err != nil { - return err - } - - nvramKvMap["dnsmasq_options"] = data - if err := nvram.SetKV(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func (d *Ddwrt) Cleanup() error { - if d.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val != "1" { - return nil // was restored, nothing to do. - } - - nvramKvMap["dnsmasq_options"] = "" - // Restore old configs. - if err := nvram.Restore(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func restartDNSMasq() error { - if out, err := exec.Command("restart_dns").CombinedOutput(); err != nil { - return fmt.Errorf("restart_dns: %s, %w", string(out), err) - } - return nil -} - -func ddwrtJff2Enabled() bool { - out, _ := nvram.Run("get", "enable_jffs2") - return out == "1" -} diff --git a/internal/router/dnsmasq/conf.go b/internal/router/dnsmasq/conf.go deleted file mode 100644 index bb81d607..00000000 --- a/internal/router/dnsmasq/conf.go +++ /dev/null @@ -1,90 +0,0 @@ -package dnsmasq - -import ( - "bufio" - "bytes" - "errors" - "io" - "os" - "path/filepath" - "strings" -) - -func InterfaceNameFromConfig(filename string) (string, error) { - buf, err := os.ReadFile(filename) - if err != nil { - return "", err - } - return interfaceNameFromReader(bytes.NewReader(buf)) -} - -func interfaceNameFromReader(r io.Reader) (string, error) { - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - after, found := strings.CutPrefix(line, "interface=") - if found { - return after, nil - } - } - return "", errors.New("not found") -} - -// AdditionalConfigFiles returns a list of Dnsmasq configuration files found in the "/tmp/etc" directory. -func AdditionalConfigFiles() []string { - if paths, err := filepath.Glob("/tmp/etc/dnsmasq-*.conf"); err == nil { - return paths - } - return nil -} - -// AdditionalLeaseFiles returns a list of lease file paths corresponding to the Dnsmasq configuration files. -func AdditionalLeaseFiles() []string { - cfgFiles := AdditionalConfigFiles() - if len(cfgFiles) == 0 { - return nil - } - leaseFiles := make([]string, 0, len(cfgFiles)) - for _, cfgFile := range cfgFiles { - if leaseFile := leaseFileFromConfigFileName(cfgFile); leaseFile != "" { - leaseFiles = append(leaseFiles, leaseFile) - - } else { - leaseFiles = append(leaseFiles, defaultLeaseFileFromConfigPath(cfgFile)) - } - } - return leaseFiles -} - -// leaseFileFromConfigFileName retrieves the DHCP lease file path by reading and parsing the provided configuration file. -func leaseFileFromConfigFileName(cfgFile string) string { - if f, err := os.Open(cfgFile); err == nil { - return leaseFileFromReader(f) - } - return "" -} - -// leaseFileFromReader parses the given io.Reader for the "dhcp-leasefile" configuration and returns its value as a string. -func leaseFileFromReader(r io.Reader) string { - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "#") { - continue - } - before, after, found := strings.Cut(line, "=") - if !found { - continue - } - if before == "dhcp-leasefile" { - return after - } - } - return "" -} - -// defaultLeaseFileFromConfigPath generates the default lease file path based on the provided configuration file path. -func defaultLeaseFileFromConfigPath(path string) string { - name := filepath.Base(path) - return filepath.Join("/var/lib/misc", strings.TrimSuffix(name, ".conf")+".leases") -} diff --git a/internal/router/dnsmasq/conf_test.go b/internal/router/dnsmasq/conf_test.go deleted file mode 100644 index 9ca672be..00000000 --- a/internal/router/dnsmasq/conf_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package dnsmasq - -import ( - "io" - "strings" - "testing" -) - -func Test_interfaceNameFromReader(t *testing.T) { - tests := []struct { - name string - in string - wantIface string - }{ - { - "good", - `interface=lo`, - "lo", - }, - { - "multiple", - `interface=lo -interface=eth0 -`, - "lo", - }, - { - "no iface", - `cache-size=100`, - "", - }, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - ifaceName, err := interfaceNameFromReader(strings.NewReader(tc.in)) - if tc.wantIface != "" && err != nil { - t.Errorf("unexpected error: %v", err) - return - } - if tc.wantIface != ifaceName { - t.Errorf("mismatched, want: %q, got: %q", tc.wantIface, ifaceName) - } - }) - } -} - -func Test_leaseFileFromReader(t *testing.T) { - tests := []struct { - name string - in io.Reader - expected string - }{ - { - "default", - strings.NewReader(` -dhcp-script=/sbin/dhcpc_lease -dhcp-leasefile=/var/lib/misc/dnsmasq-1.leases -script-arp -`), - "/var/lib/misc/dnsmasq-1.leases", - }, - { - "non-default", - strings.NewReader(` -dhcp-script=/sbin/dhcpc_lease -dhcp-leasefile=/tmp/var/lib/misc/dnsmasq-1.leases -script-arp -`), - "/tmp/var/lib/misc/dnsmasq-1.leases", - }, - { - "missing", - strings.NewReader(` -dhcp-script=/sbin/dhcpc_lease -script-arp -`), - "", - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - if got := leaseFileFromReader(tc.in); got != tc.expected { - t.Errorf("leaseFileFromReader() = %v, want %v", got, tc.expected) - } - }) - } - -} diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go deleted file mode 100644 index 058b0b59..00000000 --- a/internal/router/dnsmasq/dnsmasq.go +++ /dev/null @@ -1,190 +0,0 @@ -package dnsmasq - -import ( - "errors" - "html/template" - "net" - "os" - "path/filepath" - "strings" - - "github.com/Control-D-Inc/ctrld" -) - -const CtrldMarker = `# GENERATED BY ctrld - DO NOT MODIFY` - -const ConfigContentTmpl = `# GENERATED BY ctrld - DO NOT MODIFY -no-resolv -{{- range .Upstreams}} -server={{ .IP }}#{{ .Port }} -{{- end}} -add-mac -add-subnet=32,128 -{{- if .CacheDisabled}} -cache-size=0 -{{- else}} -max-cache-ttl=0 -{{- end}} -` - -const ( - MerlinConfPath = "/tmp/etc/dnsmasq.conf" - MerlinJffsConfDir = "/jffs/configs" - MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf" - MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" -) - -const MerlinPostConfMarker = `# GENERATED BY ctrld - EOF` -const MerlinPostConfTmpl = `# GENERATED BY ctrld - DO NOT MODIFY - -#!/bin/sh - -config_file="$1" -. /usr/sbin/helper.sh - -pid=$(cat /tmp/ctrld.pid 2>/dev/null) -if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then - pc_delete "servers-file" "$config_file" # no WAN DNS settings - pc_append "no-resolv" "$config_file" # do not read /etc/resolv.conf - # use ctrld as upstream - pc_delete "server=" "$config_file" - {{- range .Upstreams}} - pc_append "server={{ .IP }}#{{ .Port }}" "$config_file" - {{- end}} - pc_delete "add-mac" "$config_file" - pc_delete "add-subnet" "$config_file" - pc_append "add-mac" "$config_file" # add client mac - pc_append "add-subnet=32,128" "$config_file" # add client ip - pc_delete "dnssec" "$config_file" # disable DNSSEC - pc_delete "trust-anchor=" "$config_file" # disable DNSSEC - pc_delete "cache-size=" "$config_file" - pc_append "cache-size=0" "$config_file" # disable cache - - # For John fork - pc_delete "resolv-file" "$config_file" # no WAN DNS settings - - # Change /etc/resolv.conf, which may be changed by WAN DNS setup - pc_delete "nameserver" /etc/resolv.conf - pc_append "nameserver 127.0.0.1" /etc/resolv.conf - - exit 0 -fi -` - -type Upstream struct { - IP string - Port int -} - -// ConfTmpl generates dnsmasq configuration from ctrld config. -func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { - return ConfTmplWithCacheDisabled(tmplText, cfg, true) -} - -// ConfTmplWithCacheDisabled is like ConfTmpl, but the caller can control whether -// dnsmasq cache is disabled using cacheDisabled parameter. -// -// Generally, the caller should use ConfTmpl, but on some routers which dnsmasq config may be changed -// after ctrld started (like EdgeOS/Ubios, Firewalla ...), dnsmasq cache should not be disabled because -// the cache-size=0 generated by ctrld will conflict with router's generated config. -func ConfTmplWithCacheDisabled(tmplText string, cfg *ctrld.Config, cacheDisabled bool) (string, error) { - listener := cfg.FirstListener() - if listener == nil { - return "", errors.New("missing listener") - } - ip := listener.IP - if ip == "0.0.0.0" || ip == "::" || ip == "" { - ip = "127.0.0.1" - } - upstreams := []Upstream{{IP: ip, Port: listener.Port}} - return confTmpl(tmplText, upstreams, cacheDisabled) -} - -// FirewallaConfTmpl generates dnsmasq config for Firewalla routers. -func FirewallaConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { - // If ctrld listen on all interfaces, generating config for all of them. - if lc := cfg.FirstListener(); lc != nil && (lc.IP == "0.0.0.0" || lc.IP == "") { - return confTmpl(tmplText, firewallaUpstreams(lc.Port), false) - } - // Otherwise, generating config for the specific listener from ctrld's config. - return ConfTmplWithCacheDisabled(tmplText, cfg, false) -} - -func confTmpl(tmplText string, upstreams []Upstream, cacheDisabled bool) (string, error) { - tmpl := template.Must(template.New("").Parse(tmplText)) - var to = &struct { - Upstreams []Upstream - CacheDisabled bool - }{ - Upstreams: upstreams, - CacheDisabled: cacheDisabled, - } - var sb strings.Builder - if err := tmpl.Execute(&sb, to); err != nil { - return "", err - } - return sb.String(), nil -} - -func firewallaUpstreams(port int) []Upstream { - ifaces := FirewallaSelfInterfaces() - upstreams := make([]Upstream, 0, len(ifaces)) - for _, netIface := range ifaces { - addrs, _ := netIface.Addrs() - for _, addr := range addrs { - if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.To4() != nil { - upstreams = append(upstreams, Upstream{ - IP: netIP.IP.To4().String(), - Port: port, - }) - } - } - } - return upstreams -} - -// firewallaDnsmasqConfFiles returns dnsmasq config files of all firewalla interfaces. -func firewallaDnsmasqConfFiles() ([]string, error) { - return filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf") -} - -// FirewallaSelfInterfaces returns list of interfaces that will be configured with default dnsmasq setup on Firewalla. -func FirewallaSelfInterfaces() []*net.Interface { - matches, err := firewallaDnsmasqConfFiles() - if err != nil { - return nil - } - ifaces := make([]*net.Interface, 0, len(matches)) - for _, match := range matches { - // Trim prefix and suffix to get the iface name only. - ifaceName := strings.TrimSuffix(strings.TrimPrefix(match, "/home/pi/firerouter/etc/dnsmasq.dns."), ".conf") - if netIface, _ := net.InterfaceByName(ifaceName); netIface != nil { - ifaces = append(ifaces, netIface) - } - } - return ifaces -} - -const ( - ubios43ConfPath = "/run/dnsmasq.dhcp.conf.d" - ubios42ConfPath = "/run/dnsmasq.conf.d" - ubios43PidFile = "/run/dnsmasq-main.pid" - ubios42PidFile = "/run/dnsmasq.pid" - UbiosConfName = "zzzctrld.conf" -) - -// UbiosConfPath returns the appropriate configuration path based on the system's directory structure. -func UbiosConfPath() string { - if st, _ := os.Stat(ubios43ConfPath); st != nil && st.IsDir() { - return ubios43ConfPath - } - return ubios42ConfPath -} - -// UbiosPidFile returns the appropriate dnsmasq pid file based on the system's directory structure. -func UbiosPidFile() string { - if st, _ := os.Stat(ubios43PidFile); st != nil && !st.IsDir() { - return ubios43PidFile - } - return ubios42PidFile -} diff --git a/internal/router/edgeos/edgeos.go b/internal/router/edgeos/edgeos.go deleted file mode 100644 index 7364ac11..00000000 --- a/internal/router/edgeos/edgeos.go +++ /dev/null @@ -1,209 +0,0 @@ -package edgeos - -import ( - "bufio" - "bytes" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" -) - -const ( - Name = "edgeos" - edgeOSDNSMasqConfigPath = "/etc/dnsmasq.d/dnsmasq-zzz-ctrld.conf" - usgDNSMasqConfigPath = "/etc/dnsmasq.conf" - usgDNSMasqBackupConfigPath = "/etc/dnsmasq.conf.bak" - toggleContentFilteringLink = "https://community.ui.com/questions/UDM-Pro-disable-enable-DNS-filtering/e2cc4060-e56a-4139-b200-62d7f773ff8f" - toggleDnsShieldLink = "https://community.ui.com/questions/UniFi-OS-3-2-7-DNS-Shield-Missing/d3a85905-4ce0-4fe4-8bf0-6cb04f21371d" -) - -var ErrContentFilteringEnabled = fmt.Errorf(`the "Content Filtering" feature" is enabled, which is conflicted with ctrld.\n -To disable it, folowing instruction here: %s`, toggleContentFilteringLink) - -var ErrDnsShieldEnabled = fmt.Errorf(`the "DNS Shield" feature" is enabled, which is conflicted with ctrld.\n -To disable it, folowing screenshot here: %s`, toggleDnsShieldLink) - -type EdgeOS struct { - cfg *ctrld.Config - isUSG bool -} - -// New returns a router.Router for configuring/setup/run ctrld on EdgeOS routers. -func New(cfg *ctrld.Config) *EdgeOS { - e := &EdgeOS{cfg: cfg} - e.isUSG = checkUSG() - return e -} - -func (e *EdgeOS) ConfigureService(config *service.Config) error { - return nil -} - -func (e *EdgeOS) Install(_ *service.Config) error { - // If "Content Filtering" is enabled, UniFi OS will create firewall rules to intercept all DNS queries - // from outside, and route those queries to separated interfaces (e.g: dnsfilter-2@if79) created by UniFi OS. - // Thus, those queries will never reach ctrld listener. UniFi OS does not provide any mechanism to toggle this - // feature via command line, so there's nothing ctrld can do to disable this feature. For now, reporting an - // error and guiding users to disable the feature using UniFi OS web UI. - if ContentFilteringEnabled() { - return ErrContentFilteringEnabled - } - // If "DNS Shield" is enabled, UniFi OS will spawn dnscrypt-proxy process, and route all DNS queries to it. So - // reporting an error and guiding users to disable the feature using UniFi OS web UI. - if DnsShieldEnabled() { - return ErrDnsShieldEnabled - } - return nil -} - -func (e *EdgeOS) Uninstall(_ *service.Config) error { - return nil -} - -func (e *EdgeOS) PreRun() error { - return nil -} - -func (e *EdgeOS) Setup() error { - if e.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - if e.isUSG { - return e.setupUSG() - } - return e.setupUDM() -} - -func (e *EdgeOS) Cleanup() error { - if e.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - if e.isUSG { - return e.cleanupUSG() - } - return e.cleanupUDM() -} - -func (e *EdgeOS) setupUSG() error { - // On USG, dnsmasq is configured to forward queries to external provider by default. - // So instead of generating config in /etc/dnsmasq.d, we need to create a backup of - // the config, then modify it to forward queries to ctrld listener. - - // Creating a backup. - buf, err := os.ReadFile(usgDNSMasqConfigPath) - if err != nil { - return fmt.Errorf("setupUSG: reading current config: %w", err) - } - if err := os.WriteFile(usgDNSMasqBackupConfigPath, buf, 0600); err != nil { - return fmt.Errorf("setupUSG: backup current config: %w", err) - } - - // Removing all configured upstreams and cache config. - var sb strings.Builder - scanner := bufio.NewScanner(bytes.NewReader(buf)) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "server=") { - continue - } - if strings.HasPrefix(line, "all-servers") { - continue - } - sb.WriteString(line) - } - - data, err := dnsmasq.ConfTmplWithCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false) - if err != nil { - return err - } - sb.WriteString("\n") - sb.WriteString(data) - if err := os.WriteFile(usgDNSMasqConfigPath, []byte(sb.String()), 0644); err != nil { - return fmt.Errorf("setupUSG: writing dnsmasq config: %w", err) - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return fmt.Errorf("setupUSG: restartDNSMasq: %w", err) - } - return nil -} - -func (e *EdgeOS) setupUDM() error { - data, err := dnsmasq.ConfTmplWithCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false) - if err != nil { - return err - } - if err := os.WriteFile(edgeOSDNSMasqConfigPath, []byte(data), 0600); err != nil { - return fmt.Errorf("setupUDM: generating dnsmasq config: %w", err) - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return fmt.Errorf("setupUDM: restartDNSMasq: %w", err) - } - return nil -} - -func (e *EdgeOS) cleanupUSG() error { - if err := os.Rename(usgDNSMasqBackupConfigPath, usgDNSMasqConfigPath); err != nil { - return fmt.Errorf("cleanupUSG: os.Rename: %w", err) - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return fmt.Errorf("cleanupUSG: restartDNSMasq: %w", err) - } - return nil -} - -func (e *EdgeOS) cleanupUDM() error { - // Remove the custom dnsmasq config - if err := os.Remove(edgeOSDNSMasqConfigPath); err != nil { - return fmt.Errorf("cleanupUDM: os.Remove: %w", err) - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return fmt.Errorf("cleanupUDM: restartDNSMasq: %w", err) - } - return nil -} - -func ContentFilteringEnabled() bool { - st, err := os.Stat("/run/dnsfilter/dnsfilter") - return err == nil && !st.IsDir() -} - -// DnsShieldEnabled reports whether DNS Shield is enabled. -// See: https://community.ui.com/releases/UniFi-OS-Dream-Machines-3-2-7/251dfc1e-f4dd-4264-a080-3be9d8b9e02b -func DnsShieldEnabled() bool { - buf, err := os.ReadFile(filepath.Join(dnsmasq.UbiosConfPath(), "dns.conf")) - if err != nil { - return false - } - return bytes.Contains(buf, []byte("server=127.0.0.1#5053")) -} - -func LeaseFileDir() string { - if checkUSG() { - return "" - } - return "/run" -} - -func checkUSG() bool { - out, _ := os.ReadFile("/etc/version") - return bytes.HasPrefix(out, []byte("UniFiSecurityGateway.")) -} - -func restartDNSMasq() error { - if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil { - return fmt.Errorf("edgeosRestartDNSMasq: %s, %w", string(out), err) - } - return nil -} diff --git a/internal/router/firewalla/firewalla.go b/internal/router/firewalla/firewalla.go deleted file mode 100644 index cdf65864..00000000 --- a/internal/router/firewalla/firewalla.go +++ /dev/null @@ -1,110 +0,0 @@ -package firewalla - -import ( - "fmt" - "os" - "os/exec" - "strings" - - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - - "github.com/Control-D-Inc/ctrld" - "github.com/kardianos/service" -) - -const ( - Name = "firewalla" - - firewallaDNSMasqConfigPath = "/home/pi/.firewalla/config/dnsmasq_local/ctrld" - firewallaConfigPostMainDir = "/home/pi/.firewalla/config/post_main.d" - firewallaCtrldInitScriptPath = "/home/pi/.firewalla/config/post_main.d/start_ctrld.sh" -) - -type Firewalla struct { - cfg *ctrld.Config -} - -// New returns a router.Router for configuring/setup/run ctrld on Firewalla routers. -func New(cfg *ctrld.Config) *Firewalla { - return &Firewalla{cfg: cfg} -} - -func (f *Firewalla) ConfigureService(_ *service.Config) error { - return nil -} - -func (f *Firewalla) Install(_ *service.Config) error { - // Writing startup script. - if err := writeFirewallStartupScript(); err != nil { - return fmt.Errorf("writing startup script: %w", err) - } - return nil -} - -func (f *Firewalla) Uninstall(_ *service.Config) error { - // Removing startup script. - if err := os.Remove(firewallaCtrldInitScriptPath); err != nil { - return fmt.Errorf("removing startup script: %w", err) - } - return nil -} - -func (f *Firewalla) PreRun() error { - return nil -} - -func (f *Firewalla) Setup() error { - if f.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - data, err := dnsmasq.FirewallaConfTmpl(dnsmasq.ConfigContentTmpl, f.cfg) - if err != nil { - return fmt.Errorf("generating dnsmasq config: %w", err) - } - if err := os.WriteFile(firewallaDNSMasqConfigPath, []byte(data), 0600); err != nil { - return fmt.Errorf("writing ctrld config: %w", err) - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return fmt.Errorf("restartDNSMasq: %w", err) - } - - return nil -} - -func (f *Firewalla) Cleanup() error { - if f.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Removing current config. - if err := os.Remove(firewallaDNSMasqConfigPath); err != nil { - return fmt.Errorf("removing ctrld config: %w", err) - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return fmt.Errorf("restartDNSMasq: %w", err) - } - - return nil -} - -func writeFirewallStartupScript() error { - if err := os.MkdirAll(firewallaConfigPostMainDir, 0775); err != nil { - return err - } - exe, err := os.Executable() - if err != nil { - return err - } - // This is called when "ctrld start ..." runs, so recording - // the same command line arguments to use in startup script. - argStr := strings.Join(os.Args[1:], " ") - script := fmt.Sprintf("#!/bin/bash\n\nsudo %q %s\n", exe, argStr) - return os.WriteFile(firewallaCtrldInitScriptPath, []byte(script), 0755) -} - -func restartDNSMasq() error { - return exec.Command("systemctl", "restart", "firerouter_dns").Run() -} diff --git a/internal/router/merlin/merlin.go b/internal/router/merlin/merlin.go deleted file mode 100644 index c1c68210..00000000 --- a/internal/router/merlin/merlin.go +++ /dev/null @@ -1,266 +0,0 @@ -package merlin - -import ( - "bytes" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - "unicode" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/Control-D-Inc/ctrld/internal/router/ntp" - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -const Name = "merlin" - -// nvramKvMap is a map of NVRAM key-value pairs used to configure and manage Merlin-specific settings. -var nvramKvMap = map[string]string{ - "dnspriv_enable": "0", // Ensure Merlin native DoT disabled. -} - -// dnsmasqConfig represents configuration paths for dnsmasq operations in Merlin firmware. -type dnsmasqConfig struct { - confPath string - jffsConfPath string -} - -// Merlin represents a configuration handler for setting up and managing ctrld on Merlin routers. -type Merlin struct { - cfg *ctrld.Config -} - -// New returns a router.Router for configuring/setup/run ctrld on Merlin routers. -func New(cfg *ctrld.Config) *Merlin { - return &Merlin{cfg: cfg} -} - -// ConfigureService configures the service based on the provided configuration. It returns an error if the configuration fails. -func (m *Merlin) ConfigureService(config *service.Config) error { - return nil -} - -// Install sets up the necessary configurations and services required for the Merlin instance to function properly. -func (m *Merlin) Install(_ *service.Config) error { - return nil -} - -// Uninstall removes the ctrld-related configurations and services from the Merlin router and reverts to the original state. -func (m *Merlin) Uninstall(_ *service.Config) error { - return nil -} - -// PreRun prepares the Merlin instance for operation by waiting for essential services and directories to become available. -func (m *Merlin) PreRun() error { - // Wait NTP ready. - _ = m.Cleanup() - if err := ntp.WaitNvram(); err != nil { - return err - } - // Wait until directories mounted. - for _, dir := range []string{"/tmp", "/proc"} { - waitDirExists(dir) - } - // Wait dnsmasq started. - for { - out, _ := exec.Command("pidof", "dnsmasq").CombinedOutput() - if len(bytes.TrimSpace(out)) > 0 { - break - } - time.Sleep(time.Second) - } - return nil -} - -// Setup initializes and configures the Merlin instance for use, including setting up dnsmasq and necessary nvram settings. -func (m *Merlin) Setup() error { - if m.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Already setup. - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" { - return nil - } - - if err := m.writeDnsmasqPostconf(); err != nil { - return err - } - - for _, cfg := range getDnsmasqConfigs() { - if err := m.setupDnsmasq(cfg); err != nil { - return fmt.Errorf("failed to setup dnsmasq: config: %s, error: %w", cfg.confPath, err) - } - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - - if err := nvram.SetKV(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - return nil -} - -// Cleanup restores the original dnsmasq and nvram configurations and restarts dnsmasq if necessary. -func (m *Merlin) Cleanup() error { - if m.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val != "1" { - return nil // was restored, nothing to do. - } - - // Restore old configs. - if err := nvram.Restore(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath) - if err != nil && !os.IsNotExist(err) { - return err - } - // Restore dnsmasq post conf file. - if err := os.WriteFile(dnsmasq.MerlinPostConfPath, merlinParsePostConf(buf), 0750); err != nil { - return err - } - - for _, cfg := range getDnsmasqConfigs() { - if err := m.cleanupDnsmasqJffs(cfg); err != nil { - return fmt.Errorf("failed to cleanup jffs dnsmasq: config: %s, error: %w", cfg.confPath, err) - } - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -// setupDnsmasq sets up dnsmasq configuration by writing postconf, copying configuration, and running a postconf script. -func (m *Merlin) setupDnsmasq(cfg *dnsmasqConfig) error { - src, err := os.Open(cfg.confPath) - if os.IsNotExist(err) { - return nil // nothing to do if conf file does not exist. - } - if err != nil { - return fmt.Errorf("failed to open dnsmasq config: %w", err) - } - defer src.Close() - - // Copy current dnsmasq config to cfg.jffsConfPath, - // Then we will run postconf script on this file. - // - // Normally, adding postconf script is enough. However, we see - // reports on some Merlin devices that postconf scripts does not - // work, but manipulating the config directly via /jffs/configs does. - dst, err := os.Create(cfg.jffsConfPath) - if err != nil { - return fmt.Errorf("failed to create %s: %w", cfg.jffsConfPath, err) - } - defer dst.Close() - - if _, err := io.Copy(dst, src); err != nil { - return fmt.Errorf("failed to copy current dnsmasq config: %w", err) - } - if err := dst.Close(); err != nil { - return fmt.Errorf("failed to save %s: %w", cfg.jffsConfPath, err) - } - - // Run postconf script on cfg.jffsConfPath directly. - cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, cfg.jffsConfPath) - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to run post conf: %s: %w", string(out), err) - } - return nil -} - -// cleanupDnsmasqJffs removes the JFFS configuration file specified in the given dnsmasqConfig, if it exists. -func (m *Merlin) cleanupDnsmasqJffs(cfg *dnsmasqConfig) error { - // Remove cfg.jffsConfPath file. - if err := os.Remove(cfg.jffsConfPath); err != nil && !os.IsNotExist(err) { - return err - } - return nil -} - -// writeDnsmasqPostconf writes the requireddnsmasqConfigs post-configuration for dnsmasq to enable custom DNS settings with ctrld. -func (m *Merlin) writeDnsmasqPostconf() error { - buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath) - // Already setup. - if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) { - return nil - } - if err != nil && !os.IsNotExist(err) { - return err - } - - data, err := dnsmasq.ConfTmpl(dnsmasq.MerlinPostConfTmpl, m.cfg) - if err != nil { - return err - } - data = strings.Join([]string{ - data, - "\n", - dnsmasq.MerlinPostConfMarker, - "\n", - string(buf), - }, "\n") - // Write dnsmasq post conf file. - return os.WriteFile(dnsmasq.MerlinPostConfPath, []byte(data), 0750) -} - -// restartDNSMasq restarts the dnsmasq service by executing the appropriate system command using "service". -// Returns an error if the command fails or if there is an issue processing the command output. -func restartDNSMasq() error { - if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil { - return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err) - } - return nil -} - -// getDnsmasqConfigs retrieves a list of dnsmasqConfig containing configuration and JFFS paths for dnsmasq operations. -func getDnsmasqConfigs() []*dnsmasqConfig { - cfgs := []*dnsmasqConfig{ - {dnsmasq.MerlinConfPath, dnsmasq.MerlinJffsConfPath}, - } - for _, path := range dnsmasq.AdditionalConfigFiles() { - jffsConfPath := filepath.Join(dnsmasq.MerlinJffsConfDir, filepath.Base(path)) - cfgs = append(cfgs, &dnsmasqConfig{path, jffsConfPath}) - } - - return cfgs -} - -// merlinParsePostConf parses the dnsmasq post configuration by removing content after the MerlinPostConfMarker, if present. -// If no marker is found, the original buffer is returned unmodified. -// Returns nil if the input buffer is empty. -func merlinParsePostConf(buf []byte) []byte { - if len(buf) == 0 { - return nil - } - parts := bytes.Split(buf, []byte(dnsmasq.MerlinPostConfMarker)) - if len(parts) != 1 { - return bytes.TrimLeftFunc(parts[1], unicode.IsSpace) - } - return buf -} - -// waitDirExists waits until the specified directory exists, polling its existence every second. -func waitDirExists(dir string) { - for { - if _, err := os.Stat(dir); !os.IsNotExist(err) { - return - } - time.Sleep(time.Second) - } -} diff --git a/internal/router/merlin/merlin_test.go b/internal/router/merlin/merlin_test.go deleted file mode 100644 index 057628cd..00000000 --- a/internal/router/merlin/merlin_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package merlin - -import ( - "bytes" - "strings" - "testing" - - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" -) - -func Test_merlinParsePostConf(t *testing.T) { - origContent := "# foo" - data := strings.Join([]string{ - dnsmasq.MerlinPostConfTmpl, - "\n", - dnsmasq.MerlinPostConfMarker, - "\n", - }, "\n") - - tests := []struct { - name string - data string - expected string - }{ - {"empty", "", ""}, - {"no ctrld", origContent, origContent}, - {"ctrld with data", data + origContent, origContent}, - {"ctrld without data", data, ""}, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - //t.Parallel() - if got := merlinParsePostConf([]byte(tc.data)); !bytes.Equal(got, []byte(tc.expected)) { - t.Errorf("unexpected result, want: %q, got: %q", tc.expected, string(got)) - } - }) - } -} diff --git a/internal/router/netgear_orbi_voxel/procd.go b/internal/router/netgear_orbi_voxel/procd.go deleted file mode 100644 index 750a17da..00000000 --- a/internal/router/netgear_orbi_voxel/procd.go +++ /dev/null @@ -1,22 +0,0 @@ -package netgear - -const openWrtScript = `#!/bin/sh /etc/rc.common -USE_PROCD=1 -# After dnsmasq starts -START=61 -# Before network stops -STOP=89 -cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}" -name="{{.Name}}" -pid_file="/var/run/${name}.pid" - -start_service() { - echo "Starting ${name}" - procd_open_instance - procd_set_param command ${cmd} - procd_set_param respawn # respawn automatically if something died - procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop - procd_close_instance - echo "${name} has been started" -} -` diff --git a/internal/router/netgear_orbi_voxel/voxel.go b/internal/router/netgear_orbi_voxel/voxel.go deleted file mode 100644 index 4338f9c6..00000000 --- a/internal/router/netgear_orbi_voxel/voxel.go +++ /dev/null @@ -1,220 +0,0 @@ -package netgear - -import ( - "bufio" - "bytes" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -const ( - Name = "netgear_orbi_voxel" - netgearOrbiVoxelDNSMasqConfigPath = "/etc/dnsmasq.conf" - netgearOrbiVoxelHomedir = "/mnt/bitdefender" - netgearOrbiVoxelStartupScript = "/mnt/bitdefender/rc.user" - netgearOrbiVoxelStartupScriptBackup = "/mnt/bitdefender/rc.user.bak" - netgearOrbiVoxelStartupScriptMarker = "\n# GENERATED BY ctrld" -) - -var nvramKvMap = map[string]string{ - "dns_hijack": "0", // Disable dns hijacking -} - -type NetgearOrbiVoxel struct { - cfg *ctrld.Config -} - -// New returns a router.Router for configuring/setup/run ctrld on ddwrt routers. -func New(cfg *ctrld.Config) *NetgearOrbiVoxel { - return &NetgearOrbiVoxel{cfg: cfg} -} - -func (d *NetgearOrbiVoxel) ConfigureService(svc *service.Config) error { - if err := d.checkInstalledDir(); err != nil { - return err - } - svc.Option["SysvScript"] = openWrtScript - return nil -} - -func (d *NetgearOrbiVoxel) Install(_ *service.Config) error { - // Ignoring error here at this moment is ok, since everything will be wiped out on reboot. - _ = exec.Command("/etc/init.d/ctrld", "enable").Run() - if err := d.checkInstalledDir(); err != nil { - return err - } - if err := backupVoxelStartupScript(); err != nil { - return fmt.Errorf("backup startup script: %w", err) - } - if err := writeVoxelStartupScript(); err != nil { - return fmt.Errorf("writing startup script: %w", err) - } - return nil -} - -func (d *NetgearOrbiVoxel) Uninstall(_ *service.Config) error { - if err := os.Remove(netgearOrbiVoxelStartupScript); err != nil && !os.IsNotExist(err) { - return err - } - err := os.Rename(netgearOrbiVoxelStartupScriptBackup, netgearOrbiVoxelStartupScript) - if err != nil && !os.IsNotExist(err) { - return err - } - return nil -} - -func (d *NetgearOrbiVoxel) PreRun() error { - return nil -} - -func (d *NetgearOrbiVoxel) Setup() error { - if d.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Already setup. - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" { - return nil - } - - data, err := dnsmasq.ConfTmplWithCacheDisabled(dnsmasq.ConfigContentTmpl, d.cfg, false) - if err != nil { - return err - } - currentConfig, _ := os.ReadFile(netgearOrbiVoxelDNSMasqConfigPath) - configContent := append(currentConfig, data...) - if err := os.WriteFile(netgearOrbiVoxelDNSMasqConfigPath, configContent, 0600); err != nil { - return err - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - - if err := nvram.SetKV(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - return nil -} - -func (d *NetgearOrbiVoxel) Cleanup() error { - if d.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val != "1" { - return nil // was restored, nothing to do. - } - - // Restore old configs. - if err := nvram.Restore(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - // Restore dnsmasq config. - if err := restoreDnsmasqConf(); err != nil { - return err - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -// checkInstalledDir checks that ctrld binary was installed in the correct directory. -func (d *NetgearOrbiVoxel) checkInstalledDir() error { - exePath, err := os.Executable() - if err != nil { - return fmt.Errorf("checkHomeDir: failed to get binary path %w", err) - } - if !strings.HasSuffix(filepath.Dir(exePath), netgearOrbiVoxelHomedir) { - return fmt.Errorf("checkHomeDir: could not install service outside %s", netgearOrbiVoxelHomedir) - } - return nil -} - -// backupVoxelStartupScript creates a backup of original startup script if existed. -func backupVoxelStartupScript() error { - // Do nothing if the startup script was modified by ctrld. - script, _ := os.ReadFile(netgearOrbiVoxelStartupScript) - if bytes.Contains(script, []byte(netgearOrbiVoxelStartupScriptMarker)) { - return nil - } - err := os.Rename(netgearOrbiVoxelStartupScript, netgearOrbiVoxelStartupScriptBackup) - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("backupVoxelStartupScript: %w", err) - } - return nil -} - -// writeVoxelStartupScript writes startup script to re-install ctrld upon reboot. -// See: https://github.com/SVoxel/ORBI-RBK50/pull/7 -func writeVoxelStartupScript() error { - exe, err := os.Executable() - if err != nil { - return fmt.Errorf("configure service: failed to get binary path %w", err) - } - // This is called when "ctrld start ..." runs, so recording - // the same command line arguments to use in startup script. - argStr := strings.Join(os.Args[1:], " ") - script, _ := os.ReadFile(netgearOrbiVoxelStartupScriptBackup) - script = append(script, fmt.Sprintf("%s\n%q %s\n", netgearOrbiVoxelStartupScriptMarker, exe, argStr)...) - f, err := os.Create(netgearOrbiVoxelStartupScript) - if err != nil { - return fmt.Errorf("failed to create startup script: %w", err) - } - defer f.Close() - - if _, err := f.Write(script); err != nil { - return fmt.Errorf("failed to write startup script: %w", err) - } - if err := f.Close(); err != nil { - return fmt.Errorf("failed to save startup script: %w", err) - } - return nil -} - -// restoreDnsmasqConf restores original dnsmasq configuration. -func restoreDnsmasqConf() error { - f, err := os.Open(netgearOrbiVoxelDNSMasqConfigPath) - if err != nil { - return err - } - defer f.Close() - - var bs []byte - buf := bytes.NewBuffer(bs) - - removed := false - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - if line == dnsmasq.CtrldMarker { - removed = true - } - if !removed { - _, err := buf.WriteString(line + "\n") - if err != nil { - return err - } - } - } - return os.WriteFile(netgearOrbiVoxelDNSMasqConfigPath, buf.Bytes(), 0644) -} - -func restartDNSMasq() error { - if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil { - return fmt.Errorf("restartDNSMasq: %s, %w", string(out), err) - } - return nil -} diff --git a/internal/router/ntp/ntp.go b/internal/router/ntp/ntp.go deleted file mode 100644 index 5c04a36d..00000000 --- a/internal/router/ntp/ntp.go +++ /dev/null @@ -1,49 +0,0 @@ -package ntp - -import ( - "bytes" - "context" - "errors" - "fmt" - "os/exec" - "time" - - "tailscale.com/logtail/backoff" - - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -// WaitNvram waits NTP synced by checking "ntp_ready" value using nvram. -func WaitNvram() error { - // Wait until `ntp_ready=1` set. - b := backoff.NewBackoff("ntp.Wait", func(format string, args ...any) {}, 10*time.Second) - for { - // ddwrt use "ntp_done": https://github.com/mirror/dd-wrt/blob/a08c693527ab3204bf7bebd408a7c9a83b6ede47/src/router/rc/ntp.c#L100 - for _, key := range []string{"ntp_ready", "ntp_done"} { - out, err := nvram.Run("get", key) - if err != nil { - return fmt.Errorf("PreStart: nvram: %w", err) - } - if out == "1" { - return nil - } - } - b.BackOff(context.Background(), errors.New("ntp not ready")) - } -} - -// WaitUpstart waits NTP synced by checking upstart task "ntpsync" is in "stop/waiting" state. -func WaitUpstart() error { - // Wait until `initctl status ntpsync` returns stop state. - b := backoff.NewBackoff("ntp.WaitUpstart", func(format string, args ...any) {}, 10*time.Second) - for { - out, err := exec.Command("initctl", "status", "ntpsync").CombinedOutput() - if err != nil { - return fmt.Errorf("exec.Command: %w", err) - } - if bytes.Contains(out, []byte("stop/waiting")) { - return nil - } - b.BackOff(context.Background(), errors.New("ntp not ready")) - } -} diff --git a/internal/router/nvram/nvram.go b/internal/router/nvram/nvram.go deleted file mode 100644 index e76c0171..00000000 --- a/internal/router/nvram/nvram.go +++ /dev/null @@ -1,89 +0,0 @@ -package nvram - -import ( - "bytes" - "fmt" - "os/exec" - "strings" -) - -const ( - CtrldKeyPrefix = "ctrld_" - CtrldSetupKey = "ctrld_setup" - CtrldInstallKey = "ctrld_install" - RCStartupKey = "rc_startup" -) - -// Run runs the given nvram command. -func Run(args ...string) (string, error) { - cmd := exec.Command("nvram", args...) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("%s:%w", stderr.String(), err) - } - return strings.TrimSpace(stdout.String()), nil -} - -/* -NOTE: - - For Openwrt, DNSSEC is not included in default dnsmasq (require dnsmasq-full). - - For Merlin, DNSSEC is configured during postconf script (see merlinDNSMasqPostConfTmpl). - - For Ubios UDM Pro/Dream Machine, DNSSEC is not included in their dnsmasq package: - +https://community.ui.com/questions/Implement-DNSSEC-into-UniFi/951c72b0-4d88-4c86-9174-45417bd2f9ca - +https://community.ui.com/questions/Enable-DNSSEC-for-Unifi-Dream-Machine-FW-updates/e68e367c-d09b-4459-9444-18908f7c1ea1 -*/ - -// SetKV writes the given key/value from map to nvram. -// The given setupKey is set to 1 to indicates key/value set. -func SetKV(m map[string]string, setupKey string) error { - // Backup current value, store ctrld's configs. - for key, value := range m { - old, err := Run("get", key) - if err != nil { - return fmt.Errorf("%s: %w", old, err) - } - if out, err := Run("set", CtrldKeyPrefix+key+"="+old); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - if out, err := Run("set", key+"="+value); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - } - - if out, err := Run("set", setupKey+"=1"); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - // Commit. - if out, err := Run("commit"); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - return nil -} - -// Restore restores the old value of given key from map m. -// The given setupKey is set to 0 to indicates key/value restored. -func Restore(m map[string]string, setupKey string) error { - // Restore old configs. - for key := range m { - ctrldKey := CtrldKeyPrefix + key - old, err := Run("get", ctrldKey) - if err != nil { - return fmt.Errorf("%s: %w", old, err) - } - _, _ = Run("unset", ctrldKey) - if out, err := Run("set", key+"="+old); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - } - - if out, err := Run("unset", setupKey); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - // Commit. - if out, err := Run("commit"); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - return nil -} diff --git a/internal/router/openwrt/openwrt.go b/internal/router/openwrt/openwrt.go deleted file mode 100644 index 73f5a06f..00000000 --- a/internal/router/openwrt/openwrt.go +++ /dev/null @@ -1,191 +0,0 @@ -package openwrt - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" -) - -const ( - Name = "openwrt" - openwrtDNSMasqConfigName = "ctrld.conf" - openwrtDNSMasqDefaultConfigDir = "/tmp/dnsmasq.d" -) - -var openwrtDnsmasqDefaultConfigPath = filepath.Join(openwrtDNSMasqDefaultConfigDir, openwrtDNSMasqConfigName) - -type Openwrt struct { - cfg *ctrld.Config - dnsmasqCacheSize string -} - -// New returns a router.Router for configuring/setup/run ctrld on Openwrt routers. -func New(cfg *ctrld.Config) *Openwrt { - return &Openwrt{cfg: cfg} -} - -func (o *Openwrt) ConfigureService(svc *service.Config) error { - svc.Option["SysvScript"] = openWrtScript - return nil -} - -func (o *Openwrt) Install(config *service.Config) error { - return exec.Command("/etc/init.d/ctrld", "enable").Run() -} - -func (o *Openwrt) Uninstall(config *service.Config) error { - return nil -} - -func (o *Openwrt) PreRun() error { - return nil -} - -func (o *Openwrt) Setup() error { - if o.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - - // Save current dnsmasq config cache size if present. - if cs, err := uci("get", "dhcp.@dnsmasq[0].cachesize"); err == nil { - o.dnsmasqCacheSize = cs - if _, err := uci("delete", "dhcp.@dnsmasq[0].cachesize"); err != nil { - return err - } - // Commit. - if _, err := uci("commit", "dhcp"); err != nil { - return err - } - } - - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, o.cfg) - if err != nil { - return err - } - if err := os.WriteFile(dnsmasqConfPathFromUbus(), []byte(data), 0600); err != nil { - return err - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func (o *Openwrt) Cleanup() error { - if o.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Remove the custom dnsmasq config - if err := os.Remove(dnsmasqConfPathFromUbus()); err != nil { - return err - } - - // Restore original value if present. - if o.dnsmasqCacheSize != "" { - if _, err := uci("set", fmt.Sprintf("dhcp.@dnsmasq[0].cachesize=%s", o.dnsmasqCacheSize)); err != nil { - return err - } - // Commit. - if _, err := uci("commit", "dhcp"); err != nil { - return err - } - } - - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func restartDNSMasq() error { - if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil { - return fmt.Errorf("%s: %w", string(out), err) - } - return nil -} - -var errUCIEntryNotFound = errors.New("uci: Entry not found") - -func uci(args ...string) (string, error) { - cmd := exec.Command("uci", args...) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - if strings.HasPrefix(stderr.String(), errUCIEntryNotFound.Error()) { - return "", errUCIEntryNotFound - } - return "", fmt.Errorf("%s:%w", stderr.String(), err) - } - return strings.TrimSpace(stdout.String()), nil -} - -// openwrtServiceList represents openwrt services config. -type openwrtServiceList struct { - Dnsmasq dnsmasqConf `json:"dnsmasq"` -} - -// dnsmasqConf represents dnsmasq config. -type dnsmasqConf struct { - Instances map[string]confInstances `json:"instances"` -} - -// confInstances represents an instance config of a service. -type confInstances struct { - Mount map[string]string `json:"mount"` -} - -// dnsmasqConfPath returns the dnsmasq config path. -// -// Since version 24.10, openwrt makes some changes to dnsmasq to support -// multiple instances of dnsmasq. This change causes breaking changes to -// software which depends on the default dnsmasq path. -// -// There are some discussion/PRs in openwrt repo to address this: -// -// - https://github.com/openwrt/openwrt/pull/16806 -// - https://github.com/openwrt/openwrt/pull/16890 -// -// In the meantime, workaround this problem by querying the actual config path -// by querying ubus service list. -func dnsmasqConfPath(r io.Reader) string { - var svc openwrtServiceList - if err := json.NewDecoder(r).Decode(&svc); err != nil { - return openwrtDnsmasqDefaultConfigPath - } - for _, inst := range svc.Dnsmasq.Instances { - for mount := range inst.Mount { - dirName := filepath.Base(mount) - parts := strings.Split(dirName, ".") - if len(parts) < 2 { - continue - } - if parts[0] == "dnsmasq" && parts[len(parts)-1] == "d" { - return filepath.Join(mount, openwrtDNSMasqConfigName) - } - } - } - return openwrtDnsmasqDefaultConfigPath -} - -// dnsmasqConfPathFromUbus get dnsmasq config path from ubus service list. -func dnsmasqConfPathFromUbus() string { - output, err := exec.Command("ubus", "call", "service", "list").Output() - if err != nil { - return openwrtDnsmasqDefaultConfigPath - } - return dnsmasqConfPath(bytes.NewReader(output)) -} diff --git a/internal/router/openwrt/openwrt_test.go b/internal/router/openwrt/openwrt_test.go deleted file mode 100644 index 8b260e88..00000000 --- a/internal/router/openwrt/openwrt_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package openwrt - -import ( - "io" - "path/filepath" - "strings" - "testing" -) - -// Sample output from https://github.com/openwrt/openwrt/pull/16806#issuecomment-2448255734 -const ubusDnsmasqBefore2410 = `{ - "dnsmasq": { - "instances": { - "guest_dns": { - "mount": { - "/tmp/dnsmasq.d": "0", - "/var/run/dnsmasq/": "1" - } - } - } - } -}` - -const ubusDnsmasq2410 = `{ - "dnsmasq": { - "instances": { - "guest_dns": { - "mount": { - "/tmp/dnsmasq.guest_dns.d": "0", - "/var/run/dnsmasq/": "1" - } - } - } - } -}` - -func Test_dnsmasqConfPath(t *testing.T) { - var dnsmasq2410expected = filepath.Join("/tmp/dnsmasq.guest_dns.d", openwrtDNSMasqConfigName) - tests := []struct { - name string - in io.Reader - expected string - }{ - {"empty", strings.NewReader(""), openwrtDnsmasqDefaultConfigPath}, - {"invalid", strings.NewReader("}}"), openwrtDnsmasqDefaultConfigPath}, - {"before 24.10", strings.NewReader(ubusDnsmasqBefore2410), openwrtDnsmasqDefaultConfigPath}, - {"24.10", strings.NewReader(ubusDnsmasq2410), dnsmasq2410expected}, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - if got := dnsmasqConfPath(tc.in); got != tc.expected { - t.Errorf("dnsmasqConfPath() = %v, want %v", got, tc.expected) - } - }) - } -} diff --git a/internal/router/openwrt/procd.go b/internal/router/openwrt/procd.go deleted file mode 100644 index bf7253e6..00000000 --- a/internal/router/openwrt/procd.go +++ /dev/null @@ -1,25 +0,0 @@ -package openwrt - -const openWrtScript = `#!/bin/sh /etc/rc.common -USE_PROCD=1 -# After network starts -START=21 -# Before network stops -STOP=89 -cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}" -name="{{.Name}}" -pid_file="/var/run/${name}.pid" - -start_service() { - echo "Starting ${name}" - procd_open_instance - procd_set_param command ${cmd} - procd_set_param respawn # respawn automatically if something died - procd_set_param stdout 1 # forward stdout of the command to logd - procd_set_param stderr 1 # same for stderr - procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop - procd_set_param term_timeout 10 - procd_close_instance - echo "${name} has been started" -} -` diff --git a/internal/router/os_config_freebsd.go b/internal/router/os_config_freebsd.go deleted file mode 100644 index 9066191e..00000000 --- a/internal/router/os_config_freebsd.go +++ /dev/null @@ -1,40 +0,0 @@ -package router - -import ( - "encoding/xml" - "os" -) - -// Config represents /conf/config.xml file found on pfsense/opnsense. -type Config struct { - PfsenseUnbound *string `xml:"unbound>enable,omitempty"` - OPNsenseUnbound *string `xml:"OPNsense>unboundplus>general>enabled,omitempty"` - Dnsmasq *string `xml:"dnsmasq>enable,omitempty"` -} - -// DnsmasqEnabled reports whether dnsmasq is enabled. -func (c *Config) DnsmasqEnabled() bool { - if isPfsense() { // pfsense only set the attribute if dnsmasq is enabled. - return c.Dnsmasq != nil - } - return c.Dnsmasq != nil && *c.Dnsmasq == "1" -} - -// UnboundEnabled reports whether unbound is enabled. -func (c *Config) UnboundEnabled() bool { - if isPfsense() { // pfsense only set the attribute if unbound is enabled. - return c.PfsenseUnbound != nil - } - return c.OPNsenseUnbound != nil && *c.OPNsenseUnbound == "1" -} - -// currentConfig does unmarshalling /conf/config.xml file, -// return the corresponding *Config represent it. -func currentConfig() (*Config, error) { - buf, _ := os.ReadFile("/conf/config.xml") - c := Config{} - if err := xml.Unmarshal(buf, &c); err != nil { - return nil, err - } - return &c, nil -} diff --git a/internal/router/os_freebsd.go b/internal/router/os_freebsd.go deleted file mode 100644 index 9a79188f..00000000 --- a/internal/router/os_freebsd.go +++ /dev/null @@ -1,157 +0,0 @@ -package router - -import ( - "bytes" - "fmt" - "net" - "os" - "os/exec" - "path/filepath" - "text/template" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" -) - -const ( - osName = "freebsd" - rcPath = "/usr/local/etc/rc.d" - rcConfPath = "/etc/rc.conf.d/" - unboundRcPath = rcPath + "/unbound" - dnsmasqRcPath = rcPath + "/dnsmasq" -) - -func newOsRouter(cfg *ctrld.Config, cdMode bool) Router { - return &osRouter{cfg: cfg, cdMode: cdMode} -} - -type osRouter struct { - cfg *ctrld.Config - svcName string - // cdMode indicates whether the router will configure ctrld in cd mode (aka --cd=). - // When ctrld is running on freebsd-like routers, and there's process running on port 53 - // in cd mode, ctrld will attempt to kill the process and become direct listener. - // See details implemenation in osRouter.PreRun method. - cdMode bool -} - -func (or *osRouter) ConfigureService(svc *service.Config) error { - svc.Option["SysvScript"] = bsdInitScript - or.svcName = svc.Name - rcFile := filepath.Join(rcConfPath, or.svcName) - var to = &struct { - Name string - }{ - or.svcName, - } - - f, err := os.Create(rcFile) - if err != nil { - return fmt.Errorf("os.Create: %w", err) - } - defer f.Close() - if err := template.Must(template.New("").Parse(rcConfTmpl)).Execute(f, to); err != nil { - return err - } - return f.Close() -} - -func (or *osRouter) Install(_ *service.Config) error { - if isPfsense() { - // pfsense need ".sh" extension for script to be run at boot. - // See: https://docs.netgate.com/pfsense/en/latest/development/boot-commands.html#shell-script-option - oldname := filepath.Join(rcPath, or.svcName) - newname := filepath.Join(rcPath, or.svcName+".sh") - _ = os.Remove(newname) - if err := os.Symlink(oldname, newname); err != nil { - return fmt.Errorf("os.Symlink: %w", err) - } - } - return nil -} - -func (or *osRouter) Uninstall(_ *service.Config) error { - rcFiles := []string{filepath.Join(rcConfPath, or.svcName)} - if isPfsense() { - rcFiles = append(rcFiles, filepath.Join(rcPath, or.svcName+".sh")) - } - for _, filename := range rcFiles { - if err := os.Remove(filename); err != nil { - return fmt.Errorf("os.Remove: %w", err) - } - } - - return nil -} - -func (or *osRouter) PreRun() error { - if or.cdMode { - addr := "0.0.0.0:53" - udpLn, udpErr := net.ListenPacket("udp", addr) - if udpLn != nil { - udpLn.Close() - } - tcpLn, tcpErr := net.Listen("tcp", addr) - if tcpLn != nil { - tcpLn.Close() - } - // If we could not listen on :53 for any reason, try killing unbound/dnsmasq, become direct listener - if udpErr != nil || tcpErr != nil { - _ = exec.Command("killall", "unbound").Run() - _ = exec.Command("killall", "dnsmasq").Run() - } - } - return nil -} - -func (or *osRouter) Setup() error { - return nil -} - -func (or *osRouter) Cleanup() error { - if or.cdMode { - c, err := currentConfig() - if err != nil { - return err - } - if c.UnboundEnabled() { - _ = exec.Command(unboundRcPath, "onerestart").Run() - } - if c.DnsmasqEnabled() { - _ = exec.Command(dnsmasqRcPath, "onerestart").Run() - } - } - return nil -} - -func isPfsense() bool { - b, err := os.ReadFile("/etc/platform") - return err == nil && bytes.HasPrefix(b, []byte("pfSense")) -} - -const bsdInitScript = `#!/bin/sh - -# PROVIDE: {{.Name}} -# REQUIRE: SERVERS -# REQUIRE: unbound dnsmasq securelevel -# KEYWORD: shutdown - -. /etc/rc.subr - -name="{{.Name}}" -rcvar="${name}_enable" -{{.Name}}_env="IS_DAEMON=1" -pidfile="/var/run/${name}.pid" -child_pidfile="/var/run/${name}_child.pid" -command="/usr/sbin/daemon" -daemon_args="-r -P ${pidfile} -p ${child_pidfile} -t \"${name}: daemon\"{{if .WorkingDirectory}} -c {{.WorkingDirectory}}{{end}}" -command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}" - -load_rc_config "${name}" -run_rc_command "$1" -` - -var rcConfTmpl = `# {{.Name}} -{{.Name}}_enable="YES" -` diff --git a/internal/router/os_others.go b/internal/router/os_others.go deleted file mode 100644 index 52b41e4b..00000000 --- a/internal/router/os_others.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build !freebsd - -package router - -import ( - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" -) - -const osName = "" - -func newOsRouter(cfg *ctrld.Config, cdMode bool) Router { - return &osRouter{} -} - -type osRouter struct{} - -func (d *osRouter) ConfigureService(_ *service.Config) error { - return nil -} - -func (d *osRouter) Install(_ *service.Config) error { - return nil -} - -func (d *osRouter) Uninstall(_ *service.Config) error { - return nil -} - -func (d *osRouter) PreRun() error { - return nil -} - -func (d *osRouter) Setup() error { - return nil -} - -func (d *osRouter) Cleanup() error { - return nil -} diff --git a/internal/router/router.go b/internal/router/router.go deleted file mode 100644 index 2d8c462d..00000000 --- a/internal/router/router.go +++ /dev/null @@ -1,288 +0,0 @@ -package router - -import ( - "bytes" - "crypto/x509" - "net" - "os" - "os/exec" - "path/filepath" - "strings" - "sync/atomic" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/certs" - "github.com/Control-D-Inc/ctrld/internal/router/ddwrt" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/Control-D-Inc/ctrld/internal/router/edgeos" - "github.com/Control-D-Inc/ctrld/internal/router/firewalla" - "github.com/Control-D-Inc/ctrld/internal/router/merlin" - netgear "github.com/Control-D-Inc/ctrld/internal/router/netgear_orbi_voxel" - "github.com/Control-D-Inc/ctrld/internal/router/openwrt" - "github.com/Control-D-Inc/ctrld/internal/router/synology" - "github.com/Control-D-Inc/ctrld/internal/router/tomato" - "github.com/Control-D-Inc/ctrld/internal/router/ubios" -) - -// Service is the interface to manage ctrld service on router. -type Service interface { - // ConfigureService performs works for installing ctrla as a service on router. - ConfigureService(*service.Config) error - // Install performs necessary works after service.Install done. - Install(*service.Config) error - // Uninstall performs necessary works after service.Uninstallation done. - Uninstall(*service.Config) error -} - -// Router is the interface for managing ctrld running on router. -type Router interface { - Service - - // PreRun performs works need to be done before ctrld being run on router. - // Implementation should only return if the pre-condition was met (e.g: ntp synced). - PreRun() error - // Setup configures ctrld to be run on the router. - Setup() error - // Cleanup cleans up works setup on router by ctrld. - Cleanup() error -} - -// New returns new Router interface. -func New(cfg *ctrld.Config, cdMode bool) Router { - switch Name() { - case ddwrt.Name: - return ddwrt.New(cfg) - case merlin.Name: - return merlin.New(cfg) - case openwrt.Name: - return openwrt.New(cfg) - case edgeos.Name: - return edgeos.New(cfg) - case ubios.Name: - return ubios.New(cfg) - case synology.Name: - return synology.New(cfg) - case tomato.Name: - return tomato.New(cfg) - case firewalla.Name: - return firewalla.New(cfg) - case netgear.Name: - return netgear.New(cfg) - } - return newOsRouter(cfg, cdMode) -} - -// IsNetGearOrbi reports whether the router is a Netgear Orbi router. -func IsNetGearOrbi() bool { - return Name() == netgear.Name -} - -// IsGLiNet reports whether the router is an GL.iNet router. -func IsGLiNet() bool { - if Name() != openwrt.Name { - return false - } - buf, _ := os.ReadFile("/proc/version") - // The output of /proc/version contains "(glinet@glinet)". - return bytes.Contains(buf, []byte(" (glinet")) -} - -// IsOldOpenwrt reports whether the router is an "old" version of Openwrt, -// aka versions which don't have "service" command. -func IsOldOpenwrt() bool { - if Name() != openwrt.Name { - return false - } - cmd, _ := exec.LookPath("service") - return cmd == "" -} - -// WaitProcessExited reports whether the "ctrld stop" command have to wait until ctrld process exited. -func WaitProcessExited() bool { - return Name() == openwrt.Name -} - -var routerPlatform atomic.Pointer[router] - -type router struct { - name string -} - -// Name returns name of the router platform. -func Name() string { - if r := routerPlatform.Load(); r != nil { - return r.name - } - r := &router{} - r.name = distroName() - routerPlatform.Store(r) - return r.name -} - -// DefaultInterfaceName returns the default interface name of the current router. -func DefaultInterfaceName() string { - switch Name() { - case ubios.Name: - return "lo" - } - return "" -} - -// LocalResolverIP returns the IP that could be used as nameserver in /etc/resolv.conf file. -func LocalResolverIP() string { - var iface string - switch Name() { - case edgeos.Name: - // On EdgeOS, dnsmasq is run with "--local-service", so we need to get - // the proper interface from dnsmasq config. - if name, _ := dnsmasq.InterfaceNameFromConfig("/etc/dnsmasq.conf"); name != "" { - iface = name - } - case firewalla.Name: - // On Firewalla, the lo interface is excluded in all dnsmasq settings of all interfaces. - // Thus, we use "br0" as the nameserver in /etc/resolv.conf file. - iface = "br0" - } - if netIface, _ := net.InterfaceByName(iface); netIface != nil { - addrs, _ := netIface.Addrs() - for _, addr := range addrs { - if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.To4() != nil { - return netIP.IP.To4().String() - } - } - } - return "" -} - -// HomeDir returns the home directory of ctrld on current router. -func HomeDir() (string, error) { - switch Name() { - case ddwrt.Name, firewalla.Name, merlin.Name, netgear.Name, tomato.Name: - exe, err := os.Executable() - if err != nil { - return "", err - } - return filepath.Dir(exe), nil - case edgeos.Name: - exe, err := os.Executable() - if err != nil { - return "", err - } - // Using binary directory as home dir if it is located in /config. - // Otherwise, fallback to old behavior for compatibility. - if strings.HasPrefix(exe, "/config/") { - return filepath.Dir(exe), nil - } - } - return "", nil -} - -// CertPool returns the system certificate pool of the current router. -func CertPool() *x509.CertPool { - if Name() == ddwrt.Name { - return certs.CACertPool() - } - return nil -} - -// CanListenLocalhost reports whether the ctrld can listen on localhost with current host. -func CanListenLocalhost() bool { - switch { - case Name() == firewalla.Name: - return false - default: - return true - } -} - -// SelfInterfaces return list of *net.Interface that will be source of requests from router itself. -func SelfInterfaces() []*net.Interface { - switch Name() { - case firewalla.Name: - return dnsmasq.FirewallaSelfInterfaces() - default: - return nil - } -} - -// LeaseFilesDir is the directory which contains lease files. -func LeaseFilesDir() string { - if Name() == edgeos.Name { - edgeos.LeaseFileDir() - } - return "" -} - -// ServiceDependencies returns list of dependencies that ctrld services needs on this router. -// See https://pkg.go.dev/github.com/kardianos/service#Config for list format. -func ServiceDependencies() []string { - if Name() == ubios.Name { - // On Ubios, ctrld needs to start after unifi-mongodb, - // so it can query custom client info mapping. - return []string{ - "Wants=unifi-mongodb.service", - "After=unifi-mongodb.service", - } - } - return nil -} - -func distroName() string { - switch { - case bytes.HasPrefix(unameO(), []byte("DD-WRT")): - return ddwrt.Name - case bytes.HasPrefix(unameO(), []byte("ASUSWRT-Merlin")): - return merlin.Name - case haveFile("/etc/openwrt_version"): - if haveFile("/bin/config") { // TODO: is there any more reliable way? - return netgear.Name - } - return openwrt.Name - case isUbios(): - return ubios.Name - case bytes.HasPrefix(unameU(), []byte("synology")): - return synology.Name - case bytes.HasPrefix(unameO(), []byte("Tomato")): - return tomato.Name - case haveDir("/config/scripts/post-config.d"): - return edgeos.Name - case haveFile("/etc/ubnt/init/vyatta-router"): - return edgeos.Name // For 2.x - case haveFile("/etc/firewalla_release"): - return firewalla.Name - } - return osName -} - -func haveFile(file string) bool { - _, err := os.Stat(file) - return err == nil -} - -func haveDir(dir string) bool { - fi, _ := os.Stat(dir) - return fi != nil && fi.IsDir() -} - -func unameO() []byte { - out, _ := exec.Command("uname", "-o").Output() - return out -} - -func unameU() []byte { - out, _ := exec.Command("uname", "-u").Output() - return out -} - -// isUbios reports whether the current machine is running on Ubios. -func isUbios() bool { - if haveDir("/data/unifi") { - return true - } - if err := exec.Command("ubnt-device-info", "firmware").Run(); err == nil { - return true - } - return false -} diff --git a/internal/router/service.go b/internal/router/service.go deleted file mode 100644 index 33339646..00000000 --- a/internal/router/service.go +++ /dev/null @@ -1,96 +0,0 @@ -package router - -import ( - "bytes" - "os" - "os/exec" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld/internal/router/ddwrt" - "github.com/Control-D-Inc/ctrld/internal/router/merlin" - "github.com/Control-D-Inc/ctrld/internal/router/tomato" - "github.com/Control-D-Inc/ctrld/internal/router/ubios" -) - -func init() { - systems := []service.System{ - &linuxSystemService{ - name: "ddwrt", - detect: func() bool { return Name() == ddwrt.Name }, - interactive: func() bool { - is, _ := isInteractive() - return is - }, - new: newddwrtService, - }, - &linuxSystemService{ - name: "merlin", - detect: func() bool { return Name() == merlin.Name }, - interactive: func() bool { - is, _ := isInteractive() - return is - }, - new: newMerlinService, - }, - &linuxSystemService{ - name: "ubios", - detect: func() bool { - if Name() != ubios.Name { - return false - } - out, err := exec.Command("ubnt-device-info", "firmware").CombinedOutput() - if err == nil { - // For v2/v3, UbiOS use a Debian base with systemd, so it is not - // necessary to use custom implementation for supporting init system. - return bytes.HasPrefix(out, []byte("1.")) - } - return true - }, - interactive: func() bool { - is, _ := isInteractive() - return is - }, - new: newUbiosService, - }, - &linuxSystemService{ - name: "tomato", - detect: func() bool { return Name() == tomato.Name }, - interactive: func() bool { - is, _ := isInteractive() - return is - }, - new: newTomatoService, - }, - } - systems = append(systems, service.AvailableSystems()...) - service.ChooseSystem(systems...) -} - -type linuxSystemService struct { - name string - detect func() bool - interactive func() bool - new func(i service.Interface, platform string, c *service.Config) (service.Service, error) -} - -func (sc linuxSystemService) String() string { - return sc.name -} -func (sc linuxSystemService) Detect() bool { - return sc.detect() -} -func (sc linuxSystemService) Interactive() bool { - return sc.interactive() -} -func (sc linuxSystemService) New(i service.Interface, c *service.Config) (service.Service, error) { - return sc.new(i, sc.String(), c) -} - -func isInteractive() (bool, error) { - ppid := os.Getppid() - if ppid == 1 { - return false, nil - } - return true, nil -} diff --git a/internal/router/service_ddwrt.go b/internal/router/service_ddwrt.go deleted file mode 100644 index 3217f8a4..00000000 --- a/internal/router/service_ddwrt.go +++ /dev/null @@ -1,294 +0,0 @@ -package router - -import ( - "bytes" - "errors" - "fmt" - "os" - "os/exec" - "os/signal" - "strings" - "syscall" - "text/template" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -type ddwrtSvc struct { - i service.Interface - platform string - *service.Config - rcStartup string -} - -func newddwrtService(i service.Interface, platform string, c *service.Config) (service.Service, error) { - s := &ddwrtSvc{ - i: i, - platform: platform, - Config: c, - } - if err := os.MkdirAll("/jffs/etc/config", 0644); err != nil { - return nil, err - } - return s, nil -} - -func (s *ddwrtSvc) String() string { - if len(s.DisplayName) > 0 { - return s.DisplayName - } - return s.Name -} - -func (s *ddwrtSvc) Platform() string { - return s.platform -} - -func (s *ddwrtSvc) configPath() string { - return fmt.Sprintf("/jffs/etc/config/%s.startup", s.Config.Name) -} - -func (s *ddwrtSvc) template() *template.Template { - return template.Must(template.New("").Parse(ddwrtSvcScript)) -} - -func (s *ddwrtSvc) Install() error { - confPath := s.configPath() - if _, err := os.Stat(confPath); err == nil { - return fmt.Errorf("already installed: %s", confPath) - } - - path, err := os.Executable() - if err != nil { - return err - } - - if !strings.HasPrefix(path, "/jffs/") { - return errors.New("could not install service outside /jffs") - } - - var to = &struct { - *service.Config - Path string - }{ - s.Config, - path, - } - - f, err := os.Create(confPath) - if err != nil { - return err - } - defer f.Close() - - if err := s.template().Execute(f, to); err != nil { - return err - } - - if err = os.Chmod(confPath, 0755); err != nil { - return err - } - - var sb strings.Builder - if err := template.Must(template.New("").Parse(ddwrtStartupCmd)).Execute(&sb, to); err != nil { - return err - } - s.rcStartup = sb.String() - curVal, err := nvram.Run("get", nvram.RCStartupKey) - if err != nil { - return err - } - if _, err := nvram.Run("set", nvram.CtrldKeyPrefix+nvram.RCStartupKey+"="+curVal); err != nil { - return err - } - val := strings.Join([]string{curVal, s.rcStartup + " &", fmt.Sprintf(`echo $! > "/tmp/%s.pid"`, s.Config.Name)}, "\n") - - if _, err := nvram.Run("set", nvram.RCStartupKey+"="+val); err != nil { - return err - } - if out, err := nvram.Run("commit"); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - - return nil -} - -func (s *ddwrtSvc) Uninstall() error { - if err := os.Remove(s.configPath()); err != nil { - return err - } - - ctrldStartupKey := nvram.CtrldKeyPrefix + nvram.RCStartupKey - rcStartup, err := nvram.Run("get", ctrldStartupKey) - if err != nil { - return err - } - _, _ = nvram.Run("unset", ctrldStartupKey) - if _, err := nvram.Run("set", nvram.RCStartupKey+"="+rcStartup); err != nil { - return err - } - if out, err := nvram.Run("commit"); err != nil { - return fmt.Errorf("%s: %w", out, err) - } - - return nil -} - -func (s *ddwrtSvc) Logger(errs chan<- error) (service.Logger, error) { - if service.Interactive() { - return service.ConsoleLogger, nil - } - return s.SystemLogger(errs) -} - -func (s *ddwrtSvc) SystemLogger(errs chan<- error) (service.Logger, error) { - // TODO(cuonglm): detect syslog enable and return proper logger? - // this at least works with default configuration. - if service.Interactive() { - return service.ConsoleLogger, nil - - } - return &noopLogger{}, nil -} - -func (s *ddwrtSvc) Run() (err error) { - err = s.i.Start(s) - if err != nil { - return err - } - - if interactice, _ := isInteractive(); !interactice { - signal.Ignore(syscall.SIGHUP) - } - var sigChan = make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt) - <-sigChan - - return s.i.Stop(s) -} - -func (s *ddwrtSvc) Status() (service.Status, error) { - if _, err := os.Stat(s.configPath()); os.IsNotExist(err) { - return service.StatusUnknown, service.ErrNotInstalled - } - out, err := exec.Command(s.configPath(), "status").CombinedOutput() - if err != nil { - return service.StatusUnknown, err - } - switch string(bytes.TrimSpace(out)) { - case "running": - return service.StatusRunning, nil - default: - return service.StatusStopped, nil - } -} - -func (s *ddwrtSvc) Start() error { - return exec.Command(s.configPath(), "start").Run() -} - -func (s *ddwrtSvc) Stop() error { - return exec.Command(s.configPath(), "stop").Run() -} - -func (s *ddwrtSvc) Restart() error { - err := s.Stop() - if err != nil { - return err - } - return s.Start() -} - -type noopLogger struct { -} - -func (c noopLogger) Error(v ...interface{}) error { - return nil -} -func (c noopLogger) Warning(v ...interface{}) error { - return nil -} -func (c noopLogger) Info(v ...interface{}) error { - return nil -} -func (c noopLogger) Errorf(format string, a ...interface{}) error { - return nil -} -func (c noopLogger) Warningf(format string, a ...interface{}) error { - return nil -} -func (c noopLogger) Infof(format string, a ...interface{}) error { - return nil -} - -const ddwrtStartupCmd = `{{.Path}}{{range .Arguments}} {{.}}{{end}}` -const ddwrtSvcScript = `#!/bin/sh - -name="{{.Name}}" -cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}" -pid_file="/tmp/$name.pid" - -get_pid() { - cat "$pid_file" -} - -is_running() { - [ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) " -} - -case "$1" in - start) - if is_running; then - echo "Already started" - else - echo "Starting $name" - $cmd & - echo $! > "$pid_file" - chmod 600 "$pid_file" - if ! is_running; then - echo "Failed to start $name" - exit 1 - fi - fi - ;; - stop) - if is_running; then - echo -n "Stopping $name..." - kill "$(get_pid)" - for _ in 1 2 3 4 5; do - if ! is_running; then - echo "stopped" - if [ -f "$pid_file" ]; then - rm "$pid_file" - fi - exit 0 - fi - printf "." - sleep 2 - done - echo "failed to stop $name" - exit 1 - fi - exit 0 - ;; - restart) - $0 stop - $0 start - ;; - status) - if is_running; then - echo "running" - else - echo "stopped" - exit 1 - fi - ;; - *) - echo "Usage: $0 {start|stop|restart|status}" - exit 1 - ;; -esac -exit 0 -` diff --git a/internal/router/service_merlin.go b/internal/router/service_merlin.go deleted file mode 100644 index 8ab6d6a7..00000000 --- a/internal/router/service_merlin.go +++ /dev/null @@ -1,360 +0,0 @@ -package router - -import ( - "bytes" - "errors" - "fmt" - "os" - "os/exec" - "os/signal" - "path/filepath" - "strings" - "syscall" - "text/template" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -const ( - merlinJFFSScriptPath = "/jffs/scripts/services-start" - merlinJFFSServiceEventScriptPath = "/jffs/scripts/service-event" -) - -type merlinSvc struct { - i service.Interface - platform string - *service.Config -} - -func newMerlinService(i service.Interface, platform string, c *service.Config) (service.Service, error) { - s := &merlinSvc{ - i: i, - platform: platform, - Config: c, - } - return s, nil -} - -func (s *merlinSvc) String() string { - if len(s.DisplayName) > 0 { - return s.DisplayName - } - return s.Name -} - -func (s *merlinSvc) Platform() string { - return s.platform -} - -func (s *merlinSvc) configPath() string { - bin := s.Config.Executable - if bin == "" { - path, err := os.Executable() - if err != nil { - return "" - } - bin = path - } - return bin + ".startup" -} - -func (s *merlinSvc) template() *template.Template { - return template.Must(template.New("").Parse(merlinSvcScript)) -} - -func (s *merlinSvc) Install() error { - exePath, err := os.Executable() - if err != nil { - return err - } - - if !strings.HasPrefix(exePath, "/jffs/") { - return errors.New("could not install service outside /jffs") - } - if _, err := nvram.Run("set", "jffs2_scripts=1"); err != nil { - return err - } - if _, err := nvram.Run("commit"); err != nil { - return err - } - - confPath := s.configPath() - if _, err := os.Stat(confPath); err == nil { - return fmt.Errorf("already installed: %s", confPath) - } - - var to = &struct { - *service.Config - Path string - }{ - s.Config, - exePath, - } - - f, err := os.Create(confPath) - if err != nil { - return fmt.Errorf("os.Create: %w", err) - } - defer f.Close() - - if err := s.template().Execute(f, to); err != nil { - return fmt.Errorf("s.template.Execute: %w", err) - } - - if err = os.Chmod(confPath, 0755); err != nil { - return fmt.Errorf("os.Chmod: startup script: %w", err) - } - - if err := os.MkdirAll(filepath.Dir(merlinJFFSScriptPath), 0755); err != nil { - return fmt.Errorf("os.MkdirAll: %w", err) - } - - tmpScript, err := os.CreateTemp("", "ctrld_install") - if err != nil { - return fmt.Errorf("os.CreateTemp: %w", err) - } - defer os.Remove(tmpScript.Name()) - defer tmpScript.Close() - - if _, err := tmpScript.WriteString(merlinAddLineToScript); err != nil { - return fmt.Errorf("tmpScript.WriteString: %w", err) - } - if err := tmpScript.Close(); err != nil { - return fmt.Errorf("tmpScript.Close: %w", err) - } - addLineToScript := func(line, script string) error { - if _, err := os.Stat(script); os.IsNotExist(err) { - if err := os.WriteFile(script, []byte("#!/bin/sh\n"), 0755); err != nil { - return err - } - } - if err := os.Chmod(script, 0755); err != nil { - return fmt.Errorf("os.Chmod: jffs script: %w", err) - } - - if err := exec.Command("sh", tmpScript.Name(), line, script).Run(); err != nil { - return fmt.Errorf("exec.Command: add startup script: %w", err) - } - return nil - } - - for script, line := range map[string]string{ - merlinJFFSScriptPath: s.configPath() + " start", - merlinJFFSServiceEventScriptPath: s.configPath() + ` service_event "$1" "$2"`, - } { - if err := addLineToScript(line, script); err != nil { - return err - } - } - - return nil -} - -func (s *merlinSvc) Uninstall() error { - if err := os.Remove(s.configPath()); err != nil { - return fmt.Errorf("os.Remove: %w", err) - } - tmpScript, err := os.CreateTemp("", "ctrld_uninstall") - if err != nil { - return fmt.Errorf("os.CreateTemp: %w", err) - } - defer os.Remove(tmpScript.Name()) - defer tmpScript.Close() - - if _, err := tmpScript.WriteString(merlinRemoveLineFromScript); err != nil { - return fmt.Errorf("tmpScript.WriteString: %w", err) - } - if err := tmpScript.Close(); err != nil { - return fmt.Errorf("tmpScript.Close: %w", err) - } - removeLineFromScript := func(line, script string) error { - if _, err := os.Stat(script); os.IsNotExist(err) { - if err := os.WriteFile(script, []byte("#!/bin/sh\n"), 0755); err != nil { - return err - } - } - if err := os.Chmod(script, 0755); err != nil { - return fmt.Errorf("os.Chmod: jffs script: %w", err) - } - - if err := exec.Command("sh", tmpScript.Name(), line, script).Run(); err != nil { - return fmt.Errorf("exec.Command: add startup script: %w", err) - } - return nil - } - - for script, line := range map[string]string{ - merlinJFFSScriptPath: s.configPath() + " start", - merlinJFFSServiceEventScriptPath: s.configPath() + ` service_event "$1" "$2"`, - } { - if err := removeLineFromScript(line, script); err != nil { - return err - } - } - - return nil -} - -func (s *merlinSvc) Logger(errs chan<- error) (service.Logger, error) { - if service.Interactive() { - return service.ConsoleLogger, nil - } - return s.SystemLogger(errs) -} - -func (s *merlinSvc) SystemLogger(errs chan<- error) (service.Logger, error) { - return newSysLogger(s.Name, errs) -} - -func (s *merlinSvc) Run() (err error) { - err = s.i.Start(s) - if err != nil { - return err - } - - if interactice, _ := isInteractive(); !interactice { - signal.Ignore(syscall.SIGHUP) - } - - var sigChan = make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt) - <-sigChan - - return s.i.Stop(s) -} - -func (s *merlinSvc) Status() (service.Status, error) { - if _, err := os.Stat(s.configPath()); os.IsNotExist(err) { - return service.StatusUnknown, service.ErrNotInstalled - } - out, err := exec.Command(s.configPath(), "status").CombinedOutput() - if err != nil { - return service.StatusUnknown, err - } - switch string(bytes.TrimSpace(out)) { - case "running": - return service.StatusRunning, nil - default: - return service.StatusStopped, nil - } -} - -func (s *merlinSvc) Start() error { - return exec.Command(s.configPath(), "start").Run() -} - -func (s *merlinSvc) Stop() error { - return exec.Command(s.configPath(), "stop").Run() -} - -func (s *merlinSvc) Restart() error { - err := s.Stop() - if err != nil { - return err - } - return s.Start() -} - -const merlinSvcScript = `#!/bin/sh - -name="{{.Name}}" -cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}" -pid_file="/tmp/$name.pid" - -get_pid() { - cat "$pid_file" -} - -is_running() { - [ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) " -} - -case "$1" in - start) - if is_running; then - logger -c "Already started" - else - logger -c "Starting $name" - if [ -f /rom/ca-bundle.crt ]; then - # For John’s fork - export SSL_CERT_FILE=/rom/ca-bundle.crt - fi - $cmd & - echo $! > "$pid_file" - chmod 600 "$pid_file" - if ! is_running; then - logger -c "Failed to start $name" - exit 1 - fi - fi - ;; - stop) - if is_running; then - logger -c "Stopping $name..." - kill "$(get_pid)" - for _ in 1 2 3 4 5; do - if ! is_running; then - logger -c "stopped" - if [ -f "$pid_file" ]; then - rm "$pid_file" - fi - exit 0 - fi - printf "." - sleep 2 - done - logger -c "failed to stop $name" - exit 1 - fi - exit 0 - ;; - restart) - $0 stop - $0 start - ;; - status) - if is_running; then - echo "running" - else - echo "stopped" - exit 1 - fi - ;; - service_event) - event=$2 - svc=$3 - dnsmasq_pid_file=$(sed -n '/pid-file=/s///p' /etc/dnsmasq.conf) - - if [ "$event" = "restart" ] && [ "$svc" = "diskmon" ]; then - kill "$(cat "$dnsmasq_pid_file")" >/dev/null 2>&1 - fi - ;; - *) - echo "Usage: $0 {start|stop|restart|status}" - exit 1 - ;; -esac -exit 0 -` - -const merlinAddLineToScript = `#!/bin/sh - -line=$1 -file=$2 - -. /usr/sbin/helper.sh - -pc_append "$line" "$file" -` - -const merlinRemoveLineFromScript = `#!/bin/sh - -line=$1 -file=$2 - -. /usr/sbin/helper.sh - -pc_delete "$line" "$file" -` diff --git a/internal/router/service_tomato.go b/internal/router/service_tomato.go deleted file mode 100644 index 2cf59391..00000000 --- a/internal/router/service_tomato.go +++ /dev/null @@ -1,289 +0,0 @@ -package router - -import ( - "bytes" - "errors" - "fmt" - "os" - "os/exec" - "os/signal" - "strings" - "syscall" - "text/template" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld/internal/router/nvram" -) - -const tomatoNvramScriptWanupKey = "script_wanup" - -type tomatoSvc struct { - i service.Interface - platform string - *service.Config -} - -func newTomatoService(i service.Interface, platform string, c *service.Config) (service.Service, error) { - s := &tomatoSvc{ - i: i, - platform: platform, - Config: c, - } - return s, nil -} - -func (s *tomatoSvc) String() string { - if len(s.DisplayName) > 0 { - return s.DisplayName - } - return s.Name -} - -func (s *tomatoSvc) Platform() string { - return s.platform -} - -func (s *tomatoSvc) configPath() string { - bin := s.Config.Executable - if bin == "" { - path, err := os.Executable() - if err != nil { - return "" - } - bin = path - } - return bin + ".startup" -} - -func (s *tomatoSvc) template() *template.Template { - return template.Must(template.New("").Parse(tomatoSvcScript)) -} - -func (s *tomatoSvc) Install() error { - exePath, err := os.Executable() - if err != nil { - return err - } - - if !strings.HasPrefix(exePath, "/jffs/") { - return errors.New("could not install service outside /jffs") - } - if _, err := nvram.Run("set", "jffs2_on=1"); err != nil { - return err - } - if _, err := nvram.Run("commit"); err != nil { - return err - } - - confPath := s.configPath() - if _, err := os.Stat(confPath); err == nil { - return fmt.Errorf("already installed: %s", confPath) - } - - var to = &struct { - *service.Config - Path string - }{ - s.Config, - exePath, - } - - f, err := os.Create(confPath) - if err != nil { - return fmt.Errorf("os.Create: %w", err) - } - defer f.Close() - - if err := s.template().Execute(f, to); err != nil { - return fmt.Errorf("s.template.Execute: %w", err) - } - - if err = os.Chmod(confPath, 0755); err != nil { - return fmt.Errorf("os.Chmod: startup script: %w", err) - } - - nvramKvMap := map[string]string{ - tomatoNvramScriptWanupKey: "", // script to start ctrld, filled by tomatoSvc.Install method. - } - old, err := nvram.Run("get", tomatoNvramScriptWanupKey) - if err != nil { - return fmt.Errorf("nvram: %w", err) - } - nvramKvMap[tomatoNvramScriptWanupKey] = strings.Join([]string{old, s.configPath() + " start"}, "\n") - if err := nvram.SetKV(nvramKvMap, nvram.CtrldInstallKey); err != nil { - return err - } - return nil -} - -func (s *tomatoSvc) Uninstall() error { - if err := os.Remove(s.configPath()); err != nil { - return fmt.Errorf("os.Remove: %w", err) - } - nvramKvMap := map[string]string{ - tomatoNvramScriptWanupKey: "", // script to start ctrld, filled by tomatoSvc.Install method. - } - // Restore old configs. - if err := nvram.Restore(nvramKvMap, nvram.CtrldInstallKey); err != nil { - return err - } - return nil -} - -func (s *tomatoSvc) Logger(errs chan<- error) (service.Logger, error) { - if service.Interactive() { - return service.ConsoleLogger, nil - } - return s.SystemLogger(errs) -} - -func (s *tomatoSvc) SystemLogger(errs chan<- error) (service.Logger, error) { - return newSysLogger(s.Name, errs) -} - -func (s *tomatoSvc) Run() (err error) { - err = s.i.Start(s) - if err != nil { - return err - } - - if interactice, _ := isInteractive(); !interactice { - signal.Ignore(syscall.SIGHUP) - } - - var sigChan = make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt) - <-sigChan - - return s.i.Stop(s) -} - -func (s *tomatoSvc) Status() (service.Status, error) { - if _, err := os.Stat(s.configPath()); os.IsNotExist(err) { - return service.StatusUnknown, service.ErrNotInstalled - } - out, err := exec.Command(s.configPath(), "status").CombinedOutput() - if err != nil { - return service.StatusUnknown, err - } - switch string(bytes.TrimSpace(out)) { - case "running": - return service.StatusRunning, nil - default: - return service.StatusStopped, nil - } -} - -func (s *tomatoSvc) Start() error { - return exec.Command(s.configPath(), "start").Run() -} - -func (s *tomatoSvc) Stop() error { - return exec.Command(s.configPath(), "stop").Run() -} - -func (s *tomatoSvc) Restart() error { - return exec.Command(s.configPath(), "restart").Run() -} - -// https://wiki.freshtomato.org/doku.php/freshtomato_zerotier?s[]=%2Aservice%2A -const tomatoSvcScript = `#!/bin/sh - - -NAME="{{.Name}}" -CMD="{{.Path}}{{range .Arguments}} {{.}}{{end}}" -LOG_FILE="/var/log/${NAME}.log" -PID_FILE="/tmp/$NAME.pid" - - -alias elog="logger -t $NAME -s" - - -COND=$1 -[ $# -eq 0 ] && COND="start" - -get_pid() { - cat "$PID_FILE" -} - -is_running() { - [ -f "$PID_FILE" ] && ps | grep -q "^ *$(get_pid) " -} - -start() { - if is_running; then - elog "$NAME is already running." - exit 1 - fi - elog "Starting $NAME Services: " - $CMD & - echo $! > "$PID_FILE" - chmod 600 "$PID_FILE" - if is_running; then - elog "succeeded." - else - elog "failed." - fi -} - - -stop() { - if ! is_running; then - elog "$NAME is not running." - exit 0 - fi - elog "Shutting down $NAME Services: " - kill -SIGTERM "$(get_pid)" - for _ in 1 2 3 4 5; do - if ! is_running; then - if [ -f "$pid_file" ]; then - rm "$pid_file" - fi - return 0 - fi - printf "." - sleep 2 - done - if ! is_running; then - elog "succeeded." - else - elog "failed." - fi -} - - -do_restart() { - stop - start -} - - -do_status() { - if ! is_running; then - echo "stopped" - else - echo "running" - fi -} - - -case "$COND" in -start) - start - ;; -stop) - stop - ;; -restart) - do_restart - ;; -status) - do_status - ;; -*) - elog "Usage: $0 (start|stop|restart|status)" - ;; -esac -exit 0 -` diff --git a/internal/router/service_ubios.go b/internal/router/service_ubios.go deleted file mode 100644 index 9ad971d2..00000000 --- a/internal/router/service_ubios.go +++ /dev/null @@ -1,340 +0,0 @@ -package router - -import ( - "bytes" - "fmt" - "os" - "os/exec" - "os/signal" - "path/filepath" - "strings" - "syscall" - "text/template" - "time" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" -) - -// This is a copy of https://github.com/kardianos/service/blob/v1.2.1/service_sysv_linux.go, -// with modification for supporting ubios v1 init system. - -type ubiosSvc struct { - i service.Interface - platform string - *service.Config -} - -func newUbiosService(i service.Interface, platform string, c *service.Config) (service.Service, error) { - s := &ubiosSvc{ - i: i, - platform: platform, - Config: c, - } - return s, nil -} - -func (s *ubiosSvc) String() string { - if len(s.DisplayName) > 0 { - return s.DisplayName - } - return s.Name -} - -func (s *ubiosSvc) Platform() string { - return s.platform -} - -func (s *ubiosSvc) configPath() string { - return "/etc/init.d/" + s.Config.Name -} - -func (s *ubiosSvc) execPath() (string, error) { - if len(s.Executable) != 0 { - return filepath.Abs(s.Executable) - } - return os.Executable() -} - -func (s *ubiosSvc) template() *template.Template { - return template.Must(template.New("").Funcs(tf).Parse(ubiosSvcScript)) -} - -func (s *ubiosSvc) Install() error { - confPath := s.configPath() - if _, err := os.Stat(confPath); err == nil { - return fmt.Errorf("init already exists: %s", confPath) - } - - f, err := os.Create(confPath) - if err != nil { - return fmt.Errorf("failed to create config path: %w", err) - } - defer f.Close() - - path, err := s.execPath() - if err != nil { - return fmt.Errorf("failed to get exec path: %w", err) - } - - var to = &struct { - *service.Config - Path string - DnsMasqConfPath string - }{ - s.Config, - path, - filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName), - } - - if err := s.template().Execute(f, to); err != nil { - return fmt.Errorf("failed to create init script: %w", err) - } - - if err := f.Close(); err != nil { - return fmt.Errorf("failed to save init script: %w", err) - } - - if err = os.Chmod(confPath, 0755); err != nil { - return fmt.Errorf("failed to set init script executable: %w", err) - } - - // Enable on boot - script, err := os.CreateTemp("", "ctrld_boot.service") - if err != nil { - return fmt.Errorf("failed to create boot service tmp file: %w", err) - } - defer script.Close() - - svcConfig := *to.Config - svcConfig.Arguments = os.Args[1:] - to.Config = &svcConfig - if err := template.Must(template.New("").Funcs(tf).Parse(ubiosBootSystemdService)).Execute(script, &to); err != nil { - return fmt.Errorf("failed to create boot service file: %w", err) - } - if err := script.Close(); err != nil { - return fmt.Errorf("failed to save boot service file: %w", err) - } - - // Copy the boot script to container and start. - cmd := exec.Command("podman", "cp", "--pause=false", script.Name(), "unifi-os:/lib/systemd/system/ctrld-boot.service") - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to copy boot script, out: %s, err: %v", string(out), err) - } - cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "enable", "--now", "ctrld-boot.service") - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to start ctrld boot script, out: %s, err: %v", string(out), err) - } - return nil -} - -func (s *ubiosSvc) Uninstall() error { - if err := os.Remove(s.configPath()); err != nil { - return err - } - // Remove ctrld-boot service inside unifi-os container. - cmd := exec.Command("podman", "exec", "unifi-os", "systemctl", "disable", "ctrld-boot.service") - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to disable ctrld-boot service, out: %s, err: %v", string(out), err) - } - cmd = exec.Command("podman", "exec", "unifi-os", "rm", "/lib/systemd/system/ctrld-boot.service") - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to remove ctrld-boot service file, out: %s, err: %v", string(out), err) - } - cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "daemon-reload") - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to reload systemd service, out: %s, err: %v", string(out), err) - } - cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "reset-failed") - if out, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("failed to reset-failed systemd service, out: %s, err: %v", string(out), err) - } - return nil -} - -func (s *ubiosSvc) Logger(errs chan<- error) (service.Logger, error) { - if service.Interactive() { - return service.ConsoleLogger, nil - } - return s.SystemLogger(errs) -} - -func (s *ubiosSvc) SystemLogger(errs chan<- error) (service.Logger, error) { - return newSysLogger(s.Name, errs) -} - -func (s *ubiosSvc) Run() (err error) { - err = s.i.Start(s) - if err != nil { - return err - } - - if interactice, _ := isInteractive(); !interactice { - signal.Ignore(syscall.SIGHUP) - } - - var sigChan = make(chan os.Signal, 3) - signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt) - <-sigChan - - return s.i.Stop(s) -} - -func (s *ubiosSvc) Status() (service.Status, error) { - if _, err := os.Stat(s.configPath()); os.IsNotExist(err) { - return service.StatusUnknown, service.ErrNotInstalled - } - out, err := exec.Command(s.configPath(), "status").CombinedOutput() - if err != nil { - return service.StatusUnknown, err - } - switch string(bytes.TrimSpace(out)) { - case "Running": - return service.StatusRunning, nil - default: - return service.StatusStopped, nil - } -} - -func (s *ubiosSvc) Start() error { - return exec.Command(s.configPath(), "start").Run() -} - -func (s *ubiosSvc) Stop() error { - return exec.Command(s.configPath(), "stop").Run() -} - -func (s *ubiosSvc) Restart() error { - err := s.Stop() - if err != nil { - return err - } - time.Sleep(50 * time.Millisecond) - return s.Start() -} - -const ubiosBootSystemdService = `[Unit] -Description=Run ctrld On Startup UDM -Wants=network-online.target -After=network-online.target -Wants=unifi-mongodb -After=unifi-mongodb -StartLimitIntervalSec=500 -StartLimitBurst=5 - -[Service] -Restart=on-failure -RestartSec=5s -ExecStart=/sbin/ssh-proxy '[ -f "{{.DnsMasqConfPath}}" ] || {{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}' -RemainAfterExit=true -[Install] -WantedBy=multi-user.target -` - -const ubiosSvcScript = `#!/bin/sh -# For RedHat and cousins: -# chkconfig: - 99 01 -# description: {{.Description}} -# processname: {{.Path}} - -### BEGIN INIT INFO -# Provides: {{.Path}} -# Required-Start: -# Required-Stop: -# Default-Start: 2 3 4 5 -# Default-Stop: 0 1 6 -# Short-Description: {{.DisplayName}} -# Description: {{.Description}} -### END INIT INFO - -cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}" - -name=$(basename $(readlink -f $0)) -pid_file="/var/run/$name.pid" -stdout_log="/var/log/$name.log" -stderr_log="/var/log/$name.err" - -[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name - -get_pid() { - cat "$pid_file" -} - -is_running() { - [ -f "$pid_file" ] && cat /proc/$(get_pid)/stat > /dev/null 2>&1 -} - -case "$1" in - start) - if is_running; then - echo "Already started" - else - echo "Starting $name" - {{if .WorkingDirectory}}cd '{{.WorkingDirectory}}'{{end}} - $cmd >> "$stdout_log" 2>> "$stderr_log" & - echo $! > "$pid_file" - if ! is_running; then - echo "Unable to start, see $stdout_log and $stderr_log" - exit 1 - fi - fi - ;; - stop) - if is_running; then - echo -n "Stopping $name.." - kill $(get_pid) - for i in $(seq 1 10) - do - if ! is_running; then - break - fi - echo -n "." - sleep 1 - done - echo - if is_running; then - echo "Not stopped; may still be shutting down or shutdown may have failed" - exit 1 - else - echo "Stopped" - if [ -f "$pid_file" ]; then - rm "$pid_file" - fi - fi - else - echo "Not running" - fi - ;; - restart) - $0 stop - if is_running; then - echo "Unable to stop, will not attempt to start" - exit 1 - fi - $0 start - ;; - status) - if is_running; then - echo "Running" - else - echo "Stopped" - exit 1 - fi - ;; - *) - echo "Usage: $0 {start|stop|restart|status}" - exit 1 - ;; -esac -exit 0 -` - -var tf = map[string]interface{}{ - "cmd": func(s string) string { - return `"` + strings.Replace(s, `"`, `\"`, -1) + `"` - }, - "cmdEscape": func(s string) string { - return strings.Replace(s, " ", `\x20`, -1) - }, -} diff --git a/internal/router/synology/synology.go b/internal/router/synology/synology.go deleted file mode 100644 index 79339430..00000000 --- a/internal/router/synology/synology.go +++ /dev/null @@ -1,125 +0,0 @@ -package synology - -import ( - "bytes" - "context" - "errors" - "fmt" - "os" - "os/exec" - "strings" - "time" - - "github.com/kardianos/service" - "tailscale.com/logtail/backoff" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/Control-D-Inc/ctrld/internal/router/ntp" -) - -const ( - Name = "synology" - - synologyDNSMasqConfigPath = "/etc/dhcpd/dhcpd-zzz-ctrld.conf" - synologyDhcpdInfoPath = "/etc/dhcpd/dhcpd-zzz-ctrld.info" -) - -type Synology struct { - cfg *ctrld.Config - useUpstart bool -} - -// New returns a router.Router for configuring/setup/run ctrld on Ubios routers. -func New(cfg *ctrld.Config) *Synology { - return &Synology{ - cfg: cfg, - useUpstart: service.Platform() == "linux-upstart", - } -} - -func (s *Synology) ConfigureService(svc *service.Config) error { - svc.Option["LogOutput"] = true - return nil -} - -func (s *Synology) Install(_ *service.Config) error { - return nil -} - -func (s *Synology) Uninstall(_ *service.Config) error { - return nil -} - -func (s *Synology) PreRun() error { - if s.useUpstart { - if err := ntp.WaitUpstart(); err != nil { - return err - } - return waitDhcpServer() - } - return nil -} - -func (s *Synology) Setup() error { - if s.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, s.cfg) - if err != nil { - return err - } - if err := os.WriteFile(synologyDNSMasqConfigPath, []byte(data), 0600); err != nil { - return err - } - if err := os.WriteFile(synologyDhcpdInfoPath, []byte(`enable="yes"`), 0600); err != nil { - return err - } - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func (s *Synology) Cleanup() error { - if s.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Remove the custom config files. - for _, f := range []string{synologyDNSMasqConfigPath, synologyDhcpdInfoPath} { - if err := os.Remove(f); err != nil { - return err - } - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func restartDNSMasq() error { - if out, err := exec.Command("/etc/rc.network", "nat-restart-dhcp").CombinedOutput(); err != nil { - return fmt.Errorf("synologyRestartDNSMasq: %s - %w", string(out), err) - } - return nil -} - -func waitDhcpServer() error { - // Wait until `initctl status dhcpserver` returns running state. - b := backoff.NewBackoff("waitDhcpServer", func(format string, args ...any) {}, 10*time.Second) - for { - out, err := exec.Command("initctl", "status", "dhcpserver").CombinedOutput() - if err != nil { - if strings.Contains(err.Error(), "Unknown job") { - // dhcpserver service does not exist. - return nil - } - return fmt.Errorf("exec.Command: %w", err) - } - if bytes.Contains(out, []byte("start/running")) { - return nil - } - b.BackOff(context.Background(), errors.New("ntp not ready")) - } -} diff --git a/internal/router/syslog.go b/internal/router/syslog.go deleted file mode 100644 index 008bbeb7..00000000 --- a/internal/router/syslog.go +++ /dev/null @@ -1,49 +0,0 @@ -//go:build linux || darwin || freebsd - -package router - -import ( - "fmt" - "log/syslog" - - "github.com/kardianos/service" -) - -func newSysLogger(name string, errs chan<- error) (service.Logger, error) { - w, err := syslog.New(syslog.LOG_INFO, name) - if err != nil { - return nil, err - } - return sysLogger{w, errs}, nil -} - -type sysLogger struct { - *syslog.Writer - errs chan<- error -} - -func (s sysLogger) send(err error) error { - if err != nil && s.errs != nil { - s.errs <- err - } - return err -} - -func (s sysLogger) Error(v ...interface{}) error { - return s.send(s.Writer.Err(fmt.Sprint(v...))) -} -func (s sysLogger) Warning(v ...interface{}) error { - return s.send(s.Writer.Warning(fmt.Sprint(v...))) -} -func (s sysLogger) Info(v ...interface{}) error { - return s.send(s.Writer.Info(fmt.Sprint(v...))) -} -func (s sysLogger) Errorf(format string, a ...interface{}) error { - return s.send(s.Writer.Err(fmt.Sprintf(format, a...))) -} -func (s sysLogger) Warningf(format string, a ...interface{}) error { - return s.send(s.Writer.Warning(fmt.Sprintf(format, a...))) -} -func (s sysLogger) Infof(format string, a ...interface{}) error { - return s.send(s.Writer.Info(fmt.Sprintf(format, a...))) -} diff --git a/internal/router/syslog_windows.go b/internal/router/syslog_windows.go deleted file mode 100644 index ecac969f..00000000 --- a/internal/router/syslog_windows.go +++ /dev/null @@ -1,7 +0,0 @@ -package router - -import "github.com/kardianos/service" - -func newSysLogger(name string, errs chan<- error) (service.Logger, error) { - return service.ConsoleLogger, nil -} diff --git a/internal/router/tomato/tomato.go b/internal/router/tomato/tomato.go deleted file mode 100644 index ee5f09b8..00000000 --- a/internal/router/tomato/tomato.go +++ /dev/null @@ -1,133 +0,0 @@ -package tomato - -import ( - "fmt" - "os/exec" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/Control-D-Inc/ctrld/internal/router/ntp" - "github.com/Control-D-Inc/ctrld/internal/router/nvram" - "github.com/kardianos/service" -) - -const ( - Name = "freshtomato" - - tomatoDnsCryptProxySvcName = "dnscrypt-proxy" - tomatoStubbySvcName = "stubby" - tomatoDNSMasqSvcName = "dnsmasq" -) - -var nvramKvMap = map[string]string{ - "dnsmasq_custom": "", // Configuration of dnsmasq set by ctrld, filled by setupTomato. - "dnscrypt_proxy": "0", // Disable DNSCrypt. - "dnssec_enable": "0", // Disable DNSSEC. - "stubby_proxy": "0", // Disable Stubby -} - -type FreshTomato struct { - cfg *ctrld.Config -} - -// New returns a router.Router for configuring/setup/run ctrld on Ubios routers. -func New(cfg *ctrld.Config) *FreshTomato { - return &FreshTomato{cfg: cfg} -} - -func (f *FreshTomato) ConfigureService(config *service.Config) error { - return nil -} - -func (f *FreshTomato) Install(_ *service.Config) error { - return nil -} - -func (f *FreshTomato) Uninstall(_ *service.Config) error { - return nil -} - -func (f *FreshTomato) PreRun() error { - _ = f.Cleanup() - return ntp.WaitNvram() -} - -func (f *FreshTomato) Setup() error { - if f.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Already setup. - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" { - return nil - } - - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, f.cfg) - if err != nil { - return err - } - nvramKvMap["dnsmasq_custom"] = data - if err := nvram.SetKV(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - // Restart dnscrypt-proxy service. - if err := tomatoRestartServiceWithKill(tomatoDnsCryptProxySvcName, true); err != nil { - return err - } - // Restart stubby service. - if err := tomatoRestartService(tomatoStubbySvcName); err != nil { - return err - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func (f *FreshTomato) Cleanup() error { - if f.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val != "1" { - return nil // was restored, nothing to do. - } - - nvramKvMap["dnsmasq_custom"] = "" - // Restore old configs. - if err := nvram.Restore(nvramKvMap, nvram.CtrldSetupKey); err != nil { - return err - } - - // Restart dnscrypt-proxy service. - if err := tomatoRestartServiceWithKill(tomatoDnsCryptProxySvcName, true); err != nil { - return err - } - // Restart stubby service. - if err := tomatoRestartService(tomatoStubbySvcName); err != nil { - return err - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func tomatoRestartService(name string) error { - return tomatoRestartServiceWithKill(name, false) -} - -func tomatoRestartServiceWithKill(name string, killBeforeRestart bool) error { - if killBeforeRestart { - _, _ = exec.Command("killall", name).CombinedOutput() - } - if out, err := exec.Command("service", name, "restart").CombinedOutput(); err != nil { - return fmt.Errorf("service restart %s: %s, %w", name, string(out), err) - } - return nil -} - -func restartDNSMasq() error { - return tomatoRestartService(tomatoDNSMasqSvcName) -} diff --git a/internal/router/ubios/ubios.go b/internal/router/ubios/ubios.go deleted file mode 100644 index cba68426..00000000 --- a/internal/router/ubios/ubios.go +++ /dev/null @@ -1,102 +0,0 @@ -package ubios - -import ( - "bytes" - "os" - "path/filepath" - "strconv" - - "github.com/kardianos/service" - - "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" - "github.com/Control-D-Inc/ctrld/internal/router/edgeos" -) - -const Name = "ubios" - -type Ubios struct { - cfg *ctrld.Config - dnsmasqConfPath string -} - -// New returns a router.Router for configuring/setup/run ctrld on Ubios routers. -func New(cfg *ctrld.Config) *Ubios { - return &Ubios{ - cfg: cfg, - dnsmasqConfPath: filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName), - } -} - -func (u *Ubios) ConfigureService(config *service.Config) error { - return nil -} - -func (u *Ubios) Install(config *service.Config) error { - // See comment in (*edgeos.EdgeOS).Install method. - if edgeos.ContentFilteringEnabled() { - return edgeos.ErrContentFilteringEnabled - } - // See comment in (*edgeos.EdgeOS).Install method. - if edgeos.DnsShieldEnabled() { - return edgeos.ErrDnsShieldEnabled - } - return nil -} - -func (u *Ubios) Uninstall(_ *service.Config) error { - return nil -} - -func (u *Ubios) PreRun() error { - return nil -} - -func (u *Ubios) Setup() error { - if u.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - data, err := dnsmasq.ConfTmplWithCacheDisabled(dnsmasq.ConfigContentTmpl, u.cfg, false) - if err != nil { - return err - } - if err := os.WriteFile(u.dnsmasqConfPath, []byte(data), 0600); err != nil { - return err - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func (u *Ubios) Cleanup() error { - if u.cfg.FirstListener().IsDirectDnsListener() { - return nil - } - // Remove the custom dnsmasq config - if err := os.Remove(u.dnsmasqConfPath); err != nil { - return err - } - // Restart dnsmasq service. - if err := restartDNSMasq(); err != nil { - return err - } - return nil -} - -func restartDNSMasq() error { - buf, err := os.ReadFile(dnsmasq.UbiosPidFile()) - if err != nil { - return err - } - pid, err := strconv.ParseUint(string(bytes.TrimSpace(buf)), 10, 64) - if err != nil { - return err - } - proc, err := os.FindProcess(int(pid)) - if err != nil { - return err - } - return proc.Kill() -} diff --git a/internal/rulematcher/domain.go b/internal/rulematcher/domain.go new file mode 100644 index 00000000..e70ea583 --- /dev/null +++ b/internal/rulematcher/domain.go @@ -0,0 +1,31 @@ +package rulematcher + +import ( + "context" +) + +// DomainRuleMatcher handles matching of domain-based rules +type DomainRuleMatcher struct{} + +// Match evaluates domain rules against the requested domain +func (d *DomainRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult { + if req.Policy == nil || len(req.Policy.Rules) == 0 { + return &MatchResult{Matched: false, RuleType: RuleTypeDomain} + } + + for _, rule := range req.Policy.Rules { + // There's only one entry per rule, config validation ensures this. + for source, targets := range rule { + if source == req.Domain || wildcardMatches(source, req.Domain) { + return &MatchResult{ + Matched: true, + Targets: targets, + MatchedRule: source, + RuleType: RuleTypeDomain, + } + } + } + } + + return &MatchResult{Matched: false, RuleType: RuleTypeDomain} +} diff --git a/internal/rulematcher/engine.go b/internal/rulematcher/engine.go new file mode 100644 index 00000000..8a5b9513 --- /dev/null +++ b/internal/rulematcher/engine.go @@ -0,0 +1,148 @@ +// Package rulematcher provides a flexible rule matching engine for DNS request routing. +// +// The rulematcher package implements a policy-based DNS routing system that allows +// configuring different types of rules to determine which upstream DNS servers should +// handle specific requests. It supports three types of rules: +// +// - Network rules: Match requests based on source IP address ranges +// - MAC rules: Match requests based on source MAC addresses +// - Domain rules: Match requests based on requested domain names +// +// The matching engine uses a configurable priority order to determine which rules +// take precedence when multiple rules match. By default, the priority order is: +// Network -> MAC -> Domain, with Domain rules having the highest priority and +// overriding all other matches. +// +// Example usage: +// +// config := &MatchingConfig{ +// Order: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, +// } +// engine := NewMatchingEngine(config) +// +// request := &MatchRequest{ +// SourceIP: net.ParseIP("192.168.1.100"), +// SourceMac: "aa:bb:cc:dd:ee:ff", +// Domain: "example.com", +// Policy: policyConfig, +// Config: appConfig, +// } +// +// result := engine.FindUpstreams(ctx, request) +// if result.Matched { +// // Use result.Upstreams to route the request +// } +// +// The package maintains backward compatibility with existing behavior while +// providing a clean, extensible interface for adding new rule types. +package rulematcher + +import ( + "context" +) + +// MatchingEngine orchestrates rule matching based on configurable order +type MatchingEngine struct { + config *MatchingConfig + matchers map[RuleType]RuleMatcher +} + +// NewMatchingEngine creates a new matching engine with the given configuration +func NewMatchingEngine(config *MatchingConfig) *MatchingEngine { + if config == nil { + config = DefaultMatchingConfig() + } + + engine := &MatchingEngine{ + config: config, + matchers: map[RuleType]RuleMatcher{ + RuleTypeNetwork: &NetworkRuleMatcher{}, + RuleTypeMac: &MacRuleMatcher{}, + RuleTypeDomain: &DomainRuleMatcher{}, + }, + } + + return engine +} + +// FindUpstreams determines which upstreams should handle a request based on policy rules +// It implements the original behavior where MAC and domain rules can override network rules +func (e *MatchingEngine) FindUpstreams(ctx context.Context, req *MatchRequest) *MatchingResult { + result := &MatchingResult{ + Upstreams: []string{}, + MatchedPolicy: "no policy", + MatchedNetwork: "no network", + MatchedRule: "no rule", + Matched: false, + SrcAddr: req.SourceIP.String(), + MatchedRuleType: "", + MatchingOrder: e.config.Order, + } + + if req.Policy == nil { + return result + } + + result.MatchedPolicy = req.Policy.Name + + var networkMatch *MatchResult + var macMatch *MatchResult + var domainMatch *MatchResult + + // Check all rule types and store matches + for _, ruleType := range e.config.Order { + matcher, exists := e.matchers[ruleType] + if !exists { + continue + } + + matchResult := matcher.Match(ctx, req) + if matchResult.Matched { + switch matchResult.RuleType { + case RuleTypeNetwork: + networkMatch = matchResult + case RuleTypeMac: + macMatch = matchResult + case RuleTypeDomain: + domainMatch = matchResult + } + } + } + + // Determine the final match based on original logic: + // Domain rules override everything, MAC rules override network rules + if domainMatch != nil { + result.Upstreams = domainMatch.Targets + result.Matched = true + result.MatchedRuleType = string(domainMatch.RuleType) + result.MatchedRule = domainMatch.MatchedRule + // Special case: domain rules override network rules + if networkMatch != nil { + result.MatchedNetwork = networkMatch.MatchedRule + " (unenforced)" + } + } else if macMatch != nil { + result.Upstreams = macMatch.Targets + result.Matched = true + result.MatchedRuleType = string(macMatch.RuleType) + result.MatchedNetwork = macMatch.MatchedRule + } else if networkMatch != nil { + result.Upstreams = networkMatch.Targets + result.Matched = true + result.MatchedRuleType = string(networkMatch.RuleType) + result.MatchedNetwork = networkMatch.MatchedRule + } + + return result +} + +// MatchingResult represents the result of the matching engine +type MatchingResult struct { + Upstreams []string + MatchedPolicy string + MatchedNetwork string + MatchedRule string + Matched bool + SrcAddr string + MatchedRuleType string + MatchingOrder []RuleType +} diff --git a/internal/rulematcher/engine_test.go b/internal/rulematcher/engine_test.go new file mode 100644 index 00000000..1c388dc4 --- /dev/null +++ b/internal/rulematcher/engine_test.go @@ -0,0 +1,216 @@ +package rulematcher + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/Control-D-Inc/ctrld/testhelper" +) + +func TestMatchingEngine(t *testing.T) { + cfg := testhelper.SampleConfig(t) + // Convert Cidrs to IPNets like in the original test + for _, nc := range cfg.Network { + for _, cidr := range nc.Cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatal(err) + } + nc.IPNets = append(nc.IPNets, ipNet) + } + } + + tests := []struct { + name string + config *MatchingConfig + request *MatchRequest + expected *MatchingResult + }{ + { + name: "Default config - network match first", + config: DefaultMatchingConfig(), + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + SourceMac: "14:45:A0:67:83:0A", + Domain: "example.ru", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchingResult{ + Upstreams: []string{"upstream.1"}, + MatchedPolicy: "My Policy", + MatchedNetwork: "network.0 (unenforced)", + MatchedRule: "*.ru", + Matched: true, + SrcAddr: "192.168.0.1", + MatchedRuleType: "domain", + MatchingOrder: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, + }, + }, + { + name: "Custom order - domain first", + config: &MatchingConfig{ + Order: []RuleType{RuleTypeDomain, RuleTypeNetwork, RuleTypeMac}, + }, + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + SourceMac: "14:45:A0:67:83:0A", + Domain: "example.ru", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchingResult{ + Upstreams: []string{"upstream.1"}, + MatchedPolicy: "My Policy", + MatchedNetwork: "network.0 (unenforced)", + MatchedRule: "*.ru", + Matched: true, + SrcAddr: "192.168.0.1", + MatchedRuleType: "domain", + MatchingOrder: []RuleType{RuleTypeDomain, RuleTypeNetwork, RuleTypeMac}, + }, + }, + { + name: "Custom order - MAC first", + config: &MatchingConfig{ + Order: []RuleType{RuleTypeMac, RuleTypeNetwork, RuleTypeDomain}, + }, + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + SourceMac: "14:45:A0:67:83:0A", + Domain: "example.ru", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchingResult{ + Upstreams: []string{"upstream.1"}, + MatchedPolicy: "My Policy", + MatchedNetwork: "network.0 (unenforced)", + MatchedRule: "*.ru", + Matched: true, + SrcAddr: "192.168.0.1", + MatchedRuleType: "domain", + MatchingOrder: []RuleType{RuleTypeMac, RuleTypeNetwork, RuleTypeDomain}, + }, + }, + { + name: "No policy", + config: DefaultMatchingConfig(), + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + SourceMac: "14:45:A0:67:83:0A", + Domain: "example.ru", + Policy: nil, + Config: cfg, + }, + expected: &MatchingResult{ + Upstreams: []string{}, + MatchedPolicy: "no policy", + MatchedNetwork: "no network", + MatchedRule: "no rule", + Matched: false, + SrcAddr: "192.168.0.1", + MatchedRuleType: "", + MatchingOrder: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, + }, + }, + { + name: "No matches", + config: DefaultMatchingConfig(), + request: &MatchRequest{ + SourceIP: net.ParseIP("10.0.0.1"), + SourceMac: "00:11:22:33:44:55", + Domain: "example.com", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchingResult{ + Upstreams: []string{}, + MatchedPolicy: "My Policy", + MatchedNetwork: "no network", + MatchedRule: "no rule", + Matched: false, + SrcAddr: "10.0.0.1", + MatchedRuleType: "", + MatchingOrder: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, + }, + }, + { + name: "MAC rule overrides network rule", + config: DefaultMatchingConfig(), + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + SourceMac: "14:45:A0:67:83:0A", + Domain: "example.com", // This domain doesn't match any domain rules + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchingResult{ + Upstreams: []string{"upstream.2"}, + MatchedPolicy: "My Policy", + MatchedNetwork: "14:45:a0:67:83:0a", + MatchedRule: "no rule", + Matched: true, + SrcAddr: "192.168.0.1", + MatchedRuleType: "mac", + MatchingOrder: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + engine := NewMatchingEngine(tc.config) + result := engine.FindUpstreams(context.Background(), tc.request) + + assert.Equal(t, tc.expected.Upstreams, result.Upstreams) + assert.Equal(t, tc.expected.MatchedPolicy, result.MatchedPolicy) + assert.Equal(t, tc.expected.MatchedNetwork, result.MatchedNetwork) + assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule) + assert.Equal(t, tc.expected.Matched, result.Matched) + assert.Equal(t, tc.expected.SrcAddr, result.SrcAddr) + assert.Equal(t, tc.expected.MatchedRuleType, result.MatchedRuleType) + assert.Equal(t, tc.expected.MatchingOrder, result.MatchingOrder) + }) + } +} + +func TestDefaultMatchingConfig(t *testing.T) { + config := DefaultMatchingConfig() + + assert.Equal(t, []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, config.Order) +} + +func TestMatchingEngineWithInvalidRuleType(t *testing.T) { + cfg := testhelper.SampleConfig(t) + // Convert Cidrs to IPNets like in the original test + for _, nc := range cfg.Network { + for _, cidr := range nc.Cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatal(err) + } + nc.IPNets = append(nc.IPNets, ipNet) + } + } + + config := &MatchingConfig{ + Order: []RuleType{RuleType("invalid"), RuleTypeNetwork}, + } + + engine := NewMatchingEngine(config) + request := &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + Policy: cfg.Listener["0"].Policy, + Config: cfg, + } + + result := engine.FindUpstreams(context.Background(), request) + + // Should still work, just skip the invalid rule type + assert.True(t, result.Matched) + assert.Equal(t, "network", result.MatchedRuleType) +} diff --git a/internal/rulematcher/mac.go b/internal/rulematcher/mac.go new file mode 100644 index 00000000..ff20e814 --- /dev/null +++ b/internal/rulematcher/mac.go @@ -0,0 +1,62 @@ +package rulematcher + +import ( + "context" + "strings" +) + +// MacRuleMatcher handles matching of MAC address-based rules +type MacRuleMatcher struct{} + +// Match evaluates MAC address rules against the source MAC address +func (m *MacRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult { + if req.Policy == nil || len(req.Policy.Macs) == 0 { + return &MatchResult{Matched: false, RuleType: RuleTypeMac} + } + + for _, rule := range req.Policy.Macs { + for source, targets := range rule { + if source != "" && (strings.EqualFold(source, req.SourceMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(req.SourceMac))) { + return &MatchResult{ + Matched: true, + Targets: targets, + MatchedRule: source, // Return the original source from the rule + RuleType: RuleTypeMac, + } + } + } + } + + return &MatchResult{Matched: false, RuleType: RuleTypeMac} +} + +// wildcardMatches checks if a wildcard pattern matches a string +// This is copied from the original implementation to maintain compatibility +func wildcardMatches(wildcard, str string) bool { + if wildcard == "" { + return false + } + if wildcard == "*" { + return true + } + if !strings.Contains(wildcard, "*") { + return wildcard == str + } + + parts := strings.Split(wildcard, "*") + if len(parts) != 2 { + return false + } + + prefix := parts[0] + suffix := parts[1] + + if prefix != "" && !strings.HasPrefix(str, prefix) { + return false + } + if suffix != "" && !strings.HasSuffix(str, suffix) { + return false + } + + return true +} diff --git a/internal/rulematcher/network.go b/internal/rulematcher/network.go new file mode 100644 index 00000000..1c20406a --- /dev/null +++ b/internal/rulematcher/network.go @@ -0,0 +1,38 @@ +package rulematcher + +import ( + "context" + "strings" +) + +// NetworkRuleMatcher handles matching of network-based rules +type NetworkRuleMatcher struct{} + +// Match evaluates network rules against the source IP address +func (n *NetworkRuleMatcher) Match(ctx context.Context, req *MatchRequest) *MatchResult { + if req.Policy == nil || len(req.Policy.Networks) == 0 { + return &MatchResult{Matched: false, RuleType: RuleTypeNetwork} + } + + for _, rule := range req.Policy.Networks { + for source, targets := range rule { + networkNum := strings.TrimPrefix(source, "network.") + nc := req.Config.Network[networkNum] + if nc == nil { + continue + } + for _, ipNet := range nc.IPNets { + if ipNet.Contains(req.SourceIP) { + return &MatchResult{ + Matched: true, + Targets: targets, + MatchedRule: source, + RuleType: RuleTypeNetwork, + } + } + } + } + } + + return &MatchResult{Matched: false, RuleType: RuleTypeNetwork} +} diff --git a/internal/rulematcher/rulematcher_test.go b/internal/rulematcher/rulematcher_test.go new file mode 100644 index 00000000..d4eb2356 --- /dev/null +++ b/internal/rulematcher/rulematcher_test.go @@ -0,0 +1,248 @@ +package rulematcher + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/testhelper" +) + +// Test NetworkRuleMatcher +func TestNetworkRuleMatcher(t *testing.T) { + cfg := testhelper.SampleConfig(t) + // Convert Cidrs to IPNets like in the original test + for _, nc := range cfg.Network { + for _, cidr := range nc.Cidrs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatal(err) + } + nc.IPNets = append(nc.IPNets, ipNet) + } + } + matcher := &NetworkRuleMatcher{} + + tests := []struct { + name string + request *MatchRequest + expected *MatchResult + }{ + { + name: "No policy", + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + Policy: nil, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork}, + }, + { + name: "No network rules", + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + Policy: &ctrld.ListenerPolicyConfig{}, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork}, + }, + { + name: "Match network rule", + request: &MatchRequest{ + SourceIP: net.ParseIP("192.168.0.1"), + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.1", "upstream.0"}, + MatchedRule: "network.0", + RuleType: RuleTypeNetwork, + }, + }, + { + name: "No match for IP", + request: &MatchRequest{ + SourceIP: net.ParseIP("10.0.0.1"), + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeNetwork}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := matcher.Match(context.Background(), tc.request) + assert.Equal(t, tc.expected.Matched, result.Matched) + assert.Equal(t, tc.expected.RuleType, result.RuleType) + if tc.expected.Matched { + assert.Equal(t, tc.expected.Targets, result.Targets) + assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule) + } + }) + } +} + +// Test MacRuleMatcher +func TestMacRuleMatcher(t *testing.T) { + cfg := testhelper.SampleConfig(t) + matcher := &MacRuleMatcher{} + + tests := []struct { + name string + request *MatchRequest + expected *MatchResult + }{ + { + name: "No policy", + request: &MatchRequest{ + SourceMac: "14:45:A0:67:83:0A", + Policy: nil, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeMac}, + }, + { + name: "No MAC rules", + request: &MatchRequest{ + SourceMac: "14:45:A0:67:83:0A", + Policy: &ctrld.ListenerPolicyConfig{}, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeMac}, + }, + { + name: "Match MAC rule - exact", + request: &MatchRequest{ + SourceMac: "14:45:A0:67:83:0A", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.2"}, + MatchedRule: "14:45:a0:67:83:0a", // Config loading normalizes MAC addresses to lowercase + RuleType: RuleTypeMac, + }, + }, + { + name: "Match MAC rule - case insensitive", + request: &MatchRequest{ + SourceMac: "14:54:4a:8e:08:2d", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.2"}, + MatchedRule: "14:54:4a:8e:08:2d", + RuleType: RuleTypeMac, + }, + }, + { + name: "No match for MAC", + request: &MatchRequest{ + SourceMac: "00:11:22:33:44:55", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeMac}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := matcher.Match(context.Background(), tc.request) + assert.Equal(t, tc.expected.Matched, result.Matched) + assert.Equal(t, tc.expected.RuleType, result.RuleType) + if tc.expected.Matched { + assert.Equal(t, tc.expected.Targets, result.Targets) + assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule) + } + }) + } +} + +// Test DomainRuleMatcher +func TestDomainRuleMatcher(t *testing.T) { + cfg := testhelper.SampleConfig(t) + matcher := &DomainRuleMatcher{} + + tests := []struct { + name string + request *MatchRequest + expected *MatchResult + }{ + { + name: "No policy", + request: &MatchRequest{ + Domain: "example.com", + Policy: nil, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain}, + }, + { + name: "No domain rules", + request: &MatchRequest{ + Domain: "example.com", + Policy: &ctrld.ListenerPolicyConfig{}, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain}, + }, + { + name: "Match domain rule - exact", + request: &MatchRequest{ + Domain: "example.ru", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.1"}, + MatchedRule: "*.ru", + RuleType: RuleTypeDomain, + }, + }, + { + name: "Match domain rule - wildcard", + request: &MatchRequest{ + Domain: "test.ru", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{ + Matched: true, + Targets: []string{"upstream.1"}, + MatchedRule: "*.ru", + RuleType: RuleTypeDomain, + }, + }, + { + name: "No match for domain", + request: &MatchRequest{ + Domain: "example.com", + Policy: cfg.Listener["0"].Policy, + Config: cfg, + }, + expected: &MatchResult{Matched: false, RuleType: RuleTypeDomain}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := matcher.Match(context.Background(), tc.request) + assert.Equal(t, tc.expected.Matched, result.Matched) + assert.Equal(t, tc.expected.RuleType, result.RuleType) + if tc.expected.Matched { + assert.Equal(t, tc.expected.Targets, result.Targets) + assert.Equal(t, tc.expected.MatchedRule, result.MatchedRule) + } + }) + } +} diff --git a/internal/rulematcher/types.go b/internal/rulematcher/types.go new file mode 100644 index 00000000..073830e1 --- /dev/null +++ b/internal/rulematcher/types.go @@ -0,0 +1,52 @@ +package rulematcher + +import ( + "context" + "net" + + "github.com/Control-D-Inc/ctrld" +) + +// RuleType represents the type of rule being matched +type RuleType string + +const ( + RuleTypeNetwork RuleType = "network" + RuleTypeMac RuleType = "mac" + RuleTypeDomain RuleType = "domain" +) + +// RuleMatcher defines the interface for matching different types of rules +type RuleMatcher interface { + Match(ctx context.Context, request *MatchRequest) *MatchResult +} + +// MatchRequest contains all the information needed for rule matching +type MatchRequest struct { + SourceIP net.IP + SourceMac string + Domain string + Policy *ctrld.ListenerPolicyConfig + Config *ctrld.Config +} + +// MatchResult represents the result of a rule matching operation +type MatchResult struct { + Matched bool + Targets []string + MatchedRule string + RuleType RuleType +} + +// MatchingConfig defines the configuration for rule matching behavior +type MatchingConfig struct { + Order []RuleType `json:"order" yaml:"order"` +} + +// DefaultMatchingConfig returns the default matching configuration +// This maintains backward compatibility with the current behavior +func DefaultMatchingConfig() *MatchingConfig { + return &MatchingConfig{ + Order: []RuleType{RuleTypeNetwork, RuleTypeMac, RuleTypeDomain}, + } +} diff --git a/log.go b/log.go index 14c82e8a..2f3a42f4 100644 --- a/log.go +++ b/log.go @@ -4,25 +4,126 @@ import ( "context" "fmt" "io" - "sync/atomic" + "os" + "time" - "github.com/rs/zerolog" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) -// ProxyLog emits the log record for proxy operations. -// The caller should set it only once. -// DEPRECATED: use ProxyLogger instead. -var ProxyLog = zerolog.New(io.Discard) +// Custom log level for NOTICE (between INFO and WARN) +// DEBUG = -1, INFO = 0, WARN = 1, ERROR = 2, FATAL = 3 +// Since there's no integer between INFO (0) and WARN (1), we'll use the same value as WARN +// but handle NOTICE specially in the encoder to display it differently. +// Note: NOTICE and WARN share the same numeric value (1), so they will both display as "NOTICE" +// when using the custom encoder. This is the intended behavior for visual distinction. +const NoticeLevel = zapcore.Level(zapcore.WarnLevel) // Same value as WARN, but handled specially -// ProxyLogger emits the log record for proxy operations. -var ProxyLogger atomic.Pointer[zerolog.Logger] +// LoggerCtxKey is the context.Context key for a logger. +type LoggerCtxKey struct{} + +// LoggerCtx returns a context.Context with LoggerCtxKey set. +func LoggerCtx(ctx context.Context, l *Logger) context.Context { + return context.WithValue(ctx, LoggerCtxKey{}, l) +} + +// A Logger provides fast, leveled, structured logging. +type Logger struct { + *zap.Logger +} + +var noOpZapLogger = zap.NewNop() + +// NopLogger returns a logger which all operation are no-op. +var NopLogger = &Logger{noOpZapLogger} + +// LoggerFromCtx returns the logger associated with given ctx. +// +// If there's no logger, a no-op logger will be returned. +func LoggerFromCtx(ctx context.Context) *Logger { + if logger, ok := ctx.Value(LoggerCtxKey{}).(*Logger); ok && logger != nil { + return logger + } + return NopLogger +} // ReqIdCtxKey is the context.Context key for a request id. type ReqIdCtxKey struct{} -// Log emits the logs for a particular zerolog event. +// LogEvent represents a logging event with structured fields +type LogEvent struct { + logger *zap.Logger + level zapcore.Level + fields []zap.Field +} + +// Msg logs the message with the collected fields +func (e *LogEvent) Msg(msg string) { + e.logger.Check(e.level, msg).Write(e.fields...) +} + +// Msgf logs a formatted message with the collected fields +func (e *LogEvent) Msgf(format string, v ...any) { + e.Msg(fmt.Sprintf(format, v...)) +} + +// MsgFunc logs a message from a function with the collected fields +func (e *LogEvent) MsgFunc(fn func() string) { + e.Msg(fn()) +} + +// Str adds a string field to the event +func (e *LogEvent) Str(key, val string) *LogEvent { + e.fields = append(e.fields, zap.String(key, val)) + return e +} + +// Int adds an integer field to the event +func (e *LogEvent) Int(key string, val int) *LogEvent { + e.fields = append(e.fields, zap.Int(key, val)) + return e +} + +// Int64 adds an int64 field to the event +func (e *LogEvent) Int64(key string, val int64) *LogEvent { + e.fields = append(e.fields, zap.Int64(key, val)) + return e +} + +// Err adds an error field to the event +func (e *LogEvent) Err(err error) *LogEvent { + if err != nil { + e.fields = append(e.fields, zap.Error(err)) + } + return e +} + +// Bool adds a boolean field to the event +func (e *LogEvent) Bool(key string, val bool) *LogEvent { + e.fields = append(e.fields, zap.Bool(key, val)) + return e +} + +// Interface adds an interface field to the event +func (e *LogEvent) Interface(key string, val interface{}) *LogEvent { + e.fields = append(e.fields, zap.Any(key, val)) + return e +} + +// Any adds an interface field to the event (alias for Interface) +func (e *LogEvent) Any(key string, val interface{}) *LogEvent { + return e.Interface(key, val) +} + +// Strs adds a string slice field to the event +func (e *LogEvent) Strs(key string, vals []string) *LogEvent { + e.fields = append(e.fields, zap.Strings(key, vals)) + return e +} + +// Log emits the logs for a particular logging event. // The request id associated with the context will be included if presents. -func Log(ctx context.Context, e *zerolog.Event, format string, v ...any) { +func Log(ctx context.Context, e *LogEvent, format string, v ...any) { id, ok := ctx.Value(ReqIdCtxKey{}).(string) if !ok { e.Msgf(format, v...) @@ -32,3 +133,124 @@ func Log(ctx context.Context, e *zerolog.Event, format string, v ...any) { return fmt.Sprintf("[%s] %s", id, fmt.Sprintf(format, v...)) }) } + +// Logger methods that mimic zerolog API +func (l *Logger) Debug() *LogEvent { + return &LogEvent{ + logger: l.Logger, + level: zapcore.DebugLevel, + fields: []zap.Field{}, + } +} + +func (l *Logger) Info() *LogEvent { + return &LogEvent{ + logger: l.Logger, + level: zapcore.InfoLevel, + fields: []zap.Field{}, + } +} + +func (l *Logger) Warn() *LogEvent { + return &LogEvent{ + logger: l.Logger, + level: zapcore.WarnLevel, + fields: []zap.Field{}, + } +} + +func (l *Logger) Error() *LogEvent { + return &LogEvent{ + logger: l.Logger, + level: zapcore.ErrorLevel, + fields: []zap.Field{}, + } +} + +func (l *Logger) Fatal() *LogEvent { + return &LogEvent{ + logger: l.Logger, + level: zapcore.FatalLevel, + fields: []zap.Field{}, + } +} + +func (l *Logger) Notice() *LogEvent { + return &LogEvent{ + logger: l.Logger, + level: NoticeLevel, // Custom NOTICE level between INFO and WARN + fields: []zap.Field{}, + } +} + +// With returns a logger with additional fields +func (l *Logger) With() *Logger { + return l +} + +// Str adds a string field to the logger +func (l *Logger) Str(key, val string) *Logger { + // Create a new logger with the field added + newLogger := l.Logger.With(zap.String(key, val)) + return &Logger{newLogger} +} + +// Err adds an error field to the logger +func (l *Logger) Err(err error) *Logger { + // Create a new logger with the error field added + newLogger := l.Logger.With(zap.Error(err)) + return &Logger{newLogger} +} + +// Any adds an interface field to the logger +func (l *Logger) Any(key string, val interface{}) *Logger { + // Create a new logger with the field added + newLogger := l.Logger.With(zap.Any(key, val)) + return &Logger{newLogger} +} + +// Bool adds a boolean field to the logger +func (l *Logger) Bool(key string, val bool) *Logger { + // Create a new logger with the field added + newLogger := l.Logger.With(zap.Bool(key, val)) + return &Logger{newLogger} +} + +// Msgf logs a formatted message at info level +func (l *Logger) Msgf(format string, v ...any) { + l.Info().Msgf(format, v...) +} + +// Msg logs a message at info level +func (l *Logger) Msg(msg string) { + l.Info().Msg(msg) +} + +// Output returns a logger with the specified output +func (l *Logger) Output(w io.Writer) *Logger { + // Create a new zap logger with the writer + encoderConfig := zap.NewDevelopmentEncoderConfig() + encoderConfig.TimeKey = "time" + encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.RFC3339) + encoder := zapcore.NewConsoleEncoder(encoderConfig) + core := zapcore.NewCore(encoder, zapcore.AddSync(w), zapcore.InfoLevel) + newLogger := zap.New(core) + return &Logger{newLogger} +} + +// GetLogger returns the underlying logger +func (l *Logger) GetLogger() *Logger { + return l +} + +// Write implements io.Writer to allow direct writing to the logger +func (l *Logger) Write(p []byte) (n int, err error) { + stdoutSyncer := zapcore.AddSync(os.Stdout) + stdoutSyncer.Write(p) + return len(p), nil +} + +// Printf logs a formatted message at info level +func (l *Logger) Printf(format string, v ...any) { + l.Info().Msgf(format, v...) +} diff --git a/nameservers.go b/nameservers.go index 0aebf9e1..da573e67 100644 --- a/nameservers.go +++ b/nameservers.go @@ -1,9 +1,15 @@ package ctrld -type dnsFn func() []string +import ( + "context" + + "github.com/Control-D-Inc/ctrld/internal/resolvconffile" +) + +type dnsFn func(ctx context.Context) []string // nameservers returns DNS nameservers from system settings. -func nameservers() []string { +func nameservers(ctx context.Context) []string { var dns []string seen := make(map[string]bool) ch := make(chan []string) @@ -11,7 +17,7 @@ func nameservers() []string { for _, fn := range fns { go func(fn dnsFn) { - ch <- fn() + ch <- fn(ctx) }(fn) } for range fns { @@ -26,3 +32,8 @@ func nameservers() []string { return dns } + +// CurrentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file. +func CurrentNameserversFromResolvconf() []string { + return resolvconffile.NameServers() +} diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 09c9516d..15c30c94 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -3,6 +3,7 @@ package ctrld import ( + "context" "net" "syscall" @@ -13,7 +14,7 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, dnsFromRIB} } -func dnsFromRIB() []string { +func dnsFromRIB(_ context.Context) []string { var dns []string rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { diff --git a/nameservers_darwin.go b/nameservers_darwin.go index 1bf45746..eff05bb5 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -22,8 +22,8 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} } -func getDNSFromScutil() []string { - logger := *ProxyLogger.Load() +func getDNSFromScutil(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) const ( maxRetries = 10 @@ -41,7 +41,7 @@ func getDNSFromScutil() []string { cmd := exec.Command("scutil", "--dns") output, err := cmd.Output() if err != nil { - Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err) + Log(context.Background(), logger.Error(), "Failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err) continue } @@ -75,7 +75,7 @@ func getDNSFromScutil() []string { } if err := scanner.Err(); err != nil { - Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err) + Log(context.Background(), logger.Error(), "Error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err) continue } @@ -109,8 +109,8 @@ func getDHCPNameservers(iface string) ([]string, error) { return nameservers, nil } -func getAllDHCPNameservers() []string { - logger := *ProxyLogger.Load() +func getAllDHCPNameservers(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) interfaces, err := net.Interfaces() if err != nil { @@ -172,7 +172,7 @@ func getAllDHCPNameservers() []string { // if we have static DNS servers saved for the current default route, we should add them to the list drIfaceName, err := netmon.DefaultRouteInterface() - Log(context.Background(), logger.Debug(), "checking for static DNS servers for default route interface: %s", drIfaceName) + Log(context.Background(), logger.Debug(), "Checking for static DNS servers for default route interface: %s", drIfaceName) if err != nil { Log(context.Background(), logger.Debug(), "Failed to get default route interface: %v", err) @@ -186,7 +186,7 @@ func getAllDHCPNameservers() []string { Log(context.Background(), logger.Debug(), "Failed to patch interface name %s: %v", drIfaceName, err) } - staticNs, file := SavedStaticNameservers(drIface) + staticNs, file := SavedStaticNameserversAndPath(drIface) Log(context.Background(), logger.Debug(), "static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { diff --git a/nameservers_linux.go b/nameservers_linux.go index 37a9ed24..7a0406df 100644 --- a/nameservers_linux.go +++ b/nameservers_linux.go @@ -3,6 +3,7 @@ package ctrld import ( "bufio" "bytes" + "context" "encoding/hex" "net" "net/netip" @@ -23,7 +24,7 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromResolvConf, dns4, dns6, dnsFromSystemdResolver} } -func dns4() []string { +func dns4(ctx context.Context) []string { f, err := os.Open(v4RouteFile) if err != nil { return nil @@ -32,7 +33,7 @@ func dns4() []string { var dns []string seen := make(map[string]bool) - vis := virtualInterfaces() + vis := virtualInterfaces(ctx) s := bufio.NewScanner(f) first := true for s.Scan() { @@ -45,7 +46,7 @@ func dns4() []string { continue } // Skip virtual interfaces. - if vis.contains(string(bytes.TrimSpace(fields[0]))) { + if _, ok := vis[string(bytes.TrimSpace(fields[0]))]; ok { continue } gw := make([]byte, net.IPv4len) @@ -63,7 +64,7 @@ func dns4() []string { return dns } -func dns6() []string { +func dns6(ctx context.Context) []string { f, err := os.Open(v6RouteFile) if err != nil { return nil @@ -71,7 +72,7 @@ func dns6() []string { defer f.Close() var dns []string - vis := virtualInterfaces() + vis := virtualInterfaces(ctx) s := bufio.NewScanner(f) for s.Scan() { fields := bytes.Fields(s.Bytes()) @@ -79,7 +80,7 @@ func dns6() []string { continue } // Skip virtual interfaces. - if vis.contains(string(bytes.TrimSpace(fields[len(fields)-1]))) { + if _, ok := vis[string(bytes.TrimSpace(fields[len(fields)-1]))]; ok { continue } @@ -97,7 +98,7 @@ func dns6() []string { return dns } -func dnsFromSystemdResolver() []string { +func dnsFromSystemdResolver(_ context.Context) []string { c, err := resolvconffile.ParseFile("/run/systemd/resolve/resolv.conf") if err != nil { return nil @@ -109,34 +110,29 @@ func dnsFromSystemdResolver() []string { return ns } -type set map[string]struct{} - -func (s *set) add(e string) { - (*s)[e] = struct{}{} -} - -func (s *set) contains(e string) bool { - _, ok := (*s)[e] - return ok -} - -// virtualInterfaces returns a set of virtual interfaces on current machine. -func virtualInterfaces() set { - s := make(set) - entries, _ := os.ReadDir("/sys/devices/virtual/net") +// virtualInterfaces returns a map of virtual interfaces on the current machine. +// This reads from /sys/devices/virtual/net to identify virtual network interfaces +// Virtual interfaces should not have DNS configured as they don't represent physical network connections +func virtualInterfaces(ctx context.Context) map[string]struct{} { + logger := LoggerFromCtx(ctx) + s := make(map[string]struct{}) + entries, err := os.ReadDir("/sys/devices/virtual/net") + if err != nil { + logger.Error().Err(err).Msg("Failed to read /sys/devices/virtual/net") + return nil + } for _, entry := range entries { if entry.IsDir() { - s.add(strings.TrimSpace(entry.Name())) + s[strings.TrimSpace(entry.Name())] = struct{}{} } } return s } -// validInterfacesMap returns a set containing non virtual interfaces. -// TODO: deduplicated with cmd/cli/net_linux.go in v2. -func validInterfaces() set { +// ValidInterfaces returns a set containing non virtual interfaces. +func ValidInterfaces(ctx context.Context) map[string]struct{} { m := make(map[string]struct{}) - vis := virtualInterfaces() + vis := virtualInterfaces(ctx) netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { if _, existed := vis[i.Name]; existed { return diff --git a/nameservers_linux_test.go b/nameservers_linux_test.go index 23f15441..dddd377e 100644 --- a/nameservers_linux_test.go +++ b/nameservers_linux_test.go @@ -1,10 +1,11 @@ package ctrld import ( + "context" "testing" ) func Test_virtualInterfaces(t *testing.T) { - vis := virtualInterfaces() + vis := virtualInterfaces(context.Background()) t.Log(vis) } diff --git a/nameservers_test.go b/nameservers_test.go index 166cced6..e2e2bace 100644 --- a/nameservers_test.go +++ b/nameservers_test.go @@ -1,9 +1,12 @@ package ctrld -import "testing" +import ( + "context" + "testing" +) func TestNameservers(t *testing.T) { - ns := nameservers() + ns := nameservers(context.Background()) if len(ns) == 0 { t.Fatal("failed to get nameservers") } diff --git a/nameservers_unix.go b/nameservers_unix.go index d8e6035e..6022f7a5 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -3,24 +3,18 @@ package ctrld import ( + "context" "net" "slices" "time" "tailscale.com/net/netmon" - - "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) -// currentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file. -func currentNameserversFromResolvconf() []string { - return resolvconffile.NameServers() -} - // dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file. // A nameserver is usable if it's not one of current machine's IP addresses // and loopback IP addresses. -func dnsFromResolvConf() []string { +func dnsFromResolvConf(_ context.Context) []string { const ( maxRetries = 10 retryInterval = 100 * time.Millisecond @@ -34,7 +28,7 @@ func dnsFromResolvConf() []string { time.Sleep(retryInterval) } - nss := resolvconffile.NameServers() + nss := CurrentNameserversFromResolvconf() var localDNS []string seen := make(map[string]bool) diff --git a/nameservers_windows.go b/nameservers_windows.go index 7b16e8e1..589d14d8 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -52,28 +52,25 @@ func dnsFns() []dnsFn { return []dnsFn{dnsFromAdapter} } -func dnsFromAdapter() []string { - ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout) +func dnsFromAdapter(ctx context.Context) []string { + ctx, cancel := context.WithTimeout(ctx, defaultDNSAdapterTimeout) defer cancel() var ns []string var err error - logger := *ProxyLogger.Load() + logger := LoggerFromCtx(ctx) for i := 0; i < maxDNSAdapterRetries; i++ { if ctx.Err() != nil { - Log(context.Background(), logger.Debug(), - "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) + logger.Debug().Msgf("dnsFromAdapter lookup cancelled or timed out, attempt %d", i) return nil } ns, err = getDNSServers(ctx) if err == nil && len(ns) >= minDNSServers { if i > 0 { - Log(context.Background(), logger.Debug(), - "Successfully got DNS servers after %d attempts, found %d servers", - i+1, len(ns)) + logger.Debug().Msgf("Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns)) } return ns } @@ -85,11 +82,9 @@ func dnsFromAdapter() []string { } if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get DNS servers, attempt %d: %v", i+1, err) + logger.Debug().Msgf("Failed to get DNS servers, attempt %d: %v", i+1, err) } else { - Log(context.Background(), logger.Debug(), - "Got insufficient DNS servers, retrying, found %d servers", len(ns)) + logger.Debug().Msgf("Got insufficient DNS servers, retrying, found %d servers", len(ns)) } select { @@ -99,14 +94,12 @@ func dnsFromAdapter() []string { } } - Log(context.Background(), logger.Debug(), - "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + logger.Debug().Msgf("Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + return ns } func getDNSServers(ctx context.Context) ([]string, error) { - logger := *ProxyLogger.Load() - // Check context before making the call if ctx.Err() != nil { return nil, ctx.Err() @@ -121,17 +114,16 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("getting adapters: %w", err) } - Log(context.Background(), logger.Debug(), - "Found network adapters, count=%d", len(aas)) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Found network adapters, count=%d", len(aas)) // Try to get domain controller info if domain-joined var dcServers []string - isDomain := checkDomainJoined() + isDomain := checkDomainJoined(ctx) if isDomain { domainName, err := getLocalADDomain() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get local AD domain: %v", err) + logger.Debug().Msgf("Failed to get local AD domain: %v", err) } else { // Load netapi32.dll netapi32 := windows.NewLazySystemDLL("netapi32.dll") @@ -142,11 +134,9 @@ func getDNSServers(ctx context.Context) ([]string, error) { domainUTF16, err := windows.UTF16PtrFromString(domainName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to convert domain name to UTF16: %v", err) + logger.Debug().Msgf("Failed to convert domain name to UTF16: %v", err) } else { - Log(context.Background(), logger.Debug(), - "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) + logger.Debug().Msgf("Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) // Call DsGetDcNameW with domain name ret, _, err := dsDcName.Call( @@ -160,38 +150,32 @@ func getDNSServers(ctx context.Context) ([]string, error) { if ret != 0 { switch ret { case 1355: // ERROR_NO_SUCH_DOMAIN - Log(context.Background(), logger.Debug(), - "Domain not found: %s (%d)", domainName, ret) + logger.Debug().Msgf("Domain not found: %s (%d)", domainName, ret) case 1311: // ERROR_NO_LOGON_SERVERS - Log(context.Background(), logger.Debug(), - "No logon servers available for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("No logon servers available for domain: %s (%d)", domainName, ret) case 1004: // ERROR_DC_NOT_FOUND - Log(context.Background(), logger.Debug(), - "Domain controller not found for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("Domain controller not found for domain: %s (%d)", domainName, ret) case 1722: // RPC_S_SERVER_UNAVAILABLE - Log(context.Background(), logger.Debug(), - "RPC server unavailable for domain: %s (%d)", domainName, ret) + logger.Debug().Msgf("RPC server unavailable for domain: %s (%d)", domainName, ret) default: - Log(context.Background(), logger.Debug(), - "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) + logger.Debug().Msgf("Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) } } else if info != nil { defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) if info.DomainControllerAddress != nil { dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) + // Remove "\\" prefix from domain controller address + // Windows domain controller addresses are returned with "\\" prefix, + // but we need just the IP address for DNS resolution dcAddr = strings.TrimPrefix(dcAddr, "\\\\") - Log(context.Background(), logger.Debug(), - "Found domain controller address: %s", dcAddr) - + logger.Debug().Msgf("Found domain controller address: %s", dcAddr) if ip := net.ParseIP(dcAddr); ip != nil { dcServers = append(dcServers, ip.String()) - Log(context.Background(), logger.Debug(), - "Added domain controller DNS servers: %v", dcServers) + logger.Debug().Msgf("Added domain controller DNS servers: %v", dcServers) } } else { - Log(context.Background(), logger.Debug(), - "No domain controller address found") + logger.Debug().Msg("No domain controller address found") } } } @@ -206,31 +190,27 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Collect all local IPs for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { - Log(context.Background(), logger.Debug(), - "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) + logger.Debug().Msgf("Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) continue } // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName()) continue } - Log(context.Background(), logger.Debug(), - "Processing adapter %s", aa.FriendlyName()) + logger.Debug().Msgf("Processing adapter %s", aa.FriendlyName()) for a := aa.FirstUnicastAddress; a != nil; a = a.Next { ip := a.Address.IP().String() addressMap[ip] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP %s from adapter %s", ip, aa.FriendlyName()) + logger.Debug().Msgf("Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } - validInterfacesMap := validInterfaces() + validInterfacesMap := ValidInterfaces(ctx) // Collect DNS servers for _, aa := range aas { @@ -241,23 +221,20 @@ func getDNSServers(ctx context.Context) ([]string, error) { // Skip if software loopback or other non-physical types // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows if aa.IfType == winipcfg.IfTypeSoftwareLoopback { - Log(context.Background(), logger.Debug(), - "Skipping %s (software loopback)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (software loopback)", aa.FriendlyName()) continue } // if not in the validInterfacesMap, skip if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { - Log(context.Background(), logger.Debug(), - "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + logger.Debug().Msgf("Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) continue } for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() if ip == nil { - Log(context.Background(), logger.Debug(), - "Skipping nil IP from adapter %s", aa.FriendlyName()) + logger.Debug().Msgf("Skipping nil IP from adapter %s", aa.FriendlyName()) continue } @@ -290,28 +267,23 @@ func getDNSServers(ctx context.Context) ([]string, error) { if !seen[dcServer] { seen[dcServer] = true ns = append(ns, dcServer) - Log(context.Background(), logger.Debug(), - "Added additional domain controller DNS server: %s", dcServer) + logger.Debug().Msgf("Added additional domain controller DNS server: %s", dcServer) } } // if we have static DNS servers saved for the current default route, we should add them to the list drIfaceName, err := netmon.DefaultRouteInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get default route interface: %v", err) + logger.Debug().Msgf("Failed to get default route interface: %v", err) } else { drIface, err := net.InterfaceByName(drIfaceName) if err != nil { - Log(context.Background(), logger.Debug(), - "Failed to get interface by name %s: %v", drIfaceName, err) + logger.Debug().Msgf("Failed to get interface by name %s: %v", drIfaceName, err) } else { - staticNs, file := SavedStaticNameservers(drIface) - Log(context.Background(), logger.Debug(), - "static dns servers from %s: %v", file, staticNs) + staticNs, file := SavedStaticNameserversAndPath(drIface) + logger.Debug().Msgf("Static dns servers from %s: %v", file, staticNs) if len(staticNs) > 0 { - Log(context.Background(), logger.Debug(), - "Adding static DNS servers from %s: %v", drIfaceName, staticNs) + logger.Debug().Msgf("Adding static DNS servers from %s: %v", drIfaceName, staticNs) ns = append(ns, staticNs...) } } @@ -321,27 +293,20 @@ func getDNSServers(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("no valid DNS servers found") } - Log(context.Background(), logger.Debug(), - "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", - len(ns), ns, len(dcServers)) + logger.Debug().Msgf("DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", len(ns), ns, len(dcServers)) return ns, nil } -// currentNameserversFromResolvconf returns a nil slice of strings. -func currentNameserversFromResolvconf() []string { - return nil -} - // checkDomainJoined checks if the machine is joined to an Active Directory domain // Returns whether it's domain joined and the domain name if available -func checkDomainJoined() bool { - logger := *ProxyLogger.Load() +func checkDomainJoined(ctx context.Context) bool { + logger := LoggerFromCtx(ctx) var domain *uint16 var status uint32 if err := windows.NetGetJoinInformation(nil, &domain, &status); err != nil { - Log(context.Background(), logger.Debug(), "Failed to get domain join status: %v", err) + logger.Debug().Msgf("Failed to get domain join status: %v", err) return false } defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) @@ -356,12 +321,12 @@ func checkDomainJoined() bool { // // We only care about NetSetupDomainName. domainName := windows.UTF16PtrToString(domain) - Log(context.Background(), logger.Debug(), + logger.Debug().Msgf( "Domain join status: domain=%s status=%d (UnknownStatus=0, Unjoined=1, WorkgroupName=2, DomainName=3)", domainName, status) isDomain := status == syscall.NetSetupDomainName - Log(context.Background(), logger.Debug(), "Is domain joined? status=%d, result=%v", status, isDomain) + logger.Debug().Msgf("Is domain joined? status=%d, result=%v", status, isDomain) return isDomain } @@ -406,15 +371,14 @@ func getLocalADDomain() (string, error) { return domainName, nil } -// validInterfaces returns a list of all physical interfaces. -// this is a duplicate of what is in net_windows.go, we should -// clean this up so there is only one version -func validInterfaces() map[string]struct{} { +// ValidInterfaces returns a map of valid network interface names as keys with empty struct values. +// It filters interfaces to include only physical, hardware-based adapters using WMI queries. +func ValidInterfaces(ctx context.Context) map[string]struct{} { log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) //load the logger - logger := *ProxyLogger.Load() + logger := LoggerFromCtx(ctx) whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") @@ -423,23 +387,20 @@ func validInterfaces() map[string]struct{} { defer instances.Close() } if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get wmi network adapter: %v", err) + logger.Warn().Msgf("Failed to get wmi network adapter: %v", err) return nil } var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get network adapter: %v", err) + logger.Warn().Msgf("Failed to get network adapter: %v", err) continue } name, err := adapter.GetPropertyName() if err != nil { - Log(context.Background(), logger.Warn(), - "failed to get interface name: %v", err) + logger.Warn().Msgf("Failed to get interface name: %v", err) continue } @@ -449,13 +410,11 @@ func validInterfaces() map[string]struct{} { // if this is a physical adapter or FALSE if this is not a physical adapter." physical, err := adapter.GetPropertyConnectorPresent() if err != nil { - Log(context.Background(), logger.Debug(), - "failed to get network adapter connector present property: %v", err) + logger.Debug().Msgf("Failed to get network adapter connector present property: %v", err) continue } if !physical { - Log(context.Background(), logger.Debug(), - "skipping non-physical adapter: %s", name) + logger.Debug().Msgf("Skipping non-physical adapter: %s", name) continue } @@ -463,13 +422,11 @@ func validInterfaces() map[string]struct{} { // because some interfaces are not physical but have a connector. hardware, err := adapter.GetPropertyHardwareInterface() if err != nil { - Log(context.Background(), logger.Debug(), - "failed to get network adapter hardware interface property: %v", err) + logger.Debug().Msgf("Failed to get network adapter hardware interface property: %v", err) continue } if !hardware { - Log(context.Background(), logger.Debug(), - "skipping non-hardware interface: %s", name) + logger.Debug().Msgf("Skipping non-hardware interface: %s", name) continue } diff --git a/net.go b/net.go index 7bbf54bb..30799bff 100644 --- a/net.go +++ b/net.go @@ -17,26 +17,27 @@ var ( ) // HasIPv6 reports whether the current network stack has IPv6 available. -func HasIPv6() bool { +func HasIPv6(ctx context.Context) bool { hasIPv6Once.Do(func() { - ProxyLogger.Load().Debug().Msg("checking for IPv6 availability once") + logger := LoggerFromCtx(ctx) + logger.Debug().Msg("Checking for ipv6 availability once") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() val := ctrldnet.IPv6Available(ctx) ipv6Available.Store(val) - ProxyLogger.Load().Debug().Msgf("ipv6 availability: %v", val) + logger.Debug().Msgf("ipv6 availability: %v", val) mon, err := netmon.New(func(format string, args ...any) {}) if err != nil { - ProxyLogger.Load().Debug().Err(err).Msg("failed to monitor IPv6 state") + logger.Debug().Err(err).Msg("Failed to monitor ipv6 state") return } mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { old := ipv6Available.Load() cur := delta.Monitor.InterfaceState().HaveV6 if old != cur { - ProxyLogger.Load().Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur) + logger.Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur) } else { - ProxyLogger.Load().Debug().Msg("ipv6 availability does not changed") + logger.Debug().Msg("ipv6 availability does not Changed") } ipv6Available.Store(cur) }) @@ -46,8 +47,9 @@ func HasIPv6() bool { } // DisableIPv6 marks IPv6 as unavailable if enabled. -func DisableIPv6() { +func DisableIPv6(ctx context.Context) { if ipv6Available.CompareAndSwap(true, false) { - ProxyLogger.Load().Debug().Msg("turned off IPv6 availability") + logger := LoggerFromCtx(ctx) + logger.Debug().Msg("Turned off ipv6 availability") } } diff --git a/net_darwin.go b/net_darwin.go index 5b01e9f2..42c26a2c 100644 --- a/net_darwin.go +++ b/net_darwin.go @@ -3,14 +3,14 @@ package ctrld import ( "bufio" "bytes" + "context" "io" "os/exec" "strings" ) -// validInterfaces returns a set of all valid hardware ports. -// TODO: deduplicated with cmd/cli/net_darwin.go in v2. -func validInterfaces() map[string]struct{} { +// ValidInterfaces returns a set of all valid hardware ports. +func ValidInterfaces(_ context.Context) map[string]struct{} { b, err := exec.Command("networksetup", "-listallhardwareports").Output() if err != nil { return nil diff --git a/net_others.go b/net_others.go index ae7ab8e2..fef1e7d6 100644 --- a/net_others.go +++ b/net_others.go @@ -2,11 +2,14 @@ package ctrld -import "tailscale.com/net/netmon" +import ( + "context" -// validInterfaces returns a set containing only default route interfaces. -// TODO: deuplicated with cmd/cli/net_others.go in v2. -func validInterfaces() map[string]struct{} { + "tailscale.com/net/netmon" +) + +// ValidInterfaces returns a set containing only default route interfaces. +func ValidInterfaces(_ context.Context) map[string]struct{} { defaultRoute, err := netmon.DefaultRoute() if err != nil { return nil diff --git a/resolver.go b/resolver.go index 3aeddd0d..878663d4 100644 --- a/resolver.go +++ b/resolver.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net" "net/netip" "runtime" @@ -15,7 +14,6 @@ import ( "time" "github.com/miekg/dns" - "github.com/rs/zerolog" "golang.org/x/sync/singleflight" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" @@ -36,8 +34,6 @@ const ( ResolverTypeLegacy = "legacy" // ResolverTypePrivate is like ResolverTypeOS, but use for private resolver only. ResolverTypePrivate = "private" - // ResolverTypeLocal is like ResolverTypeOS, but use for local resolver only. - ResolverTypeLocal = "local" // ResolverTypeSDNS specifies resolver with information encoded using DNS Stamps. // See: https://dnscrypt.info/stamps-specifications/ ResolverTypeSDNS = "sdns" @@ -47,16 +43,6 @@ const controldPublicDns = "76.76.2.0" var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") -var localResolver Resolver - -func init() { - // Initializing ProxyLogger here, so other places don't have to do nil check. - l := zerolog.New(io.Discard) - ProxyLogger.Store(&l) - - localResolver = newLocalResolver() -} - var ( resolverMutex sync.Mutex or *osResolver @@ -64,14 +50,6 @@ var ( defaultLocalIPv6 atomic.Value // holds net.IP (IPv6) ) -func newLocalResolver() Resolver { - var nss []string - for _, addr := range Rfc1918Addresses() { - nss = append(nss, net.JoinHostPort(addr, "53")) - } - return NewResolverWithNameserver(nss) -} - // LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network. type LanQueryCtxKey struct{} @@ -81,8 +59,8 @@ func LanQueryCtx(ctx context.Context) context.Context { } // defaultNameservers is like nameservers with each element formed "ip:53". -func defaultNameservers() []string { - ns := nameservers() +func defaultNameservers(ctx context.Context) []string { + ns := nameservers(ctx) nss := make([]string, len(ns)) for i := range ns { nss[i] = net.JoinHostPort(ns[i], "53") @@ -91,42 +69,36 @@ func defaultNameservers() []string { } // availableNameservers returns list of current available DNS servers of the system. -func availableNameservers() []string { +func availableNameservers(ctx context.Context) []string { var nss []string // Ignore local addresses to prevent loop. regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - //load the logger - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), - "Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) + // Load the logger. + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) for _, v := range slices.Concat(regularIPs, loopbackIPs) { ipStr := v.String() machineIPsMap[ipStr] = struct{}{} - Log(context.Background(), logger.Debug(), - "Added local IP to OS resolverexclusion map: %s", ipStr) + logger.Debug().Msgf("Added local IP to OS resolverexclusion map: %s", ipStr) } - systemNameservers := nameservers() - Log(context.Background(), logger.Debug(), - "Got system nameservers: %v", systemNameservers) + systemNameservers := nameservers(ctx) + logger.Debug().Msgf("Got system nameservers: %v", systemNameservers) for _, ns := range systemNameservers { if _, ok := machineIPsMap[ns]; ok { - Log(context.Background(), logger.Debug(), - "Skipping local nameserver: %s", ns) + logger.Debug().Msgf("Skipping local nameserver: %s", ns) continue } nss = append(nss, ns) - Log(context.Background(), logger.Debug(), - "Added non-local nameserver: %s", ns) + logger.Debug().Msgf("Added non-local nameserver: %s", ns) } - Log(context.Background(), logger.Debug(), - "Final available nameservers: %v", nss) + logger.Debug().Msgf("Final available nameservers: %v", nss) + return nss } @@ -135,8 +107,8 @@ func availableNameservers() []string { // // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. -func InitializeOsResolver(guardAgainstNoNameservers bool) []string { - nameservers := availableNameservers() +func InitializeOsResolver(ctx context.Context, guardAgainstNoNameservers bool) []string { + nameservers := availableNameservers(ctx) // if no nameservers, return empty slice so we dont remove all nameservers if len(nameservers) == 0 && guardAgainstNoNameservers { return []string{} @@ -154,10 +126,11 @@ func InitializeOsResolver(guardAgainstNoNameservers bool) []string { // - First available LAN servers are saved and store. // - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { - var lanNss, publicNss []string - // First categorize servers + // Categorize DNS servers into LAN and public servers + // This is needed because LAN servers should be tried first for better performance, + // while public servers serve as fallback for external queries for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { @@ -171,6 +144,8 @@ func initializeOsResolver(servers []string) []string { } } + // Ensure we have at least one public DNS server as fallback + // This prevents DNS resolution failures when no public servers are configured if len(publicNss) == 0 { publicNss = []string{controldPublicDnsWithPort} } @@ -188,7 +163,7 @@ type Resolver interface { var errUnknownResolver = errors.New("unknown resolver") // NewResolver creates a Resolver based on the given upstream config. -func NewResolver(uc *UpstreamConfig) (Resolver, error) { +func NewResolver(ctx context.Context, uc *UpstreamConfig) (Resolver, error) { typ := uc.Type switch typ { case ResolverTypeDOH, ResolverTypeDOH3: @@ -200,17 +175,16 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeOS: resolverMutex.Lock() if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver") - or = newResolverWithNameserver(defaultNameservers()) + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("Initialize new OS resolver") + or = newResolverWithNameserver(defaultNameservers(ctx)) } resolverMutex.Unlock() return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil case ResolverTypePrivate: - return NewPrivateResolver(), nil - case ResolverTypeLocal: - return localResolver, nil + return NewPrivateResolver(ctx), nil } return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ) } @@ -235,14 +209,16 @@ type publicResponse struct { } // SetDefaultLocalIPv4 updates the stored local IPv4. -func SetDefaultLocalIPv4(ip net.IP) { - Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip) +func SetDefaultLocalIPv4(ctx context.Context, ip net.IP) { + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("SetDefaultLocalIPv4: %s", ip) defaultLocalIPv4.Store(ip) } // SetDefaultLocalIPv6 updates the stored local IPv6. -func SetDefaultLocalIPv6(ip net.IP) { - Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip) +func SetDefaultLocalIPv6(ctx context.Context, ip net.IP) { + logger := LoggerFromCtx(ctx) + logger.Debug().Msgf("SetDefaultLocalIPv6: %s", ip) defaultLocalIPv6.Store(ip) } @@ -300,10 +276,13 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // Unique key for the singleflight group. key := fmt.Sprintf("%s:%d:", domain, qtype) + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "OS resolver query started: %s - %s", domain, dns.TypeToString[qtype]) + // Checking the cache first. if val, ok := o.cache.Load(key); ok { if val, ok := val.(*dns.Msg); ok { - Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) + Log(ctx, logger.Debug(), "Hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) res := val.Copy() SetCacheReply(res, msg, val.Rcode) return res, nil @@ -312,8 +291,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // Ensure only one DNS query is in flight for the key. v, err, shared := o.group.Do(key, func() (interface{}, error) { + Log(ctx, logger.Debug(), "Resolving query: %s - %s", domain, dns.TypeToString[qtype]) msg, err := o.resolve(ctx, msg) if err != nil { + Log(ctx, logger.Error().Err(err), "OS resolver query failed: %s - %s", domain, dns.TypeToString[qtype]) return nil, err } // If we got an answer, storing it to the hot cache for hotCacheTTL @@ -325,6 +306,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error time.AfterFunc(hotCacheTTL, func() { o.removeCache(key) }) + Log(ctx, logger.Debug(), "OS resolver query successful: %s - %s", domain, dns.TypeToString[qtype]) return msg, nil }) if err != nil { @@ -338,7 +320,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error res := sharedMsg.Copy() SetCacheReply(res, msg, sharedMsg.Rcode) if shared { - Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) + Log(ctx, logger.Debug(), "Shared result: %s - %s", domain, dns.TypeToString[qtype]) } return res, nil @@ -368,7 +350,8 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if msg != nil && len(msg.Question) > 0 { question = msg.Question[0].Name } - Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers) + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "OS resolver query for %s with nameservers: %v public: %v", question, nss, publicServers) // New check: If no resolvers are available, return an error. if numServers == 0 { @@ -417,7 +400,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // If splitting fails, fallback to the original server string host = server } - Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host) + Log(ctx, logger.Debug(), "Got answer from nameserver: %s", host) } // try local nameservers @@ -444,7 +427,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error switch { case res.lan: // Always prefer LAN responses immediately - Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server) + Log(ctx, logger.Debug(), "Using LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -454,7 +437,7 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error // if there are no LAN nameservers, we should not wait // just use the first response if len(nss) == 0 { - Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server) + Log(ctx, logger.Debug(), "Using public answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -465,12 +448,12 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error }) } case res.answer != nil: - Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", + Log(ctx, logger.Debug(), "Got non-success answer from: %s with code: %d", res.server, res.answer.Rcode) // When there are no LAN nameservers, we should not wait // for other nameservers to respond. if len(nss) == 0 { - Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer") + Log(ctx, logger.Debug(), "No lan nameservers using public non success answer") cancel() logAnswer(res.server) return res.answer, nil @@ -483,17 +466,17 @@ func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if len(publicResponses) > 0 { resp := publicResponses[0] - Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server) + Log(ctx, logger.Debug(), "Using public answer from: %s", resp.server) logAnswer(resp.server) return resp.answer, nil } if controldSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort) + Log(ctx, logger.Debug(), "Using ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer) + Log(ctx, logger.Debug(), "Using non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } @@ -509,13 +492,16 @@ type legacyResolver struct { } func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + logger := LoggerFromCtx(ctx) + Log(ctx, logger.Debug(), "Legacy resolver query started") + // See comment in (*dotResolver).resolve method. dialer := newDialer(net.JoinHostPort(controldPublicDns, "53")) dnsTyp := uint16(0) if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - _, udpNet := r.uc.netForDNSType(dnsTyp) + _, udpNet := r.uc.netForDNSType(ctx, dnsTyp) dnsClient := &dns.Client{ Net: udpNet, Dialer: dialer, @@ -527,7 +513,13 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e endpoint = net.JoinHostPort(r.uc.BootstrapIP, port) } + Log(ctx, logger.Debug(), "Sending legacy request to: %s", endpoint) answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) + if err != nil { + Log(ctx, logger.Error().Err(err), "Legacy request failed") + } else { + Log(ctx, logger.Debug(), "Legacy resolver query successful") + } return answer, err } @@ -541,39 +533,43 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err // LookupIP looks up domain using current system nameservers settings. // It returns a slice of that host's IPv4 and IPv6 addresses. -func LookupIP(domain string) []string { - nss := initDefaultOsResolver() - return lookupIP(domain, -1, nss) +func LookupIP(ctx context.Context, domain string) []string { + nss := initDefaultOsResolver(ctx) + return lookupIP(ctx, domain, -1, nss) } // initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet. // It returns the combined list of LAN and public nameservers currently held by the resolver. -func initDefaultOsResolver() []string { +func initDefaultOsResolver(ctx context.Context) []string { + logger := LoggerFromCtx(ctx) resolverMutex.Lock() defer resolverMutex.Unlock() if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers") - or = newResolverWithNameserver(defaultNameservers()) + logger.Debug().Msgf("Initialize new OS resolver with default nameservers") + or = newResolverWithNameserver(defaultNameservers(ctx)) } nss := *or.lanServers.Load() nss = append(nss, *or.publicServers.Load()...) return nss + } // lookupIP looks up domain with given timeout and bootstrapDNS. // If the timeout is negative, default timeout 2000 ms will be used. // It returns nil if bootstrapDNS is nil or empty. -func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) { +func lookupIP(ctx context.Context, domain string, timeout int, bootstrapDNS []string) (ips []string) { if net.ParseIP(domain) != nil { return []string{domain} } + logger := LoggerFromCtx(ctx) if bootstrapDNS == nil { - ProxyLogger.Load().Debug().Msgf("empty bootstrap DNS") + logger.Debug().Msgf("Empty bootstrap dns") return nil } resolver := newResolverWithNameserver(bootstrapDNS) - ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS) + logger.Debug().Msgf("Resolving %q using bootstrap dns %q", domain, bootstrapDNS) + timeoutMs := 2000 if timeout > 0 && timeout < timeoutMs { timeoutMs = timeout @@ -616,15 +612,15 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) r, err := resolver.Resolve(ctx, m) if err != nil { - ProxyLogger.Load().Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain) + logger.Error().Err(err).Msgf("Could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain) return } if r.Rcode != dns.RcodeSuccess { - ProxyLogger.Load().Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode]) + logger.Error().Msgf("Could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode]) return } if len(r.Answer) == 0 { - ProxyLogger.Load().Error().Msg("no answer from OS resolver") + logger.Error().Msg("No answer from os resolver") return } target := targetDomain(r.Answer) @@ -641,22 +637,6 @@ func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) return ips } -// NewBootstrapResolver returns an OS resolver, which use following nameservers: -// -// - Gateway IP address (depends on OS). -// - Input servers. -func NewBootstrapResolver(servers ...string) Resolver { - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers) - nss := defaultNameservers() - nss = append([]string{controldPublicDnsWithPort}, nss...) - for _, ns := range servers { - nss = append([]string{net.JoinHostPort(ns, "53")}, nss...) - } - return NewResolverWithNameserver(nss) -} - // NewPrivateResolver returns an OS resolver, which includes only private DNS servers, // excluding: // @@ -664,9 +644,9 @@ func NewBootstrapResolver(servers ...string) Resolver { // - Nameservers which is local RFC1918 addresses. // // This is useful for doing PTR lookup in LAN network. -func NewPrivateResolver() Resolver { - nss := initDefaultOsResolver() - resolveConfNss := currentNameserversFromResolvconf() +func NewPrivateResolver(ctx context.Context) Resolver { + nss := initDefaultOsResolver(ctx) + resolveConfNss := CurrentNameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() n := 0 for _, ns := range nss { @@ -731,7 +711,7 @@ func newResolverWithNameserver(nameservers []string) *osResolver { // Rfc1918Addresses returns the list of local physical interfaces private IP addresses func Rfc1918Addresses() []string { - vis := validInterfaces() + vis := ValidInterfaces(context.Background()) var res []string netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { // Skip virtual interfaces. diff --git a/resolver_test.go b/resolver_test.go index f030739e..cfa284fb 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -132,7 +132,7 @@ func Test_osResolver_InitializationRace(t *testing.T) { for range n { go func() { defer wg.Done() - InitializeOsResolver(false) + InitializeOsResolver(LoggerCtx(context.Background(), nil), false) }() } wg.Wait() @@ -143,6 +143,8 @@ func Test_osResolver_Singleflight(t *testing.T) { if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } + defer lanPC.Close() + call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { @@ -153,7 +155,13 @@ func Test_osResolver_Singleflight(t *testing.T) { or := newResolverWithNameserver([]string{lanAddr}) domain := "controld.com" n := 10 + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var wg sync.WaitGroup + errs := make(chan error, n) + wg.Add(n) for i := 0; i < n; i++ { go func() { @@ -161,25 +169,40 @@ func Test_osResolver_Singleflight(t *testing.T) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true - _, err := or.Resolve(context.Background(), m) + _, err := or.Resolve(ctx, m) if err != nil { - t.Error(err) + errs <- err } }() } wg.Wait() + close(errs) + + // Collect any errors that occurred + for err := range errs { + t.Errorf("resolver error: %v", err) + } // All above queries should only make 1 call to server. - if call.Load() != 1 { - t.Fatalf("expected 1 result from singleflight lookup, got %d", call) + if got := call.Load(); got != 1 { + t.Fatalf("expected 1 result from singleflight lookup, got %d", got) } } func Test_osResolver_HotCache(t *testing.T) { + const ( + testIterations = 2 + cacheCheckTimeout = 5 * time.Second + pollInterval = 10 * time.Millisecond + ) + + // Setup test server lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } + defer lanPC.Close() + call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { @@ -187,58 +210,81 @@ func Test_osResolver_HotCache(t *testing.T) { } defer lanServer.Shutdown() + // Initialize resolver or := newResolverWithNameserver([]string{lanAddr}) domain := "controld.com" m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true - // Make 2 repeated queries to server, should hit hot cache. - for i := 0; i < 2; i++ { - if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + // Setup context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Make repeated queries to server, should hit hot cache + for i := 0; i < testIterations; i++ { + resp, err := or.Resolve(ctx, m.Copy()) + if err != nil { t.Fatal(err) } + // Verify response content + if resp.Rcode != dns.RcodeSuccess { + t.Errorf("expected success response, got %v", resp.Rcode) + } } + if call.Load() != 1 { t.Fatalf("cache not hit, server was called: %d", call.Load()) } + // Wait for cache to be cleaned timeoutChan := make(chan struct{}) - time.AfterFunc(5*time.Second, func() { + time.AfterFunc(cacheCheckTimeout, func() { close(timeoutChan) }) + // Check cache with proper polling interval +waitLoop: for { select { case <-timeoutChan: t.Fatal("timed out waiting for cache cleaned") - default: + case <-time.After(pollInterval): count := 0 or.cache.Range(func(key, value interface{}) bool { count++ return true }) - if count != 0 { - t.Logf("hot cache is not empty: %d elements", count) - continue + if count == 0 { + break waitLoop } + t.Logf("hot cache is not empty: %d elements", count) } - break } - if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + // Verify cache miss after cleanup + resp, err := or.Resolve(ctx, m.Copy()) + if err != nil { t.Fatal(err) } + if resp.Rcode != dns.RcodeSuccess { + t.Errorf("expected success response after cache cleanup, got %v", resp.Rcode) + } if call.Load() != 2 { t.Fatal("cache hit unexpectedly") } } func Test_Edns0_CacheReply(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } + defer lanPC.Close() + call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { @@ -252,33 +298,45 @@ func Test_Edns0_CacheReply(t *testing.T) { m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true - do := func() *dns.Msg { + do := func() (*dns.Msg, error) { msg := m.Copy() msg.SetEdns0(4096, true) cookieOption := new(dns.EDNS0_COOKIE) cookieOption.Code = dns.EDNS0COOKIE cookieOption.Cookie = generateEdns0ClientCookie() msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption) + return or.Resolve(ctx, msg) + } - answer, err := or.Resolve(context.Background(), msg) - if err != nil { - t.Fatal(err) - } - return answer + answer1, err := do() + if err != nil { + t.Fatalf("first resolve failed: %v", err) } - answer1 := do() - answer2 := do() - // Ensure the cache was hit, so we can check that edns0 cookie must be modified. - if call.Load() != 1 { - t.Fatalf("cache not hit, server was called: %d", call.Load()) + + answer2, err := do() + if err != nil { + t.Fatalf("second resolve failed: %v", err) } + + // Ensure the cache was hit + if got := call.Load(); got != 1 { + t.Fatalf("expected 1 server call, got: %d", got) + } + cookie1 := getEdns0Cookie(answer1.IsEdns0()) cookie2 := getEdns0Cookie(answer2.IsEdns0()) + if cookie1 == nil || cookie2 == nil { - t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2) + t.Fatalf("unexpected nil cookie (cookie1: %v, cookie2: %v)", cookie1, cookie2) } + if cookie1.Cookie == cookie2.Cookie { - t.Fatalf("edns0 cookie is not modified: %v", cookie1) + t.Fatalf("edns0 cookie was not modified (cookie: %v)", cookie1.Cookie) + } + + // Validate response code + if answer1.Rcode != dns.RcodeSuccess || answer2.Rcode != dns.RcodeSuccess { + t.Errorf("expected success response code, got: %v, %v", answer1.Rcode, answer2.Rcode) } } @@ -299,8 +357,9 @@ func Test_legacyResolverWithBigExtraSection(t *testing.T) { Type: ResolverTypeLegacy, Endpoint: lanAddr, } - uc.Init() - r, err := NewResolver(uc) + ctx := context.Background() + uc.Init(ctx) + r, err := NewResolver(ctx, uc) if err != nil { t.Fatal(err) } diff --git a/scripts/build.sh b/scripts/build.sh index 2faeddc8..fa365987 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -44,11 +44,11 @@ compress() { return 0 ;; *-linux-armv*) - echo >&2 "upx does not work on arm routers" + echo >&2 "upx does not work on arm platforms" return 0 ;; *-linux-mips*) - echo >&2 "upx does not work on mips routers" + echo >&2 "upx does not work on mips platforms" return 0 ;; esac diff --git a/staticdns.go b/staticdns.go index 1bfd5562..b1de8ec4 100644 --- a/staticdns.go +++ b/staticdns.go @@ -8,14 +8,9 @@ import ( "strings" ) -var homedir string - -// absHomeDir returns the absolute path to given filename using home directory as root dir. -func absHomeDir(filename string) string { - if homedir != "" { - return filepath.Join(homedir, filename) - } - dir, err := userHomeDir() +// AbsHomeDir returns the absolute path to given filename using home directory as root dir. +func AbsHomeDir(filename string) string { + dir, err := UserHomeDir() if err != nil { return filename } @@ -31,7 +26,8 @@ func dirWritable(dir string) (bool, error) { return true, f.Close() } -func userHomeDir() (string, error) { +// UserHomeDir returns the home directory for user who is running ctrld. +func UserHomeDir() (string, error) { // viper will expand for us. if runtime.GOOS == "windows" { // If we're on windows, use the install path for this. @@ -54,13 +50,18 @@ func userHomeDir() (string, error) { // SavedStaticDnsSettingsFilePath returns the file path where the static DNS settings // for the provided interface are saved. +// +// The caller must ensure iface is non-nil. func SavedStaticDnsSettingsFilePath(iface *net.Interface) string { // The file is stored in the user home directory under a hidden file. - return absHomeDir(".dns_" + iface.Name) + return AbsHomeDir(".dns_" + iface.Name) } -// SavedStaticNameservers returns the stored static nameservers for the given interface. -func SavedStaticNameservers(iface *net.Interface) ([]string, string) { +// SavedStaticNameserversAndPath returns the stored static nameservers for the given interface, +// and the absolute path to file that stored the settings. +// +// The caller must ensure iface is non-nil. +func SavedStaticNameserversAndPath(iface *net.Interface) ([]string, string) { file := SavedStaticDnsSettingsFilePath(iface) data, err := os.ReadFile(file) if err != nil || len(data) == 0 { @@ -77,3 +78,9 @@ func SavedStaticNameservers(iface *net.Interface) ([]string, string) { } return ns, file } + +// SavedStaticNameservers is like SavedStaticNameserversAndPath, but only returns the static nameservers. +func SavedStaticNameservers(iface *net.Interface) []string { + nss, _ := SavedStaticNameserversAndPath(iface) + return nss +}