diff --git a/lib/Executor.php b/lib/Executor.php index e150aa6..02c72c5 100644 --- a/lib/Executor.php +++ b/lib/Executor.php @@ -5,6 +5,8 @@ namespace Amp\Postgres; use Amp\Promise; interface Executor { + const STATEMENT_NAME_PREFIX = "amp_"; + /** * @param string $sql * diff --git a/lib/Internal/PqStatementStorage.php b/lib/Internal/PqStatementStorage.php new file mode 100644 index 0000000..b2deffc --- /dev/null +++ b/lib/Internal/PqStatementStorage.php @@ -0,0 +1,8 @@ +query(\sprintf("DEALLOCATE %s", $name)); + private function sendDeallocate(string $name) { + \assert(isset($this->statements[$name]), "Named statement not found when deallocating"); + + $storage = $this->statements[$name]; + + if (--$storage->count) { + return; + } + + unset($this->statements[$name]); + + Promise\rethrow($this->query(\sprintf("DEALLOCATE %s", $name))); } /** @@ -225,11 +239,29 @@ class PgSqlExecutor implements Executor { * {@inheritdoc} */ public function prepare(string $sql): Promise { - return call(function () use ($sql) { - $name = "amphp" . \sha1($sql); + $name = self::STATEMENT_NAME_PREFIX . \sha1($sql); + + if (isset($this->statements[$name])) { + $storage = $this->statements[$name]; + ++$storage->count; + + if ($storage->promise) { + return $storage->promise; + } + + return new Success(new PgSqlStatement($name, $sql, $this->executeCallback, $this->deallocateCallback)); + } + + $this->statements[$name] = $storage = new Internal\StatementStorage; + + $storage->promise = call(function () use ($name, $sql) { yield from $this->send("pg_send_prepare", $name, $sql); return new PgSqlStatement($name, $sql, $this->executeCallback, $this->deallocateCallback); }); + $storage->promise->onResolve(function () use ($storage) { + $storage->promise = null; + }); + return $storage->promise; } /** @@ -274,9 +306,7 @@ class PgSqlExecutor implements Executor { * @throws \Error */ private function unlisten(string $channel): Promise { - if (!isset($this->listeners[$channel])) { - throw new \Error("Not listening on that channel"); - } + \assert(isset($this->listeners[$channel]), "Not listening on that channel"); $emitter = $this->listeners[$channel]; unset($this->listeners[$channel]); diff --git a/lib/PqExecutor.php b/lib/PqExecutor.php index d8fb1d0..698ce33 100644 --- a/lib/PqExecutor.php +++ b/lib/PqExecutor.php @@ -8,6 +8,7 @@ use Amp\Deferred; use Amp\Emitter; use Amp\Loop; use Amp\Promise; +use Amp\Success; use pq; use function Amp\call; use function Amp\coroutine; @@ -33,6 +34,9 @@ class PqExecutor implements Executor { /** @var \Amp\Emitter[] */ private $listeners; + /** @var \Amp\Postgres\Internal\PqStatementStorage[] */ + private $statements = []; + /** @var callable */ private $send; @@ -45,6 +49,9 @@ class PqExecutor implements Executor { /** @var callable */ private $release; + /** @var callable */ + private $deallocate; + /** * Connection constructor. * @@ -89,6 +96,7 @@ class PqExecutor implements Executor { $this->fetch = coroutine($this->callableFromInstanceMethod("fetch")); $this->unlisten = $this->callableFromInstanceMethod("unlisten"); $this->release = $this->callableFromInstanceMethod("release"); + $this->deallocate = $this->callableFromInstanceMethod("deallocate"); } /** @@ -134,10 +142,6 @@ class PqExecutor implements Executor { $this->deferred = null; } - if ($handle instanceof pq\Statement) { - return new PqStatement($handle, $this->send); - } - if (!$result instanceof pq\Result) { throw new FailureException("Unknown query result"); } @@ -152,6 +156,10 @@ class PqExecutor implements Executor { throw new QueryError("Empty query string"); case pq\Result::COMMAND_OK: + if ($handle instanceof pq\Statement) { + return $handle; // Will be wrapped into a PqStatement object. + } + return new PqCommandResult($result); case pq\Result::TUPLES_OK: @@ -208,6 +216,20 @@ class PqExecutor implements Executor { $deferred->resolve(); } + private function deallocate(string $name) { + \assert(isset($this->statements[$name]), "Named statement not found when deallocating"); + + $storage = $this->statements[$name]; + + if (--$storage->count) { + return; + } + + unset($this->statements[$name]); + + Promise\rethrow(new Coroutine($this->send([$storage->statement, "deallocateAsync"]))); + } + /** * {@inheritdoc} */ @@ -226,7 +248,30 @@ class PqExecutor implements Executor { * {@inheritdoc} */ public function prepare(string $sql): Promise { - return new Coroutine($this->send([$this->handle, "prepareAsync"], "amphp" . \sha1($sql), $sql)); + $name = self::STATEMENT_NAME_PREFIX . \sha1($sql); + + if (isset($this->statements[$name])) { + $storage = $this->statements[$name]; + ++$storage->count; + + if ($storage->promise) { + return $storage->promise; + } + + return new Success(new PqStatement($storage->statement, $name, $this->send, $this->deallocate)); + } + + $this->statements[$name] = $storage = new Internal\PqStatementStorage; + + $storage->promise = call(function () use ($storage, $name, $sql) { + $statement = yield from $this->send([$this->handle, "prepareAsync"], $name, $sql); + $storage->statement = $statement; + return new PqStatement($statement, $name, $this->send, $this->deallocate); + }); + $storage->promise->onResolve(function () use ($storage) { + $storage->promise = null; + }); + return $storage->promise; } /** @@ -277,9 +322,7 @@ class PqExecutor implements Executor { * @throws \Error */ private function unlisten(string $channel): Promise { - if (!isset($this->listeners[$channel])) { - throw new \Error("Not listening on that channel"); - } + \assert(isset($this->listeners[$channel]), "Not listening on that channel"); $emitter = $this->listeners[$channel]; unset($this->listeners[$channel]); diff --git a/lib/PqStatement.php b/lib/PqStatement.php index 5a7f724..6a0ecc1 100644 --- a/lib/PqStatement.php +++ b/lib/PqStatement.php @@ -9,22 +9,32 @@ class PqStatement implements Statement { /** @var \pq\Statement */ private $statement; + /** @var string */ + private $name; + /** @var callable */ private $execute; + /** @var callable */ + private $deallocate; + /** * @internal * * @param \pq\Statement $statement + * @param string $name * @param callable $execute + * @param callable $deallocate */ - public function __construct(pq\Statement $statement, callable $execute) { + public function __construct(pq\Statement $statement, string $name, callable $execute, callable $deallocate) { $this->statement = $statement; + $this->name = $name; $this->execute = $execute; + $this->deallocate = $deallocate; } public function __destruct() { - ($this->execute)([$this->statement, "deallocateAsync"]); + ($this->deallocate)($this->name); } /** diff --git a/test/AbstractConnectionTest.php b/test/AbstractConnectionTest.php index 1179a26..6baf924 100644 --- a/test/AbstractConnectionTest.php +++ b/test/AbstractConnectionTest.php @@ -9,6 +9,7 @@ use Amp\Postgres\CommandResult; use Amp\Postgres\Connection; use Amp\Postgres\Listener; use Amp\Postgres\QueryError; +use Amp\Postgres\Statement; use Amp\Postgres\Transaction; use Amp\Postgres\TransactionError; use Amp\Postgres\TupleResult; @@ -116,6 +117,92 @@ abstract class AbstractConnectionTest extends TestCase { }); } + /** + * @depends testPrepare + */ + public function testPrepareSameQuery() { + Loop::run(function () { + $sql = "SELECT * FROM test WHERE domain=\$1"; + + /** @var \Amp\Postgres\Statement $statement1 */ + $statement1 = yield $this->connection->prepare($sql); + + /** @var \Amp\Postgres\Statement $statement2 */ + $statement2 = yield $this->connection->prepare($sql); + + $this->assertInstanceOf(Statement::class, $statement1); + $this->assertInstanceOf(Statement::class, $statement2); + + unset($statement1); + + $data = $this->getData()[0]; + + /** @var \Amp\Postgres\TupleResult $result */ + $result = yield $statement2->execute($data[0]); + + $this->assertInstanceOf(TupleResult::class, $result); + + $this->assertSame(2, $result->numFields()); + + while (yield $result->advance()) { + $row = $result->getCurrent(); + $this->assertSame($data[0], $row['domain']); + $this->assertSame($data[1], $row['tld']); + } + }); + } + + /** + * @depends testPrepareSameQuery + */ + public function testSimultaneousPrepareSameQuery() { + Loop::run(function () { + $sql = "SELECT * FROM test WHERE domain=\$1"; + + $statement1 = $this->connection->prepare($sql); + $statement2 = $this->connection->prepare($sql); + + /** + * @var \Amp\Postgres\Statement $statement1 + * @var \Amp\Postgres\Statement $statement2 + */ + list($statement1, $statement2) = yield [$statement1, $statement2]; + + $this->assertInstanceOf(Statement::class, $statement1); + $this->assertInstanceOf(Statement::class, $statement2); + + $data = $this->getData()[0]; + + /** @var \Amp\Postgres\TupleResult $result */ + $result = yield $statement1->execute($data[0]); + + $this->assertInstanceOf(TupleResult::class, $result); + + $this->assertSame(2, $result->numFields()); + + while (yield $result->advance()) { + $row = $result->getCurrent(); + $this->assertSame($data[0], $row['domain']); + $this->assertSame($data[1], $row['tld']); + } + + unset($statement1); + + /** @var \Amp\Postgres\TupleResult $result */ + $result = yield $statement2->execute($data[0]); + + $this->assertInstanceOf(TupleResult::class, $result); + + $this->assertSame(2, $result->numFields()); + + while (yield $result->advance()) { + $row = $result->getCurrent(); + $this->assertSame($data[0], $row['domain']); + $this->assertSame($data[1], $row['tld']); + } + }); + } + public function testExecute() { Loop::run(function () { $data = $this->getData()[0]; @@ -154,7 +241,7 @@ abstract class AbstractConnectionTest extends TestCase { }); Loop::run(function () use ($callback) { - yield \Amp\Promise\all([$callback(0), $callback(1)]); + yield [$callback(0), $callback(1)]; }); } @@ -225,7 +312,7 @@ abstract class AbstractConnectionTest extends TestCase { })()); Loop::run(function () use ($promises) { - yield \Amp\Promise\all($promises); + yield $promises; }); } @@ -260,7 +347,7 @@ abstract class AbstractConnectionTest extends TestCase { })()); Loop::run(function () use ($promises) { - yield \Amp\Promise\all($promises); + yield $promises; }); }