diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 89ee3f778765..170115a12bc6 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -216,15 +216,26 @@ impl<'a> Tape<'a> { /// States based on #[derive(Debug, Copy, Clone)] enum DecoderState { - /// Decoding an object + /// Decoding an object - awaiting a '"' (new field) or '}' (done) /// /// Contains index of start [`TapeElement::StartObject`] + /// This state handles both the initial `{` and after `,` Object(u32), - /// Decoding a list + /// After a value in an object member - awaiting ',' (next field) or '}' (done) + /// + /// Contains index of start [`TapeElement::StartObject`] + ObjectAfterValue(u32), + /// Decoding a list - awaiting a value or ']' (done) /// /// Contains index of start [`TapeElement::StartList`] + /// This state handles both the initial `[` and after `,` List(u32), + /// After a value in a list - awaiting ',' (next element) or ']' (done) + /// + /// Contains index of start [`TapeElement::StartList`] + ListAfterValue(u32), String, + /// Skip whitespace and detect value type Value, Number, Colon, @@ -242,8 +253,8 @@ enum DecoderState { impl DecoderState { fn as_str(&self) -> &'static str { match self { - DecoderState::Object(_) => "object", - DecoderState::List(_) => "list", + DecoderState::Object(_) | DecoderState::ObjectAfterValue(_) => "object", + DecoderState::List(_) | DecoderState::ListAfterValue(_) => "list", DecoderState::String => "string", DecoderState::Value => "value", DecoderState::Number => "number", @@ -294,6 +305,45 @@ macro_rules! next { }; } +/// Evaluates to the next non-whitespace byte in the iterator or breaks the current loop +macro_rules! next_non_whitespace { + ($next:ident) => { + match $next.next_non_whitespace() { + Some(b) => b, + None => break, + } + }; +} + +/// Dispatches value type detection with optional special case and custom transition function +macro_rules! dispatch_value { + ($self:ident, $b:expr, |$s:ident| $transition:expr $(, $special:pat => $special_body:expr)?) => {{ + let $s = match $b { + $($special => $special_body,)? + b'"' => DecoderState::String, + b @ (b'-' | b'0'..=b'9') => { + $self.bytes.push(b); + DecoderState::Number + } + b'n' => DecoderState::Literal(Literal::Null, 1), + b'f' => DecoderState::Literal(Literal::False, 1), + b't' => DecoderState::Literal(Literal::True, 1), + b'[' => { + let idx = $self.elements.len() as u32; + $self.elements.push(TapeElement::StartList(u32::MAX)); + DecoderState::List(idx) + } + b'{' => { + let idx = $self.elements.len() as u32; + $self.elements.push(TapeElement::StartObject(u32::MAX)); + DecoderState::Object(idx) + } + b => return Err(err(b, "parsing value")), + }; + $transition + }}; +} + /// Implements a state machine for decoding JSON to a tape pub struct TapeDecoder { elements: Vec, @@ -338,6 +388,22 @@ impl TapeDecoder { } } + /// Write the closing elements for an object to the tape + fn end_object(&mut self, start_idx: u32) { + let end_idx = self.elements.len() as u32; + self.elements[start_idx as usize] = TapeElement::StartObject(end_idx); + self.elements.push(TapeElement::EndObject(start_idx)); + self.stack.pop(); + } + + /// Write the closing elements for a list to the tape + fn end_list(&mut self, start_idx: u32) { + let end_idx = self.elements.len() as u32; + self.elements[start_idx as usize] = TapeElement::StartList(end_idx); + self.elements.push(TapeElement::EndList(start_idx)); + self.stack.pop(); + } + pub fn decode(&mut self, buf: &[u8]) -> Result { let mut iter = BufIter::new(buf); @@ -345,52 +411,67 @@ impl TapeDecoder { let state = match self.stack.last_mut() { Some(l) => l, None => { - iter.skip_whitespace(); - if iter.is_empty() || self.cur_row >= self.batch_size { + if self.cur_row >= self.batch_size { break; } // Start of row + let b = next_non_whitespace!(iter); self.cur_row += 1; - self.stack.push(DecoderState::Value); + + // Detect value type and push appropriate state + dispatch_value!(self, b, |s| self.stack.push(s)); + self.stack.last_mut().unwrap() } }; match state { - // Decoding an object + // Expecting object member or close brace DecoderState::Object(start_idx) => { - iter.advance_until(|b| !json_whitespace(b) && b != b','); - match next!(iter) { + let start_idx = *start_idx; + match next_non_whitespace!(iter) { b'"' => { - self.stack.push(DecoderState::Value); + *state = DecoderState::ObjectAfterValue(start_idx); self.stack.push(DecoderState::Colon); self.stack.push(DecoderState::String); } - b'}' => { - let start_idx = *start_idx; - let end_idx = self.elements.len() as u32; - self.elements[start_idx as usize] = TapeElement::StartObject(end_idx); - self.elements.push(TapeElement::EndObject(start_idx)); - self.stack.pop(); - } - b => return Err(err(b, "parsing object")), + b'}' => self.end_object(start_idx), + b => return Err(err(b, "expected '\"' or '}'")), + } + } + // After value in object - expecting comma or close brace + DecoderState::ObjectAfterValue(start_idx) => { + let start_idx = *start_idx; + match next_non_whitespace!(iter) { + b',' => *state = DecoderState::Object(start_idx), + b'}' => self.end_object(start_idx), + b => return Err(err(b, "expected ',' or '}'")), } } - // Decoding a list + // Decoding a list - awaiting next element or ']' DecoderState::List(start_idx) => { - iter.advance_until(|b| !json_whitespace(b) && b != b','); - match iter.peek() { - Some(b']') => { - iter.next(); - let start_idx = *start_idx; - let end_idx = self.elements.len() as u32; - self.elements[start_idx as usize] = TapeElement::StartList(end_idx); - self.elements.push(TapeElement::EndList(start_idx)); - self.stack.pop(); + let start_idx = *start_idx; + dispatch_value!( + self, + next_non_whitespace!(iter), + |s| { + *state = DecoderState::ListAfterValue(start_idx); + self.stack.push(s); + }, + b']' => { + self.end_list(start_idx); + continue; } - Some(_) => self.stack.push(DecoderState::Value), - None => break, + ); + } + // After value in a list - expecting comma or close bracket + DecoderState::ListAfterValue(start_idx) => { + let start_idx = *start_idx; + match next_non_whitespace!(iter) { + b',' => *state = DecoderState::List(start_idx), + b']' => self.end_list(start_idx), + b => return Err(err(b, "expected ',' or ']'")), } } // Decoding a string @@ -409,29 +490,9 @@ impl TapeDecoder { b => unreachable!("{}", b), } } - state @ DecoderState::Value => { - iter.skip_whitespace(); - *state = match next!(iter) { - b'"' => DecoderState::String, - b @ b'-' | b @ b'0'..=b'9' => { - self.bytes.push(b); - DecoderState::Number - } - b'n' => DecoderState::Literal(Literal::Null, 1), - b'f' => DecoderState::Literal(Literal::False, 1), - b't' => DecoderState::Literal(Literal::True, 1), - b'[' => { - let idx = self.elements.len() as u32; - self.elements.push(TapeElement::StartList(u32::MAX)); - DecoderState::List(idx) - } - b'{' => { - let idx = self.elements.len() as u32; - self.elements.push(TapeElement::StartObject(u32::MAX)); - DecoderState::Object(idx) - } - b => return Err(err(b, "parsing value")), - }; + // Skip whitespace and detect value type + DecoderState::Value => { + *state = dispatch_value!(self, next_non_whitespace!(iter), |s| s); } DecoderState::Number => { let s = iter.advance_until(|b| { @@ -447,9 +508,8 @@ impl TapeDecoder { } } DecoderState::Colon => { - iter.skip_whitespace(); - match next!(iter) { - b':' => self.stack.pop(), + match next_non_whitespace!(iter) { + b':' => *state = DecoderState::Value, b => return Err(err(b, "parsing colon")), }; } @@ -676,8 +736,15 @@ impl<'a> BufIter<'a> { } } - fn skip_whitespace(&mut self) { - self.advance_until(|b| !json_whitespace(b)); + // Advance to the next non-whitespace char and consume it + fn next_non_whitespace(&mut self) -> Option { + for b in self.as_slice() { + self.pos += 1; + if !json_whitespace(*b) { + return Some(*b); + } + } + None } } @@ -685,9 +752,9 @@ impl Iterator for BufIter<'_> { type Item = u8; fn next(&mut self) -> Option { - let b = self.peek(); + let b = self.peek()?; self.pos += 1; - b + Some(b) } fn size_hint(&self) -> (usize, Option) { @@ -972,4 +1039,121 @@ mod tests { let res = decoder.decode(b"{\"test\": \"\\udc00\\udc01\"}"); assert!(res.is_err()); } + + #[test] + fn test_valid_comma_usage() { + // Verify that valid JSON with proper comma usage still works + + // Valid object with commas + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"{"a": 1, "b": 2, "c": 3}"#; + decoder.decode(json.as_bytes()).unwrap(); + let tape = decoder.finish().unwrap(); + let mut s = String::new(); + tape.serialize(&mut s, 1); + assert!(s.contains("\"a\"")); + assert!(s.contains("\"b\"")); + assert!(s.contains("\"c\"")); + + // Valid list with commas + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"[1, 2, 3, 4]"#; + decoder.decode(json.as_bytes()).unwrap(); + let tape = decoder.finish().unwrap(); + let mut s = String::new(); + tape.serialize(&mut s, 1); + assert!(s.contains("1")); + assert!(s.contains("2")); + assert!(s.contains("3")); + assert!(s.contains("4")); + + // Empty object (no commas needed) + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"{}"#; + decoder.decode(json.as_bytes()).unwrap(); + decoder.finish().unwrap(); + + // Empty list (no commas needed) + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"[]"#; + decoder.decode(json.as_bytes()).unwrap(); + decoder.finish().unwrap(); + } + + #[test] + fn test_reject_invalid_commas_in_objects() { + // Verify that the parser correctly rejects invalid JSON with extra commas in objects + + // Empty with comma - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"{,}"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("expected '\"' or '}'"), "Error was: {}", err); + + // Leading comma - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"{, "field": 10}"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("expected '\"' or '}'"), "Error was: {}", err); + + // Double comma between fields - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"{"a": 1,, "b": 2}"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("expected '\"' or '}'"), "Error was: {}", err); + + // Multiple commas - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"{"a": 1,,,, "b": 2}"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("expected '\"' or '}'"), "Error was: {}", err); + + // Trailing comma - intentionally allowed + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"{"a": 1,}"#; + decoder.decode(json.as_bytes()).unwrap(); + let tape = decoder.finish().unwrap(); + let mut s = String::new(); + tape.serialize(&mut s, 1); + assert!(s.contains("\"a\"")); + } + + #[test] + fn test_reject_invalid_commas_in_lists() { + // Verify that the parser correctly rejects invalid JSON with extra commas in lists + + // Empty with comma - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"[,]"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("parsing value"), "Error was: {}", err); + + // Leading comma - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"[, 1, 2]"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("parsing value"), "Error was: {}", err); + + // Double comma between elements - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"[1,, 2, 3]"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("parsing value"), "Error was: {}", err); + + // Multiple commas - should reject + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"[1,,,, 2]"#; + let err = decoder.decode(json.as_bytes()).unwrap_err().to_string(); + assert!(err.contains("parsing value"), "Error was: {}", err); + + // Trailing comma - intentionally allowed + let mut decoder = TapeDecoder::new(16, 2); + let json = r#"[1, 2,]"#; + decoder.decode(json.as_bytes()).unwrap(); + let tape = decoder.finish().unwrap(); + let mut s = String::new(); + tape.serialize(&mut s, 1); + assert!(s.contains("1")); + assert!(s.contains("2")); + } }