Dynamic request capping system for web app plus minor optimization

This commit is contained in:
Sean Owen 2020-05-13 14:11:54 -05:00
parent 639f51d340
commit 78ccf2323f
8 changed files with 136 additions and 80 deletions

View file

@ -25,7 +25,8 @@ import javax.servlet.annotation.WebInitParam;
@WebFilter(urlPatterns = {"/w/chart"}, initParams = {
@WebInitParam(name = "maxAccessPerTime", value = "250"),
@WebInitParam(name = "accessTimeSec", value = "500"),
@WebInitParam(name = "maxEntries", value = "10000")
@WebInitParam(name = "maxEntries", value = "10000"),
@WebInitParam(name = "maxLoad", value = "0.9")
})
public final class ChartDoSFilter extends DoSFilter {
// no additional implementation

View file

@ -25,7 +25,8 @@ import javax.servlet.annotation.WebInitParam;
@WebFilter(urlPatterns = {"/w/decode"}, initParams = {
@WebInitParam(name = "maxAccessPerTime", value = "60"),
@WebInitParam(name = "accessTimeSec", value = "180"),
@WebInitParam(name = "maxEntries", value = "10000")
@WebInitParam(name = "maxEntries", value = "10000"),
@WebInitParam(name = "maxLoad", value = "0.9")
})
public final class DecodeDoSFilter extends DoSFilter {
// no additional implementation

View file

@ -103,7 +103,6 @@ public final class DecodeServlet extends HttpServlet {
private static final long MAX_IMAGE_SIZE = 1L << 26;
// No real reason to deal with more than ~32 megapixels
private static final int MAX_PIXELS = 1 << 25;
private static final byte[] REMAINDER_BUFFER = new byte[1 << 16];
private static final Map<DecodeHintType,Object> HINTS;
private static final Map<DecodeHintType,Object> HINTS_PURE;
@ -144,7 +143,7 @@ public final class DecodeServlet extends HttpServlet {
String name = getClass().getSimpleName();
timer = new Timer(name);
destHostTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries);
destHostTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries, null);
// Hack to try to avoid odd OOM due to memory leak in JAI?
timer.scheduleAtFixedRate(
new TimerTask() {
@ -198,7 +197,7 @@ public final class DecodeServlet extends HttpServlet {
errorResponse(request, response, "badurl");
return;
}
// Shortcut for data URI
if ("data".equals(imageURI.getScheme())) {
BufferedImage image;
@ -226,8 +225,8 @@ public final class DecodeServlet extends HttpServlet {
errorResponse(request, response, "badurl");
return;
}
URL imageURL;
URL imageURL;
try {
imageURL = imageURI.toURL();
} catch (MalformedURLException ignored) {
@ -273,30 +272,26 @@ public final class DecodeServlet extends HttpServlet {
}
try (InputStream is = connection.getInputStream()) {
try {
if (connection.getResponseCode() != HttpServletResponse.SC_OK) {
log.info("Unsuccessful return code " + connection.getResponseCode() + " from " + imageURIString);
errorResponse(request, response, "badurl");
return;
}
if (connection.getHeaderFieldInt(HttpHeaders.CONTENT_LENGTH, 0) > MAX_IMAGE_SIZE) {
log.info("Too large: " + imageURIString);
errorResponse(request, response, "badimage");
return;
}
// Assume we'll only handle image/* content types
String contentType = connection.getContentType();
if (contentType != null && !contentType.startsWith("image/")) {
log.info("Wrong content type " + contentType + ": " + imageURIString);
errorResponse(request, response, "badimage");
return;
}
log.info("Decoding " + imageURIString);
processStream(is, request, response);
} finally {
consumeRemainder(is);
if (connection.getResponseCode() != HttpServletResponse.SC_OK) {
log.info("Unsuccessful return code " + connection.getResponseCode() + " from " + imageURIString);
errorResponse(request, response, "badurl");
return;
}
if (connection.getHeaderFieldInt(HttpHeaders.CONTENT_LENGTH, 0) > MAX_IMAGE_SIZE) {
log.info("Too large: " + imageURIString);
errorResponse(request, response, "badimage");
return;
}
// Assume we'll only handle image/* content types
String contentType = connection.getContentType();
if (contentType != null && !contentType.startsWith("image/")) {
log.info("Wrong content type " + contentType + ": " + imageURIString);
errorResponse(request, response, "badimage");
return;
}
log.info("Decoding " + imageURIString);
processStream(is, request, response);
} catch (IOException ioe) {
log.info("Error " + ioe + " processing " + imageURIString);
errorResponse(request, response, "badurl");
@ -306,17 +301,6 @@ public final class DecodeServlet extends HttpServlet {
}
private static void consumeRemainder(InputStream is) {
try {
while (is.read(REMAINDER_BUFFER) > 0) {
// don't care about value, or collision
}
} catch (IOException | IndexOutOfBoundsException ioe) {
// sun.net.www.http.ChunkedInputStream.read is throwing IndexOutOfBoundsException
// continue
}
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
@ -378,7 +362,7 @@ public final class DecodeServlet extends HttpServlet {
image.flush();
}
}
private static void processImage(BufferedImage image,
HttpServletRequest request,
HttpServletResponse response) throws IOException, ServletException {
@ -401,7 +385,7 @@ public final class DecodeServlet extends HttpServlet {
} catch (ReaderException re) {
savedException = re;
}
if (results.isEmpty()) {
try {
// Look for pure barcode
@ -413,7 +397,7 @@ public final class DecodeServlet extends HttpServlet {
savedException = re;
}
}
if (results.isEmpty()) {
try {
// Look for normal barcode in photo
@ -425,7 +409,7 @@ public final class DecodeServlet extends HttpServlet {
savedException = re;
}
}
if (results.isEmpty()) {
try {
// Try again with other binarizer
@ -438,7 +422,7 @@ public final class DecodeServlet extends HttpServlet {
savedException = re;
}
}
if (results.isEmpty()) {
try {
throw savedException == null ? NotFoundException.getNotFoundInstance() : savedException;

View file

@ -45,15 +45,28 @@ public abstract class DoSFilter implements Filter {
public void init(FilterConfig filterConfig) {
int maxAccessPerTime = Integer.parseInt(filterConfig.getInitParameter("maxAccessPerTime"));
Preconditions.checkArgument(maxAccessPerTime > 0);
int accessTimeSec = Integer.parseInt(filterConfig.getInitParameter("accessTimeSec"));
Preconditions.checkArgument(accessTimeSec > 0);
long accessTimeMS = TimeUnit.MILLISECONDS.convert(accessTimeSec, TimeUnit.SECONDS);
int maxEntries = Integer.parseInt(filterConfig.getInitParameter("maxEntries"));
Preconditions.checkArgument(maxEntries > 0);
String maxEntriesValue = filterConfig.getInitParameter("maxEntries");
int maxEntries = Integer.MAX_VALUE;
if (maxEntriesValue != null) {
maxEntries = Integer.parseInt(maxEntriesValue);
Preconditions.checkArgument(maxEntries > 0);
}
String maxLoadValue = filterConfig.getInitParameter("maxLoad");
Double maxLoad = null;
if (maxLoadValue != null) {
maxLoad = Double.valueOf(maxLoadValue);
Preconditions.checkArgument(maxLoad > 0.0);
}
String name = getClass().getSimpleName();
timer = new Timer(name);
sourceAddrTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries);
sourceAddrTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries, maxLoad);
}
@Override

View file

@ -16,11 +16,13 @@
package com.google.zxing.web;
import java.lang.management.ManagementFactory;
import java.lang.management.OperatingSystemMXBean;
import java.util.Iterator;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
/**
@ -32,49 +34,101 @@ final class DoSTracker {
private static final Logger log = Logger.getLogger(DoSTracker.class.getName());
private final long maxAccessesPerTime;
private final Map<String,AtomicLong> numRecentAccesses;
private volatile int maxAccessesPerTime;
private final Map<String,AtomicInteger> numRecentAccesses;
DoSTracker(Timer timer, final String name, final int maxAccessesPerTime, long accessTimeMS, int maxEntries) {
/**
* @param timer {@link Timer} to use for scheduling update tasks
* @param name identifier for this tracker
* @param maxAccessesPerTime maximum number of accesses allowed from one source per {@code accessTimeMS}
* @param accessTimeMS interval in milliseconds over which up to {@code maxAccessesPerTime} accesses are allowed
* @param maxEntries maximum number of source entries to track before forgetting least recent ones
* @param maxLoad if set, dynamically adjust {@code maxAccessesPerTime} downwards when average load per core
* exceeds this value, and upwards when below this value
*/
DoSTracker(Timer timer,
final String name,
final int maxAccessesPerTime,
long accessTimeMS,
int maxEntries,
Double maxLoad) {
this.maxAccessesPerTime = maxAccessesPerTime;
this.numRecentAccesses = new LRUMap<>(maxEntries);
timer.schedule(new TimerTask() {
@Override
public void run() {
synchronized (numRecentAccesses) {
Iterator<Map.Entry<String,AtomicLong>> accessIt = numRecentAccesses.entrySet().iterator();
while (accessIt.hasNext()) {
Map.Entry<String,AtomicLong> entry = accessIt.next();
AtomicLong count = entry.getValue();
// If number of accesses is below the threshold, remove it entirely
if (count.get() <= maxAccessesPerTime) {
accessIt.remove();
} else {
// Else it exceeded the max, so log it (again)
log.warning(name + ": Blocking " + entry.getKey() + " (" + count + " outstanding)");
// Reduce count of accesses held against the host
count.getAndAdd(-maxAccessesPerTime);
}
}
}
}
}, accessTimeMS, accessTimeMS);
timer.schedule(new TrackerTask(name, maxLoad), accessTimeMS, accessTimeMS);
}
boolean isBanned(String event) {
if (event == null) {
return true;
}
AtomicLong count;
AtomicInteger count;
synchronized (numRecentAccesses) {
count = numRecentAccesses.get(event);
if (count == null) {
numRecentAccesses.put(event, new AtomicLong(1));
numRecentAccesses.put(event, new AtomicInteger(1));
return false;
}
}
return count.incrementAndGet() > maxAccessesPerTime;
}
private final class TrackerTask extends TimerTask {
private final String name;
private final Double maxLoad;
private TrackerTask(String name, Double maxLoad) {
this.name = name;
this.maxLoad = maxLoad;
}
@Override
public void run() {
// largest count <= maxAccessesPerTime
int maxAllowedCount = 1;
// smallest count > maxAccessesPerTime
int minDisallowedCount = Integer.MAX_VALUE;
int localMAPT = maxAccessesPerTime;
synchronized (numRecentAccesses) {
Iterator<Map.Entry<String,AtomicInteger>> accessIt = numRecentAccesses.entrySet().iterator();
while (accessIt.hasNext()) {
Map.Entry<String,AtomicInteger> entry = accessIt.next();
AtomicInteger atomicCount = entry.getValue();
int count = atomicCount.get();
// If number of accesses is below the threshold, remove it entirely
if (count <= localMAPT) {
accessIt.remove();
maxAllowedCount = Math.max(maxAllowedCount, count);
} else {
// Else it exceeded the max, so log it (again)
log.warning(name + ": Blocking " + entry.getKey() + " (" + count + " outstanding)");
// Reduce count of accesses held against the host
atomicCount.getAndAdd(-localMAPT);
minDisallowedCount = Math.min(minDisallowedCount, count);
}
}
}
if (maxLoad != null) {
OperatingSystemMXBean mxBean = ManagementFactory.getOperatingSystemMXBean();
if (mxBean == null) {
log.warning("Could not obtain OperatingSystemMXBean; ignoring load");
} else {
double loadAvg = mxBean.getSystemLoadAverage();
if (loadAvg >= 0.0) {
int cores = mxBean.getAvailableProcessors();
double loadRatio = loadAvg / cores;
log.info(name + ": Load ratio: " + loadRatio + " (" + loadAvg + '/' + cores + ") vs " + maxLoad);
if (loadRatio > maxLoad) {
maxAccessesPerTime = Math.min(maxAllowedCount, maxAccessesPerTime);
} else {
maxAccessesPerTime = Math.max(minDisallowedCount, maxAccessesPerTime);
}
log.info(name + ": New maxAccessesPerTime: " + maxAccessesPerTime);
}
}
}
}
}
}

View file

@ -31,7 +31,7 @@ public final class WelcomeFilter extends AbstractFilter {
public void doFilter(ServletRequest servletRequest,
ServletResponse servletResponse,
FilterChain filterChain) {
redirect(servletResponse, "/w/decode.jspx");
redirect(servletResponse, "https://" + servletRequest.getServerName() + "/w/decode.jspx");
}
}

View file

@ -31,7 +31,7 @@ public final class DoSTrackerTestCase extends Assert {
Timer timer = new Timer();
long timerTimeMS = 500;
int maxAccessPerTime = 2;
DoSTracker tracker = new DoSTracker(timer, "test", maxAccessPerTime, timerTimeMS, 3);
DoSTracker tracker = new DoSTracker(timer, "test", maxAccessPerTime, timerTimeMS, 3, null);
// 2 requests allowed per time; 3rd should be banned
assertFalse(tracker.isBanned("A"));

View file

@ -39,7 +39,10 @@ public final class WelcomeFilterTestCase extends Assert {
FilterChain chain = new MockFilterChain();
new WelcomeFilter().doFilter(request, response, chain);
assertEquals(HttpServletResponse.SC_MOVED_PERMANENTLY, response.getStatus());
assertEquals("/w/decode.jspx", response.getHeader(HttpHeaders.LOCATION));
String location = response.getHeader(HttpHeaders.LOCATION);
assertNotNull(location);
assertTrue(location.startsWith("https://"));
assertTrue(location.endsWith("/w/decode.jspx"));
}
}