Some time ago I wrote a post describing a Java Multi-channel Asynchronous Throttler I had written. At the time, I stated it would preserve the order of calls, but as Asa commented on that blog post, this was not always the case. Here is a new version that does preserve order, and passes Asa's test. As part of this work I also extracted common code into new classes and created a ChannelThrottler interface. It works by placing incoming tasks on an internal queue. All the code detailed in this post (and the other throttler post) is available here.
//imports skipped for prettier code display
public interface ChannelThrottler {
Future<?> submit(Runnable task);
Future<?> submit(Object channelKey, Runnable task);
}
//imports skipped for prettier code display
public final class Rate {
private final int numberCalls;
private final int timeLength;
private final TimeUnit timeUnit;
private final LinkedList<Long> callHistory = new LinkedList<Long>();
public Rate(int numberCalls, int timeLength, TimeUnit timeUnit) {
this.numberCalls = numberCalls;
this.timeLength = timeLength;
this.timeUnit = timeUnit;
}
private long timeInMillis() {
return timeUnit.toMillis(timeLength);
}
/* package */ void addCall(long callTime) {
callHistory.addLast(callTime);
}
private void cleanOld(long now) {
ListIterator<Long> i = callHistory.listIterator();
long threshold = now-timeInMillis();
while (i.hasNext()) {
if (i.next()<=threshold) {
i.remove();
} else {
break;
}
}
}
/* package */ long callTime(long now) {
cleanOld(now);
if (callHistory.size()<numberCalls) {
return now;
}
long lastStart = callHistory.getLast()-timeInMillis();
long firstPeriodCall=lastStart, call;
int count = 0;
Iterator<Long> i = callHistory.descendingIterator();
while (i.hasNext()) {
call = i.next();
if (call<lastStart) {
break;
} else {
count++;
firstPeriodCall = call;
}
}
if (count<numberCalls) {
return firstPeriodCall+1;
} else {
return firstPeriodCall+timeInMillis()+1;
}
}
}
//imports skipped for prettier code display
/* package */ abstract class AbstractChannelThrottler implements ChannelThrottler {
protected final Rate totalRate;
protected final TimeProvider timeProvider;
protected final ScheduledExecutorService scheduler;
protected final Map<Object, Rate> channels = new HashMap<Object, Rate>();
protected AbstractChannelThrottler(Rate totalRate, ScheduledExecutorService scheduler, Map<Object, Rate> channels, TimeProvider timeProvider) {
this.totalRate = totalRate;
this.scheduler = scheduler;
this.channels.putAll(channels);
this.timeProvider = timeProvider;
}
protected synchronized long callTime(Rate channel) {
long now = timeProvider.getCurrentTimeInMillis();
long callTime = totalRate.callTime(now);
if (channel!=null) {
callTime = Math.max(callTime, channel.callTime(now));
channel.addCall(callTime);
}
totalRate.addCall(callTime);
return callTime;
}
protected long getThrottleDelay(Object channelKey) {
long delay = callTime(channels.get(channelKey))-timeProvider.getCurrentTimeInMillis();
return delay<0?0:delay;
}
}
public final class QueueChannelThrottler extends AbstractChannelThrottler {
private final Runnable processQueueTask = new Runnable() {
@Override public void run() {
FutureTask<?> task = tasks.poll();
if (task!=null && !task.isCancelled()) {
task.run();
}
}
};
private final Queue<FutureTask<?>> tasks = new LinkedList<FutureTask<?>>();
public QueueChannelThrottler(Rate totalRate) {
this(totalRate, Executors.newSingleThreadScheduledExecutor(), new HashMap<Object, Rate>(), TimeProvider.SYSTEM_PROVIDER);
}
public QueueChannelThrottler(Rate totalRate, Map<Object, Rate> channels) {
this(totalRate, Executors.newSingleThreadScheduledExecutor(), channels, TimeProvider.SYSTEM_PROVIDER);
}
public QueueChannelThrottler(Rate totalRate, ScheduledExecutorService scheduler, Map<Object, Rate> channels, TimeProvider timeProvider) {
super(totalRate, scheduler, channels, timeProvider);
}
@Override public Future<?> submit(Runnable task) {
return submit(null, task);
}
@SuppressWarnings("unchecked")
@Override public Future<?> submit(Object channelKey, Runnable task) {
long throttledTime = channelKey==null?callTime(null):callTime(channels.get(channelKey));
FutureTask runTask = new FutureTask(task, null);
tasks.add(runTask);
long now = timeProvider.getCurrentTimeInMillis();
scheduler.schedule(processQueueTask, throttledTime<now?0:throttledTime-now, TimeUnit.MILLISECONDS);
return runTask;
}
}
//imports skipped for prettier code display
public class QueueChannelThrottlerTest {
private static final String CHANNEL1 = "CHANNEL1";
private static final String CHANNEL2 = "CHANNEL2";
private DeterministicScheduler scheduler;
private AtomicLong currentTime = new AtomicLong(0);
private QueueChannelThrottler throttler;
private AtomicInteger count = new AtomicInteger(0);
private Runnable countIncrementTask = new Runnable() {@Override public void run() {count.incrementAndGet();}};
@SuppressWarnings("serial")
@Before public void setupThrottler() {
scheduler = new DeterministicScheduler();
currentTime.set(0);
Map<Object, Rate> channels = new HashMap<Object, Rate>() {{
put(CHANNEL1, new Rate(3, 1, TimeUnit.SECONDS));
put(CHANNEL2, new Rate(1, 1, TimeUnit.SECONDS));
}};
throttler = new QueueChannelThrottler(new Rate(2, 1, TimeUnit.SECONDS), scheduler, channels, new TimeProvider() {
@Override public long getCurrentTimeInMillis() {return currentTime.get();}
});
count = new AtomicInteger(0);
}
@Test public void testTotalChannelWithNoDelay() throws Exception {
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
}
@Test public void testTotalChannelWithDelay() throws Exception {
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
scheduler.tick(1000, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
}
@Test public void testTotalChannelWithDoubleDelay() throws Exception {
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
scheduler.tick(500, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
scheduler.tick(500, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(4, count.get());
scheduler.tick(1000, TimeUnit.MILLISECONDS);
assertEquals(5, count.get());
}
@Test public void testTotalChannelWithShortestDelay() throws Exception {
throttler.submit(countIncrementTask);
currentTime = new AtomicLong(777);
scheduler.tick(777, TimeUnit.MILLISECONDS);
throttler.submit(countIncrementTask);
throttler.submit(countIncrementTask);
currentTime = new AtomicLong(877);
scheduler.tick(100, TimeUnit.MILLISECONDS);
throttler.submit(countIncrementTask);
assertEquals(2, count.get());
scheduler.tick(124, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
scheduler.tick(777, TimeUnit.MILLISECONDS);
assertEquals(4, count.get());
}
@Test public void testChannel() throws Exception {
throttler.submit(CHANNEL2, countIncrementTask);
currentTime = new AtomicLong(777);
scheduler.tick(777, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL2, countIncrementTask);
currentTime = new AtomicLong(877);
scheduler.tick(100, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL2, countIncrementTask);
assertEquals(1, count.get());
scheduler.tick(124, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
scheduler.tick(1000, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
}
@Test public void testChannelAndTotal() throws Exception {
throttler.submit(CHANNEL1, countIncrementTask);
currentTime = new AtomicLong(777);
scheduler.tick(777, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL1, countIncrementTask);
throttler.submit(CHANNEL1, countIncrementTask);
currentTime = new AtomicLong(877);
scheduler.tick(100, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL1, countIncrementTask);
assertEquals(2, count.get());
scheduler.tick(124, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
scheduler.tick(777, TimeUnit.MILLISECONDS);
assertEquals(4, count.get());
}
@Test public void testChannelAffectsTotal() throws Exception {
throttler.submit(CHANNEL1, countIncrementTask);
currentTime = new AtomicLong(777);
scheduler.tick(777, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL1, countIncrementTask);
throttler.submit(countIncrementTask);
currentTime = new AtomicLong(877);
scheduler.tick(100, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL1, countIncrementTask);
assertEquals(2, count.get());
scheduler.tick(124, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
scheduler.tick(777, TimeUnit.MILLISECONDS);
assertEquals(4, count.get());
}
private class OrderedTask implements Runnable {
private final int order;
public OrderedTask(int order) {this.order=order;}
@Override public void run() {
assertEquals(count.incrementAndGet(), order);
}
};
@Test public void testChannelCallsAreOrdered() throws Exception {
throttler.submit(CHANNEL1, new OrderedTask(1));
throttler.submit(CHANNEL2, new OrderedTask(2));
throttler.submit(CHANNEL1, new OrderedTask(3));
throttler.submit(CHANNEL2, new OrderedTask(4));
throttler.submit(CHANNEL2, new OrderedTask(5));
throttler.submit(CHANNEL1, new OrderedTask(6));
throttler.submit(CHANNEL1, new OrderedTask(7));
scheduler.tick(5000, TimeUnit.MILLISECONDS);
assertEquals(7, count.get());
}
@Test public void testMultiChannel() throws Exception {
throttler.submit(CHANNEL1, countIncrementTask);
currentTime = new AtomicLong(777);
scheduler.tick(777, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL2, countIncrementTask);
throttler.submit(CHANNEL1, countIncrementTask);
currentTime = new AtomicLong(877);
scheduler.tick(100, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL2, countIncrementTask);
throttler.submit(CHANNEL2, countIncrementTask);
throttler.submit(CHANNEL1, countIncrementTask);
throttler.submit(CHANNEL1, countIncrementTask);
assertEquals(2, count.get());
scheduler.tick(123, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
scheduler.tick(778, TimeUnit.MILLISECONDS);
assertEquals(4, count.get());
scheduler.tick(1001, TimeUnit.MILLISECONDS);
assertEquals(6, count.get());
scheduler.tick(999, TimeUnit.MILLISECONDS);
assertEquals(6, count.get());
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(7, count.get());
}
@Test public void testMultiChannelWithTotal() throws Exception {
throttler.submit(CHANNEL1, countIncrementTask);
currentTime = new AtomicLong(777);
scheduler.tick(777, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL2, countIncrementTask);
throttler.submit(countIncrementTask);
currentTime = new AtomicLong(877);
scheduler.tick(100, TimeUnit.MILLISECONDS);
throttler.submit(CHANNEL2, countIncrementTask);
throttler.submit(CHANNEL2, countIncrementTask);
throttler.submit(countIncrementTask);
throttler.submit(CHANNEL1, countIncrementTask);
assertEquals(2, count.get());
scheduler.tick(123, TimeUnit.MILLISECONDS);
assertEquals(2, count.get());
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(3, count.get());
scheduler.tick(778, TimeUnit.MILLISECONDS);
assertEquals(4, count.get());
scheduler.tick(1001, TimeUnit.MILLISECONDS);
assertEquals(6, count.get());
scheduler.tick(999, TimeUnit.MILLISECONDS);
assertEquals(6, count.get());
scheduler.tick(1, TimeUnit.MILLISECONDS);
assertEquals(7, count.get());
}
@Test
public void scheduledTasksMonotonicallyIncreasing(){
int numCalls = 50;
int totalCalls = 1000;
int ratePeriod = 100;
final CountDownLatch latch = new CountDownLatch(totalCalls);
Rate rate = new Rate(numCalls, ratePeriod, TimeUnit.MILLISECONDS);
QueueChannelThrottler throttler = new QueueChannelThrottler(rate);
final ConcurrentLinkedQueue<Long> base = new ConcurrentLinkedQueue<Long>();
for(int i = 0; i < totalCalls; i++) {
throttler.submit(new Runnable() {
@Override public void run() {
base.add(System.currentTimeMillis());
latch.countDown();
}
});
}
// wait for the tasks to finish, before exiting
try {
latch.await((totalCalls/numCalls)*ratePeriod, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
fail();
}
assertEquals(base.size(), 1000);
long last = 0;
for (Long next: base) {
assertTrue(next >= last);
last = next;
}
}
@Test
public void scheduledTasksShouldRunInOrder(){
int numCalls = 50;
int totalCalls = 1000;
int ratePeriod = 100;
CountDownLatch latch = new CountDownLatch(totalCalls);
Rate rate = new Rate(numCalls, ratePeriod, TimeUnit.MILLISECONDS);
QueueChannelThrottler throttler = new QueueChannelThrottler(rate);
ConcurrentLinkedQueue<Integer> base = new ConcurrentLinkedQueue<Integer>();
ConcurrentLinkedQueue<Integer> toCompare = new ConcurrentLinkedQueue<Integer>();
for(int i = 0; i < totalCalls; i++) {
throttler.submit(new RunnableImpl(i, toCompare, latch));
base.add(i);
}
// wait for the tasks to finish, before exiting
try {
latch.await((totalCalls/numCalls)*ratePeriod, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
fail();
}
assertEquals(toCompare.size(), base.size());
for (int i=0; i<toCompare.size(); i++) {
assertEquals(toCompare.poll(), base.poll());
}
//assertEquals(toCompare, base);
}
public class RunnableImpl implements Runnable {
public final int id;
private final ConcurrentLinkedQueue<Integer> collection;
private final CountDownLatch latch;
public RunnableImpl(int id, ConcurrentLinkedQueue<Integer> collection, CountDownLatch latch) {
this.id = id;
this.collection = collection;
this.latch = latch;
}
public void run() {
collection.add(id);
latch.countDown();
}
}
}
Leave a comment