diff --git a/message.go b/message.go index 86bcd75..3f5ce57 100644 --- a/message.go +++ b/message.go @@ -2,6 +2,7 @@ package gomail import ( "bytes" + "context" "io" "os" "path/filepath" @@ -19,6 +20,7 @@ type Message struct { hEncoder mimeEncoder buf bytes.Buffer boundary string + ctx context.Context } type header map[string][]string @@ -36,6 +38,7 @@ func NewMessage(settings ...MessageSetting) *Message { header: make(header), charset: "UTF-8", encoding: QuotedPrintable, + ctx: context.Background(), } m.applySettings(settings) @@ -60,6 +63,19 @@ func (m *Message) Reset() { m.embedded = nil } +// Context returns the message's internal context. The context is either set +// using SetContext or it's defaulted to Background. +func (m *Message) Context() context.Context { + return m.ctx +} + +// WithContext copies the message and makes it use a different context. +func (m *Message) WithContext(ctx context.Context) *Message { + m2 := *m + m2.ctx = ctx + return &m2 +} + func (m *Message) applySettings(settings []MessageSetting) { for _, s := range settings { s(m) @@ -84,6 +100,15 @@ func SetEncoding(enc Encoding) MessageSetting { } } +// SetContext is a message setting to set the context of the email. The context +// determines cancellation and timeout for sending the message over the SMTP +// connection. +func SetContext(ctx context.Context) MessageSetting { + return func(m *Message) { + m.ctx = ctx + } +} + // Encoding represents a MIME encoding scheme like quoted-printable or base64. type Encoding string diff --git a/smtp.go b/smtp.go index c56a65b..b19ede1 100644 --- a/smtp.go +++ b/smtp.go @@ -1,6 +1,7 @@ package gomail import ( + "context" "crypto/tls" "fmt" "io" @@ -26,6 +27,9 @@ type Dialer struct { Auth smtp.Auth // Port represents the port of the SMTP server. Port int + // NetDialer is the net.Dialer instance to use. For legacy purposes, if + // NetDialTimeout is not net.DialTimeout, then this field is not used. + NetDialer net.Dialer // TLSConfig represents the TLS configuration used for the TLS (when the // STARTTLS extension is used) or SSL connection. TLSConfig *tls.Config @@ -74,12 +78,30 @@ func NewPlainDialer(host string, port int, username, password string) *Dialer { // NetDialTimeout specifies the DialTimeout function to establish a connection // to the SMTP server. This can be used to override dialing in the case that a // proxy or other special behavior is needed. -var NetDialTimeout = net.DialTimeout +// +// Deprecated: use (*Dialer).NetDialer instead. If NetDialTimeout is nil, then +// (*Dialer).NetDialer is used. +var NetDialTimeout func(network, address string, timeout time.Duration) (net.Conn, error) = nil // Dial dials and authenticates to an SMTP server. The returned SendCloser // should be closed when done using it. func (d *Dialer) Dial() (SendCloser, error) { - conn, err := NetDialTimeout("tcp", addr(d.Host, d.Port), d.Timeout) + return d.DialCtx(context.Background()) +} + +// DialCtx is Dial with context support. +func (d *Dialer) DialCtx(ctx context.Context) (SendCloser, error) { + var conn net.Conn + var err error + + if NetDialTimeout == nil { + ctx, cancel := context.WithTimeout(ctx, d.Timeout) + defer cancel() + + conn, err = d.NetDialer.DialContext(ctx, "tcp", addr(d.Host, d.Port)) + } else { + conn, err = NetDialTimeout("tcp", addr(d.Host, d.Port), d.Timeout) + } if err != nil { return nil, err } @@ -93,6 +115,19 @@ func (d *Dialer) Dial() (SendCloser, error) { return nil, err } + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + // Parent context expired. Immediately terminate and return. + c.Close() + case <-done: + // ok + } + }() + if d.Timeout > 0 { conn.SetDeadline(time.Now().Add(d.Timeout)) } @@ -201,12 +236,21 @@ func addr(host string, port int) string { // DialAndSend opens a connection to the SMTP server, sends the given emails and // closes the connection. func (d *Dialer) DialAndSend(m ...*Message) error { - s, err := d.Dial() + return d.DialAndSendCtx(context.Background(), m...) +} + +// DialAndSendCtx is DialAndSend with context support. +func (d *Dialer) DialAndSendCtx(ctx context.Context, m ...*Message) error { + s, err := d.DialCtx(ctx) if err != nil { return err } defer s.Close() + for i := range m { + m[i] = m[i].WithContext(ctx) + } + return Send(s, m...) } @@ -228,7 +272,29 @@ func (c *smtpSender) retryError(err error) bool { return err == io.EOF } +type messageContexter interface { + Context() context.Context +} + +var _ messageContexter = (*Message)(nil) + func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error { + if ctxer, ok := msg.(messageContexter); ok { + if ctx := ctxer.Context(); ctx != context.Background() { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + c.conn.Close() + case <-done: + // ok + } + }() + } + } + if c.d.Timeout > 0 { c.conn.SetDeadline(time.Now().Add(c.d.Timeout)) }