Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/master/lib/rbmysql/protocol.rb
Views: 11766
# coding: ascii-8bit1# Copyright (C) 2008-2012 TOMITA Masahiro2# mailto:[email protected]34require "socket"5require "timeout"6require "digest/sha1"7require "stringio"89class RbMysql10# MySQL network protocol11class Protocol1213VERSION = 1014MAX_PACKET_LENGTH = 2**24-11516# Convert netdata to Ruby value17# === Argument18# data :: [Packet] packet data19# type :: [Integer] field type20# unsigned :: [true or false] true if value is unsigned21# === Return22# Object :: converted value.23def self.net2value(pkt, type, unsigned)24case type25when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB, Field::TYPE_JSON26return pkt.lcs27when Field::TYPE_TINY28v = pkt.utiny29return unsigned ? v : v < 128 ? v : v-25630when Field::TYPE_SHORT31v = pkt.ushort32return unsigned ? v : v < 32768 ? v : v-6553633when Field::TYPE_INT24, Field::TYPE_LONG34v = pkt.ulong35return unsigned ? v : v < 0x8000_0000 ? v : v-0x10000_000036when Field::TYPE_LONGLONG37n1, n2 = pkt.ulong, pkt.ulong38v = (n2 << 32) | n139return unsigned ? v : v < 0x8000_0000_0000_0000 ? v : v-0x10000_0000_0000_000040when Field::TYPE_FLOAT41return pkt.read(4).unpack('e').first42when Field::TYPE_DOUBLE43return pkt.read(8).unpack('E').first44when Field::TYPE_DATE45len = pkt.utiny46y, m, d = pkt.read(len).unpack("vCC")47t = RbMysql::Time.new(y, m, d, nil, nil, nil)48return t49when Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP50len = pkt.utiny51y, m, d, h, mi, s, sp = pkt.read(len).unpack("vCCCCCV")52return RbMysql::Time.new(y, m, d, h, mi, s, false, sp)53when Field::TYPE_TIME54len = pkt.utiny55sign, d, h, mi, s, sp = pkt.read(len).unpack("CVCCCV")56h = d.to_i * 24 + h.to_i57return RbMysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)58when Field::TYPE_YEAR59return pkt.ushort60when Field::TYPE_BIT61return pkt.lcs62else63raise "not implemented: type=#{type}"64end65end6667# convert Ruby value to netdata68# === Argument69# v :: [Object] Ruby value.70# === Return71# Integer :: type of column. Field::TYPE_*72# String :: netdata73# === Exception74# ProtocolError :: value too large / value is not supported75def self.value2net(v)76case v77when nil78type = Field::TYPE_NULL79val = ""80when Integer81if -0x8000_0000 <= v && v < 0x8000_000082type = Field::TYPE_LONG83val = [v].pack('V')84elsif -0x8000_0000_0000_0000 <= v && v < 0x8000_0000_0000_000085type = Field::TYPE_LONGLONG86val = [v&0xffffffff, v>>32].pack("VV")87elsif 0x8000_0000_0000_0000 <= v && v <= 0xffff_ffff_ffff_ffff88type = Field::TYPE_LONGLONG | 0x800089val = [v&0xffffffff, v>>32].pack("VV")90else91raise ProtocolError, "value too large: #{v}"92end93when Float94type = Field::TYPE_DOUBLE95val = [v].pack("E")96when String97type = Field::TYPE_STRING98val = Packet.lcs(v)99when ::Time100type = Field::TYPE_DATETIME101val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec].pack("CvCCCCCV")102when RbMysql::Time103type = Field::TYPE_DATETIME104val = [11, v.year, v.month, v.day, v.hour, v.min, v.sec, v.second_part].pack("CvCCCCCV")105else106raise ProtocolError, "class #{v.class} is not supported"107end108return type, val109end110111attr_reader :server_info112attr_reader :server_version113attr_reader :thread_id114attr_reader :sqlstate115attr_reader :affected_rows116attr_reader :insert_id117attr_reader :server_status118attr_reader :warning_count119attr_reader :message120attr_accessor :charset121122# @state variable keep state for connection.123# :INIT :: Initial state.124# :READY :: Ready for command.125# :FIELD :: After query(). retr_fields() is needed.126# :RESULT :: After retr_fields(), retr_all_records() or stmt_retr_all_records() is needed.127128# make socket connection to server.129# === Argument130# host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket131# port :: [Integer] port number using by TCPSocket132# socket :: [String,Socket] socket file name using by UNIXSocket, or an existing ::Socket instance133# conn_timeout :: [Integer] connect timeout (sec).134# read_timeout :: [Integer] read timeout (sec).135# write_timeout :: [Integer] write timeout (sec).136# === Exception137# [ClientError] :: connection timeout138def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)139@insert_id = 0140@warning_count = 0141@gc_stmt_queue = [] # stmt id list which GC destroy.142set_state :INIT143@read_timeout = read_timeout144@write_timeout = write_timeout145begin146Timeout.timeout conn_timeout do147if host.nil? or host.empty? or host == "localhost"148socket ||= ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT149@sock = UNIXSocket.new socket150else151if !socket152port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)153@sock = TCPSocket.new host, port154else155@sock = socket156end157end158end159rescue Timeout::Error160raise ClientError, "connection timeout"161end162end163164def close165@sock.close166end167168# initial negotiate and authenticate.169# === Argument170# user :: [String / nil] username171# passwd :: [String / nil] password172# db :: [String / nil] default database name. nil: no default.173# flag :: [Integer] client flag174# charset :: [RbMysql::Charset / nil] charset for connection. nil: use server's charset175# === Exception176# ProtocolError :: The old style password is not supported177def authenticate(user, passwd, db, flag, charset)178check_state :INIT179@authinfo = [user, passwd, db, flag, charset]180reset181init_packet = InitialPacket.parse read182@server_info = init_packet.server_version183@server_version = init_packet.server_version.split(/\D/)[0,3].inject{|a,b|a.to_i*100+b.to_i}184@thread_id = init_packet.thread_id185@client_flags = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_TRANSACTIONS | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION186@client_flags |= CLIENT_CONNECT_WITH_DB if db187@client_flags |= flag188@charset = charset189unless @charset190@charset = Charset.by_number(init_packet.server_charset)191@charset.encoding # raise error if unsupported charset192end193netpw = encrypt_password passwd, init_packet.scramble_buff194write AuthenticationPacket.serialize(@client_flags, 1024**3, @charset.number, user, netpw, db)195raise ProtocolError, 'The old style password is not supported' if read.to_s == "\xfe"196set_state :READY197end198199# Quit command200def quit_command201synchronize do202reset203write [COM_QUIT].pack("C")204close205end206end207208# Query command209# === Argument210# query :: [String] query string211# === Return212# [Integer / nil] number of fields of results. nil if no results.213def query_command(query)214check_state :READY215begin216reset217write [COM_QUERY, @charset.convert(query)].pack("Ca*")218get_result219rescue220set_state :READY221raise222end223end224225# get result of query.226# === Return227# [integer / nil] number of fields of results. nil if no results.228def get_result229begin230res_packet = ResultPacket.parse read231if res_packet.field_count.to_i > 0 # result data exists232set_state :FIELD233return res_packet.field_count234end235if res_packet.field_count.nil? # LOAD DATA LOCAL INFILE236if @client_flags.to_i & CLIENT_LOCAL_FILES == 0237raise ProtocolError, 'Load data local infile forbidden'238end239filename = res_packet.message240File.open(filename){|f| write f}241write nil # EOF mark242read243end244@affected_rows, @insert_id, @server_status, @warning_count, @message =245res_packet.affected_rows, res_packet.insert_id, res_packet.server_status, res_packet.warning_count, res_packet.message246set_state :READY247return nil248rescue249set_state :READY250raise251end252end253254# Retrieve n fields255# === Argument256# n :: [Integer] number of fields257# === Return258# [Array of RbMysql::Field] field list259def retr_fields(n)260check_state :FIELD261begin262fields = n.times.map{Field.new FieldPacket.parse(read)}263read_eof_packet264set_state :RESULT265fields266rescue267set_state :READY268raise269end270end271272# Retrieve all records for simple query273# === Argument274# fields :: [Array<RbMysql::Field>] number of fields275# === Return276# [Array of Array of String] all records277def retr_all_records(fields)278check_state :RESULT279enc = charset.encoding280begin281all_recs = []282until (pkt = read).eof?283all_recs.push RawRecord.new(pkt, fields, enc)284end285pkt.read(3)286@server_status = pkt.utiny287all_recs288ensure289set_state :READY290end291end292293# Field list command294# === Argument295# table :: [String] table name.296# field :: [String / nil] field name that may contain wild card.297# === Return298# [Array of Field] field list299def field_list_command(table, field)300synchronize do301reset302write [COM_FIELD_LIST, table, 0, field].pack("Ca*Ca*")303fields = []304until (data = read).eof?305fields.push Field.new(FieldPacket.parse(data))306end307return fields308end309end310311# Process info command312# === Return313# [Array of Field] field list314def process_info_command315check_state :READY316begin317reset318write [COM_PROCESS_INFO].pack("C")319field_count = read.lcb320fields = field_count.times.map{Field.new FieldPacket.parse(read)}321read_eof_packet322set_state :RESULT323return fields324rescue325set_state :READY326raise327end328end329330# Ping command331def ping_command332simple_command [COM_PING].pack("C")333end334335# Kill command336def kill_command(pid)337simple_command [COM_PROCESS_KILL, pid].pack("CV")338end339340# Refresh command341def refresh_command(op)342simple_command [COM_REFRESH, op].pack("CC")343end344345# Set option command346def set_option_command(opt)347simple_command [COM_SET_OPTION, opt].pack("Cv")348end349350# Shutdown command351def shutdown_command(level)352simple_command [COM_SHUTDOWN, level].pack("CC")353end354355# Statistics command356def statistics_command357simple_command [COM_STATISTICS].pack("C")358end359360# Stmt prepare command361# === Argument362# stmt :: [String] prepared statement363# === Return364# [Integer] statement id365# [Integer] number of parameters366# [Array of Field] field list367def stmt_prepare_command(stmt)368synchronize do369reset370write [COM_STMT_PREPARE, charset.convert(stmt)].pack("Ca*")371res_packet = PrepareResultPacket.parse read372if res_packet.param_count > 0373res_packet.param_count.times{read} # skip parameter packet374read_eof_packet375end376if res_packet.field_count > 0377fields = res_packet.field_count.times.map{Field.new FieldPacket.parse(read)}378read_eof_packet379else380fields = []381end382return res_packet.statement_id, res_packet.param_count, fields383end384end385386# Stmt execute command387# === Argument388# stmt_id :: [Integer] statement id389# values :: [Array] parameters390# === Return391# [Integer] number of fields392def stmt_execute_command(stmt_id, values)393check_state :READY394begin395reset396write ExecutePacket.serialize(stmt_id, RbMysql::Stmt::CURSOR_TYPE_NO_CURSOR, values)397get_result398rescue399set_state :READY400raise401end402end403404# Retrieve all records for prepared statement405# === Argument406# fields :: [Array of RbMysql::Fields] field list407# charset :: [RbMysql::Charset]408# === Return409# [Array of Array of Object] all records410def stmt_retr_all_records(fields, charset)411check_state :RESULT412enc = charset.encoding413begin414all_recs = []415until (pkt = read).eof?416all_recs.push StmtRawRecord.new(pkt, fields, enc)417end418all_recs419ensure420set_state :READY421end422end423424# Stmt close command425# === Argument426# stmt_id :: [Integer] statement id427def stmt_close_command(stmt_id)428synchronize do429reset430write [COM_STMT_CLOSE, stmt_id].pack("CV")431end432end433434def gc_stmt(stmt_id)435@gc_stmt_queue.push stmt_id436end437438private439440def check_state(st)441raise 'command out of sync' unless @state == st442end443444def set_state(st)445@state = st446if st == :READY447gc_disabled = GC.disable448begin449while st = @gc_stmt_queue.shift450reset451write [COM_STMT_CLOSE, st].pack("CV")452end453ensure454GC.enable unless gc_disabled455end456end457end458459def synchronize460begin461check_state :READY462return yield463ensure464set_state :READY465end466end467468# Reset sequence number469def reset470@seq = 0 # packet counter. reset by each command471end472473# Read one packet data474# === Return475# [Packet] packet data476# === Exception477# [ProtocolError] invalid packet sequence number478def read479data = ''480len = nil481begin482Timeout.timeout @read_timeout do483header = @sock.read(4)484raise EOFError unless header && header.length == 4485len1, len2, seq = header.unpack("CvC")486len = (len2 << 8) + len1487raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq488@seq = (@seq + 1) % 256489ret = @sock.read(len)490raise EOFError unless ret && ret.length == len491data.concat ret492end493rescue EOFError494raise ClientError::ServerGoneError, 'MySQL server has gone away'495rescue Timeout::Error496raise ClientError, "read timeout"497end while len == MAX_PACKET_LENGTH498499@sqlstate = "00000"500501# Error packet502if data[0] == ?\xff503f, errno, marker, @sqlstate, message = data.unpack("Cvaa5a*")504unless marker == "#"505f, errno, message = data.unpack("Cva*") # Version 4.0 Error506@sqlstate = ""507end508message.force_encoding(@charset.encoding) if @charset509if RbMysql::ServerError::ERROR_MAP.key? errno510raise RbMysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)511end512raise RbMysql::ServerError.new(message, @sqlstate)513end514Packet.new(data)515end516517# Write one packet data518# === Argument519# data :: [String / IO] packet data. If data is nil, write empty packet.520def write(data)521begin522@sock.sync = false523if data.nil?524Timeout.timeout @write_timeout do525@sock.write [0, 0, @seq].pack("CvC")526end527@seq = (@seq + 1) % 256528else529data = StringIO.new data if data.is_a? String530while d = data.read(MAX_PACKET_LENGTH)531Timeout.timeout @write_timeout do532@sock.write [d.length%256, d.length/256, @seq].pack("CvC")533@sock.write d534end535@seq = (@seq + 1) % 256536end537end538@sock.sync = true539Timeout.timeout @write_timeout do540@sock.flush541end542rescue Errno::EPIPE543raise ClientError::ServerGoneError, 'MySQL server has gone away'544rescue Timeout::Error545raise ClientError, "write timeout"546end547end548549# Read EOF packet550# === Exception551# [ProtocolError] packet is not EOF552def read_eof_packet553raise ProtocolError, "packet is not EOF" unless read.eof?554end555556# Send simple command557# === Argument558# packet :: [String] packet data559# === Return560# [String] received data561def simple_command(packet)562synchronize do563reset564write packet565read.to_s566end567end568569# Encrypt password570# === Argument571# plain :: [String] plain password.572# scramble :: [String] scramble code from initial packet.573# === Return574# [String] encrypted password575def encrypt_password(plain, scramble)576return "" if plain.nil? or plain.empty?577hash_stage1 = Digest::SHA1.digest plain578hash_stage2 = Digest::SHA1.digest hash_stage1579return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(scramble+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")580end581582# Initial packet583class InitialPacket584def self.parse(pkt)585protocol_version = pkt.utiny586server_version = pkt.string587thread_id = pkt.ulong588scramble_buff = pkt.read(8)589f0 = pkt.utiny590server_capabilities = pkt.ushort591server_charset = pkt.utiny592server_status = pkt.ushort593_f1 = pkt.read(13)594rest_scramble_buff = pkt.string595raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION596raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0597scramble_buff.concat rest_scramble_buff598self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff599end600601attr_reader :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff602603def initialize(*args)604@protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args605end606end607608# Result packet609class ResultPacket610def self.parse(pkt)611field_count = pkt.lcb612if field_count == 0613affected_rows = pkt.lcb614insert_id = pkt.lcb615server_status = pkt.ushort616warning_count = pkt.ushort617message = pkt.lcs618return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)619elsif field_count.nil? # LOAD DATA LOCAL INFILE620return self.new(nil, nil, nil, nil, nil, pkt.to_s)621else622return self.new(field_count)623end624end625626attr_reader :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message627628def initialize(*args)629@field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args630end631end632633# Field packet634class FieldPacket635def self.parse(pkt)636_first = pkt.lcs637db = pkt.lcs638table = pkt.lcs639org_table = pkt.lcs640name = pkt.lcs641org_name = pkt.lcs642_f0 = pkt.utiny643charsetnr = pkt.ushort644length = pkt.ulong645type = pkt.utiny646flags = pkt.ushort647decimals = pkt.utiny648f1 = pkt.ushort649650raise ProtocolError, "invalid packet: f1=#{f1}" unless f1 == 0651default = pkt.lcs652return self.new(db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default)653end654655attr_reader :db, :table, :org_table, :name, :org_name, :charsetnr, :length, :type, :flags, :decimals, :default656657def initialize(*args)658@db, @table, @org_table, @name, @org_name, @charsetnr, @length, @type, @flags, @decimals, @default = args659end660end661662# Prepare result packet663class PrepareResultPacket664def self.parse(pkt)665raise ProtocolError, "invalid packet" unless pkt.utiny == 0666statement_id = pkt.ulong667field_count = pkt.ushort668param_count = pkt.ushort669f = pkt.utiny670warning_count = pkt.ushort671raise ProtocolError, "invalid packet" unless f == 0x00672self.new statement_id, field_count, param_count, warning_count673end674675attr_reader :statement_id, :field_count, :param_count, :warning_count676677def initialize(*args)678@statement_id, @field_count, @param_count, @warning_count = args679end680end681682# Authentication packet683class AuthenticationPacket684def self.serialize(client_flags, max_packet_size, charset_number, username, scrambled_password, databasename)685[686client_flags,687max_packet_size,688Packet.lcb(charset_number),689"", # always 0x00 * 23690username,691Packet.lcs(scrambled_password),692databasename693].pack("VVa*a23Z*A*Z*")694end695end696697# Execute packet698class ExecutePacket699def self.serialize(statement_id, cursor_type, values)700nbm = null_bitmap values701netvalues = ""702types = values.map do |v|703t, n = Protocol.value2net v704netvalues.concat n if v705t706end707[RbMysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")708end709710# make null bitmap711#712# If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).713def self.null_bitmap(values)714bitmap = values.enum_for(:each_slice,8).map do |vals|715vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}716end717return bitmap.pack("C*")718end719720end721end722723class RawRecord724def initialize(packet, fields, encoding)725@packet, @fields, @encoding = packet, fields, encoding726end727728def to_a729@fields.map do |f|730if s = @packet.lcs731unless f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER732s = Charset.convert_encoding(s, @encoding)733end734end735s736end737end738end739740class StmtRawRecord741# === Argument742# pkt :: [Packet]743# fields :: [Array of Fields]744# encoding:: [Encoding]745def initialize(packet, fields, encoding)746@packet, @fields, @encoding = packet, fields, encoding747end748749# Parse statement result packet750# === Return751# [Array of Object] one record752def parse_record_packet753@packet.utiny # skip first byte754null_bit_map = @packet.read((@fields.length+7+2)/8).unpack("b*").first755rec = @fields.each_with_index.map do |f, i|756if null_bit_map[i+2] == ?1757nil758else759unsigned = f.flags & Field::UNSIGNED_FLAG != 0760v = Protocol.net2value(@packet, f.type, unsigned)761if v.is_a? Numeric or v.is_a? RbMysql::Time762v763elsif f.type == Field::TYPE_BIT or f.charsetnr == Charset::BINARY_CHARSET_NUMBER764Charset.to_binary(v)765else766Charset.convert_encoding(v, @encoding)767end768end769end770rec771end772773alias to_a parse_record_packet774775end776end777778779