Perform batch LLM inference using ai_query
Important
This feature is in Public Preview.
This article describes how to perform batch inference using the built-in Databricks SQL function ai_query
with an endpoint that uses Foundation Model APIs provisioned throughput. The examples and guidance in this article are recommended for batch inference workloads that use large language models (LLM) to process multiple inputs.
You can use ai_query
with either SQL or PySpark to run batch inference workloads. To run batch inference on your data, specify the following in ai_query
:
- The Unity Catalog input table and output table
- The provisioned throughput endpoint name
- The model prompt and any model parameters
See ai_query function for more detail about this AI function.
Requirements
- A workspace in a Foundation Model APIs supported region.
- One of the following:
- All-purpose compute with compute size
i3.2xlarge
or larger running Databricks Runtime 15.4 ML LTS or above with at least two workers. - SQL warehouse medium and larger.
- All-purpose compute with compute size
- An existing model serving endpoint. See Provisioned throughput Foundation Model APIs to create a provisioned throughput endpoint.
- Query permission on the Delta table in Unity Catalog that contains the data you want to use.
- Set the
pipelines.channel
in the table properties as ‘preview’ to useai_query()
. See Examples for a sample query.
Use ai_query
and SQL
The following is a batch inference example using ai_query
and SQL. This example includes modelParameters
with max_tokens
and temperature
and shows how to concatenate the prompt for your model and the input column using concat()
. There are multiple ways to perform concatenation, such as using ||
, concat()
, or format_string()
.
CREATE OR REPLACE TABLE ${output_table_name} AS (
SELECT
${input_column_name},
AI_QUERY(
"${endpoint}",
CONCAT("${prompt}", ${input_column_name}),
modelParameters => named_struct('max_tokens', ${num_output_tokens},'temperature', ${temperature})
) as response
FROM ${input_table_name}
LIMIT ${input_num_rows}
)
Use ai_query
and PySpark
If you prefer using Python, you can also run batch inference with ai_query
and PySpark as shown in the following:
df_out = df.selectExpr("ai_query('{endpoint_name}', CONCAT('{prompt}', {input_column_name}), modelParameters => named_struct('max_tokens', ${num_output_tokens},'temperature', ${temperature})) as {output_column_name}")
df_out.write.mode("overwrite").saveAsTable(output_table_name)
Batch inference example notebook using Python
The example notebook creates a provisioned throughput endpoint and runs batch LLM inference using Python and the Meta Llama 3.1 70B model. It also has guidance on benchmarking your batch inference workload and creating a provisioned throughput model serving endpoint.