You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 2.1 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. import subprocess
  3. from django.core.files.temp import NamedTemporaryFile
  4. from django.db.backends.base.client import BaseDatabaseClient
  5. from django.utils.six import print_
  6. def _escape_pgpass(txt):
  7. """
  8. Escape a fragment of a PostgreSQL .pgpass file.
  9. """
  10. return txt.replace('\\', '\\\\').replace(':', '\\:')
  11. class DatabaseClient(BaseDatabaseClient):
  12. executable_name = 'psql'
  13. @classmethod
  14. def runshell_db(cls, conn_params):
  15. args = [cls.executable_name]
  16. host = conn_params.get('host', '')
  17. port = conn_params.get('port', '')
  18. dbname = conn_params.get('database', '')
  19. user = conn_params.get('user', '')
  20. passwd = conn_params.get('password', '')
  21. if user:
  22. args += ['-U', user]
  23. if host:
  24. args += ['-h', host]
  25. if port:
  26. args += ['-p', str(port)]
  27. args += [dbname]
  28. temp_pgpass = None
  29. try:
  30. if passwd:
  31. # Create temporary .pgpass file.
  32. temp_pgpass = NamedTemporaryFile(mode='w+')
  33. try:
  34. print_(
  35. _escape_pgpass(host) or '*',
  36. str(port) or '*',
  37. _escape_pgpass(dbname) or '*',
  38. _escape_pgpass(user) or '*',
  39. _escape_pgpass(passwd),
  40. file=temp_pgpass,
  41. sep=':',
  42. flush=True,
  43. )
  44. os.environ['PGPASSFILE'] = temp_pgpass.name
  45. except UnicodeEncodeError:
  46. # If the current locale can't encode the data, we let
  47. # the user input the password manually.
  48. pass
  49. subprocess.call(args)
  50. finally:
  51. if temp_pgpass:
  52. temp_pgpass.close()
  53. if 'PGPASSFILE' in os.environ: # unit tests need cleanup
  54. del os.environ['PGPASSFILE']
  55. def runshell(self):
  56. DatabaseClient.runshell_db(self.connection.get_connection_params())