Question
I have a relatively simple FastAPI app that accepts a query and streams back the response from ChatGPT's API. ChatGPT is streaming back the result and I can see this being printed to console as it comes in.
What's not working is the StreamingResponse
back via FastAPI. The response
gets sent all together instead. I'm really at a loss as to why this isn't
working.
Here is the FastAPI app code:
import os
import time
import openai
import fastapi
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
auth_scheme = HTTPBearer()
app = fastapi.FastAPI()
openai.api_key = os.environ["OPENAI_API_KEY"]
def ask_statesman(query: str):
#prompt = router(query)
completion_reason = None
response = ""
while not completion_reason or completion_reason == "length":
openai_stream = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": query}],
temperature=0.0,
stream=True,
)
for line in openai_stream:
completion_reason = line["choices"][0]["finish_reason"]
if "content" in line["choices"][0].delta:
current_response = line["choices"][0].delta.content
print(current_response)
yield current_response
time.sleep(0.25)
@app.post("/")
async def request_handler(auth_key: str, query: str):
if auth_key != "123":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": auth_scheme.scheme_name},
)
else:
stream_response = ask_statesman(query)
return StreamingResponse(stream_response, media_type="text/plain")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, debug=True, log_level="debug")
And here is the very simple test.py
file to test this:
import requests
query = "How tall is the Eiffel tower?"
url = "http://localhost:8000"
params = {"auth_key": "123", "query": query}
response = requests.post(url, params=params, stream=True)
for chunk in response.iter_lines():
if chunk:
print(chunk.decode("utf-8"))
Answer
First, it wouldn't be good practice to use a
POST
request for requesting data from the server. Using a
GET
request
instead would be more suitable to your case. In addition to that, you
shouldn't be sending credentials, such as auth_key
as part of the URL (i.e.,
using the query string), but you
should rather use [Headers
](https://fastapi.tiangolo.com/tutorial/header-
params/) and/or [Cookies
](https://fastapi.tiangolo.com/tutorial/cookie-
params/) (using HTTPS
).
Please have a look at this
answer for more details and
examples on the concepts of headers and cookies, as well as the risks involved
when using query parameters instead. Helpful posts around this topic can also
be found here and
here, as well as
here,
here and
here.
Second, if you are executing a blocking operation (i.e., blocking I/O-bound or
CPU-bound tasks) inside the
[StreamingResponse
](https://fastapi.tiangolo.com/advanced/custom-
response/#streamingresponse)'s generator function, you should define the
generator function with def
instead of async def
, as, otherwise, the
blocking operation, as well as the time.sleep()
function that is used inside
your generator, would blcok the event loop. As explained
here, if the function for
streaming the response body is a normal def
generator and not an async def
one, FastAPI will use
iterate_in_threadpool()
to run the iterator/generator in a separate thread that is then await
ed. If
you prefer using an async def
generator, then make sure to execute any
blocking operations in an external ThreadPool
(or ProcessPool
) and await
it, as well as use await asyncio.sleep()
instead of time.sleep()
, in cased
you need to add delay in the execution of an operation. Have a look at this
detailed answer for more
details and examples.
Third, you are using requests
'
iter_lines()
function, which iterates over the response data, one line at a time. If,
however, the response data did not include any line break character (i.e.,
\n
), you wouldn't see the data on client's console getting printed as they
arrive, until the entire response is received by the client and printed as a
whole. In that case, you should instead use
iter_content()
and specify the chunk_size
as desired (both cases are demonstrated in the
example below).
Finally, if you would like the StreamingResponse
to work in every browser
(including Chrome as well)—in the sense of being able to see the data as they
stream in—you should specify the media_type
to a different type than
text/plain
(e.g., application/json
or text/event-stream
, see
here), or disableMIME
Sniffing. As explained
here, browsers will start buffering
text/plain
responses for a certain amount (around 1445 bytes, as documented
here), in
order to check whether or not the content received is actually plain text. To
avoid that, you can either set the media_type
to text/event-stream
(used
for [server-sent events](https://developer.mozilla.org/en-
US/docs/Web/API/Server-sent_events/Using_server-sent_events)), or keep using
text/plain
, but set the [X-Content-Type- Options
](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Content-
Type-Options) response header to nosniff
, which would disable MIME Sniffing
(both options are demonstrated in the example below).
Working Example
app.py
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio
app = FastAPI()
async def fake_data_streamer():
for i in range(10):
yield b'some fake data\n\n'
await asyncio.sleep(0.5)
# If your generator contains blocking operations such as time.sleep(),
# then define the generator function with normal `def`; or use `async def`,
# but run any blocking operations in external ThreadPool, etc. (see 2nd paragraph of this answer)
'''
import time
def fake_data_streamer():
for i in range(10):
yield b'some fake data\n\n'
time.sleep(0.5)
'''
@app.get('/')
async def main():
return StreamingResponse(fake_data_streamer(), media_type='text/event-stream')
# or, use:
'''
headers = {'X-Content-Type-Options': 'nosniff'}
return StreamingResponse(fake_data_streamer(), headers=headers, media_type='text/plain')
'''
test.py (using Python requests
)
import requests
url = "http://localhost:8000/"
with requests.get(url, stream=True) as r:
for chunk in r.iter_content(1024): # or, for line in r.iter_lines():
print(chunk)
test.py (using httpx
—see
this, as well as
this and
this for the benefits of
using httpx
over requests
)
import httpx
url = 'http://127.0.0.1:8000/'
with httpx.stream('GET', url) as r:
for chunk in r.iter_raw(): # or, for line in r.iter_lines():
print(chunk)