Back to blog
engineering
Rust
AI
Building Streaming AI Responses with Axum and SSE

Building Streaming AI Responses with Axum and SSE

How we stream Claude and OpenAI responses token-by-token to the browser using Server-Sent Events in Rust.

ellix.ai TeamApril 22, 20266 min read

Building Streaming AI Responses with Axum and SSE

Non-streaming AI responses feel slow. A 300-token answer takes 2–3 seconds to generate — that's 2–3 seconds of the user staring at a spinner. Streaming delivers tokens as they're generated, making the experience feel instantaneous even for long responses.

Here's how we built streaming in the aiassist.chat backend using Axum and Server-Sent Events — including the full handler, error handling, reconnection, and testing.

Why SSE over WebSockets, and Why Not Polling

There are three approaches to delivering tokens to the client: WebSockets, SSE, or polling.

Polling (repeated GET /response?conversation_id=X) is the worst option. It introduces latency equal to your poll interval, hammers your server with unnecessary requests, and the implementation complexity rivals SSE without any of the benefits. If you're polling for AI responses, migrate.

WebSockets are bidirectional. A chat widget sending one message and receiving one streaming response doesn't need bidirectionality — it's a fundamentally one-way flow after the message is sent. WebSockets also require a connection upgrade, which some corporate proxies and load balancers block or reset. Managing WebSocket connection state across reconnects adds meaningful complexity.

SSE is simpler, works over standard HTTP/2, doesn't require a connection upgrade, and has built-in reconnection semantics via the Last-Event-ID header. For a chat widget embedded on third-party sites — where you have no control over the network infrastructure between you and the visitor — SSE is the right primitive.

SSE isn't just simpler than WebSockets for this use case — it's more reliable in the hostile network environments that embedded widgets encounter.

The Full Axum Handler

Axum's Sse response type pairs with tokio::sync::mpsc channels. Here's the complete handler with authentication, rate limiting, error propagation, and token budget management:

use axum::{
    extract::{Extension, Json},
    response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::{Stream, StreamExt};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use std::convert::Infallible;

#[derive(Deserialize)]
pub struct ChatStreamRequest {
    pub message: String,
    pub conversation_id: Uuid,
}

#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
    Token { text: String },
    Done { conversation_id: Uuid },
    Error { message: String },
}

pub async fn chat_stream_handler(
    Extension(tenant): Extension<AuthenticatedTenant>,
    Extension(state): Extension<AppState>,
    Json(req): Json<ChatStreamRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, AppError> {
    // Check rate limit before allocating the channel
    state
        .rate_limiter
        .check(tenant.id, 60, tenant.plan.requests_per_minute())
        .await
        .map_err(|_| AppError::RateLimitExceeded)?;

    // Check token budget
    let remaining_budget = state
        .billing
        .get_remaining_budget(tenant.id)
        .await?;

    if remaining_budget == 0 {
        return Err(AppError::BudgetExhausted);
    }

    let (tx, rx) = mpsc::channel::<StreamEvent>(32);
    let state = state.clone();
    let tenant_id = tenant.id;

    tokio::spawn(async move {
        let result = run_rag_and_stream(
            &state,
            tenant_id,
            &req.message,
            req.conversation_id,
            tx.clone(),
        )
        .await;

        if let Err(e) = result {
            let _ = tx
                .send(StreamEvent::Error {
                    message: e.to_string(),
                })
                .await;
        }
    });

    let event_stream = ReceiverStream::new(rx).map(|evt| {
        let data = serde_json::to_string(&evt).unwrap_or_default();
        Ok(Event::default().data(data))
    });

    Ok(Sse::new(event_stream).keep_alive(
        KeepAlive::new()
            .interval(std::time::Duration::from_secs(15))
            .text("ping"),
    ))
}

The keep_alive with a 15-second interval is essential. Without it, connections behind proxies or load balancers with shorter idle timeouts (AWS ALB defaults to 60 seconds, but some corporate proxies are much shorter) will terminate before a long response finishes. The ping maintains the connection without sending visible content to the client.

Redis-Based Session State for Streaming

For multi-instance deployments, streaming state needs to be coordinated across instances. If a user reconnects to a different backend instance, that instance needs to know the conversation context.

We store active streaming sessions in Redis with a short TTL:

pub struct StreamingSession {
    pub conversation_id: Uuid,
    pub tenant_id: Uuid,
    pub started_at: i64,
    pub tokens_emitted: u32,
    pub last_event_id: Option<String>,
}

pub async fn create_streaming_session(
    redis: &RedisPool,
    session: &StreamingSession,
) -> Result<()> {
    let key = format!("stream_session:{}", session.conversation_id);
    let value = serde_json::to_string(session)?;

    // TTL of 10 minutes — longer than any realistic streaming response
    redis.set_ex(&key, &value, 600).await?;
    Ok(())
}

When the stream completes or errors, we delete the session key. If it expires naturally, the conversation is considered complete. This gives us clean state management without any background cleanup jobs.

Reconnection Handling When SSE Drops

The SSE protocol has a built-in reconnection mechanism: the browser automatically reconnects after a configurable delay (default 3 seconds) and sends the Last-Event-ID header with the ID of the last event it received.

To support seamless reconnection, emit event IDs on your token events:

// Emit tokens with sequential IDs
let event = Event::default()
    .id(token_index.to_string())
    .data(serde_json::to_string(&StreamEvent::Token { text: token })?);

On reconnection, your handler checks the Last-Event-ID and can resume from that point:

pub async fn chat_stream_handler(
    headers: HeaderMap,
    // ... other extractors
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, AppError> {
    let last_event_id: Option<u32> = headers
        .get("last-event-id")
        .and_then(|v| v.to_str().ok())
        .and_then(|s| s.parse().ok());

    // If reconnecting mid-stream, replay tokens from last_event_id + 1
    // (stored in your Redis session)
    // ...
}

In practice, most streaming responses complete in under 10 seconds. Reconnection mid-stream is rare. But implementing it correctly prevents a poor UX where a network hiccup produces a truncated response with no recovery path.

Testing Streaming Endpoints

Streaming endpoints are the most-skipped test surface in most codebases. Here's how we test them:

#[tokio::test]
async fn test_chat_stream_emits_tokens_and_done_event() {
    let app = build_test_app().await;
    let tenant = create_test_tenant(&app).await;

    let response = app
        .authenticated_request(tenant.api_key)
        .post("/chat/stream")
        .json(&json!({
            "message": "What is your refund policy?",
            "conversation_id": Uuid::new_v4(),
        }))
        .send()
        .await
        .expect("request failed");

    assert_eq!(response.status(), 200);
    assert_eq!(
        response.headers()["content-type"],
        "text/event-stream"
    );

    // Collect all SSE events
    let body = response.text().await.unwrap();
    let events: Vec<StreamEvent> = parse_sse_events(&body);

    // At least one token event
    assert!(events.iter().any(|e| matches!(e, StreamEvent::Token { .. })));

    // Exactly one Done event at the end
    let done_events: Vec<_> = events
        .iter()
        .filter(|e| matches!(e, StreamEvent::Done { .. }))
        .collect();
    assert_eq!(done_events.len(), 1);
    assert!(matches!(events.last().unwrap(), StreamEvent::Done { .. }));
}

#[tokio::test]
async fn test_chat_stream_sends_error_event_on_failure() {
    let app = build_test_app_with_failing_llm().await;
    let tenant = create_test_tenant(&app).await;

    let response = app
        .authenticated_request(tenant.api_key)
        .post("/chat/stream")
        // ...
        .send()
        .await
        .unwrap();

    let body = response.text().await.unwrap();
    let events: Vec<StreamEvent> = parse_sse_events(&body);

    // Must terminate with an error event, not a silent truncation
    assert!(matches!(events.last().unwrap(), StreamEvent::Error { .. }));
}

The key invariants to test: the stream always terminates (with either Done or Error), it never silently truncates, and error events carry a human-readable message.

Token Budget Management During Streaming

Streaming complicates billing because you don't know the total token count until the stream ends. Our approach:

  1. Reserve a token budget at stream start based on the max expected response length
  2. Count actual tokens emitted during the stream
  3. Commit the actual count at stream end, releasing the reserved balance
// At stream start: reserve max budget
let reservation = billing.reserve_tokens(tenant_id, MAX_RESPONSE_TOKENS).await?;

// During stream: count tokens emitted
let mut tokens_emitted = 0u32;
while let Some(token) = llm_stream.next().await {
    tokens_emitted += token.len() as u32; // approximate; use tiktoken for accuracy
    tx.send(StreamEvent::Token { text: token }).await?;
}

// At stream end: commit actual usage
billing.commit_tokens(tenant_id, reservation.id, tokens_emitted).await?;

This prevents over-billing on short responses while still enforcing budget limits. The reservation ensures a tenant can't start 50 concurrent streams that together would exceed their budget — the reservation checks against remaining budget atomically.