|
| 1 | +import os |
| 2 | +import psycopg2 |
| 3 | +from typing import Dict, Any |
| 4 | + |
| 5 | +def get_connection(): |
| 6 | + """Get PostgreSQL database connection""" |
| 7 | + return psycopg2.connect( |
| 8 | + dbname=os.getenv('POSTGRES_DB'), |
| 9 | + user=os.getenv('POSTGRES_USER'), |
| 10 | + password=os.getenv('POSTGRES_PASSWORD'), |
| 11 | + host=os.getenv('POSTGRES_HOST', 'localhost'), |
| 12 | + port=os.getenv('POSTGRES_PORT', '5432') |
| 13 | + ) |
| 14 | + |
| 15 | +def get_schema() -> Dict[str, Any]: |
| 16 | + """ |
| 17 | + Initialize connection and get database schema. |
| 18 | + Returns a dictionary containing the database schema. |
| 19 | + """ |
| 20 | + try: |
| 21 | + conn = get_connection() |
| 22 | + cursor = conn.cursor() |
| 23 | + |
| 24 | + # Query to get all tables in the current schema |
| 25 | + schema_query = """ |
| 26 | + SELECT table_name |
| 27 | + FROM information_schema.tables |
| 28 | + WHERE table_schema = 'public' |
| 29 | + AND table_type = 'BASE TABLE'; |
| 30 | + """ |
| 31 | + |
| 32 | + cursor.execute(schema_query) |
| 33 | + tables = cursor.fetchall() |
| 34 | + |
| 35 | + # Create schema dictionary |
| 36 | + schema = {} |
| 37 | + for (table_name,) in tables: |
| 38 | + # Get column information for each table |
| 39 | + column_query = """ |
| 40 | + SELECT column_name |
| 41 | + FROM information_schema.columns |
| 42 | + WHERE table_schema = 'public' |
| 43 | + AND table_name = %s; |
| 44 | + """ |
| 45 | + cursor.execute(column_query, (table_name,)) |
| 46 | + columns = [col[0] for col in cursor.fetchall()] |
| 47 | + schema[table_name] = columns |
| 48 | + |
| 49 | + cursor.close() |
| 50 | + conn.close() |
| 51 | + return schema |
| 52 | + |
| 53 | + except Exception as e: |
| 54 | + print(f"Error getting database schema: {str(e)}") |
| 55 | + return {} |
| 56 | + |
| 57 | +def execute_query(query: str) -> list: |
| 58 | + """ |
| 59 | + Execute a SQL query on the database. |
| 60 | + Args: |
| 61 | + query: SQL query to execute |
| 62 | + Returns: |
| 63 | + List of query results |
| 64 | + """ |
| 65 | + try: |
| 66 | + conn = get_connection() |
| 67 | + cursor = conn.cursor() |
| 68 | + |
| 69 | + # Execute the query |
| 70 | + cursor.execute(query) |
| 71 | + results = cursor.fetchall() |
| 72 | + |
| 73 | + cursor.close() |
| 74 | + conn.close() |
| 75 | + return results |
| 76 | + |
| 77 | + except Exception as e: |
| 78 | + print(f"Error executing query: {str(e)}") |
| 79 | + return [] |
0 commit comments