Chaining OpenAI Models: Better and Faster

This blog explores combining AI models for a efficient, accurate, and cost-effective solution. We chain responses from a larger model (GPT-4o) to a smaller one (GPT-4o mini) to convert complex outputs into structured JSON. A Python script demonstrates this step-by-step process.

Chaining OpenAI Models: Better and Faster

Overview

The goal is to fact-check a list of statements using a powerful model and then convert the verbose responses into a simple JSON format using a smaller, cheaper model.

The Need for Chaining Models

Large language models like gpt-4o offer advanced reasoning and comprehension abilities, making them ideal for tasks like fact-checking. However, they can be costly and slow for simple tasks like data formatting or conversion. By chaining the output of a large model into a smaller one, we can:

  • Optimize Costs: Use the expensive model only when necessary.
  • Improve Efficiency: Delegate simpler tasks to faster models.
  • Maintain Accuracy: Ensure complex reasoning tasks are handled by capable models.

While we will use only OpenAI models in this blog post for simplicity, we recommend this approach especially for models that do not have a strict JSON mode, like Anthropic Claude 3.5 Sonnet. (We also recommend using LiteLLM to expose Anthropic's API as a OpenAI API compatible endpoint.)

Step-by-Step Process

Let's break down the script and understand how each part contributes to the overall process.

1. Import Necessary Libraries

import asyncio
import logging
import httpx
from openai import AsyncOpenAI
import random
import os
import json
import numpy as np

2. Configure Logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
c_handler = logging.StreamHandler()
logger.addHandler(c_handler)

3. Retrieve API Credentials

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"

We obtain the OpenAI API key and base URL from environment variables. Make sure to have these set if you don't already!

4. Initialize the OpenAI Client

client = AsyncOpenAI(
    api_key=OPENAI_API_KEY,
    base_url=OPENAI_API_BASE,
    http_client=httpx.AsyncClient(
        http2=True,
        limits=httpx.Limits(
            max_connections=None,
            max_keepalive_connections=None,
            keepalive_expiry=None,
        ),
    ),
    max_retries=0,
    timeout=3600,
)

We create an asynchronous OpenAI client with customized HTTP settings.

5. Define the Main Asynchronous Function

async def main():
    statements_to_check = [
        "SEZC means Special Economic Zone Company.",
        "SEZC means Special Economic Zone Corporation.",
    ]

The main function contains the core logic. We define a list of statements that we want to fact-check.

6. Fact-Checking with the Large Model

for statement_to_check in statements_to_check:
    logger.debug(f"--------")
    logger.debug(f"Checking statement: {statement_to_check}")

    response_text = await client.chat.completions.create(
        model="gpt-4o-2024-08-06",
        max_tokens=8192,
        seed=random.randint(1, 0x7FFFFFFF),
        messages=[
            {
                "role": "system",
                "content": "Fact check the user's message.",
            },
            {
                "role": "user",
                "content": statement_to_check,
            },
        ],
        temperature=0.0,
    )

We iterate over each statement and use the gpt-4o-2024-08-06 model to perform fact-checking. The model provides a detailed response about the accuracy of the statement.

7. Extract and Log the Response

text_answer = response_text.choices[0].message.content
logger.debug(text_answer)

We extract the assistant's response from the API reply and log it for review.

Yes, SEZC stands for Special Economic Zone Company.

GPT-4o response from the first statement.

No, SEZC typically stands for "Special Economic Zone Company," not "Special Economic Zone Corporation."

GPT-4o response from the second statement.

8. Converting the Response to JSON with the Smaller Model

response_json = await client.chat.completions.create(
    response_format={"type": "json_object"},
    model="gpt-4o-mini-2024-07-18",
    max_tokens=20,
    logprobs=True,
    stop=[
        "rue",
        "alse",
        "}",
    ],
    seed=random.randint(1, 0x7FFFFFFF),
    messages=[
        {
            "role": "user",
            "content": f"Fact check this statement:\n\n```{statement_to_check}```",
        },
        {
            "role": "assistant",
            "content": text_answer,
        },
        {
            "role": "user",
            "content": 'Convert your answer into a JSON object that follows {"fact": boolean}',
        },
    ],
    temperature=0.0,
)

We pass the assistant's response to a smaller model, gpt-4o-mini-2024-07-18, asking it to convert the answer into a JSON object. This model is faster and suitable for simple tasks like formatting.

We use the stop parameter to save a bit on costs as well, which we covered in a previous blog post.

9. Handling the Response and Extracting Probabilities

if response_json.choices[0].logprobs is None:
    raise ValueError("No logprobs in response")

last_logprob_content = response_json.choices[0].logprobs.content[-1]
last_logprob_token = last_logprob_content.token
last_logprob_float = last_logprob_content.logprob
last_logprob_percentage = np.exp(last_logprob_float)

We check that the response includes log probabilities, which help in assessing the confidence of the model's output (covered in a previous blog post of ours). We then extract the last token's log probability, which contains our true or false token.

10. Determining the Fact Status

if "t" in last_logprob_token.lower():
    is_fact = True
elif "f" in last_logprob_token.lower():
    is_fact = False
else:
    raise ValueError("Invalid token")

Based on the token (expected to be true or false), we determine whether the statement is a fact. We only check for one letter, as it is possible that the last token is only a single character and not a full word.

11. Creating and Logging the Result Object

result_obj = {
    "fact": is_fact,
    "logprob": last_logprob_float,
    "percentage": last_logprob_percentage,
}

logger.debug(json.dumps(result_obj, indent=2))

We compile the results into a JSON object and log it in a readable format.

{
  "fact": true,
  "logprob": -1.9361265e-07,
  "percentage": 0.9999998063873687
}

Final result for the first statement.

{
  "fact": false,
  "logprob": 0.0,
  "percentage": 1.0
}

Final result for the second statement.

Perfect! These are exactly the results we were hoping and expecting.

12. Minifying the JSON String

thin_json_str = json.dumps(result_obj, separators=(",", ":"))
logger.debug(f"Thin JSON: {repr(thin_json_str)}")

For compactness, we create a minified JSON string without extra whitespace and log it. This can be completely omitted; we included it for documentation purposes, as it's useful if you're going to be passing this output to another LLM.

13. Running the Main Function

if __name__ == "__main__":
    asyncio.run(main())

We check if the script is run directly and execute the main function using asyncio's event loop.

Full Code with Comments

Below is the complete script with comments explaining each step:

#!/usr/bin/env python3.11
# -*- coding: utf-8 -*-
# Author: David Manouchehri

import asyncio  # For asynchronous operations
import logging  # For logging messages
import httpx  # For making HTTP requests asynchronously
from openai import AsyncOpenAI  # OpenAI client
import random  # For generating random numbers
import os  # For accessing environment variables
import json  # For handling JSON data
import numpy as np  # For numerical operations

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
c_handler = logging.StreamHandler()
logger.addHandler(c_handler)

# Get OpenAI API credentials
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"

# Initialize OpenAI client
client = AsyncOpenAI(
    api_key=OPENAI_API_KEY,
    base_url=OPENAI_API_BASE,
    http_client=httpx.AsyncClient(
        http2=True,
        limits=httpx.Limits(
            max_connections=None,
            max_keepalive_connections=None,
            keepalive_expiry=None,
        ),
    ),
    max_retries=0,
    timeout=3600,
)

# Main asynchronous function
async def main():
    # List of statements to fact-check
    statements_to_check = [
        "SEZC means Special Economic Zone Company.",
        "SEZC means Special Economic Zone Corporation.",
    ]

    # Iterate over each statement
    for statement_to_check in statements_to_check:
        logger.debug(f"--------")
        logger.debug(f"Checking statement: {statement_to_check}")

        # Fact-check using the large model
        response_text = await client.chat.completions.create(
            model="gpt-4o-2024-08-06",
            max_tokens=8192,
            seed=random.randint(1, 0x7FFFFFFF),
            messages=[
                {
                    "role": "system",
                    "content": "Fact check the user's message.",
                },
                {
                    "role": "user",
                    "content": statement_to_check,
                },
            ],
            temperature=0.0,
        )

        # Extract the assistant's response
        text_answer = response_text.choices[0].message.content
        logger.debug(text_answer)

        # Convert the response to JSON using the smaller model
        response_json = await client.chat.completions.create(
            response_format={"type": "json_object"},
            model="gpt-4o-mini-2024-07-18",
            max_tokens=20,
            logprobs=True,
            stop=[
                "rue",
                "alse",
                "}",
            ],
            seed=random.randint(1, 0x7FFFFFFF),
            messages=[
                {
                    "role": "user",
                    "content": f"Fact check this statement:\n\n```{statement_to_check}```",
                },
                {
                    "role": "assistant",
                    "content": text_answer,
                },
                {
                    "role": "user",
                    "content": 'Convert your answer into a JSON object that follows {"fact": boolean}',
                },
            ],
            temperature=0.0,
        )

        # Ensure response includes log probabilities
        if response_json.choices[0].logprobs is None:
            raise ValueError("No logprobs in response")

        # Extract the last token's log probability
        last_logprob_content = response_json.choices[0].logprobs.content[-1]
        last_logprob_token = last_logprob_content.token
        last_logprob_float = last_logprob_content.logprob
        last_logprob_percentage = np.exp(last_logprob_float)

        # Determine if the statement is a fact
        if "t" in last_logprob_token.lower():
            is_fact = True
        elif "f" in last_logprob_token.lower():
            is_fact = False
        else:
            raise ValueError("Invalid token")

        # Create result object
        result_obj = {
            "fact": is_fact,
            "logprob": last_logprob_float,
            "percentage": last_logprob_percentage,
        }

        # Log the result
        logger.debug(json.dumps(result_obj, indent=2))

        # Minify the JSON string
        thin_json_str = json.dumps(result_obj, separators=(",", ":"))
        logger.debug(f"Thin JSON: {repr(thin_json_str)}")

# Execute the main function
if __name__ == "__main__":
    asyncio.run(main())

Conclusion

By chaining the responses from a larger model into a smaller one, we can optimize both the cost and efficiency of our applications. The larger model handles complex tasks like fact-checking, while the smaller model takes care of formatting and conversion tasks.

Key Takeaways:

  • Resource Optimization: Use powerful models only when necessary.
  • Cost Efficiency: Leverage smaller models for simpler tasks.
  • Structured Outputs: Convert complex responses into structured data formats like JSON.