From 5803aadd3f209eba1ffbb2cf7bf16778019dbee1 Mon Sep 17 00:00:00 2001 From: fredoboulo Date: Fri, 23 Feb 2018 23:55:57 +0100 Subject: [PATCH] Fix #524 : V1 and V2 protocol downgrades handle received data in handshake buffer This patch is upstream pull request, see: https://gihub.com/zeromq/jeromq/pull/527. It is merged on commit c2afa9c, and we can drop it on the 0.4.4 release. --- src/main/java/zmq/io/StreamEngine.java | 21 ++++++++++-- src/test/java/zmq/io/AbstractProtocolVersion.java | 41 +++++++++++++---------- src/test/java/zmq/io/V0ProtocolTest.java | 12 +++++++ src/test/java/zmq/io/V1ProtocolTest.java | 16 +++++++-- src/test/java/zmq/io/V2ProtocolTest.java | 16 +++++++-- 5 files changed, 81 insertions(+), 25 deletions(-) diff --git a/src/main/java/zmq/io/StreamEngine.java b/src/main/java/zmq/io/StreamEngine.java index b8933c92..fe2f2d8d 100644 --- a/src/main/java/zmq/io/StreamEngine.java +++ b/src/main/java/zmq/io/StreamEngine.java @@ -816,9 +816,7 @@ private boolean handshake() assert (bufferSize == headerSize); // Make sure the decoder sees the data we have already received. - greetingRecv.flip(); - inpos = greetingRecv; - insize = greetingRecv.limit(); + decodeDataAfterHandshake(0); // To allow for interoperability with peers that do not forward // their subscriptions, we inject a phantom subscription message @@ -846,6 +844,8 @@ else if (greetingRecv.get(revisionPos) == Protocol.V1.revision) { } encoder = new V1Encoder(errno, Config.OUT_BATCH_SIZE.getValue()); decoder = new V1Decoder(errno, Config.IN_BATCH_SIZE.getValue(), options.maxMsgSize, options.allocator); + + decodeDataAfterHandshake(V2_GREETING_SIZE); } else if (greetingRecv.get(revisionPos) == Protocol.V2.revision) { // ZMTP/2.0 framing. @@ -859,6 +859,8 @@ else if (greetingRecv.get(revisionPos) == Protocol.V2.revision) { } encoder = new V2Encoder(errno, Config.OUT_BATCH_SIZE.getValue()); decoder = new V2Decoder(errno, Config.IN_BATCH_SIZE.getValue(), options.maxMsgSize, options.allocator); + + decodeDataAfterHandshake(V2_GREETING_SIZE); } else { zmtpVersion = Protocol.V3; @@ -904,6 +906,19 @@ else if (greetingRecv.get(revisionPos) == Protocol.V2.revision) { return true; } + private void decodeDataAfterHandshake(int greetingSize) + { + final int pos = greetingRecv.position(); + if (pos > greetingSize) { + // data is present after handshake + greetingRecv.position(greetingSize).limit(pos); + + // Make sure the decoder sees this extra data. + inpos = greetingRecv; + insize = greetingRecv.remaining(); + } + } + private Msg identityMsg() { Msg msg = new Msg(options.identitySize); diff --git a/src/test/java/zmq/io/AbstractProtocolVersion.java b/src/test/java/zmq/io/AbstractProtocolVersion.java index e60db403..aa06b4a7 100644 --- a/src/test/java/zmq/io/AbstractProtocolVersion.java +++ b/src/test/java/zmq/io/AbstractProtocolVersion.java @@ -18,15 +18,18 @@ import zmq.SocketBase; import zmq.ZError; import zmq.ZMQ; +import zmq.ZMQ.Event; import zmq.util.Utils; public abstract class AbstractProtocolVersion { + protected static final int REPETITIONS = 1000; + static class SocketMonitor extends Thread { - private final Ctx ctx; - private final String monitorAddr; - private final List events = new ArrayList<>(); + private final Ctx ctx; + private final String monitorAddr; + private final ZMQ.Event[] events = new ZMQ.Event[1]; public SocketMonitor(Ctx ctx, String monitorAddr) { @@ -41,15 +44,15 @@ public void run() boolean rc = s.connect(monitorAddr); assertThat(rc, is(true)); // Only some of the exceptional events could fire - while (true) { - ZMQ.Event event = ZMQ.Event.read(s); - if (event == null && s.errno() == ZError.ETERM) { - break; - } - assertThat(event, notNullValue()); - - events.add(event); + + ZMQ.Event event = ZMQ.Event.read(s); + if (event == null && s.errno() == ZError.ETERM) { + s.close(); + return; } + assertThat(event, notNullValue()); + + events[0] = event; s.close(); } } @@ -69,11 +72,12 @@ public void run() boolean rc = ZMQ.setSocketOption(receiver, ZMQ.ZMQ_LINGER, 0); assertThat(rc, is(true)); - SocketMonitor monitor = new SocketMonitor(ctx, "inproc://monitor"); - monitor.start(); rc = ZMQ.monitorSocket(receiver, "inproc://monitor", ZMQ.ZMQ_EVENT_HANDSHAKE_PROTOCOL); assertThat(rc, is(true)); + SocketMonitor monitor = new SocketMonitor(ctx, "inproc://monitor"); + monitor.start(); + rc = ZMQ.bind(receiver, host); assertThat(rc, is(true)); @@ -81,17 +85,18 @@ public void run() OutputStream out = sender.getOutputStream(); for (ByteBuffer raw : raws) { out.write(raw.array()); - ZMQ.msleep(100); } Msg msg = ZMQ.recv(receiver, 0); assertThat(msg, notNullValue()); assertThat(new String(msg.data(), ZMQ.CHARSET), is(payload)); - ZMQ.msleep(500); - assertThat(monitor.events.size(), is(1)); - assertThat(monitor.events.get(0).event, is(ZMQ.ZMQ_EVENT_HANDSHAKE_PROTOCOL)); - assertThat((Integer) monitor.events.get(0).arg, is(version)); + monitor.join(); + + final Event event = monitor.events[0]; + assertThat(event, notNullValue()); + assertThat(event.event, is(ZMQ.ZMQ_EVENT_HANDSHAKE_PROTOCOL)); + assertThat((Integer) event.arg, is(version)); InputStream in = sender.getInputStream(); byte[] data = new byte[255]; diff --git a/src/test/java/zmq/io/V0ProtocolTest.java b/src/test/java/zmq/io/V0ProtocolTest.java index bd547d23..1a5b7aef 100644 --- a/src/test/java/zmq/io/V0ProtocolTest.java +++ b/src/test/java/zmq/io/V0ProtocolTest.java @@ -10,6 +10,18 @@ public class V0ProtocolTest extends AbstractProtocolVersion { + @Test + public void testFixIssue524() throws IOException, InterruptedException + { + for (int idx = 0; idx < REPETITIONS; ++idx) { + if (idx % 100 == 0) { + System.out.print(idx + " "); + } + testProtocolVersion0short(); + } + System.out.println(); + } + @Test(timeout = 2000) public void testProtocolVersion0short() throws IOException, InterruptedException { diff --git a/src/test/java/zmq/io/V1ProtocolTest.java b/src/test/java/zmq/io/V1ProtocolTest.java index e1045f34..764159d0 100644 --- a/src/test/java/zmq/io/V1ProtocolTest.java +++ b/src/test/java/zmq/io/V1ProtocolTest.java @@ -10,7 +10,19 @@ public class V1ProtocolTest extends AbstractProtocolVersion { - @Test(timeout = 2000) + @Test + public void testFixIssue524() throws IOException, InterruptedException + { + for (int idx = 0; idx < REPETITIONS; ++idx) { + if (idx % 100 == 0) { + System.out.print(idx + " "); + } + testProtocolVersion1short(); + } + System.out.println(); + } + + @Test public void testProtocolVersion1short() throws IOException, InterruptedException { List raws = raws(0); @@ -25,7 +37,7 @@ public void testProtocolVersion1short() throws IOException, InterruptedException assertProtocolVersion(1, raws, "abcdefg"); } - @Test(timeout = 2000) + @Test public void testProtocolVersion1long() throws IOException, InterruptedException { List raws = raws(0); diff --git a/src/test/java/zmq/io/V2ProtocolTest.java b/src/test/java/zmq/io/V2ProtocolTest.java index d5e64bce..7fda31bc 100644 --- a/src/test/java/zmq/io/V2ProtocolTest.java +++ b/src/test/java/zmq/io/V2ProtocolTest.java @@ -21,7 +21,19 @@ protected ByteBuffer identity() .put((byte) 0); } - @Test(timeout = 2000) + @Test + public void testFixIssue524() throws IOException, InterruptedException + { + for (int idx = 0; idx < REPETITIONS; ++idx) { + if (idx % 100 == 0) { + System.out.print(idx + " "); + } + testProtocolVersion2short(); + } + System.out.println(); + } + + @Test public void testProtocolVersion2short() throws IOException, InterruptedException { List raws = raws(1); @@ -38,7 +50,7 @@ public void testProtocolVersion2short() throws IOException, InterruptedException assertProtocolVersion(2, raws, "abcdefg"); } - @Test(timeout = 2000) + @Test public void testProtocolVersion2long() throws IOException, InterruptedException { List raws = raws(1);