Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions cmd/wireproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"context"
"fmt"
"github.com/landlock-lsm/go-landlock/landlock"
"log"
"net"
"net/http"
Expand All @@ -13,6 +12,8 @@ import (
"strconv"
"syscall"

"github.com/landlock-lsm/go-landlock/landlock"

"github.com/akamensky/argparse"
"github.com/pufferffish/wireproxy"
"golang.zx2c4.com/wireguard/device"
Expand All @@ -23,9 +24,9 @@ import (
const daemonProcess = "daemon-process"

// default paths for wireproxy config file
var default_config_paths = []string {
"/etc/wireproxy/wireproxy.conf",
os.Getenv("HOME")+"/.config/wireproxy.conf",
var default_config_paths = []string{
"/etc/wireproxy/wireproxy.conf",
os.Getenv("HOME") + "/.config/wireproxy.conf",
}

var version = "1.0.8-dev"
Expand Down Expand Up @@ -59,12 +60,12 @@ func executablePath() string {

// check if default config file paths exist
func configFilePath() (string, bool) {
for _, path := range default_config_paths {
if _, err := os.Stat(path); err == nil {
return path, true
}
}
return "", false
for _, path := range default_config_paths {
if _, err := os.Stat(path); err == nil {
return path, true
}
}
return "", false
}

func lock(stage string) {
Expand Down Expand Up @@ -193,12 +194,12 @@ func main() {
}

if *config == "" {
if path, config_exist := configFilePath(); config_exist {
*config = path
} else {
fmt.Println("configuration path is required")
return
}
if path, config_exist := configFilePath(); config_exist {
*config = path
} else {
fmt.Println("configuration path is required")
return
}
}

if !*daemon {
Expand Down Expand Up @@ -244,7 +245,7 @@ func main() {

lock("ready")

tun, err := wireproxy.StartWireguard(conf.Device, logLevel)
tun, err := wireproxy.StartWireguard(conf, logLevel)
if err != nil {
log.Fatal(err)
}
Expand Down
30 changes: 28 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,14 @@ type HTTPConfig struct {
Password string
}

type ResolveConfig struct {
ResolveStrategy string
}

type Configuration struct {
Device *DeviceConfig
Routines []RoutineSpawner
Resolve *ResolveConfig
}

func parseString(section *ini.Section, keyName string) (string, error) {
Expand Down Expand Up @@ -175,15 +180,15 @@ func parseCIDRNetIP(section *ini.Section, keyName string) ([]netip.Addr, error)
if len(str) == 0 {
continue
}

if addr, err := netip.ParseAddr(str); err == nil {
ips = append(ips, addr)
} else {
prefix, err := netip.ParsePrefix(str)
if err != nil {
return nil, err
}

addr := prefix.Addr()
ips = append(ips, addr)
}
Expand Down Expand Up @@ -435,6 +440,15 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) {
return config, nil
}

func parseResolveConfig(section *ini.Section) (*ResolveConfig, error) {
config := &ResolveConfig{}

resolvStrategy, _ := parseString(section, "ResolveStrategy")
config.ResolveStrategy = resolvStrategy

return config, nil
}

// Takes a function that parses an individual section into a config, and apply it on all
// specified sections
func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName string, f func(*ini.Section) (RoutineSpawner, error)) error {
Expand Down Expand Up @@ -472,6 +486,10 @@ func ParseConfig(path string) (*Configuration, error) {
MTU: 1420,
}

resolve := &ResolveConfig{
ResolveStrategy: "auto",
}

root := cfg.Section("")
wgConf, err := root.GetKey("WGConfig")
wgCfg := cfg
Expand Down Expand Up @@ -519,8 +537,16 @@ func ParseConfig(path string) (*Configuration, error) {
return nil, err
}

if resolveSection, err := cfg.GetSection("Resolve"); err == nil {
resolve, err = parseResolveConfig(resolveSection)
if err != nil {
return nil, err
}
}

return &Configuration{
Device: device,
Routines: routinesSpawners,
Resolve: resolve,
}, nil
}
61 changes: 39 additions & 22 deletions routine.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/device"
"io"
"log"
"math/rand"
Expand All @@ -24,6 +20,11 @@ import (
"sync"
"time"

"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/device"

"github.com/things-go/go-socks5"
"github.com/things-go/go-socks5/bufferpool"

Expand All @@ -43,10 +44,11 @@ type CredentialValidator struct {

// VirtualTun stores a reference to netstack network and DNS configuration
type VirtualTun struct {
Tnet *netstack.Net
Dev *device.Device
SystemDNS bool
Conf *DeviceConfig
Tnet *netstack.Net
Dev *device.Device
SystemDNS bool
Conf *DeviceConfig
ResolveConfig *ResolveConfig
// PingRecord stores the last time an IP was pinged
PingRecord map[string]uint64
PingRecordLock *sync.Mutex
Expand Down Expand Up @@ -79,33 +81,48 @@ func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*n
return nil, err
}

size := len(addrs)
if size == 0 {
return nil, errors.New("no address found for: " + name)
}

rand.Shuffle(size, func(i, j int) {
addrs[i], addrs[j] = addrs[j], addrs[i]
})
addrs_v4 := []netip.Addr{}
addrs_v6 := []netip.Addr{}

var addr netip.Addr
for _, saddr := range addrs {
addr, err = netip.ParseAddr(saddr)
addr, err := netip.ParseAddr(saddr)
if err == nil {
break
if addr.Is4() {
addrs_v4 = append(addrs_v4, addr)
} else if addr.Is6() {
addrs_v6 = append(addrs_v6, addr)
}
}
}

if err != nil {
return nil, err
rand.Shuffle(len(addrs_v4), func(i, j int) {
addrs_v4[i], addrs_v4[j] = addrs_v4[j], addrs_v4[i]
})
rand.Shuffle(len(addrs_v6), func(i, j int) {
addrs_v6[i], addrs_v6[j] = addrs_v6[j], addrs_v6[i]
})

addrs_all := []netip.Addr{}

switch d.ResolveConfig.ResolveStrategy {
case "ipv4":
addrs_all = append(addrs_v4, addrs_v6...)
case "ipv6":
addrs_all = append(addrs_v6, addrs_v4...)
}

if len(addrs_all) == 0 {
return nil, errors.New("no address found for: " + name)
}

return &addr, nil
return &addrs_all[0], nil
}

// Resolve resolves a hostname and returns an IP.
// DNS traffic may or may not be routed depending on VirtualTun's setting
func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
log.Printf("Resolving address for %s\n", name)

addr, err := d.ResolveAddrWithContext(ctx, name)
if err != nil {
return nil, nil, err
Expand Down
26 changes: 23 additions & 3 deletions wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ func CreateIPCRequest(conf *DeviceConfig) (*DeviceSetting, error) {
}

// StartWireguard creates a tun interface on netstack given a configuration
func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) {
setting, err := CreateIPCRequest(conf)
func StartWireguard(conf *Configuration, logLevel int) (*VirtualTun, error) {
deviceConf := conf.Device
setting, err := CreateIPCRequest(deviceConf)
if err != nil {
return nil, err
}
Expand All @@ -81,10 +82,29 @@ func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) {
return nil, err
}

hasV4 := false
hasV6 := false
for _, addr := range setting.DeviceAddr {
if addr.Is4() {
hasV4 = true
}
if addr.Is6() {
hasV6 = true
}
}

if conf.Resolve.ResolveStrategy == "auto" {
if hasV4 && !hasV6 {
conf.Resolve.ResolveStrategy = "ipv4"
} else {
conf.Resolve.ResolveStrategy = "ipv6"
}
}
return &VirtualTun{
Tnet: tnet,
Dev: dev,
Conf: conf,
Conf: deviceConf,
ResolveConfig: conf.Resolve,
SystemDNS: len(setting.DNS) == 0,
PingRecord: make(map[string]uint64),
PingRecordLock: new(sync.Mutex),
Expand Down