From b919d3b9976a53fd8d6c8162cb81ad7aef0181bf Mon Sep 17 00:00:00 2001 From: j-xhan Date: Wed, 23 May 2018 16:48:24 -0700 Subject: [PATCH] add EndlessServer struct wrapping smtpd server instead of http server --- endless.go | 351 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 341 insertions(+), 10 deletions(-) diff --git a/endless.go b/endless.go index 9604a20..326a038 100644 --- a/endless.go +++ b/endless.go @@ -17,6 +17,7 @@ import ( "time" // "github.com/fvbock/uds-go/introspect" + "github.com/rguliyev/smtpd" ) const ( @@ -33,6 +34,8 @@ var ( runningServerReg sync.RWMutex runningServers map[string]*endlessServer runningServersOrder []string + SMTPDrunningServers map[string]*endlessSMTPDServer + SMTPDrunningServersOrder []string socketPtrOffsetMap map[string]uint runningServersForked bool @@ -51,6 +54,8 @@ func init() { runningServerReg = sync.RWMutex{} runningServers = make(map[string]*endlessServer) runningServersOrder = []string{} + SMTPDrunningServers = make(map[string]*endlessSMTPDServer) + SMTPDrunningServersOrder = []string{} socketPtrOffsetMap = make(map[string]uint) DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB) @@ -82,6 +87,19 @@ type endlessServer struct { BeforeBegin func(add string) } +type endlessSMTPDServer struct { + smtpd.Server + EndlessListener net.Listener + SignalHooks map[int]map[os.Signal][]func() + tlsInnerListener *SMTPDendlessListener + wg sync.WaitGroup + sigChan chan os.Signal + isChild bool + state uint8 + lock *sync.RWMutex + BeforeBegin func(add string) +} + /* NewServer returns an intialized endlessServer Object. Calling Serve on it will actually "start" the server. @@ -98,7 +116,7 @@ func NewServer(addr string, handler http.Handler) (srv *endlessServer) { socketPtrOffsetMap[addr] = uint(i) } } else { - socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) + socketPtrOffsetMap[addr] = uint(len(runningServersOrder) + len(SMTPDrunningServersOrder)) } srv = &endlessServer{ @@ -309,6 +327,244 @@ func (srv *endlessServer) getListener(laddr string) (l net.Listener, err error) return } +/* +SMTPDNewServer returns an intialized endlessSMTPDServer Object. Calling Serve on it will +actually "start" the server. +*/ +func SMTPDNewServer(addr string, handler smtpd.Handler, appname string, hostname string) (srv *endlessSMTPDServer) { + runningServerReg.Lock() + defer runningServerReg.Unlock() + + socketOrder = os.Getenv("ENDLESS_SOCKET_ORDER") + isChild = os.Getenv("ENDLESS_CONTINUE") != "" + + if len(socketOrder) > 0 { + for i, addr := range strings.Split(socketOrder, ",") { + socketPtrOffsetMap[addr] = uint(i) + } + } else { + socketPtrOffsetMap[addr] = uint(len(runningServersOrder) + len(SMTPDrunningServersOrder)) + } + + srv = &endlessSMTPDServer{ + wg: sync.WaitGroup{}, + sigChan: make(chan os.Signal), + isChild: isChild, + SignalHooks: map[int]map[os.Signal][]func(){ + PRE_SIGNAL: map[os.Signal][]func(){ + syscall.SIGHUP: []func(){}, + syscall.SIGUSR1: []func(){}, + syscall.SIGUSR2: []func(){}, + syscall.SIGINT: []func(){}, + syscall.SIGTERM: []func(){}, + syscall.SIGTSTP: []func(){}, + }, + POST_SIGNAL: map[os.Signal][]func(){ + syscall.SIGHUP: []func(){}, + syscall.SIGUSR1: []func(){}, + syscall.SIGUSR2: []func(){}, + syscall.SIGINT: []func(){}, + syscall.SIGTERM: []func(){}, + syscall.SIGTSTP: []func(){}, + }, + }, + state: STATE_INIT, + lock: &sync.RWMutex{}, + } + + srv.Server.Addr = addr + srv.Server.Handler = handler + srv.Server.Appname = appname + srv.Server.Hostname = hostname + + srv.BeforeBegin = func(addr string) { + log.Println(syscall.Getpid(), addr) + } + + if SMTPDrunningServers == nil { + SMTPDrunningServers = make(map[string]*endlessSMTPDServer) + SMTPDrunningServersOrder = []string{} + } + SMTPDrunningServersOrder = append(SMTPDrunningServersOrder, addr) + SMTPDrunningServers[addr] = srv + + return +} + +/* +ListenAndServe listens on the TCP network address addr and then calls Serve +with handler to handle requests on incoming connections. Handler is typically +nil, in which case the DefaultServeMux is used. +*/ +func SMTPDListenAndServe(addr string, handler smtpd.Handler, appname string, hostname string) error { + server := SMTPDNewServer(addr, handler, appname, hostname) + return server.SMTPDListenAndServe() +} + +/* +ListenAndServeTLS acts identically to ListenAndServe, except that it expects +HTTPS connections. Additionally, files containing a certificate and matching +private key for the server must be provided. If the certificate is signed by a +certificate authority, the certFile should be the concatenation of the server's +certificate followed by the CA's certificate. +*/ +func SMTPDListenAndServeTLS(addr string, certFile string, keyFile string, handler smtpd.Handler, appname string, hostname string) error { + server := SMTPDNewServer(addr, handler, appname, hostname) + return server.SMTPDListenAndServeTLS(addr, certFile, keyFile, handler, appname, hostname) +} + +func (srv *endlessSMTPDServer) SMTPDgetState() uint8 { + srv.lock.RLock() + defer srv.lock.RUnlock() + + return srv.state +} + +func (srv *endlessSMTPDServer) SMTPDsetState(st uint8) { + srv.lock.Lock() + defer srv.lock.Unlock() + + srv.state = st +} + +/* +Serve accepts incoming HTTP connections on the listener l, creating a new +service goroutine for each. The service goroutines read requests and then call +handler to reply to them. Handler is typically nil, in which case the +DefaultServeMux is used. + +In addition to the stl Serve behaviour each connection is added to a +sync.Waitgroup so that all outstanding connections can be served before shutting +down the server. +*/ +func (srv *endlessSMTPDServer) SMTPDServe() (err error) { + defer log.Println(syscall.Getpid(), "Serve() returning...") + srv.SMTPDsetState(STATE_RUNNING) + err = srv.Server.Serve(srv.EndlessListener) + log.Println(syscall.Getpid(), "Waiting for connections to finish...") + srv.wg.Wait() + srv.SMTPDsetState(STATE_TERMINATE) + return +} + +/* +ListenAndServe listens on the TCP network address srv.Addr and then calls Serve +to handle requests on incoming connections. If srv.Addr is blank, ":http" is +used. +*/ +func (srv *endlessSMTPDServer) SMTPDListenAndServe() (err error) { + addr := srv.Addr + if srv.Addr == "" { + srv.Addr = ":25" + } + if srv.Appname == "" { + srv.Appname = "smtpd" + } + if srv.Hostname == "" { + srv.Hostname, _ = os.Hostname() + } + if srv.Timeout == 0 { + srv.Timeout = 5 * time.Minute + } + + // go srv.handleSignals() + + l, err := srv.SMTPDgetListener(addr) + if err != nil { + log.Println(err) + return + } + + srv.EndlessListener = SMTPDnewEndlessListener(l, srv) + + if srv.isChild { + syscall.Kill(syscall.Getppid(), syscall.SIGTERM) + } + + srv.BeforeBegin(srv.Addr) + + return srv.Serve(l) + +} + +/* +ListenAndServeTLS listens on the TCP network address srv.Addr and then calls +Serve to handle requests on incoming TLS connections. + +Filenames containing a certificate and matching private key for the server must +be provided. If the certificate is signed by a certificate authority, the +certFile should be the concatenation of the server's certificate followed by the +CA's certificate. + +If srv.Addr is blank, ":https" is used. +*/ +func (srv *endlessSMTPDServer) SMTPDListenAndServeTLS(addr, certFile, keyFile string, handler smtpd.Handler, appname string, hostname string) (err error) { + if srv.Addr == "" { + srv.Addr = ":25" + } + if srv.Appname == "" { + srv.Appname = "smtpd" + } + if srv.Hostname == "" { + srv.Hostname, _ = os.Hostname() + } + if srv.Timeout == 0 { + srv.Timeout = 5 * time.Minute + } + + err = srv.ConfigureTLS(certFile, keyFile) + if err != nil { + return err + } + config := srv.TLSConfig + + l, err := srv.SMTPDgetListener(addr) + if err != nil { + log.Println(err) + return + } + + srv.tlsInnerListener = SMTPDnewEndlessListener(l, srv) + srv.EndlessListener = tls.NewListener(srv.tlsInnerListener, config) + + if srv.isChild { + syscall.Kill(syscall.Getppid(), syscall.SIGTERM) + } + + log.Println(syscall.Getpid(), srv.Addr) + return srv.Serve(l) +} + +/* +getListener either opens a new socket to listen on, or takes the acceptor socket +it got passed when restarted. +*/ +func (srv *endlessSMTPDServer) SMTPDgetListener(laddr string) (l net.Listener, err error) { + if srv.isChild { + var ptrOffset uint = 0 + runningServerReg.RLock() + defer runningServerReg.RUnlock() + if len(socketPtrOffsetMap) > 0 { + ptrOffset = socketPtrOffsetMap[laddr] + // log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) + } + + f := os.NewFile(uintptr(3+ptrOffset), "") + l, err = net.FileListener(f) + if err != nil { + err = fmt.Errorf("net.FileListener error: %v", err) + return + } + } else { + l, err = net.Listen("tcp", laddr) + if err != nil { + err = fmt.Errorf("net.Listen error: %v", err) + return + } + } + return +} + /* handleSignals listens for os Signals and calls any hooked in function that the user had registered with the signal. @@ -377,7 +633,7 @@ func (srv *endlessServer) shutdown() { go srv.hammerTime(DefaultHammerTime) } // disable keep-alives on existing connections - srv.SetKeepAlivesEnabled(false) + // srv.SetKeepAlivesEnabled(false) err := srv.EndlessListener.Close() if err != nil { log.Println(syscall.Getpid(), "Listener.Close() error:", err) @@ -429,18 +685,31 @@ func (srv *endlessServer) fork() (err error) { runningServersForked = true - var files = make([]*os.File, len(runningServers)) - var orderArgs = make([]string, len(runningServers)) + var files = make([]*os.File, len(runningServers) + len(SMTPDrunningServers)) + var orderArgs = make([]string, len(runningServers) + len(SMTPDrunningServers)) // get the accessor socket fds for _all_ server instances for _, srvPtr := range runningServers { // introspect.PrintTypeDump(srvPtr.EndlessListener) switch srvPtr.EndlessListener.(type) { case *endlessListener: // normal listener - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.EndlessListener.(*endlessListener).File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]], err = srvPtr.EndlessListener.(*endlessListener).File() + default: + // tls listener + files[socketPtrOffsetMap[srvPtr.Server.Addr]], err = srvPtr.tlsInnerListener.File() + } + orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr + } + + for _, srvPtr := range SMTPDrunningServers { + // introspect.PrintTypeDump(srvPtr.EndlessListener) + switch srvPtr.EndlessListener.(type) { + case *SMTPDendlessListener: + // normal listener + files[socketPtrOffsetMap[srvPtr.Server.Addr]], err = srvPtr.EndlessListener.(*SMTPDendlessListener).File() default: // tls listener - files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File() + files[socketPtrOffsetMap[srvPtr.Server.Addr]], err = srvPtr.tlsInnerListener.File() } orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr } @@ -449,7 +718,7 @@ func (srv *endlessServer) fork() (err error) { os.Environ(), "ENDLESS_CONTINUE=1", ) - if len(runningServers) > 1 { + if len(runningServers) + len(SMTPDrunningServers) > 1 { env = append(env, fmt.Sprintf(`ENDLESS_SOCKET_ORDER=%s`, strings.Join(orderArgs, ","))) } @@ -522,11 +791,11 @@ func (el *endlessListener) Close() error { return el.Listener.Close() } -func (el *endlessListener) File() *os.File { +func (el *endlessListener) File() (*os.File, error) { // returns a dup(2) - FD_CLOEXEC flag *not* set tl := el.Listener.(*net.TCPListener) - fl, _ := tl.File() - return fl + fl, err := tl.File() + return fl, err } type endlessConn struct { @@ -542,6 +811,68 @@ func (w endlessConn) Close() error { return err } +type SMTPDendlessListener struct { + net.Listener + stopped bool + server *endlessSMTPDServer +} + +func (el *SMTPDendlessListener) Accept() (c net.Conn, err error) { + tc, err := el.Listener.(*net.TCPListener).AcceptTCP() + if err != nil { + return + } + + tc.SetKeepAlive(true) // see http.tcpKeepAliveListener + tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener + + c = SMTPDendlessConn{ + Conn: tc, + server: el.server, + } + + el.server.wg.Add(1) + return +} + +func SMTPDnewEndlessListener(l net.Listener, srv *endlessSMTPDServer) (el *SMTPDendlessListener) { + el = &SMTPDendlessListener{ + Listener: l, + server: srv, + } + + return +} + +func (el *SMTPDendlessListener) Close() error { + if el.stopped { + return syscall.EINVAL + } + + el.stopped = true + return el.Listener.Close() +} + +func (el *SMTPDendlessListener) File() (*os.File, error) { + // returns a dup(2) - FD_CLOEXEC flag *not* set + tl := el.Listener.(*net.TCPListener) + fl, err := tl.File() + return fl, err +} + +type SMTPDendlessConn struct { + net.Conn + server *endlessSMTPDServer +} + +func (w SMTPDendlessConn) Close() error { + err := w.Conn.Close() + if err == nil { + w.server.wg.Done() + } + return err +} + /* RegisterSignalHook registers a function to be run PRE_SIGNAL or POST_SIGNAL for a given signal. PRE or POST in this case means before or after the signal