Perform batch LLM inference using AI Functions

Important

This feature is in Public Preview.

This article describes how to perform batch inference using AI Functions.

Requirements

  • A workspace in a Foundation Model APIs supported region.
  • Databricks Runtime 15.4 or above is recommended.
  • Query permission on the Delta table in Unity Catalog that contains the data you want to use.
  • Set the pipelines.channel in the table properties as ‘preview’ to use ai_query(). See Requirements for an example query.
  • For batch inference workloads using AI Functions, Databricks recommends Databricks Runtime 15.4 ML LTS for improved performance.

Batch LLM inference using task-specific AI Functions

You can run batch inference using task-specific AI functions. See Deploy batch inference pipelines for guidance on how to incorporate your task-specific AI function into a pipeline.

The following is an example of using the task-specific AI function, ai_translate:

SELECT
writer_summary,
  ai_translate(writer_summary, "cn") as cn_translation
from user.batch.news_summaries
limit 500
;

Batch LLM inference using ai_query

You can use the general purpose AI function, ai_query to perform batch inference. See which model types and the associated models that ai_query supports.

The examples in this section focus on the flexibility of ai_query and how to use it in batch inference pipelines and workflows.

ai_query and Databricks-hosted foundation models

When you use a Databricks-hosted and pre-provisioned foundation model for batch inference, Databricks configures a provisioned throughput endpoint on your behalf that scales automatically based on the workload.

To use this method for batch inference, specify the following in your request:

  • The pre-provisioned LLM you want to use in ai_query. Select from supported pre-provisioned LLMs.
  • The Unity Catalog input table and output table.
  • The model prompt and any model parameters.
SELECT text, ai_query(
    "databricks-meta-llama-3-1-8b-instruct",
    "Summarize the given text comprehensively, covering key points and main ideas concisely while retaining relevant details and examples. Ensure clarity and accuracy without unnecessary repetition or omissions: " || text
) AS summary
FROM uc_catalog.schema.table;

ai_query and custom or fine-tuned foundation models

The notebook examples in this section demonstrate batch inference workloads that use custom or fine-tuned foundation models to process multiple inputs. The examples require an existing model serving endpoint that uses Foundation Model APIs provisioned throughput.

LLM batch inference using a custom foundation model

The following example notebook creates a provisioned throughput endpoint and runs batch LLM inference using Python and the Meta Llama 3.1 70B model. It also has guidance on benchmarking your batch inference workload and creating a provisioned throughput model serving endpoint.

LLM batch inference with a custom foundation model and a provisioned throughput endpoint notebook

Get notebook

LLM batch inference using an embeddings model

The following example notebook creates a provisioned throughput endpoint and runs batch LLM inference using Python and your choice of either the GTE Large (English) or BGE Large (English) embeddings model.

LLM batch inference embeddings with a provisioned throughput endpoint notebook

Get notebook

Batch inference and structured data extraction

The following example notebook demonstrates how to perform basic structured data extraction using ai_query to transform raw, unstructured data into organized, useable information through automated extraction techniques. This notebook also shows how to leverage Mosaic AI Agent Evaluation to evaluate the accuracy using ground truth data.

Batch inference and structured data extraction notebook

Get notebook

Batch inference using BERT for named entity recognition

The following notebook shows a traditional ML model batch inference example using BERT.

Batch inference using BERT for named entity recognition notebook

Get notebook

Deploy batch inference pipelines

This section shows how you can integrate AI Functions into other Databricks data and AI products to build complete batch inference pipelines. These pipelines can perform end to end workflows that include ingestion, preprocessing, inference, and post-processing. Pipelines can be authored in SQL or Python and deployed as:

  • Delta Live Table pipelines
  • Scheduled workflows using Databricks workflows
  • Streaming inference workflows using Structured Streaming

Perform incremental batch inference on Delta Live Tables

The following example performs incremental batch inference using Delta Live Tables for when data is continuously updated.

Step 1: Ingest raw news data from a volume

SQL

CREATE OR REFRESH STREAMING TABLE news_raw
COMMENT "Raw news articles ingested from volume."
AS SELECT *
FROM STREAM(read_files(
  '/Volumes/databricks_news_summarization_benchmarking_data/v01/csv',
  format => 'csv',
  header => true,
  mode => 'PERMISSIVE',
  multiLine => 'true'
));
Python

Import the packages and define the JSON schema of the LLM response as a Python variable


import dlt
from pyspark.sql.functions import expr, get_json_object, concat

news_extraction_schema = (
    '{"type": "json_schema", "json_schema": {"name": "news_extraction", '
    '"schema": {"type": "object", "properties": {"title": {"type": "string"}, '
    '"category": {"type": "string", "enum": ["Politics", "Sports", "Technology", '
    '"Health", "Entertainment", "Business"]}}}, "strict": true}}'
)

Ingest your data from a Unity Catalog volume.

@dlt.table(
  comment="Raw news articles ingested from volume."
)
def news_raw():
  return (
    spark.readStream
      .format("cloudFiles")
      .option("cloudFiles.format", "csv")
      .option("header", True)
      .option("mode", "PERMISSIVE")
      .option("multiLine", "true")
      .load("/Volumes/databricks_news_summarization_benchmarking_data/v01/csv")
  )

Step 2: Apply LLM inference to extract title and category

SQL

CREATE OR REFRESH MATERIALIZED VIEW news_categorized
COMMENT "Extract category and title from news articles using LLM inference."
AS
SELECT
  inputs,
  ai_query(
    "databricks-meta-llama-3-3-70b-instruct",
    "Extract the category of the following news article: " || inputs,
    responseFormat => '{
      "type": "json_schema",
      "json_schema": {
        "name": "news_extraction",
        "schema": {
          "type": "object",
          "properties": {
            "title": { "type": "string" },
            "category": {
              "type": "string",
              "enum": ["Politics", "Sports", "Technology", "Health", "Entertainment", "Business"]
            }
          }
        },
        "strict": true
      }
    }'
  ) AS meta_data
FROM news_raw
LIMIT 2;
Python
@dlt.table(
  comment="Extract category and title from news articles using LLM inference."
)
def news_categorized():
  # Limit the number of rows to 2 as in the SQL version
  df_raw = spark.read.table("news_raw").limit(2)
  # Inject the JSON schema variable into the ai_query call using an f-string.
  return df_raw.withColumn(
    "meta_data",
    expr(
      f"ai_query('databricks-meta-llama-3-3-70b-instruct', "
      f"concat('Extract the category of the following news article: ', inputs), "
      f"responseFormat => '{news_extraction_schema}')"
    )
  )

Step 3: Validate the LLM inference output before summarization

SQL
CREATE OR REFRESH MATERIALIZED VIEW news_validated (
  CONSTRAINT valid_title EXPECT (size(split(get_json_object(meta_data, '$.title'), ' ')) >= 3),
  CONSTRAINT valid_category EXPECT (get_json_object(meta_data, '$.category') IN ('Politics', 'Sports', 'Technology', 'Health', 'Entertainment', 'Business'))
)
COMMENT "Validated news articles ensuring the title has at least 3 words and the category is valid."
AS
SELECT *
FROM news_categorized;
Python
@dlt.table(
  comment="Validated news articles ensuring the title has at least 3 words and the category is valid."
)
@dlt.expect("valid_title", "size(split(get_json_object(meta_data, '$.title'), ' ')) >= 3")
@dlt.expect_or_fail("valid_category", "get_json_object(meta_data, '$.category') IN ('Politics', 'Sports', 'Technology', 'Health', 'Entertainment', 'Business')")
def news_validated():
  return spark.read.table("news_categorized")

Step 4: Summarize news articles from the validated data

SQL
CREATE OR REFRESH MATERIALIZED VIEW news_summarized
COMMENT "Summarized political news articles after validation."
AS
SELECT
  get_json_object(meta_data, '$.category') as category,
  get_json_object(meta_data, '$.title') as title,
  ai_query(
    "databricks-meta-llama-3-3-70b-instruct",
    "Summarize the following political news article in 2-3 sentences: " || inputs
  ) AS summary
FROM news_validated;
Python

@dlt.table(
  comment="Summarized political news articles after validation."
)
def news_summarized():
  df = spark.read.table("news_validated")
  return df.select(
    get_json_object("meta_data", "$.category").alias("category"),
    get_json_object("meta_data", "$.title").alias("title"),
    expr(
      "ai_query('databricks-meta-llama-3-3-70b-instruct', "
      "concat('Summarize the following political news article in 2-3 sentences: ', inputs))"
    ).alias("summary")
  )

Automate batch inference jobs using Databricks workflows

Schedule batch inference jobs and automate AI pipelines.

SQL

SELECT
   *,
   ai_query('databricks-meta-llama-3-3-70b-instruct', request => concat("You are an opinion mining service. Given a piece of text, output an array of json results that extracts key user opinions, a classification, and a Positive, Negative, Neutral, or Mixed sentiment about that subject.


AVAILABLE CLASSIFICATIONS
Quality, Service, Design, Safety, Efficiency, Usability, Price


Examples below:


DOCUMENT
I got soup. It really did take only 20 minutes to make some pretty good soup. The noises it makes when it's blending are somewhat terrifying, but it gives a little beep to warn you before it does that. It made three or four large servings of soup. It's a single layer of steel, so the outside gets pretty hot. It can be hard to unplug the lid without knocking the blender against the side, which is not a nice sound. The soup was good and the recipes it comes with look delicious, but I'm not sure I'll use it often. 20 minutes of scary noises from the kitchen when I already need comfort food is not ideal for me. But if you aren't sensitive to loud sounds it does exactly what it says it does..


RESULT
[
 {'Classification': 'Efficiency', 'Comment': 'only 20 minutes','Sentiment': 'Positive'},
 {'Classification': 'Quality','Comment': 'pretty good soup','Sentiment': 'Positive'},
 {'Classification': 'Usability', 'Comment': 'noises it makes when it's blending are somewhat terrifying', 'Sentiment': 'Negative'},
 {'Classification': 'Safety','Comment': 'outside gets pretty hot','Sentiment': 'Negative'},
 {'Classification': 'Design','Comment': 'Hard to unplug the lid without knocking the blender against the side, which is not a nice sound', 'Sentiment': 'Negative'}
]


DOCUMENT
", REVIEW_TEXT, '\n\nRESULT\n')) as result
FROM catalog.schema.product_reviews
LIMIT 10

Python


import json
from pyspark.sql.functions import expr

# Define the opinion mining prompt as a multi-line string.
opinion_prompt = """You are an opinion mining service. Given a piece of text, output an array of json results that extracts key user opinions, a classification, and a Positive, Negative, Neutral, or Mixed sentiment about that subject.

AVAILABLE CLASSIFICATIONS
Quality, Service, Design, Safety, Efficiency, Usability, Price

Examples below:

DOCUMENT
I got soup. It really did take only 20 minutes to make some pretty good soup.The noises it makes when it's blending are somewhat terrifying, but it gives a little beep to warn you before it does that.It made three or four large servings of soup.It's a single layer of steel, so the outside gets pretty hot. It can be hard to unplug the lid without knocking the blender against the side, which is not a nice sound.The soup was good and the recipes it comes with look delicious, but I'm not sure I'll use it often. 20 minutes of scary noises from the kitchen when I already need comfort food is not ideal for me. But if you aren't sensitive to loud sounds it does exactly what it says it does.

RESULT
[
 {'Classification': 'Efficiency', 'Comment': 'only 20 minutes','Sentiment': 'Positive'},
 {'Classification': 'Quality','Comment': 'pretty good soup','Sentiment': 'Positive'},
 {'Classification': 'Usability', 'Comment': 'noises it makes when it's blending are somewhat terrifying', 'Sentiment': 'Negative'},
 {'Classification': 'Safety','Comment': 'outside gets pretty hot','Sentiment': 'Negative'},
 {'Classification': 'Design','Comment': 'Hard to unplug the lid without knocking the blender against the side, which is not a nice sound', 'Sentiment': 'Negative'}
]

DOCUMENT
"""

# Escape the prompt so it can be safely embedded in the SQL expression.
escaped_prompt = json.dumps(opinion_prompt)

# Read the source table and limit to 10 rows.
df = spark.table("catalog.schema.product_reviews").limit(10)

# Apply the LLM inference to each row, concatenating the prompt, the review text, and the tail string.
result_df = df.withColumn(
    "result",
    expr(f"ai_query('databricks-meta-llama-3-3-70b-instruct', request => concat({escaped_prompt}, REVIEW_TEXT, '\\n\\nRESULT\\n'))")
)

# Display the result DataFrame.
display(result_df)

AI Functions using Structured Streaming

Apply AI inference in near real-time or micro-batch scenarios using ai_query and Structured Streaming.

Step 1. Read your static Delta table

Read your static Delta table as if it were a stream.


from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.getOrCreate()

# Spark processes all existing rows exactly once in the first micro-batch.
df = spark.table("enterprise.docs")  # Replace with your table name containing enterprise documents
df.repartition(50).write.format("delta").mode("overwrite").saveAsTable("enterprise.docs")
df_stream = spark.readStream.format("delta").option("maxBytesPerTrigger", "50K").table("enterprise.docs")

# Define the prompt outside the SQL expression.
prompt = (
    "You are provided with an enterprise document. Summarize the key points in a concise paragraph. "
    "Do not include extra commentary or suggestions. Document: "
)

Step 2. Apply ai_query

Spark processes this only once for static data unless new rows arrive in the table.


df_transformed = df_stream.select(
    "document_text",
    F.expr(f"""
      ai_query(
        'databricks-meta-llama-3-1-8b-instruct',
        CONCAT('{prompt}', document_text)
      )
    """).alias("summary")
)

Step 3: Write the summarized output

Write the summarized output to another Delta table


# Time-based triggers apply, but only the first trigger processes all existing static data.
query = df_transformed.writeStream \
    .format("delta") \
    .option("checkpointLocation", "/tmp/checkpoints/_docs_summary") \
    .outputMode("append") \
    .toTable("enterprise.docs_summary")

query.awaitTermination()