Skip to content
Snippets Groups Projects

DNS-SD

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Jannis Konrad
    dnssd.py 6.39 KiB
    import usocket
    import utime
    import uselect
    
    #TYPEs
    TYPE_A   = 1
    TYPE_PTR = 12
    TYPE_TXT = 16
    TYPE_SRV = 33
    TYPE_AAAA= 28
    
    #Make request
    def str2dnsstr(string):
      strings = string.split(b'.')
      return b''.join((bytes((len(x),)) + x for x in strings))+b'\0'
    
    def makeServiceQuery(service):
      query  = bytes((0, 0,  #ID
                      0, 0,  #Flage (Query,...)
                      0, 1,  #One question
                      0, 0,  #No Answer
                      0, 0,  #NScount
                      0, 0)) #ARcount
      query += str2dnsstr(service)
      query += bytes((0x00, 0x0C,  #PTR record
                      0x80, 0x01)) #Unicast, Query Class: INternet
      return query
    
    def uint16(x):
      return (x[0]<<8)+x[1]
    
    #parse answer
    def parseDNSstring(data, i, r=''):
      #print(data[i:])
      qname = b''
      if data and len(data) > i:
        while data[i]:
          #print(r+'parsing 1:', i, hex(data[i]))
          if data[i] & 0xc0:
            p = uint16(data[i:i+2])&0x3FFF
            if p > i:
              print('weird pointer, p>i', p, i)
              return None
            #print(r+'recurse', p)
            ret = parseDNSstring(data, p, r=' '+r)
            #print(r+'back', ret)
            if ret:
              return qname+ret[0], i+2
            else:
              return ret
          elif len(data) > (data[i]+i):
            qname += data[i+1:i+1+(data[i])] + b'.'
            #print(r+'qname+=', data[i+1:i+1+(data[i])])
            i += (data[i]+1)
          else:
            return None
        if qname:
          return qname[:-1], i+1
        else:
          return None #no valid data
      return None
    
    def parsePacket(data):
      if len(data) < 13:
        return None
      if not data[2] & 0x80: #not an answer
        return None
      if data[2] & 0x02: #data truncated, ignore for now
        return None
      if not data[7]: #Zero answers
        return None
      
      questions = uint16(data[4:6])
      answers   = uint16(data[6:8])
      
      #data = data[12:]
      i = 12
      
      for q in range(questions):
        ret = parseDNSstring(data, i)
        if ret:
          qname, d = ret
          i = d
        else:
          return None #invalid Data
        
        qtype = uint16(data[i:i+2])
        i += 2
        print('qtype:', hex(qtype))
        qclass = uint16(data[i:i+2])
        i += 2
        print('qclass', hex(qclass))
      
      adict = {'A': [], 'AAAA': [], 'SRV': [], 'TXT': [], 'PTR': []}
      
      for a in range(answers):
        #print('\nanswer', a+1, '/', answers)
        if len(data) < i+12: #incomplete data
          return None
    
        ret = parseDNSstring(data, i)
        if ret:
          aname, d = ret
          i = d
        else:
          return None
        #print("aname:", aname)
    
        atype = uint16(data[i:i+2])
        i += 2
        #print('atype:', hex(atype))
        aclass = uint16(data[i:i+2])&0x7FFF
        i += 2
        #print('aclass', aclass)
        if aclass != 1:
          continue
        attl = data[i:i+4]
        i += 4
        #print('TTL:', attl)
        ardl = uint16(data[i:i+2])
        i += 2
        #print('Data len:', ardl, len(data)-i)
        if len(data)-i < (ardl): #incomplete data
          return None
        
        ardata = data[i: i+ardl]
        
        if atype == TYPE_A:
          if ardl == 4:
            ardata = data[i: i+ardl]
            ip = '.'.join((str(x) for x in ardata))
            #print('A:', {aname: ip})
            adict['A'].append({aname: ip})
        elif atype == TYPE_PTR or atype == TYPE_TXT:
          ret = parseDNSstring(data[:i+ardl], i)
          if ret:
            ardata, j = ret
            if atype == TYPE_PTR:
              #print('PTR:', ardata)
              adict['PTR'].append({aname: ardata})
            elif atype == TYPE_TXT:
              #print('TXT:', ardata)
              adict['TXT'].append({aname: ardata})
            else:
              print('Other:', atype, data[i: i+ardl], ardata)
        elif atype == TYPE_SRV:
          ardata = data[i: i+ardl]
          prio   = uint16(ardata[0:2])
          weight = uint16(ardata[2:4])
          port   = uint16(ardata[4:6])
          ret = parseDNSstring(data[:i+ardl], i+6)
          if ret:
            target, j = ret
            #print('SRV:', prio, weight, port, target)
            adict['SRV'].append({aname: {'prio': prio, 'weight': weight, 'port': port, 'target': target}})
        elif atype == TYPE_AAAA:
          if ardl == 16:
            ardata = data[i: i+ardl]
            ip = ':'.join((hex(x) for x in ardata))
            adict['AAAA'].append({aname: ip})
            #print('AAAA:', {aname: ip})
        else:
          print('Other:', atype, data[i: i+ardl])
        i += ardl
        #print('ardata', ardata)
        #if len(data)-i > 0:
          #print("remaining Data:", data[i:])
      
      return adict
    
    def getUdps():
      udps = usocket.socket(usocket.AF_INET, usocket.SOCK_DGRAM)
      addr = usocket.getaddrinfo("0.0.0.0", 5353)[0][-1]
      udps.bind(addr)
      mreq = bytes([224, 0, 0, 251]) + bytes([0, 0, 0, 0])
      udps.setsockopt(usocket.IPPROTO_IP, usocket.IP_ADD_MEMBERSHIP, mreq)
      return udps
    
    def service2ip(services, service):
      servers = []
      for s in services:
        for ptr in s['PTR']:
          if service in ptr:
            sname = ptr[service]
            for srv in s['SRV']:
              if sname in srv:
                host = srv[sname]['target']
                port = srv[sname]['port']
                for a in s['A']:
                  if host in a:
                    servers.append((a[host], port))
      return servers
    
    def test():
      data = b'\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x05_mqtt\x04_tcp\x05local\x00\x00\x0c\x00\x01\x00\x00\x11\x94\x00 \x1dMosquitto MQTT server on arc4\xc0\x0c\xc0(\x00\x10\x80\x01\x00\x00\x11\x94\x00\x01\x00\xc0(\x00!\x80\x01\x00\x00\x00x\x00\r\x00\x00\x00\x00\x07[\x04arc4\xc0\x17\xc0g\x00\x1c\x80\x01\x00\x00\x00x\x00\x10\xfd; \xc7u\x82\x00\x00\x8d\x99\x8e\xe8S\x91g\xc5\xc0g\x00\x1c\x80\x01\x00\x00\x00x\x00\x10\xfd; \xc7u\x82\x00\x00\x00\x00\x00\x00\x00\x00\x03\xd1\xc0g\x00\x01\x80\x01\x00\x00\x00x\x00\x04\xc0\xa8*~'
      
      services = parsePacket(data)
      print('test:', service2ip([services,], b'_mqtt._tcp.local') == [('192.168.42.126', 1883)])
    
    def queryService(udps, service=b'_mqtt._tcp.local', timeout=1000):
      try:
        #udps.setblocking(True)
        udps.sendto(makeServiceQuery(service), ('224.0.0.251', 5353))
        #udps.setblocking(False)
    
        services = []
        
        starttime = utime.ticks_ms()
        while utime.ticks_diff(utime.ticks_ms(), starttime) < timeout:
          poller = uselect.poll()
          poller.register(udps, uselect.POLLIN)
          while poller.poll(100):
            try:
              data, addr = udps.recvfrom(1024)
              print(data, addr)
              s = parsePacket(data)
              if s:
                services.append(s)
            except OSError:
              # No data to evaluate
              pass
      except Exception as e:
        import sys
        sys.print_exception(e)
      
      return service2ip(services, service)
    
    if __name__ == '__main__':
      print("start test")
      test()
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Finish editing this message first!
    Please register or to comment