// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "bytes" "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "io" "math/big" "net" "sync" ) type ServerConfig struct { rsa *rsa.PrivateKey rsaSerialized []byte // Rand provides the source of entropy for key exchange. If Rand is // nil, the cryptographic random reader in package crypto/rand will // be used. Rand io.Reader // NoClientAuth is true if clients are allowed to connect without // authenticating. NoClientAuth bool // PasswordCallback, if non-nil, is called when a user attempts to // authenticate using a password. It may be called concurrently from // several goroutines. PasswordCallback func(user, password string) bool // PublicKeyCallback, if non-nil, is called when a client attempts public // key authentication. It must return true iff the given public key is // valid for the given user. PublicKeyCallback func(user, algo string, pubkey []byte) bool // Cryptographic-related configuration. Crypto CryptoConfig } func (c *ServerConfig) rand() io.Reader { if c.Rand == nil { return rand.Reader } return c.Rand } // SetRSAPrivateKey sets the private key for a Server. A Server must have a // private key configured in order to accept connections. The private key must // be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" // typically contains such a key. func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error { block, _ := pem.Decode(pemBytes) if block == nil { return errors.New("ssh: no key found") } var err error s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return err } s.rsaSerialized = marshalRSA(s.rsa) return nil } // marshalRSA serializes an RSA private key according to RFC 4256, section 6.6. func marshalRSA(priv *rsa.PrivateKey) []byte { e := new(big.Int).SetInt64(int64(priv.E)) length := stringLength([]byte(hostAlgoRSA)) length += intLength(e) length += intLength(priv.N) ret := make([]byte, length) r := marshalString(ret, []byte(hostAlgoRSA)) r = marshalInt(r, e) r = marshalInt(r, priv.N) return ret } // parseRSA parses an RSA key according to RFC 4256, section 6.6. func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) { algo, in, ok := parseString(in) if !ok || string(algo) != hostAlgoRSA { return nil, false } bigE, in, ok := parseInt(in) if !ok || bigE.BitLen() > 24 { return nil, false } e := bigE.Int64() if e < 3 || e&1 == 0 { return nil, false } N, in, ok := parseInt(in) if !ok || len(in) > 0 { return nil, false } return &rsa.PublicKey{ N: N, E: int(e), }, true } func parseRSASig(in []byte) (sig []byte, ok bool) { algo, in, ok := parseString(in) if !ok || string(algo) != hostAlgoRSA { return nil, false } sig, in, ok = parseString(in) if len(in) > 0 { ok = false } return } // cachedPubKey contains the results of querying whether a public key is // acceptable for a user. The cache only applies to a single ServerConn. type cachedPubKey struct { user, algo string pubKey []byte result bool } const maxCachedPubKeys = 16 // A ServerConn represents an incomming connection. type ServerConn struct { *transport config *ServerConfig channels map[uint32]*channel nextChanId uint32 // lock protects err and also allows Channels to serialise their writes // to out. lock sync.RWMutex err error // cachedPubKeys contains the cache results of tests for public keys. // Since SSH clients will query whether a public key is acceptable // before attempting to authenticate with it, we end up with duplicate // queries for public key validity. cachedPubKeys []cachedPubKey } // Server returns a new SSH server connection // using c as the underlying transport. func Server(c net.Conn, config *ServerConfig) *ServerConn { conn := &ServerConn{ transport: newTransport(c, config.rand()), channels: make(map[uint32]*channel), config: config, } return conn } // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The // returned values are given the same names as in RFC 4253, section 8. func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err error) { packet, err := s.readPacket() if err != nil { return } var kexDHInit kexDHInitMsg if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil { return } if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 { return nil, nil, errors.New("client DH parameter out of bounds") } y, err := rand.Int(s.config.rand(), group.p) if err != nil { return } Y := new(big.Int).Exp(group.g, y, group.p) kInt := new(big.Int).Exp(kexDHInit.X, y, group.p) var serializedHostKey []byte switch hostKeyAlgo { case hostAlgoRSA: serializedHostKey = s.config.rsaSerialized default: return nil, nil, errors.New("internal error") } h := hashFunc.New() writeString(h, magics.clientVersion) writeString(h, magics.serverVersion) writeString(h, magics.clientKexInit) writeString(h, magics.serverKexInit) writeString(h, serializedHostKey) writeInt(h, kexDHInit.X) writeInt(h, Y) K = make([]byte, intLength(kInt)) marshalInt(K, kInt) h.Write(K) H = h.Sum(nil) h.Reset() h.Write(H) hh := h.Sum(nil) var sig []byte switch hostKeyAlgo { case hostAlgoRSA: sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh) if err != nil { return } default: return nil, nil, errors.New("internal error") } serializedSig := serializeSignature(hostAlgoRSA, sig) kexDHReply := kexDHReplyMsg{ HostKey: serializedHostKey, Y: Y, Signature: serializedSig, } packet = marshal(msgKexDHReply, kexDHReply) err = s.writePacket(packet) return } // serverVersion is the fixed identification string that Server will use. var serverVersion = []byte("SSH-2.0-Go\r\n") // Handshake performs an SSH transport and client authentication on the given ServerConn. func (s *ServerConn) Handshake() error { var magics handshakeMagics if _, err := s.Write(serverVersion); err != nil { return err } if err := s.Flush(); err != nil { return err } magics.serverVersion = serverVersion[:len(serverVersion)-2] version, err := readVersion(s) if err != nil { return err } magics.clientVersion = version serverKexInit := kexInitMsg{ KexAlgos: supportedKexAlgos, ServerHostKeyAlgos: supportedHostKeyAlgos, CiphersClientServer: s.config.Crypto.ciphers(), CiphersServerClient: s.config.Crypto.ciphers(), MACsClientServer: supportedMACs, MACsServerClient: supportedMACs, CompressionClientServer: supportedCompressions, CompressionServerClient: supportedCompressions, } kexInitPacket := marshal(msgKexInit, serverKexInit) magics.serverKexInit = kexInitPacket if err := s.writePacket(kexInitPacket); err != nil { return err } packet, err := s.readPacket() if err != nil { return err } magics.clientKexInit = packet var clientKexInit kexInitMsg if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil { return err } kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, &clientKexInit, &serverKexInit) if !ok { return errors.New("ssh: no common algorithms") } if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { // The client sent a Kex message for the wrong algorithm, // which we have to ignore. if _, err := s.readPacket(); err != nil { return err } } var H, K []byte var hashFunc crypto.Hash switch kexAlgo { case kexAlgoDH14SHA1: hashFunc = crypto.SHA1 dhGroup14Once.Do(initDHGroup14) H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) default: err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo) } if err != nil { return err } if err = s.writePacket([]byte{msgNewKeys}); err != nil { return err } if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { return err } if packet, err = s.readPacket(); err != nil { return err } if packet[0] != msgNewKeys { return UnexpectedMessageError{msgNewKeys, packet[0]} } if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil { return err } if packet, err = s.readPacket(); err != nil { return err } var serviceRequest serviceRequestMsg if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil { return err } if serviceRequest.Service != serviceUserAuth { return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") } serviceAccept := serviceAcceptMsg{ Service: serviceUserAuth, } if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil { return err } if err = s.authenticate(H); err != nil { return err } return nil } func isAcceptableAlgo(algo string) bool { return algo == hostAlgoRSA } // testPubKey returns true if the given public key is acceptable for the user. func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool { if s.config.PublicKeyCallback == nil || !isAcceptableAlgo(algo) { return false } for _, c := range s.cachedPubKeys { if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) { return c.result } } result := s.config.PublicKeyCallback(user, algo, pubKey) if len(s.cachedPubKeys) < maxCachedPubKeys { c := cachedPubKey{ user: user, algo: algo, pubKey: make([]byte, len(pubKey)), result: result, } copy(c.pubKey, pubKey) s.cachedPubKeys = append(s.cachedPubKeys, c) } return result } func (s *ServerConn) authenticate(H []byte) error { var userAuthReq userAuthRequestMsg var err error var packet []byte userAuthLoop: for { if packet, err = s.readPacket(); err != nil { return err } if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil { return err } if userAuthReq.Service != serviceSSH { return errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) } switch userAuthReq.Method { case "none": if s.config.NoClientAuth { break userAuthLoop } case "password": if s.config.PasswordCallback == nil { break } payload := userAuthReq.Payload if len(payload) < 1 || payload[0] != 0 { return ParseError{msgUserAuthRequest} } payload = payload[1:] password, payload, ok := parseString(payload) if !ok || len(payload) > 0 { return ParseError{msgUserAuthRequest} } if s.config.PasswordCallback(userAuthReq.User, string(password)) { break userAuthLoop } case "publickey": if s.config.PublicKeyCallback == nil { break } payload := userAuthReq.Payload if len(payload) < 1 { return ParseError{msgUserAuthRequest} } isQuery := payload[0] == 0 payload = payload[1:] algoBytes, payload, ok := parseString(payload) if !ok { return ParseError{msgUserAuthRequest} } algo := string(algoBytes) pubKey, payload, ok := parseString(payload) if !ok { return ParseError{msgUserAuthRequest} } if isQuery { // The client can query if the given public key // would be ok. if len(payload) > 0 { return ParseError{msgUserAuthRequest} } if s.testPubKey(userAuthReq.User, algo, pubKey) { okMsg := userAuthPubKeyOkMsg{ Algo: algo, PubKey: string(pubKey), } if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil { return err } continue userAuthLoop } } else { sig, payload, ok := parseString(payload) if !ok || len(payload) > 0 { return ParseError{msgUserAuthRequest} } if !isAcceptableAlgo(algo) { break } rsaSig, ok := parseRSASig(sig) if !ok { return ParseError{msgUserAuthRequest} } signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey) switch algo { case hostAlgoRSA: hashFunc := crypto.SHA1 h := hashFunc.New() h.Write(signedData) digest := h.Sum(nil) rsaKey, ok := parseRSA(pubKey) if !ok { return ParseError{msgUserAuthRequest} } if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil { return ParseError{msgUserAuthRequest} } default: return errors.New("ssh: isAcceptableAlgo incorrect") } if s.testPubKey(userAuthReq.User, algo, pubKey) { break userAuthLoop } } } var failureMsg userAuthFailureMsg if s.config.PasswordCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "password") } if s.config.PublicKeyCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "publickey") } if len(failureMsg.Methods) == 0 { return errors.New("ssh: no authentication methods configured but NoClientAuth is also false") } if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { return err } } packet = []byte{msgUserAuthSuccess} if err = s.writePacket(packet); err != nil { return err } return nil } const defaultWindowSize = 32768 // Accept reads and processes messages on a ServerConn. It must be called // in order to demultiplex messages to any resulting Channels. func (s *ServerConn) Accept() (Channel, error) { if s.err != nil { return nil, s.err } for { packet, err := s.readPacket() if err != nil { s.lock.Lock() s.err = err s.lock.Unlock() for _, c := range s.channels { c.dead = true c.handleData(nil) } return nil, err } switch packet[0] { case msgChannelData: if len(packet) < 9 { // malformed data packet return nil, ParseError{msgChannelData} } peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4]) s.lock.Lock() c, ok := s.channels[peersId] if !ok { s.lock.Unlock() continue } if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 { packet = packet[9:] c.handleData(packet[:length]) } s.lock.Unlock() default: switch msg := decode(packet).(type) { case *channelOpenMsg: c := new(channel) c.chanType = msg.ChanType c.theirId = msg.PeersId c.theirWindow = msg.PeersWindow c.maxPacketSize = msg.MaxPacketSize c.extraData = msg.TypeSpecificData c.myWindow = defaultWindowSize c.serverConn = s c.cond = sync.NewCond(&c.lock) c.pendingData = make([]byte, c.myWindow) s.lock.Lock() c.myId = s.nextChanId s.nextChanId++ s.channels[c.myId] = c s.lock.Unlock() return c, nil case *channelRequestMsg: s.lock.Lock() c, ok := s.channels[msg.PeersId] if !ok { s.lock.Unlock() continue } c.handlePacket(msg) s.lock.Unlock() case *channelEOFMsg: s.lock.Lock() c, ok := s.channels[msg.PeersId] if !ok { s.lock.Unlock() continue } c.handlePacket(msg) s.lock.Unlock() case *channelCloseMsg: s.lock.Lock() c, ok := s.channels[msg.PeersId] if !ok { s.lock.Unlock() continue } c.handlePacket(msg) s.lock.Unlock() case *globalRequestMsg: if msg.WantReply { if err := s.writePacket([]byte{msgRequestFailure}); err != nil { return nil, err } } case UnexpectedMessageError: return nil, msg case *disconnectMsg: return nil, io.EOF default: // Unknown message. Ignore. } } } panic("unreachable") } // A Listener implements a network listener (net.Listener) for SSH connections. type Listener struct { listener net.Listener config *ServerConfig } // Accept waits for and returns the next incoming SSH connection. // The receiver should call Handshake() in another goroutine // to avoid blocking the accepter. func (l *Listener) Accept() (*ServerConn, error) { c, err := l.listener.Accept() if err != nil { return nil, err } conn := Server(c, l.config) return conn, nil } // Addr returns the listener's network address. func (l *Listener) Addr() net.Addr { return l.listener.Addr() } // Close closes the listener. func (l *Listener) Close() error { return l.listener.Close() } // Listen creates an SSH listener accepting connections on // the given network address using net.Listen. func Listen(network, addr string, config *ServerConfig) (*Listener, error) { l, err := net.Listen(network, addr) if err != nil { return nil, err } return &Listener{ l, config, }, nil }