diff --git a/app/api/gemini/gemini.py b/app/api/gemini/gemini.py index a0e1534..700adaf 100644 --- a/app/api/gemini/gemini.py +++ b/app/api/gemini/gemini.py @@ -1,9 +1,57 @@ -from fastapi import APIRouter +import os +from fastapi import APIRouter, HTTPException +from app.utils.make_meta import make_meta router = APIRouter() @router.get("/gemini") def root() -> dict: """GET /gemini endpoint.""" - return {"message": "Welcome to the Gemini API!"} + meta = make_meta("success", "Gemini endpoint says hello") + return {"meta": meta} + +@router.post("/gemini") +def gemini_post(payload: dict) -> dict: + """POST /gemini: send prompt to Gemini, returns completion google-genai SDK.""" + prompt = payload.get("prompt") + if not prompt: + raise HTTPException(status_code=400, detail="Missing 'prompt' in request body.") + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + raise HTTPException(status_code=500, detail="Gemini API key not configured.") + import logging + try: + from google import genai + client = genai.Client(api_key=api_key) + # Try a list of known Gemini model names + model_names = [ + "models/gemini-flash-latest", + "models/gemini-1.5-pro", + "models/gemini-1.5-flash", + "models/gemini-1.0-pro", + "models/gemini-pro", + "models/gemini-pro-vision" + ] + response = None + completion = None + used_model = None + errors = {} + for model_name in model_names: + try: + response = client.models.generate_content(model=model_name, contents=prompt) + completion = getattr(response, "text", None) + if completion: + used_model = model_name + break + except Exception as e: + errors[model_name] = str(e) + continue + if not completion: + error_details = " | ".join([f"{k}: {v}" for k, v in errors.items()]) + raise Exception(f"No available Gemini model succeeded for generate_content with your API key. Details: {error_details}") + except Exception as e: + meta = make_meta("error", f"Gemini API error: {str(e)}") + return {"meta": meta, "data": {}} + meta = make_meta("success", f"Gemini completion received from {used_model}") + return {"meta": meta, "data": {"prompt": prompt, "completion": completion}} diff --git a/requirements.txt b/requirements.txt index 036c0d8..3193d5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ python-dotenv>=1.0.0 psycopg2-binary>=2.9.0 python-multipart>=0.0.20 Faker>=25.2.0 +google-genai>=0.3.0 diff --git a/tests/test_gemini.py b/tests/test_gemini.py index 2e8eefd..51f8e6d 100644 --- a/tests/test_gemini.py +++ b/tests/test_gemini.py @@ -1,3 +1,20 @@ +import os +import pytest +def test_gemini_real_api(): + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + pytest.skip("GEMINI_API_KEY not set; skipping real Gemini API test.") + from google import genai + client = genai.Client(api_key=api_key) + try: + response = client.models.generate_content( + model="models/gemini-flash-latest", + contents="Say hello from Gemini!" + ) + completion = getattr(response, "text", None) + assert completion is not None and "hello" in completion.lower() + except Exception as e: + pytest.fail(f"Gemini real API call failed: {e}") import sys import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) @@ -8,7 +25,39 @@ client = TestClient(app) -def test_gemini_endpoint(): + +def test_gemini_get_endpoint(): response = client.get("/gemini") assert response.status_code == 200 - assert response.json() == {"message": "Welcome to the Gemini API!"} + data = response.json() + assert "meta" in data + assert data["meta"]["severity"] == "success" + assert "Gemini endpoint says hello" in data["meta"]["title"] + + +def test_gemini_post_endpoint(monkeypatch): + # Mock google-genai SDK to avoid real API call + class MockGenAIResponse: + text = "Test completion" + + class MockGenAIModel: + def generate_content(self, model, contents): + return MockGenAIResponse() + + class MockGenAIClient: + models = MockGenAIModel() + + monkeypatch.setattr("google.genai.Client", lambda *args, **kwargs: MockGenAIClient()) + + payload = {"prompt": "Test prompt"} + response = client.post("/gemini", json=payload) + assert response.status_code == 200 + data = response.json() + assert "meta" in data + assert data["meta"]["severity"] == "success" + assert "Gemini completion received" in data["meta"]["title"] + assert data["data"]["prompt"] == "Test prompt" + assert data["data"]["completion"] == "Test completion" + assert "data" in data + assert data["data"]["prompt"] == "Test prompt" + assert data["data"]["completion"] == "Test completion"