Maximizing read throughput of ScyllaDB table scan using shard-awareness


Test. To read a set of token ranges do the following:

  1. Split the set of token ranges to smaller intervals.
  2. For each token range find which nodes and shards own it.
  3. In parallel read token ranges while minimizing node and shard concurrency.

We want this…

…instead of this.

The Task

We want to read a given set of token ranges as fast as possible. This can be a full table scan, reading only token ranges for which a Scylla node is the primary replica, or any other subset of a table. Maximal throughput is the goal.

Introduction

Read this first: Use parallel efficient full table scan with ScyllaDB to scan 475 million partitions x12 faster. The linked article relies on randomization which doesn’t always lead to optimal performance. Also it recommends splitting token ranges into larger intervals than what lead to better results in my case. But it’s a great start.

To run performance tests I implemented a Java application ran directly on a Scylla node. The application used com.scylladb:java-driver-core-shaded:4.14.1.0 as a Scylla driver. The cluster was running ScyllaDB version 4.6.11-0.20221128.6c0825e2a, 28 nodes with 72 shards each (80 CPU cores in total), replication factor 3, Murmur 3 paritioner.

For simplicity here we read token ranges only from a single Scylla node for which this node is the primary replica. The idea can be extended but you need to track load not only for each shard, but for a tuple of (replica, shard).

Step 0: Set routing token for SELECT token range statement

Assuming that your read query has form of SELECT token(key), key, ... FROM ... WHERE token(key)>:starttoken AND token(key)<=:endtoken BYPASS CACHE;. Help Scylla driver to send the request to where it belongs by calling setRoutingToken:

      ...
      .bind()
      .setToken("starttoken", tokenRange.getStart)
      .setToken("endtoken", tokenRange.getEnd)
      .setRoutingToken(tokenRange.getEnd);

Most probably you don’t want to neither cache read data nor evict other data from the cache for other queries. That’s why BYPASS CACHE. Caching also doesn’t make performance testing easier.

Step 1: Find owning shards

Sharding algorithm is absolutely critical to what’s written here. Take a look at the actual implementation of int shardId(Token token) mehtod in ShardingInfo class which we’ll use to find the shard for a particular token: https://github.com/scylladb/java-driver/blob/bd0659903a686825616deac985b12f1cdb29d180/core/src/main/java/com/datastax/oss/driver/internal/core/protocol/ShardingInfo.java#L63

Thanks to the nature of the biased-token-round-robin sharding algorithm shards are assigned to consecutive tokens (token ranges) in sequence:

  • Shard 0 owns token range ]startToken, endToken0]
  • Shard 1 owns token range ]endToken0, endToken1]
  • Shard 2 owns token range ]endToken1, endToken2]

Token range is an interval ]startToken, endToken] where startToken is exclusive and endToken is inclusive. To get shard for the end of a token range is straightforward endShard=shardId(endToken) but because start is exclusive then we need to do startShard=shardId(startToken + 1). Implementation note: just be careful with +1 because token value is Long and it goes all the way to MAX_LONG.

With this knowledge if we get startShard=2 and endShard=4 (assuming our token range is small enough – see the next step) then we know that the whole token range is owned by shards 2, 3 and 4. Special case: Let’s say we have 5 shards and we get startShard=3 and endShard=0. Then the token range belongs to shards 3, 4, 0 because of round robin.

Step 2: Split token ranges

The goal is to have a set of token ranges where each is owned only by a few shards so that we can read them concurrently with minimal utilization of shared resources. On the other hand we don’t want to make them too small because then the overhead of running many small reads would not lead to the best performance.

You can split a token range by calling tokenRange.splitEvenly(numberOfSplits); You must find the optimal number of splits for yourself by running performance tests.

In case you’re curious, for my setup, I found that if I split all token ranges for a node so that in total I get 10,000 similarly-sized token ranges then each token range is owned by 2 shards in average.

Step 3: Evenly distribute

Now we have a set of small token ranges and for each we know which shards that own it. We need to decide when are we going to read which. We want to maximize concurrency, evenly distributed the load between all shards and avoid overloading.

To achieve this we need to keep track of which shards are busy reading and which are idle. So then when we’re done with reading of a token range then we pick the one whose shards have the smallest load.

We also probably want to limit total level of concurrency per Scylla node across all shards. This depends on your HW because even if you have 80 shards per node and you concurrently read 80 token ranges where each is owned only by a single shard then you may find out that the final throughput is not higher than if you read only 10 token ranges concurrently. Maybe that’s because all shards read from the same disk and that disk is slow.

Pseudocode for your inspiration:

TokenRange[] allTokenRangesToRead; // Algorithm input
int[] shardLoad; // Array length is the total number of shards (index = shard ID)

// Magic numbers that you need to figure out:
final int maxShardLoad = 2;
final int maxConcurrency = 42;

int concurrency() {
  // Number of shards with positive load
  return shardLoad.count(value > 0)
}

boolean isOverloaded(int shards[]) {
  // Is there a shard with high load?
  return shards.exists(shardId -> shardLoad[shardId] >= maxShardLoad)
}

int getTokenRangeLoad(TokenRange tokenRange) {
  // Total load of all shards that own this token range
  return shardLoad[tokenRange.shards].sum()
}

TokenRange findTheBestTokenRange() {
  if (concurrency() >= maxConcurrency) {
    // We're already overloaded
    return null
  }
  // Don't consider token ranges whose shards are overloaded
  TokenRange[] notOverloaded = allTokenRangesToRead.filter(tokenRange ->
    !isOverloaded(tokenRange.shards)    
  )
  // Sort eligible token ranges by their load
  TokenRange[] sortedByLoad = notOverloaded.sortBy(tokenRange ->
    getTokenRangeLoad(tokenRange)
  )
  // Choose token range with the smallest load
  return sortedByLoad[0]
}

boolean readNextTokenRange() {
  tokenRange = findTheBestTokenRange()
  if (tokenRange == null) {
    return false
  }
  allTokenRangesToRead.remove(tokenRange)
  tokenRange.shards.foreach(shardId -> shardLoad(shardId) += 1)
  readAsync(tokenRange).onComplete(result ->
    tokenRange.shards.foreach(shardId -> shardLoad(shardId) -= 1)    
    readAsManyAsPossible()
  )
  return true
}

void readAsManyAsPossible() {
  if (allTokenRangesToRead.isEmpty) {
    // TODO: Tell the main thread to stop waiting. We're done.    
  } else {
    while (readNextTokenRange()) {}
  }
}

// Initialize reading and wait for all asynchronous reading to be done.
readAsManyAsPossible()
waitForAllToFinish()

Optimization variables

To summarize the previous steps, your goal is to find the best combination of:

  1. Number of token range splits
    1. Too low reduces concurrency
    2. Too high increases per-request overhead
  2. Maximal number of concurrent requests per shard
    1. Too low wastes shard’s capacity
    2. Too high increases overhead related to task switching
  3. Maximal number of concurrent requests per node
    1. Too low wastes node’s capacity
    2. Too high probably doesn’t hurt as long as individual shards are not overloaded

Asynchronous paging with prefetch

Use Scylla driver’s asynchronous API with paging. Don’t block any threads. Consider fetching of the next page before you start processing of the result set.

Pseudocode for asychronous prefetching:

resultSetFuture.thenCompose(rs -> readResultSet(rs));

private CompletionStage<Integer> readResultSet(AsyncResultSet resultSet) {
if (resultSet.hasMorePages()) {
    // Trigger fetching of the next page first so that Scylla is not idle
    // while we're processing the rows.
    final CompletionStage<AsyncResultSet> nextPage = resultSet.fetchNextPage();
    // Scylla is busy with fetching of the next page so we can use the time to
    // run our logic.
    doWhateverWithTheRows(resultSet.currentPage());
    return nextPage.thenCompose(rs -> readResultSet(rs));
  } else {
    // Process the last page.
    doWhateverWithTheRows(buffer);
    return CompletableFuture.completedFuture(count);
  }
}

Know your limits

Take a look at io_properties.yaml to know what are the IO limits for your Scylla cluster. You can’t go faster than that. Example:

$ cat /etc/scylla.d/io_properties.yaml

disks:
- mountpoint: /srv/data/disk2
  read_iops: 300000
  read_bandwidth: 3400000000
  write_iops: 150000
  write_bandwidth: 4500000000

Running performance tests

To get the most stable environment and simplify monitoring I did the following:

  • Enforce connecting Scylla driver to only one specific Scylla node by using NodeFilterToDistanceEvaluatorAdapter.
  • Read only token ranges owned by that Scylla node (token ranges for which the node is the primary replica): session.getMetadata.getTokenMap.get().getTokenRanges(node)

This allows you to open Scylla monitoring Grafana dashboard and select only this node.

Monitoring

To measure performance add a counter to your application and increase it by byte size of every read Row. Also keep looking at metrics collected both from application client-side and Scylla server-side.

Client side

On the application side collect BYTES_RECEIVED metric: https://github.com/scylladb/java-driver/blob/bd0659903a686825616deac985b12f1cdb29d180/core/src/main/java/com/datastax/oss/driver/api/core/metrics/DefaultSessionMetric.java#L25

Then add your own custom metrics for monitoring remaining token ranges, number of token ranges in progress, load by shard, rate of pages read per second, total byte size counter value, request latency, …

Server side

From Scylla Monitoring dashboards I found the most valuable to switch to shard-level detail and keep looking at:

  • “Load” panel on “Detailed” dashboard.
  • “query I/O Queue bandwidth by Shard“ panel on “Advanced” dashboard.

Reference