mirror of
https://github.com/zxing/zxing.git
synced 2024-11-09 20:44:03 -08:00
More sophisticated load protection, plus tests
This commit is contained in:
parent
bc645c50bb
commit
8a53ade692
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Copyright 2019 ZXing authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.zxing.web;
|
||||
|
||||
import javax.servlet.annotation.WebFilter;
|
||||
import javax.servlet.annotation.WebInitParam;
|
||||
|
||||
/**
|
||||
* Protect the /chart endpoint from too many requests.
|
||||
*/
|
||||
@WebFilter(urlPatterns = {"/w/chart"}, initParams = {
|
||||
@WebInitParam(name = "maxAccessPerTime", value = "250"),
|
||||
@WebInitParam(name = "accessTimeSec", value = "500"),
|
||||
@WebInitParam(name = "maxEntries", value = "10000")
|
||||
})
|
||||
public final class ChartDoSFilter extends DoSFilter {
|
||||
// no additional implementation
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Copyright 2019 ZXing authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.google.zxing.web;
|
||||
|
||||
import javax.servlet.annotation.WebFilter;
|
||||
import javax.servlet.annotation.WebInitParam;
|
||||
|
||||
/**
|
||||
* Protect the /decode endpoint from too many requests.
|
||||
*/
|
||||
@WebFilter(urlPatterns = {"/w/decode"}, initParams = {
|
||||
@WebInitParam(name = "maxAccessPerTime", value = "60"),
|
||||
@WebInitParam(name = "accessTimeSec", value = "180"),
|
||||
@WebInitParam(name = "maxEntries", value = "10000")
|
||||
})
|
||||
public final class DecodeDoSFilter extends DoSFilter {
|
||||
// no additional implementation
|
||||
}
|
|
@ -59,6 +59,7 @@ import java.util.Locale;
|
|||
import java.util.Map;
|
||||
import java.util.ResourceBundle;
|
||||
import java.util.Timer;
|
||||
import java.util.TimerTask;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
@ -89,8 +90,8 @@ import javax.servlet.http.Part;
|
|||
fileSizeThreshold = 1 << 23, // ~8MB
|
||||
location = "/tmp")
|
||||
@WebServlet(value = "/w/decode", loadOnStartup = 1, initParams = {
|
||||
@WebInitParam(name = "maxAccessPerTime", value = "150"),
|
||||
@WebInitParam(name = "accessTimeSec", value = "300"),
|
||||
@WebInitParam(name = "maxAccessPerTime", value = "120"),
|
||||
@WebInitParam(name = "accessTimeSec", value = "120"),
|
||||
@WebInitParam(name = "maxEntries", value = "10000")
|
||||
})
|
||||
public final class DecodeServlet extends HttpServlet {
|
||||
|
@ -141,8 +142,17 @@ public final class DecodeServlet extends HttpServlet {
|
|||
long accessTimeMS = TimeUnit.MILLISECONDS.convert(accessTimeSec, TimeUnit.SECONDS);
|
||||
int maxEntries = Integer.parseInt(servletConfig.getInitParameter("maxEntries"));
|
||||
|
||||
timer = new Timer("DecodeServlet");
|
||||
destHostTracker = new DoSTracker(timer, maxAccessPerTime, accessTimeMS, maxEntries);
|
||||
String name = getClass().getSimpleName();
|
||||
timer = new Timer(name);
|
||||
destHostTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries);
|
||||
// Hack to try to avoid odd OOM due to memory leak in JAI?
|
||||
timer.scheduleAtFixedRate(
|
||||
new TimerTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
System.gc();
|
||||
}
|
||||
}, 0L, TimeUnit.MILLISECONDS.convert(10, TimeUnit.MINUTES));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -433,10 +443,8 @@ public final class DecodeServlet extends HttpServlet {
|
|||
try {
|
||||
throw savedException == null ? NotFoundException.getNotFoundInstance() : savedException;
|
||||
} catch (FormatException | ChecksumException e) {
|
||||
log.info(e.toString());
|
||||
errorResponse(request, response, "format");
|
||||
} catch (ReaderException e) { // Including NotFoundException
|
||||
log.info(e.toString());
|
||||
errorResponse(request, response, "notfound");
|
||||
}
|
||||
return;
|
||||
|
|
|
@ -16,19 +16,18 @@
|
|||
|
||||
package com.google.zxing.web;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import javax.servlet.Filter;
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.FilterConfig;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.ServletResponse;
|
||||
import javax.servlet.annotation.WebFilter;
|
||||
import javax.servlet.annotation.WebInitParam;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.Timer;
|
||||
import java.util.TimerTask;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
|
@ -37,12 +36,7 @@ import java.util.concurrent.TimeUnit;
|
|||
*
|
||||
* @author Sean Owen
|
||||
*/
|
||||
@WebFilter(urlPatterns = {"/w/decode", "/w/chart"}, initParams = {
|
||||
@WebInitParam(name = "maxAccessPerTime", value = "150"),
|
||||
@WebInitParam(name = "accessTimeSec", value = "300"),
|
||||
@WebInitParam(name = "maxEntries", value = "10000")
|
||||
})
|
||||
public final class DoSFilter implements Filter {
|
||||
public abstract class DoSFilter implements Filter {
|
||||
|
||||
private Timer timer;
|
||||
private DoSTracker sourceAddrTracker;
|
||||
|
@ -50,18 +44,16 @@ public final class DoSFilter implements Filter {
|
|||
@Override
|
||||
public void init(FilterConfig filterConfig) {
|
||||
int maxAccessPerTime = Integer.parseInt(filterConfig.getInitParameter("maxAccessPerTime"));
|
||||
Preconditions.checkArgument(maxAccessPerTime > 0);
|
||||
int accessTimeSec = Integer.parseInt(filterConfig.getInitParameter("accessTimeSec"));
|
||||
Preconditions.checkArgument(accessTimeSec > 0);
|
||||
long accessTimeMS = TimeUnit.MILLISECONDS.convert(accessTimeSec, TimeUnit.SECONDS);
|
||||
int maxEntries = Integer.parseInt(filterConfig.getInitParameter("maxEntries"));
|
||||
timer = new Timer("DoSFilter");
|
||||
sourceAddrTracker = new DoSTracker(timer, maxAccessPerTime, accessTimeMS, maxEntries);
|
||||
timer.scheduleAtFixedRate(
|
||||
new TimerTask() {
|
||||
@Override
|
||||
public void run() {
|
||||
System.gc();
|
||||
}
|
||||
}, 0L, TimeUnit.MILLISECONDS.convert(15, TimeUnit.MINUTES));
|
||||
Preconditions.checkArgument(maxEntries > 0);
|
||||
|
||||
String name = getClass().getSimpleName();
|
||||
timer = new Timer(name);
|
||||
sourceAddrTracker = new DoSTracker(timer, name, maxAccessPerTime, accessTimeMS, maxEntries);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -79,9 +71,17 @@ public final class DoSFilter implements Filter {
|
|||
}
|
||||
|
||||
private boolean isBanned(HttpServletRequest request) {
|
||||
String remoteIPAddress = request.getHeader("x-forwarded-for");
|
||||
String remoteHost = request.getHeader("x-forwarded-for");
|
||||
if (remoteHost != null) {
|
||||
int comma = remoteHost.indexOf(',');
|
||||
if (comma >= 0) {
|
||||
remoteHost = remoteHost.substring(0, comma);
|
||||
}
|
||||
remoteHost = remoteHost.trim();
|
||||
}
|
||||
// Non-short-circuit "|" below is on purpose
|
||||
return
|
||||
(remoteIPAddress != null && sourceAddrTracker.isBanned(remoteIPAddress)) ||
|
||||
(remoteHost != null && sourceAddrTracker.isBanned(remoteHost)) |
|
||||
sourceAddrTracker.isBanned(request.getRemoteAddr());
|
||||
}
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ final class DoSTracker {
|
|||
private final long maxAccessesPerTime;
|
||||
private final Map<String,AtomicLong> numRecentAccesses;
|
||||
|
||||
DoSTracker(Timer timer, final int maxAccessesPerTime, long accessTimeMS, int maxEntries) {
|
||||
DoSTracker(Timer timer, final String name, final int maxAccessesPerTime, long accessTimeMS, int maxEntries) {
|
||||
this.maxAccessesPerTime = maxAccessesPerTime;
|
||||
this.numRecentAccesses = new LRUMap<>(maxEntries);
|
||||
timer.schedule(new TimerTask() {
|
||||
|
@ -51,8 +51,8 @@ final class DoSTracker {
|
|||
accessIt.remove();
|
||||
} else {
|
||||
// Else it exceeded the max, so log it (again)
|
||||
log.warning("Blocking " + entry.getKey() + " (" + count + " outstanding)");
|
||||
// Reduce count of accesses held against the IP
|
||||
log.warning(name + ": Blocking " + entry.getKey() + " (" + count + " outstanding)");
|
||||
// Reduce count of accesses held against the host
|
||||
count.getAndAdd(-maxAccessesPerTime);
|
||||
}
|
||||
}
|
||||
|
@ -70,8 +70,8 @@ final class DoSTracker {
|
|||
synchronized (numRecentAccesses) {
|
||||
count = numRecentAccesses.get(event);
|
||||
if (count == null) {
|
||||
count = new AtomicLong();
|
||||
numRecentAccesses.put(event, count);
|
||||
numRecentAccesses.put(event, new AtomicLong(1));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return count.incrementAndGet() > maxAccessesPerTime;
|
||||
|
|
|
@ -23,32 +23,79 @@ import org.springframework.mock.web.MockFilterConfig;
|
|||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
|
||||
import javax.servlet.Filter;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* Tests {@link DoSFilter}.
|
||||
* Tests {@link DoSFilter} implementations.
|
||||
*/
|
||||
public final class DoSFilterTestCase extends Assert {
|
||||
|
||||
private static final int MAX_ACCESS_PER_TIME = 10;
|
||||
|
||||
@Test
|
||||
public void testRedirect() throws Exception {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setRequestURI("/");
|
||||
request.setRemoteAddr("1.2.3.4");
|
||||
HttpServletResponse response = new MockHttpServletResponse();
|
||||
DoSFilter filter = new DoSFilter();
|
||||
for (DoSFilter filter : Arrays.asList(new ChartDoSFilter(), new DecodeDoSFilter())) {
|
||||
initFilter(filter);
|
||||
try {
|
||||
for (int i = 0; i < MAX_ACCESS_PER_TIME; i++) {
|
||||
testRequest(filter, "1.2.3.4", null, HttpServletResponse.SC_OK);
|
||||
}
|
||||
testRequest(filter, "1.2.3.4", null, HttpServletResponse.SC_FORBIDDEN);
|
||||
} finally {
|
||||
filter.destroy();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoRemoteHost() throws Exception {
|
||||
Filter filter = new DecodeDoSFilter();
|
||||
initFilter(filter);
|
||||
try {
|
||||
testRequest(filter, null, null, HttpServletResponse.SC_FORBIDDEN);
|
||||
testRequest(filter, null, "1.1.1.1", HttpServletResponse.SC_FORBIDDEN);
|
||||
} finally {
|
||||
filter.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testProxy() throws Exception {
|
||||
Filter filter = new DecodeDoSFilter();
|
||||
initFilter(filter);
|
||||
try {
|
||||
for (int i = 0; i < MAX_ACCESS_PER_TIME; i++) {
|
||||
testRequest(filter, "1.2.3.4", "1.1.1." + i + ", proxy1", HttpServletResponse.SC_OK);
|
||||
}
|
||||
testRequest(filter, "1.2.3.4", "1.1.1.0", HttpServletResponse.SC_FORBIDDEN);
|
||||
} finally {
|
||||
filter.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
private void initFilter(Filter filter) throws ServletException {
|
||||
MockFilterConfig config = new MockFilterConfig();
|
||||
int maxAccessPerTime = 10;
|
||||
config.addInitParameter("maxAccessPerTime", Integer.toString(maxAccessPerTime));
|
||||
config.addInitParameter("maxAccessPerTime", Integer.toString(MAX_ACCESS_PER_TIME));
|
||||
config.addInitParameter("accessTimeSec", "60");
|
||||
config.addInitParameter("maxEntries", "100");
|
||||
filter.init(config);
|
||||
for (int i = 0; i < maxAccessPerTime; i++) {
|
||||
filter.doFilter(request, response, new MockFilterChain());
|
||||
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
|
||||
}
|
||||
|
||||
private void testRequest(Filter filter, String host, String proxy, int expectedStatus)
|
||||
throws IOException, ServletException {
|
||||
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||
request.setRequestURI("/");
|
||||
request.setRemoteAddr(host);
|
||||
if (proxy != null) {
|
||||
request.addHeader("X-Forwarded-For", proxy);
|
||||
}
|
||||
HttpServletResponse response = new MockHttpServletResponse();
|
||||
filter.doFilter(request, response, new MockFilterChain());
|
||||
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
|
||||
assertEquals(expectedStatus, response.getStatus());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ public final class DoSTrackerTestCase extends Assert {
|
|||
Timer timer = new Timer();
|
||||
long timerTimeMS = 500;
|
||||
int maxAccessPerTime = 2;
|
||||
DoSTracker tracker = new DoSTracker(timer, maxAccessPerTime, timerTimeMS, 3);
|
||||
DoSTracker tracker = new DoSTracker(timer, "test", maxAccessPerTime, timerTimeMS, 3);
|
||||
|
||||
// 2 requests allowed per time; 3rd should be banned
|
||||
assertFalse(tracker.isBanned("A"));
|
||||
|
|
Loading…
Reference in a new issue