Improve locking and add panic handlers to websocket functionality

ws_check
Martin Boehm 2020-10-29 08:55:04 +01:00
parent 4832205f45
commit b48a9ef6a6
2 changed files with 103 additions and 67 deletions

View File

@ -545,12 +545,22 @@ func syncIndexLoop() {
}
func onNewBlockHash(hash string, height uint32) {
defer func() {
if r := recover(); r != nil {
glog.Error("onNewBlockHash recovered from panic: ", r)
}
}()
for _, c := range callbacksOnNewBlock {
c(hash, height)
}
}
func onNewFiatRatesTicker(ticker *db.CurrencyRatesTicker) {
defer func() {
if r := recover(); r != nil {
glog.Error("onNewFiatRatesTicker recovered from panic: ", r)
}
}()
for _, c := range callbacksOnNewFiatRatesTicker {
c(ticker)
}
@ -617,12 +627,22 @@ func storeInternalStateLoop() {
}
func onNewTxAddr(tx *bchain.Tx, desc bchain.AddressDescriptor) {
defer func() {
if r := recover(); r != nil {
glog.Error("onNewTxAddr recovered from panic: ", r)
}
}()
for _, c := range callbacksOnNewTxAddr {
c(tx, desc)
}
}
func onNewTx(tx *bchain.MempoolTx) {
defer func() {
if r := recover(); r != nil {
glog.Error("onNewTx recovered from panic: ", r)
}
}()
for _, c := range callbacksOnNewTx {
c(tx)
}

View File

@ -143,6 +143,12 @@ func (s *WebsocketServer) GetHandler() http.Handler {
}
func (s *WebsocketServer) closeChannel(c *websocketChannel) {
if c.CloseOut() {
s.onDisconnect(c)
}
}
func (c *websocketChannel) CloseOut() bool {
c.aliveLock.Lock()
defer c.aliveLock.Unlock()
if c.alive {
@ -153,14 +159,24 @@ func (s *WebsocketServer) closeChannel(c *websocketChannel) {
for len(c.out) > 0 {
<-c.out
}
s.onDisconnect(c)
return true
}
return false
}
func (c *websocketChannel) IsAlive() bool {
func (c *websocketChannel) DataOut(data *websocketRes) {
c.aliveLock.Lock()
defer c.aliveLock.Unlock()
return c.alive
if c.alive {
if len(c.out) < outChannelSize-1 {
c.out <- data
} else {
glog.Warning("Channel ", c.id, " overflow, closing")
// close the connection but do not call CloseOut - would call duplicate c.aliveLock.Lock
// CloseOut will be called because the closed connection will cause break in the inputLoop
c.conn.Close()
}
}
}
func (s *WebsocketServer) inputLoop(c *websocketChannel) {
@ -204,11 +220,18 @@ func (s *WebsocketServer) inputLoop(c *websocketChannel) {
}
func (s *WebsocketServer) outputLoop(c *websocketChannel) {
defer func() {
if r := recover(); r != nil {
glog.Error("recovered from panic: ", r, ", ", c.id)
s.closeChannel(c)
}
}()
for m := range c.out {
err := c.conn.WriteJSON(m)
if err != nil {
glog.Error("Error sending message to ", c.id, ", ", err)
s.closeChannel(c)
return
}
}
}
@ -383,18 +406,6 @@ var requestHandlers = map[string]func(*WebsocketServer, *websocketChannel, *webs
},
}
func sendResponse(c *websocketChannel, req *websocketReq, data interface{}) {
defer func() {
if r := recover(); r != nil {
glog.Error("Client ", c.id, ", onRequest ", req.Method, " recovered from panic: ", r)
}
}()
c.out <- &websocketRes{
ID: req.ID,
Data: data,
}
}
func (s *WebsocketServer) onRequest(c *websocketChannel, req *websocketReq) {
var err error
var data interface{}
@ -408,7 +419,10 @@ func (s *WebsocketServer) onRequest(c *websocketChannel, req *websocketReq) {
}
// nil data means no response
if data != nil {
sendResponse(c, req, data)
c.DataOut(&websocketRes{
ID: req.ID,
Data: data,
})
}
}()
t := time.Now()
@ -429,7 +443,7 @@ func (s *WebsocketServer) onRequest(c *websocketChannel, req *websocketReq) {
data = e
}
} else {
glog.Warning("Client ", c.id, " onMessage ", req.Method, ": unknown method, data ", string(req.Params))
glog.V(1).Info("Client ", c.id, " onMessage ", req.Method, ": unknown method, data ", string(req.Params))
}
}
@ -665,11 +679,25 @@ func (s *WebsocketServer) unmarshalAddresses(params []byte) ([]bchain.AddressDes
return rv, nil
}
// unsubscribe addresses without addressSubscriptionsLock - can be called only from subscribeAddresses and unsubscribeAddresses
func (s *WebsocketServer) doUnsubscribeAddresses(c *websocketChannel) {
for ads, sa := range s.addressSubscriptions {
for sc := range sa {
if sc == c {
delete(sa, c)
}
}
if len(sa) == 0 {
delete(s.addressSubscriptions, ads)
}
}
}
func (s *WebsocketServer) subscribeAddresses(c *websocketChannel, addrDesc []bchain.AddressDescriptor, req *websocketReq) (res interface{}, err error) {
// unsubscribe all previous subscriptions
s.unsubscribeAddresses(c)
s.addressSubscriptionsLock.Lock()
defer s.addressSubscriptionsLock.Unlock()
// unsubscribe all previous subscriptions
s.doUnsubscribeAddresses(c)
for i := range addrDesc {
ads := string(addrDesc[i])
as, ok := s.addressSubscriptions[ads]
@ -686,26 +714,30 @@ func (s *WebsocketServer) subscribeAddresses(c *websocketChannel, addrDesc []bch
func (s *WebsocketServer) unsubscribeAddresses(c *websocketChannel) (res interface{}, err error) {
s.addressSubscriptionsLock.Lock()
defer s.addressSubscriptionsLock.Unlock()
for ads, sa := range s.addressSubscriptions {
s.doUnsubscribeAddresses(c)
return &subscriptionResponse{false}, nil
}
// unsubscribe fiat rates without fiatRatesSubscriptionsLock - can be called only from subscribeFiatRates and unsubscribeFiatRates
func (s *WebsocketServer) doUnsubscribeFiatRates(c *websocketChannel) {
for fr, sa := range s.fiatRatesSubscriptions {
for sc := range sa {
if sc == c {
delete(sa, c)
}
}
if len(sa) == 0 {
delete(s.addressSubscriptions, ads)
delete(s.fiatRatesSubscriptions, fr)
}
}
return &subscriptionResponse{false}, nil
}
// subscribeFiatRates subscribes all FiatRates subscriptions by this channel
func (s *WebsocketServer) subscribeFiatRates(c *websocketChannel, currency string, req *websocketReq) (res interface{}, err error) {
// unsubscribe all previous subscriptions
s.unsubscribeFiatRates(c)
s.fiatRatesSubscriptionsLock.Lock()
defer s.fiatRatesSubscriptionsLock.Unlock()
// unsubscribe all previous subscriptions
s.doUnsubscribeFiatRates(c)
if currency == "" {
currency = allFiatRates
}
@ -722,16 +754,7 @@ func (s *WebsocketServer) subscribeFiatRates(c *websocketChannel, currency strin
func (s *WebsocketServer) unsubscribeFiatRates(c *websocketChannel) (res interface{}, err error) {
s.fiatRatesSubscriptionsLock.Lock()
defer s.fiatRatesSubscriptionsLock.Unlock()
for fr, sa := range s.fiatRatesSubscriptions {
for sc := range sa {
if sc == c {
delete(sa, c)
}
}
if len(sa) == 0 {
delete(s.fiatRatesSubscriptions, fr)
}
}
s.doUnsubscribeFiatRates(c)
return &subscriptionResponse{false}, nil
}
@ -747,12 +770,10 @@ func (s *WebsocketServer) OnNewBlock(hash string, height uint32) {
Hash: hash,
}
for c, id := range s.newBlockSubscriptions {
if c.IsAlive() {
c.out <- &websocketRes{
ID: id,
Data: &data,
}
}
c.DataOut(&websocketRes{
ID: id,
Data: &data,
})
}
glog.Info("broadcasting new block ", height, " ", hash, " to ", len(s.newBlockSubscriptions), " channels")
}
@ -772,30 +793,26 @@ func (s *WebsocketServer) sendOnNewTxAddr(stringAddressDescriptor string, tx *ap
Address: addr[0],
Tx: tx,
}
// get the list of subscriptions again, this time keep the lock
s.addressSubscriptionsLock.Lock()
defer s.addressSubscriptionsLock.Unlock()
as, ok := s.addressSubscriptions[stringAddressDescriptor]
if ok {
for c, id := range as {
if c.IsAlive() {
c.out <- &websocketRes{
ID: id,
Data: &data,
}
}
c.DataOut(&websocketRes{
ID: id,
Data: &data,
})
}
glog.Info("broadcasting new tx ", tx.Txid, ", addr ", addr[0], " to ", len(as), " channels")
}
}
}
// OnNewTx is a callback that broadcasts info about a tx affecting subscribed address
func (s *WebsocketServer) OnNewTx(tx *bchain.MempoolTx) {
func (s *WebsocketServer) getNewTxSubscriptions(tx *bchain.MempoolTx) map[string]struct{} {
// check if there is any subscription in inputs, outputs and erc20
// release the lock immediately, GetTransactionFromMempoolTx is potentially slow
subscribed := make(map[string]struct{})
s.addressSubscriptionsLock.Lock()
defer s.addressSubscriptionsLock.Unlock()
subscribed := make(map[string]struct{})
for i := range tx.Vin {
sad := string(tx.Vin[i].AddrDesc)
if len(sad) > 0 {
@ -833,7 +850,12 @@ func (s *WebsocketServer) OnNewTx(tx *bchain.MempoolTx) {
}
}
}
s.addressSubscriptionsLock.Unlock()
return subscribed
}
// OnNewTx is a callback that broadcasts info about a tx affecting subscribed address
func (s *WebsocketServer) OnNewTx(tx *bchain.MempoolTx) {
subscribed := s.getNewTxSubscriptions(tx)
if len(subscribed) > 0 {
atx, err := s.api.GetTransactionFromMempoolTx(tx)
if err != nil {
@ -847,8 +869,6 @@ func (s *WebsocketServer) OnNewTx(tx *bchain.MempoolTx) {
}
func (s *WebsocketServer) broadcastTicker(currency string, rates map[string]float64) {
s.fiatRatesSubscriptionsLock.Lock()
defer s.fiatRatesSubscriptionsLock.Unlock()
as, ok := s.fiatRatesSubscriptions[currency]
if ok && len(as) > 0 {
data := struct {
@ -856,24 +876,20 @@ func (s *WebsocketServer) broadcastTicker(currency string, rates map[string]floa
}{
Rates: rates,
}
// get the list of subscriptions again, this time keep the lock
as, ok = s.fiatRatesSubscriptions[currency]
if ok {
for c, id := range as {
if c.IsAlive() {
c.out <- &websocketRes{
ID: id,
Data: &data,
}
}
}
glog.Info("broadcasting new rates for currency ", currency, " to ", len(as), " channels")
for c, id := range as {
c.DataOut(&websocketRes{
ID: id,
Data: &data,
})
}
glog.Info("broadcasting new rates for currency ", currency, " to ", len(as), " channels")
}
}
// OnNewFiatRatesTicker is a callback that broadcasts info about fiat rates affecting subscribed currency
func (s *WebsocketServer) OnNewFiatRatesTicker(ticker *db.CurrencyRatesTicker) {
s.fiatRatesSubscriptionsLock.Lock()
defer s.fiatRatesSubscriptionsLock.Unlock()
for currency, rate := range ticker.Rates {
s.broadcastTicker(currency, map[string]float64{currency: rate})
}