diff --git a/transport/websocket/websocket.go b/transport/websocket/websocket.go index 4000a25..78239bf 100644 --- a/transport/websocket/websocket.go +++ b/transport/websocket/websocket.go @@ -8,6 +8,7 @@ import ( "log" "strconv" "strings" + "sync" "sync/atomic" ) @@ -29,7 +30,9 @@ type Websocket struct { conn *websocket.Conn clientID string msgID *uint64 - subs map[string]chan struct{} + + subsMu sync.Mutex //todo sync.Map + subs map[string]chan *message.Message } var _ transport.Transport = (*Websocket)(nil) @@ -41,16 +44,48 @@ func (w *Websocket) Init(options *transport.Options) error { ) w.TransportOpts = options w.msgID = &msgID - w.subs = map[string]chan struct{}{} + w.subs = map[string]chan *message.Message{} w.conn, _, err = websocket.DefaultDialer.Dial(options.Url, nil) if err != nil { return err } return nil } + +func (w *Websocket) readWorker() error { + var payload []message.Message + for { + err := w.conn.ReadJSON(&payload) + if err != nil { + return err + } + //dispatch + msg := payload[0] + if strings.HasPrefix(msg.Channel, "/meta") { + continue //todo update introspect message and update state + } + + w.subsMu.Lock() + subscription := w.subs[msg.Channel] + w.subsMu.Unlock() + + subscription <- &msg + } + +} + func (w *Websocket) Name() string { return transportName } + +func (w *Websocket) sendMessage(m *message.Message) error { + var payload []message.Message + payload = append(payload, *m) + if Debug { + log.Println("sending request", debugJson(payload)) + } + return w.conn.WriteJSON(payload) +} func (w *Websocket) nextMsgID() string { return strconv.Itoa(int(atomic.AddUint64(w.msgID, 1))) } @@ -59,16 +94,14 @@ func (w *Websocket) Options() *transport.Options { return w.TransportOpts } func (w *Websocket) Handshake() (err error) { - var payload []message.Message - payload = append(payload, message.Message{ + + m := message.Message{ Channel: string(transport.Handshake), Version: "1.0", //todo const SupportedConnectionTypes: []string{transportName}, - }) - if Debug { - log.Println("handshake request", debugJson(payload)) } - if err = w.conn.WriteJSON(payload); err != nil { + err = w.sendMessage(&m) + if err != nil { return err } @@ -84,23 +117,20 @@ func (w *Websocket) Handshake() (err error) { if resp.GetError() != nil { return err } + log.Println(debugJson(resp)) w.clientID = resp.ClientId return nil } func (w *Websocket) Connect() error { - var payload []message.Message - payload = append(payload, message.Message{ + m := message.Message{ Channel: string(transport.Connect), ClientId: w.clientID, ConnectionType: transportName, Id: w.nextMsgID(), - }) - if Debug { - log.Println("connect request", debugJson(payload)) } //todo verify if extensions are applied on connect,verify if hs is complete - return w.conn.WriteJSON(payload) + return w.sendMessage(&m) } func (w *Websocket) Subscribe(subscription string, onMessage func(message *message.Message)) error { @@ -114,53 +144,19 @@ func (w *Websocket) Subscribe(subscription string, onMessage func(message *messa w.TransportOpts.OutExt(m) } - var payload []message.Message - payload = append(payload, *m) - if Debug { - log.Println("subscribe request", debugJson(payload)) - } - err := w.conn.WriteJSON(payload) - if err != nil { + if err := w.sendMessage(m); err != nil { return err } - var hsResps []message.Message - if err = w.conn.ReadJSON(&hsResps); err != nil { - return err - } - if Debug { - log.Println("subscribe response", debugJson(hsResps)) - } + //todo validate - subResp := hsResps[0] - if subResp.GetError() != nil { - return err - } - if !subResp.Successful { - //report err just for sanity - } - unsubsCh := make(chan struct{}, 0) + inMsgCh := make(chan *message.Message, 0) - w.subs[subscription] = unsubsCh + w.subs[subscription] = inMsgCh - for { - select { - case <-unsubsCh: - return nil - default: - } - //todo guard unsusribe - var payload []message.Message - err := w.conn.ReadJSON(&payload) - if err != nil { - return err - } - //hack - msg := payload[0] - if strings.HasPrefix(msg.Channel, "/meta") { - continue //todo update introspect message and update state - } - onMessage(&msg) + var inMsg *message.Message + for inMsg = range inMsgCh { + onMessage(inMsg) } return nil }