diff --git a/README.md b/README.md index 3c1e3a0..fc29506 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Stack Up is a simple deployment tool that performs given set of commands on mult | Option | Description | |-------------------|----------------------------------| | `-f Supfile` | Custom path to Supfile | +| `-i`, `sshKey` | Set the the ssh key to use | | `-e`, `--env=[]` | Set environment variables | | `--only REGEXP` | Filter hosts matching regexp | | `--except REGEXP` | Filter out hosts matching regexp | @@ -43,6 +44,8 @@ networks: - api1.example.com - api2.example.com - api3.example.com + # Optional, override the ssh key to use for this network + ssh-key: ~/.ssh/prodKey staging: # fetch dynamic list of hosts inventory: curl http://example.com/latest/meta-data/hostname diff --git a/cmd/sup/main.go b/cmd/sup/main.go index e62b1e3..5e91bd5 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -16,6 +16,7 @@ import ( var ( supfile string envVars flagStringSlice + sshKey string onlyHosts string exceptHosts string @@ -46,6 +47,7 @@ func (f *flagStringSlice) Set(value string) error { func init() { flag.StringVar(&supfile, "f", "./Supfile", "Custom path to Supfile") flag.Var(&envVars, "e", "Set environment variables") + flag.StringVar(&sshKey, "i", "", "Set the ssh key to use") flag.Var(&envVars, "env", "Set environment variables") flag.StringVar(&onlyHosts, "only", "", "Filter hosts using regexp") flag.StringVar(&exceptHosts, "except", "", "Filter out hosts using regexp") @@ -280,6 +282,7 @@ func main() { } app.Debug(debug) app.Prefix(!disablePrefix) + app.SSHKey(sshKey) // Run all the commands in the given network. err = app.Run(network, commands...) diff --git a/ssh.go b/ssh.go index dba3f26..12c0b2b 100644 --- a/ssh.go +++ b/ssh.go @@ -16,18 +16,21 @@ import ( // Client is a wrapper over the SSH connection/sessions. type SSHClient struct { - conn *ssh.Client - sess *ssh.Session - user string - host string - remoteStdin io.WriteCloser - remoteStdout io.Reader - remoteStderr io.Reader - connOpened bool - sessOpened bool - running bool - env string //export FOO="bar"; export BAR="baz"; - color string + conn *ssh.Client + sess *ssh.Session + user string + host string + sshKeys []string // ssh key to use + remoteStdin io.WriteCloser + remoteStdout io.Reader + remoteStderr io.Reader + connOpened bool + sessOpened bool + running bool + env string //export FOO="bar"; export BAR="baz"; + color string + initAuthMethodOnce sync.Once + authMethod ssh.AuthMethod } type ErrConnect struct { @@ -40,6 +43,15 @@ func (e ErrConnect) Error() string { return fmt.Sprintf(`Connect("%v@%v"): %v`, e.User, e.Host, e.Reason) } +func newSSHClient() *SSHClient { + return &SSHClient{ + sshKeys: []string{ + os.Getenv("HOME") + "/.ssh/id_rsa", + os.Getenv("HOME") + "/.ssh/id_dsa", + }, + } +} + // parseHost parses and normalizes @ from a given string. func (c *SSHClient) parseHost(host string) error { c.host = host @@ -75,11 +87,8 @@ func (c *SSHClient) parseHost(host string) error { return nil } -var initAuthMethodOnce sync.Once -var authMethod ssh.AuthMethod - // initAuthMethod initiates SSH authentication method. -func initAuthMethod() { +func (c *SSHClient) initAuthMethod() { var signers []ssh.Signer // If there's a running SSH Agent, try to use its Private keys. @@ -90,11 +99,7 @@ func initAuthMethod() { } // Try to read user's SSH private keys form the standard paths. - files := []string{ - os.Getenv("HOME") + "/.ssh/id_rsa", - os.Getenv("HOME") + "/.ssh/id_dsa", - } - for _, file := range files { + for _, file := range c.sshKeys { data, err := ioutil.ReadFile(file) if err != nil { continue @@ -106,7 +111,7 @@ func initAuthMethod() { signers = append(signers, signer) } - authMethod = ssh.PublicKeys(signers...) + c.authMethod = ssh.PublicKeys(signers...) } // SSHDialFunc can dial an ssh server and return a client @@ -126,7 +131,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { return fmt.Errorf("Already connected") } - initAuthMethodOnce.Do(initAuthMethod) + c.initAuthMethodOnce.Do(c.initAuthMethod) err := c.parseHost(host) if err != nil { @@ -136,7 +141,7 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { config := &ssh.ClientConfig{ User: c.user, Auth: []ssh.AuthMethod{ - authMethod, + c.authMethod, }, } diff --git a/sup.go b/sup.go index 4c42745..06f1ae9 100644 --- a/sup.go +++ b/sup.go @@ -19,6 +19,7 @@ type Stackup struct { conf *Supfile debug bool prefix bool + sshKey string } func New(conf *Supfile) (*Stackup, error) { @@ -27,6 +28,26 @@ func New(conf *Supfile) (*Stackup, error) { }, nil } +// Returns the first defined key parameter, otherwise empty string +func firstDefinedKey(keys ...string) string { + for _, key := range keys { + if key != "" { + return expandHome(key) + } + } + return "" +} + +// Expands ~/foo -> /home/user/foo +func expandHome(path string) string { + if !strings.HasPrefix(path, "~/") { + return path + } + parts := strings.Split(path, "/") + parts[0] = os.Getenv("HOME") + return strings.Join(parts, "/") +} + // Run runs set of commands on multiple hosts defined by network sequentially. // TODO: This megamoth method needs a big refactor and should be split // to multiple smaller methods. @@ -45,7 +66,7 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error { // Create clients for every host (either SSH or Localhost). var bastion *SSHClient if network.Bastion != "" { - bastion = &SSHClient{} + bastion = newSSHClient() if err := bastion.Connect(network.Bastion); err != nil { return errors.Wrap(err, "connecting to bastion failed") } @@ -74,9 +95,12 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error { } // SSH client. - remote := &SSHClient{ - env: env + `export SUP_HOST="` + host + `";`, - color: Colors[i%len(Colors)], + remote := newSSHClient() + remote.env = env + `export SUP_HOST="` + host + `";` + remote.color = Colors[i%len(Colors)] + sshKey := firstDefinedKey(sup.sshKey, network.SSHKey) + if sshKey != "" { + remote.sshKeys = []string{sshKey} } if bastion != nil { @@ -251,3 +275,7 @@ func (sup *Stackup) Debug(value bool) { func (sup *Stackup) Prefix(value bool) { sup.prefix = value } + +func (sup *Stackup) SSHKey(value string) { + sup.sshKey = value +} diff --git a/supfile.go b/supfile.go index 2a7ce35..ae9533c 100644 --- a/supfile.go +++ b/supfile.go @@ -26,6 +26,7 @@ type Network struct { Env EnvList `yaml:"env"` Inventory string `yaml:"inventory"` Hosts []string `yaml:"hosts"` + SSHKey string `yaml:"ssh-key"` Bastion string `yaml:"bastion"` // Jump host for the environment }