create helper to validate channel name

add test cases
This commit is contained in:
Marcelo Pires 2018-09-12 09:42:45 +02:00
parent 17a2e4a408
commit b5bb59476c
3 changed files with 106 additions and 28 deletions

View File

@ -2,6 +2,7 @@ package subscription
import ( import (
"github.com/thesyncim/faye/message" "github.com/thesyncim/faye/message"
"regexp"
) )
type Unsubscriber func(subscription *Subscription) error type Unsubscriber func(subscription *Subscription) error
@ -63,3 +64,10 @@ func (s *Subscription) Unsubscribe() error {
func (s *Subscription) Publish(msg message.Data) (string, error) { func (s *Subscription) Publish(msg message.Data) (string, error) {
return s.pub(msg) return s.pub(msg)
} }
var validChannelName = regexp.MustCompile(`^\/(((([a-z]|[A-Z])|[0-9])|(\-|\_|\!|\~|\(|\)|\$|\@)))+(\/(((([a-z]|[A-Z])|[0-9])|(\-|\_|\!|\~|\(|\)|\$|\@)))+)*$`)
var validChannelPattern = regexp.MustCompile(`^(\/(((([a-z]|[A-Z])|[0-9])|(\-|\_|\!|\~|\(|\)|\$|\@)))+)*\/\*{1,2}$`)
func IsValidChannel(channel string) bool {
return validChannelName.MatchString(channel) || validChannelPattern.MatchString(channel)
}

View File

@ -0,0 +1,73 @@
package subscription
import "testing"
/*
assertEqual( ["/**", "/foo", "/*"],
Channel.expand("/foo") )
assertEqual( ["/**", "/foo/bar", "/foo/*", "/foo/**"],
Channel.expand("/foo/bar") )
assertEqual( ["/**", "/foo/bar/qux", "/foo/bar/*", "/foo/**", "/foo/bar/**"],
*/
func TestIsValidChannel(t *testing.T) {
type args struct {
channel string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "single asterisc",
args: args{
channel: "/*",
},
want: true,
},
{
name: "double asterisc",
args: args{
channel: "/**",
},
want: true,
},
{
name: "regular channel",
args: args{
channel: "/foo",
},
want: true,
},
{
name: "regular channel 2",
args: args{
channel: "/foo/bar",
},
want: true,
},
{
name: "invalid slash ending",
args: args{
channel: "/foo/",
},
want: false,
},
{
name: "invalid asterisc at the middle",
args: args{
channel: "/foo/**/bar",
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidChannel(tt.args.channel); got != tt.want {
t.Errorf("IsValidChannel() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -30,13 +30,10 @@ type Websocket struct {
once sync.Once once sync.Once
advice atomic.Value //type message.Advise advice atomic.Value //type message.Advise
stopCh chan error stopCh chan error //todo replace wth context
//subsMu sync.Mutex //todo sync.Map subscriptionsMu sync.Mutex //todo thread safe radix tree
//subs map[string]chan *message.Message subscriptions map[string][]*subscription.Subscription
subsMu2 sync.Mutex //todo sync.Map
subs2 map[string][]*subscription.Subscription
onPubResponseMu sync.Mutex //todo sync.Map onPubResponseMu sync.Mutex //todo sync.Map
onPublishResponse map[string]func(message *message.Message) onPublishResponse map[string]func(message *message.Message)
@ -53,7 +50,7 @@ func (w *Websocket) Init(endpoint string, options *transport.Options) error {
w.topts = options w.topts = options
w.msgID = &msgID w.msgID = &msgID
//w.subs = map[string]chan *message.Message{} //w.subs = map[string]chan *message.Message{}
w.subs2 = map[string][]*subscription.Subscription{} w.subscriptions = map[string][]*subscription.Subscription{}
w.onPublishResponse = map[string]func(message *message.Message){} w.onPublishResponse = map[string]func(message *message.Message){}
w.stopCh = make(chan error) w.stopCh = make(chan error)
w.conn, _, err = websocket.DefaultDialer.Dial(endpoint, options.Headers) w.conn, _, err = websocket.DefaultDialer.Dial(endpoint, options.Headers)
@ -88,8 +85,8 @@ func (w *Websocket) readWorker() error {
switch msg.Channel { switch msg.Channel {
case transport.MetaSubscribe: case transport.MetaSubscribe:
//handle MetaSubscribe resp //handle MetaSubscribe resp
w.subsMu2.Lock() w.subscriptionsMu.Lock()
subscriptions, ok := w.subs2[msg.Subscription] subscriptions, ok := w.subscriptions[msg.Subscription]
if !ok { if !ok {
panic("BUG: subscription not registered `" + msg.Subscription + "`") panic("BUG: subscription not registered `" + msg.Subscription + "`")
} }
@ -115,7 +112,7 @@ func (w *Websocket) readWorker() error {
subscriptions = subscriptions[:si+copy(subscriptions[si:], subscriptions[si+1:])] subscriptions = subscriptions[:si+copy(subscriptions[si:], subscriptions[si+1:])]
} }
w.subs2[msg.Subscription] = subscriptions w.subscriptions[msg.Subscription] = subscriptions
//v2 //v2
} else { } else {
for i := range subscriptions { for i := range subscriptions {
@ -128,7 +125,7 @@ func (w *Websocket) readWorker() error {
} }
} }
} }
w.subsMu2.Unlock() w.subscriptionsMu.Unlock()
} }
@ -139,8 +136,8 @@ func (w *Websocket) readWorker() error {
// 1. Publish // 1. Publish
// 2. Delivery // 2. Delivery
if transport.IsEventDelivery(msg) { if transport.IsEventDelivery(msg) {
w.subsMu2.Lock() w.subscriptionsMu.Lock()
subscriptions, ok := w.subs2[msg.Channel] subscriptions, ok := w.subscriptions[msg.Channel]
if ok { if ok {
//send to all listeners //send to all listeners
@ -155,7 +152,7 @@ func (w *Websocket) readWorker() error {
} }
} }
} }
w.subsMu2.Unlock() w.subscriptionsMu.Unlock()
continue continue
} }
@ -256,15 +253,15 @@ func (w *Websocket) Disconnect() error {
w.stopCh <- nil w.stopCh <- nil
close(w.stopCh) close(w.stopCh)
w.subsMu2.Lock() w.subscriptionsMu.Lock()
for i := range w.subs2 { for i := range w.subscriptions {
//close all listeners //close all listeners
for j := range w.subs2[i] { for j := range w.subscriptions[i] {
close(w.subs2[i][j].MsgChannel()) close(w.subscriptions[i][j].MsgChannel())
} }
delete(w.subs2, i) delete(w.subscriptions, i)
} }
w.subsMu2.Unlock() w.subscriptionsMu.Unlock()
return w.sendMessage(&m) return w.sendMessage(&m)
} }
@ -291,9 +288,9 @@ func (w *Websocket) Subscribe(channel string) (*subscription.Subscription, error
} }
sub := subscription.NewSubscription(id, channel, w.Unsubscribe, pub, inMsgCh, subRes) sub := subscription.NewSubscription(id, channel, w.Unsubscribe, pub, inMsgCh, subRes)
w.subsMu2.Lock() w.subscriptionsMu.Lock()
w.subs2[channel] = append(w.subs2[channel], sub) w.subscriptions[channel] = append(w.subscriptions[channel], sub)
w.subsMu2.Unlock() w.subscriptionsMu.Unlock()
//todo timeout here //todo timeout here
err := <-subRes err := <-subRes
@ -309,9 +306,9 @@ func (w *Websocket) Subscribe(channel string) (*subscription.Subscription, error
//the specified channel/subscription //the specified channel/subscription
func (w *Websocket) Unsubscribe(subscription *subscription.Subscription) error { func (w *Websocket) Unsubscribe(subscription *subscription.Subscription) error {
//https://docs.cometd.org/current/reference/#_bayeux_meta_unsubscribe //https://docs.cometd.org/current/reference/#_bayeux_meta_unsubscribe
w.subsMu2.Lock() w.subscriptionsMu.Lock()
defer w.subsMu2.Unlock() defer w.subscriptionsMu.Unlock()
subs, ok := w.subs2[subscription.Channel()] subs, ok := w.subscriptions[subscription.Channel()]
if ok { if ok {
var si = -1 var si = -1
for i := range subs { for i := range subs {
@ -324,10 +321,10 @@ func (w *Websocket) Unsubscribe(subscription *subscription.Subscription) error {
//remove the subscription //remove the subscription
subs = subs[:si+copy(subs[si:], subs[si+1:])] subs = subs[:si+copy(subs[si:], subs[si+1:])]
} }
w.subs2[subscription.Channel()] = subs w.subscriptions[subscription.Channel()] = subs
//if no more listeners to this subscription send unsubscribe to server //if no more listeners to this subscription send unsubscribe to server
if len(subs) == 0 { if len(subs) == 0 {
delete(w.subs2, subscription.Channel()) delete(w.subscriptions, subscription.Channel())
//remove onPublishResponse handler //remove onPublishResponse handler
w.onPubResponseMu.Lock() w.onPubResponseMu.Lock()
delete(w.onPublishResponse, subscription.Channel()) delete(w.onPublishResponse, subscription.Channel())