diff --git a/lib/DefaultResolver.php b/lib/DefaultResolver.php index 91a677a..caf1204 100644 --- a/lib/DefaultResolver.php +++ b/lib/DefaultResolver.php @@ -3,9 +3,12 @@ namespace Amp\Dns; use Amp; +use Amp\ByteStream\ResourceInputStream; +use Amp\ByteStream\ResourceOutputStream; +use Amp\ByteStream\StreamException; use Amp\Cache\ArrayCache; use Amp\Cache\Cache; -use Amp\CallableMaker; +use Amp\Coroutine; use Amp\Deferred; use Amp\Failure; use Amp\Loop; @@ -21,8 +24,6 @@ use LibDNS\Records\QuestionFactory; use function Amp\call; class DefaultResolver implements Resolver { - use CallableMaker; - const MAX_REQUEST_ID = 65536; const IDLE_TIMEOUT = 15000; const CACHE_PREFIX = "amphp.dns."; @@ -145,7 +146,7 @@ class DefaultResolver implements Resolver { }, $records); } - private function recurseWithHosts($name, array $types) { + private function recurseWithHosts($name, array $types): \Generator { /** @var Config $config */ $config = yield $this->loadConfig(); $hosts = $config->getKnownHosts(); @@ -166,7 +167,7 @@ class DefaultResolver implements Resolver { return yield from $this->doRecurse($name, $types); } - private function doRecurse($name, array $types) { + private function doRecurse($name, array $types): \Generator { if (\array_intersect($types, [Record::CNAME, Record::DNAME])) { throw new \Error("Cannot use recursion for CNAME and DNAME records"); } @@ -194,17 +195,15 @@ class DefaultResolver implements Resolver { throw new ResolutionException("CNAME or DNAME chain too long (possible recursion?)"); } - private function doRequest($uri, $name, $type) { - $server = $this->loadExistingServer($uri) ?: $this->loadNewServer($uri); - $useTCP = \substr($uri, 0, 6) == "tcp://"; - - if ($useTCP && isset($server->connect)) { - return call(function () use ($server, $uri, $name, $type) { - yield $server->connect; - return $this->doRequest($uri, $name, $type); - }); + private function loadServer($uri): Promise { + if ($server = $this->loadExistingServer($uri)) { + return new Success($server); } + return $this->loadNewServer($uri); + } + + private function doRequest($server, $name, $type): \Generator { // Get the next available request ID do { $requestId = $this->requestIdCounter++; @@ -226,27 +225,66 @@ class DefaultResolver implements Resolver { // Encode request message $requestPacket = $this->encoder->encode($request); - if ($useTCP) { + if ($server->tcp) { $requestPacket = \pack("n", \strlen($requestPacket)) . $requestPacket; } // Send request - // FIXME: Fix might not write all bytes if TCP is used, as the buffer might be full - $bytesWritten = @\fwrite($server->socket, $requestPacket); - if ($bytesWritten === false || $bytesWritten === 0 && (!\is_resource($server->socket) || !\feof($server->socket))) { - $exception = new ResolutionException("Request send failed"); + try { + yield $server->output->write($requestPacket); + } catch (StreamException $exception) { + $exception = new ResolutionException("Request send failed", 0, $exception); $this->unloadServer($server->id, $exception); throw $exception; } $deferred = new Deferred; $server->pendingRequests[$requestId] = true; - $this->pendingRequests[$requestId] = [$deferred, $name, $type, $uri]; + $this->pendingRequests[$requestId] = [$deferred, $name, $type, $server->uri]; - return $deferred->promise(); + $packet = yield $server->input->read(); + + if ($packet === null) { + $exception = new ResolutionException("Server connection failed"); + $this->unloadServer($server->id, $exception); + throw $exception; + } + + if ($server->tcp) { + if ($server->length === INF) { + $server->length = \unpack("n", $packet)[1]; + $packet = \substr($packet, 2); + } + + $server->buffer .= $packet; + + while ($server->length > \strlen($server->buffer)) { + $packet = yield $server->input->read(); + if ($packet === null) { + $exception = new ResolutionException("Server connection failed"); + $this->unloadServer($server->id, $exception); + throw $exception; + } + } + + while ($server->length <= \strlen($server->buffer)) { + $this->decodeResponsePacket($server->id, \substr($server->buffer, 0, $server->length)); + $server->buffer = \substr($server->buffer, $server->length); + if (\strlen($server->buffer) >= 2 + $server->length) { + $server->length = \unpack("n", $server->buffer)[1]; + $server->buffer = \substr($server->buffer, 2); + } else { + $server->length = INF; + } + } + } else { + $this->decodeResponsePacket($server->id, $packet); + } + + return yield $deferred->promise(); } - private function doResolve($name, array $types) { + private function doResolve($name, array $types): \Generator { /** @var Config $config */ $config = yield $this->loadConfig(); @@ -298,9 +336,11 @@ class DefaultResolver implements Resolver { $i = $attempt % \count($nameservers); $uri = "udp://" . $nameservers[$i]; + $server = yield $this->loadServer($uri); + $promises = []; foreach ($types as $type) { - $promises[] = $this->doRequest($uri, $name, $type); + $promises[] = new Coroutine($this->doRequest($server, $name, $type)); } try { @@ -341,7 +381,6 @@ class DefaultResolver implements Resolver { if (\is_resource($server->socket)) { unset($this->serverIdTimeoutMap[$server->id]); - Loop::enable($server->watcherId); return $server; } @@ -349,7 +388,7 @@ class DefaultResolver implements Resolver { return null; } - private function loadNewServer($uri) { + private function loadNewServer($uri): Promise { if (!$socket = @\stream_socket_client($uri, $errno, $errstr, 0, STREAM_CLIENT_ASYNC_CONNECT)) { throw new ResolutionException(\sprintf( "Connection to %s failed: [Error #%d] %s", @@ -367,32 +406,30 @@ class DefaultResolver implements Resolver { public $id; public $uri; public $server; - public $socket; public $buffer = ""; public $length = INF; public $pendingRequests = []; - public $watcherId; - public $connect; + public $tcp = false; + public $input; + public $output; }; $server->id = $id; $server->uri = $uri; - $server->socket = $socket; $server->pendingRequests = []; - $server->watcherId = Loop::onReadable($socket, $this->callableFromInstanceMethod("onReadable")); - Loop::disable($server->watcherId); + $server->input = new ResourceInputStream($socket); + $server->output = new ResourceOutputStream($socket); $this->serverIdMap[$id] = $server; $this->serverUriMap[$uri] = $server; if (\substr($uri, 0, 6) == "tcp://") { + $server->tcp = true; $deferred = new Deferred; - $server->connect = $deferred->promise(); - $watcher = Loop::onWritable($server->socket, static function ($watcher) use ($server, $deferred, &$timer) { + $watcher = Loop::onWritable($socket, static function ($watcher) use ($server, $deferred, &$timer) { Loop::cancel($watcher); Loop::cancel($timer); - $server->connect = null; - $deferred->resolve(); + $deferred->resolve($server); }); // TODO: Respect timeout $timer = Loop::delay(5000, function () use ($id, $deferred, $watcher, $uri) { @@ -400,9 +437,10 @@ class DefaultResolver implements Resolver { $this->unloadServer($id); $deferred->fail(new TimeoutException("Name resolution timed out, could not connect to server at $uri")); }); + return $deferred->promise(); } - return $server; + return new Success($server); } private function unloadServer($serverId, $error = null) { @@ -412,7 +450,6 @@ class DefaultResolver implements Resolver { } $server = $this->serverIdMap[$serverId]; - Loop::cancel($server->watcherId); unset( $this->serverIdMap[$serverId], $this->serverUriMap[$server->uri] @@ -428,37 +465,6 @@ class DefaultResolver implements Resolver { } } - private function onReadable($watcherId, $socket) { - $serverId = (int) $socket; - $packet = @\fread($socket, 512); - if ($packet != "") { - $server = $this->serverIdMap[$serverId]; - if (\substr($server->uri, 0, 6) == "tcp://") { - if ($server->length == INF) { - $server->length = \unpack("n", $packet)[1]; - $packet = \substr($packet, 2); - } - $server->buffer .= $packet; - while ($server->length <= \strlen($server->buffer)) { - $this->decodeResponsePacket($serverId, \substr($server->buffer, 0, $server->length)); - $server->buffer = substr($server->buffer, $server->length); - if (\strlen($server->buffer) >= 2 + $server->length) { - $server->length = \unpack("n", $server->buffer)[1]; - $server->buffer = \substr($server->buffer, 2); - } else { - $server->length = INF; - } - } - } else { - $this->decodeResponsePacket($serverId, $packet); - } - } else { - $this->unloadServer($serverId, new ResolutionException( - "Server connection failed" - )); - } - } - private function decodeResponsePacket($serverId, $packet) { try { $response = $this->decoder->decode($packet); @@ -485,13 +491,17 @@ class DefaultResolver implements Resolver { } private function processDecodedResponse($serverId, $requestId, $response) { + /** @var \Amp\Deferred $deferred */ list($deferred, $name, $type, $uri) = $this->pendingRequests[$requestId]; // Retry via tcp if message has been truncated if ($response->isTruncated()) { if (\substr($uri, 0, 6) != "tcp://") { $uri = \preg_replace("#[a-z.]+://#", "tcp://", $uri); - $deferred->resolve($this->doRequest($uri, $name, $type)); + $deferred->resolve(call(function () use ($uri, $name, $type) { + $server = yield $this->loadServer($uri); + return yield from $this->doRequest($server, $name, $type); + })); } else { $this->finalizeResult($serverId, $requestId, new ResolutionException( "Server returned truncated response" @@ -528,7 +538,6 @@ class DefaultResolver implements Resolver { ); if (empty($server->pendingRequests)) { $this->serverIdTimeoutMap[$server->id] = $this->now + self::IDLE_TIMEOUT; - Loop::disable($server->watcherId); Loop::enable($this->serverTimeoutWatcher); } if ($error) {