Advertisement
tills

restapi_bielik_v2

Nov 28th, 2024
389
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.12 KB | Source Code | 0 0
  1. import os
  2. import time
  3. import logging
  4. from fastapi import FastAPI, HTTPException, Request
  5. from fastapi.responses import JSONResponse
  6. from pydantic import BaseModel
  7. from llama_cpp import Llama
  8. from threading import Lock
  9.  
  10. log_directory = "/home/tomand/code/bielik_serving_final"
  11. os.makedirs(log_directory, exist_ok=True)
  12.  
  13. logging.basicConfig(
  14.     level=logging.INFO,
  15.     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  16.     handlers=[
  17.         logging.FileHandler(os.path.join(log_directory, 'error_logging.log')),
  18.         logging.StreamHandler()
  19.     ]
  20. )
  21. logger = logging.getLogger(__name__)
  22.  
  23. request_logger = logging.getLogger('request_logger')
  24. request_logger.setLevel(logging.INFO)
  25. request_handler = logging.FileHandler(os.path.join(log_directory, 'app.log'))
  26. request_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
  27. request_logger.addHandler(request_handler)
  28.  
  29. app = FastAPI(title="Bielik API Service")
  30.  
  31. MODEL_PATH = "/home/tomand/code/models/Bielik-11B-v2.3-Instruct.Q4_K_M.gguf"
  32. N_GPU_LAYERS = 0
  33. CONTEXT_SIZE = 16384
  34. MAX_OUTPUT_TOKENS = 512
  35.  
  36. # Global lock for thread-safe access to is_processing
  37. processing_lock = Lock()
  38. is_processing = False
  39.  
  40. class GenerationRequest(BaseModel):
  41.     conversation_id: str
  42.     conversation_transcription: str
  43.     prompt: str
  44.  
  45. class GenerationResponse(BaseModel):
  46.     conversation_id: str
  47.     model_response: str
  48.  
  49. try:
  50.     model = Llama(
  51.         model_path=MODEL_PATH,
  52.         n_ctx=CONTEXT_SIZE,
  53.         n_gpu_layers=N_GPU_LAYERS
  54.     )
  55.     logger.info("Model loaded successfully")
  56. except Exception as e:
  57.     logger.error(f"Error loading model: {str(e)}")
  58.     raise
  59.  
  60. async def iterate_bytes(chunks):
  61.     for chunk in chunks:
  62.         yield chunk
  63.  
  64. @app.middleware("http")
  65. async def log_requests(request: Request, call_next):
  66.     global is_processing
  67.    
  68.     if request.url.path == "/health":
  69.         return await call_next(request)
  70.    
  71.     client_ip = request.client.host
  72.    
  73.     # Check if service is busy (only for /generate endpoint)
  74.     if request.url.path == "/generate" and request.method == "POST":
  75.         with processing_lock:
  76.             if is_processing:
  77.                 logger.info(f"Request rejected - service busy: {client_ip}")
  78.                 return JSONResponse(
  79.                     status_code=503,
  80.                     content={"detail": "Service is busy processing another request"}
  81.                 )
  82.    
  83.     if request.method == "POST":
  84.         try:
  85.             request_body = await request.json()
  86.             request_logger.info(f"Request from {client_ip}: {request_body}")
  87.         except:
  88.             request_body = {}
  89.             logger.error("Could not parse request body as JSON")
  90.  
  91.     start_time = time.time()
  92.     response = await call_next(request)
  93.     process_time = time.time() - start_time
  94.  
  95.     response_body = [chunk async for chunk in response.body_iterator]
  96.     response.body_iterator = iterate_bytes(response_body)
  97.  
  98.     try:
  99.         response_text = b''.join(response_body).decode()
  100.         request_logger.info(f"Response to {client_ip}: {response_text}, Process time: {process_time:.4f} seconds")
  101.     except:
  102.         logger.error("Could not decode response body")
  103.  
  104.     return response
  105.  
  106. @app.post("/generate", response_model=GenerationResponse)
  107. async def generate_response(request: GenerationRequest):
  108.     global is_processing
  109.    
  110.     try:
  111.         # Set processing flag at the start of the request
  112.         with processing_lock:
  113.             is_processing = True
  114.            
  115.         prompt_text = f"""<s><|im_start|> system
  116.        {request.prompt}<|im_end|>
  117.        <|im_start|> user
  118.        {request.conversation_transcription}<|im_end|>
  119.        <|im_start|> assistant
  120.        """
  121.  
  122.         total_tokens = len(model.tokenize(prompt_text.encode()))
  123.         logger.info(f"Total tokens in input: {total_tokens}")
  124.  
  125.         if total_tokens > CONTEXT_SIZE - MAX_OUTPUT_TOKENS:
  126.             logger.warning("Input is too long for the model context size")
  127.             raise HTTPException(status_code=400, detail="Input text is too long")
  128.  
  129.         response = model(
  130.             prompt_text,
  131.             max_tokens=MAX_OUTPUT_TOKENS,
  132.             temperature=0.7,
  133.             echo=False
  134.         )
  135.  
  136.         generated_text = response["choices"][0]["text"].strip()
  137.        
  138.         logger.info(f"Successfully generated response for conversation id {request.conversation_id}")
  139.  
  140.         return GenerationResponse(
  141.             conversation_id=request.conversation_id,
  142.             model_response=generated_text
  143.         )
  144.     except Exception as e:
  145.         logger.error(f"Error generating response: {str(e)}")
  146.         raise HTTPException(status_code=500, detail=str(e))
  147.     finally:
  148.         # Always reset processing flag when done, regardless of success or failure
  149.         with processing_lock:
  150.             is_processing = False
  151.    
  152. @app.get("/health")
  153. async def health_check():
  154.     return {
  155.         "status": "healthy",
  156.         "context_size": CONTEXT_SIZE,
  157.         "model_path": MODEL_PATH,
  158.     }
  159.  
  160. if __name__ == "__main__":
  161.     import uvicorn
  162.     uvicorn.run(app, host="0.0.0.0", port=8080)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement