Skip to content
This repository was archived by the owner on Jul 26, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gomail

import (
"bytes"
"context"
"io"
"os"
"path/filepath"
Expand All @@ -19,6 +20,7 @@ type Message struct {
hEncoder mimeEncoder
buf bytes.Buffer
boundary string
ctx context.Context
}

type header map[string][]string
Expand All @@ -36,6 +38,7 @@ func NewMessage(settings ...MessageSetting) *Message {
header: make(header),
charset: "UTF-8",
encoding: QuotedPrintable,
ctx: context.Background(),
}

m.applySettings(settings)
Expand All @@ -60,6 +63,19 @@ func (m *Message) Reset() {
m.embedded = nil
}

// Context returns the message's internal context. The context is either set
// using SetContext or it's defaulted to Background.
func (m *Message) Context() context.Context {
return m.ctx
}

// WithContext copies the message and makes it use a different context.
func (m *Message) WithContext(ctx context.Context) *Message {
m2 := *m
m2.ctx = ctx
return &m2
}

func (m *Message) applySettings(settings []MessageSetting) {
for _, s := range settings {
s(m)
Expand All @@ -84,6 +100,15 @@ func SetEncoding(enc Encoding) MessageSetting {
}
}

// SetContext is a message setting to set the context of the email. The context
// determines cancellation and timeout for sending the message over the SMTP
// connection.
func SetContext(ctx context.Context) MessageSetting {
return func(m *Message) {
m.ctx = ctx
}
}

// Encoding represents a MIME encoding scheme like quoted-printable or base64.
type Encoding string

Expand Down
72 changes: 69 additions & 3 deletions smtp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gomail

import (
"context"
"crypto/tls"
"fmt"
"io"
Expand All @@ -26,6 +27,9 @@ type Dialer struct {
Auth smtp.Auth
// Port represents the port of the SMTP server.
Port int
// NetDialer is the net.Dialer instance to use. For legacy purposes, if
// NetDialTimeout is not net.DialTimeout, then this field is not used.
NetDialer net.Dialer
// TLSConfig represents the TLS configuration used for the TLS (when the
// STARTTLS extension is used) or SSL connection.
TLSConfig *tls.Config
Expand Down Expand Up @@ -74,12 +78,30 @@ func NewPlainDialer(host string, port int, username, password string) *Dialer {
// NetDialTimeout specifies the DialTimeout function to establish a connection
// to the SMTP server. This can be used to override dialing in the case that a
// proxy or other special behavior is needed.
var NetDialTimeout = net.DialTimeout
//
// Deprecated: use (*Dialer).NetDialer instead. If NetDialTimeout is nil, then
// (*Dialer).NetDialer is used.
var NetDialTimeout func(network, address string, timeout time.Duration) (net.Conn, error) = nil

// Dial dials and authenticates to an SMTP server. The returned SendCloser
// should be closed when done using it.
func (d *Dialer) Dial() (SendCloser, error) {
conn, err := NetDialTimeout("tcp", addr(d.Host, d.Port), d.Timeout)
return d.DialCtx(context.Background())
}

// DialCtx is Dial with context support.
func (d *Dialer) DialCtx(ctx context.Context) (SendCloser, error) {
var conn net.Conn
var err error

if NetDialTimeout == nil {
ctx, cancel := context.WithTimeout(ctx, d.Timeout)
defer cancel()

conn, err = d.NetDialer.DialContext(ctx, "tcp", addr(d.Host, d.Port))
} else {
conn, err = NetDialTimeout("tcp", addr(d.Host, d.Port), d.Timeout)
}
if err != nil {
return nil, err
}
Expand All @@ -93,6 +115,19 @@ func (d *Dialer) Dial() (SendCloser, error) {
return nil, err
}

done := make(chan struct{})
defer close(done)

go func() {
select {
case <-ctx.Done():
// Parent context expired. Immediately terminate and return.
c.Close()
case <-done:
// ok
}
}()

if d.Timeout > 0 {
conn.SetDeadline(time.Now().Add(d.Timeout))
}
Expand Down Expand Up @@ -201,12 +236,21 @@ func addr(host string, port int) string {
// DialAndSend opens a connection to the SMTP server, sends the given emails and
// closes the connection.
func (d *Dialer) DialAndSend(m ...*Message) error {
s, err := d.Dial()
return d.DialAndSendCtx(context.Background(), m...)
}

// DialAndSendCtx is DialAndSend with context support.
func (d *Dialer) DialAndSendCtx(ctx context.Context, m ...*Message) error {
s, err := d.DialCtx(ctx)
if err != nil {
return err
}
defer s.Close()

for i := range m {
m[i] = m[i].WithContext(ctx)
}

return Send(s, m...)
}

Expand All @@ -228,7 +272,29 @@ func (c *smtpSender) retryError(err error) bool {
return err == io.EOF
}

type messageContexter interface {
Context() context.Context
}

var _ messageContexter = (*Message)(nil)

func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error {
if ctxer, ok := msg.(messageContexter); ok {
if ctx := ctxer.Context(); ctx != context.Background() {
done := make(chan struct{})
defer close(done)

go func() {
select {
case <-ctx.Done():
c.conn.Close()
case <-done:
// ok
}
}()
}
}

if c.d.Timeout > 0 {
c.conn.SetDeadline(time.Now().Add(c.d.Timeout))
}
Expand Down