Browse Source

Binary packet parsing

master
Nikita 13 years ago
parent
commit
7ccd152da8
  1. 6
      pom.xml
  2. 84
      src/main/java/com/corundumstudio/socketio/PacketHandler.java
  3. 12
      src/main/java/com/corundumstudio/socketio/SocketIOPipelineFactory.java
  4. 25
      src/main/java/com/corundumstudio/socketio/messages/PacketsMessage.java
  5. 107
      src/main/java/com/corundumstudio/socketio/transport/WebSocketTransport.java
  6. 6
      src/main/java/com/corundumstudio/socketio/transport/XHRPollingClient.java
  7. 30
      src/main/java/com/corundumstudio/socketio/transport/XHRPollingTransport.java
  8. 123
      src/test/java/com/corundumstudio/socketio/PacketHandlerTest.java

6
pom.xml

@ -41,6 +41,12 @@
<version>4.10</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.googlecode.jmockit</groupId>
<artifactId>jmockit</artifactId>
<version>0.999.15</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty</artifactId>

84
src/main/java/com/corundumstudio/socketio/PacketHandler.java

@ -0,0 +1,84 @@
package com.corundumstudio.socketio;
import java.io.IOException;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.ChannelHandler.Sharable;
import org.jboss.netty.util.CharsetUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.corundumstudio.socketio.messages.PacketsMessage;
import com.corundumstudio.socketio.parser.Decoder;
import com.corundumstudio.socketio.parser.Packet;
@Sharable
public class PacketHandler extends SimpleChannelUpstreamHandler {
private final Logger log = LoggerFactory.getLogger(getClass());
private final PacketListener packetListener;
private final Decoder decoder;
public PacketHandler(PacketListener packetListener, Decoder decoder) {
super();
this.packetListener = packetListener;
this.decoder = decoder;
}
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
throws Exception {
Object msg = e.getMessage();
if (msg instanceof PacketsMessage) {
PacketsMessage message = (PacketsMessage) msg;
ChannelBuffer content = message.getContent();
if (log.isTraceEnabled()) {
log.trace("In message: {} sessionId: {}", new Object[] {content.toString(CharsetUtil.UTF_8), message.getClient().getSessionId()});
}
while (content.readable()) {
Packet packet = decode(content);
packetListener.onPacket(packet, message.getClient());
}
} else {
ctx.sendUpstream(e);
}
}
// TODO use ForkJoin
private Packet decode(ChannelBuffer buffer) throws IOException {
char delimiter = getChar(buffer, buffer.readerIndex());
if (delimiter == Packet.DELIMITER) {
StringBuilder length = new StringBuilder(4);
for (int i = buffer.readerIndex() + 2 + 1; i < buffer.readerIndex() + buffer.readableBytes(); i++) {
if (getChar(buffer, i) == Packet.DELIMITER) {
break;
} else {
length.append((char)buffer.getUnsignedByte(i));
}
}
Integer len = Integer.valueOf(length.toString());
int startIndex = buffer.readerIndex() + 3 + length.length() + 3;
ChannelBuffer frame = buffer.slice(startIndex, len);
Packet packet = decoder.decodePacket(frame.toString(CharsetUtil.UTF_8));
buffer.readerIndex(startIndex + len);
return packet;
} else {
Packet packet = decoder.decodePacket(buffer.toString(CharsetUtil.UTF_8));
buffer.readerIndex(buffer.readableBytes());
return packet;
}
}
// TODO refactor it
private char getChar(ChannelBuffer buffer, int index) {
byte[] bytes = {buffer.getByte(index), buffer.getByte(index + 1)};
return new String(bytes).charAt(0);
}
}

12
src/main/java/com/corundumstudio/socketio/SocketIOPipelineFactory.java

@ -46,17 +46,21 @@ public class SocketIOPipelineFactory implements ChannelPipelineFactory, Disconne
private SocketIOListener socketIOHandler;
private HeartbeatHandler heartbeatHandler;
private PacketHandler packetHandler;
public SocketIOPipelineFactory(Configuration configuration) {
this.socketIOHandler = configuration.getListener();
this.heartbeatHandler = new HeartbeatHandler(configuration);
ObjectMapper objectMapper = configuration.getObjectMapper();
Encoder encoder = new Encoder(objectMapper);
Decoder decoder = new Decoder(objectMapper);
this.heartbeatHandler = new HeartbeatHandler(configuration);
PacketListener packetListener = new PacketListener(socketIOHandler, this, heartbeatHandler);
packetHandler = new PacketHandler(packetListener, decoder);
authorizeHandler = new AuthorizeHandler(connectPath, socketIOHandler, configuration);
xhrPollingTransport = new XHRPollingTransport(connectPath, decoder, packetListener, this, heartbeatHandler, authorizeHandler, configuration);
webSocketTransport = new WebSocketTransport(connectPath, decoder, this, packetListener, authorizeHandler);
xhrPollingTransport = new XHRPollingTransport(connectPath, this, heartbeatHandler, authorizeHandler, configuration);
webSocketTransport = new WebSocketTransport(connectPath, this, authorizeHandler);
socketIOEncoder = new SocketIOEncoder(objectMapper, encoder);
}
@ -67,6 +71,8 @@ public class SocketIOPipelineFactory implements ChannelPipelineFactory, Disconne
pipeline.addLast("aggregator", new HttpChunkAggregator(65536));
pipeline.addLast("encoder", new HttpResponseEncoder());
pipeline.addLast("packetHandler", packetHandler);
pipeline.addLast("authorizeHandler", authorizeHandler);
pipeline.addLast("xhrPollingTransport", xhrPollingTransport);
pipeline.addLast("webSocketTransport", webSocketTransport);

25
src/main/java/com/corundumstudio/socketio/messages/PacketsMessage.java

@ -0,0 +1,25 @@
package com.corundumstudio.socketio.messages;
import org.jboss.netty.buffer.ChannelBuffer;
import com.corundumstudio.socketio.SocketIOClient;
public class PacketsMessage {
private final SocketIOClient client;
private final ChannelBuffer content;
public PacketsMessage(SocketIOClient client, ChannelBuffer content) {
this.client = client;
this.content = content;
}
public SocketIOClient getClient() {
return client;
}
public ChannelBuffer getContent() {
return content;
}
}

107
src/main/java/com/corundumstudio/socketio/transport/WebSocketTransport.java

@ -15,13 +15,15 @@
*/
package com.corundumstudio.socketio.transport;
import java.util.List;
import java.io.IOException;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.handler.codec.http.HttpHeaders;
@ -31,15 +33,14 @@ import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import org.jboss.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.jboss.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import org.jboss.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import org.jboss.netty.util.CharsetUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.corundumstudio.socketio.AuthorizeHandler;
import com.corundumstudio.socketio.Disconnectable;
import com.corundumstudio.socketio.PacketListener;
import com.corundumstudio.socketio.SocketIOClient;
import com.corundumstudio.socketio.parser.Decoder;
import com.corundumstudio.socketio.parser.Packet;
import com.corundumstudio.socketio.messages.PacketsMessage;
public class WebSocketTransport extends SimpleChannelUpstreamHandler implements Disconnectable {
@ -50,18 +51,13 @@ public class WebSocketTransport extends SimpleChannelUpstreamHandler implements
private final AuthorizeHandler authorizeHandler;
private final Disconnectable disconnectable;
private final PacketListener packetListener;
private final Decoder decoder;
private final String path;
public WebSocketTransport(String connectPath, Decoder decoder,
Disconnectable disconnectable, PacketListener packetListener, AuthorizeHandler authorizeHandler) {
public WebSocketTransport(String connectPath, Disconnectable disconnectable, AuthorizeHandler authorizeHandler) {
this.path = connectPath + "websocket";
this.decoder = decoder;
this.authorizeHandler = authorizeHandler;
this.disconnectable = disconnectable;
this.packetListener = packetListener;
}
@Override
@ -71,57 +67,66 @@ public class WebSocketTransport extends SimpleChannelUpstreamHandler implements
ctx.getChannel().close();
} else if (msg instanceof TextWebSocketFrame) {
TextWebSocketFrame frame = (TextWebSocketFrame) msg;
WebSocketClient client = channelId2Client.get(ctx.getChannel().getId());
String content = frame.getText();
log.trace("In message: {} sessionId: {}", new Object[] {content, client.getSessionId()});
List<Packet> packets = decoder.decodePayload(content);
for (Packet packet : packets) {
packetListener.onPacket(packet, client);
}
receivePackets(ctx, frame.getBinaryData());
} else if (msg instanceof HttpRequest) {
HttpRequest req = (HttpRequest) msg;
WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, false);
WebSocketServerHandshaker handshaker = factory.newHandshaker(req);
if (handshaker != null) {
handshaker.handshake(ctx.getChannel(), req);
QueryStringDecoder queryDecoder = new QueryStringDecoder(req.getUri());
connectClient(ctx.getChannel(), queryDecoder);
} else {
factory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel());
}
handshake(ctx, req);
} else {
ctx.sendUpstream(e);
}
}
private void connectClient(Channel channel, QueryStringDecoder queryDecoder) {
String path = queryDecoder.getPath();
if (!path.startsWith(this.path)) {
private void handshake(ChannelHandlerContext ctx, HttpRequest req) {
QueryStringDecoder queryDecoder = new QueryStringDecoder(req.getUri());
Channel channel = ctx.getChannel();
String path = queryDecoder.getPath();
if (!path.startsWith(this.path)) {
return;
}
String[] parts = path.split("/");
if (parts.length <= 3) {
log.warn("Wrong GET request path: {}, from ip: {}. Channel closed!",
new Object[] {path, channel.getRemoteAddress()});
channel.close();
return;
}
UUID sessionId = UUID.fromString(parts[4]);
WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, false);
WebSocketServerHandshaker handshaker = factory.newHandshaker(req);
if (handshaker != null) {
handshaker.handshake(channel, req);
connectClient(channel, sessionId);
} else {
factory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel());
}
}
private void receivePackets(ChannelHandlerContext ctx,
ChannelBuffer channelBuffer) throws IOException {
WebSocketClient client = channelId2Client.get(ctx.getChannel().getId());
if (log.isTraceEnabled()) {
String content = channelBuffer.toString(CharsetUtil.UTF_8);
log.trace("In message: {} sessionId: {}", new Object[] {content, client.getSessionId()});
}
Channels.fireMessageReceived(ctx.getChannel(), new PacketsMessage(client, channelBuffer));
}
private void connectClient(Channel channel, UUID sessionId) {
if (!authorizeHandler.isSessionAuthorized(sessionId)) {
log.warn("Unauthorized client with sessionId: {}, from ip: {}. Channel closed!",
new Object[] {sessionId, channel.getRemoteAddress()});
channel.close();
return;
}
String[] parts = path.split("/");
if (parts.length > 3) {
UUID sessionId = UUID.fromString(parts[4]);
if (!authorizeHandler.isSessionAuthorized(sessionId)) {
log.warn("Unauthorized client with sessionId: {}, from ip: {}. Channel closed!",
new Object[] {sessionId, channel.getRemoteAddress()});
channel.close();
return;
}
WebSocketClient client = new WebSocketClient(channel, disconnectable, sessionId);
channelId2Client.put(channel.getId(), client);
sessionId2Client.put(sessionId, client);
authorizeHandler.connect(client);
} else {
log.warn("Wrong GET request path: {}, from ip: {}. Channel closed!",
new Object[] {path, channel.getRemoteAddress()});
channel.close();
}
WebSocketClient client = new WebSocketClient(channel, disconnectable, sessionId);
channelId2Client.put(channel.getId(), client);
sessionId2Client.put(sessionId, client);
authorizeHandler.connect(client);
}
private String getWebSocketLocation(HttpRequest req) {

6
src/main/java/com/corundumstudio/socketio/transport/XHRPollingClient.java

@ -19,8 +19,6 @@ import java.util.UUID;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.HttpRequest;
import com.corundumstudio.socketio.Disconnectable;
import com.corundumstudio.socketio.messages.XHRNewChannelMessage;
@ -39,8 +37,8 @@ public class XHRPollingClient extends BaseClient {
this.disconnectable = disconnectable;
}
public void update(Channel channel, HttpRequest req) {
this.origin = req.getHeader(HttpHeaders.Names.ORIGIN);
public void update(Channel channel, String origin) {
this.origin = origin;
this.channel = channel;
channel.write(new XHRNewChannelMessage(sessionId, origin));
}

30
src/main/java/com/corundumstudio/socketio/transport/XHRPollingTransport.java

@ -16,20 +16,19 @@
package com.corundumstudio.socketio.transport;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.QueryStringDecoder;
import org.jboss.netty.util.CharsetUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -37,11 +36,10 @@ import com.corundumstudio.socketio.AuthorizeHandler;
import com.corundumstudio.socketio.Configuration;
import com.corundumstudio.socketio.Disconnectable;
import com.corundumstudio.socketio.HeartbeatHandler;
import com.corundumstudio.socketio.PacketListener;
import com.corundumstudio.socketio.SocketIOClient;
import com.corundumstudio.socketio.messages.PacketsMessage;
import com.corundumstudio.socketio.messages.XHRErrorMessage;
import com.corundumstudio.socketio.messages.XHRPostMessage;
import com.corundumstudio.socketio.parser.Decoder;
import com.corundumstudio.socketio.parser.ErrorAdvice;
import com.corundumstudio.socketio.parser.ErrorReason;
import com.corundumstudio.socketio.parser.Packet;
@ -55,22 +53,17 @@ public class XHRPollingTransport extends SimpleChannelUpstreamHandler implements
private final AuthorizeHandler authorizeHandler;
private final HeartbeatHandler heartbeatHandler;
private final PacketListener packetListener;
private final Disconnectable disconnectable;
private final Decoder decoder;
private final String path;
private final Configuration configuration;
public XHRPollingTransport(String connectPath, Decoder decoder,
PacketListener packetListener, Disconnectable disconnectable,
public XHRPollingTransport(String connectPath, Disconnectable disconnectable,
HeartbeatHandler heartbeatHandler, AuthorizeHandler authorizeHandler, Configuration configuration) {
this.path = connectPath + "xhr-polling/";
this.authorizeHandler = authorizeHandler;
this.configuration = configuration;
this.heartbeatHandler = heartbeatHandler;
this.disconnectable = disconnectable;
this.decoder = decoder;
this.packetListener = packetListener;
}
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
@ -112,14 +105,8 @@ public class XHRPollingTransport extends SimpleChannelUpstreamHandler implements
return;
}
String content = req.getContent().toString(CharsetUtil.UTF_8);
log.trace("In message: {} sessionId: {}", new Object[] {content, sessionId});
List<Packet> packets = decoder.decodePayload(content);
for (Packet packet : packets) {
packetListener.onPacket(packet, client);
}
String origin = req.getHeader(HttpHeaders.Names.ORIGIN);
Channels.fireMessageReceived(channel, new PacketsMessage(client, req.getContent()));
channel.write(new XHRPostMessage(origin));
}
@ -128,18 +115,19 @@ public class XHRPollingTransport extends SimpleChannelUpstreamHandler implements
sendError(channel, req, sessionId);
return;
}
String origin = req.getHeader(HttpHeaders.Names.ORIGIN);
XHRPollingClient client = sessionId2Client.get(sessionId);
if (client == null) {
client = createClient(req, channel, sessionId);
client = createClient(origin, channel, sessionId);
}
client.update(channel, req);
client.update(channel, origin);
}
private XHRPollingClient createClient(HttpRequest req, Channel channel, UUID sessionId) {
private XHRPollingClient createClient(String origin, Channel channel, UUID sessionId) {
XHRPollingClient client = new XHRPollingClient(authorizeHandler, sessionId);
sessionId2Client.put(sessionId, client);
client.update(channel, req);
client.update(channel, origin);
authorizeHandler.connect(client);
if (configuration.isHeartbeatsEnabled()) {

123
src/test/java/com/corundumstudio/socketio/PacketHandlerTest.java

@ -0,0 +1,123 @@
package com.corundumstudio.socketio;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import junit.framework.Assert;
import mockit.Mocked;
import org.codehaus.jackson.map.ObjectMapper;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.UpstreamMessageEvent;
import org.junit.Test;
import com.corundumstudio.socketio.messages.PacketsMessage;
import com.corundumstudio.socketio.parser.Decoder;
import com.corundumstudio.socketio.parser.Encoder;
import com.corundumstudio.socketio.parser.Packet;
import com.corundumstudio.socketio.parser.PacketType;
public class PacketHandlerTest {
private ObjectMapper map = new ObjectMapper();
private Decoder decoder = new Decoder(map);
private Encoder encoder = new Encoder(map);
@Mocked
private Channel channel;
@Test
public void testOnePacket() throws Exception {
final AtomicInteger invocations = new AtomicInteger();
PacketListener listener = new PacketListener(null, null, null) {
@Override
public void onPacket(Packet packet, SocketIOClient client) {
invocations.incrementAndGet();
Assert.assertEquals(PacketType.JSON, packet.getType());
Map<String, String> map = (Map<String, String>) packet.getData();
Assert.assertTrue(map.keySet().size() == 1);
Assert.assertTrue(map.keySet().contains("test1"));
}
};
PacketHandler handler = new PacketHandler(listener, decoder);
List<Packet> packets = new ArrayList<Packet>();
Packet packet = new Packet(PacketType.JSON);
packet.setData(Collections.singletonMap("test1", "test2"));
packets.add(packet);
testHandler(invocations, handler, packets);
}
@Test
public void testMultiplePackets() throws Exception {
final AtomicInteger invocations = new AtomicInteger();
PacketListener listener = new PacketListener(null, null, null) {
@Override
public void onPacket(Packet packet, SocketIOClient client) {
if (packet.getType() == PacketType.CONNECT) {
invocations.incrementAndGet();
return;
}
Assert.assertEquals(PacketType.JSON, packet.getType());
Map<String, String> map = (Map<String, String>) packet.getData();
Set<String> keys = new HashSet<String>();
keys.add("test1");
keys.add("fsdfdf");
Assert.assertTrue(map.keySet().size() == 1);
Assert.assertTrue(map.keySet().removeAll(keys));
invocations.incrementAndGet();
}
};
PacketHandler handler = new PacketHandler(listener, decoder);
List<Packet> packets = new ArrayList<Packet>();
Packet packet3 = new Packet(PacketType.CONNECT);
packets.add(packet3);
Packet packet = new Packet(PacketType.JSON);
packet.setData(Collections.singletonMap("test1", "test2"));
packets.add(packet);
Packet packet1 = new Packet(PacketType.JSON);
packet1.setData(Collections.singletonMap("fsdfdf", "wqeq"));
packets.add(packet1);
testHandler(invocations, handler, packets);
}
private void testHandler(final AtomicInteger invocations,
PacketHandler handler, List<Packet> packets) throws Exception {
String str = encoder.encodePackets(packets);
ChannelBuffer buffer = ChannelBuffers.wrappedBuffer(str.getBytes());
handler.messageReceived(null, new UpstreamMessageEvent(channel, new PacketsMessage(null, buffer), null));
Assert.assertEquals(packets.size(), invocations.get());
}
//@Test
public void testDecodePerf() throws Exception {
PacketListener listener = new PacketListener(null, null, null) {
@Override
public void onPacket(Packet packet, SocketIOClient client) {
}
};
PacketHandler handler = new PacketHandler(listener, decoder);
long start = System.currentTimeMillis();
ChannelBuffer buffer = ChannelBuffers.wrappedBuffer("\ufffd5\ufffd3:::5\ufffd7\ufffd3:::53d\ufffd3\ufffd0::\ufffd5\ufffd3:::5\ufffd7\ufffd3:::53d\ufffd3\ufffd0::\ufffd5\ufffd3:::5\ufffd7\ufffd3:::53d\ufffd3\ufffd0::\ufffd5\ufffd3:::5\ufffd7\ufffd3:::53d\ufffd3\ufffd0::\ufffd5\ufffd3:::5\ufffd7\ufffd3:::53d\ufffd3\ufffd0::".getBytes());
for (int i = 0; i < 50000; i++) {
ChannelBuffer t = buffer.copy();
handler.messageReceived(null, new UpstreamMessageEvent(channel, new PacketsMessage(null, t), null));
}
long end = System.currentTimeMillis() - start;
System.out.println(end + "ms");
// 1143ms
}
}
Loading…
Cancel
Save