Revamp protection against flood of requests; minor related tweaks

This commit is contained in:
Sean Owen 2017-10-02 13:40:53 +01:00
parent 05093ed3d2
commit 65d2b163eb
3 changed files with 40 additions and 39 deletions

View file

@ -447,6 +447,7 @@ public final class DecodeServlet extends HttpServlet {
if (dispatcher == null) {
log.warning("Can't obtain RequestDispatcher");
} else {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
dispatcher.forward(request, response);
}
}

View file

@ -26,15 +26,13 @@ import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Logger;
/**
@ -48,40 +46,49 @@ public final class DoSFilter implements Filter {
private static final Logger log = Logger.getLogger(DoSFilter.class.getName());
private static final int MAX_ACCESSES_PER_IP_PER_TIME = 100;
private static final int MAX_RECENT_ACCESS_MAP_SIZE = 100_000;
private static final int MAX_ACCESSES_PER_IP_PER_TIME = 50;
private static final long MAX_ACCESSES_TIME_MS = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES);
private static final int MAX_RECENT_ACCESS_MAP_SIZE = 10_000;
private final Map<String,AtomicInteger> numRecentAccesses;
private final Set<String> bannedIPAddresses;
private final Map<String,AtomicLong> numRecentAccesses;
private Timer timer;
public DoSFilter() {
numRecentAccesses = Collections.synchronizedMap(new LinkedHashMap<String,AtomicInteger>() {
numRecentAccesses = new LinkedHashMap<String,AtomicLong>() {
@Override
protected boolean removeEldestEntry(Map.Entry<String,AtomicInteger> eldest) {
protected boolean removeEldestEntry(Map.Entry<String,AtomicLong> eldest) {
return size() > MAX_RECENT_ACCESS_MAP_SIZE;
}
});
bannedIPAddresses = Collections.synchronizedSet(new HashSet<String>());
};
}
@Override
public void init(FilterConfig filterConfig) {
timer = new Timer("DoSFilter reset timer");
timer = new Timer("DoSFilter");
timer.scheduleAtFixedRate(
new TimerTask() {
@Override
public void run() {
numRecentAccesses.clear();
synchronized (numRecentAccesses) {
// Periodically reduce allowed accesses per IP
Iterator<Map.Entry<String,AtomicLong>> accessIt = numRecentAccesses.entrySet().iterator();
while (accessIt.hasNext()) {
Map.Entry<String,AtomicLong> entry = accessIt.next();
AtomicLong count = entry.getValue();
// If number of accesses is below the threshold, remove it entirely
if (count.get() <= MAX_ACCESSES_PER_IP_PER_TIME) {
accessIt.remove();
} else {
// Else it exceeded the max, so log it (again)
log.warning("Possible DoS attack from " + entry.getKey() + " (" + count + " outstanding)");
// Reduce count of accesses held against the IP
count.getAndAdd(-MAX_ACCESSES_PER_IP_PER_TIME);
}
}
log.info("Tracking accesses from " + numRecentAccesses.size() + " IPs");
}
}
}, 0L, TimeUnit.MILLISECONDS.convert(1, TimeUnit.MINUTES));
timer.scheduleAtFixedRate(
new TimerTask() {
@Override
public void run() {
bannedIPAddresses.clear();
}
}, 0L, TimeUnit.MILLISECONDS.convert(15, TimeUnit.MINUTES));
}, MAX_ACCESSES_TIME_MS, MAX_ACCESSES_TIME_MS);
timer.scheduleAtFixedRate(
new TimerTask() {
@Override
@ -108,25 +115,18 @@ public final class DoSFilter implements Filter {
if (remoteIPAddress == null) {
remoteIPAddress = request.getRemoteAddr();
}
if (remoteIPAddress == null || bannedIPAddresses.contains(remoteIPAddress)) {
if (remoteIPAddress == null) {
return true;
}
if (getCount(remoteIPAddress) > MAX_ACCESSES_PER_IP_PER_TIME) {
log.warning("Possible DoS attack from " + remoteIPAddress);
bannedIPAddresses.add(remoteIPAddress);
return true;
}
return false;
}
private int getCount(String remoteIPAddress) {
AtomicInteger count = numRecentAccesses.get(remoteIPAddress);
if (count == null) {
numRecentAccesses.put(remoteIPAddress, new AtomicInteger(1));
return 1;
} else {
return count.incrementAndGet();
AtomicLong count;
synchronized (numRecentAccesses) {
count = numRecentAccesses.get(remoteIPAddress);
if (count == null) {
count = new AtomicLong();
numRecentAccesses.put(remoteIPAddress, count);
}
}
return count.incrementAndGet() > MAX_ACCESSES_PER_IP_PER_TIME;
}
@Override

View file

@ -55,7 +55,7 @@ JAVA_OPTS="-Djava.security.egd=file:/dev/urandom -Djava.awt.headless=true -Xmx32
<Host name="localhost" appBase="webapps" unpackWARs="true" autoDeploy="true">
<Valve className="org.apache.catalina.valves.AccessLogValve" directory="logs"
prefix="localhost_access_log" suffix=".txt" rotatable="false"
prefix="localhost_access_log" suffix=".txt"
pattern="%h %l %u %t &quot;%r&quot; %s %b" />
</Host>