aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/exp/ssh/client.go
blob: 331138581fef4ddf19ea6210d37c2f0063970234 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssh

import (
	"big"
	"crypto"
	"crypto/rand"
	"fmt"
	"io"
	"net"
	"os"
	"sync"
)

// clientVersion is the fixed identification string that the client will use.
var clientVersion = []byte("SSH-2.0-Go\r\n")

// ClientConn represents the client side of an SSH connection.
type ClientConn struct {
	*transport
	config *ClientConfig
	chanlist
}

// Client returns a new SSH client connection using c as the underlying transport.
func Client(c net.Conn, config *ClientConfig) (*ClientConn, os.Error) {
	conn := &ClientConn{
		transport: newTransport(c, config.rand()),
		config:    config,
	}
	if err := conn.handshake(); err != nil {
		conn.Close()
		return nil, err
	}
	if err := conn.authenticate(); err != nil {
		conn.Close()
		return nil, err
	}
	go conn.mainLoop()
	return conn, nil
}

// handshake performs the client side key exchange. See RFC 4253 Section 7.
func (c *ClientConn) handshake() os.Error {
	var magics handshakeMagics

	if _, err := c.Write(clientVersion); err != nil {
		return err
	}
	if err := c.Flush(); err != nil {
		return err
	}
	magics.clientVersion = clientVersion[:len(clientVersion)-2]

	// read remote server version
	version, err := readVersion(c)
	if err != nil {
		return err
	}
	magics.serverVersion = version
	clientKexInit := kexInitMsg{
		KexAlgos:                supportedKexAlgos,
		ServerHostKeyAlgos:      supportedHostKeyAlgos,
		CiphersClientServer:     supportedCiphers,
		CiphersServerClient:     supportedCiphers,
		MACsClientServer:        supportedMACs,
		MACsServerClient:        supportedMACs,
		CompressionClientServer: supportedCompressions,
		CompressionServerClient: supportedCompressions,
	}
	kexInitPacket := marshal(msgKexInit, clientKexInit)
	magics.clientKexInit = kexInitPacket

	if err := c.writePacket(kexInitPacket); err != nil {
		return err
	}
	packet, err := c.readPacket()
	if err != nil {
		return err
	}

	magics.serverKexInit = packet

	var serverKexInit kexInitMsg
	if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
		return err
	}

	kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
	if !ok {
		return os.NewError("ssh: no common algorithms")
	}

	if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
		// The server sent a Kex message for the wrong algorithm,
		// which we have to ignore.
		if _, err := c.readPacket(); err != nil {
			return err
		}
	}

	var H, K []byte
	var hashFunc crypto.Hash
	switch kexAlgo {
	case kexAlgoDH14SHA1:
		hashFunc = crypto.SHA1
		dhGroup14Once.Do(initDHGroup14)
		H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
	default:
		err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
	}
	if err != nil {
		return err
	}

	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
		return err
	}
	if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
		return err
	}
	if packet, err = c.readPacket(); err != nil {
		return err
	}
	if packet[0] != msgNewKeys {
		return UnexpectedMessageError{msgNewKeys, packet[0]}
	}
	return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
}

// authenticate authenticates with the remote server. See RFC 4252. 
// Only "password" authentication is supported.
func (c *ClientConn) authenticate() os.Error {
	if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
		return err
	}
	packet, err := c.readPacket()
	if err != nil {
		return err
	}

	var serviceAccept serviceAcceptMsg
	if err = unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
		return err
	}

	// TODO(dfc) support proper authentication method negotation
	method := "none"
	if c.config.Password != "" {
		method = "password"
	}
	if err := c.sendUserAuthReq(method); err != nil {
		return err
	}

	if packet, err = c.readPacket(); err != nil {
		return err
	}

	if packet[0] != msgUserAuthSuccess {
		return UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
	}
	return nil
}

func (c *ClientConn) sendUserAuthReq(method string) os.Error {
	length := stringLength([]byte(c.config.Password)) + 1
	payload := make([]byte, length)
	// always false for password auth, see RFC 4252 Section 8.
	payload[0] = 0
	marshalString(payload[1:], []byte(c.config.Password))

	return c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
		User:    c.config.User,
		Service: serviceSSH,
		Method:  method,
		Payload: payload,
	}))
}

// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
// returned values are given the same names as in RFC 4253, section 8.
func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, os.Error) {
	x, err := rand.Int(c.config.rand(), group.p)
	if err != nil {
		return nil, nil, err
	}
	X := new(big.Int).Exp(group.g, x, group.p)
	kexDHInit := kexDHInitMsg{
		X: X,
	}
	if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
		return nil, nil, err
	}

	packet, err := c.readPacket()
	if err != nil {
		return nil, nil, err
	}

	var kexDHReply = new(kexDHReplyMsg)
	if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
		return nil, nil, err
	}

	if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
		return nil, nil, os.NewError("server DH parameter out of bounds")
	}

	kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
	h := hashFunc.New()
	writeString(h, magics.clientVersion)
	writeString(h, magics.serverVersion)
	writeString(h, magics.clientKexInit)
	writeString(h, magics.serverKexInit)
	writeString(h, kexDHReply.HostKey)
	writeInt(h, X)
	writeInt(h, kexDHReply.Y)
	K := make([]byte, intLength(kInt))
	marshalInt(K, kInt)
	h.Write(K)

	H := h.Sum()

	return H, K, nil
}

// openChan opens a new client channel. The most common session type is "session". 
// The full set of valid session types are listed in RFC 4250 4.9.1.
func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) {
	ch := c.newChan(c.transport)
	if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
		ChanType:      typ,
		PeersId:       ch.id,
		PeersWindow:   1 << 14,
		MaxPacketSize: 1 << 15, // RFC 4253 6.1
	})); err != nil {
		c.chanlist.remove(ch.id)
		return nil, err
	}
	// wait for response
	switch msg := (<-ch.msg).(type) {
	case *channelOpenConfirmMsg:
		ch.peersId = msg.MyId
	case *channelOpenFailureMsg:
		c.chanlist.remove(ch.id)
		return nil, os.NewError(msg.Message)
	default:
		c.chanlist.remove(ch.id)
		return nil, os.NewError("Unexpected packet")
	}
	return ch, nil
}

// mainloop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
	for {
		packet, err := c.readPacket()
		if err != nil {
			// TODO(dfc) signal the underlying close to all channels
			c.Close()
			return
		}
		// TODO(dfc) A note on blocking channel use. 
		// The msg, win, data and dataExt channels of a clientChan can 
		// cause this loop to block indefinately if the consumer does 
		// not service them. 
		switch msg := decode(packet).(type) {
		case *channelOpenMsg:
			c.getChan(msg.PeersId).msg <- msg
		case *channelOpenConfirmMsg:
			c.getChan(msg.PeersId).msg <- msg
		case *channelOpenFailureMsg:
			c.getChan(msg.PeersId).msg <- msg
		case *channelCloseMsg:
			ch := c.getChan(msg.PeersId)
			close(ch.win)
			close(ch.data)
			close(ch.dataExt)
			c.chanlist.remove(msg.PeersId)
		case *channelEOFMsg:
			c.getChan(msg.PeersId).msg <- msg
		case *channelRequestSuccessMsg:
			c.getChan(msg.PeersId).msg <- msg
		case *channelRequestFailureMsg:
			c.getChan(msg.PeersId).msg <- msg
		case *channelRequestMsg:
			c.getChan(msg.PeersId).msg <- msg
		case *windowAdjustMsg:
			c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
		case *channelData:
			c.getChan(msg.PeersId).data <- msg.Payload
		case *channelExtendedData:
			// RFC 4254 5.2 defines data_type_code 1 to be data destined 
			// for stderr on interactive sessions. Other data types are
			// silently discarded.
			if msg.Datatype == 1 {
				c.getChan(msg.PeersId).dataExt <- msg.Payload
			}
		default:
			fmt.Printf("mainLoop: unhandled %#v\n", msg)
		}
	}
}

// Dial connects to the given network address using net.Dial and 
// then initiates a SSH handshake, returning the resulting client connection.
func Dial(network, addr string, config *ClientConfig) (*ClientConn, os.Error) {
	conn, err := net.Dial(network, addr)
	if err != nil {
		return nil, err
	}
	return Client(conn, config)
}

// A ClientConfig structure is used to configure a ClientConn. After one has 
// been passed to an SSH function it must not be modified.
type ClientConfig struct {
	// Rand provides the source of entropy for key exchange. If Rand is 
	// nil, the cryptographic random reader in package crypto/rand will 
	// be used.
	Rand io.Reader

	// The username to authenticate.
	User string

	// Used for "password" method authentication.
	Password string
}

func (c *ClientConfig) rand() io.Reader {
	if c.Rand == nil {
		return rand.Reader
	}
	return c.Rand
}

// A clientChan represents a single RFC 4254 channel that is multiplexed 
// over a single SSH connection.
type clientChan struct {
	packetWriter
	id, peersId uint32
	data        chan []byte      // receives the payload of channelData messages
	dataExt     chan []byte      // receives the payload of channelExtendedData messages
	win         chan int         // receives window adjustments
	msg         chan interface{} // incoming messages
}

func newClientChan(t *transport, id uint32) *clientChan {
	return &clientChan{
		packetWriter: t,
		id:           id,
		data:         make(chan []byte, 16),
		dataExt:      make(chan []byte, 16),
		win:          make(chan int, 16),
		msg:          make(chan interface{}, 16),
	}
}

// Close closes the channel. This does not close the underlying connection.
func (c *clientChan) Close() os.Error {
	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
		PeersId: c.id,
	}))
}

func (c *clientChan) sendChanReq(req channelRequestMsg) os.Error {
	if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
		return err
	}
	msg := <-c.msg
	if _, ok := msg.(*channelRequestSuccessMsg); ok {
		return nil
	}
	return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
}

// Thread safe channel list.
type chanlist struct {
	// protects concurrent access to chans
	sync.Mutex
	// chans are indexed by the local id of the channel, clientChan.id.
	// The PeersId value of messages received by ClientConn.mainloop is
	// used to locate the right local clientChan in this slice.
	chans []*clientChan
}

// Allocate a new ClientChan with the next avail local id.
func (c *chanlist) newChan(t *transport) *clientChan {
	c.Lock()
	defer c.Unlock()
	for i := range c.chans {
		if c.chans[i] == nil {
			ch := newClientChan(t, uint32(i))
			c.chans[i] = ch
			return ch
		}
	}
	i := len(c.chans)
	ch := newClientChan(t, uint32(i))
	c.chans = append(c.chans, ch)
	return ch
}

func (c *chanlist) getChan(id uint32) *clientChan {
	c.Lock()
	defer c.Unlock()
	return c.chans[int(id)]
}

func (c *chanlist) remove(id uint32) {
	c.Lock()
	defer c.Unlock()
	c.chans[int(id)] = nil
}

// A chanWriter represents the stdin of a remote process.
type chanWriter struct {
	win          chan int // receives window adjustments
	id           uint32   // this channel's id
	rwin         int      // current rwin size
	packetWriter          // for sending channelDataMsg
}

// Write writes data to the remote process's standard input.
func (w *chanWriter) Write(data []byte) (n int, err os.Error) {
	for {
		if w.rwin == 0 {
			win, ok := <-w.win
			if !ok {
				return 0, os.EOF
			}
			w.rwin += win
			continue
		}
		n = len(data)
		packet := make([]byte, 0, 9+n)
		packet = append(packet, msgChannelData,
			byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
			byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
		err = w.writePacket(append(packet, data...))
		w.rwin -= n
		return
	}
	panic("unreachable")
}

func (w *chanWriter) Close() os.Error {
	return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
}

// A chanReader represents stdout or stderr of a remote process.
type chanReader struct {
	// TODO(dfc) a fixed size channel may not be the right data structure.
	// If writes to this channel block, they will block mainLoop, making
	// it unable to receive new messages from the remote side.
	data         chan []byte // receives data from remote
	id           uint32
	packetWriter // for sending windowAdjustMsg
	buf          []byte
}

// Read reads data from the remote process's stdout or stderr.
func (r *chanReader) Read(data []byte) (int, os.Error) {
	var ok bool
	for {
		if len(r.buf) > 0 {
			n := copy(data, r.buf)
			r.buf = r.buf[n:]
			msg := windowAdjustMsg{
				PeersId:         r.id,
				AdditionalBytes: uint32(n),
			}
			return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
		}
		r.buf, ok = <-r.data
		if !ok {
			return 0, os.EOF
		}
	}
	panic("unreachable")
}

func (r *chanReader) Close() os.Error {
	return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
}