Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import time
- import logging
- from fastapi import FastAPI, HTTPException, Request
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel
- from llama_cpp import Llama
- from threading import Lock
- log_directory = "/home/tomand/code/bielik_serving_final"
- os.makedirs(log_directory, exist_ok=True)
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler(os.path.join(log_directory, 'error_logging.log')),
- logging.StreamHandler()
- ]
- )
- logger = logging.getLogger(__name__)
- request_logger = logging.getLogger('request_logger')
- request_logger.setLevel(logging.INFO)
- request_handler = logging.FileHandler(os.path.join(log_directory, 'app.log'))
- request_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
- request_logger.addHandler(request_handler)
- app = FastAPI(title="Bielik API Service")
- MODEL_PATH = "/home/tomand/code/models/Bielik-11B-v2.3-Instruct.Q4_K_M.gguf"
- N_GPU_LAYERS = 0
- CONTEXT_SIZE = 16384
- MAX_OUTPUT_TOKENS = 512
- # Global lock for thread-safe access to is_processing
- processing_lock = Lock()
- is_processing = False
- class GenerationRequest(BaseModel):
- conversation_id: str
- conversation_transcription: str
- prompt: str
- class GenerationResponse(BaseModel):
- conversation_id: str
- model_response: str
- try:
- model = Llama(
- model_path=MODEL_PATH,
- n_ctx=CONTEXT_SIZE,
- n_gpu_layers=N_GPU_LAYERS
- )
- logger.info("Model loaded successfully")
- except Exception as e:
- logger.error(f"Error loading model: {str(e)}")
- raise
- async def iterate_bytes(chunks):
- for chunk in chunks:
- yield chunk
- @app.middleware("http")
- async def log_requests(request: Request, call_next):
- global is_processing
- if request.url.path == "/health":
- return await call_next(request)
- client_ip = request.client.host
- # Check if service is busy (only for /generate endpoint)
- if request.url.path == "/generate" and request.method == "POST":
- with processing_lock:
- if is_processing:
- logger.info(f"Request rejected - service busy: {client_ip}")
- return JSONResponse(
- status_code=503,
- content={"detail": "Service is busy processing another request"}
- )
- if request.method == "POST":
- try:
- request_body = await request.json()
- request_logger.info(f"Request from {client_ip}: {request_body}")
- except:
- request_body = {}
- logger.error("Could not parse request body as JSON")
- start_time = time.time()
- response = await call_next(request)
- process_time = time.time() - start_time
- response_body = [chunk async for chunk in response.body_iterator]
- response.body_iterator = iterate_bytes(response_body)
- try:
- response_text = b''.join(response_body).decode()
- request_logger.info(f"Response to {client_ip}: {response_text}, Process time: {process_time:.4f} seconds")
- except:
- logger.error("Could not decode response body")
- return response
- @app.post("/generate", response_model=GenerationResponse)
- async def generate_response(request: GenerationRequest):
- global is_processing
- try:
- # Set processing flag at the start of the request
- with processing_lock:
- is_processing = True
- prompt_text = f"""<s><|im_start|> system
- {request.prompt}<|im_end|>
- <|im_start|> user
- {request.conversation_transcription}<|im_end|>
- <|im_start|> assistant
- """
- total_tokens = len(model.tokenize(prompt_text.encode()))
- logger.info(f"Total tokens in input: {total_tokens}")
- if total_tokens > CONTEXT_SIZE - MAX_OUTPUT_TOKENS:
- logger.warning("Input is too long for the model context size")
- raise HTTPException(status_code=400, detail="Input text is too long")
- response = model(
- prompt_text,
- max_tokens=MAX_OUTPUT_TOKENS,
- temperature=0.7,
- echo=False
- )
- generated_text = response["choices"][0]["text"].strip()
- logger.info(f"Successfully generated response for conversation id {request.conversation_id}")
- return GenerationResponse(
- conversation_id=request.conversation_id,
- model_response=generated_text
- )
- except Exception as e:
- logger.error(f"Error generating response: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- finally:
- # Always reset processing flag when done, regardless of success or failure
- with processing_lock:
- is_processing = False
- @app.get("/health")
- async def health_check():
- return {
- "status": "healthy",
- "context_size": CONTEXT_SIZE,
- "model_path": MODEL_PATH,
- }
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8080)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement