1
0
mirror of https://github.com/danog/dns.git synced 2024-12-03 09:57:56 +01:00

Use input/output streams

This commit is contained in:
Aaron Piotrowski 2017-06-22 19:38:36 -05:00
parent ba8a8423eb
commit fdeb03ca44
No known key found for this signature in database
GPG Key ID: ADD1EF783EDE9EEB

View File

@ -3,9 +3,12 @@
namespace Amp\Dns; namespace Amp\Dns;
use Amp; use Amp;
use Amp\ByteStream\ResourceInputStream;
use Amp\ByteStream\ResourceOutputStream;
use Amp\ByteStream\StreamException;
use Amp\Cache\ArrayCache; use Amp\Cache\ArrayCache;
use Amp\Cache\Cache; use Amp\Cache\Cache;
use Amp\CallableMaker; use Amp\Coroutine;
use Amp\Deferred; use Amp\Deferred;
use Amp\Failure; use Amp\Failure;
use Amp\Loop; use Amp\Loop;
@ -21,8 +24,6 @@ use LibDNS\Records\QuestionFactory;
use function Amp\call; use function Amp\call;
class DefaultResolver implements Resolver { class DefaultResolver implements Resolver {
use CallableMaker;
const MAX_REQUEST_ID = 65536; const MAX_REQUEST_ID = 65536;
const IDLE_TIMEOUT = 15000; const IDLE_TIMEOUT = 15000;
const CACHE_PREFIX = "amphp.dns."; const CACHE_PREFIX = "amphp.dns.";
@ -145,7 +146,7 @@ class DefaultResolver implements Resolver {
}, $records); }, $records);
} }
private function recurseWithHosts($name, array $types) { private function recurseWithHosts($name, array $types): \Generator {
/** @var Config $config */ /** @var Config $config */
$config = yield $this->loadConfig(); $config = yield $this->loadConfig();
$hosts = $config->getKnownHosts(); $hosts = $config->getKnownHosts();
@ -166,7 +167,7 @@ class DefaultResolver implements Resolver {
return yield from $this->doRecurse($name, $types); return yield from $this->doRecurse($name, $types);
} }
private function doRecurse($name, array $types) { private function doRecurse($name, array $types): \Generator {
if (\array_intersect($types, [Record::CNAME, Record::DNAME])) { if (\array_intersect($types, [Record::CNAME, Record::DNAME])) {
throw new \Error("Cannot use recursion for CNAME and DNAME records"); throw new \Error("Cannot use recursion for CNAME and DNAME records");
} }
@ -194,17 +195,15 @@ class DefaultResolver implements Resolver {
throw new ResolutionException("CNAME or DNAME chain too long (possible recursion?)"); throw new ResolutionException("CNAME or DNAME chain too long (possible recursion?)");
} }
private function doRequest($uri, $name, $type) { private function loadServer($uri): Promise {
$server = $this->loadExistingServer($uri) ?: $this->loadNewServer($uri); if ($server = $this->loadExistingServer($uri)) {
$useTCP = \substr($uri, 0, 6) == "tcp://"; return new Success($server);
if ($useTCP && isset($server->connect)) {
return call(function () use ($server, $uri, $name, $type) {
yield $server->connect;
return $this->doRequest($uri, $name, $type);
});
} }
return $this->loadNewServer($uri);
}
private function doRequest($server, $name, $type): \Generator {
// Get the next available request ID // Get the next available request ID
do { do {
$requestId = $this->requestIdCounter++; $requestId = $this->requestIdCounter++;
@ -226,27 +225,66 @@ class DefaultResolver implements Resolver {
// Encode request message // Encode request message
$requestPacket = $this->encoder->encode($request); $requestPacket = $this->encoder->encode($request);
if ($useTCP) { if ($server->tcp) {
$requestPacket = \pack("n", \strlen($requestPacket)) . $requestPacket; $requestPacket = \pack("n", \strlen($requestPacket)) . $requestPacket;
} }
// Send request // Send request
// FIXME: Fix might not write all bytes if TCP is used, as the buffer might be full try {
$bytesWritten = @\fwrite($server->socket, $requestPacket); yield $server->output->write($requestPacket);
if ($bytesWritten === false || $bytesWritten === 0 && (!\is_resource($server->socket) || !\feof($server->socket))) { } catch (StreamException $exception) {
$exception = new ResolutionException("Request send failed"); $exception = new ResolutionException("Request send failed", 0, $exception);
$this->unloadServer($server->id, $exception); $this->unloadServer($server->id, $exception);
throw $exception; throw $exception;
} }
$deferred = new Deferred; $deferred = new Deferred;
$server->pendingRequests[$requestId] = true; $server->pendingRequests[$requestId] = true;
$this->pendingRequests[$requestId] = [$deferred, $name, $type, $uri]; $this->pendingRequests[$requestId] = [$deferred, $name, $type, $server->uri];
return $deferred->promise(); $packet = yield $server->input->read();
if ($packet === null) {
$exception = new ResolutionException("Server connection failed");
$this->unloadServer($server->id, $exception);
throw $exception;
}
if ($server->tcp) {
if ($server->length === INF) {
$server->length = \unpack("n", $packet)[1];
$packet = \substr($packet, 2);
}
$server->buffer .= $packet;
while ($server->length > \strlen($server->buffer)) {
$packet = yield $server->input->read();
if ($packet === null) {
$exception = new ResolutionException("Server connection failed");
$this->unloadServer($server->id, $exception);
throw $exception;
}
}
while ($server->length <= \strlen($server->buffer)) {
$this->decodeResponsePacket($server->id, \substr($server->buffer, 0, $server->length));
$server->buffer = \substr($server->buffer, $server->length);
if (\strlen($server->buffer) >= 2 + $server->length) {
$server->length = \unpack("n", $server->buffer)[1];
$server->buffer = \substr($server->buffer, 2);
} else {
$server->length = INF;
}
}
} else {
$this->decodeResponsePacket($server->id, $packet);
}
return yield $deferred->promise();
} }
private function doResolve($name, array $types) { private function doResolve($name, array $types): \Generator {
/** @var Config $config */ /** @var Config $config */
$config = yield $this->loadConfig(); $config = yield $this->loadConfig();
@ -298,9 +336,11 @@ class DefaultResolver implements Resolver {
$i = $attempt % \count($nameservers); $i = $attempt % \count($nameservers);
$uri = "udp://" . $nameservers[$i]; $uri = "udp://" . $nameservers[$i];
$server = yield $this->loadServer($uri);
$promises = []; $promises = [];
foreach ($types as $type) { foreach ($types as $type) {
$promises[] = $this->doRequest($uri, $name, $type); $promises[] = new Coroutine($this->doRequest($server, $name, $type));
} }
try { try {
@ -341,7 +381,6 @@ class DefaultResolver implements Resolver {
if (\is_resource($server->socket)) { if (\is_resource($server->socket)) {
unset($this->serverIdTimeoutMap[$server->id]); unset($this->serverIdTimeoutMap[$server->id]);
Loop::enable($server->watcherId);
return $server; return $server;
} }
@ -349,7 +388,7 @@ class DefaultResolver implements Resolver {
return null; return null;
} }
private function loadNewServer($uri) { private function loadNewServer($uri): Promise {
if (!$socket = @\stream_socket_client($uri, $errno, $errstr, 0, STREAM_CLIENT_ASYNC_CONNECT)) { if (!$socket = @\stream_socket_client($uri, $errno, $errstr, 0, STREAM_CLIENT_ASYNC_CONNECT)) {
throw new ResolutionException(\sprintf( throw new ResolutionException(\sprintf(
"Connection to %s failed: [Error #%d] %s", "Connection to %s failed: [Error #%d] %s",
@ -367,32 +406,30 @@ class DefaultResolver implements Resolver {
public $id; public $id;
public $uri; public $uri;
public $server; public $server;
public $socket;
public $buffer = ""; public $buffer = "";
public $length = INF; public $length = INF;
public $pendingRequests = []; public $pendingRequests = [];
public $watcherId; public $tcp = false;
public $connect; public $input;
public $output;
}; };
$server->id = $id; $server->id = $id;
$server->uri = $uri; $server->uri = $uri;
$server->socket = $socket;
$server->pendingRequests = []; $server->pendingRequests = [];
$server->watcherId = Loop::onReadable($socket, $this->callableFromInstanceMethod("onReadable")); $server->input = new ResourceInputStream($socket);
Loop::disable($server->watcherId); $server->output = new ResourceOutputStream($socket);
$this->serverIdMap[$id] = $server; $this->serverIdMap[$id] = $server;
$this->serverUriMap[$uri] = $server; $this->serverUriMap[$uri] = $server;
if (\substr($uri, 0, 6) == "tcp://") { if (\substr($uri, 0, 6) == "tcp://") {
$server->tcp = true;
$deferred = new Deferred; $deferred = new Deferred;
$server->connect = $deferred->promise(); $watcher = Loop::onWritable($socket, static function ($watcher) use ($server, $deferred, &$timer) {
$watcher = Loop::onWritable($server->socket, static function ($watcher) use ($server, $deferred, &$timer) {
Loop::cancel($watcher); Loop::cancel($watcher);
Loop::cancel($timer); Loop::cancel($timer);
$server->connect = null; $deferred->resolve($server);
$deferred->resolve();
}); });
// TODO: Respect timeout // TODO: Respect timeout
$timer = Loop::delay(5000, function () use ($id, $deferred, $watcher, $uri) { $timer = Loop::delay(5000, function () use ($id, $deferred, $watcher, $uri) {
@ -400,9 +437,10 @@ class DefaultResolver implements Resolver {
$this->unloadServer($id); $this->unloadServer($id);
$deferred->fail(new TimeoutException("Name resolution timed out, could not connect to server at $uri")); $deferred->fail(new TimeoutException("Name resolution timed out, could not connect to server at $uri"));
}); });
return $deferred->promise();
} }
return $server; return new Success($server);
} }
private function unloadServer($serverId, $error = null) { private function unloadServer($serverId, $error = null) {
@ -412,7 +450,6 @@ class DefaultResolver implements Resolver {
} }
$server = $this->serverIdMap[$serverId]; $server = $this->serverIdMap[$serverId];
Loop::cancel($server->watcherId);
unset( unset(
$this->serverIdMap[$serverId], $this->serverIdMap[$serverId],
$this->serverUriMap[$server->uri] $this->serverUriMap[$server->uri]
@ -428,37 +465,6 @@ class DefaultResolver implements Resolver {
} }
} }
private function onReadable($watcherId, $socket) {
$serverId = (int) $socket;
$packet = @\fread($socket, 512);
if ($packet != "") {
$server = $this->serverIdMap[$serverId];
if (\substr($server->uri, 0, 6) == "tcp://") {
if ($server->length == INF) {
$server->length = \unpack("n", $packet)[1];
$packet = \substr($packet, 2);
}
$server->buffer .= $packet;
while ($server->length <= \strlen($server->buffer)) {
$this->decodeResponsePacket($serverId, \substr($server->buffer, 0, $server->length));
$server->buffer = substr($server->buffer, $server->length);
if (\strlen($server->buffer) >= 2 + $server->length) {
$server->length = \unpack("n", $server->buffer)[1];
$server->buffer = \substr($server->buffer, 2);
} else {
$server->length = INF;
}
}
} else {
$this->decodeResponsePacket($serverId, $packet);
}
} else {
$this->unloadServer($serverId, new ResolutionException(
"Server connection failed"
));
}
}
private function decodeResponsePacket($serverId, $packet) { private function decodeResponsePacket($serverId, $packet) {
try { try {
$response = $this->decoder->decode($packet); $response = $this->decoder->decode($packet);
@ -485,13 +491,17 @@ class DefaultResolver implements Resolver {
} }
private function processDecodedResponse($serverId, $requestId, $response) { private function processDecodedResponse($serverId, $requestId, $response) {
/** @var \Amp\Deferred $deferred */
list($deferred, $name, $type, $uri) = $this->pendingRequests[$requestId]; list($deferred, $name, $type, $uri) = $this->pendingRequests[$requestId];
// Retry via tcp if message has been truncated // Retry via tcp if message has been truncated
if ($response->isTruncated()) { if ($response->isTruncated()) {
if (\substr($uri, 0, 6) != "tcp://") { if (\substr($uri, 0, 6) != "tcp://") {
$uri = \preg_replace("#[a-z.]+://#", "tcp://", $uri); $uri = \preg_replace("#[a-z.]+://#", "tcp://", $uri);
$deferred->resolve($this->doRequest($uri, $name, $type)); $deferred->resolve(call(function () use ($uri, $name, $type) {
$server = yield $this->loadServer($uri);
return yield from $this->doRequest($server, $name, $type);
}));
} else { } else {
$this->finalizeResult($serverId, $requestId, new ResolutionException( $this->finalizeResult($serverId, $requestId, new ResolutionException(
"Server returned truncated response" "Server returned truncated response"
@ -528,7 +538,6 @@ class DefaultResolver implements Resolver {
); );
if (empty($server->pendingRequests)) { if (empty($server->pendingRequests)) {
$this->serverIdTimeoutMap[$server->id] = $this->now + self::IDLE_TIMEOUT; $this->serverIdTimeoutMap[$server->id] = $this->now + self::IDLE_TIMEOUT;
Loop::disable($server->watcherId);
Loop::enable($this->serverTimeoutWatcher); Loop::enable($this->serverTimeoutWatcher);
} }
if ($error) { if ($error) {