mirror of
https://github.com/zxing/zxing.git
synced 2024-11-09 20:44:03 -08:00
Dynamic request capping system for web app plus minor optimization
This commit is contained in:
parent
639f51d340
commit
78ccf2323f
|
@ -25,7 +25,8 @@ import javax.servlet.annotation.WebInitParam;
|
|||
@WebFilter(urlPatterns = {"/w/chart"}, initParams = {
|
||||
@WebInitParam(name = "maxAccessPerTime", value = "250"),
|
||||
@WebInitParam(name = "accessTimeSec", value = "500"),
|
||||
@WebInitParam(name = "maxEntries", value = "10000")
|
||||
@WebInitParam(name = "maxEntries", value = "10000"),
|
||||
@WebInitParam(name = "maxLoad", value = "0.9")
|
||||
})
|
||||
public final class ChartDoSFilter extends DoSFilter {
|
||||
// no additional implementation
|
||||
|
|
|
@ -25,7 +25,8 @@ import javax.servlet.annotation.WebInitParam;
|
|||
@WebFilter(urlPatterns = {"/w/decode"}, initParams = {
|
||||
@WebInitParam(name = "maxAccessPerTime", value = "60"),
|
||||
@WebInitParam(name = "accessTimeSec", value = "180"),
|
||||
@WebInitParam(name = "maxEntries", value = "10000")
|
||||
@WebInitParam(name = "maxEntries", value = "10000"),
|
||||
@WebInitParam(name = "maxLoad", value = "0.9")
|
||||
})
|
||||
public final class DecodeDoSFilter extends DoSFilter {
|
||||
// no additional implementation
|
||||
|
|
|
@ -103,7 +103,6 @@ public final class DecodeServlet extends HttpServlet {
|
|||
private static final long MAX_IMAGE_SIZE = 1L << 26;
|
||||
// No real reason to deal with more than ~32 megapixels
|
||||
private static final int MAX_PIXELS = 1 << 25;
|
||||
private static final byte[] REMAINDER_BUFFER = new byte[1 << 16];
|
||||
private static final Map<DecodeHintType,Object> HINTS;
|
||||
private static final Map<DecodeHintType,Object> HINTS_PURE;
|
||||
|
||||
|
@ -144,7 +143,7 @@ public final class DecodeServlet extends HttpServlet {
|
|||
|
||||
String name = getClass().getSimpleName();
|
||||
timer = new Timer(name);
|
||||
destHostTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries);
|
||||
destHostTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries, null);
|
||||
// Hack to try to avoid odd OOM due to memory leak in JAI?
|
||||
timer.scheduleAtFixedRate(
|
||||
new TimerTask() {
|
||||
|
@ -198,7 +197,7 @@ public final class DecodeServlet extends HttpServlet {
|
|||
errorResponse(request, response, "badurl");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Shortcut for data URI
|
||||
if ("data".equals(imageURI.getScheme())) {
|
||||
BufferedImage image;
|
||||
|
@ -226,8 +225,8 @@ public final class DecodeServlet extends HttpServlet {
|
|||
errorResponse(request, response, "badurl");
|
||||
return;
|
||||
}
|
||||
|
||||
URL imageURL;
|
||||
|
||||
URL imageURL;
|
||||
try {
|
||||
imageURL = imageURI.toURL();
|
||||
} catch (MalformedURLException ignored) {
|
||||
|
@ -273,30 +272,26 @@ public final class DecodeServlet extends HttpServlet {
|
|||
}
|
||||
|
||||
try (InputStream is = connection.getInputStream()) {
|
||||
try {
|
||||
if (connection.getResponseCode() != HttpServletResponse.SC_OK) {
|
||||
log.info("Unsuccessful return code " + connection.getResponseCode() + " from " + imageURIString);
|
||||
errorResponse(request, response, "badurl");
|
||||
return;
|
||||
}
|
||||
if (connection.getHeaderFieldInt(HttpHeaders.CONTENT_LENGTH, 0) > MAX_IMAGE_SIZE) {
|
||||
log.info("Too large: " + imageURIString);
|
||||
errorResponse(request, response, "badimage");
|
||||
return;
|
||||
}
|
||||
// Assume we'll only handle image/* content types
|
||||
String contentType = connection.getContentType();
|
||||
if (contentType != null && !contentType.startsWith("image/")) {
|
||||
log.info("Wrong content type " + contentType + ": " + imageURIString);
|
||||
errorResponse(request, response, "badimage");
|
||||
return;
|
||||
}
|
||||
|
||||
log.info("Decoding " + imageURIString);
|
||||
processStream(is, request, response);
|
||||
} finally {
|
||||
consumeRemainder(is);
|
||||
if (connection.getResponseCode() != HttpServletResponse.SC_OK) {
|
||||
log.info("Unsuccessful return code " + connection.getResponseCode() + " from " + imageURIString);
|
||||
errorResponse(request, response, "badurl");
|
||||
return;
|
||||
}
|
||||
if (connection.getHeaderFieldInt(HttpHeaders.CONTENT_LENGTH, 0) > MAX_IMAGE_SIZE) {
|
||||
log.info("Too large: " + imageURIString);
|
||||
errorResponse(request, response, "badimage");
|
||||
return;
|
||||
}
|
||||
// Assume we'll only handle image/* content types
|
||||
String contentType = connection.getContentType();
|
||||
if (contentType != null && !contentType.startsWith("image/")) {
|
||||
log.info("Wrong content type " + contentType + ": " + imageURIString);
|
||||
errorResponse(request, response, "badimage");
|
||||
return;
|
||||
}
|
||||
|
||||
log.info("Decoding " + imageURIString);
|
||||
processStream(is, request, response);
|
||||
} catch (IOException ioe) {
|
||||
log.info("Error " + ioe + " processing " + imageURIString);
|
||||
errorResponse(request, response, "badurl");
|
||||
|
@ -306,17 +301,6 @@ public final class DecodeServlet extends HttpServlet {
|
|||
|
||||
}
|
||||
|
||||
private static void consumeRemainder(InputStream is) {
|
||||
try {
|
||||
while (is.read(REMAINDER_BUFFER) > 0) {
|
||||
// don't care about value, or collision
|
||||
}
|
||||
} catch (IOException | IndexOutOfBoundsException ioe) {
|
||||
// sun.net.www.http.ChunkedInputStream.read is throwing IndexOutOfBoundsException
|
||||
// continue
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response)
|
||||
throws ServletException, IOException {
|
||||
|
@ -378,7 +362,7 @@ public final class DecodeServlet extends HttpServlet {
|
|||
image.flush();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static void processImage(BufferedImage image,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) throws IOException, ServletException {
|
||||
|
@ -401,7 +385,7 @@ public final class DecodeServlet extends HttpServlet {
|
|||
} catch (ReaderException re) {
|
||||
savedException = re;
|
||||
}
|
||||
|
||||
|
||||
if (results.isEmpty()) {
|
||||
try {
|
||||
// Look for pure barcode
|
||||
|
@ -413,7 +397,7 @@ public final class DecodeServlet extends HttpServlet {
|
|||
savedException = re;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (results.isEmpty()) {
|
||||
try {
|
||||
// Look for normal barcode in photo
|
||||
|
@ -425,7 +409,7 @@ public final class DecodeServlet extends HttpServlet {
|
|||
savedException = re;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (results.isEmpty()) {
|
||||
try {
|
||||
// Try again with other binarizer
|
||||
|
@ -438,7 +422,7 @@ public final class DecodeServlet extends HttpServlet {
|
|||
savedException = re;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (results.isEmpty()) {
|
||||
try {
|
||||
throw savedException == null ? NotFoundException.getNotFoundInstance() : savedException;
|
||||
|
|
|
@ -45,15 +45,28 @@ public abstract class DoSFilter implements Filter {
|
|||
public void init(FilterConfig filterConfig) {
|
||||
int maxAccessPerTime = Integer.parseInt(filterConfig.getInitParameter("maxAccessPerTime"));
|
||||
Preconditions.checkArgument(maxAccessPerTime > 0);
|
||||
|
||||
int accessTimeSec = Integer.parseInt(filterConfig.getInitParameter("accessTimeSec"));
|
||||
Preconditions.checkArgument(accessTimeSec > 0);
|
||||
long accessTimeMS = TimeUnit.MILLISECONDS.convert(accessTimeSec, TimeUnit.SECONDS);
|
||||
int maxEntries = Integer.parseInt(filterConfig.getInitParameter("maxEntries"));
|
||||
Preconditions.checkArgument(maxEntries > 0);
|
||||
|
||||
String maxEntriesValue = filterConfig.getInitParameter("maxEntries");
|
||||
int maxEntries = Integer.MAX_VALUE;
|
||||
if (maxEntriesValue != null) {
|
||||
maxEntries = Integer.parseInt(maxEntriesValue);
|
||||
Preconditions.checkArgument(maxEntries > 0);
|
||||
}
|
||||
|
||||
String maxLoadValue = filterConfig.getInitParameter("maxLoad");
|
||||
Double maxLoad = null;
|
||||
if (maxLoadValue != null) {
|
||||
maxLoad = Double.valueOf(maxLoadValue);
|
||||
Preconditions.checkArgument(maxLoad > 0.0);
|
||||
}
|
||||
|
||||
String name = getClass().getSimpleName();
|
||||
timer = new Timer(name);
|
||||
sourceAddrTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries);
|
||||
sourceAddrTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries, maxLoad);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
|
||||
package com.google.zxing.web;
|
||||
|
||||
import java.lang.management.ManagementFactory;
|
||||
import java.lang.management.OperatingSystemMXBean;
|
||||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.Timer;
|
||||
import java.util.TimerTask;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
|
@ -32,49 +34,101 @@ final class DoSTracker {
|
|||
|
||||
private static final Logger log = Logger.getLogger(DoSTracker.class.getName());
|
||||
|
||||
private final long maxAccessesPerTime;
|
||||
private final Map<String,AtomicLong> numRecentAccesses;
|
||||
private volatile int maxAccessesPerTime;
|
||||
private final Map<String,AtomicInteger> numRecentAccesses;
|
||||
|
||||
DoSTracker(Timer timer, final String name, final int maxAccessesPerTime, long accessTimeMS, int maxEntries) {
|
||||
/**
|
||||
* @param timer {@link Timer} to use for scheduling update tasks
|
||||
* @param name identifier for this tracker
|
||||
* @param maxAccessesPerTime maximum number of accesses allowed from one source per {@code accessTimeMS}
|
||||
* @param accessTimeMS interval in milliseconds over which up to {@code maxAccessesPerTime} accesses are allowed
|
||||
* @param maxEntries maximum number of source entries to track before forgetting least recent ones
|
||||
* @param maxLoad if set, dynamically adjust {@code maxAccessesPerTime} downwards when average load per core
|
||||
* exceeds this value, and upwards when below this value
|
||||
*/
|
||||
DoSTracker(Timer timer,
|
||||
final String name,
|
||||
final int maxAccessesPerTime,
|
||||
long accessTimeMS,
|
||||
int maxEntries,
|
||||
Double maxLoad) {
|
||||
this.maxAccessesPerTime = maxAccessesPerTime;
|
||||
this.numRecentAccesses = new LRUMap<>(maxEntries);
|
||||
timer.schedule(new TimerTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
synchronized (numRecentAccesses) {
|
||||
Iterator<Map.Entry<String,AtomicLong>> accessIt = numRecentAccesses.entrySet().iterator();
|
||||
while (accessIt.hasNext()) {
|
||||
Map.Entry<String,AtomicLong> entry = accessIt.next();
|
||||
AtomicLong count = entry.getValue();
|
||||
// If number of accesses is below the threshold, remove it entirely
|
||||
if (count.get() <= maxAccessesPerTime) {
|
||||
accessIt.remove();
|
||||
} else {
|
||||
// Else it exceeded the max, so log it (again)
|
||||
log.warning(name + ": Blocking " + entry.getKey() + " (" + count + " outstanding)");
|
||||
// Reduce count of accesses held against the host
|
||||
count.getAndAdd(-maxAccessesPerTime);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, accessTimeMS, accessTimeMS);
|
||||
|
||||
timer.schedule(new TrackerTask(name, maxLoad), accessTimeMS, accessTimeMS);
|
||||
}
|
||||
|
||||
boolean isBanned(String event) {
|
||||
if (event == null) {
|
||||
return true;
|
||||
}
|
||||
AtomicLong count;
|
||||
AtomicInteger count;
|
||||
synchronized (numRecentAccesses) {
|
||||
count = numRecentAccesses.get(event);
|
||||
if (count == null) {
|
||||
numRecentAccesses.put(event, new AtomicLong(1));
|
||||
numRecentAccesses.put(event, new AtomicInteger(1));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return count.incrementAndGet() > maxAccessesPerTime;
|
||||
}
|
||||
|
||||
private final class TrackerTask extends TimerTask {
|
||||
|
||||
private final String name;
|
||||
private final Double maxLoad;
|
||||
|
||||
private TrackerTask(String name, Double maxLoad) {
|
||||
this.name = name;
|
||||
this.maxLoad = maxLoad;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
// largest count <= maxAccessesPerTime
|
||||
int maxAllowedCount = 1;
|
||||
// smallest count > maxAccessesPerTime
|
||||
int minDisallowedCount = Integer.MAX_VALUE;
|
||||
int localMAPT = maxAccessesPerTime;
|
||||
synchronized (numRecentAccesses) {
|
||||
Iterator<Map.Entry<String,AtomicInteger>> accessIt = numRecentAccesses.entrySet().iterator();
|
||||
while (accessIt.hasNext()) {
|
||||
Map.Entry<String,AtomicInteger> entry = accessIt.next();
|
||||
AtomicInteger atomicCount = entry.getValue();
|
||||
int count = atomicCount.get();
|
||||
// If number of accesses is below the threshold, remove it entirely
|
||||
if (count <= localMAPT) {
|
||||
accessIt.remove();
|
||||
maxAllowedCount = Math.max(maxAllowedCount, count);
|
||||
} else {
|
||||
// Else it exceeded the max, so log it (again)
|
||||
log.warning(name + ": Blocking " + entry.getKey() + " (" + count + " outstanding)");
|
||||
// Reduce count of accesses held against the host
|
||||
atomicCount.getAndAdd(-localMAPT);
|
||||
minDisallowedCount = Math.min(minDisallowedCount, count);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (maxLoad != null) {
|
||||
OperatingSystemMXBean mxBean = ManagementFactory.getOperatingSystemMXBean();
|
||||
if (mxBean == null) {
|
||||
log.warning("Could not obtain OperatingSystemMXBean; ignoring load");
|
||||
} else {
|
||||
double loadAvg = mxBean.getSystemLoadAverage();
|
||||
if (loadAvg >= 0.0) {
|
||||
int cores = mxBean.getAvailableProcessors();
|
||||
double loadRatio = loadAvg / cores;
|
||||
log.info(name + ": Load ratio: " + loadRatio + " (" + loadAvg + '/' + cores + ") vs " + maxLoad);
|
||||
if (loadRatio > maxLoad) {
|
||||
maxAccessesPerTime = Math.min(maxAllowedCount, maxAccessesPerTime);
|
||||
} else {
|
||||
maxAccessesPerTime = Math.max(minDisallowedCount, maxAccessesPerTime);
|
||||
}
|
||||
log.info(name + ": New maxAccessesPerTime: " + maxAccessesPerTime);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ public final class WelcomeFilter extends AbstractFilter {
|
|||
public void doFilter(ServletRequest servletRequest,
|
||||
ServletResponse servletResponse,
|
||||
FilterChain filterChain) {
|
||||
redirect(servletResponse, "/w/decode.jspx");
|
||||
redirect(servletResponse, "https://" + servletRequest.getServerName() + "/w/decode.jspx");
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ public final class DoSTrackerTestCase extends Assert {
|
|||
Timer timer = new Timer();
|
||||
long timerTimeMS = 500;
|
||||
int maxAccessPerTime = 2;
|
||||
DoSTracker tracker = new DoSTracker(timer, "test", maxAccessPerTime, timerTimeMS, 3);
|
||||
DoSTracker tracker = new DoSTracker(timer, "test", maxAccessPerTime, timerTimeMS, 3, null);
|
||||
|
||||
// 2 requests allowed per time; 3rd should be banned
|
||||
assertFalse(tracker.isBanned("A"));
|
||||
|
|
|
@ -39,7 +39,10 @@ public final class WelcomeFilterTestCase extends Assert {
|
|||
FilterChain chain = new MockFilterChain();
|
||||
new WelcomeFilter().doFilter(request, response, chain);
|
||||
assertEquals(HttpServletResponse.SC_MOVED_PERMANENTLY, response.getStatus());
|
||||
assertEquals("/w/decode.jspx", response.getHeader(HttpHeaders.LOCATION));
|
||||
String location = response.getHeader(HttpHeaders.LOCATION);
|
||||
assertNotNull(location);
|
||||
assertTrue(location.startsWith("https://"));
|
||||
assertTrue(location.endsWith("/w/decode.jspx"));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue