Source code for opendp.smartnoise.sql.reader.postgres

import os

from .base import SqlReader, NameCompare, Serializer
from .engine import Engine


[docs]class PostgresReader(SqlReader): """ A dumb pipe that gets a rowset back from a database using a SQL string, and converts types to some useful subset """ ENGINE = Engine.POSTGRES def __init__(self, host=None, database=None, user=None, password=None, port=None, conn=None, **kwargs): super().__init__(self.ENGINE) import psycopg2 self.api = psycopg2 self.conn = None if conn is not None: self.conn = conn else: # generate a connection string self.host = host self.database = database self.user = user self.port = port if password is None: if "POSTGRES_PASSWORD" in os.environ: password = os.environ["POSTGRES_PASSWORD"] self.password = password self._update_connection_string()
[docs] def execute(self, query, *ignore, accuracy:bool=False): if not isinstance(query, str): raise ValueError("Please pass strings to execute. To execute ASTs, use execute_typed.") cnxn = self.conn if cnxn is None: cnxn = self.api.connect(self.connection_string) cursor = cnxn.cursor() cursor.execute(str(query)) if cursor.description is None: return [] else: col_names = [tuple(desc[0] for desc in cursor.description)] rows = [row for row in cursor] return col_names + rows
def _update_connection_string(self): self.connection_string = "user='{0}' host='{1}'".format(self.user, self.host) self.connection_string += ( " dbname='{0}'".format(self.database) if self.database is not None else "" ) self.connection_string += " port={0}".format(self.port) if self.port is not None else "" self.connection_string += ( " password='{0}'".format(self.password) if self.password is not None else "" )
[docs] def switch_database(self, dbname): sql = "\\c " + dbname self.execute(sql)
class PostgresNameCompare(NameCompare): def __init__(self, search_path=None): self.search_path = search_path if search_path is not None else ["public"] def identifier_match(self, query, meta): query = self.clean_escape(query) meta = self.clean_escape(meta) if query == meta: return True if self.is_escaped(meta) and meta.lower() == meta: meta = meta.lower().replace('"', "") if self.is_escaped(query) and query.lower() == query: query = query.lower().replace('"', "") return meta == query def should_escape(self, identifier): if self.is_escaped(identifier): return False if identifier.lower() in self.reserved(): return True if identifier.replace(" ", "") == identifier.lower(): return False else: return True class PostgresSerializer(Serializer): def __init__(self): super().__init__()