diff --git a/lib/AbstractConnection.php b/lib/AbstractConnection.php index dae5fd2..a801be0 100644 --- a/lib/AbstractConnection.php +++ b/lib/AbstractConnection.php @@ -9,11 +9,11 @@ use Amp\Deferred; use Amp\Promise; use function Amp\call; -abstract class AbstractConnection implements Connection { +abstract class AbstractConnection implements Connection, Handle { use CallableMaker; /** @var \Amp\Postgres\Executor */ - private $executor; + private $handle; /** @var \Amp\Deferred|null Used to only allow one transaction at a time. */ private $busy; @@ -30,10 +30,10 @@ abstract class AbstractConnection implements Connection { abstract public static function connect(string $connectionString, CancellationToken $token = null): Promise; /** - * @param $executor; + * @param \Amp\Postgres\Handle $handle */ - public function __construct(Executor $executor) { - $this->executor = $executor; + public function __construct(Handle $handle) { + $this->handle = $handle; $this->release = $this->callableFromInstanceMethod("release"); } @@ -54,7 +54,7 @@ abstract class AbstractConnection implements Connection { $this->busy = new Deferred; try { - return $this->executor->{$methodName}(...$args); + return $this->handle->{$methodName}(...$args); } finally { $this->release(); } @@ -118,28 +118,42 @@ abstract class AbstractConnection implements Connection { switch ($isolation) { case Transaction::UNCOMMITTED: - yield $this->executor->query("BEGIN TRANSACTION ISOLATION LEVEL READ UNCOMMITTED"); + yield $this->handle->query("BEGIN TRANSACTION ISOLATION LEVEL READ UNCOMMITTED"); break; case Transaction::COMMITTED: - yield $this->executor->query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); + yield $this->handle->query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); break; case Transaction::REPEATABLE: - yield $this->executor->query("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"); + yield $this->handle->query("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"); break; case Transaction::SERIALIZABLE: - yield $this->executor->query("BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE"); + yield $this->handle->query("BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE"); break; default: throw new \Error("Invalid transaction type"); } - $transaction = new Transaction($this->executor, $isolation); + $transaction = new Transaction($this->handle, $isolation); $transaction->onComplete($this->release); return $transaction; }); } + + /** + * {@inheritdoc} + */ + public function quoteString(string $data): string { + return $this->handle->quoteString($data); + } + + /** + * {@inheritdoc} + */ + public function quoteName(string $name): string { + return $this->handle->quoteName($name); + } } diff --git a/lib/AbstractPool.php b/lib/AbstractPool.php index 09f1de8..953b5b1 100644 --- a/lib/AbstractPool.php +++ b/lib/AbstractPool.php @@ -2,11 +2,14 @@ namespace Amp\Postgres; +use Amp\CallableMaker; use Amp\Coroutine; use Amp\Deferred; use Amp\Promise; abstract class AbstractPool implements Pool { + use CallableMaker; + /** @var \SplQueue */ private $idle; @@ -28,6 +31,9 @@ abstract class AbstractPool implements Pool { /** @var int Number of listeners on listening connection. */ private $listenerCount = 0; + /** @var callable */ + private $push; + /** * @return \Amp\Promise<\Amp\Postgres\Connection> * @@ -36,9 +42,21 @@ abstract class AbstractPool implements Pool { abstract protected function createConnection(): Promise; public function __construct() { - $this->connections = new \SplObjectStorage(); - $this->idle = new \SplQueue(); - $this->busy = new \SplQueue(); + $this->connections = new \SplObjectStorage; + $this->idle = new \SplQueue; + $this->busy = new \SplQueue; + $this->push = $this->callableFromInstanceMethod("push"); + } + + /** + * @return \Amp\Promise<\Amp\Postgres\PooledConnection> + */ + public function getConnection(): Promise { + return new Coroutine($this->doGetConnection()); + } + + private function doGetConnection(): \Generator { + return new PooledConnection(yield from $this->pop(), $this->push); } /** diff --git a/lib/Handle.php b/lib/Handle.php new file mode 100644 index 0000000..439ba25 --- /dev/null +++ b/lib/Handle.php @@ -0,0 +1,24 @@ +statements[$name] = $storage = new Internal\StatementStorage; - $storage->promise = call(function () use ($name, $sql) { + $promise = $storage->promise = call(function () use ($name, $sql) { /** @var resource $result PostgreSQL result resource. */ $result = yield from $this->send("pg_send_prepare", $name, $sql); @@ -277,7 +277,7 @@ class PgSqlExecutor implements Executor { // @codeCoverageIgnoreEnd } }); - $storage->promise->onResolve(function ($exception) use ($storage, $name) { + $promise->onResolve(function ($exception) use ($storage, $name) { if ($exception) { unset($this->statements[$name]); return; @@ -285,7 +285,7 @@ class PgSqlExecutor implements Executor { $storage->promise = null; }); - return $storage->promise; + return $promise; } /** @@ -293,10 +293,10 @@ class PgSqlExecutor implements Executor { */ public function notify(string $channel, string $payload = ""): Promise { if ($payload === "") { - return $this->query(\sprintf("NOTIFY %s", $channel)); + return $this->query(\sprintf("NOTIFY %s", $this->quoteName($channel))); } - return $this->query(\sprintf("NOTIFY %s, '%s'", $channel, $payload)); + return $this->query(\sprintf("NOTIFY %s, %s", $this->quoteName($channel), $this->quoteString($payload))); } /** @@ -311,7 +311,7 @@ class PgSqlExecutor implements Executor { $this->listeners[$channel] = $emitter = new Emitter; try { - yield $this->query(\sprintf("LISTEN %s", $channel)); + yield $this->query(\sprintf("LISTEN %s", $this->quoteName($channel))); } catch (\Throwable $exception) { unset($this->listeners[$channel]); throw $exception; @@ -339,8 +339,22 @@ class PgSqlExecutor implements Executor { Loop::disable($this->poll); } - $promise = $this->query(\sprintf("UNLISTEN %s", $channel)); + $promise = $this->query(\sprintf("UNLISTEN %s", $this->quoteName($channel))); $promise->onResolve([$emitter, "complete"]); return $promise; } + + /** + * {@inheritdoc} + */ + public function quoteString(string $data): string { + return \pg_escape_literal($this->handle, $data); + } + + /** + * {@inheritdoc} + */ + public function quoteName(string $name): string { + return \pg_escape_identifier($this->handle, $name); + } } diff --git a/lib/Pool.php b/lib/Pool.php index 2d2a2db..207f8a1 100644 --- a/lib/Pool.php +++ b/lib/Pool.php @@ -2,7 +2,14 @@ namespace Amp\Postgres; +use Amp\Promise; + interface Pool extends Connection { + /** + * @return \Amp\Promise<\Amp\Postgres\PooledConnection> + */ + public function getConnection(): Promise; + /** * @return int Current number of connections in the pool. */ diff --git a/lib/PooledConnection.php b/lib/PooledConnection.php new file mode 100644 index 0000000..58b6a4a --- /dev/null +++ b/lib/PooledConnection.php @@ -0,0 +1,85 @@ +connection = $connection; + $this->push = $push; + } + + public function __destruct() { + ($this->push)($this->connection); + } + + /** + * {@inheritdoc} + */ + public function transaction(int $isolation = Transaction::COMMITTED): Promise { + return $this->connection->transaction($isolation); + } + + /** + * {@inheritdoc} + */ + public function listen(string $channel): Promise { + return $this->connection->listen($channel); + } + + /** + * {@inheritdoc} + */ + public function query(string $sql): Promise { + return $this->connection->query($sql); + } + + /** + * {@inheritdoc} + */ + public function execute(string $sql, ...$params): Promise { + return $this->connection->execute($sql, ...$params); + } + + /** + /** + * {@inheritdoc} + */ + public function prepare(string $sql): Promise { + return $this->connection->prepare($sql); + } + + /** + * {@inheritdoc} + */ + public function notify(string $channel, string $payload = ""): Promise { + return $this->connection->notify($channel, $payload); + } + + /** + * {@inheritdoc} + */ + public function quoteString(string $data): string { + return $this->connection->quoteString($data); + } + + /** + * {@inheritdoc} + */ + public function quoteName(string $name): string { + return $this->connection->quoteName($name); + } +} diff --git a/lib/PqConnection.php b/lib/PqConnection.php index cf19c70..f14ed6a 100644 --- a/lib/PqConnection.php +++ b/lib/PqConnection.php @@ -23,7 +23,6 @@ class PqConnection extends AbstractConnection { } catch (pq\Exception $exception) { return new Failure(new FailureException("Could not connect to PostgresSQL server", 0, $exception)); } - $connection->resetAsync(); $connection->nonblocking = true; $connection->unbuffered = true; @@ -68,6 +67,6 @@ class PqConnection extends AbstractConnection { * @param \pq\Connection $handle */ public function __construct(pq\Connection $handle) { - parent::__construct(new PqExecutor($handle)); + parent::__construct(new PqHandle($handle)); } } diff --git a/lib/PqExecutor.php b/lib/PqHandle.php similarity index 94% rename from lib/PqExecutor.php rename to lib/PqHandle.php index f873fda..84cb594 100644 --- a/lib/PqExecutor.php +++ b/lib/PqHandle.php @@ -13,7 +13,7 @@ use pq; use function Amp\call; use function Amp\coroutine; -class PqExecutor implements Executor { +class PqHandle implements Handle { use CallableMaker; /** @var \pq\Connection PostgreSQL connection object. */ @@ -190,6 +190,9 @@ class PqExecutor implements Executor { $this->deferred = new Deferred; Loop::enable($this->poll); + if (!$this->handle->flush()) { + Loop::enable($this->await); + } try { $result = yield $this->deferred->promise(); @@ -268,12 +271,12 @@ class PqExecutor implements Executor { $this->statements[$name] = $storage = new Internal\PqStatementStorage; - $storage->promise = call(function () use ($storage, $name, $sql) { + $promise = $storage->promise = call(function () use ($storage, $name, $sql) { $statement = yield from $this->send([$this->handle, "prepareAsync"], $name, $sql); $storage->statement = $statement; return new PqStatement($statement, $this->send, $this->deallocate); }); - $storage->promise->onResolve(function ($exception) use ($storage, $name) { + $promise->onResolve(function ($exception) use ($storage, $name) { if ($exception) { unset($this->statements[$name]); return; @@ -281,7 +284,7 @@ class PqExecutor implements Executor { $storage->promise = null; }); - return $storage->promise; + return $promise; } /** @@ -345,4 +348,18 @@ class PqExecutor implements Executor { $promise->onResolve([$emitter, "complete"]); return $promise; } + + /** + * {@inheritdoc} + */ + public function quoteString(string $data): string { + return $this->handle->quote($data); + } + + /** + * {@inheritdoc} + */ + public function quoteName(string $name): string { + return $this->handle->quoteName($name); + } } diff --git a/lib/Transaction.php b/lib/Transaction.php index bcdb2cf..dc35a3e 100644 --- a/lib/Transaction.php +++ b/lib/Transaction.php @@ -5,7 +5,7 @@ namespace Amp\Postgres; use Amp\CallableMaker; use Amp\Promise; -class Transaction implements Executor, Operation { +class Transaction implements Handle, Operation { use Internal\Operation, CallableMaker; const UNCOMMITTED = 0; @@ -13,19 +13,19 @@ class Transaction implements Executor, Operation { const REPEATABLE = 2; const SERIALIZABLE = 4; - /** @var \Amp\Postgres\Executor */ - private $executor; + /** @var \Amp\Postgres\Handle */ + private $handle; /** @var int */ private $isolation; /** - * @param \Amp\Postgres\Executor $executor + * @param \Amp\Postgres\Handle $handle * @param int $isolation * * @throws \Error If the isolation level is invalid. */ - public function __construct(Executor $executor, int $isolation = self::COMMITTED) { + public function __construct(Handle $handle, int $isolation = self::COMMITTED) { switch ($isolation) { case self::UNCOMMITTED: case self::COMMITTED: @@ -38,11 +38,11 @@ class Transaction implements Executor, Operation { throw new \Error("Isolation must be a valid transaction isolation level"); } - $this->executor = $executor; + $this->handle = $handle; } public function __destruct() { - if ($this->executor) { + if ($this->handle) { $this->rollback(); // Invokes $this->complete(). } } @@ -51,7 +51,7 @@ class Transaction implements Executor, Operation { * @return bool True if the transaction is active, false if it has been committed or rolled back. */ public function isActive(): bool { - return $this->executor !== null; + return $this->handle !== null; } /** @@ -67,11 +67,11 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function query(string $sql): Promise { - if ($this->executor === null) { + if ($this->handle === null) { throw new TransactionError("The transaction has been committed or rolled back"); } - return $this->executor->query($sql); + return $this->handle->query($sql); } /** @@ -80,11 +80,11 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function prepare(string $sql): Promise { - if ($this->executor === null) { + if ($this->handle === null) { throw new TransactionError("The transaction has been committed or rolled back"); } - return $this->executor->prepare($sql); + return $this->handle->prepare($sql); } /** @@ -93,11 +93,11 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function execute(string $sql, ...$params): Promise { - if ($this->executor === null) { + if ($this->handle === null) { throw new TransactionError("The transaction has been committed or rolled back"); } - return $this->executor->execute($sql, ...$params); + return $this->handle->execute($sql, ...$params); } @@ -107,11 +107,11 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function notify(string $channel, string $payload = ""): Promise { - if ($this->executor === null) { + if ($this->handle === null) { throw new TransactionError("The transaction has been committed or rolled back"); } - return $this->executor->notify($channel, $payload); + return $this->handle->notify($channel, $payload); } /** @@ -122,12 +122,12 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function commit(): Promise { - if ($this->executor === null) { + if ($this->handle === null) { throw new TransactionError("The transaction has been committed or rolled back"); } - $promise = $this->executor->query("COMMIT"); - $this->executor = null; + $promise = $this->handle->query("COMMIT"); + $this->handle = null; $promise->onResolve($this->callableFromInstanceMethod("complete")); return $promise; @@ -141,12 +141,12 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function rollback(): Promise { - if ($this->executor === null) { + if ($this->handle === null) { throw new TransactionError("The transaction has been committed or rolled back"); } - $promise = $this->executor->query("ROLLBACK"); - $this->executor = null; + $promise = $this->handle->query("ROLLBACK"); + $this->handle = null; $promise->onResolve($this->callableFromInstanceMethod("complete")); return $promise; @@ -162,12 +162,11 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function savepoint(string $identifier): Promise { - return $this->query("SAVEPOINT " . $identifier); + return $this->query("SAVEPOINT " . $this->quoteName($identifier)); } /** - * Rolls back to the savepoint with the given identifier. WARNING: Identifier is not sanitized, do not pass - * untrusted data. + * Rolls back to the savepoint with the given identifier. * * @param string $identifier Savepoint identifier. * @@ -176,7 +175,7 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function rollbackTo(string $identifier): Promise { - return $this->query("ROLLBACK TO " . $identifier); + return $this->query("ROLLBACK TO " . $this->quoteName($identifier)); } /** @@ -190,6 +189,32 @@ class Transaction implements Executor, Operation { * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. */ public function release(string $identifier): Promise { - return $this->query("RELEASE SAVEPOINT " . $identifier); + return $this->query("RELEASE SAVEPOINT " . $this->quoteName($identifier)); + } + + /** + * {@inheritdoc} + * + * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. + */ + public function quoteString(string $data): string { + if ($this->handle === null) { + throw new TransactionError("The transaction has been committed or rolled back"); + } + + return $this->handle->quoteString($data); + } + + /** + * {@inheritdoc} + * + * @throws \Amp\Postgres\TransactionError If the transaction has been committed or rolled back. + */ + public function quoteName(string $name): string { + if ($this->handle === null) { + throw new TransactionError("The transaction has been committed or rolled back"); + } + + return $this->handle->quoteName($name); } } diff --git a/test/AbstractPoolTest.php b/test/AbstractPoolTest.php index 25c7b7e..bba96cc 100644 --- a/test/AbstractPoolTest.php +++ b/test/AbstractPoolTest.php @@ -3,8 +3,9 @@ namespace Amp\Postgres\Test; use Amp\Loop; +use Amp\Postgres\AbstractConnection; use Amp\Postgres\CommandResult; -use Amp\Postgres\Connection; +use Amp\Postgres\PooledConnection; use Amp\Postgres\Statement; use Amp\Postgres\Transaction; use Amp\Postgres\TupleResult; @@ -22,10 +23,12 @@ abstract class AbstractPoolTest extends TestCase { abstract protected function createPool(array $connections); /** - * @return \PHPUnit_Framework_MockObject_MockObject|\Amp\Postgres\Connection + * @return \PHPUnit_Framework_MockObject_MockObject|\Amp\Postgres\AbstractConnection */ private function createConnection() { - return $this->createMock(Connection::class); + return $this->getMockBuilder(AbstractConnection::class) + ->disableOriginalConstructor() + ->getMock(); } /** @@ -33,7 +36,7 @@ abstract class AbstractPoolTest extends TestCase { * * @return \Amp\Postgres\Connection[]|\PHPUnit_Framework_MockObject_MockObject[] */ - private function makeConnectionSet($count) { + private function makeConnectionSet(int $count) { $connections = []; for ($i = 0; $i < $count; ++$i) { @@ -63,7 +66,7 @@ abstract class AbstractPoolTest extends TestCase { * @param string $resultClass * @param mixed ...$params */ - public function testSingleQuery($count, $method, $resultClass, ...$params) { + public function testSingleQuery(int $count, string $method, string $resultClass, ...$params) { $result = $this->getMockBuilder($resultClass) ->disableOriginalConstructor() ->getMock(); @@ -93,7 +96,7 @@ abstract class AbstractPoolTest extends TestCase { * @param string $resultClass * @param mixed ...$params */ - public function testConsecutiveQueries($count, $method, $resultClass, ...$params) { + public function testConsecutiveQueries(int $count, string $method, string $resultClass, ...$params) { $rounds = 3; $result = $this->getMockBuilder($resultClass) ->disableOriginalConstructor() @@ -136,7 +139,7 @@ abstract class AbstractPoolTest extends TestCase { * * @param int $count */ - public function testTransaction($count) { + public function testTransaction(int $count) { $connections = $this->makeConnectionSet($count); $connection = $connections[0]; @@ -163,7 +166,7 @@ abstract class AbstractPoolTest extends TestCase { * * @param int $count */ - public function testConsecutiveTransactions($count) { + public function testConsecutiveTransactions(int $count) { $rounds = 3; $result = $this->getMockBuilder(Transaction::class) ->disableOriginalConstructor() @@ -200,4 +203,36 @@ abstract class AbstractPoolTest extends TestCase { } }); } + + /** + * @dataProvider getConnectionCounts + * + * @param int $count + */ + public function testGetConnection(int $count) { + $connections = $this->makeConnectionSet($count); + $query = "SELECT * FROM test"; + + foreach ($connections as $connection) { + $connection->expects($this->once()) + ->method('query') + ->with($query); + } + + $pool = $this->createPool($connections); + + Loop::run(function () use ($pool, $query, $count) { + $promises = []; + for ($i = 0; $i < $count; ++$i) { + $promises[] = $pool->getConnection(); + } + + $results = yield Promise\all($promises); + + foreach ($results as $result) { + $this->assertInstanceof(PooledConnection::class, $result); + $result->query($query); + } + }); + } }