Dask Offloading

Testing with a LocalCluster

Once you’ve written your function and are ready to move things over to . Working with Dask workers introduces another layer of complexity where things can go wrong, which make Dask LocalClusters the easiest way to prepare your code for offloading. This will mean that code will execute in the notebook session, just like running your function straight, allowing you to view print statements and debug errors normally rather than dealing with remote code execusion before we’re ready. Once you’re satisfied with your code you can switch over to a SLURMCluster to accelerate with GPU.

from distributed import Client, LocalCluster

cluster = LocalCluster()
client = Client(cluster)

We can submit our function to the cluster with the client.submit method. This will return a future which can be unpacked with its result using future.result(). We can see the outputs of print statements while we’re using a LocalCluster. Print statements will not be visible when executing remotely with SLURMCluster. Similarly, the full stack trace is still visible when an error or assertion is raised within the function.

def client_test(input1, input2, error=False, test=False):
    # Force an error
    if error:
        assert 0 == 1
    
    # Stop after one batch when testing        
    if test: 
        print("When running in a local cluster you can see print statements!")

    return input1, input2
future = client.submit(client_test, "input1", "input2", test=True)
future.result()
When running in a local cluster you can see print statements!
('input1', 'input2')
future = client.submit(client_test, "input1", "input2", error=True)
future.result()
2024-05-31 06:32:22,017 - distributed.worker - WARNING - Compute Failed
Key:       client_test-821dd3f7995546bf5d9280be38a9afd3
Function:  client_test
args:      ('input1', 'input2')
kwargs:    {'error': True}
Exception: 'AssertionError()'
AssertionError: 
client.shutdown()

Running on a SLURMCluster

We can pass in extra SLURM requirements in job_extra_directives to request a GPU for our jobs. To read more about configuring the SLURMCluster to interact with the SLURM queue, go to Dask’s jobqueue documentation.

from dask_jobqueue import SLURMCluster
from distributed import Client
cluster = SLURMCluster(
    memory="128g", processes=1, cores=16, job_extra_directives=["--gres=gpu:1", "--partition=BigCats"]
)

cluster.scale(1)
client = Client(cluster)

Since this code is executing remotely we won’t see our print statements

client.submit(client_test, "input1", "input2", test=True).result()
('input1', 'input2')

Dask will raise any errors that the process triggers locally, even when executing remotely - but you may not get the full stack trace

client.submit(client_test, "input1", "input2", error=True).result()
AssertionError: 

If you’re working with any objects that are particularly memory intensive, you can consider using the client.scatter method to scatter large objects out to our workers ahead of time for more efficient execution.

large_object = "Let's pretend that this string is actually a really big object like your dataset"
input1_future = client.scatter(large_object)
client.submit(client_test, input1_future, "input2").result()
("Let's pretend that this string is actually a really big object like your dataset",
 'input2')
client.shutdown()
2024-05-31 06:32:32,584 - distributed.scheduler - ERROR - Removing worker 'tcp://192.168.0.208:38481' caused the cluster to lose scattered data, which can't be recovered: {'str-aed0f69a5b2b8dbc59a28f905628b181'} (stimulus_id='handle-worker-cleanup-1717137152.584407')

If needed we can be more specific about the specific GPU type and QoS we need if we have more complex requirements.

cluster = SLURMCluster(
    memory="128g", processes=1, cores=16, job_extra_directives=["--gres=gpu:3g.20gb:1", "--qos=lion", "--partition=BigCats"]
)

cluster.scale(1)
client = Client(cluster)
client.submit(client_test, "input1", "input2", test=True).result()
('input1', 'input2')