diff --git a/subscription/subscription.go b/subscription/subscription.go index 60a1c7c..994271e 100644 --- a/subscription/subscription.go +++ b/subscription/subscription.go @@ -2,6 +2,7 @@ package subscription import ( "github.com/thesyncim/faye/message" + "regexp" ) type Unsubscriber func(subscription *Subscription) error @@ -63,3 +64,10 @@ func (s *Subscription) Unsubscribe() error { func (s *Subscription) Publish(msg message.Data) (string, error) { return s.pub(msg) } + +var validChannelName = regexp.MustCompile(`^\/(((([a-z]|[A-Z])|[0-9])|(\-|\_|\!|\~|\(|\)|\$|\@)))+(\/(((([a-z]|[A-Z])|[0-9])|(\-|\_|\!|\~|\(|\)|\$|\@)))+)*$`) +var validChannelPattern = regexp.MustCompile(`^(\/(((([a-z]|[A-Z])|[0-9])|(\-|\_|\!|\~|\(|\)|\$|\@)))+)*\/\*{1,2}$`) + +func IsValidChannel(channel string) bool { + return validChannelName.MatchString(channel) || validChannelPattern.MatchString(channel) +} diff --git a/subscription/subscription_test.go b/subscription/subscription_test.go new file mode 100644 index 0000000..049471b --- /dev/null +++ b/subscription/subscription_test.go @@ -0,0 +1,73 @@ +package subscription + +import "testing" + +/* + assertEqual( ["/**", "/foo", "/*"], + Channel.expand("/foo") ) + + assertEqual( ["/**", "/foo/bar", "/foo/*", "/foo/**"], + Channel.expand("/foo/bar") ) + + assertEqual( ["/**", "/foo/bar/qux", "/foo/bar/*", "/foo/**", "/foo/bar/**"], +*/ +func TestIsValidChannel(t *testing.T) { + type args struct { + channel string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "single asterisc", + args: args{ + channel: "/*", + }, + want: true, + }, + { + name: "double asterisc", + args: args{ + channel: "/**", + }, + want: true, + }, + { + name: "regular channel", + args: args{ + channel: "/foo", + }, + want: true, + }, + { + name: "regular channel 2", + args: args{ + channel: "/foo/bar", + }, + want: true, + }, + { + name: "invalid slash ending", + args: args{ + channel: "/foo/", + }, + want: false, + }, + { + name: "invalid asterisc at the middle", + args: args{ + channel: "/foo/**/bar", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsValidChannel(tt.args.channel); got != tt.want { + t.Errorf("IsValidChannel() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/transport/websocket/websocket.go b/transport/websocket/websocket.go index e90d7ce..9c71647 100644 --- a/transport/websocket/websocket.go +++ b/transport/websocket/websocket.go @@ -30,13 +30,10 @@ type Websocket struct { once sync.Once advice atomic.Value //type message.Advise - stopCh chan error + stopCh chan error //todo replace wth context - //subsMu sync.Mutex //todo sync.Map - //subs map[string]chan *message.Message - - subsMu2 sync.Mutex //todo sync.Map - subs2 map[string][]*subscription.Subscription + subscriptionsMu sync.Mutex //todo thread safe radix tree + subscriptions map[string][]*subscription.Subscription onPubResponseMu sync.Mutex //todo sync.Map onPublishResponse map[string]func(message *message.Message) @@ -53,7 +50,7 @@ func (w *Websocket) Init(endpoint string, options *transport.Options) error { w.topts = options w.msgID = &msgID //w.subs = map[string]chan *message.Message{} - w.subs2 = map[string][]*subscription.Subscription{} + w.subscriptions = map[string][]*subscription.Subscription{} w.onPublishResponse = map[string]func(message *message.Message){} w.stopCh = make(chan error) w.conn, _, err = websocket.DefaultDialer.Dial(endpoint, options.Headers) @@ -88,8 +85,8 @@ func (w *Websocket) readWorker() error { switch msg.Channel { case transport.MetaSubscribe: //handle MetaSubscribe resp - w.subsMu2.Lock() - subscriptions, ok := w.subs2[msg.Subscription] + w.subscriptionsMu.Lock() + subscriptions, ok := w.subscriptions[msg.Subscription] if !ok { panic("BUG: subscription not registered `" + msg.Subscription + "`") } @@ -115,7 +112,7 @@ func (w *Websocket) readWorker() error { subscriptions = subscriptions[:si+copy(subscriptions[si:], subscriptions[si+1:])] } - w.subs2[msg.Subscription] = subscriptions + w.subscriptions[msg.Subscription] = subscriptions //v2 } else { for i := range subscriptions { @@ -128,7 +125,7 @@ func (w *Websocket) readWorker() error { } } } - w.subsMu2.Unlock() + w.subscriptionsMu.Unlock() } @@ -139,8 +136,8 @@ func (w *Websocket) readWorker() error { // 1. Publish // 2. Delivery if transport.IsEventDelivery(msg) { - w.subsMu2.Lock() - subscriptions, ok := w.subs2[msg.Channel] + w.subscriptionsMu.Lock() + subscriptions, ok := w.subscriptions[msg.Channel] if ok { //send to all listeners @@ -155,7 +152,7 @@ func (w *Websocket) readWorker() error { } } } - w.subsMu2.Unlock() + w.subscriptionsMu.Unlock() continue } @@ -256,15 +253,15 @@ func (w *Websocket) Disconnect() error { w.stopCh <- nil close(w.stopCh) - w.subsMu2.Lock() - for i := range w.subs2 { + w.subscriptionsMu.Lock() + for i := range w.subscriptions { //close all listeners - for j := range w.subs2[i] { - close(w.subs2[i][j].MsgChannel()) + for j := range w.subscriptions[i] { + close(w.subscriptions[i][j].MsgChannel()) } - delete(w.subs2, i) + delete(w.subscriptions, i) } - w.subsMu2.Unlock() + w.subscriptionsMu.Unlock() return w.sendMessage(&m) } @@ -291,9 +288,9 @@ func (w *Websocket) Subscribe(channel string) (*subscription.Subscription, error } sub := subscription.NewSubscription(id, channel, w.Unsubscribe, pub, inMsgCh, subRes) - w.subsMu2.Lock() - w.subs2[channel] = append(w.subs2[channel], sub) - w.subsMu2.Unlock() + w.subscriptionsMu.Lock() + w.subscriptions[channel] = append(w.subscriptions[channel], sub) + w.subscriptionsMu.Unlock() //todo timeout here err := <-subRes @@ -309,9 +306,9 @@ func (w *Websocket) Subscribe(channel string) (*subscription.Subscription, error //the specified channel/subscription func (w *Websocket) Unsubscribe(subscription *subscription.Subscription) error { //https://docs.cometd.org/current/reference/#_bayeux_meta_unsubscribe - w.subsMu2.Lock() - defer w.subsMu2.Unlock() - subs, ok := w.subs2[subscription.Channel()] + w.subscriptionsMu.Lock() + defer w.subscriptionsMu.Unlock() + subs, ok := w.subscriptions[subscription.Channel()] if ok { var si = -1 for i := range subs { @@ -324,10 +321,10 @@ func (w *Websocket) Unsubscribe(subscription *subscription.Subscription) error { //remove the subscription subs = subs[:si+copy(subs[si:], subs[si+1:])] } - w.subs2[subscription.Channel()] = subs + w.subscriptions[subscription.Channel()] = subs //if no more listeners to this subscription send unsubscribe to server if len(subs) == 0 { - delete(w.subs2, subscription.Channel()) + delete(w.subscriptions, subscription.Channel()) //remove onPublishResponse handler w.onPubResponseMu.Lock() delete(w.onPublishResponse, subscription.Channel())