diff --git a/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/message/AmqpMessageSupport.java b/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/message/AmqpMessageSupport.java index eedb464f72d..102a1701a5e 100644 --- a/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/message/AmqpMessageSupport.java +++ b/activemq-amqp/src/main/java/org/apache/activemq/transport/amqp/message/AmqpMessageSupport.java @@ -18,6 +18,7 @@ import java.io.DataInputStream; import java.io.IOException; +import java.io.InputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.nio.charset.Charset; @@ -38,6 +39,7 @@ import org.apache.activemq.util.ByteArrayOutputStream; import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.JMSExceptionSupport; +import org.apache.activemq.util.MarshallingSupport; import org.apache.qpid.proton.amqp.Binary; import org.apache.qpid.proton.amqp.Symbol; import org.apache.qpid.proton.amqp.messaging.Data; @@ -213,6 +215,12 @@ public static Binary getBinaryFromMessageBody(ActiveMQBytesMessage message) thro if (message.isCompressed()) { int length = (int) message.getBodyLength(); + // before we allocate the buffer ensure it's not too large + try { + MarshallingSupport.validateMaxInflatedDataSize(message.getMaxInflatedDataSize(), length); + } catch (IOException cause) { + throw JMSExceptionSupport.create(cause); + } byte[] uncompressed = new byte[length]; message.readBytes(uncompressed); @@ -244,7 +252,9 @@ public static Binary getBinaryFromMessageBody(ActiveMQObjectMessage message) thr if (message.isCompressed()) { try (ByteArrayOutputStream os = new ByteArrayOutputStream(); ByteArrayInputStream is = new ByteArrayInputStream(contents); - InflaterInputStream iis = new InflaterInputStream(is);) { + // wrap to prevent allocating more than maxInflatedDataSize + InputStream iis = MarshallingSupport.createInflaterInputStream( + message.getMaxInflatedDataSize(), is)) { byte value; while ((value = (byte) iis.read()) != -1) { @@ -282,10 +292,14 @@ public static Binary getBinaryFromMessageBody(ActiveMQTextMessage message) throw if (message.isCompressed()) { try (ByteArrayInputStream is = new ByteArrayInputStream(contents); + // We do not need to wrap this stream, the size is validated below + // before allocation InflaterInputStream iis = new InflaterInputStream(is); DataInputStream dis = new DataInputStream(iis);) { int size = dis.readInt(); + // before we allocate the buffer ensure it's not too large + MarshallingSupport.validateMaxInflatedDataSize(message.getMaxInflatedDataSize(), size); byte[] uncompressed = new byte[size]; dis.readFully(uncompressed); diff --git a/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/JMSInteroperabilityTest.java b/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/JMSInteroperabilityTest.java index 8bad4f39c7f..4e5feeec17b 100644 --- a/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/JMSInteroperabilityTest.java +++ b/activemq-amqp/src/test/java/org/apache/activemq/transport/amqp/JMSInteroperabilityTest.java @@ -496,6 +496,42 @@ public void testOpenWireToQpidObjectMessageWithOpenWireCompression() throws Exce } } + @Test + public void testOpenWireToQpidCompressionFailure() throws Exception { + + // Raw Transformer doesn't expand message properties. + assumeFalse(!transformer.equals("jms")); + + // set to 512 bytes + brokerService.setMaxInflatedDataSize(512); + + try (Connection openwire = createJMSConnection(); Connection amqp = createConnection()) { + ((ActiveMQConnection) openwire).setUseCompression(true); + openwire.start(); + amqp.start(); + + Session openwireSession = openwire.createSession(false, Session.AUTO_ACKNOWLEDGE); + Session amqpSession = amqp.createSession(false, Session.AUTO_ACKNOWLEDGE); + + Destination queue = openwireSession.createQueue(getDestinationName()); + MessageProducer openwireProducer = openwireSession.createProducer(queue); + MessageConsumer amqpConsumer = amqpSession.createConsumer(queue); + + StringBuilder builder = new StringBuilder(); + // generate a string longer than 512 bytes + for (int i = 1; i <= 50; i++) { + builder.append("compresedpayload"); + } + // Create and send the Message + openwireProducer.send(openwireSession.createTextMessage(builder.toString())); + + // There should be an error triggered on dispatch during decompression + // and message should go to the DLQ + assertNull(amqpConsumer.receive(1000)); + assertTrue(sentToDlq.get()); + } + } + // The following tests for corruption will corrupt the headers or body // to test that the AMQP protocol correctly passes the error during // dispatch to allow the Transport Connection to properly handle diff --git a/activemq-broker/src/main/java/org/apache/activemq/broker/BrokerService.java b/activemq-broker/src/main/java/org/apache/activemq/broker/BrokerService.java index 74599becf73..1ffd8d79133 100644 --- a/activemq-broker/src/main/java/org/apache/activemq/broker/BrokerService.java +++ b/activemq-broker/src/main/java/org/apache/activemq/broker/BrokerService.java @@ -269,6 +269,7 @@ public class BrokerService implements Service { private final List preShutdownHooks = new CopyOnWriteArrayList<>(); private int maxUncommittedCount = DEFAULT_MAX_UNCOMMITTED_COUNT; + private int maxInflatedDataSize = OpenWireFormat.DEFAULT_MAX_INFLATED_DATA_SIZE; static { @@ -3349,4 +3350,17 @@ public void setMaxUncommittedCount(int maxUncommittedCount) { this.maxUncommittedCount = maxUncommittedCount; } + public int getMaxInflatedDataSize() { + return maxInflatedDataSize; + } + + /** + * Set the maximum size that a compressed message can inflate to + * if a message has to be decompressed. + * + * @param maxInflatedDataSize + */ + public void setMaxInflatedDataSize(int maxInflatedDataSize) { + this.maxInflatedDataSize = maxInflatedDataSize; + } } diff --git a/activemq-broker/src/main/java/org/apache/activemq/broker/region/BaseDestination.java b/activemq-broker/src/main/java/org/apache/activemq/broker/region/BaseDestination.java index 6aaa317fe42..e98156fa488 100644 --- a/activemq-broker/src/main/java/org/apache/activemq/broker/region/BaseDestination.java +++ b/activemq-broker/src/main/java/org/apache/activemq/broker/region/BaseDestination.java @@ -932,4 +932,9 @@ public MessageInterceptorStrategy getMessageInterceptorStrategy() { public void setMessageInterceptorStrategy(MessageInterceptorStrategy messageInterceptorStrategy) { this.messageInterceptorStrategy = messageInterceptorStrategy; } + + @Override + public int getMaxInflatedDataSize() { + return brokerService.getMaxInflatedDataSize(); + } } diff --git a/activemq-broker/src/main/java/org/apache/activemq/broker/region/DestinationFilter.java b/activemq-broker/src/main/java/org/apache/activemq/broker/region/DestinationFilter.java index 1ab96560ac2..154f013d6c6 100644 --- a/activemq-broker/src/main/java/org/apache/activemq/broker/region/DestinationFilter.java +++ b/activemq-broker/src/main/java/org/apache/activemq/broker/region/DestinationFilter.java @@ -439,6 +439,11 @@ public void deleteSubscription(ConnectionContext context, SubscriptionKey key) t } } + @Override + public int getMaxInflatedDataSize() { + return next.getMaxInflatedDataSize(); + } + public Destination getNext() { return next; } diff --git a/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnection.java b/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnection.java index 216aba8d6b2..900938ec04c 100644 --- a/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnection.java +++ b/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnection.java @@ -152,6 +152,9 @@ public class ActiveMQConnection implements Connection, TopicConnection, QueueCon private boolean optimizedMessageDispatch = true; private boolean copyMessageOnSend = true; private boolean useCompression; + private double maxInflatedDataSizeRatio = ActiveMQConnectionFactory.DEFAULT_MAX_INFLATED_DATA_SIZE_RATIO; + // This will be configured during negotiation if maxFrameSize has been configured. + private int maxInflatedDataSize = Integer.MAX_VALUE; private boolean objectMessageSerializationDefered; private boolean useAsyncSend; private boolean optimizeAcknowledge; @@ -2034,6 +2037,15 @@ protected void onWireFormatInfo(WireFormatInfo info) { if(tmpMaxFrameSize > 0) { maxFrameSize.set(tmpMaxFrameSize); } + + // Compute the maxInflatedData size as a ratio of maxFrameSize + // This prevents overflow and sets to Integer.MAX_VALUE if too large + double updatedMaxInflated = (double)tmpMaxFrameSize * maxInflatedDataSizeRatio; + if (Double.isInfinite(updatedMaxInflated) || updatedMaxInflated > Integer.MAX_VALUE) { + this.maxInflatedDataSize = Integer.MAX_VALUE; + } else { + this.maxInflatedDataSize = (int) updatedMaxInflated; + } } /** @@ -2210,6 +2222,18 @@ public void setUseCompression(boolean useCompression) { this.useCompression = useCompression; } + public int getMaxInflatedDataSize() { + return maxInflatedDataSize; + } + + public double getMaxInflatedDataSizeRatio() { + return maxInflatedDataSizeRatio; + } + + public void setMaxInflatedDataSizeRatio(double maxInflatedDataSizeRatio) { + this.maxInflatedDataSizeRatio = maxInflatedDataSizeRatio; + } + public void destroyDestination(ActiveMQDestination destination) throws JMSException { checkClosedOrFailed(); diff --git a/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnectionFactory.java b/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnectionFactory.java index cb9947a9107..93cfb20d309 100644 --- a/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnectionFactory.java +++ b/activemq-client/src/main/java/org/apache/activemq/ActiveMQConnectionFactory.java @@ -99,6 +99,9 @@ public class ActiveMQConnectionFactory extends JNDIBaseStorable implements Conne public static final String DEFAULT_USER = null; public static final String DEFAULT_PASSWORD = null; public static final int DEFAULT_PRODUCER_WINDOW_SIZE = 0; + // The default ratio for maxInflatedDataSize. The default is 10x the size + // of maxFrameSize + public static final double DEFAULT_MAX_INFLATED_DATA_SIZE_RATIO = 10.0; protected URI brokerURL; protected String userName; @@ -135,6 +138,7 @@ public class ActiveMQConnectionFactory extends JNDIBaseStorable implements Conne private long optimizedAckScheduledAckInterval = 0; private boolean copyMessageOnSend = true; private boolean useCompression; + private double maxInflatedDataSizeRatio = DEFAULT_MAX_INFLATED_DATA_SIZE_RATIO; private boolean objectMessageSerializationDefered; private boolean useAsyncSend; private boolean optimizeAcknowledge; @@ -428,6 +432,7 @@ protected void configureConnection(ActiveMQConnection connection) throws JMSExce connection.setOptimizedMessageDispatch(isOptimizedMessageDispatch()); connection.setCopyMessageOnSend(isCopyMessageOnSend()); connection.setUseCompression(isUseCompression()); + connection.setMaxInflatedDataSizeRatio(getMaxInflatedDataSizeRatio()); connection.setObjectMessageSerializationDefered(isObjectMessageSerializationDefered()); connection.setDispatchAsync(isDispatchAsync()); connection.setUseAsyncSend(isUseAsyncSend()); @@ -876,6 +881,7 @@ public void populateProperties(Properties props) { props.setProperty("useAsyncSend", Boolean.toString(isUseAsyncSend())); props.setProperty("useCompression", Boolean.toString(isUseCompression())); + props.setProperty("maxInflatedDataSizeRatio", Double.toString(getMaxInflatedDataSizeRatio())); props.setProperty("useRetroactiveConsumer", Boolean.toString(isUseRetroactiveConsumer())); props.setProperty("watchTopicAdvisories", Boolean.toString(isWatchTopicAdvisories())); @@ -917,6 +923,21 @@ public void setUseCompression(boolean useCompression) { this.useCompression = useCompression; } + public double getMaxInflatedDataSizeRatio() { + return maxInflatedDataSizeRatio; + } + + /** + * Set the ratio to use to compute maxInflatedDataSize which controls + * how large a decompressed message buffer can be. maxInflatedDataSize + * is computed as maxFrameSize * maxInflatedDataSizeRatio. + * + * @param maxInflatedDataSizeRatio + */ + public void setMaxInflatedDataSizeRatio(double maxInflatedDataSizeRatio) { + this.maxInflatedDataSizeRatio = maxInflatedDataSizeRatio; + } + public boolean isObjectMessageSerializationDefered() { return objectMessageSerializationDefered; } diff --git a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQBytesMessage.java b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQBytesMessage.java index 98517f1b39c..822d7643652 100644 --- a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQBytesMessage.java +++ b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQBytesMessage.java @@ -39,6 +39,7 @@ import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.ByteSequenceData; import org.apache.activemq.util.JMSExceptionSupport; +import org.apache.activemq.util.MarshallingSupport; /** * A BytesMessage object is used to send a message containing a @@ -901,11 +902,15 @@ protected byte[] decompress(ByteSequence dataSequence) throws IOException { ByteArrayOutputStream decompressed = new ByteArrayOutputStream(); try { length = ByteSequenceData.readIntBig(dataSequence); + // verify the length of the buffer is not larger than maxInflatedDataSize + MarshallingSupport.validateMaxInflatedDataSize(getMaxInflatedDataSize(), length); dataSequence.offset = 0; - byte[] data = Arrays.copyOfRange(dataSequence.getData(), 4, dataSequence.getLength()); - inflater.setInput(data); + inflater.setInput(dataSequence.getData(), 4, dataSequence.getLength() - 4); byte[] buffer = new byte[length]; int count = inflater.inflate(buffer); + if (count != length) { + throw new IllegalStateException("Inflated buffer size is different than expected size of " + length); + } decompressed.write(buffer, 0, count); return decompressed.toByteArray(); } catch (Exception e) { diff --git a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQMapMessage.java b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQMapMessage.java index 140431573f5..6f5473f659b 100644 --- a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQMapMessage.java +++ b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQMapMessage.java @@ -189,7 +189,8 @@ private Map deserialize(ByteSequence content) throws JMSExceptio if (content != null) { InputStream is = new ByteArrayInputStream(content); if (isCompressed()) { - is = MarshallingSupport.createInflaterInputStream(is); + // wrap the stream so we don't inflate past maxInflatedDataSize + is = MarshallingSupport.createInflaterInputStream(getMaxInflatedDataSize(), is); } DataInputStream dataIn = new DataInputStream(is); map = MarshallingSupport.unmarshalPrimitiveMap(dataIn); diff --git a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQObjectMessage.java b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQObjectMessage.java index 79cbf4c0d5e..395dcca554c 100644 --- a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQObjectMessage.java +++ b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQObjectMessage.java @@ -38,6 +38,7 @@ import org.apache.activemq.util.ByteSequence; import org.apache.activemq.util.ClassLoadingAwareObjectInputStream; import org.apache.activemq.util.JMSExceptionSupport; +import org.apache.activemq.util.MarshallingSupport; import org.apache.activemq.wireformat.WireFormat; /** @@ -208,7 +209,8 @@ private Serializable deserialize(ByteSequence content) throws JMSException { try { InputStream is = new ByteArrayInputStream(content); if (isCompressed()) { - is = new InflaterInputStream(is); + // wrap the stream so we don't inflate past maxInflatedDataSize + is = MarshallingSupport.createInflaterInputStream(getMaxInflatedDataSize(), is); } DataInputStream dataIn = new DataInputStream(is); ClassLoadingAwareObjectInputStream objIn = new ClassLoadingAwareObjectInputStream(dataIn); diff --git a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQStreamMessage.java b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQStreamMessage.java index 1cd0d70e97d..19640dad4ef 100644 --- a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQStreamMessage.java +++ b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQStreamMessage.java @@ -833,6 +833,8 @@ public Object readObject() throws JMSException { } if (type == MarshallingSupport.BYTE_ARRAY_TYPE) { int len = this.dataIn.readInt(); + // verify that there are enough bytes remaining before allocation + MarshallingSupport.validateBufferSizeRemaining(dataIn, len); byte[] value = new byte[len]; this.dataIn.readFully(value); return value; @@ -1165,10 +1167,15 @@ private void initializeWriting() throws JMSException { if (compressed) { ByteArrayInputStream input = new ByteArrayInputStream(this.content.getData(), this.content.getOffset(), this.content.getLength()); InflaterInputStream inflater = new InflaterInputStream(input); + int total = 0; try { byte[] buffer = new byte[8*1024]; int read = 0; while ((read = inflater.read(buffer)) != -1) { + total = Math.addExact(total, read); + // each time through the loop see if we are >= max inflated size so we stop + // by doing this here we might go slightly pass the limit (up to 8 KB) but that is fine + MarshallingSupport.validateMaxInflatedDataSize(getMaxInflatedDataSize(), total); this.dataOut.write(buffer, 0, read); } } finally { @@ -1203,7 +1210,10 @@ private void initializeReading() throws MessageNotReadableException { if (isCompressed()) { is = new InflaterInputStream(is); is = new BufferedInputStream(is); - is = MarshallingSupport.createFrameLimitedInputStream(Integer.MAX_VALUE, is); + // Wrap the buffered stream in a frame limited stream so we can error if we exceed + // max inflate size + is = MarshallingSupport.createFrameLimitedInputStream(getMaxInflatedDataSize(), is); + } this.dataIn = new DataInputStream(is); } diff --git a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQTextMessage.java b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQTextMessage.java index 5e73c3314b2..a5df4164f0a 100644 --- a/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQTextMessage.java +++ b/activemq-client/src/main/java/org/apache/activemq/command/ActiveMQTextMessage.java @@ -95,7 +95,8 @@ private String decodeContent(ByteSequence bodyAsBytes) throws JMSException { try { is = new ByteArrayInputStream(bodyAsBytes); if (isCompressed()) { - is = MarshallingSupport.createInflaterInputStream(is); + // wrap the stream so we don't inflate past maxInflatedDataSize + is = MarshallingSupport.createInflaterInputStream(getMaxInflatedDataSize(), is); } DataInputStream dataIn = new DataInputStream(is); text = MarshallingSupport.readUTF8(dataIn); diff --git a/activemq-client/src/main/java/org/apache/activemq/command/Message.java b/activemq-client/src/main/java/org/apache/activemq/command/Message.java index 2a31047c9eb..06e0aebeea6 100644 --- a/activemq-client/src/main/java/org/apache/activemq/command/Message.java +++ b/activemq-client/src/main/java/org/apache/activemq/command/Message.java @@ -33,6 +33,7 @@ import org.apache.activemq.ActiveMQConnection; import org.apache.activemq.advisory.AdvisorySupport; import org.apache.activemq.broker.region.MessageReference; +import org.apache.activemq.openwire.OpenWireFormat; import org.apache.activemq.usage.MemoryUsage; import org.apache.activemq.util.ByteArrayInputStream; import org.apache.activemq.util.ByteArrayOutputStream; @@ -102,9 +103,10 @@ public abstract class Message extends BaseCommand implements MarshallAware, Mess private BrokerId[] brokerPath; private BrokerId[] cluster; - public static interface MessageDestination { + public interface MessageDestination { int getMinimumMessageSize(); MemoryUsage getMemoryUsage(); + int getMaxInflatedDataSize(); } public abstract Message copy(); @@ -871,4 +873,15 @@ protected Object readResolve() throws ObjectStreamException { } return this; } + + public int getMaxInflatedDataSize() { + // If this is set then this is on a broker + if (regionDestination != null) { + return regionDestination.getMaxInflatedDataSize(); + // connection is set on Clients + } else if (connection != null) { + return connection.getMaxInflatedDataSize(); + } + return OpenWireFormat.DEFAULT_MAX_INFLATED_DATA_SIZE; + } } diff --git a/activemq-client/src/main/java/org/apache/activemq/openwire/OpenWireFormat.java b/activemq-client/src/main/java/org/apache/activemq/openwire/OpenWireFormat.java index 86171dc288f..a927ed8f4c6 100644 --- a/activemq-client/src/main/java/org/apache/activemq/openwire/OpenWireFormat.java +++ b/activemq-client/src/main/java/org/apache/activemq/openwire/OpenWireFormat.java @@ -43,6 +43,7 @@ public final class OpenWireFormat implements WireFormat { public static final int DEFAULT_WIRE_VERSION = CommandTypes.PROTOCOL_VERSION; public static final int DEFAULT_LEGACY_VERSION = CommandTypes.PROTOCOL_LEGACY_STORE_VERSION; public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE; + public static final int DEFAULT_MAX_INFLATED_DATA_SIZE = 1024 * 1024 * 100; static final byte NULL_TYPE = CommandTypes.NULL; private static final int MARSHAL_CACHE_SIZE = Short.MAX_VALUE / 2; diff --git a/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java b/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java index c7bc76f02fd..43eed3aa714 100644 --- a/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java +++ b/activemq-client/src/main/java/org/apache/activemq/util/MarshallingSupport.java @@ -60,15 +60,27 @@ public final class MarshallingSupport { private MarshallingSupport() {} - // TODO: This will be limited in a future PR to something besides Integer.MAX_VALUE - public static InputStream createInflaterInputStream(InputStream is) { - return createFrameLimitedInputStream(Integer.MAX_VALUE, new InflaterInputStream(is)); + public static InputStream createInflaterInputStream(int maxAvailable, InputStream is) { + return createFrameLimitedInputStream(maxAvailable, new InflaterInputStream(is)); } public static InputStream createFrameLimitedInputStream(int maxAvailable, InputStream is) { return new FrameSizeLimitedFilterInputStream(maxAvailable, is); } + // Validate that the size value is not greater than the max available size + public static void validateMaxInflatedDataSize(int maxAvailable, int size) throws IOException { + if (size > maxAvailable) { + throw new MaxInflatedDataSizeExceededException( + "Cannot read more than the uncompressed size bytes: requested " + size); + } + } + + // Validate the size value is not greater than the remaining bytes in the stream + public static void validateBufferSizeRemaining(DataInputStream stream, int size) throws IOException { + validateBufferSize(stream, Integer.MAX_VALUE, size); + } + public static void marshalPrimitiveMap(Map map, DataOutputStream out) throws IOException { if (map == null) { out.writeInt(-1); @@ -486,4 +498,10 @@ public ActiveMQUnmarshalEOFException(String message) { } } + public static class MaxInflatedDataSizeExceededException extends ActiveMQUnmarshalEOFException { + public MaxInflatedDataSizeExceededException(String message) { + super(message); + } + } + } diff --git a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java index badd9912173..0bae2e8b825 100644 --- a/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java +++ b/activemq-mqtt/src/main/java/org/apache/activemq/transport/mqtt/MQTTProtocolConverter.java @@ -60,6 +60,7 @@ import org.apache.activemq.util.JMSExceptionSupport; import org.apache.activemq.util.LRUCache; import org.apache.activemq.util.LongSequenceGenerator; +import org.apache.activemq.util.MarshallingSupport; import org.fusesource.hawtbuf.Buffer; import org.fusesource.hawtbuf.UTF8Buffer; import org.fusesource.mqtt.client.QoS; @@ -636,8 +637,12 @@ public PUBLISH convertMessage(ActiveMQMessage message) throws IOException, JMSEx inflater.setInput(byteSequence.data, byteSequence.offset, byteSequence.length); byte[] data = new byte[4096]; int read; + int total = 0; ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); while ((read = inflater.inflate(data)) != 0) { + total = Math.addExact(total, read); + // check if we have exceeded maxInflatedSize before continuing + MarshallingSupport.validateMaxInflatedDataSize(message.getMaxInflatedDataSize(), total); bytesOut.write(data, 0, read); } byteSequence = bytesOut.toByteSequence(); diff --git a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java index 6f0176a558d..9c66fc0b64f 100644 --- a/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java +++ b/activemq-mqtt/src/test/java/org/apache/activemq/transport/mqtt/MQTTTest.java @@ -1194,6 +1194,53 @@ public void testCorruptBody() throws Exception { .getDestinationStatistics().getMessages().getCount() == 0, 500, 10)); } + @Test + public void testMaxInflatedDataSizeErrorBytes() throws Exception { + testMaxInflatedDataSizeError(true); + } + + @Test + public void testMaxInflatedDataSizeErrorText() throws Exception { + testMaxInflatedDataSizeError(false); + } + + private void testMaxInflatedDataSizeError(boolean bytes) throws Exception { + final MQTTClientProvider provider = getMQTTClientProvider(); + initializeConnection(provider); + + brokerService.setMaxInflatedDataSize(10); + String destinationName = "foo.far"; + ActiveMQConnection activeMQConnection = (ActiveMQConnection) cf.createConnection(); + activeMQConnection.setUseCompression(true); + activeMQConnection.start(); + Session s = activeMQConnection.createSession(false, Session.AUTO_ACKNOWLEDGE); + jakarta.jms.Topic jmsTopic = s.createTopic(destinationName); + MessageProducer producer = s.createProducer(jmsTopic); + + provider.subscribe("foo/+", AT_MOST_ONCE); + ActiveMQMessage sendMessage; + if (bytes) { + BytesMessage bytesMessage = s.createBytesMessage(); + bytesMessage.writeBytes("bodybodybodybodybody".getBytes()); + sendMessage = (ActiveMQMessage) bytesMessage; + } else { + sendMessage = (ActiveMQMessage) s.createTextMessage("bodybodybodybodybody"); + } + // marshal and clear so the broker will have to decompress + sendMessage.storeContentAndClear(); + producer.send(sendMessage); + + byte[] message = provider.receive(1000); + assertNull("Should not get message", message); + + provider.disconnect(); + activeMQConnection.close(); + + // verify message is gone off the dest + assertTrue(Wait.waitFor(() -> brokerService.getDestination(new ActiveMQTopic(destinationName)) + .getDestinationStatistics().getMessages().getCount() == 0, 500, 10)); + } + @Test(timeout = 60 * 1000) public void testPingKeepsInactivityMonitorAlive() throws Exception { MQTT mqtt = createMQTTConnection(); diff --git a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java index 783e439d006..cb3b58eca43 100644 --- a/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java +++ b/activemq-stomp/src/test/java/org/apache/activemq/transport/stomp/StompTest.java @@ -357,6 +357,61 @@ public void testCorruptMessage() throws Exception { assertTrue(sentToDlq.get()); } + @Test(timeout = 60000) + public void testMaxInflatedDataSizeErrorDlqText() throws Exception { + testMaxInflatedDataSizeErrorDlq(false); + } + + @Test(timeout = 60000) + public void testMaxInflatedDataSizeErrorDlqBytes() throws Exception { + testMaxInflatedDataSizeErrorDlq(true); + } + + private void testMaxInflatedDataSizeErrorDlq(boolean bytes) throws Exception { + String body = "testtesttesttesttesttest"; + + // set a tiny max size to trigger an error on dispatch + brokerService.setMaxInflatedDataSize(10); + ((ActiveMQConnection)connection).setUseCompression(true); + MessageProducer producer = session.createProducer(queue); + + String frame = "CONNECT\n" + "login:system\n" + "passcode:manager\n\n" + Stomp.NULL; + stompConnection.sendFrame(frame); + + frame = stompConnection.receiveFrame(); + assertTrue(frame.startsWith("CONNECTED")); + frame = "SUBSCRIBE\n" + "destination:/queue/" + getQueueName() + "\n" + "ack:auto\n\n" + Stomp.NULL; + stompConnection.sendFrame(frame); + + // marshal and clear so the broker will have to decompress + ActiveMQMessage m; + if (bytes) { + BytesMessage bytesMessage = session.createBytesMessage(); + bytesMessage.writeBytes(body.getBytes()); + m = (ActiveMQMessage) bytesMessage; + } else { + m = (ActiveMQMessage) session.createTextMessage(body); + } + m.storeContentAndClear(); + producer.send(m); + + assertTrue(Wait.waitFor(() -> brokerService.getDestination(queue) + .getDestinationStatistics().getMessages().getCount() == 1, 500, 10)); + + // Message should be DLQ'd because it exceeds max inflated data size + try { + StompFrame frameNull = stompConnection.receive(500); + if (frameNull != null) { + fail("Should not have received any messages"); + } + } catch (SocketTimeoutException soe) {} + + // verify message is gone off the dest and went to the DLQ + assertTrue(Wait.waitFor(() -> brokerService.getDestination(queue) + .getDestinationStatistics().getMessages().getCount() == 0, 500, 10)); + assertTrue(sentToDlq.get()); + } + @Test(timeout = 60000) public void testJMSXGroupIdCanBeSet() throws Exception { diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQBytesMessageTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQBytesMessageTest.java index 062b797161b..fcb4cdd51a8 100644 --- a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQBytesMessageTest.java +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQBytesMessageTest.java @@ -16,13 +16,20 @@ */ package org.apache.activemq.command; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + import jakarta.jms.JMSException; import jakarta.jms.MessageFormatException; import jakarta.jms.MessageNotReadableException; import jakarta.jms.MessageNotWriteableException; import junit.framework.TestCase; +import org.apache.activemq.ActiveMQConnection; import org.apache.activemq.test.annotations.ParallelTest; +import org.apache.activemq.util.ByteSequenceData; +import org.apache.activemq.util.ExceptionUtils; +import org.apache.activemq.util.MarshallingSupport.ActiveMQUnmarshalEOFException; import org.junit.experimental.categories.Category; /** @@ -512,4 +519,28 @@ public void testWriteOnlyBody() throws JMSException { } catch (MessageNotReadableException e) { } } + + public void testCompressedUnmarshalException() throws Exception { + ActiveMQConnection connection = mock(ActiveMQConnection.class); + when(connection.isUseCompression()).thenReturn(true); + + ActiveMQBytesMessage msg = new ActiveMQBytesMessage(); + msg.setConnection(connection); + msg.writeDouble(3.3d); + + // store and reset for reading + msg.reset(); + assertTrue(msg.isCompressed()); + + // corrupt the buffer + ByteSequenceData.writeIntBig(msg.content, 100000); + + try { + msg.readDouble(); + fail("Should have thrown exception"); + } catch (JMSException e) { + // expected + assertTrue(ExceptionUtils.getRootCause(e) instanceof ActiveMQUnmarshalEOFException); + } + } } diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQMapMessageTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQMapMessageTest.java index c09c5e03ef4..6a183f5882f 100644 --- a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQMapMessageTest.java +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQMapMessageTest.java @@ -480,7 +480,18 @@ public void testWriteOnlyBody() throws JMSException { @Test public void testUnmarshalException() throws Exception { + testUnmarshalException(false); + } + + @Test + public void testCompressedUnmarshalException() throws Exception { + testUnmarshalException(true); + } + + // For map messages both compressed and uncompressed need to be unmarshalled + private void testUnmarshalException(boolean compressed) throws Exception { ActiveMQConnection connection = mock(ActiveMQConnection.class); + when(connection.isUseCompression()).thenReturn(compressed); ActiveMQMapMessage msg = new ActiveMQMapMessage(); msg.setConnection(connection); @@ -489,6 +500,7 @@ public void testUnmarshalException() throws Exception { // store and marshal msg.storeContentAndClear(); assertTrue(msg.map.isEmpty()); + assertEquals(compressed, msg.isCompressed()); // corrupt the buffer ByteSequenceData.writeIntBig(msg.content, 1000); @@ -503,5 +515,4 @@ public void testUnmarshalException() throws Exception { ExceptionUtils.getRootCause(e) instanceof ActiveMQUnmarshalEOFException); } } - } diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQObjectMessageTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQObjectMessageTest.java index c03fa938c2b..3c3cfc7106b 100644 --- a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQObjectMessageTest.java +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQObjectMessageTest.java @@ -18,6 +18,7 @@ import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.IOException; @@ -29,6 +30,7 @@ import org.apache.activemq.ActiveMQConnection; import org.apache.activemq.test.annotations.ParallelTest; import org.apache.activemq.util.ByteSequenceData; +import org.apache.activemq.util.MarshallingSupport.ActiveMQUnmarshalEOFException; import org.apache.commons.lang.exception.ExceptionUtils; import org.junit.experimental.categories.Category; @@ -132,8 +134,17 @@ public void testWriteOnlyBody() throws JMSException { // should always be readab } } - public void testUnCompressedException() throws Exception { + public void testUnCompressedUnmarshalException() throws Exception { + testUnmarshalException(false); + } + + public void testCompressedUnmarshalException() throws Exception { + testUnmarshalException(true); + } + + private void testUnmarshalException(boolean compressed) throws Exception { ActiveMQConnection connection = mock(ActiveMQConnection.class); + when(connection.isUseCompression()).thenReturn(compressed); ActiveMQObjectMessage msg = new ActiveMQObjectMessage(); msg.setConnection(connection); @@ -142,6 +153,7 @@ public void testUnCompressedException() throws Exception { // store and marshal msg.storeContentAndClear(); assertNull(msg.object); + assertEquals(compressed, msg.isCompressed()); // corrupt the buffer ByteSequenceData.writeIntBig(msg.content, 1000); @@ -153,6 +165,12 @@ public void testUnCompressedException() throws Exception { } catch (JMSException e) { // uncompressed will have an error from the JDK deserialization assertTrue(ExceptionUtils.getRootCause(e) instanceof IOException); + + // our validation causes BufferUnmarshalException for a compressed stream + if (compressed) { + // expected + assertTrue(ExceptionUtils.getRootCause(e) instanceof ActiveMQUnmarshalEOFException); + } } } diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQStreamMessageTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQStreamMessageTest.java index edac458fe03..d30cbf84676 100644 --- a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQStreamMessageTest.java +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQStreamMessageTest.java @@ -18,10 +18,12 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import jakarta.jms.JMSException; import jakarta.jms.MessageEOFException; @@ -32,8 +34,7 @@ import org.apache.activemq.ActiveMQConnection; import org.apache.activemq.test.annotations.ParallelTest; import org.apache.activemq.util.ByteSequenceData; -import org.apache.activemq.util.MarshallingSupport.ActiveMQUnmarshalEOFException; -import org.apache.commons.lang.exception.ExceptionUtils; +import org.apache.activemq.util.ExceptionUtils; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -1084,7 +1085,17 @@ public void testReadMixBufferValuesFromStream() throws JMSException { @Test public void testUnmarshalException() throws Exception { + testUnmarshalException(false); + } + + @Test + public void testCompressedUnmarshalException() throws Exception { + testUnmarshalException(true); + } + + private void testUnmarshalException(boolean compressed) throws Exception { ActiveMQConnection connection = mock(ActiveMQConnection.class); + when(connection.isUseCompression()).thenReturn(compressed); ActiveMQStreamMessage msg = new ActiveMQStreamMessage(); msg.setConnection(connection); @@ -1093,6 +1104,7 @@ public void testUnmarshalException() throws Exception { // store and marshal msg.reset(); assertNull(msg.dataOut); + assertEquals(compressed, msg.isCompressed()); // corrupt the buffer ByteSequenceData.writeIntBig(msg.content, 1000000); @@ -1101,8 +1113,8 @@ public void testUnmarshalException() throws Exception { msg.readBytes(new byte[1024]); fail("Should have thrown exception"); } catch (JMSException e) { - // expected - assertTrue(e instanceof MessageFormatException); + // if this is not null then there was an expected format exception + assertNotNull(ExceptionUtils.createMessageFormatException(e)); } } } diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQTextMessageTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQTextMessageTest.java index 2ea8695d1a8..1ea1b0553ce 100644 --- a/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQTextMessageTest.java +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/command/ActiveMQTextMessageTest.java @@ -17,6 +17,7 @@ package org.apache.activemq.command; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.beans.Transient; import java.io.DataOutputStream; @@ -165,8 +166,19 @@ public void testTransient() throws Exception { assertTrue(method.isAnnotationPresent(Transient.class)); } - public void testUnUnmarshalException() throws Exception { + + public void testUnCompressedUnmarshalException() throws Exception { + testUnmarshalException(false); + } + + public void testCompressedUnmarshalException() throws Exception { + testUnmarshalException(true); + } + + // For text messages both compressed and uncompressed need to be unmarshalled + private void testUnmarshalException(boolean compressed) throws Exception { ActiveMQConnection connection = mock(ActiveMQConnection.class); + when(connection.isUseCompression()).thenReturn(compressed); ActiveMQTextMessage msg = new ActiveMQTextMessage(); msg.setConnection(connection); @@ -175,6 +187,7 @@ public void testUnUnmarshalException() throws Exception { // store and marshal msg.storeContentAndClear(); assertNull(msg.text); + assertEquals(compressed, msg.isCompressed()); // corrupt the buffer ByteSequenceData.writeIntBig(msg.content, 1000); diff --git a/activemq-unit-tests/src/test/java/org/apache/activemq/command/MessageCompressionTest.java b/activemq-unit-tests/src/test/java/org/apache/activemq/command/MessageCompressionTest.java index d796329edf7..9eea8fd0420 100644 --- a/activemq-unit-tests/src/test/java/org/apache/activemq/command/MessageCompressionTest.java +++ b/activemq-unit-tests/src/test/java/org/apache/activemq/command/MessageCompressionTest.java @@ -16,6 +16,12 @@ */ package org.apache.activemq.command; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + import java.io.UnsupportedEncodingException; import jakarta.jms.BytesMessage; @@ -24,16 +30,30 @@ import jakarta.jms.MessageProducer; import jakarta.jms.Session; -import junit.framework.TestCase; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.activemq.ActiveMQConnection; import org.apache.activemq.ActiveMQConnectionFactory; +import org.apache.activemq.broker.Broker; +import org.apache.activemq.broker.BrokerFilter; +import org.apache.activemq.broker.BrokerPlugin; +import org.apache.activemq.broker.BrokerPluginSupport; import org.apache.activemq.broker.BrokerService; +import org.apache.activemq.broker.ConnectionContext; +import org.apache.activemq.broker.region.MessageReference; +import org.apache.activemq.broker.region.Subscription; import org.apache.activemq.test.annotations.ParallelTest; +import org.apache.activemq.util.MarshallingSupport.ActiveMQUnmarshalEOFException; +import org.apache.activemq.util.MarshallingSupport.MaxInflatedDataSizeExceededException; +import org.apache.activemq.util.Wait; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; import org.junit.experimental.categories.Category; @Category(ParallelTest.class) -public class MessageCompressionTest extends TestCase { +public class MessageCompressionTest { private static final String BROKER_URL = "tcp://localhost:0"; // The following text should compress well @@ -50,20 +70,58 @@ public class MessageCompressionTest extends TestCase { private BrokerService broker; private ActiveMQQueue queue; private String connectionUri; + private final AtomicBoolean throwMaxInflatedException = new AtomicBoolean(false); + private final AtomicBoolean sentToDlq = new AtomicBoolean(false); + + @Before + public void setUp() throws Exception { + throwMaxInflatedException.set(false); + sentToDlq.set(false); - protected void setUp() throws Exception { broker = new BrokerService(); + broker.setPlugins(new BrokerPlugin[]{new BrokerPluginSupport() { + @Override + public Broker installPlugin(Broker broker) { + return new BrokerFilter(broker) { + @Override + public void preProcessDispatch(MessageDispatch messageDispatch) { + super.preProcessDispatch(messageDispatch); + // simulate a max inflated data size exception during protocol + // conversion and decompression + if (throwMaxInflatedException.get()) { + try { + throw new MaxInflatedDataSizeExceededException("Test"); + } catch (ActiveMQUnmarshalEOFException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public boolean sendToDeadLetterQueue(ConnectionContext context, + MessageReference messageReference, Subscription subscription, + Throwable poisonCause) { + sentToDlq.set(true); + return super.sendToDeadLetterQueue(context, messageReference, + subscription, poisonCause); + } + }; + } + }}); + connectionUri = broker.addConnector(BROKER_URL).getPublishableConnectString(); broker.start(); queue = new ActiveMQQueue("TEST." + System.currentTimeMillis()); } - protected void tearDown() throws Exception { + @After + public void tearDown() throws Exception { if (broker != null) { broker.stop(); } } + @Test public void testTextMessageCompression() throws Exception { ActiveMQConnectionFactory factory = new ActiveMQConnectionFactory(connectionUri); @@ -82,6 +140,7 @@ public void testTextMessageCompression() throws Exception { compressedSize < unCompressedSize); } + @Test public void testBytesMessageCompression() throws Exception { ActiveMQConnectionFactory factory = new ActiveMQConnectionFactory(connectionUri); @@ -89,9 +148,9 @@ public void testBytesMessageCompression() throws Exception { sendTestBytesMessage(factory, TEXT); ActiveMQBytesMessage message = receiveTestBytesMessage(factory); int compressedSize = message.getContent().getLength(); - byte[] bytes = new byte[TEXT.getBytes("UTF8").length]; + byte[] bytes = new byte[TEXT.getBytes(StandardCharsets.UTF_8).length]; message.readBytes(bytes); - assertTrue(message.readBytes(new byte[255]) == -1); + assertEquals(-1, message.readBytes(new byte[255])); String rcvString = new String(bytes, "UTF8"); assertEquals(TEXT, rcvString); assertTrue(message.isCompressed()); @@ -103,7 +162,55 @@ public void testBytesMessageCompression() throws Exception { int unCompressedSize = message.getContent().getLength(); assertTrue("expected: compressed Size '" + compressedSize + "' < unCompressedSize '" + unCompressedSize + "'", - compressedSize < unCompressedSize); + compressedSize < unCompressedSize); + } + + // Test that an error during dispatch goes to the DLQ + @Test + public void testMaxInflatedSizeDlq() throws Exception { + + ActiveMQConnectionFactory factory = new ActiveMQConnectionFactory(connectionUri); + factory.setUseCompression(true); + ActiveMQConnection con1 = (ActiveMQConnection) factory.createConnection(); + con1.start(); + + Session session1 = con1.createSession(false, Session.AUTO_ACKNOWLEDGE); + MessageProducer producer = session1.createProducer(queue); + ActiveMQBytesMessage bytesMessage = (ActiveMQBytesMessage) session1.createBytesMessage(); + bytesMessage.writeBytes(TEXT.getBytes(StandardCharsets.UTF_8)); + producer.send(bytesMessage); + + assertTrue(Wait.waitFor(() -> broker.getDestination(queue) + .getDestinationStatistics().getMessages().getCount() == 1, 500, 10)); + assertFalse(sentToDlq.get()); + + // simulate a decompression error + // this should poison ack and DLQ and we shouldn't get the message + // but the connection should still be open + this.throwMaxInflatedException.set(true); + + ActiveMQConnection con2 = (ActiveMQConnection) factory.createConnection(); + con2.start(); + Session session2 = con1.createSession(false, Session.AUTO_ACKNOWLEDGE); + MessageConsumer consumer = session2.createConsumer(queue); + assertNull(consumer.receive(1000)); + + // verify message is gone off the dest and went to the DLQ + assertTrue(Wait.waitFor(() -> broker.getDestination(queue) + .getDestinationStatistics().getMessages().getCount() == 0, 500, 10)); + assertTrue(sentToDlq.get()); + + // no longer throw an exception + this.throwMaxInflatedException.set(false); + sentToDlq.set(false); + + // exception has been disabled so we should receive again on the same connection + producer.send(bytesMessage); + assertNotNull(consumer.receive(1000)); + assertFalse(sentToDlq.get()); + + con1.close(); + con2.close(); } private void sendTestMessage(ActiveMQConnectionFactory factory, String message) throws JMSException {