Skip to content
Snippets Groups Projects
dnssd.py 6.39 KiB
Newer Older
Jannis Konrad's avatar
Jannis Konrad committed
import usocket
import utime
import uselect

#TYPEs
TYPE_A   = 1
TYPE_PTR = 12
TYPE_TXT = 16
TYPE_SRV = 33
TYPE_AAAA= 28

#Make request
def str2dnsstr(string):
  strings = string.split(b'.')
  return b''.join((bytes((len(x),)) + x for x in strings))+b'\0'

def makeServiceQuery(service):
  query  = bytes((0, 0,  #ID
                  0, 0,  #Flage (Query,...)
                  0, 1,  #One question
                  0, 0,  #No Answer
                  0, 0,  #NScount
                  0, 0)) #ARcount
  query += str2dnsstr(service)
  query += bytes((0x00, 0x0C,  #PTR record
                  0x80, 0x01)) #Unicast, Query Class: INternet
  return query

def uint16(x):
  return (x[0]<<8)+x[1]

#parse answer
def parseDNSstring(data, i, r=''):
  #print(data[i:])
  qname = b''
  if data and len(data) > i:
    while data[i]:
      #print(r+'parsing 1:', i, hex(data[i]))
      if data[i] & 0xc0:
        p = uint16(data[i:i+2])&0x3FFF
        if p > i:
          print('weird pointer, p>i', p, i)
          return None
        #print(r+'recurse', p)
        ret = parseDNSstring(data, p, r=' '+r)
        #print(r+'back', ret)
        if ret:
          return qname+ret[0], i+2
        else:
          return ret
      elif len(data) > (data[i]+i):
        qname += data[i+1:i+1+(data[i])] + b'.'
        #print(r+'qname+=', data[i+1:i+1+(data[i])])
        i += (data[i]+1)
      else:
        return None
    if qname:
      return qname[:-1], i+1
    else:
      return None #no valid data
  return None

def parsePacket(data):
  if len(data) < 13:
    return None
  if not data[2] & 0x80: #not an answer
    return None
  if data[2] & 0x02: #data truncated, ignore for now
    return None
  if not data[7]: #Zero answers
    return None
  
  questions = uint16(data[4:6])
  answers   = uint16(data[6:8])
  
  #data = data[12:]
  i = 12
  
  for q in range(questions):
    ret = parseDNSstring(data, i)
    if ret:
      qname, d = ret
      i = d
    else:
      return None #invalid Data
    
    qtype = uint16(data[i:i+2])
    i += 2
    print('qtype:', hex(qtype))
    qclass = uint16(data[i:i+2])
    i += 2
    print('qclass', hex(qclass))
  
  adict = {'A': [], 'AAAA': [], 'SRV': [], 'TXT': [], 'PTR': []}
  
  for a in range(answers):
    #print('\nanswer', a+1, '/', answers)
    if len(data) < i+12: #incomplete data
      return None

    ret = parseDNSstring(data, i)
    if ret:
      aname, d = ret
      i = d
    else:
      return None
    #print("aname:", aname)

    atype = uint16(data[i:i+2])
    i += 2
    #print('atype:', hex(atype))
    aclass = uint16(data[i:i+2])&0x7FFF
    i += 2
    #print('aclass', aclass)
    if aclass != 1:
      continue
    attl = data[i:i+4]
    i += 4
    #print('TTL:', attl)
    ardl = uint16(data[i:i+2])
    i += 2
    #print('Data len:', ardl, len(data)-i)
    if len(data)-i < (ardl): #incomplete data
      return None
    
    ardata = data[i: i+ardl]
    
    if atype == TYPE_A:
      if ardl == 4:
        ardata = data[i: i+ardl]
        ip = '.'.join((str(x) for x in ardata))
        #print('A:', {aname: ip})
        adict['A'].append({aname: ip})
    elif atype == TYPE_PTR or atype == TYPE_TXT:
      ret = parseDNSstring(data[:i+ardl], i)
      if ret:
        ardata, j = ret
        if atype == TYPE_PTR:
          #print('PTR:', ardata)
          adict['PTR'].append({aname: ardata})
        elif atype == TYPE_TXT:
          #print('TXT:', ardata)
          adict['TXT'].append({aname: ardata})
        else:
          print('Other:', atype, data[i: i+ardl], ardata)
    elif atype == TYPE_SRV:
      ardata = data[i: i+ardl]
      prio   = uint16(ardata[0:2])
      weight = uint16(ardata[2:4])
      port   = uint16(ardata[4:6])
      ret = parseDNSstring(data[:i+ardl], i+6)
      if ret:
        target, j = ret
        #print('SRV:', prio, weight, port, target)
        adict['SRV'].append({aname: {'prio': prio, 'weight': weight, 'port': port, 'target': target}})
    elif atype == TYPE_AAAA:
      if ardl == 16:
        ardata = data[i: i+ardl]
        ip = ':'.join((hex(x) for x in ardata))
        adict['AAAA'].append({aname: ip})
        #print('AAAA:', {aname: ip})
    else:
      print('Other:', atype, data[i: i+ardl])
    i += ardl
    #print('ardata', ardata)
    #if len(data)-i > 0:
      #print("remaining Data:", data[i:])
  
  return adict

def getUdps():
  udps = usocket.socket(usocket.AF_INET, usocket.SOCK_DGRAM)
  addr = usocket.getaddrinfo("0.0.0.0", 5353)[0][-1]
  udps.bind(addr)
  mreq = bytes([224, 0, 0, 251]) + bytes([0, 0, 0, 0])
  udps.setsockopt(usocket.IPPROTO_IP, usocket.IP_ADD_MEMBERSHIP, mreq)
  return udps

def service2ip(services, service):
  servers = []
  for s in services:
    for ptr in s['PTR']:
      if service in ptr:
        sname = ptr[service]
        for srv in s['SRV']:
          if sname in srv:
            host = srv[sname]['target']
            port = srv[sname]['port']
            for a in s['A']:
              if host in a:
                servers.append((a[host], port))
  return servers

def test():
  data = b'\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x05_mqtt\x04_tcp\x05local\x00\x00\x0c\x00\x01\x00\x00\x11\x94\x00 \x1dMosquitto MQTT server on arc4\xc0\x0c\xc0(\x00\x10\x80\x01\x00\x00\x11\x94\x00\x01\x00\xc0(\x00!\x80\x01\x00\x00\x00x\x00\r\x00\x00\x00\x00\x07[\x04arc4\xc0\x17\xc0g\x00\x1c\x80\x01\x00\x00\x00x\x00\x10\xfd; \xc7u\x82\x00\x00\x8d\x99\x8e\xe8S\x91g\xc5\xc0g\x00\x1c\x80\x01\x00\x00\x00x\x00\x10\xfd; \xc7u\x82\x00\x00\x00\x00\x00\x00\x00\x00\x03\xd1\xc0g\x00\x01\x80\x01\x00\x00\x00x\x00\x04\xc0\xa8*~'
  
  services = parsePacket(data)
  print('test:', service2ip([services,], b'_mqtt._tcp.local') == [('192.168.42.126', 1883)])

def queryService(udps, service=b'_mqtt._tcp.local', timeout=1000):
  try:
    #udps.setblocking(True)
    udps.sendto(makeServiceQuery(service), ('224.0.0.251', 5353))
    #udps.setblocking(False)

    services = []
    
    starttime = utime.ticks_ms()
    while utime.ticks_diff(utime.ticks_ms(), starttime) < timeout:
      poller = uselect.poll()
      poller.register(udps, uselect.POLLIN)
      while poller.poll(100):
        try:
          data, addr = udps.recvfrom(1024)
          print(data, addr)
          s = parsePacket(data)
          if s:
            services.append(s)
        except OSError:
          # No data to evaluate
          pass
  except Exception as e:
    import sys
    sys.print_exception(e)
  
  return service2ip(services, service)

if __name__ == '__main__':
  print("start test")
  test()