diff --git a/.gitignore b/.gitignore index f1c181e..408fb4b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ # Test binary, build with `go test -c` *.test +*/*node_modules +*/*package-lock.json # Output of the go coverage tool, specifically when used with LiteIDE *.out diff --git a/test/package.json b/test/package.json new file mode 100644 index 0000000..f37c823 --- /dev/null +++ b/test/package.json @@ -0,0 +1,15 @@ +{ + "name": "server", + "version": "1.0.0", + "description": "", + "main": "server.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1", + "start": "node server.js" + }, + "author": "", + "license": "ISC", + "dependencies": { + "faye": "^1.2.4" + } +} diff --git a/test/server.js b/test/server.js new file mode 100644 index 0000000..554153b --- /dev/null +++ b/test/server.js @@ -0,0 +1,24 @@ +var http = require('http'), + faye = require('faye'); + +var server = http.createServer(), + bayeux = new faye.NodeAdapter({mount: '/faye', timeout: 45}); + +var unauthorized = [ + '/unauthorized', +]; + +bayeux.addExtension({ + incoming: function (message, callback) { + if (message.channel === '/meta/subscribe') { + console.log(message) + if (unauthorized.indexOf(message.subscription) >= 0) { + message.error = '500::unauthorized channel'; + } + } + callback(message); + } +}); + +bayeux.attach(server); +server.listen(8000); \ No newline at end of file diff --git a/test/setup_test.go b/test/setup_test.go new file mode 100644 index 0000000..4b826f4 --- /dev/null +++ b/test/setup_test.go @@ -0,0 +1,120 @@ +package faye_test + +import ( + "fmt" + "github.com/thesyncim/faye" + "github.com/thesyncim/faye/extensions" + "github.com/thesyncim/faye/message" + "log" + "os" + "os/exec" + "sync" + "testing" + "time" +) + +type cancelFn func() + +func setup(t *testing.T) (cancelFn, error) { + cmd := exec.Command("npm", "start") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err := cmd.Start() + if err != nil { + return nil, err + } + + var cancel = func() { + exec.Command("taskkill", "/F", "/T", "/PID", fmt.Sprint(cmd.Process.Pid)).Run() + log.Println("canceled") + } + go func() { + select { + case <-time.After(time.Second * 30): + cancel() + t.Fatal("failed") + os.Exit(1) + + } + }() + + return cancel, nil +} + +func TestServerSubscribeAndPublish(t *testing.T) { + shutdown, err := setup(t) + if err != nil { + t.Fatal(err) + } + defer shutdown() + + debug := extensions.NewDebugExtension(os.Stdout) + + client, err := fayec.NewClient("ws://localhost:8000/faye", fayec.WithExtension(debug.InExtension, debug.OutExtension)) + if err != nil { + t.Fatal(err) + } + + client.OnPublishResponse("/test", func(message *message.Message) { + if !message.Successful { + t.Fatalf("failed to send message with id %s", message.Id) + } + }) + var done sync.WaitGroup + done.Add(10) + var delivered int + go func() { + client.Subscribe("/test", func(data message.Data) { + if data != "hello world" { + t.Fatalf("expecting: `hello world` got : %s", data) + } + delivered++ + done.Done() + }) + }() + + //give some time for setup + time.Sleep(time.Second) + for i := 0; i < 10; i++ { + id, err := client.Publish("/test", "hello world") + if err != nil { + t.Fatal(err) + } + log.Println(id, i) + } + + done.Wait() + err = client.Unsubscribe("/test") + if err != nil { + t.Fatal(err) + } + if delivered != 10 { + t.Fatal("message received after client unsubscribe") + } + log.Println("complete") + +} + +func TestSubscribeUnauthorizedChannel(t *testing.T) { + shutdown, err := setup(t) + if err != nil { + t.Fatal(err) + } + defer shutdown() + + debug := extensions.NewDebugExtension(os.Stdout) + + client, err := fayec.NewClient("ws://localhost:8000/faye", fayec.WithExtension(debug.InExtension, debug.OutExtension)) + if err != nil { + t.Fatal(err) + } + + err = client.Subscribe("/unauthorized", func(data message.Data) { + t.Fatal("received message on unauthorized channel") + }) + if err == nil { + t.Fatal("subscribed to an unauthorized channel") + } + +} diff --git a/transport/websocket/websocket.go b/transport/websocket/websocket.go index 9be1620..1018f42 100644 --- a/transport/websocket/websocket.go +++ b/transport/websocket/websocket.go @@ -20,11 +20,13 @@ func init() { //Websocket represents an websocket transport for the faye protocol type Websocket struct { TransportOpts *transport.Options - conn *websocket.Conn - clientID string - msgID *uint64 - once sync.Once - advice atomic.Value //type message.Advise + + connMu sync.Mutex + conn *websocket.Conn + clientID string + msgID *uint64 + once sync.Once + advice atomic.Value //type message.Advise stopCh chan error @@ -46,6 +48,7 @@ func (w *Websocket) Init(options *transport.Options) error { w.TransportOpts = options w.msgID = &msgID w.subs = map[string]chan *message.Message{} + w.onPublishResponse = map[string]func(message *message.Message){} w.stopCh = make(chan error) w.conn, _, err = websocket.DefaultDialer.Dial(options.Url, nil) if err != nil { @@ -68,6 +71,7 @@ func (w *Websocket) readWorker() error { } //dispatch msg := &payload[0] + w.applyInExtensions(msg) if msg.Advice != nil { w.handleAdvise(msg.Advice) @@ -116,14 +120,15 @@ func (w *Websocket) readWorker() error { if transport.IsEventDelivery(msg) { w.subsMu.Lock() - subscription := w.subs[msg.Channel] + subscription, ok := w.subs[msg.Channel] w.subsMu.Unlock() - w.applyInExtensions(msg) - - if subscription != nil { - subscription <- msg + if ok { + if subscription != nil { + subscription <- msg + } } + continue } @@ -145,6 +150,8 @@ func (w *Websocket) Name() string { } func (w *Websocket) sendMessage(m *message.Message) error { + w.connMu.Lock() + defer w.connMu.Unlock() w.applyOutExtensions(m) var payload []message.Message