Source code for dea_tools.dask
# dea_dask.py
"""
Tools for simplifying the creation of Dask clusters for parallelised computing.
License: The code in this notebook is licensed under the Apache License,
Version 2.0 (https://www.apache.org/licenses/LICENSE-2.0). Digital Earth
Australia data is licensed under the Creative Commons by Attribution 4.0
license (https://creativecommons.org/licenses/by/4.0/).
Contact: If you need assistance, please post a question on the Open Data
Cube Discord chat (https://discord.com/invite/4hhBQVas5U) or on the GIS Stack
Exchange (https://gis.stackexchange.com/questions/ask?tags=open-data-cube)
using the `open-data-cube` tag (you can view previously asked questions
here: https://gis.stackexchange.com/questions/tagged/open-data-cube).
If you would like to report an issue with this script, you can file one on
GitHub (https://github.com/GeoscienceAustralia/dea-notebooks/issues/new).
Last modified: July 2025
"""
import os
from importlib.util import find_spec
import dask
import dask.distributed
from aiohttp import ClientConnectionError
from odc.io.cgroups import get_cpu_quota
_HAVE_PROXY = bool(find_spec("jupyter_server_proxy"))
[docs]
def create_local_dask_cluster(
display_client=True,
return_client=False,
configure_rio=True,
n_workers=1,
threads_per_worker=None,
memory_limit="spare_mem",
**kwargs,
):
"""
Create a local Dask cluster for parallelised computing using ``dask.distributed.Client``.
Example use:
from dea_dask import create_local_dask_cluster
create_local_dask_cluster()
Parameters
----------
display_client : bool, optional
An optional boolean indicating whether to display a summary of
the dask client, including a link to monitor progress of the
analysis. Set to False to hide this display.
return_client : bool, optional
An optional boolean indicating whether to return the dask client
object.
configure_rio : bool, optional
An optional boolean indicating whether to configure ``rasterio``
with cloud defaults and unsigned AWS access. Set to False to not
apply these defaults.
n_workers : int, optional
Number of workers to start, default is set to 1 which works well
with loading ODC data.
threads_per_worker: int, optional
Number of threads per each worker, by default this will be set to
the number of cpus on the machine.
memory_limit: str, float, int, or None, optional
Sets the memory limit per worker. Default is 'spare_mem', where 95 % of the available
system memory is split among the number of workers, allowing spare memory to be withheld
from the cluster.
To see other options: https://distributed.dask.org/en/stable/api.html#distributed.Client
**kwargs:
Additional keyword arguments passed to ``dask.distributed.Client``.
For full options, see: https://distributed.dask.org/en/stable/api.html#distributed.Client
"""
# Ensure that client links correctly launch on DEA Sandbox
if _HAVE_PROXY:
# Configure dashboard link to go over proxy
prefix = os.environ.get("JUPYTERHUB_SERVICE_PREFIX", "/")
dask.config.set({"distributed.dashboard.link": prefix + "proxy/{port}/status"})
# Count cpus if threads_per_worker not provided
if threads_per_worker is None:
threads_per_worker = round(get_cpu_quota()) if get_cpu_quota() is not None else os.cpu_count()
# by default split 95% of system memory by the n_workers.
if memory_limit == "spare_mem":
memory_limit = 0.95 / n_workers
# Start client
client = dask.distributed.Client(
n_workers=n_workers,
threads_per_worker=int(threads_per_worker),
memory_limit=memory_limit,
**kwargs,
)
# Configure AWS and GDAL/rasterio access. Use datacube `configure_s3_access`
# function preferentially if datacube is installed, as this function will
# choose the correct settings automatically. If datacube is not installed,
# use version of function from odc.loader > odc.stac.
if configure_rio:
try:
from datacube.utils.aws import configure_s3_access
except ImportError:
from odc.stac import configure_s3_access
configure_s3_access(cloud_defaults=True, aws_unsigned=True, client=client)
# Show the dask cluster settings
if display_client:
try:
from IPython.display import display # Check if IPython is available
display(client)
except ImportError:
raise ImportError(
"IPython is not installed, but display_client=True was requested. Either set \n"
"display_client=False, or install the required Jupyter dependencies \n"
"via: pip install dea-tools[jupyter]"
)
# Return the client as an object
if return_client:
return client
# Otherwise return none
return None
[docs]
def create_dask_gateway_cluster(profile="r5_L", workers=2):
"""
Create a cluster in our internal dask cluster.
Parameters
----------
profile : str
Possible values are:
- r5_L (2 cores, 15GB memory)
- r5_XL (4 cores, 31GB memory)
- r5_2XL (8 cores, 63GB memory)
- r5_4XL (16 cores, 127GB memory)
workers : int
Number of workers in the cluster.
"""
# Attempt to import dask_gateway and raise an error if not available
try:
from dask_gateway import Gateway
except ImportError as e:
raise ImportError(
"`dask_gateway` is required for `create_dask_gateway_cluster`. "
"Please install DEA Tools with the `[dask_gateway]` extra, e.g.: "
"`pip install dea-tools[dask_gateway]`"
) from e
try:
gateway = Gateway()
# Close any existing clusters
cluster_names = gateway.list_clusters()
if len(cluster_names) > 0:
print("Cluster(s) still running:", cluster_names)
for n in cluster_names:
cluster = gateway.connect(n.name)
cluster.shutdown()
options = gateway.cluster_options()
options["profile"] = profile
# limit username to alphanumeric characters
# kubernetes pods won't launch if labels contain anything other than [a-Z, -, _]
options["jupyterhub_user"] = "".join(c if c.isalnum() else "-" for c in os.getenv("JUPYTERHUB_USER"))
cluster = gateway.new_cluster(options)
cluster.scale(workers)
return cluster
except ClientConnectionError:
raise ConnectionError("access to dask gateway cluster unauthorized")