diff --git a/cmd/wireproxy/main.go b/cmd/wireproxy/main.go index 713943a..f97be6b 100644 --- a/cmd/wireproxy/main.go +++ b/cmd/wireproxy/main.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "github.com/landlock-lsm/go-landlock/landlock" "log" "net" "net/http" @@ -13,6 +12,8 @@ import ( "strconv" "syscall" + "github.com/landlock-lsm/go-landlock/landlock" + "github.com/akamensky/argparse" "github.com/pufferffish/wireproxy" "golang.zx2c4.com/wireguard/device" @@ -23,9 +24,9 @@ import ( const daemonProcess = "daemon-process" // default paths for wireproxy config file -var default_config_paths = []string { - "/etc/wireproxy/wireproxy.conf", - os.Getenv("HOME")+"/.config/wireproxy.conf", +var default_config_paths = []string{ + "/etc/wireproxy/wireproxy.conf", + os.Getenv("HOME") + "/.config/wireproxy.conf", } var version = "1.0.8-dev" @@ -59,12 +60,12 @@ func executablePath() string { // check if default config file paths exist func configFilePath() (string, bool) { - for _, path := range default_config_paths { - if _, err := os.Stat(path); err == nil { - return path, true - } - } - return "", false + for _, path := range default_config_paths { + if _, err := os.Stat(path); err == nil { + return path, true + } + } + return "", false } func lock(stage string) { @@ -193,12 +194,12 @@ func main() { } if *config == "" { - if path, config_exist := configFilePath(); config_exist { - *config = path - } else { - fmt.Println("configuration path is required") - return - } + if path, config_exist := configFilePath(); config_exist { + *config = path + } else { + fmt.Println("configuration path is required") + return + } } if !*daemon { @@ -244,7 +245,7 @@ func main() { lock("ready") - tun, err := wireproxy.StartWireguard(conf.Device, logLevel) + tun, err := wireproxy.StartWireguard(conf, logLevel) if err != nil { log.Fatal(err) } diff --git a/config.go b/config.go index 1f6e4e4..9af5b22 100644 --- a/config.go +++ b/config.go @@ -59,9 +59,14 @@ type HTTPConfig struct { Password string } +type ResolveConfig struct { + ResolveStrategy string +} + type Configuration struct { Device *DeviceConfig Routines []RoutineSpawner + Resolve *ResolveConfig } func parseString(section *ini.Section, keyName string) (string, error) { @@ -175,7 +180,7 @@ func parseCIDRNetIP(section *ini.Section, keyName string) ([]netip.Addr, error) if len(str) == 0 { continue } - + if addr, err := netip.ParseAddr(str); err == nil { ips = append(ips, addr) } else { @@ -183,7 +188,7 @@ func parseCIDRNetIP(section *ini.Section, keyName string) ([]netip.Addr, error) if err != nil { return nil, err } - + addr := prefix.Addr() ips = append(ips, addr) } @@ -435,6 +440,15 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) { return config, nil } +func parseResolveConfig(section *ini.Section) (*ResolveConfig, error) { + config := &ResolveConfig{} + + resolvStrategy, _ := parseString(section, "ResolveStrategy") + config.ResolveStrategy = resolvStrategy + + return config, nil +} + // Takes a function that parses an individual section into a config, and apply it on all // specified sections func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName string, f func(*ini.Section) (RoutineSpawner, error)) error { @@ -472,6 +486,10 @@ func ParseConfig(path string) (*Configuration, error) { MTU: 1420, } + resolve := &ResolveConfig{ + ResolveStrategy: "auto", + } + root := cfg.Section("") wgConf, err := root.GetKey("WGConfig") wgCfg := cfg @@ -519,8 +537,16 @@ func ParseConfig(path string) (*Configuration, error) { return nil, err } + if resolveSection, err := cfg.GetSection("Resolve"); err == nil { + resolve, err = parseResolveConfig(resolveSection) + if err != nil { + return nil, err + } + } + return &Configuration{ Device: device, Routines: routinesSpawners, + Resolve: resolve, }, nil } diff --git a/routine.go b/routine.go index edfc793..0a0ea82 100644 --- a/routine.go +++ b/routine.go @@ -8,10 +8,6 @@ import ( "encoding/binary" "encoding/json" "errors" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/device" "io" "log" "math/rand" @@ -24,6 +20,11 @@ import ( "sync" "time" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/device" + "github.com/things-go/go-socks5" "github.com/things-go/go-socks5/bufferpool" @@ -43,10 +44,11 @@ type CredentialValidator struct { // VirtualTun stores a reference to netstack network and DNS configuration type VirtualTun struct { - Tnet *netstack.Net - Dev *device.Device - SystemDNS bool - Conf *DeviceConfig + Tnet *netstack.Net + Dev *device.Device + SystemDNS bool + Conf *DeviceConfig + ResolveConfig *ResolveConfig // PingRecord stores the last time an IP was pinged PingRecord map[string]uint64 PingRecordLock *sync.Mutex @@ -79,33 +81,48 @@ func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*n return nil, err } - size := len(addrs) - if size == 0 { - return nil, errors.New("no address found for: " + name) - } - - rand.Shuffle(size, func(i, j int) { - addrs[i], addrs[j] = addrs[j], addrs[i] - }) + addrs_v4 := []netip.Addr{} + addrs_v6 := []netip.Addr{} - var addr netip.Addr for _, saddr := range addrs { - addr, err = netip.ParseAddr(saddr) + addr, err := netip.ParseAddr(saddr) if err == nil { - break + if addr.Is4() { + addrs_v4 = append(addrs_v4, addr) + } else if addr.Is6() { + addrs_v6 = append(addrs_v6, addr) + } } } - if err != nil { - return nil, err + rand.Shuffle(len(addrs_v4), func(i, j int) { + addrs_v4[i], addrs_v4[j] = addrs_v4[j], addrs_v4[i] + }) + rand.Shuffle(len(addrs_v6), func(i, j int) { + addrs_v6[i], addrs_v6[j] = addrs_v6[j], addrs_v6[i] + }) + + addrs_all := []netip.Addr{} + + switch d.ResolveConfig.ResolveStrategy { + case "ipv4": + addrs_all = append(addrs_v4, addrs_v6...) + case "ipv6": + addrs_all = append(addrs_v6, addrs_v4...) + } + + if len(addrs_all) == 0 { + return nil, errors.New("no address found for: " + name) } - return &addr, nil + return &addrs_all[0], nil } // Resolve resolves a hostname and returns an IP. // DNS traffic may or may not be routed depending on VirtualTun's setting func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { + log.Printf("Resolving address for %s\n", name) + addr, err := d.ResolveAddrWithContext(ctx, name) if err != nil { return nil, nil, err diff --git a/wireguard.go b/wireguard.go index 71a2960..6b54887 100644 --- a/wireguard.go +++ b/wireguard.go @@ -60,8 +60,9 @@ func CreateIPCRequest(conf *DeviceConfig) (*DeviceSetting, error) { } // StartWireguard creates a tun interface on netstack given a configuration -func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) { - setting, err := CreateIPCRequest(conf) +func StartWireguard(conf *Configuration, logLevel int) (*VirtualTun, error) { + deviceConf := conf.Device + setting, err := CreateIPCRequest(deviceConf) if err != nil { return nil, err } @@ -81,10 +82,29 @@ func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) { return nil, err } + hasV4 := false + hasV6 := false + for _, addr := range setting.DeviceAddr { + if addr.Is4() { + hasV4 = true + } + if addr.Is6() { + hasV6 = true + } + } + + if conf.Resolve.ResolveStrategy == "auto" { + if hasV4 && !hasV6 { + conf.Resolve.ResolveStrategy = "ipv4" + } else { + conf.Resolve.ResolveStrategy = "ipv6" + } + } return &VirtualTun{ Tnet: tnet, Dev: dev, - Conf: conf, + Conf: deviceConf, + ResolveConfig: conf.Resolve, SystemDNS: len(setting.DNS) == 0, PingRecord: make(map[string]uint64), PingRecordLock: new(sync.Mutex),