merge: Fix rate limits under multi-node environments (!809)

View MR for information: https://activitypub.software/TransFem-org/Sharkey/-/merge_requests/809

Approved-by: dakkar <dakkar@thenautilus.net>
Approved-by: Marie <github@yuugi.dev>
This commit is contained in:
Hazelnoot 2024-12-15 16:53:48 +00:00
commit fd0ecb22cf
16 changed files with 719 additions and 400 deletions

View file

@ -5,16 +5,13 @@
import { Inject, Injectable } from '@nestjs/common';
import Redis from 'ioredis';
import { LoggerService } from '@/core/LoggerService.js';
import { TimeService } from '@/core/TimeService.js';
import { EnvService } from '@/core/EnvService.js';
import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed, hasMaxLimit, disabledLimitInfo, MaxLegacyLimit, MinLegacyLimit } from '@/misc/rate-limit-utils.js';
import { DI } from '@/di-symbols.js';
import type Logger from '@/logger.js';
import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed } from '@/misc/rate-limit-utils.js';
@Injectable()
export class SkRateLimiterService {
private readonly logger: Logger;
private readonly disabled: boolean;
constructor(
@ -24,32 +21,31 @@ export class SkRateLimiterService {
@Inject(DI.redis)
private readonly redisClient: Redis.Redis,
@Inject(LoggerService)
loggerService: LoggerService,
@Inject(EnvService)
envService: EnvService,
) {
this.logger = loggerService.getLogger('limiter');
this.disabled = envService.env.NODE_ENV !== 'production'; // TODO disable in TEST *only*
this.disabled = envService.env.NODE_ENV === 'test';
}
/**
* Check & increment a rate limit
* @param limit The limit definition
* @param actor Client who is calling this limit
* @param factor Scaling factor - smaller = larger limit (less restrictive)
*/
public async limit(limit: Keyed<RateLimit>, actor: string, factor = 1): Promise<LimitInfo> {
if (this.disabled || factor === 0) {
return {
blocked: false,
remaining: Number.MAX_SAFE_INTEGER,
resetSec: 0,
resetMs: 0,
fullResetSec: 0,
fullResetMs: 0,
};
return disabledLimitInfo;
}
if (factor < 0) {
throw new Error(`Rate limit factor is zero or negative: ${factor}`);
}
return await this.tryLimit(limit, actor, factor);
}
private async tryLimit(limit: Keyed<RateLimit>, actor: string, factor: number): Promise<LimitInfo> {
if (isLegacyRateLimit(limit)) {
return await this.limitLegacy(limit, actor, factor);
} else {
@ -58,141 +54,200 @@ export class SkRateLimiterService {
}
private async limitLegacy(limit: Keyed<LegacyRateLimit>, actor: string, factor: number): Promise<LimitInfo> {
const promises: Promise<LimitInfo | null>[] = [];
// The "min" limit - if present - is handled directly.
if (hasMinLimit(limit)) {
promises.push(
this.limitMin(limit, actor, factor),
);
if (hasMaxLimit(limit)) {
return await this.limitLegacyMinMax(limit, actor, factor);
} else if (hasMinLimit(limit)) {
return await this.limitLegacyMinOnly(limit, actor, factor);
} else {
return disabledLimitInfo;
}
// Convert the "max" limit into a leaky bucket with 1 drip / second rate.
if (limit.max != null && limit.duration != null) {
promises.push(
this.limitBucket({
type: 'bucket',
key: limit.key,
size: limit.max,
dripRate: Math.max(Math.round(limit.duration / limit.max), 1),
}, actor, factor),
);
}
const [lim1, lim2] = await Promise.all(promises);
return {
blocked: (lim1?.blocked || lim2?.blocked) ?? false,
remaining: Math.min(lim1?.remaining ?? Number.MAX_SAFE_INTEGER, lim2?.remaining ?? Number.MAX_SAFE_INTEGER),
resetSec: Math.max(lim1?.resetSec ?? 0, lim2?.resetSec ?? 0),
resetMs: Math.max(lim1?.resetMs ?? 0, lim2?.resetMs ?? 0),
fullResetSec: Math.max(lim1?.fullResetSec ?? 0, lim2?.fullResetSec ?? 0),
fullResetMs: Math.max(lim1?.fullResetMs ?? 0, lim2?.fullResetMs ?? 0),
};
}
private async limitMin(limit: Keyed<LegacyRateLimit> & { minInterval: number }, actor: string, factor: number): Promise<LimitInfo | null> {
if (limit.minInterval === 0) return null;
private async limitLegacyMinMax(limit: Keyed<MaxLegacyLimit>, actor: string, factor: number): Promise<LimitInfo> {
if (limit.duration === 0) return disabledLimitInfo;
if (limit.duration < 0) throw new Error(`Invalid rate limit ${limit.key}: duration is negative (${limit.duration})`);
if (limit.max < 1) throw new Error(`Invalid rate limit ${limit.key}: max is less than 1 (${limit.max})`);
// Derive initial dripRate from minInterval OR duration/max.
const initialDripRate = Math.max(limit.minInterval ?? Math.round(limit.duration / limit.max), 1);
// Calculate dripSize to reach max at exactly duration
const dripSize = Math.max(Math.round(limit.max / (limit.duration / initialDripRate)), 1);
// Calculate final dripRate from dripSize and duration/max
const dripRate = Math.max(Math.round(limit.duration / (limit.max / dripSize)), 1);
const bucketLimit: Keyed<BucketRateLimit> = {
type: 'bucket',
key: limit.key,
size: limit.max,
dripRate,
dripSize,
};
return await this.limitBucket(bucketLimit, actor, factor);
}
private async limitLegacyMinOnly(limit: Keyed<MinLegacyLimit>, actor: string, factor: number): Promise<LimitInfo> {
if (limit.minInterval === 0) return disabledLimitInfo;
if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`);
const counter = await this.getLimitCounter(limit, actor, 'min');
const minInterval = Math.max(Math.ceil(limit.minInterval * factor), 0);
// Update expiration
if (counter.c > 0) {
const isCleared = this.timeService.now - counter.t >= minInterval;
if (isCleared) {
counter.c = 0;
}
}
const blocked = counter.c > 0;
if (!blocked) {
counter.c++;
counter.t = this.timeService.now;
}
// Calculate limit status
const resetMs = Math.max(Math.ceil(minInterval - (this.timeService.now - counter.t)), 0);
const resetSec = Math.ceil(resetMs / 1000);
const limitInfo: LimitInfo = { blocked, remaining: 0, resetSec, resetMs, fullResetSec: resetSec, fullResetMs: resetMs };
// Update the limit counter, but not if blocked
if (!blocked) {
// Don't await, or we will slow down the API.
this.setLimitCounter(limit, actor, counter, resetSec, 'min')
.catch(err => this.logger.error(`Failed to update limit ${limit.key}:min for ${actor}:`, err));
}
return limitInfo;
const dripRate = Math.max(Math.round(limit.minInterval), 1);
const bucketLimit: Keyed<BucketRateLimit> = {
type: 'bucket',
key: limit.key,
size: 1,
dripRate,
dripSize: 1,
};
return await this.limitBucket(bucketLimit, actor, factor);
}
/**
* Implementation of Leaky Bucket rate limiting - see SkRateLimiterService.md for details.
*/
private async limitBucket(limit: Keyed<BucketRateLimit>, actor: string, factor: number): Promise<LimitInfo> {
if (limit.size < 1) throw new Error(`Invalid rate limit ${limit.key}: size is less than 1 (${limit.size})`);
if (limit.dripRate != null && limit.dripRate < 1) throw new Error(`Invalid rate limit ${limit.key}: dripRate is less than 1 (${limit.dripRate})`);
if (limit.dripSize != null && limit.dripSize < 1) throw new Error(`Invalid rate limit ${limit.key}: dripSize is less than 1 (${limit.dripSize})`);
const counter = await this.getLimitCounter(limit, actor, 'bucket');
// 0 - Calculate
const now = this.timeService.now;
const bucketSize = Math.max(Math.ceil(limit.size / factor), 1);
const dripRate = Math.ceil(limit.dripRate ?? 1000);
const dripSize = Math.ceil(limit.dripSize ?? 1);
const expirationSec = Math.max(Math.ceil((dripRate * Math.ceil(bucketSize / dripSize)) / 1000), 1);
// Update drips
if (counter.c > 0) {
const dripsSinceLastTick = Math.floor((this.timeService.now - counter.t) / dripRate) * dripSize;
counter.c = Math.max(counter.c - dripsSinceLastTick, 0);
// 1 - Read
const counterKey = createLimitKey(limit, actor, 'c');
const timestampKey = createLimitKey(limit, actor, 't');
const counter = await this.getLimitCounter(counterKey, timestampKey);
// 2 - Drip
const dripsSinceLastTick = Math.floor((now - counter.timestamp) / dripRate) * dripSize;
const deltaCounter = Math.min(dripsSinceLastTick, counter.counter);
const deltaTimestamp = dripsSinceLastTick * dripRate;
if (deltaCounter > 0) {
// Execute the next drip(s)
const results = await this.executeRedisMulti(
['get', timestampKey],
['incrby', timestampKey, deltaTimestamp],
['expire', timestampKey, expirationSec],
['get', timestampKey],
['decrby', counterKey, deltaCounter],
['expire', counterKey, expirationSec],
['get', counterKey],
);
const expectedTimestamp = counter.timestamp;
const canaryTimestamp = results[0] ? parseInt(results[0]) : 0;
counter.timestamp = results[3] ? parseInt(results[3]) : 0;
counter.counter = results[6] ? parseInt(results[6]) : 0;
// Check for a data collision and rollback
if (canaryTimestamp !== expectedTimestamp) {
const rollbackResults = await this.executeRedisMulti(
['decrby', timestampKey, deltaTimestamp],
['get', timestampKey],
['incrby', counterKey, deltaCounter],
['get', counterKey],
);
counter.timestamp = rollbackResults[1] ? parseInt(rollbackResults[1]) : 0;
counter.counter = rollbackResults[3] ? parseInt(rollbackResults[3]) : 0;
}
}
const blocked = counter.c >= bucketSize;
// 3 - Check
const blocked = counter.counter >= bucketSize;
if (!blocked) {
counter.c++;
counter.t = this.timeService.now;
if (counter.timestamp === 0) {
const results = await this.executeRedisMulti(
['set', timestampKey, now],
['expire', timestampKey, expirationSec],
['incr', counterKey],
['expire', counterKey, expirationSec],
['get', counterKey],
);
counter.timestamp = now;
counter.counter = results[4] ? parseInt(results[4]) : 0;
} else {
const results = await this.executeRedisMulti(
['incr', counterKey],
['expire', counterKey, expirationSec],
['get', counterKey],
);
counter.counter = results[2] ? parseInt(results[2]) : 0;
}
}
// Calculate how much time is needed to free up a bucket slot
const overflow = Math.max((counter.counter + 1) - bucketSize, 0);
const dripsNeeded = Math.ceil(overflow / dripSize);
const timeNeeded = Math.max((dripRate * dripsNeeded) - (this.timeService.now - counter.timestamp), 0);
// Calculate limit status
const remaining = Math.max(bucketSize - counter.c, 0);
const resetMs = remaining > 0 ? 0 : Math.max(dripRate - (this.timeService.now - counter.t), 0);
const remaining = Math.max(bucketSize - counter.counter, 0);
const resetMs = timeNeeded;
const resetSec = Math.ceil(resetMs / 1000);
const fullResetMs = Math.ceil(counter.c / dripSize) * dripRate;
const fullResetMs = Math.ceil(counter.counter / dripSize) * dripRate;
const fullResetSec = Math.ceil(fullResetMs / 1000);
const limitInfo: LimitInfo = { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs };
return { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs };
}
// Update the limit counter, but not if blocked
if (!blocked) {
// Don't await, or we will slow down the API.
this.setLimitCounter(limit, actor, counter, fullResetSec, 'bucket')
.catch(err => this.logger.error(`Failed to update limit ${limit.key} for ${actor}:`, err));
private async getLimitCounter(counterKey: string, timestampKey: string): Promise<LimitCounter> {
const [counter, timestamp] = await this.executeRedisMulti(
['get', counterKey],
['get', timestampKey],
);
return {
counter: counter ? parseInt(counter) : 0,
timestamp: timestamp ? parseInt(timestamp) : 0,
};
}
private async executeRedisMulti(...batch: RedisCommand[]): Promise<RedisResult[]> {
const results = await this.redisClient.multi(batch).exec();
// Transaction conflict (retryable)
if (!results) {
throw new ConflictError('Redis error: transaction conflict');
}
return limitInfo;
}
private async getLimitCounter(limit: Keyed<RateLimit>, actor: string, subject: string): Promise<LimitCounter> {
const key = createLimitKey(limit, actor, subject);
const value = await this.redisClient.get(key);
if (value == null) {
return { t: 0, c: 0 };
// Transaction failed (fatal)
if (results.length !== batch.length) {
throw new Error('Redis error: failed to execute batch');
}
return JSON.parse(value);
}
// Map responses
const errors: Error[] = [];
const responses: RedisResult[] = [];
for (const [error, response] of results) {
if (error) errors.push(error);
responses.push(response as RedisResult);
}
private async setLimitCounter(limit: Keyed<RateLimit>, actor: string, counter: LimitCounter, expiration: number, subject: string): Promise<void> {
const key = createLimitKey(limit, actor, subject);
const value = JSON.stringify(counter);
const expirationSec = Math.max(expiration, 1);
await this.redisClient.set(key, value, 'EX', expirationSec);
// Command failed (fatal)
if (errors.length > 0) {
const errorMessages = errors
.map((e, i) => `Error in command ${i}: ${e}`)
.join('\', \'');
throw new AggregateError(errors, `Redis error: failed to execute command(s): '${errorMessages}'`);
}
return responses;
}
}
function createLimitKey(limit: Keyed<RateLimit>, actor: string, subject: string): string {
return `rl_${actor}_${limit.key}_${subject}`;
// Not correct, but good enough for the basic commands we use.
type RedisResult = string | null;
type RedisCommand = [command: string, ...args: unknown[]];
function createLimitKey(limit: Keyed<RateLimit>, actor: string, value: string): string {
return `rl_${actor}_${limit.key}_${value}`;
}
export interface LimitCounter {
/** Timestamp */
t: number;
class ConflictError extends Error {}
/** Counter */
c: number;
interface LimitCounter {
timestamp: number;
counter: number;
}

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;

View file

@ -17,10 +17,11 @@ export const meta = {
allowGet: true,
cacheSec: 60 * 60,
// 10 calls per 5 seconds
// Burst up to 100, then 2/sec average
limit: {
duration: 1000 * 5,
max: 10,
type: 'bucket',
size: 100,
dripRate: 500,
},
} as const;