Pub/Sub - A Solution to “CUDA out of memory” for LLM apps.

Soohoon Choi

Nov 16, 2023

Background

If you've ever used the OpenAI API in a project, there is one error that you've probably come across, Error 429: You exceeded your current quota. Alternatively, if you’ve hosted your own models you might have seen this nice message before your cluster crashing RuntimeError: CUDA out of memory.


We here at Onboard AI have spent countless hours bashing our heads against the wall trying to avoid this error and in this blog post we hope to document one of the architectures we have implemented to do just that.

In an early LLM dependent application, unless you ask for the user to provide their own model there is often only a singular model endpoint to be consumed by all live instances of the application. Trying to create more Open AI accounts for more API endpoints can work but one can only create so many accounts (though I would love to be proven wrong). Trying to host your own models pose challenges of its own such as proper scaling, maintaining reliability, and the optimization of high compute costs.

In our case, we would have large LLM processing tasks that would sometimes take hours to process, the fact that there could be an arbitrary number of these tasks at any moment would sometimes lead to our entire service crashing from the errors above. As we faced our limits with horizontal scaling of compute, we needed a proper method of throttling requests to control the number of requests our model endpoints received at any given moment.

The nature of our task (the preprocessing of an entire codebase to understand its structure and inner workings) had two notable characteristics:

  1. They would create a large number of requests per codebase request, leading to large bursts of requests at any given moment.

  2. They were asynchronous tasks in that we didn’t need to worry about streaming results to the end users as fast as possible.

We noted that if we could implement a method to throttle the number of requests, we could control and optimize for the max throughput of our LLM endpoints/clusters at any given moment. This led us to pick a pub/sub model with a messaging queue to do just that. This architecture would compromise a publisher which would create the request that would be made to the LLM, a messaging queue to throttle requests, and a subscriber who takes the messages in batches to process the requests just below the expected rate limit/memory bandwidth of the model endpoints.

Sample App Tutorial

Here we try and implement a sample application that implements the architecture described above. We use a sample express app to be hosted on AWS Fargate and use AWS Simple Queue Service (SQS) for the messaging queue. All of the infrastructure necessary for the deployment of the code + the full application can be found on GitHub:

GitHub Repository

The Publisher

In our publisher, we simply add an endpoint that can trigger some workflow to transform and add a message to SQS. Any preprocessing that would increase the number of LLM calls, such as chunking, should be done here so that the message queue has an accurate number of LLM calls to be made. It should be noted that each message has a payload limit, so large prompt building via RAG should be done later in the subscriber application.

// /app/publisher/index.js
// part of a sample express app
const sample_process = async (body) => {
	// Your logic here
	return body;
}
app.post('/endpoint', async (req, res) => {
	try {
		const message = await sample_process(req.body) // process your request here
		await sqs.send(new SendMessageCommand({ // send to the queue
			QueueUrl: process.env.SQS_URL,
			MessageBody: JSON.stringify(message),
		}))
		res.status(200).send('Success')
	} catch (e) {
		res.status(500).send(e.message)
	}
})

The Messaging Queue

If not familiar, AWS SQS is just one big queue in the cloud with nice development properties. It ensures that messages remain in the queue until the consumer/subscriber notifies the service that the message has been successfully processed via a delete request. Depending on the task you can also guarantee ordering by turning it into a FIFO queue. AWS SQS is dead easy to use/ However there are some drawbacks in that you can only send/consume 10 messages per batch request which makes it hard to batch process large amounts of requests.

The Subscriber

The subscriber’s job should contain the bulk of the processing from all of the context engineering to the actual LLM call and more. In our sample application, we make use of SQS’s built-in long polling feature via the WaitTimeSeconds parameter to poll for messages.


// /app/subscriber/index.js
// part of a sample express app
const sample_process = async (messages) => {
  // sample llms processing with openai
	// you can include all of your context engineering code here
  const results = await Promise.allSettled(
    messages.map(async (message) => {
      const body = JSON.parse(message.Body)
      const messages = [
        {
          role: 'system',
          content: '<YOUR SYSTEM PROMPT>'
        },
        {
          role: 'user',
          content: body
        },
      ]
      const response = await openai.chat.completions.create({
        model: process.env.OPENAI_MODEL,
        messages,
      })
      return response
    })
  )
  return results // refer to allSettled docs for type
}
async function main() {
  while (true) { // Poll the SQS queue
    try {
      const receiveParams = {
        QueueUrl: process.env.SQS_URL,
				MaxNumberOfMessages: 10, // SQS max
        VisibilityTimeout: 60, // how long until message reappears in the queue
        WaitTimeSeconds: 10, // for long polling
      }
      const data = await sqs.send(
        new ReceiveMessageCommand(receiveParams)
      )
      if (data.Messages && data.Messages.length > 0) {
        const receiptHandles = data.Messages.map((message) => message.ReceiptHandle)
        const results = await sample_process(data.Messages)
        await sqs.send( // signify that we are done processing
          new DeleteMessageCommand({
            QueueUrl: process.env.SQS_URL,
            ReceiptHandle: receiptHandles
          }))
        console.log(results);
      
    } catch (err) {
      console.log(err)