# -*- Mode: Python; tab-width: 4 -*-

#
#	Author: Sam Rushing <rushing@nightmare.com>
#

RCS_ID =  '$Id: resolver.py,v 1.12.8.1 2003/07/21 16:37:34 chrism Exp $'


# Fast, low-overhead asynchronous name resolver.  uses 'pre-cooked'
# DNS requests, unpacks only as much as it needs of the reply.

# see rfc1035 for details

import string
import asyncore
import socket
import sys
import time
from counter import counter

if RCS_ID.startswith('$Id: '):
    VERSION = string.split(RCS_ID)[2]
else:
    VERSION = '0.0'

# header
#                                    1  1  1  1  1  1
#      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                      ID                       |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    QDCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    ANCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    NSCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    ARCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+


# question
#                                    1  1  1  1  1  1
#      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                                               |
#    /                     QNAME                     /
#    /                                               /
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                     QTYPE                     |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                     QCLASS                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+

# build a DNS address request, _quickly_
def fast_address_request (host, id=0):
    return (
            '%c%c' % (chr((id>>8)&0xff),chr(id&0xff))
            + '\001\000\000\001\000\000\000\000\000\000%s\000\000\001\000\001' % (
                    string.join (
                            map (
                                    lambda part: '%c%s' % (chr(len(part)),part),
                                    string.split (host, '.')
                                    ), ''
                            )
                    )
            )
    
def fast_ptr_request (host, id=0):
    return (
            '%c%c' % (chr((id>>8)&0xff),chr(id&0xff))
            + '\001\000\000\001\000\000\000\000\000\000%s\000\000\014\000\001' % (
                    string.join (
                            map (
                                    lambda part: '%c%s' % (chr(len(part)),part),
                                    string.split (host, '.')
                                    ), ''
                            )
                    )
            )
    
def unpack_name (r,pos):
    n = []
    while 1:
        ll = ord(r[pos])
        if (ll&0xc0):
                # compression
            pos = (ll&0x3f << 8) + (ord(r[pos+1]))
        elif ll == 0:
            break			
        else:
            pos = pos + 1
            n.append (r[pos:pos+ll])
            pos = pos + ll
    return string.join (n,'.')
    
def skip_name (r,pos):
    s = pos
    while 1:
        ll = ord(r[pos])
        if (ll&0xc0):
                # compression
            return pos + 2
        elif ll == 0:
            pos = pos + 1
            break
        else:
            pos = pos + ll + 1
    return pos
    
def unpack_ttl (r,pos):
    return reduce (
            lambda x,y: (x<<8)|y,
            map (ord, r[pos:pos+4])
            )
    
    # resource record
    #                                    1  1  1  1  1  1
    #      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                                               |
    #    /                                               /
    #    /                      NAME                     /
    #    |                                               |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                      TYPE                     |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                     CLASS                     |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                      TTL                      |
    #    |                                               |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                   RDLENGTH                    |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
    #    /                     RDATA                     /
    #    /                                               /
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    
def unpack_address_reply (r):
    ancount = (ord(r[6])<<8) + (ord(r[7]))
    # skip question, first name starts at 12,
    # this is followed by QTYPE and QCLASS
    pos = skip_name (r, 12) + 4
    if ancount:
            # we are looking very specifically for
            # an answer with TYPE=A, CLASS=IN (\000\001\000\001)
        for an in range(ancount):
            pos = skip_name (r, pos)
            if r[pos:pos+4] == '\000\001\000\001':
                return (
                        unpack_ttl (r,pos+4),
                        '%d.%d.%d.%d' % tuple(map(ord,r[pos+10:pos+14]))
                        )
                # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA
            pos = pos + 8
            rdlength = (ord(r[pos])<<8) + (ord(r[pos+1]))
            pos = pos + 2 + rdlength
        return 0, None
    else:
        return 0, None
        
def unpack_ptr_reply (r):
    ancount = (ord(r[6])<<8) + (ord(r[7]))
    # skip question, first name starts at 12,
    # this is followed by QTYPE and QCLASS
    pos = skip_name (r, 12) + 4
    if ancount:
            # we are looking very specifically for
            # an answer with TYPE=PTR, CLASS=IN (\000\014\000\001)
        for an in range(ancount):
            pos = skip_name (r, pos)
            if r[pos:pos+4] == '\000\014\000\001':
                return (
                        unpack_ttl (r,pos+4),
                        unpack_name (r, pos+10)
                        )
                # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA
            pos = pos + 8
            rdlength = (ord(r[pos])<<8) + (ord(r[pos+1]))
            pos = pos + 2 + rdlength
        return 0, None
    else:
        return 0, None
        
        
        # This is a UDP (datagram) resolver.
        
        #
        # It may be useful to implement a TCP resolver.  This would presumably
        # give us more reliable behavior when things get too busy.  A TCP
        # client would have to manage the connection carefully, since the
        # server is allowed to close it at will (the RFC recommends closing
        # after 2 minutes of idle time).
        #
        # Note also that the TCP client will have to prepend each request
        # with a 2-byte length indicator (see rfc1035).
        #
        
class resolver (asyncore.dispatcher):
    id = counter()
    def __init__ (self, server='127.0.0.1'):
        asyncore.dispatcher.__init__ (self)
        self.create_socket (socket.AF_INET, socket.SOCK_DGRAM)
        self.server = server
        self.request_map = {}
        self.last_reap_time = int(time.time())      # reap every few minutes
        
    def writable (self):
        return 0
        
    def log (self, *args):
        pass
        
    def handle_close (self):
        self.log_info('closing!')
        self.close()
        
    def handle_error (self):      # don't close the connection on error
        (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
        self.log_info(
                        'Problem with DNS lookup (%s:%s %s)' % (t, v, tbinfo),
                        'error')
        
    def get_id (self):
        return (self.id.as_long() % (1<<16))
        
    def reap (self):          # find DNS requests that have timed out
        now = int(time.time())
        if now - self.last_reap_time > 180:        # reap every 3 minutes
            self.last_reap_time = now              # update before we forget
            for k,(host,unpack,callback,when) in self.request_map.items():
                if now - when > 180:               # over 3 minutes old
                    del self.request_map[k]
                    try:                           # same code as in handle_read
                        callback (host, 0, None)   # timeout val is (0,None) 
                    except:
                        (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
                        self.log_info('%s %s %s' % (t,v,tbinfo), 'error')
                        
    def resolve (self, host, callback):
        self.reap()                                # first, get rid of old guys
        self.socket.sendto (
                fast_address_request (host, self.get_id()),
                (self.server, 53)
                )
        self.request_map [self.get_id()] = (
                host, unpack_address_reply, callback, int(time.time()))
        self.id.increment()
        
    def resolve_ptr (self, host, callback):
        self.reap()                                # first, get rid of old guys
        ip = string.split (host, '.')
        ip.reverse()
        ip = string.join (ip, '.') + '.in-addr.arpa'
        self.socket.sendto (
                fast_ptr_request (ip, self.get_id()),
                (self.server, 53)
                )
        self.request_map [self.get_id()] = (
                host, unpack_ptr_reply, callback, int(time.time()))
        self.id.increment()
        
    def handle_read (self):
        reply, whence = self.socket.recvfrom (512)
        # for security reasons we may want to double-check
        # that <whence> is the server we sent the request to.
        id = (ord(reply[0])<<8) + ord(reply[1])
        if self.request_map.has_key (id):
            host, unpack, callback, when = self.request_map[id]
            del self.request_map[id]
            ttl, answer = unpack (reply)
            try:
                callback (host, ttl, answer)
            except:
                (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
                self.log_info('%s %s %s' % ( t,v,tbinfo), 'error')
                
class rbl (resolver):

    def resolve_maps (self, host, callback):
        ip = string.split (host, '.')
        ip.reverse()
        ip = string.join (ip, '.') + '.rbl.maps.vix.com'
        self.socket.sendto (
                fast_ptr_request (ip, self.get_id()),
                (self.server, 53)
                )
        self.request_map [self.get_id()] = host, self.check_reply, callback
        self.id.increment()
        
    def check_reply (self, r):
            # we only need to check RCODE.
        rcode = (ord(r[3])&0xf)
        self.log_info('MAPS RBL; RCODE =%02x\n %s' % (rcode, repr(r)))
        return 0, rcode # (ttl, answer)
        
        
class hooked_callback:
    def __init__ (self, hook, callback):
        self.hook, self.callback = hook, callback
        
    def __call__ (self, *args):
        self.hook(*args)
        self.callback(*args)
        
class caching_resolver (resolver):
    "Cache DNS queries.  Will need to honor the TTL value in the replies"
    
    def __init__(self, *args):
        resolver.__init__(self, *args)
        self.cache = {}
        self.forward_requests = counter()
        self.reverse_requests = counter()
        self.cache_hits = counter()
        
    def resolve (self, host, callback):
        self.forward_requests.increment()
        if self.cache.has_key (host):
            when, ttl, answer = self.cache[host]
            # ignore TTL for now
            callback (host, ttl, answer)
            self.cache_hits.increment()
        else:
            resolver.resolve (
                    self,
                    host,
                    hooked_callback (
                            self.callback_hook,
                            callback
                            )
                    )
            
    def resolve_ptr (self, host, callback):
        self.reverse_requests.increment()
        if self.cache.has_key (host):
            when, ttl, answer = self.cache[host]
            # ignore TTL for now
            callback (host, ttl, answer)
            self.cache_hits.increment()
        else:
            resolver.resolve_ptr (
                    self,
                    host,
                    hooked_callback (
                            self.callback_hook,
                            callback
                            )
                    )
            
    def callback_hook (self, host, ttl, answer):
        self.cache[host] = time.time(), ttl, answer
        
    SERVER_IDENT = 'Caching DNS Resolver (V%s)' % VERSION
    
    def status (self):
        import status_handler
        import producers
        return producers.simple_producer (
                '<h2>%s</h2>'					% self.SERVER_IDENT
                + '<br>Server: %s'				% self.server
                + '<br>Cache Entries: %d'		% len(self.cache)
                + '<br>Outstanding Requests: %d' % len(self.request_map)
                + '<br>Forward Requests: %s'	% self.forward_requests
                + '<br>Reverse Requests: %s'	% self.reverse_requests
                + '<br>Cache Hits: %s'			% self.cache_hits
                )
        
        #test_reply = """\000\000\205\200\000\001\000\001\000\002\000\002\006squirl\011nightmare\003com\000\000\001\000\001\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\011nightmare\003com\000\000\002\000\001\000\001Q\200\000\002\300\014\3006\000\002\000\001\000\001Q\200\000\015\003ns1\003iag\003net\000\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\300]\000\001\000\001\000\000\350\227\000\004\314\033\322\005"""
        # def test_unpacker ():
        # 	print unpack_address_reply (test_reply)
        # 
        # import time
        # class timer:
        # 	def __init__ (self):
        # 		self.start = time.time()
        # 	def end (self):
        # 		return time.time() - self.start
        # 
        # # I get ~290 unpacks per second for the typical case, compared to ~48
        # # using dnslib directly.  also, that latter number does not include
        # # picking the actual data out.
        # 
        # def benchmark_unpacker():
        # 
        # 	r = range(1000)
        # 	t = timer()
        # 	for i in r:
        # 		unpack_address_reply (test_reply)
        # 	print '%.2f unpacks per second' % (1000.0 / t.end())
        
if __name__ == '__main__':
    import sys
    if len(sys.argv) == 1:
        print 'usage: %s [-r] [-s <server_IP>] host [host ...]' % sys.argv[0]
        sys.exit(0)
    elif ('-s' in sys.argv):
        i = sys.argv.index('-s')
        server = sys.argv[i+1]
        del sys.argv[i:i+2]
    else:
        server = '127.0.0.1'
        
    if ('-r' in sys.argv):
        reverse = 1
        i = sys.argv.index('-r')
        del sys.argv[i]
    else:
        reverse = 0
        
    if ('-m' in sys.argv):
        maps = 1
        sys.argv.remove ('-m')
    else:
        maps = 0
        
    if maps:
        r = rbl (server)
    else:
        r = caching_resolver(server)
        
    count = len(sys.argv) - 1
    
    def print_it (host, ttl, answer):
        global count
        print '%s: %s' % (host, answer)
        count = count - 1
        if not count:
            r.close()
            
    for host in sys.argv[1:]:
        if reverse:
            r.resolve_ptr (host, print_it)
        elif maps:
            r.resolve_maps (host, print_it)
        else:
            r.resolve (host, print_it)
            
            # hooked asyncore.loop()
    while asyncore.socket_map:
        asyncore.poll (30.0)
        print 'requests outstanding: %d' % len(r.request_map)
