Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # gemini.py
- import google.generativeai as genai
- from dotenv import load_dotenv
- import os
- load_dotenv()
- API_KEY = os.getenv("GEMINI_API_KEY")
- if not API_KEY:
- raise ValueError("GEMINI_API_KEY is missing! Please check your .env file.")
- genai.configure(api_key=API_KEY)
- model = genai.GenerativeModel('gemini-2.0-flash')
- # main.py
- from fastapi import FastAPI, Depends, HTTPException, Header, Query, WebSocket, WebSocketDisconnect
- from fastapi.responses import JSONResponse
- from fastapi.middleware.cors import CORSMiddleware
- from sqlalchemy.orm import Session
- from sqlalchemy import func, exists
- from datetime import datetime, timezone, timedelta
- from app.database import SessionLocal
- from app.models import *
- from app.schemas import *
- from app.gemini import *
- import json
- import asyncio
- import re
- import os
- # FastAPI app
- app = FastAPI()
- # Store connected clients
- active_connections = []
- # Dictionary to store chat history per WebSocket session
- chat_sessions = {}
- @app.websocket("/chat_ws")
- async def websocket_endpoint(websocket: WebSocket):
- """ WebSocket to handle real-time AI study chat based on markdown content """
- await websocket.accept()
- active_connections.append(websocket)
- session_id = id(websocket) # Unique identifier for session
- chat_sessions[session_id] = [] # Initialize chat history
- try:
- await websocket.send_text("Hi! How can I *help* you?")
- while True:
- data = await websocket.receive_text()
- try:
- request = json.loads(data)
- markdown_text = request.get("markdown", "").strip()
- user_prompt = request.get("prompt", "").strip()
- # Handle missing fields
- if not markdown_text or not user_prompt:
- await websocket.send_text("Error: Both 'markdown' and 'prompt' fields are required.")
- continue
- # Maintain context (limit last 5 messages)
- chat_context = "\n".join(chat_sessions[session_id][-5:])
- # Construct the AI prompt
- ai_prompt = f"""
- You are an AI tutor. The following is a lecture in Markdown format:
- {markdown_text}
- Previous chat history:
- {chat_context}
- Based on this content, User: {user_prompt}
- Your sole task is to assist with the provided markdown content. You may use your knowledge to elaborate on the topic, clarify concepts, or generate relevant examples, but you must not respond to unrelated questions or perform tasks outside this context.
- """
- # Run the AI model in a separate thread to avoid blocking the event loop
- response = await asyncio.to_thread(model.generate_content, ai_prompt)
- # Clean response by removing extra newlines
- cleaned_response = re.sub(r'\n+', '\n', response.text).strip()
- # Store conversation
- chat_sessions[session_id].append(f"User: {user_prompt}")
- chat_sessions[session_id].append(f"AI: {cleaned_response}")
- await websocket.send_text(cleaned_response)
- except json.JSONDecodeError:
- await websocket.send_text("Error: Invalid JSON format.")
- except WebSocketDisconnect:
- active_connections.remove(websocket)
- chat_sessions.pop(session_id, None) # Remove chat history on disconnect
- print("Client disconnected")
- except Exception as e:
- print(f"Error in WebSocket: {e}")
- await websocket.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement