// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "bytes" "fmt" "http" "http/httptest" "io" "log" "net" "strings" "sync" "testing" "url" ) var serverAddr string var once sync.Once func echoServer(ws *Conn) { io.Copy(ws, ws) } type Count struct { S string N int } func countServer(ws *Conn) { for { var count Count err := JSON.Receive(ws, &count) if err != nil { return } count.N++ count.S = strings.Repeat(count.S, count.N) err = JSON.Send(ws, count) if err != nil { return } } } func startServer() { http.Handle("/echo", Handler(echoServer)) http.Handle("/count", Handler(countServer)) server := httptest.NewServer(nil) serverAddr = server.Listener.Addr().String() log.Print("Test WebSocket server listening on ", serverAddr) } func newConfig(t *testing.T, path string) *Config { config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") return config } func TestEcho(t *testing.T) { once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } msg := []byte("hello, world\n") if _, err := conn.Write(msg); err != nil { t.Errorf("Write: %v", err) } var actual_msg = make([]byte, 512) n, err := conn.Read(actual_msg) if err != nil { t.Errorf("Read: %v", err) } actual_msg = actual_msg[0:n] if !bytes.Equal(msg, actual_msg) { t.Errorf("Echo: expected %q got %q", msg, actual_msg) } conn.Close() } func TestAddr(t *testing.T) { once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } ra := conn.RemoteAddr().String() if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") { t.Errorf("Bad remote addr: %v", ra) } la := conn.LocalAddr().String() if !strings.HasPrefix(la, "http://") { t.Errorf("Bad local addr: %v", la) } conn.Close() } func TestCount(t *testing.T) { once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/count"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } var count Count count.S = "hello" if err := JSON.Send(conn, count); err != nil { t.Errorf("Write: %v", err) } if err := JSON.Receive(conn, &count); err != nil { t.Errorf("Read: %v", err) } if count.N != 1 { t.Errorf("count: expected %d got %d", 1, count.N) } if count.S != "hello" { t.Errorf("count: expected %q got %q", "hello", count.S) } if err := JSON.Send(conn, count); err != nil { t.Errorf("Write: %v", err) } if err := JSON.Receive(conn, &count); err != nil { t.Errorf("Read: %v", err) } if count.N != 2 { t.Errorf("count: expected %d got %d", 2, count.N) } if count.S != "hellohello" { t.Errorf("count: expected %q got %q", "hellohello", count.S) } conn.Close() } func TestWithQuery(t *testing.T) { once.Do(startServer) client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } config := newConfig(t, "/echo") config.Location, err = url.ParseRequest(fmt.Sprintf("ws://%s/echo?q=v", serverAddr)) if err != nil { t.Fatal("location url", err) } ws, err := NewClient(config, client) if err != nil { t.Errorf("WebSocket handshake: %v", err) return } ws.Close() } func TestWithProtocol(t *testing.T) { once.Do(startServer) client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } config := newConfig(t, "/echo") config.Protocol = append(config.Protocol, "test") ws, err := NewClient(config, client) if err != nil { t.Errorf("WebSocket handshake: %v", err) return } ws.Close() } func TestHTTP(t *testing.T) { once.Do(startServer) // If the client did not send a handshake that matches the protocol // specification, the server should abort the WebSocket connection. _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) if err == nil { t.Error("Get: unexpected success") return } urlerr, ok := err.(*url.Error) if !ok { t.Errorf("Get: not url.Error %#v", err) return } if urlerr.Err != io.ErrUnexpectedEOF { t.Errorf("Get: error %#v", err) return } } func TestTrailingSpaces(t *testing.T) { // http://code.google.com/p/go/issues/detail?id=955 // The last runs of this create keys with trailing spaces that should not be // generated by the client. once.Do(startServer) config := newConfig(t, "/echo") for i := 0; i < 30; i++ { // body ws, err := DialConfig(config) if err != nil { t.Errorf("Dial #%d failed: %v", i, err) break } ws.Close() } } func TestSmallBuffer(t *testing.T) { // http://code.google.com/p/go/issues/detail?id=1145 // Read should be able to handle reading a fragment of a frame. once.Do(startServer) // websocket.Dial() client, err := net.Dial("tcp", serverAddr) if err != nil { t.Fatal("dialing", err) } conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } msg := []byte("hello, world\n") if _, err := conn.Write(msg); err != nil { t.Errorf("Write: %v", err) } var small_msg = make([]byte, 8) n, err := conn.Read(small_msg) if err != nil { t.Errorf("Read: %v", err) } if !bytes.Equal(msg[:len(small_msg)], small_msg) { t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) } var second_msg = make([]byte, len(msg)) n, err = conn.Read(second_msg) if err != nil { t.Errorf("Read: %v", err) } second_msg = second_msg[0:n] if !bytes.Equal(msg[len(small_msg):], second_msg) { t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) } conn.Close() }