diff --git a/.gitignore b/.gitignore index 1062418..8e98a03 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .idea/ *.iml +src +pkg diff --git a/channel.go b/channel.go index 4975f1b..d331111 100644 --- a/channel.go +++ b/channel.go @@ -17,13 +17,6 @@ const ( bufferSize = 100 ) -type state int - -const ( - open = iota - closed -) - var ( errClosedChannel = errors.New("zerorpc/channel closed channel") ErrLostRemote = errors.New("zerorpc/channel lost remote") @@ -32,7 +25,7 @@ var ( // Channel representation type channel struct { Id string - state state + open bool socket *socket socketInput chan *Event channelOutput chan *Event @@ -53,7 +46,7 @@ func (s *socket) newChannel(id string) *channel { c := channel{ Id: id, - state: open, + open: true, socket: s, socketInput: make(chan *Event, bufferSize), channelOutput: make(chan *Event), @@ -77,11 +70,11 @@ func (ch *channel) close() { ch.mu.Lock() defer ch.mu.Unlock() - if ch.state == closed { + if !ch.open { return } - ch.state = closed + ch.open = false ch.socket.removeChannel(ch) @@ -99,19 +92,19 @@ func (ch *channel) sendEvent(e *Event) error { ch.mu.Lock() defer ch.mu.Unlock() - if ch.state == closed { + if !ch.open { return errClosedChannel } if ch.Id != "" { - e.Header["response_to"] = ch.Id + e.Header.ResponseTo = ch.Id } else { - ch.Id = e.Header["message_id"].(string) + ch.Id = e.Header.Id go ch.sendHeartbeats() } - log.Printf("Channel %s sending event %s", ch.Id, e.Header["message_id"].(string)) + log.Printf("Channel %s sending event %s", ch.Id, e.Header.Id) identity := ch.identity @@ -138,7 +131,7 @@ func (ch *channel) sendHeartbeats() { for { time.Sleep(HeartbeatFrequency) - if ch.state == closed { + if !ch.open { return } @@ -160,10 +153,11 @@ func (ch *channel) sendHeartbeats() { } func (ch *channel) listen() { + streamCounter := 0 for { - if ch.state == closed { + if !ch.open { return } @@ -175,6 +169,7 @@ func (ch *channel) listen() { switch ev.Name { case "OK": + ch.lastHeartbeat = time.Now() ch.channelOutput <- ev case "ERR": @@ -242,7 +237,7 @@ func (ch *channel) listen() { func (ch *channel) handleHeartbeats() { for { - if ch.state == closed { + if !ch.open { return } diff --git a/client.go b/client.go index c228a7a..bc1af42 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,7 @@ package zerorpc import ( - "errors" + "fmt" "log" ) @@ -96,7 +96,7 @@ func (c *Client) Invoke(name string, args ...interface{}) (*Event, error) { select { case response := <-ch.channelOutput: if response.Name == "ERR" { - return response, errors.New(response.Args[0].(string)) + return response, fmt.Errorf("%s", response.Args) } else { return response, nil } @@ -152,7 +152,7 @@ It also supports first class exceptions, in case of an exception, the error returned from Invoke() or InvokeStream() is the exception name and the args of the returned event are the exception description and traceback. -The client sends heartbeat events every 5 seconds, if twp heartbeat events are missed, +The client sends heartbeat events every 5 seconds, if two heartbeat events are missed, the remote is considered as lost and an ErrLostRemote is returned. */ func (c *Client) InvokeStream(name string, args ...interface{}) ([]*Event, error) { @@ -177,7 +177,7 @@ func (c *Client) InvokeStream(name string, args ...interface{}) ([]*Event, error select { case response := <-ch.channelOutput: if response.Name == "ERR" { - return []*Event{response}, errors.New(response.Args[0].(string)) + return []*Event{response}, fmt.Errorf("%s", response.Args) } else if response.Name == "OK" { return []*Event{response}, nil } else if response.Name == "STREAM" { diff --git a/event.go b/event.go index 8edae65..4ef7f3f 100644 --- a/event.go +++ b/event.go @@ -1,7 +1,6 @@ package zerorpc import ( - "errors" uuid "github.com/nu7hatch/gouuid" "github.com/ugorji/go/codec" ) @@ -9,11 +8,27 @@ import ( // ZeroRPC protocol version const ProtocolVersion = 3 +var ( + mh codec.MsgpackHandle +) + +func init() { + mh.RawToString = true + +} + // Event representation + +type EventHeader struct { + Id string `codec:"message_id"` + ResponseTo string `codec:"response_to,omitempty"` + Version int `codec:"v"` +} + type Event struct { - Header map[string]interface{} + Header *EventHeader Name string - Args []interface{} + Args codec.MsgpackSpecRpcMultiArgs } // Returns a pointer to a new event, @@ -24,9 +39,7 @@ func newEvent(name string, args ...interface{}) (*Event, error) { return nil, err } - header := make(map[string]interface{}) - header["message_id"] = id.String() - header["v"] = ProtocolVersion + header := &EventHeader{Id: id.String(), Version: ProtocolVersion} e := Event{ Header: header, @@ -49,7 +62,7 @@ func (e *Event) packBytes() ([]byte, error) { var buf []byte - enc := codec.NewEncoderBytes(&buf, &codec.MsgpackHandle{}) + enc := codec.NewEncoderBytes(&buf, &mh) if err := enc.Encode(data); err != nil { return nil, err } @@ -59,63 +72,13 @@ func (e *Event) packBytes() ([]byte, error) { // Unpacks an event fom MsgPack bytes func unPackBytes(b []byte) (*Event, error) { - var mh codec.MsgpackHandle - var v interface{} - + var e Event dec := codec.NewDecoderBytes(b, &mh) - - err := dec.Decode(&v) + err := dec.Decode(&e) if err != nil { return nil, err } - // get the event headers - h, ok := v.([]interface{})[0].(map[interface{}]interface{}) - if !ok { - return nil, errors.New("zerorpc/event interface conversion error") - } - - header := make(map[string]interface{}) - - for k, v := range h { - switch t := v.(type) { - case []byte: - header[k.(string)] = string(t) - - default: - header[k.(string)] = t - } - } - - // get the event name - n, ok := v.([]interface{})[1].([]byte) - if !ok { - return nil, errors.New("zerorpc/event interface conversion error") - } - - // get the event args - args := make([]interface{}, 0) - - for i := 2; i < len(v.([]interface{})); i++ { - t := v.([]interface{})[i] - - switch t.(type) { - case []interface{}: - for _, a := range t.([]interface{}) { - args = append(args, convertValue(a)) - } - - default: - args = append(args, convertValue(t)) - } - } - - e := Event{ - Header: header, - Name: string(n), - Args: args, - } - return &e, nil } diff --git a/server.go b/server.go index cce060c..c616ef5 100644 --- a/server.go +++ b/server.go @@ -7,15 +7,11 @@ import ( // ZeroRPC server representation, // it holds a pointer to the ZeroMQ socket +type HandlerFunc func(args []interface{}) (interface{}, error) + type Server struct { socket *socket - handlers []*taskHandler -} - -// Task handler representation -type taskHandler struct { - TaskName string - HandlerFunc *func(args []interface{}) (interface{}, error) + handlers map[string]HandlerFunc } var ( @@ -66,7 +62,7 @@ func NewServer(endpoint string) (*Server, error) { server := Server{ socket: s, - handlers: make([]*taskHandler, 0), + handlers: make(map[string]HandlerFunc, 0), } server.socket.server = &server @@ -83,14 +79,13 @@ func (s *Server) Close() error { // tasks are invoked in new goroutines // // it returns ErrDuplicateHandler if an handler was already registered for the task -func (s *Server) RegisterTask(name string, handlerFunc *func(args []interface{}) (interface{}, error)) error { - for _, h := range s.handlers { - if h.TaskName == name { - return ErrDuplicateHandler - } +func (s *Server) RegisterTask(name string, handlerFunc func(args []interface{}) (interface{}, error)) error { + + if _, ok := s.handlers[name]; ok { + return ErrDuplicateHandler } - s.handlers = append(s.handlers, &taskHandler{TaskName: name, HandlerFunc: handlerFunc}) + s.handlers[name] = handlerFunc log.Printf("ZeroRPC server registered handler for task %s", name) @@ -100,12 +95,11 @@ func (s *Server) RegisterTask(name string, handlerFunc *func(args []interface{}) // Invoke the handler for a task event, // it returns ErrNoTaskHandler if no handler is registered for the task func (s *Server) handleTask(ev *Event) (interface{}, error) { - for _, h := range s.handlers { - if h.TaskName == ev.Name { - log.Printf("ZeroRPC server handling task %s with args %s", ev.Name, ev.Args) - return (*h.HandlerFunc)(ev.Args) - } + if h, ok := s.handlers[ev.Name]; ok { + + log.Printf("ZeroRPC server handling task %s with args %s", ev.Name, ev.Args) + return h(ev.Args) } return nil, ErrNoTaskHandler diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..8fdf10c --- /dev/null +++ b/server_test.go @@ -0,0 +1,18 @@ +package zerorpc + +import "testing" + +func TestServerBind(t *testing.T) { + s, err := NewServer("tcp://0.0.0.0:4242") + if err != nil { + panic(err) + } + + defer s.Close() + + h := func(v []interface{}) (interface{}, error) { + return "Hello, " + v[0].(string), nil + } + + s.RegisterTask("hello", h) +} diff --git a/socket.go b/socket.go index c73561a..d28a313 100644 --- a/socket.go +++ b/socket.go @@ -4,9 +4,10 @@ package zerorpc import ( - zmq "github.com/pebbe/zmq4" "log" "sync" + + zmq "github.com/pebbe/zmq4" ) // ZeroRPC socket representation @@ -75,6 +76,9 @@ func (s *socket) close() error { s.removeChannel(c) } + s.mu.Lock() + defer s.mu.Unlock() + log.Printf("ZeroRPC socket closed") return s.zmqSocket.Close() } @@ -102,7 +106,7 @@ func (s *socket) sendEvent(e *Event, identity string) error { return err } - log.Printf("ZeroRPC socket sent event %s", e.Header["message_id"].(string)) + log.Printf("ZeroRPC socket sent event %s", e.Header.Id) i, err := s.zmqSocket.SendMessage(identity, "", b) if err != nil { @@ -135,11 +139,11 @@ func (s *socket) listen() { s.socketErrors <- err } - log.Printf("ZeroRPC socket recieved event %s", ev.Header["message_id"].(string)) + log.Printf("ZeroRPC socket recieved event %s", ev.Header.Id) var ch *channel - if _, ok := ev.Header["response_to"]; !ok { - ch = s.newChannel(ev.Header["message_id"].(string)) + if ev.Header.ResponseTo == "" { + ch = s.newChannel(ev.Header.Id) go ch.sendHeartbeats() if len(barr) > 1 { @@ -147,14 +151,14 @@ func (s *socket) listen() { } } else { for _, c := range s.Channels { - if c.Id == ev.Header["response_to"].(string) { + if c.Id == ev.Header.ResponseTo { ch = c } } } - if ch != nil && ch.state == open { - log.Printf("ZeroRPC socket routing event %s to channel %s", ev.Header["message_id"].(string), ch.Id) + if ch != nil && ch.open { + log.Printf("ZeroRPC socket routing event %s to channel %s", ev.Header.Id, ch.Id) ch.socketInput <- ev }