Author AI agents in code

This article shows how to author an AI agent in code, using MLflow ChatModel. Azure Databricks leverages MLflow ChatModel to ensure compatibility with Databricks AI agent features like evaluation, tracing, and deployment.

What is ChatModel?

ChatModel is an MLflow class designed to simplify the creation of conversational AI agents. It provides a standardized interface for building models compatible with OpenAI’s ChatCompletion API.

ChatModel extends OpenAI’s ChatCompletion schema. This approach allows you to maintain broad compatibility with platforms supporting the ChatCompletion standard, while also adding your own custom functionality

By using ChatModel, developers can create agents that are compatible with Databricks and MLflow tools for agent tracking, evaluation, and lifecycle management, which are essential for deploying production-ready models.

See MLflow: Getting Started with ChatModel.

Requirements

Databricks recommends installing the latest version of the MLflow Python client when developing agents.

To author and deploy agents using the approach in this article, you must meet the following requirements:

  • Install databricks-agents version 0.15.0 and above
  • Install mlflow version 2.20.0 and above
%pip install -U -qqqq databricks-agents>=0.15.0 mlflow>=2.20.0

Create a ChatModel agent

You can author your agent as a subclass of mlflow.pyfunc.ChatModel. This method provides the following benefits:

  • Allows you to write agent code compatible with the ChatCompletion schema using typed Python classes.
  • MLflow will automatically infer a chat completion-compatible signature when logging the agent, even without an input_example. This simplifies the process of registering and deploying the agent. See Infer Model Signature during logging.

The following code is best executed in a Databricks notebook. Notebooks provide a convenient environment for developing, testing, and iterating on your agent.

The MyAgent class extends mlflow.pyfunc.ChatModel, implementing the required predict method. This ensures compatibility with Mosaic AI Agent Framework.

The class also includes the optional methods _create_chat_completion_chunk and predict_stream to handle streaming outputs.

from dataclasses import dataclass
from typing import Optional, Dict, List, Generator
from mlflow.pyfunc import ChatModel
from mlflow.types.llm import (
    # Non-streaming helper classes
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionChunk,
    ChatMessage,
    ChatChoice,
    ChatParams,
    # Helper classes for streaming agent output
    ChatChoiceDelta,
    ChatChunkChoice,
)

class MyAgent(ChatModel):
    """
    Defines a custom agent that processes ChatCompletionRequests
    and returns ChatCompletionResponses.
    """
    def predict(self, context, messages: list[ChatMessage], params: ChatParams) -> ChatCompletionResponse:
        last_user_question_text = messages[-1].content
        response_message = ChatMessage(
            role="assistant",
            content=(
                f"I will always echo back your last question. Your last question was: {last_user_question_text}. "
            )
        )
        return ChatCompletionResponse(
            choices=[ChatChoice(message=response_message)]
        )

    def _create_chat_completion_chunk(self, content) -> ChatCompletionChunk:
        """Helper for constructing a ChatCompletionChunk instance for wrapping streaming agent output"""
        return ChatCompletionChunk(
                choices=[ChatChunkChoice(
                    delta=ChatChoiceDelta(
                        role="assistant",
                        content=content
                    )
                )]
            )

    def predict_stream(
        self, context, messages: List[ChatMessage], params: ChatParams
    ) -> Generator[ChatCompletionChunk, None, None]:
        last_user_question_text = messages[-1].content
        yield self._create_chat_completion_chunk(f"Echoing back your last question, word by word.")
        for word in last_user_question_text.split(" "):
            yield self._create_chat_completion_chunk(word)

agent = MyAgent()
model_input = ChatCompletionRequest(
    messages=[ChatMessage(role="user", content="What is Databricks?")]
)
response = agent.predict(context=None, model_input=model_input)
print(response)

While the agent class MyAgent is defined in one notebook, you should create a separate driver notebook. The driver notebook logs the agent to Model Registry and deploys the agent using Model Serving.

This separation follows Databricks’ recommended workflow for logging models using MLflow’s Models from Code methodology.

Example: Wrap LangChain in ChatModel

If you have an existing LangChain model and want to integrate it with other Mosaic AI agent features, you can wrap it in an MLflow ChatModel to ensure compatibility.

This code sample does the following steps to wrap a LangChain runnable as a ChatModel:

  1. Wrap the final output of the LangChain with mlflow.langchain.output_parsers.ChatCompletionOutputParser to produce a chat completion output signature
  2. The LangchainAgent class extends mlflow.pyfunc.ChatModel and implements two key methods:
    • predict: Handles synchronous predictions by invoking the chain and returning a formatted response.
    • predict_stream: Handles streaming predictions by invoking the chain and yielding chunks of responses.
from mlflow.langchain.output_parsers import ChatCompletionOutputParser
from mlflow.pyfunc import ChatModel
from typing import Optional, Dict, List, Generator
from mlflow.types.llm import (
    ChatCompletionResponse,
    ChatCompletionChunk
)

chain = (
    <your chain here>
    | ChatCompletionOutputParser()
)

class LangchainAgent(ChatModel):
    def _prepare_messages(self, messages: List[ChatMessage]):
        return {"messages": [m.to_dict() for m in messages]}

    def predict(
        self, context, messages: List[ChatMessage], params: ChatParams
    ) -> ChatCompletionResponse:
        question = self._prepare_messages(messages)
        response_message = self.chain.invoke(question)
        return ChatCompletionResponse.from_dict(response_message)

    def predict_stream(
        self, context, messages: List[ChatMessage], params: ChatParams
    ) -> Generator[ChatCompletionChunk, None, None]:
        question = self._prepare_messages(messages)
        for chunk in chain.stream(question):
          yield ChatCompletionChunk.from_dict(chunk)

Use parameters to configure the agent

In the Agent Framework, you can use parameters to control how agents are executed. This allows you to quickly iterate by varying characteristics of your agent without changing the code. Parameters are key-value pairs that you define in a Python dictionary or a .yaml file.

To configure the code, create a ModelConfig, a set of key-value parameters. ModelConfig is either a Python dictionary or a .yaml file. For example, you can use a dictionary during development and then convert it to a .yaml file for production deployment and CI/CD. For details about ModelConfig, see the MLflow documentation.

An example ModelConfig is shown below.

llm_parameters:
  max_tokens: 500
  temperature: 0.01
model_serving_endpoint: databricks-dbrx-instruct
vector_search_index: ml.docs.databricks_docs_index
prompt_template: 'You are a hello world bot. Respond with a reply to the user''s
  question that indicates your prompt template came from a YAML file. Your response
  must use the word "YAML" somewhere. User''s question: {question}'
prompt_template_input_vars:
- question

To call the configuration from your code, use one of the following:

# Example for loading from a .yml file
config_file = "configs/hello_world_config.yml"
model_config = mlflow.models.ModelConfig(development_config=config_file)

# Example of using a dictionary
config_dict = {
    "prompt_template": "You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}",
    "prompt_template_input_vars": ["question"],
    "model_serving_endpoint": "databricks-dbrx-instruct",
    "llm_parameters": {"temperature": 0.01, "max_tokens": 500},
}

model_config = mlflow.models.ModelConfig(development_config=config_dict)

# Use model_config.get() to retrieve a parameter value
value = model_config.get('sample_param')

Set retriever schema

AI agents often use retrievers, a type of agent tool that finds and returns relevant documents using a Vector Search index. For more information on retrievers, see Unstructured retrieval AI agent tools.

To ensure that retrievers are traced properly, call mlflow.models.set_retriever_schema when you define your agent in code. Use set_retriever_schema to map the column names in the returned table to MLflow’s expected fields such as primary_key, text_column, and doc_uri.

# Define the retriever's schema by providing your column names
# These strings should be read from a config dictionary
mlflow.models.set_retriever_schema(
    name="vector_search",
    primary_key="chunk_id",
    text_column="text_column",
    doc_uri="doc_uri"
    # other_columns=["column1", "column2"],
)

Note

The doc_uri column is especially important when evaluating the retriever’s performance. doc_uri is the main identifier for documents returned by the retriever, allowing you to compare them against ground truth evaluation sets. See Evaluation sets

You can also specify additional columns in your retriever’s schema by providing a list of column names with the other_columns field.

If you have multiple retrievers, you can define multiple schemas by using unique names for each retriever schema.

Custom inputs and outputs

Some scenarios may require additional agents inputs, such as client_type and session_id, or outputs like retrieval source links that should not be included in the chat history for future interactions.

For these scenarios, MLflow ChatModel natively support augmenting OpenAI chat completion requests and responses with the ChatParams fields custom_input and custom_output.

See the following examples to learn how to create custom inputs and outputs for PyFunc and LangGraph agents.

Warning

The Agent Evaluation review app does not currently support rendering traces for agents with additional input fields.

PyFunc custom schemas

The following notebooks show a custom schema example using PyFunc.

PyFunc custom schema agent notebook

Get notebook

PyFunc custom schema driver notebook

Get notebook

LangGraph custom schemas

The following notebooks show a custom schema example using LangGraph. You can modify the wrap_output function in the notebooks to parse and extract information from the message stream.

LangGraph custom schema agent notebook

Get notebook

LangGraph custom schema driver notebook

Get notebook

Provide custom_inputs in the AI Playground and agent review app

If your agent accepts additional inputs using the custom_inputs field, you can manually provide these inputs in both the AI Playground and the agent review app.

  1. In either the AI Playground or the Agent Review App, select the gear icon Gear icon.

  2. Enable custom_inputs.

  3. Provide a JSON object that matches your agent’s defined input schema.

    Provide custom_inputs in the AI playground.

Example notebooks

These notebooks create a simple “Hello, world” chain to illustrate creating an agent in Databricks. The first example creates a simple chain, and the second example notebook illustrates how to use parameters to minimize code changes during development.

Simple chain notebook

Get notebook

Simple chain driver notebook

Get notebook

Parameterized chain notebook

Get notebook

Parameterized chain driver notebook

Get notebook

Next steps