复制项目

This commit is contained in:
kim.dev.6789
2026-01-14 22:16:44 +08:00
parent e2577b8cee
commit e50142a3b9
691 changed files with 97009 additions and 1 deletions

View File

@@ -0,0 +1,76 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"time"
cbapi "git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/mcontext"
)
func (ws *WsServer) webhookAfterUserOnline(ctx context.Context, after *config.AfterConfig, userID string, platformID int, isAppBackground bool, connID string) {
req := cbapi.CallbackUserOnlineReq{
UserStatusCallbackReq: cbapi.UserStatusCallbackReq{
UserStatusBaseCallback: cbapi.UserStatusBaseCallback{
CallbackCommand: cbapi.CallbackAfterUserOnlineCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: platformID,
Platform: constant.PlatformIDToName(platformID),
},
UserID: userID,
},
Seq: time.Now().UnixMilli(),
IsAppBackground: isAppBackground,
ConnID: connID,
}
ws.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CommonCallbackResp{}, after)
}
func (ws *WsServer) webhookAfterUserOffline(ctx context.Context, after *config.AfterConfig, userID string, platformID int, connID string) {
req := &cbapi.CallbackUserOfflineReq{
UserStatusCallbackReq: cbapi.UserStatusCallbackReq{
UserStatusBaseCallback: cbapi.UserStatusBaseCallback{
CallbackCommand: cbapi.CallbackAfterUserOfflineCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: platformID,
Platform: constant.PlatformIDToName(platformID),
},
UserID: userID,
},
Seq: time.Now().UnixMilli(),
ConnID: connID,
}
ws.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CallbackUserOfflineResp{}, after)
}
func (ws *WsServer) webhookAfterUserKickOff(ctx context.Context, after *config.AfterConfig, userID string, platformID int) {
req := &cbapi.CallbackUserKickOffReq{
UserStatusCallbackReq: cbapi.UserStatusCallbackReq{
UserStatusBaseCallback: cbapi.UserStatusBaseCallback{
CallbackCommand: cbapi.CallbackAfterUserKickOffCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: platformID,
Platform: constant.PlatformIDToName(platformID),
},
UserID: userID,
},
Seq: time.Now().UnixMilli(),
}
ws.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CommonCallbackResp{}, after)
}

View File

@@ -0,0 +1,476 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"google.golang.org/protobuf/proto"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/stringutil"
)
var (
ErrConnClosed = errs.New("conn has closed")
ErrNotSupportMessageProtocol = errs.New("not support message protocol")
ErrClientClosed = errs.New("client actively close the connection")
ErrPanic = errs.New("panic error")
)
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
// CloseMessage denotes a close control message. The optional message
// payload contains a numeric code and text. Use the FormatCloseMessage
// function to format a close message payload.
CloseMessage = 8
// PingMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PingMessage = 9
// PongMessage denotes a pong control message. The optional message payload
// is UTF-8 encoded text.
PongMessage = 10
)
type PingPongHandler func(string) error
type Client struct {
w *sync.Mutex
conn LongConn
PlatformID int `json:"platformID"`
IsCompress bool `json:"isCompress"`
UserID string `json:"userID"`
IsBackground bool `json:"isBackground"`
SDKType string `json:"sdkType"`
SDKVersion string `json:"sdkVersion"`
Encoder Encoder
ctx *UserConnContext
longConnServer LongConnServer
closed atomic.Bool
closedErr error
token string
hbCtx context.Context
hbCancel context.CancelFunc
subLock *sync.Mutex
subUserIDs map[string]struct{} // client conn subscription list
}
// ResetClient updates the client's state with new connection and context information.
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
c.w = new(sync.Mutex)
c.conn = conn
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
c.IsCompress = ctx.GetCompression()
c.IsBackground = ctx.GetBackground()
c.UserID = ctx.GetUserID()
c.ctx = ctx
c.longConnServer = longConnServer
c.IsBackground = false
c.closed.Store(false)
c.closedErr = nil
c.token = ctx.GetToken()
c.SDKType = ctx.GetSDKType()
c.SDKVersion = ctx.GetSDKVersion()
c.hbCtx, c.hbCancel = context.WithCancel(c.ctx)
c.subLock = new(sync.Mutex)
if c.subUserIDs != nil {
clear(c.subUserIDs)
}
if c.SDKType == GoSDK {
c.Encoder = NewGobEncoder()
} else {
c.Encoder = NewJsonEncoder()
}
c.subUserIDs = make(map[string]struct{})
}
func (c *Client) pingHandler(appData string) error {
if err := c.conn.SetReadDeadline(pongWait); err != nil {
return err
}
log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData)
return c.writePongMsg(appData)
}
func (c *Client) pongHandler(_ string) error {
if err := c.conn.SetReadDeadline(pongWait); err != nil {
return err
}
return nil
}
// readMessage continuously reads messages from the connection.
func (c *Client) readMessage() {
defer func() {
if r := recover(); r != nil {
c.closedErr = ErrPanic
log.ZPanic(c.ctx, "socket have panic err:", errs.ErrPanic(r))
}
c.close()
}()
c.conn.SetReadLimit(maxMessageSize)
_ = c.conn.SetReadDeadline(pongWait)
c.conn.SetPongHandler(c.pongHandler)
c.conn.SetPingHandler(c.pingHandler)
c.activeHeartbeat(c.hbCtx)
for {
log.ZDebug(c.ctx, "readMessage")
messageType, message, returnErr := c.conn.ReadMessage()
if returnErr != nil {
log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType)
c.closedErr = returnErr
return
}
log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
if c.closed.Load() {
// The scenario where the connection has just been closed, but the coroutine has not exited
c.closedErr = ErrConnClosed
return
}
switch messageType {
case MessageBinary:
_ = c.conn.SetReadDeadline(pongWait)
parseDataErr := c.handleMessage(message)
if parseDataErr != nil {
c.closedErr = parseDataErr
return
}
case MessageText:
_ = c.conn.SetReadDeadline(pongWait)
parseDataErr := c.handlerTextMessage(message)
if parseDataErr != nil {
c.closedErr = parseDataErr
return
}
case PingMessage:
err := c.writePongMsg("")
log.ZError(c.ctx, "writePongMsg", err)
case CloseMessage:
c.closedErr = ErrClientClosed
return
default:
}
}
}
// handleMessage processes a single message received by the client.
func (c *Client) handleMessage(message []byte) error {
if c.IsCompress {
var err error
message, err = c.longConnServer.DecompressWithPool(message)
if err != nil {
return errs.Wrap(err)
}
}
var binaryReq = getReq()
defer freeReq(binaryReq)
err := c.Encoder.Decode(message, binaryReq)
if err != nil {
return err
}
if err := c.longConnServer.Validate(binaryReq); err != nil {
return err
}
if binaryReq.SendID != c.UserID {
return errs.New("exception conn userID not same to req userID", "binaryReq", binaryReq.String())
}
ctx := mcontext.WithMustInfoCtx(
[]string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.PlatformID), c.ctx.GetConnID()},
)
log.ZDebug(ctx, "gateway req message", "req", binaryReq.String())
var (
resp []byte
messageErr error
)
switch binaryReq.ReqIdentifier {
case WSGetNewestSeq:
resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq)
case WSSendMsg:
resp, messageErr = c.longConnServer.SendMessage(ctx, binaryReq)
case WSSendSignalMsg:
resp, messageErr = c.longConnServer.SendSignalMessage(ctx, binaryReq)
case WSPullMsgBySeqList:
resp, messageErr = c.longConnServer.PullMessageBySeqList(ctx, binaryReq)
case WSPullMsg:
resp, messageErr = c.longConnServer.GetSeqMessage(ctx, binaryReq)
case WSGetConvMaxReadSeq:
resp, messageErr = c.longConnServer.GetConversationsHasReadAndMaxSeq(ctx, binaryReq)
case WsPullConvLastMessage:
resp, messageErr = c.longConnServer.GetLastMessage(ctx, binaryReq)
case WsLogoutMsg:
resp, messageErr = c.longConnServer.UserLogout(ctx, binaryReq)
case WsSetBackgroundStatus:
resp, messageErr = c.setAppBackgroundStatus(ctx, binaryReq)
case WsSubUserOnlineStatus:
resp, messageErr = c.longConnServer.SubUserOnlineStatus(ctx, c, binaryReq)
default:
return fmt.Errorf(
"ReqIdentifier failed,sendID:%s,msgIncr:%s,reqIdentifier:%d",
binaryReq.SendID,
binaryReq.MsgIncr,
binaryReq.ReqIdentifier,
)
}
return c.replyMessage(ctx, binaryReq, messageErr, resp)
}
func (c *Client) setAppBackgroundStatus(ctx context.Context, req *Req) ([]byte, error) {
resp, isBackground, messageErr := c.longConnServer.SetUserDeviceBackground(ctx, req)
if messageErr != nil {
return nil, messageErr
}
c.IsBackground = isBackground
// TODO: callback
return resp, nil
}
func (c *Client) close() {
c.w.Lock()
defer c.w.Unlock()
if c.closed.Load() {
return
}
c.closed.Store(true)
c.conn.Close()
c.hbCancel() // Close server-initiated heartbeat.
c.longConnServer.UnRegister(c)
}
func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, resp []byte) error {
errResp := apiresp.ParseError(err)
mReply := Resp{
ReqIdentifier: binaryReq.ReqIdentifier,
MsgIncr: binaryReq.MsgIncr,
OperationID: binaryReq.OperationID,
ErrCode: errResp.ErrCode,
ErrMsg: errResp.ErrMsg,
Data: resp,
}
t := time.Now()
log.ZDebug(ctx, "gateway reply message", "resp", mReply.String())
err = c.writeBinaryMsg(mReply)
if err != nil {
log.ZWarn(ctx, "wireBinaryMsg replyMessage", err, "resp", mReply.String())
}
log.ZDebug(ctx, "wireBinaryMsg end", "time cost", time.Since(t))
if binaryReq.ReqIdentifier == WsLogoutMsg {
return errs.New("user logout", "operationID", binaryReq.OperationID).Wrap()
}
return nil
}
func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error {
var msg sdkws.PushMessages
conversationID := msgprocessor.GetConversationIDByMsg(msgData)
m := map[string]*sdkws.PullMsgs{conversationID: {Msgs: []*sdkws.MsgData{msgData}}}
if msgprocessor.IsNotification(conversationID) {
msg.NotificationMsgs = m
} else {
msg.Msgs = m
}
log.ZDebug(ctx, "PushMessage", "msg", &msg)
data, err := proto.Marshal(&msg)
if err != nil {
return err
}
resp := Resp{
ReqIdentifier: WSPushMsg,
OperationID: mcontext.GetOperationID(ctx),
Data: data,
}
return c.writeBinaryMsg(resp)
}
func (c *Client) KickOnlineMessage() error {
resp := Resp{
ReqIdentifier: WSKickOnlineMsg,
}
log.ZDebug(c.ctx, "KickOnlineMessage debug ")
// 先发送踢掉通知消息
err := c.writeBinaryMsg(resp)
if err != nil {
log.ZWarn(c.ctx, "KickOnlineMessage writeBinaryMsg failed", err)
// 即使发送失败,也要关闭连接
c.close()
return err
}
// 给消息一些时间发送,然后关闭连接
// 使用一个小延迟确保消息已经写入缓冲区
time.Sleep(10 * time.Millisecond)
c.close()
return nil
}
func (c *Client) PushUserOnlineStatus(data []byte) error {
resp := Resp{
ReqIdentifier: WsSubUserOnlineStatus,
Data: data,
}
return c.writeBinaryMsg(resp)
}
func (c *Client) writeBinaryMsg(resp Resp) error {
if c.closed.Load() {
return nil
}
encodedBuf, err := c.Encoder.Encode(resp)
if err != nil {
return err
}
c.w.Lock()
defer c.w.Unlock()
err = c.conn.SetWriteDeadline(writeWait)
if err != nil {
return err
}
if c.IsCompress {
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
if compressErr != nil {
return compressErr
}
return c.conn.WriteMessage(MessageBinary, resultBuf)
}
return c.conn.WriteMessage(MessageBinary, encodedBuf)
}
// Actively initiate Heartbeat when platform in Web.
func (c *Client) activeHeartbeat(ctx context.Context) {
if c.PlatformID == constant.WebPlatformID {
go func() {
defer func() {
if r := recover(); r != nil {
log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r))
}
}()
log.ZDebug(ctx, "server initiative send heartbeat start.")
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.writePingMsg(); err != nil {
log.ZWarn(c.ctx, "send Ping Message error.", err)
return
}
case <-c.hbCtx.Done():
return
}
}
}()
}
}
func (c *Client) writePingMsg() error {
if c.closed.Load() {
return nil
}
c.w.Lock()
defer c.w.Unlock()
err := c.conn.SetWriteDeadline(writeWait)
if err != nil {
return err
}
return c.conn.WriteMessage(PingMessage, nil)
}
func (c *Client) writePongMsg(appData string) error {
log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData)
if c.closed.Load() {
log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr)
return nil
}
c.w.Lock()
defer c.w.Unlock()
err := c.conn.SetWriteDeadline(writeWait)
if err != nil {
log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData)
return errs.Wrap(err)
}
err = c.conn.WriteMessage(PongMessage, []byte(appData))
if err != nil {
log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage)
}
return errs.Wrap(err)
}
func (c *Client) handlerTextMessage(b []byte) error {
var msg TextMessage
if err := json.Unmarshal(b, &msg); err != nil {
return err
}
switch msg.Type {
case TextPong:
return nil
case TextPing:
msg.Type = TextPong
msgData, err := json.Marshal(msg)
if err != nil {
return err
}
c.w.Lock()
defer c.w.Unlock()
if err := c.conn.SetWriteDeadline(writeWait); err != nil {
return err
}
return c.conn.WriteMessage(MessageText, msgData)
default:
return fmt.Errorf("not support message type %s", msg.Type)
}
}

View File

@@ -0,0 +1,113 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"bytes"
"compress/gzip"
"io"
"sync"
"github.com/openimsdk/tools/errs"
)
var (
gzipWriterPool = sync.Pool{New: func() any { return gzip.NewWriter(nil) }}
gzipReaderPool = sync.Pool{New: func() any { return new(gzip.Reader) }}
)
type Compressor interface {
Compress(rawData []byte) ([]byte, error)
CompressWithPool(rawData []byte) ([]byte, error)
DeCompress(compressedData []byte) ([]byte, error)
DecompressWithPool(compressedData []byte) ([]byte, error)
}
type GzipCompressor struct {
compressProtocol string
}
func NewGzipCompressor() *GzipCompressor {
return &GzipCompressor{compressProtocol: "gzip"}
}
func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) {
gzipBuffer := bytes.Buffer{}
gz := gzip.NewWriter(&gzipBuffer)
if _, err := gz.Write(rawData); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.Compress: writing to gzip writer failed")
}
if err := gz.Close(); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.Compress: closing gzip writer failed")
}
return gzipBuffer.Bytes(), nil
}
func (g *GzipCompressor) CompressWithPool(rawData []byte) ([]byte, error) {
gz := gzipWriterPool.Get().(*gzip.Writer)
defer gzipWriterPool.Put(gz)
gzipBuffer := bytes.Buffer{}
gz.Reset(&gzipBuffer)
if _, err := gz.Write(rawData); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.CompressWithPool: error writing data")
}
if err := gz.Close(); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.CompressWithPool: error closing gzip writer")
}
return gzipBuffer.Bytes(), nil
}
func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) {
buff := bytes.NewBuffer(compressedData)
reader, err := gzip.NewReader(buff)
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DeCompress: NewReader creation failed")
}
decompressedData, err := io.ReadAll(reader)
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DeCompress: reading from gzip reader failed")
}
if err = reader.Close(); err != nil {
// Even if closing the reader fails, we've successfully read the data,
// so we return the decompressed data and an error indicating the close failure.
return decompressedData, errs.WrapMsg(err, "GzipCompressor.DeCompress: closing gzip reader failed")
}
return decompressedData, nil
}
func (g *GzipCompressor) DecompressWithPool(compressedData []byte) ([]byte, error) {
reader := gzipReaderPool.Get().(*gzip.Reader)
defer gzipReaderPool.Put(reader)
err := reader.Reset(bytes.NewReader(compressedData))
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DecompressWithPool: resetting gzip reader failed")
}
decompressedData, err := io.ReadAll(reader)
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DecompressWithPool: reading from pooled gzip reader failed")
}
if err = reader.Close(); err != nil {
// Similar to DeCompress, return the data and error for close failure.
return decompressedData, errs.WrapMsg(err, "GzipCompressor.DecompressWithPool: closing pooled gzip reader failed")
}
return decompressedData, nil
}

View File

@@ -0,0 +1,139 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"crypto/rand"
"github.com/stretchr/testify/assert"
"sync"
"testing"
"unsafe"
)
func mockRandom() []byte {
bs := make([]byte, 50)
rand.Read(bs)
return bs
}
func TestCompressDecompress(t *testing.T) {
compressor := NewGzipCompressor()
for i := 0; i < 2000; i++ {
src := mockRandom()
// compress
dest, err := compressor.CompressWithPool(src)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// decompress
res, err := compressor.DecompressWithPool(dest)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// check
assert.EqualValues(t, src, res)
}
}
func TestCompressDecompressWithConcurrency(t *testing.T) {
wg := sync.WaitGroup{}
compressor := NewGzipCompressor()
for i := 0; i < 200; i++ {
wg.Add(1)
go func() {
defer wg.Done()
src := mockRandom()
// compress
dest, err := compressor.CompressWithPool(src)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// decompress
res, err := compressor.DecompressWithPool(dest)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// check
assert.EqualValues(t, src, res)
}()
}
wg.Wait()
}
func BenchmarkCompress(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
for i := 0; i < b.N; i++ {
_, err := compressor.Compress(src)
assert.Equal(b, nil, err)
}
}
func BenchmarkCompressWithSyncPool(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
for i := 0; i < b.N; i++ {
_, err := compressor.CompressWithPool(src)
assert.Equal(b, nil, err)
}
}
func BenchmarkDecompress(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
comdata, err := compressor.Compress(src)
assert.Equal(b, nil, err)
for i := 0; i < b.N; i++ {
_, err := compressor.DeCompress(comdata)
assert.Equal(b, nil, err)
}
}
func BenchmarkDecompressWithSyncPool(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
comdata, err := compressor.Compress(src)
assert.Equal(b, nil, err)
for i := 0; i < b.N; i++ {
_, err := compressor.DecompressWithPool(comdata)
assert.Equal(b, nil, err)
}
}
func TestName(t *testing.T) {
t.Log(unsafe.Sizeof(Client{}))
}

View File

@@ -0,0 +1,72 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import "time"
const (
WsUserID = "sendID"
CommonUserID = "userID"
PlatformID = "platformID"
ConnID = "connID"
Token = "token"
OperationID = "operationID"
Compression = "compression"
GzipCompressionProtocol = "gzip"
BackgroundStatus = "isBackground"
SendResponse = "isMsgResp"
SDKType = "sdkType"
SDKVersion = "sdkVersion"
)
const (
GoSDK = "go"
JsSDK = "js"
)
const (
WebSocket = iota + 1
)
const (
// Websocket Protocol.
WSGetNewestSeq = 1001
WSPullMsgBySeqList = 1002
WSSendMsg = 1003
WSSendSignalMsg = 1004
WSPullMsg = 1005
WSGetConvMaxReadSeq = 1006
WsPullConvLastMessage = 1007
WSPushMsg = 2001
WSKickOnlineMsg = 2002
WsLogoutMsg = 2003
WsSetBackgroundStatus = 2004
WsSubUserOnlineStatus = 2005
WSDataError = 3001
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 30 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 51200
)

View File

@@ -0,0 +1,216 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"net/http"
"net/url"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/utils/encrypt"
"github.com/openimsdk/tools/utils/stringutil"
"github.com/openimsdk/tools/utils/timeutil"
)
type UserConnContext struct {
RespWriter http.ResponseWriter
Req *http.Request
Path string
Method string
RemoteAddr string
ConnID string
}
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
return
}
func (c *UserConnContext) Done() <-chan struct{} {
return nil
}
func (c *UserConnContext) Err() error {
return nil
}
func (c *UserConnContext) Value(key any) any {
switch key {
case constant.OpUserID:
return c.GetUserID()
case constant.OperationID:
return c.GetOperationID()
case constant.ConnID:
return c.GetConnID()
case constant.OpUserPlatform:
return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID()))
case constant.RemoteAddr:
return c.RemoteAddr
default:
return ""
}
}
func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnContext {
remoteAddr := req.RemoteAddr
if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != "" {
remoteAddr += "_" + forwarded
}
return &UserConnContext{
RespWriter: respWriter,
Req: req,
Path: req.URL.Path,
Method: req.Method,
RemoteAddr: remoteAddr,
ConnID: encrypt.Md5(req.RemoteAddr + "_" + strconv.Itoa(int(timeutil.GetCurrentTimestampByMill()))),
}
}
func newTempContext() *UserConnContext {
return &UserConnContext{
Req: &http.Request{URL: &url.URL{}},
}
}
func (c *UserConnContext) GetRemoteAddr() string {
return c.RemoteAddr
}
func (c *UserConnContext) Query(key string) (string, bool) {
var value string
if value = c.Req.URL.Query().Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) GetHeader(key string) (string, bool) {
var value string
if value = c.Req.Header.Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) SetHeader(key, value string) {
c.RespWriter.Header().Set(key, value)
}
func (c *UserConnContext) ErrReturn(error string, code int) {
http.Error(c.RespWriter, error, code)
}
func (c *UserConnContext) GetConnID() string {
return c.ConnID
}
func (c *UserConnContext) GetUserID() string {
return c.Req.URL.Query().Get(WsUserID)
}
func (c *UserConnContext) GetPlatformID() string {
return c.Req.URL.Query().Get(PlatformID)
}
func (c *UserConnContext) GetOperationID() string {
return c.Req.URL.Query().Get(OperationID)
}
func (c *UserConnContext) SetOperationID(operationID string) {
values := c.Req.URL.Query()
values.Set(OperationID, operationID)
c.Req.URL.RawQuery = values.Encode()
}
func (c *UserConnContext) GetToken() string {
return c.Req.URL.Query().Get(Token)
}
func (c *UserConnContext) GetSDKVersion() string {
return c.Req.URL.Query().Get(SDKVersion)
}
func (c *UserConnContext) GetCompression() bool {
compression, exists := c.Query(Compression)
if exists && compression == GzipCompressionProtocol {
return true
} else {
compression, exists := c.GetHeader(Compression)
if exists && compression == GzipCompressionProtocol {
return true
}
}
return false
}
func (c *UserConnContext) GetSDKType() string {
sdkType := c.Req.URL.Query().Get(SDKType)
if sdkType == "" {
sdkType = GoSDK
}
return sdkType
}
func (c *UserConnContext) ShouldSendResp() bool {
errResp, exists := c.Query(SendResponse)
if exists {
b, err := strconv.ParseBool(errResp)
if err != nil {
return false
} else {
return b
}
}
return false
}
func (c *UserConnContext) SetToken(token string) {
c.Req.URL.RawQuery = Token + "=" + token
}
func (c *UserConnContext) GetBackground() bool {
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
if err != nil {
return false
}
return b
}
func (c *UserConnContext) ParseEssentialArgs() error {
_, exists := c.Query(Token)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
}
_, exists = c.Query(WsUserID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
platformIDStr, exists := c.Query(PlatformID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
}
_, err := strconv.Atoi(platformIDStr)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
switch sdkType, _ := c.Query(SDKType); sdkType {
case "", GoSDK, JsSDK:
default:
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js")
}
return nil
}

View File

@@ -0,0 +1,74 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"bytes"
"encoding/gob"
"encoding/json"
"github.com/openimsdk/tools/errs"
)
type Encoder interface {
Encode(data any) ([]byte, error)
Decode(encodeData []byte, decodeData any) error
}
type GobEncoder struct{}
func NewGobEncoder() Encoder {
return GobEncoder{}
}
func (g GobEncoder) Encode(data any) ([]byte, error) {
var buff bytes.Buffer
enc := gob.NewEncoder(&buff)
if err := enc.Encode(data); err != nil {
return nil, errs.WrapMsg(err, "GobEncoder.Encode failed", "action", "encode")
}
return buff.Bytes(), nil
}
func (g GobEncoder) Decode(encodeData []byte, decodeData any) error {
buff := bytes.NewBuffer(encodeData)
dec := gob.NewDecoder(buff)
if err := dec.Decode(decodeData); err != nil {
return errs.WrapMsg(err, "GobEncoder.Decode failed", "action", "decode")
}
return nil
}
type JsonEncoder struct{}
func NewJsonEncoder() Encoder {
return JsonEncoder{}
}
func (g JsonEncoder) Encode(data any) ([]byte, error) {
b, err := json.Marshal(data)
if err != nil {
return nil, errs.New("JsonEncoder.Encode failed", "action", "encode")
}
return b, nil
}
func (g JsonEncoder) Decode(encodeData []byte, decodeData any) error {
err := json.Unmarshal(encodeData, decodeData)
if err != nil {
return errs.New("JsonEncoder.Decode failed", "action", "decode")
}
return nil
}

View File

@@ -0,0 +1,25 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/log"
)
func httpError(ctx *UserConnContext, err error) {
log.ZWarn(ctx, "ws connection error", err)
apiresp.HttpError(ctx.RespWriter, err)
}

View File

@@ -0,0 +1,264 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"sync/atomic"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/mq/memamq"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/grpc"
)
func (s *Server) InitServer(ctx context.Context, config *Config, disCov discovery.Conn, server grpc.ServiceRegistrar) error {
userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
s.userClient = rpcli.NewUserClient(userConn)
if err := s.LongConnServer.SetDiscoveryRegistry(ctx, disCov, config); err != nil {
return err
}
msggateway.RegisterMsgGatewayServer(server, s)
if s.ready != nil {
return s.ready(s)
}
return nil
}
//func (s *Server) Start(ctx context.Context, index int, conf *Config) error {
// return startrpc.Start(ctx, &conf.Discovery, &conf.MsgGateway.Prometheus, conf.MsgGateway.ListenIP,
// conf.MsgGateway.RPC.RegisterIP,
// conf.MsgGateway.RPC.AutoSetPorts, conf.MsgGateway.RPC.Ports, index,
// conf.Discovery.RpcService.MessageGateway,
// nil,
// conf,
// []string{
// conf.Share.GetConfigFileName(),
// conf.Discovery.GetConfigFileName(),
// conf.MsgGateway.GetConfigFileName(),
// conf.WebhooksConfig.GetConfigFileName(),
// conf.RedisConfig.GetConfigFileName(),
// },
// []string{
// conf.Discovery.RpcService.MessageGateway,
// },
// s.InitServer,
// )
//}
type Server struct {
msggateway.UnimplementedMsgGatewayServer
LongConnServer LongConnServer
config *Config
pushTerminal map[int]struct{}
ready func(srv *Server) error
queue *memamq.MemoryQueue
userClient *rpcli.UserClient
}
func (s *Server) SetLongConnServer(LongConnServer LongConnServer) {
s.LongConnServer = LongConnServer
}
func NewServer(longConnServer LongConnServer, conf *Config, ready func(srv *Server) error) *Server {
s := &Server{
LongConnServer: longConnServer,
pushTerminal: make(map[int]struct{}),
config: conf,
ready: ready,
queue: memamq.NewMemoryQueue(512, 1024*16),
}
s.pushTerminal[constant.IOSPlatformID] = struct{}{}
s.pushTerminal[constant.AndroidPlatformID] = struct{}{}
s.pushTerminal[constant.WebPlatformID] = struct{}{}
return s
}
func (s *Server) GetUsersOnlineStatus(ctx context.Context, req *msggateway.GetUsersOnlineStatusReq) (*msggateway.GetUsersOnlineStatusResp, error) {
if !authverify.IsAdmin(ctx) {
return nil, errs.ErrNoPermission.WrapMsg("only app manager")
}
var resp msggateway.GetUsersOnlineStatusResp
for _, userID := range req.UserIDs {
clients, ok := s.LongConnServer.GetUserAllCons(userID)
if !ok {
continue
}
uresp := new(msggateway.GetUsersOnlineStatusResp_SuccessResult)
uresp.UserID = userID
for _, client := range clients {
if client == nil {
continue
}
ps := new(msggateway.GetUsersOnlineStatusResp_SuccessDetail)
ps.PlatformID = int32(client.PlatformID)
ps.ConnID = client.ctx.GetConnID()
ps.Token = client.token
ps.IsBackground = client.IsBackground
uresp.Status = constant.Online
uresp.DetailPlatformStatus = append(uresp.DetailPlatformStatus, ps)
}
if uresp.Status == constant.Online {
resp.SuccessResult = append(resp.SuccessResult, uresp)
}
}
return &resp, nil
}
func (s *Server) pushToUser(ctx context.Context, userID string, msgData *sdkws.MsgData) *msggateway.SingleMsgToUserResults {
clients, ok := s.LongConnServer.GetUserAllCons(userID)
if !ok {
log.ZDebug(ctx, "push user not online", "userID", userID)
return &msggateway.SingleMsgToUserResults{
UserID: userID,
}
}
log.ZDebug(ctx, "push user online", "clients", clients, "userID", userID)
result := &msggateway.SingleMsgToUserResults{
UserID: userID,
Resp: make([]*msggateway.SingleMsgToUserPlatform, 0, len(clients)),
}
for _, client := range clients {
if client == nil {
continue
}
userPlatform := &msggateway.SingleMsgToUserPlatform{
RecvPlatFormID: int32(client.PlatformID),
}
if !client.IsBackground ||
(client.IsBackground && client.PlatformID != constant.IOSPlatformID) {
err := client.PushMessage(ctx, msgData)
if err != nil {
log.ZWarn(ctx, "online push msg failed", err, "userID", userID, "platformID", client.PlatformID)
userPlatform.ResultCode = int64(servererrs.ErrPushMsgErr.Code())
} else {
if _, ok := s.pushTerminal[client.PlatformID]; ok {
result.OnlinePush = true
}
}
} else {
userPlatform.ResultCode = int64(servererrs.ErrIOSBackgroundPushErr.Code())
}
result.Resp = append(result.Resp, userPlatform)
}
return result
}
func (s *Server) SuperGroupOnlineBatchPushOneMsg(ctx context.Context, req *msggateway.OnlineBatchPushOneMsgReq) (*msggateway.OnlineBatchPushOneMsgResp, error) {
if len(req.PushToUserIDs) == 0 {
return &msggateway.OnlineBatchPushOneMsgResp{}, nil
}
ch := make(chan *msggateway.SingleMsgToUserResults, len(req.PushToUserIDs))
var count atomic.Int64
count.Add(int64(len(req.PushToUserIDs)))
for i := range req.PushToUserIDs {
userID := req.PushToUserIDs[i]
err := s.queue.PushCtx(ctx, func() {
ch <- s.pushToUser(ctx, userID, req.MsgData)
if count.Add(-1) == 0 {
close(ch)
}
})
if err != nil {
if count.Add(-1) == 0 {
close(ch)
}
log.ZError(ctx, "pushToUser MemoryQueue failed", err, "userID", userID)
ch <- &msggateway.SingleMsgToUserResults{
UserID: userID,
}
}
}
resp := &msggateway.OnlineBatchPushOneMsgResp{
SinglePushResult: make([]*msggateway.SingleMsgToUserResults, 0, len(req.PushToUserIDs)),
}
for {
select {
case <-ctx.Done():
log.ZError(ctx, "SuperGroupOnlineBatchPushOneMsg ctx done", context.Cause(ctx))
userIDSet := datautil.SliceSet(req.PushToUserIDs)
for _, results := range resp.SinglePushResult {
delete(userIDSet, results.UserID)
}
for userID := range userIDSet {
resp.SinglePushResult = append(resp.SinglePushResult, &msggateway.SingleMsgToUserResults{
UserID: userID,
})
}
return resp, nil
case res, ok := <-ch:
if !ok {
return resp, nil
}
resp.SinglePushResult = append(resp.SinglePushResult, res)
}
}
}
func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOfflineReq) (*msggateway.KickUserOfflineResp, error) {
for _, v := range req.KickUserIDList {
clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID))
if !ok {
log.ZDebug(ctx, "conn not exist", "userID", v, "platformID", req.PlatformID)
continue
}
for _, client := range clients {
log.ZDebug(ctx, "kick user offline", "userID", v, "platformID", req.PlatformID, "client", client)
if err := client.longConnServer.KickUserConn(client); err != nil {
log.ZWarn(ctx, "kick user offline failed", err, "userID", v, "platformID", req.PlatformID)
}
}
continue
}
return &msggateway.KickUserOfflineResp{}, nil
}
func (s *Server) MultiTerminalLoginCheck(ctx context.Context, req *msggateway.MultiTerminalLoginCheckReq) (*msggateway.MultiTerminalLoginCheckResp, error) {
if oldClients, userOK, clientOK := s.LongConnServer.GetUserPlatformCons(req.UserID, int(req.PlatformID)); userOK {
tempUserCtx := newTempContext()
tempUserCtx.SetToken(req.Token)
tempUserCtx.SetOperationID(mcontext.GetOperationID(ctx))
client := &Client{}
client.ctx = tempUserCtx
client.token = req.Token
client.UserID = req.UserID
client.PlatformID = int(req.PlatformID)
i := &kickHandler{
clientOK: clientOK,
oldClients: oldClients,
newClient: client,
}
s.LongConnServer.SetKickHandlerInfo(i)
}
return &msggateway.MultiTerminalLoginCheckResp{}, nil
}

117
internal/msggateway/init.go Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpccache"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/runtimeenv"
"google.golang.org/grpc"
"github.com/openimsdk/tools/log"
)
type Config struct {
MsgGateway config.MsgGateway
Share config.Share
RedisConfig config.Redis
WebhooksConfig config.Webhooks
Discovery config.Discovery
Index config.Index
}
// Start run ws server.
func Start(ctx context.Context, conf *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
log.CInfo(ctx, "MSG-GATEWAY server is initializing", "runtimeEnv", runtimeenv.RuntimeEnvironment(),
"rpcPorts", conf.MsgGateway.RPC.Ports,
"wsPort", conf.MsgGateway.LongConnSvr.Ports, "prometheusPorts", conf.MsgGateway.Prometheus.Ports)
wsPort, err := datautil.GetElemByIndex(conf.MsgGateway.LongConnSvr.Ports, int(conf.Index))
if err != nil {
return err
}
dbb := dbbuild.NewBuilder(nil, &conf.RedisConfig)
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
longServer := NewWsServer(
conf,
WithPort(wsPort),
WithMaxConnNum(int64(conf.MsgGateway.LongConnSvr.WebsocketMaxConnNum)),
WithHandshakeTimeout(time.Duration(conf.MsgGateway.LongConnSvr.WebsocketTimeout)*time.Second),
WithMessageMaxMsgLength(conf.MsgGateway.LongConnSvr.WebsocketMaxMsgLen),
)
hubServer := NewServer(longServer, conf, func(srv *Server) error {
var err error
longServer.online, err = rpccache.NewOnlineCache(srv.userClient, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges)
return err
})
if err := hubServer.InitServer(ctx, conf, client, server); err != nil {
return err
}
go longServer.ChangeOnlineStatus(4)
return hubServer.LongConnServer.Run(ctx)
}
//
//// Start run ws server.
//func Start(ctx context.Context, index int, conf *Config) error {
// log.CInfo(ctx, "MSG-GATEWAY server is initializing", "runtimeEnv", runtimeenv.RuntimeEnvironment(),
// "rpcPorts", conf.MsgGateway.RPC.Ports,
// "wsPort", conf.MsgGateway.LongConnSvr.Ports, "prometheusPorts", conf.MsgGateway.Prometheus.Ports)
// wsPort, err := datautil.GetElemByIndex(conf.MsgGateway.LongConnSvr.Ports, index)
// if err != nil {
// return err
// }
//
// rdb, err := redisutil.NewRedisClient(ctx, conf.RedisConfig.Build())
// if err != nil {
// return err
// }
// longServer := NewWsServer(
// conf,
// WithPort(wsPort),
// WithMaxConnNum(int64(conf.MsgGateway.LongConnSvr.WebsocketMaxConnNum)),
// WithHandshakeTimeout(time.Duration(conf.MsgGateway.LongConnSvr.WebsocketTimeout)*time.Second),
// WithMessageMaxMsgLength(conf.MsgGateway.LongConnSvr.WebsocketMaxMsgLen),
// )
//
// hubServer := NewServer(longServer, conf, func(srv *Server) error {
// var err error
// longServer.online, err = rpccache.NewOnlineCache(srv.userClient, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges)
// return err
// })
//
// go longServer.ChangeOnlineStatus(4)
//
// netDone := make(chan error)
// go func() {
// err = hubServer.Start(ctx, index, conf)
// netDone <- err
// }()
// return hubServer.LongConnServer.Run(netDone)
//}

View File

@@ -0,0 +1,179 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"encoding/json"
"net/http"
"time"
"github.com/openimsdk/tools/apiresp"
"github.com/gorilla/websocket"
"github.com/openimsdk/tools/errs"
)
type LongConn interface {
// Close this connection
Close() error
// WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1).
WriteMessage(messageType int, message []byte) error
// ReadMessage Read message from connection.
ReadMessage() (int, []byte, error)
// SetReadDeadline sets the read deadline on the underlying network connection,
// after a read has timed out, will return an error.
SetReadDeadline(timeout time.Duration) error
// SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error.
SetWriteDeadline(timeout time.Duration) error
// Dial Try to dial a connection,url must set auth args,header can control compress data
Dial(urlStr string, requestHeader http.Header) (*http.Response, error)
// IsNil Whether the connection of the current long connection is nil
IsNil() bool
// SetConnNil Set the connection of the current long connection to nil
SetConnNil()
// SetReadLimit sets the maximum size for a message read from the peer.bytes
SetReadLimit(limit int64)
SetPongHandler(handler PingPongHandler)
SetPingHandler(handler PingPongHandler)
// GenerateLongConn Check the connection of the current and when it was sent are the same
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
}
type GWebSocket struct {
protocolType int
conn *websocket.Conn
handshakeTimeout time.Duration
writeBufferSize int
}
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket {
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs}
}
func (d *GWebSocket) Close() error {
return d.conn.Close()
}
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
upgrader := &websocket.Upgrader{
HandshakeTimeout: d.handshakeTimeout,
CheckOrigin: func(r *http.Request) bool { return true },
}
if d.writeBufferSize > 0 { // default is 4kb.
upgrader.WriteBufferSize = d.writeBufferSize
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
// The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade
return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed")
}
d.conn = conn
return nil
}
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
// d.setSendConn(d.conn)
return d.conn.WriteMessage(messageType, message)
}
// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) {
// d.sendConn = sendConn
//}
func (d *GWebSocket) ReadMessage() (int, []byte, error) {
return d.conn.ReadMessage()
}
func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
return d.conn.SetReadDeadline(time.Now().Add(timeout))
}
func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
if timeout <= 0 {
return errs.New("timeout must be greater than 0")
}
// TODO SetWriteDeadline Future add error handling
if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed")
}
return nil
}
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {
conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader)
if err != nil {
return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr)
}
d.conn = conn
return httpResp, nil
}
func (d *GWebSocket) IsNil() bool {
return d.conn == nil
//
// if d.conn != nil {
// return false
// }
// return true
}
func (d *GWebSocket) SetConnNil() {
d.conn = nil
}
func (d *GWebSocket) SetReadLimit(limit int64) {
d.conn.SetReadLimit(limit)
}
func (d *GWebSocket) SetPongHandler(handler PingPongHandler) {
d.conn.SetPongHandler(handler)
}
func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
d.conn.SetPingHandler(handler)
}
func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error {
if err := d.GenerateLongConn(w, r); err != nil {
return err
}
data, err := json.Marshal(apiresp.ParseError(err))
if err != nil {
_ = d.Close()
return errs.WrapMsg(err, "json marshal failed")
}
if err := d.WriteMessage(MessageText, data); err != nil {
_ = d.Close()
return errs.WrapMsg(err, "WriteMessage failed")
}
_ = d.Close()
return nil
}
func (d *GWebSocket) RespondWithSuccess() error {
data, err := json.Marshal(apiresp.ParseError(nil))
if err != nil {
_ = d.Close()
return errs.WrapMsg(err, "json marshal failed")
}
if err := d.WriteMessage(MessageText, data); err != nil {
_ = d.Close()
return errs.WrapMsg(err, "WriteMessage failed")
}
return nil
}

View File

@@ -0,0 +1,282 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"encoding/json"
"sync"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/go-playground/validator/v10"
"google.golang.org/protobuf/proto"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/push"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/jsonutil"
)
const (
TextPing = "ping"
TextPong = "pong"
)
type TextMessage struct {
Type string `json:"type"`
Body json.RawMessage `json:"body"`
}
type Req struct {
ReqIdentifier int32 `json:"reqIdentifier" validate:"required"`
Token string `json:"token"`
SendID string `json:"sendID" validate:"required"`
OperationID string `json:"operationID" validate:"required"`
MsgIncr string `json:"msgIncr" validate:"required"`
Data []byte `json:"data"`
}
func (r *Req) String() string {
var tReq Req
tReq.ReqIdentifier = r.ReqIdentifier
tReq.Token = r.Token
tReq.SendID = r.SendID
tReq.OperationID = r.OperationID
tReq.MsgIncr = r.MsgIncr
return jsonutil.StructToJsonString(tReq)
}
var reqPool = sync.Pool{
New: func() any {
return new(Req)
},
}
func getReq() *Req {
req := reqPool.Get().(*Req)
req.Data = nil
req.MsgIncr = ""
req.OperationID = ""
req.ReqIdentifier = 0
req.SendID = ""
req.Token = ""
return req
}
func freeReq(req *Req) {
reqPool.Put(req)
}
type Resp struct {
ReqIdentifier int32 `json:"reqIdentifier"`
MsgIncr string `json:"msgIncr"`
OperationID string `json:"operationID"`
ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"`
Data []byte `json:"data"`
}
func (r *Resp) String() string {
var tResp Resp
tResp.ReqIdentifier = r.ReqIdentifier
tResp.MsgIncr = r.MsgIncr
tResp.OperationID = r.OperationID
tResp.ErrCode = r.ErrCode
tResp.ErrMsg = r.ErrMsg
return jsonutil.StructToJsonString(tResp)
}
type MessageHandler interface {
GetSeq(ctx context.Context, data *Req) ([]byte, error)
SendMessage(ctx context.Context, data *Req) ([]byte, error)
SendSignalMessage(ctx context.Context, data *Req) ([]byte, error)
PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error)
GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error)
GetSeqMessage(ctx context.Context, data *Req) ([]byte, error)
UserLogout(ctx context.Context, data *Req) ([]byte, error)
SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error)
GetLastMessage(ctx context.Context, data *Req) ([]byte, error)
}
var _ MessageHandler = (*GrpcHandler)(nil)
type GrpcHandler struct {
validate *validator.Validate
msgClient *rpcli.MsgClient
pushClient *rpcli.PushMsgServiceClient
}
func NewGrpcHandler(validate *validator.Validate, msgClient *rpcli.MsgClient, pushClient *rpcli.PushMsgServiceClient) *GrpcHandler {
return &GrpcHandler{
validate: validate,
msgClient: msgClient,
pushClient: pushClient,
}
}
func (g *GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) {
req := sdkws.GetMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "GetSeq: error unmarshaling request", "action", "unmarshal", "dataType", "GetMaxSeqReq")
}
if err := g.validate.Struct(&req); err != nil {
return nil, errs.WrapMsg(err, "GetSeq: validation failed", "action", "validate", "dataType", "GetMaxSeqReq")
}
resp, err := g.msgClient.MsgClient.GetMaxSeq(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "GetSeq: error marshaling response", "action", "marshal", "dataType", "GetMaxSeqResp")
}
return c, nil
}
// SendMessage handles the sending of messages through gRPC. It unmarshals the request data,
// validates the message, and then sends it using the message RPC client.
func (g *GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) {
var msgData sdkws.MsgData
if err := proto.Unmarshal(data.Data, &msgData); err != nil {
return nil, errs.WrapMsg(err, "SendMessage: error unmarshaling message data", "action", "unmarshal", "dataType", "MsgData")
}
if err := g.validate.Struct(&msgData); err != nil {
return nil, errs.WrapMsg(err, "SendMessage: message data validation failed", "action", "validate", "dataType", "MsgData")
}
req := msg.SendMsgReq{MsgData: &msgData}
resp, err := g.msgClient.MsgClient.SendMsg(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "SendMessage: error marshaling response", "action", "marshal", "dataType", "SendMsgResp")
}
return c, nil
}
func (g *GrpcHandler) SendSignalMessage(ctx context.Context, data *Req) ([]byte, error) {
resp, err := g.msgClient.MsgClient.SendMsg(ctx, nil)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "SendMsgResp")
}
return c, nil
}
func (g *GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error) {
req := sdkws.PullMessageBySeqsReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "PullMessageBySeqsReq")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "PullMessageBySeqsReq")
}
resp, err := g.msgClient.MsgClient.PullMessageBySeqs(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "PullMessageBySeqsResp")
}
return c, nil
}
func (g *GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error) {
req := msg.GetConversationsHasReadAndMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "GetConversationsHasReadAndMaxSeq")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetConversationsHasReadAndMaxSeq")
}
resp, err := g.msgClient.MsgClient.GetConversationsHasReadAndMaxSeq(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "GetConversationsHasReadAndMaxSeq")
}
return c, nil
}
func (g *GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, error) {
req := msg.GetSeqMessageReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "GetSeqMessage")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetSeqMessage")
}
resp, err := g.msgClient.MsgClient.GetSeqMessage(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "GetSeqMessage")
}
return c, nil
}
func (g *GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error) {
req := push.DelUserPushTokenReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "DelUserPushTokenReq")
}
resp, err := g.pushClient.PushMsgServiceClient.DelUserPushToken(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "DelUserPushTokenResp")
}
return c, nil
}
func (g *GrpcHandler) SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error) {
req := sdkws.SetAppBackgroundStatusReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, false, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "SetAppBackgroundStatusReq")
}
if err := g.validate.Struct(data); err != nil {
return nil, false, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "SetAppBackgroundStatusReq")
}
return nil, req.IsBackground, nil
}
func (g *GrpcHandler) GetLastMessage(ctx context.Context, data *Req) ([]byte, error) {
var req msg.GetLastMessageReq
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, err
}
resp, err := g.msgClient.GetLastMessage(ctx, &req)
if err != nil {
return nil, err
}
return proto.Marshal(resp)
}

View File

@@ -0,0 +1,131 @@
package msggateway
import (
"context"
"crypto/md5"
"encoding/binary"
"fmt"
"math/rand"
"os"
"strconv"
"sync/atomic"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
pbuser "git.imall.cloud/openim/protocol/user"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
func (ws *WsServer) ChangeOnlineStatus(concurrent int) {
if concurrent < 1 {
concurrent = 1
}
const renewalTime = cachekey.OnlineExpire / 3
//const renewalTime = time.Second * 10
renewalTicker := time.NewTicker(renewalTime)
requestChs := make([]chan *pbuser.SetUserOnlineStatusReq, concurrent)
changeStatus := make([][]UserState, concurrent)
for i := 0; i < concurrent; i++ {
requestChs[i] = make(chan *pbuser.SetUserOnlineStatusReq, 64)
changeStatus[i] = make([]UserState, 0, 100)
}
mergeTicker := time.NewTicker(time.Second)
local2pb := func(u UserState) *pbuser.UserOnlineStatus {
return &pbuser.UserOnlineStatus{
UserID: u.UserID,
Online: u.Online,
Offline: u.Offline,
}
}
rNum := rand.Uint64()
pushUserState := func(us ...UserState) {
for _, u := range us {
sum := md5.Sum([]byte(u.UserID))
i := (binary.BigEndian.Uint64(sum[:]) + rNum) % uint64(concurrent)
changeStatus[i] = append(changeStatus[i], u)
status := changeStatus[i]
if len(status) == cap(status) {
req := &pbuser.SetUserOnlineStatusReq{
Status: datautil.Slice(status, local2pb),
}
changeStatus[i] = status[:0]
select {
case requestChs[i] <- req:
default:
log.ZError(context.Background(), "user online processing is too slow", nil)
}
}
}
}
pushAllUserState := func() {
for i, status := range changeStatus {
if len(status) == 0 {
continue
}
req := &pbuser.SetUserOnlineStatusReq{
Status: datautil.Slice(status, local2pb),
}
changeStatus[i] = status[:0]
select {
case requestChs[i] <- req:
default:
log.ZError(context.Background(), "user online processing is too slow", nil)
}
}
}
var count atomic.Int64
operationIDPrefix := fmt.Sprintf("p_%d_", os.Getpid())
doRequest := func(req *pbuser.SetUserOnlineStatusReq) {
opIdCtx := mcontext.SetOperationID(context.Background(), operationIDPrefix+strconv.FormatInt(count.Add(1), 10))
ctx, cancel := context.WithTimeout(opIdCtx, time.Second*5)
defer cancel()
if err := ws.userClient.SetUserOnlineStatus(ctx, req); err != nil {
log.ZError(ctx, "update user online status", err)
}
for _, ss := range req.Status {
for _, online := range ss.Online {
client, _, _ := ws.clients.Get(ss.UserID, int(online))
back := false
if len(client) > 0 {
back = client[0].IsBackground
}
ws.webhookAfterUserOnline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOnline, ss.UserID, int(online), back, ss.ConnID)
}
for _, offline := range ss.Offline {
ws.webhookAfterUserOffline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOffline, ss.UserID, int(offline), ss.ConnID)
}
}
}
for i := 0; i < concurrent; i++ {
go func(ch <-chan *pbuser.SetUserOnlineStatusReq) {
for req := range ch {
doRequest(req)
}
}(requestChs[i])
}
for {
select {
case <-mergeTicker.C:
pushAllUserState()
case now := <-renewalTicker.C:
deadline := now.Add(-cachekey.OnlineExpire / 3)
users := ws.clients.GetAllUserStatus(deadline, now)
log.ZDebug(context.Background(), "renewal ticker", "deadline", deadline, "nowtime", now, "num", len(users), "users", users)
pushUserState(users...)
case state := <-ws.clients.UserState():
log.ZDebug(context.Background(), "OnlineCache user online change", "userID", state.UserID, "online", state.Online, "offline", state.Offline)
pushUserState(state)
}
}
}

View File

@@ -0,0 +1,63 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import "time"
type (
Option func(opt *configs)
configs struct {
// Long connection listening port
port int
// Maximum number of connections allowed for long connection
maxConnNum int64
// Connection handshake timeout
handshakeTimeout time.Duration
// Maximum length allowed for messages
messageMaxMsgLength int
// Websocket write buffer, default: 4096, 4kb.
writeBufferSize int
}
)
func WithPort(port int) Option {
return func(opt *configs) {
opt.port = port
}
}
func WithMaxConnNum(num int64) Option {
return func(opt *configs) {
opt.maxConnNum = num
}
}
func WithHandshakeTimeout(t time.Duration) Option {
return func(opt *configs) {
opt.handshakeTimeout = t
}
}
func WithMessageMaxMsgLength(length int) Option {
return func(opt *configs) {
opt.messageMaxMsgLength = length
}
}
func WithWriteBufferSize(size int) Option {
return func(opt *configs) {
opt.writeBufferSize = size
}
}

View File

@@ -0,0 +1,167 @@
package msggateway
import (
"context"
"sync"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/protobuf/proto"
)
func (ws *WsServer) subscriberUserOnlineStatusChanges(ctx context.Context, userID string, platformIDs []int32) {
if ws.clients.RecvSubChange(userID, platformIDs) {
log.ZDebug(ctx, "gateway receive subscription message and go back online", "userID", userID, "platformIDs", platformIDs)
} else {
log.ZDebug(ctx, "gateway ignore user online status changes", "userID", userID, "platformIDs", platformIDs)
}
ws.pushUserIDOnlineStatus(ctx, userID, platformIDs)
}
func (ws *WsServer) SubUserOnlineStatus(ctx context.Context, client *Client, data *Req) ([]byte, error) {
var sub sdkws.SubUserOnlineStatus
if err := proto.Unmarshal(data.Data, &sub); err != nil {
return nil, err
}
ws.subscription.Sub(client, sub.SubscribeUserID, sub.UnsubscribeUserID)
var resp sdkws.SubUserOnlineStatusTips
if len(sub.SubscribeUserID) > 0 {
resp.Subscribers = make([]*sdkws.SubUserOnlineStatusElem, 0, len(sub.SubscribeUserID))
for _, userID := range sub.SubscribeUserID {
platformIDs, err := ws.online.GetUserOnlinePlatform(ctx, userID)
if err != nil {
return nil, err
}
resp.Subscribers = append(resp.Subscribers, &sdkws.SubUserOnlineStatusElem{
UserID: userID,
OnlinePlatformIDs: platformIDs,
})
}
}
return proto.Marshal(&resp)
}
func newSubscription() *Subscription {
return &Subscription{
userIDs: make(map[string]*subClient),
}
}
type subClient struct {
clients map[string]*Client
}
type Subscription struct {
lock sync.RWMutex
userIDs map[string]*subClient // subscribe to the user's client connection
}
func (s *Subscription) DelClient(client *Client) {
client.subLock.Lock()
userIDs := datautil.Keys(client.subUserIDs)
for _, userID := range userIDs {
delete(client.subUserIDs, userID)
}
client.subLock.Unlock()
if len(userIDs) == 0 {
return
}
addr := client.ctx.GetRemoteAddr()
s.lock.Lock()
defer s.lock.Unlock()
for _, userID := range userIDs {
sub, ok := s.userIDs[userID]
if !ok {
continue
}
delete(sub.clients, addr)
if len(sub.clients) == 0 {
delete(s.userIDs, userID)
}
}
}
func (s *Subscription) GetClient(userID string) []*Client {
s.lock.RLock()
defer s.lock.RUnlock()
cs, ok := s.userIDs[userID]
if !ok {
return nil
}
clients := make([]*Client, 0, len(cs.clients))
for _, client := range cs.clients {
clients = append(clients, client)
}
return clients
}
func (s *Subscription) Sub(client *Client, addUserIDs, delUserIDs []string) {
if len(addUserIDs)+len(delUserIDs) == 0 {
return
}
var (
del = make(map[string]struct{})
add = make(map[string]struct{})
)
client.subLock.Lock()
for _, userID := range delUserIDs {
if _, ok := client.subUserIDs[userID]; !ok {
continue
}
del[userID] = struct{}{}
delete(client.subUserIDs, userID)
}
for _, userID := range addUserIDs {
delete(del, userID)
if _, ok := client.subUserIDs[userID]; ok {
continue
}
client.subUserIDs[userID] = struct{}{}
add[userID] = struct{}{}
}
client.subLock.Unlock()
if len(del)+len(add) == 0 {
return
}
addr := client.ctx.GetRemoteAddr()
s.lock.Lock()
defer s.lock.Unlock()
for userID := range del {
sub, ok := s.userIDs[userID]
if !ok {
continue
}
delete(sub.clients, addr)
if len(sub.clients) == 0 {
delete(s.userIDs, userID)
}
}
for userID := range add {
sub, ok := s.userIDs[userID]
if !ok {
sub = &subClient{clients: make(map[string]*Client)}
s.userIDs[userID] = sub
}
sub.clients[addr] = client
}
}
func (ws *WsServer) pushUserIDOnlineStatus(ctx context.Context, userID string, platformIDs []int32) {
clients := ws.subscription.GetClient(userID)
if len(clients) == 0 {
return
}
onlineStatus, err := proto.Marshal(&sdkws.SubUserOnlineStatusTips{
Subscribers: []*sdkws.SubUserOnlineStatusElem{{UserID: userID, OnlinePlatformIDs: platformIDs}},
})
if err != nil {
log.ZError(ctx, "pushUserIDOnlineStatus json.Marshal", err)
return
}
for _, client := range clients {
if err := client.PushUserOnlineStatus(onlineStatus); err != nil {
log.ZError(ctx, "UserSubscribeOnlineStatusNotification push failed", err, "userID", client.UserID, "platformID", client.PlatformID, "changeUserID", userID, "changePlatformID", platformIDs)
}
}
}

View File

@@ -0,0 +1,185 @@
package msggateway
import (
"github.com/openimsdk/tools/utils/datautil"
"sync"
"time"
)
type UserMap interface {
GetAll(userID string) ([]*Client, bool)
Get(userID string, platformID int) ([]*Client, bool, bool)
Set(userID string, v *Client)
DeleteClients(userID string, clients []*Client) (isDeleteUser bool)
UserState() <-chan UserState
GetAllUserStatus(deadline time.Time, nowtime time.Time) []UserState
RecvSubChange(userID string, platformIDs []int32) bool
}
type UserState struct {
UserID string
Online []int32
Offline []int32
}
type UserPlatform struct {
Time time.Time
Clients []*Client
}
func (u *UserPlatform) PlatformIDs() []int32 {
if len(u.Clients) == 0 {
return nil
}
platformIDs := make([]int32, 0, len(u.Clients))
for _, client := range u.Clients {
platformIDs = append(platformIDs, int32(client.PlatformID))
}
return platformIDs
}
func (u *UserPlatform) PlatformIDSet() map[int32]struct{} {
if len(u.Clients) == 0 {
return nil
}
platformIDs := make(map[int32]struct{})
for _, client := range u.Clients {
platformIDs[int32(client.PlatformID)] = struct{}{}
}
return platformIDs
}
func newUserMap() UserMap {
return &userMap{
data: make(map[string]*UserPlatform),
ch: make(chan UserState, 10000),
}
}
type userMap struct {
lock sync.RWMutex
data map[string]*UserPlatform
ch chan UserState
}
func (u *userMap) RecvSubChange(userID string, platformIDs []int32) bool {
u.lock.RLock()
defer u.lock.RUnlock()
result, ok := u.data[userID]
if !ok {
return false
}
localPlatformIDs := result.PlatformIDSet()
for _, platformID := range platformIDs {
delete(localPlatformIDs, platformID)
}
if len(localPlatformIDs) == 0 {
return false
}
u.push(userID, result, nil)
return true
}
func (u *userMap) push(userID string, userPlatform *UserPlatform, offline []int32) bool {
select {
case u.ch <- UserState{UserID: userID, Online: userPlatform.PlatformIDs(), Offline: offline}:
userPlatform.Time = time.Now()
return true
default:
return false
}
}
func (u *userMap) GetAll(userID string) ([]*Client, bool) {
u.lock.RLock()
defer u.lock.RUnlock()
result, ok := u.data[userID]
if !ok {
return nil, false
}
return result.Clients, true
}
func (u *userMap) Get(userID string, platformID int) ([]*Client, bool, bool) {
u.lock.RLock()
defer u.lock.RUnlock()
result, ok := u.data[userID]
if !ok {
return nil, false, false
}
var clients []*Client
for _, client := range result.Clients {
if client.PlatformID == platformID {
clients = append(clients, client)
}
}
return clients, true, len(clients) > 0
}
func (u *userMap) Set(userID string, client *Client) {
u.lock.Lock()
defer u.lock.Unlock()
result, ok := u.data[userID]
if ok {
result.Clients = append(result.Clients, client)
} else {
result = &UserPlatform{
Clients: []*Client{client},
}
u.data[userID] = result
}
u.push(client.UserID, result, nil)
}
func (u *userMap) DeleteClients(userID string, clients []*Client) (isDeleteUser bool) {
if len(clients) == 0 {
return false
}
u.lock.Lock()
defer u.lock.Unlock()
result, ok := u.data[userID]
if !ok {
return false
}
offline := make([]int32, 0, len(clients))
deleteAddr := datautil.SliceSetAny(clients, func(client *Client) string {
return client.ctx.GetRemoteAddr()
})
tmp := result.Clients
result.Clients = result.Clients[:0]
for _, client := range tmp {
if _, delCli := deleteAddr[client.ctx.GetRemoteAddr()]; delCli {
offline = append(offline, int32(client.PlatformID))
} else {
result.Clients = append(result.Clients, client)
}
}
defer u.push(userID, result, offline)
if len(result.Clients) > 0 {
return false
}
delete(u.data, userID)
return true
}
func (u *userMap) GetAllUserStatus(deadline time.Time, nowtime time.Time) (result []UserState) {
u.lock.RLock()
defer u.lock.RUnlock()
result = make([]UserState, 0, len(u.data))
for userID, userPlatform := range u.data {
if deadline.Before(userPlatform.Time) {
continue
}
userPlatform.Time = nowtime
online := make([]int32, 0, len(userPlatform.Clients))
for _, client := range userPlatform.Clients {
online = append(online, int32(client.PlatformID))
}
result = append(result, UserState{UserID: userID, Online: online})
}
return result
}
func (u *userMap) UserState() <-chan UserState {
return u.ch
}

View File

@@ -0,0 +1,568 @@
package msggateway
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpccache"
pbAuth "git.imall.cloud/openim/protocol/auth"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/mcontext"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
"github.com/go-playground/validator/v10"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/stringutil"
"golang.org/x/sync/errgroup"
)
type LongConnServer interface {
Run(ctx context.Context) error
wsHandler(w http.ResponseWriter, r *http.Request)
GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s any) error
SetDiscoveryRegistry(ctx context.Context, client discovery.Conn, config *Config) error
KickUserConn(client *Client) error
UnRegister(c *Client)
SetKickHandlerInfo(i *kickHandler)
SubUserOnlineStatus(ctx context.Context, client *Client, data *Req) ([]byte, error)
Compressor
MessageHandler
}
type WsServer struct {
msgGatewayConfig *Config
port int
wsMaxConnNum int64
registerChan chan *Client
unregisterChan chan *Client
kickHandlerChan chan *kickHandler
clients UserMap
online rpccache.OnlineCache
subscription *Subscription
clientPool sync.Pool
onlineUserNum atomic.Int64
onlineUserConnNum atomic.Int64
handshakeTimeout time.Duration
writeBufferSize int
validate *validator.Validate
disCov discovery.Conn
Compressor
//Encoder
MessageHandler
webhookClient *webhook.Client
userClient *rpcli.UserClient
authClient *rpcli.AuthClient
ready atomic.Bool
}
type kickHandler struct {
clientOK bool
oldClients []*Client
newClient *Client
}
func (ws *WsServer) SetDiscoveryRegistry(ctx context.Context, disCov discovery.Conn, config *Config) error {
userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
pushConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Push)
if err != nil {
return err
}
authConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Auth)
if err != nil {
return err
}
msgConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Msg)
if err != nil {
return err
}
ws.userClient = rpcli.NewUserClient(userConn)
ws.authClient = rpcli.NewAuthClient(authConn)
ws.MessageHandler = NewGrpcHandler(ws.validate, rpcli.NewMsgClient(msgConn), rpcli.NewPushMsgServiceClient(pushConn))
ws.disCov = disCov
ws.ready.Store(true)
return nil
}
//func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, status int32) {
// err := ws.userClient.SetUserStatus(ctx, client.UserID, status, client.PlatformID)
// if err != nil {
// log.ZWarn(ctx, "SetUserStatus err", err)
// }
// switch status {
// case constant.Online:
// ws.webhookAfterUserOnline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOnline, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID())
// case constant.Offline:
// ws.webhookAfterUserOffline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOffline, client.UserID, client.PlatformID, client.ctx.GetConnID())
// }
//}
func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c
}
func (ws *WsServer) Validate(_ any) error {
return nil
}
func (ws *WsServer) GetUserAllCons(userID string) ([]*Client, bool) {
return ws.clients.GetAll(userID)
}
func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) {
return ws.clients.Get(userID, platform)
}
func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer {
var config configs
for _, o := range opts {
o(&config)
}
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Discovery.RpcService.User, config.Share.IMAdminUser)
v := validator.New()
return &WsServer{
msgGatewayConfig: msgGatewayConfig,
port: config.port,
wsMaxConnNum: config.maxConnNum,
writeBufferSize: config.writeBufferSize,
handshakeTimeout: config.handshakeTimeout,
clientPool: sync.Pool{
New: func() any {
return new(Client)
},
},
registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000),
kickHandlerChan: make(chan *kickHandler, 1000),
validate: v,
clients: newUserMap(),
subscription: newSubscription(),
Compressor: NewGzipCompressor(),
webhookClient: webhook.NewWebhookClient(msgGatewayConfig.WebhooksConfig.URL),
}
}
func (ws *WsServer) Run(ctx context.Context) error {
var client *Client
ctx, cancel := context.WithCancelCause(ctx)
go func() {
for {
select {
case <-ctx.Done():
return
case client = <-ws.registerChan:
ws.registerClient(client)
case client = <-ws.unregisterChan:
ws.unregisterClient(client)
case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo.clientOK, onlineInfo.oldClients, onlineInfo.newClient)
}
}
}()
done := make(chan struct{})
go func() {
wsServer := http.Server{Addr: fmt.Sprintf(":%d", ws.port), Handler: nil}
http.HandleFunc("/", ws.wsHandler)
go func() {
defer close(done)
<-ctx.Done()
_ = wsServer.Shutdown(context.Background())
}()
err := wsServer.ListenAndServe()
if err == nil {
err = fmt.Errorf("http server closed")
}
cancel(fmt.Errorf("msg gateway %w", err))
}()
<-ctx.Done()
timeout := time.NewTimer(time.Second * 15)
defer timeout.Stop()
select {
case <-timeout.C:
log.ZWarn(ctx, "msg gateway graceful stop timeout", nil)
case <-done:
log.ZDebug(ctx, "msg gateway graceful stop done")
}
return context.Cause(ctx)
}
const concurrentRequest = 3
func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error {
conns, err := ws.disCov.GetConns(ctx, ws.msgGatewayConfig.Discovery.RpcService.MessageGateway)
if err != nil {
return err
}
if len(conns) == 0 || (len(conns) == 1 && ws.disCov.IsSelfNode(conns[0])) {
return nil
}
wg := errgroup.Group{}
wg.SetLimit(concurrentRequest)
// Online push user online message to other node
for _, v := range conns {
v := v
log.ZDebug(ctx, "sendUserOnlineInfoToOtherNode conn")
if ws.disCov.IsSelfNode(v) {
log.ZDebug(ctx, "Filter out this node")
continue
}
wg.Go(func() error {
msgClient := msggateway.NewMsgGatewayClient(v)
_, err := msgClient.MultiTerminalLoginCheck(ctx, &msggateway.MultiTerminalLoginCheckReq{
UserID: client.UserID,
PlatformID: int32(client.PlatformID), Token: client.token,
})
if err != nil {
log.ZWarn(ctx, "MultiTerminalLoginCheck err", err)
}
return nil
})
}
_ = wg.Wait()
return nil
}
func (ws *WsServer) SetKickHandlerInfo(i *kickHandler) {
ws.kickHandlerChan <- i
}
func (ws *WsServer) registerClient(client *Client) {
var (
userOK bool
clientOK bool
oldClients []*Client
)
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID,
"sdkVersion", client.SDKVersion)
if !userOK {
ws.clients.Set(client.UserID, client)
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
prommetrics.OnlineUserGauge.Add(1)
ws.onlineUserNum.Add(1)
ws.onlineUserConnNum.Add(1)
} else {
ws.multiTerminalLoginChecker(clientOK, oldClients, client)
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
if clientOK {
ws.clients.Set(client.UserID, client)
// There is already a connection to the platform
log.ZDebug(client.ctx, "repeat login", "userID", client.UserID, "platformID",
client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
ws.onlineUserConnNum.Add(1)
} else {
ws.clients.Set(client.UserID, client)
ws.onlineUserConnNum.Add(1)
}
}
wg := sync.WaitGroup{}
log.ZDebug(client.ctx, "ws.msgGatewayConfig.Discovery.Enable", "discoveryEnable", ws.msgGatewayConfig.Discovery.Enable)
if ws.msgGatewayConfig.Discovery.Enable != "k8s" {
wg.Add(1)
go func() {
defer wg.Done()
_ = ws.sendUserOnlineInfoToOtherNode(client.ctx, client)
}()
}
//wg.Add(1)
//go func() {
// defer wg.Done()
// ws.SetUserOnlineStatus(client.ctx, client, constant.Online)
//}()
wg.Wait()
log.ZDebug(client.ctx, "user online", "online user Num", ws.onlineUserNum.Load(), "online user conn Num", ws.onlineUserConnNum.Load())
}
func getRemoteAdders(client []*Client) string {
var ret string
for i, c := range client {
if i == 0 {
ret = c.ctx.GetRemoteAddr()
} else {
ret += "@" + c.ctx.GetRemoteAddr()
}
}
return ret
}
func (ws *WsServer) KickUserConn(client *Client) error {
// 先发送踢掉通知,再删除客户端,确保通知能正确发送
err := client.KickOnlineMessage()
// KickOnlineMessage 内部会调用 close()close() 会调用 UnRegisterUnRegister 会删除客户端
// 所以这里不需要再调用 DeleteClients
return err
}
func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) {
kickTokenFunc := func(kickClients []*Client) {
if len(kickClients) == 0 {
return
}
var kickTokens []string
// 先发送踢掉通知KickOnlineMessage 内部会关闭连接并删除客户端
// 使用 goroutine 并发发送通知,提高效率
var wg sync.WaitGroup
for _, c := range kickClients {
kickTokens = append(kickTokens, c.token)
wg.Add(1)
go func(client *Client) {
defer wg.Done()
err := client.KickOnlineMessage()
if err != nil {
log.ZWarn(client.ctx, "KickOnlineMessage", err)
}
}(c)
}
// 等待所有通知发送完成
wg.Wait()
ctx := mcontext.WithMustInfoCtx(
[]string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(),
constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()},
)
if err := ws.authClient.KickTokens(ctx, kickTokens); err != nil {
log.ZWarn(newClient.ctx, "kickTokens err", err)
}
}
// If reconnect: When multiple msgGateway instances are deployed, a client may disconnect from instance A and reconnect to instance B.
// During this process, instance A might still be executing, resulting in two clients with the same token existing simultaneously.
// This situation needs to be filtered to prevent duplicate clients.
checkSameTokenFunc := func(oldClients []*Client) []*Client {
var clientsNeedToKick []*Client
for _, c := range oldClients {
if c.token == newClient.token {
log.ZDebug(newClient.ctx, "token is same, not kick",
"userID", newClient.UserID,
"platformID", newClient.PlatformID,
"token", newClient.token)
continue
}
clientsNeedToKick = append(clientsNeedToKick, c)
}
return clientsNeedToKick
}
switch ws.msgGatewayConfig.Share.MultiLogin.Policy {
case constant.DefalutNotKick:
case constant.PCAndOther:
if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC {
return
}
clients, ok := ws.clients.GetAll(newClient.UserID)
clientOK = ok
oldClients = make([]*Client, 0, len(clients))
for _, c := range clients {
if constant.PlatformIDToClass(c.PlatformID) == constant.TerminalPC {
continue
}
oldClients = append(oldClients, c)
}
fallthrough
case constant.AllLoginButSameTermKick:
if !clientOK {
return
}
oldClients = checkSameTokenFunc(oldClients)
if len(oldClients) == 0 {
return
}
// 先发送踢掉通知KickOnlineMessage 内部会关闭连接并删除客户端
// 使用 goroutine 并发发送通知,提高效率
var wg sync.WaitGroup
for _, c := range oldClients {
wg.Add(1)
go func(client *Client) {
defer wg.Done()
err := client.KickOnlineMessage()
if err != nil {
log.ZWarn(client.ctx, "KickOnlineMessage", err)
}
}(c)
}
// 等待所有通知发送完成
wg.Wait()
ctx := mcontext.WithMustInfoCtx(
[]string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(),
constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()},
)
req := &pbAuth.InvalidateTokenReq{
PreservedToken: newClient.token,
UserID: newClient.UserID,
PlatformID: int32(newClient.PlatformID),
}
if err := ws.authClient.InvalidateToken(ctx, req); err != nil {
log.ZWarn(newClient.ctx, "InvalidateToken err", err, "userID", newClient.UserID,
"platformID", newClient.PlatformID)
}
case constant.AllLoginButSameClassKick:
clients, ok := ws.clients.GetAll(newClient.UserID)
if !ok {
return
}
var kickClients []*Client
for _, client := range clients {
if constant.PlatformIDToClass(client.PlatformID) == constant.PlatformIDToClass(newClient.PlatformID) {
{
kickClients = append(kickClients, client)
}
}
}
kickClients = checkSameTokenFunc(kickClients)
kickTokenFunc(kickClients)
}
}
func (ws *WsServer) unregisterClient(client *Client) {
defer ws.clientPool.Put(client)
isDeleteUser := ws.clients.DeleteClients(client.UserID, []*Client{client})
if isDeleteUser {
ws.onlineUserNum.Add(-1)
prommetrics.OnlineUserGauge.Dec()
}
ws.onlineUserConnNum.Add(-1)
ws.subscription.DelClient(client)
//ws.SetUserOnlineStatus(client.ctx, client, constant.Offline)
log.ZDebug(client.ctx, "user offline", "close reason", client.closedErr, "online user Num",
ws.onlineUserNum.Load(), "online user conn Num",
ws.onlineUserConnNum.Load(),
)
}
// validateRespWithRequest checks if the response matches the expected userID and platformID.
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
userID := ctx.GetUserID()
platformID := stringutil.StringToInt32(ctx.GetPlatformID())
if resp.UserID != userID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
}
if resp.PlatformID != platformID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
}
return nil
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
// Create a new connection context
connContext := newContext(w, r)
if !ws.ready.Load() {
httpError(connContext, errs.New("ws server not ready"))
return
}
// Check if the current number of online user connections exceeds the maximum limit
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
return
}
// Parse essential arguments (e.g., user ID, Token)
err := connContext.ParseEssentialArgs()
if err != nil {
// If there's an error during parsing, return an error via HTTP and stop processing
httpError(connContext, err)
return
}
if ws.authClient == nil {
httpError(connContext, errs.New("auth client is not initialized"))
return
}
// Call the authentication client to parse the Token obtained from the context
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
if err != nil {
// If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
shouldSendError := connContext.ShouldSendResp()
if shouldSendError {
// Create a WebSocket connection object and attempt to send the error message via WebSocket
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.RespondWithError(err, w, r); err == nil {
// If the error message is successfully sent via WebSocket, stop processing
return
}
}
// If sending via WebSocket is not required or fails, return the error via HTTP and stop processing
httpError(connContext, err)
return
}
// Validate the authentication response matches the request (e.g., user ID and platform ID)
err = ws.validateRespWithRequest(connContext, resp)
if err != nil {
// If validation fails, return an error via HTTP and stop processing
httpError(connContext, err)
return
}
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
// Create a WebSocket long connection object
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
//If the creation of the long connection fails, the error is handled internally during the handshake process.
log.ZWarn(connContext, "long connection fails", err)
return
} else {
// Check if a normal response should be sent via WebSocket
shouldSendSuccessResp := connContext.ShouldSendResp()
if shouldSendSuccessResp {
// Attempt to send a success message through WebSocket
if err := wsLongConn.RespondWithSuccess(); err != nil {
// If the success message is successfully sent, end further processing
return
}
}
}
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
client := ws.clientPool.Get().(*Client)
client.ResetClient(connContext, wsLongConn, ws)
// Register the client with the server and start message processing
ws.registerChan <- client
go client.readMessage()
}